Line data Source code
1 : use std::cell::{Cell, RefCell};
2 : use std::collections::HashMap;
3 : use std::hash::BuildHasher;
4 : use std::sync::atomic::{AtomicU32, Ordering};
5 : use std::{array, env, fmt, io};
6 :
7 : use chrono::{DateTime, Utc};
8 : use indexmap::IndexSet;
9 : use opentelemetry::trace::TraceContextExt;
10 : use scopeguard::defer;
11 : use serde::ser::{SerializeMap, Serializer};
12 : use tracing::subscriber::Interest;
13 : use tracing::{Event, Metadata, Span, Subscriber, callsite, span};
14 : use tracing_opentelemetry::OpenTelemetrySpanExt;
15 : use tracing_subscriber::filter::{EnvFilter, LevelFilter};
16 : use tracing_subscriber::fmt::format::{Format, Full};
17 : use tracing_subscriber::fmt::time::SystemTime;
18 : use tracing_subscriber::fmt::{FormatEvent, FormatFields};
19 : use tracing_subscriber::layer::{Context, Layer};
20 : use tracing_subscriber::prelude::*;
21 : use tracing_subscriber::registry::{LookupSpan, SpanRef};
22 : use try_lock::TryLock;
23 :
24 : /// Initialize logging and OpenTelemetry tracing and exporter.
25 : ///
26 : /// Logging can be configured using `RUST_LOG` environment variable.
27 : ///
28 : /// OpenTelemetry is configured with OTLP/HTTP exporter. It picks up
29 : /// configuration from environment variables. For example, to change the
30 : /// destination, set `OTEL_EXPORTER_OTLP_ENDPOINT=http://jaeger:4318`.
31 : /// See <https://opentelemetry.io/docs/reference/specification/sdk-environment-variables>
32 0 : pub async fn init() -> anyhow::Result<LoggingGuard> {
33 0 : let logfmt = LogFormat::from_env()?;
34 :
35 0 : let env_filter = EnvFilter::builder()
36 0 : .with_default_directive(LevelFilter::INFO.into())
37 0 : .from_env_lossy()
38 0 : .add_directive(
39 0 : "aws_config=info"
40 0 : .parse()
41 0 : .expect("this should be a valid filter directive"),
42 0 : )
43 0 : .add_directive(
44 0 : "azure_core::policies::transport=off"
45 0 : .parse()
46 0 : .expect("this should be a valid filter directive"),
47 0 : );
48 :
49 0 : let otlp_layer = tracing_utils::init_tracing("proxy").await;
50 :
51 0 : let json_log_layer = if logfmt == LogFormat::Json {
52 0 : Some(JsonLoggingLayer::new(
53 0 : RealClock,
54 0 : StderrWriter {
55 0 : stderr: std::io::stderr(),
56 0 : },
57 0 : ["request_id", "session_id", "conn_id"],
58 0 : ))
59 : } else {
60 0 : None
61 : };
62 :
63 0 : let text_log_layer = if logfmt == LogFormat::Text {
64 0 : Some(
65 0 : tracing_subscriber::fmt::layer()
66 0 : .with_ansi(false)
67 0 : .with_writer(std::io::stderr)
68 0 : .with_target(false),
69 0 : )
70 : } else {
71 0 : None
72 : };
73 :
74 0 : tracing_subscriber::registry()
75 0 : .with(env_filter)
76 0 : .with(otlp_layer)
77 0 : .with(json_log_layer)
78 0 : .with(text_log_layer)
79 0 : .try_init()?;
80 :
81 0 : Ok(LoggingGuard)
82 0 : }
83 :
84 : /// Initialize logging for local_proxy with log prefix and no opentelemetry.
85 : ///
86 : /// Logging can be configured using `RUST_LOG` environment variable.
87 0 : pub fn init_local_proxy() -> anyhow::Result<LoggingGuard> {
88 0 : let env_filter = EnvFilter::builder()
89 0 : .with_default_directive(LevelFilter::INFO.into())
90 0 : .from_env_lossy();
91 0 :
92 0 : let fmt_layer = tracing_subscriber::fmt::layer()
93 0 : .with_ansi(false)
94 0 : .with_writer(std::io::stderr)
95 0 : .event_format(LocalProxyFormatter(Format::default().with_target(false)));
96 0 :
97 0 : tracing_subscriber::registry()
98 0 : .with(env_filter)
99 0 : .with(fmt_layer)
100 0 : .try_init()?;
101 :
102 0 : Ok(LoggingGuard)
103 0 : }
104 :
105 : pub struct LocalProxyFormatter(Format<Full, SystemTime>);
106 :
107 : impl<S, N> FormatEvent<S, N> for LocalProxyFormatter
108 : where
109 : S: Subscriber + for<'a> LookupSpan<'a>,
110 : N: for<'a> FormatFields<'a> + 'static,
111 : {
112 0 : fn format_event(
113 0 : &self,
114 0 : ctx: &tracing_subscriber::fmt::FmtContext<'_, S, N>,
115 0 : mut writer: tracing_subscriber::fmt::format::Writer<'_>,
116 0 : event: &tracing::Event<'_>,
117 0 : ) -> std::fmt::Result {
118 0 : writer.write_str("[local_proxy] ")?;
119 0 : self.0.format_event(ctx, writer, event)
120 0 : }
121 : }
122 :
123 : pub struct LoggingGuard;
124 :
125 : impl Drop for LoggingGuard {
126 0 : fn drop(&mut self) {
127 0 : // Shutdown trace pipeline gracefully, so that it has a chance to send any
128 0 : // pending traces before we exit.
129 0 : tracing::info!("shutting down the tracing machinery");
130 0 : tracing_utils::shutdown_tracing();
131 0 : }
132 : }
133 :
134 : // TODO: make JSON the default
135 : #[derive(Copy, Clone, PartialEq, Eq, Default, Debug)]
136 : enum LogFormat {
137 : #[default]
138 : Text = 1,
139 : Json,
140 : }
141 :
142 : impl LogFormat {
143 0 : fn from_env() -> anyhow::Result<Self> {
144 0 : let logfmt = env::var("LOGFMT");
145 0 : Ok(match logfmt.as_deref() {
146 0 : Err(_) => LogFormat::default(),
147 0 : Ok("text") => LogFormat::Text,
148 0 : Ok("json") => LogFormat::Json,
149 0 : Ok(logfmt) => anyhow::bail!("unknown log format: {logfmt}"),
150 : })
151 0 : }
152 : }
153 :
154 : trait MakeWriter {
155 : fn make_writer(&self) -> impl io::Write;
156 : }
157 :
158 : struct StderrWriter {
159 : stderr: io::Stderr,
160 : }
161 :
162 : impl MakeWriter for StderrWriter {
163 : #[inline]
164 0 : fn make_writer(&self) -> impl io::Write {
165 0 : self.stderr.lock()
166 0 : }
167 : }
168 :
169 : // TODO: move into separate module or even separate crate.
170 : trait Clock {
171 : fn now(&self) -> DateTime<Utc>;
172 : }
173 :
174 : struct RealClock;
175 :
176 : impl Clock for RealClock {
177 : #[inline]
178 0 : fn now(&self) -> DateTime<Utc> {
179 0 : Utc::now()
180 0 : }
181 : }
182 :
183 : /// Name of the field used by tracing crate to store the event message.
184 : const MESSAGE_FIELD: &str = "message";
185 :
186 : thread_local! {
187 : /// Protects against deadlocks and double panics during log writing.
188 : /// The current panic handler will use tracing to log panic information.
189 : static REENTRANCY_GUARD: Cell<bool> = const { Cell::new(false) };
190 : /// Thread-local instance with per-thread buffer for log writing.
191 : static EVENT_FORMATTER: RefCell<EventFormatter> = RefCell::new(EventFormatter::new());
192 : /// Cached OS thread ID.
193 : static THREAD_ID: u64 = gettid::gettid();
194 : }
195 :
196 : /// Implements tracing layer to handle events specific to logging.
197 : struct JsonLoggingLayer<C: Clock, W: MakeWriter, const F: usize> {
198 : clock: C,
199 : skipped_field_indices: papaya::HashMap<callsite::Identifier, SkippedFieldIndices>,
200 : callsite_ids: papaya::HashMap<callsite::Identifier, CallsiteId>,
201 : writer: W,
202 : // We use a const generic and arrays to bypass one heap allocation.
203 : extract_fields: IndexSet<&'static str>,
204 : _marker: std::marker::PhantomData<[&'static str; F]>,
205 : }
206 :
207 : impl<C: Clock, W: MakeWriter, const F: usize> JsonLoggingLayer<C, W, F> {
208 0 : fn new(clock: C, writer: W, extract_fields: [&'static str; F]) -> Self {
209 0 : JsonLoggingLayer {
210 0 : clock,
211 0 : skipped_field_indices: papaya::HashMap::default(),
212 0 : callsite_ids: papaya::HashMap::default(),
213 0 : writer,
214 0 : extract_fields: IndexSet::from_iter(extract_fields),
215 0 : _marker: std::marker::PhantomData,
216 0 : }
217 0 : }
218 :
219 : #[inline]
220 2 : fn callsite_id(&self, cs: callsite::Identifier) -> CallsiteId {
221 2 : *self
222 2 : .callsite_ids
223 2 : .pin()
224 2 : .get_or_insert_with(cs, CallsiteId::next)
225 2 : }
226 : }
227 :
228 : impl<S, C: Clock + 'static, W: MakeWriter + 'static, const F: usize> Layer<S>
229 : for JsonLoggingLayer<C, W, F>
230 : where
231 : S: Subscriber + for<'a> LookupSpan<'a>,
232 : {
233 1 : fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
234 : use std::io::Write;
235 :
236 : // TODO: consider special tracing subscriber to grab timestamp very
237 : // early, before OTel machinery, and add as event extension.
238 1 : let now = self.clock.now();
239 1 :
240 1 : let res: io::Result<()> = REENTRANCY_GUARD.with(move |entered| {
241 1 : if entered.get() {
242 0 : let mut formatter = EventFormatter::new();
243 0 : formatter.format::<S, F>(
244 0 : now,
245 0 : event,
246 0 : &ctx,
247 0 : &self.skipped_field_indices,
248 0 : &self.callsite_ids,
249 0 : &self.extract_fields,
250 0 : )?;
251 0 : self.writer.make_writer().write_all(formatter.buffer())
252 : } else {
253 1 : entered.set(true);
254 1 : defer!(entered.set(false););
255 1 :
256 1 : EVENT_FORMATTER.with_borrow_mut(move |formatter| {
257 1 : formatter.reset();
258 1 : formatter.format::<S, F>(
259 1 : now,
260 1 : event,
261 1 : &ctx,
262 1 : &self.skipped_field_indices,
263 1 : &self.callsite_ids,
264 1 : &self.extract_fields,
265 1 : )?;
266 1 : self.writer.make_writer().write_all(formatter.buffer())
267 1 : })
268 : }
269 1 : });
270 :
271 : // In case logging fails we generate a simpler JSON object.
272 1 : if let Err(err) = res {
273 0 : if let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
274 0 : "timestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
275 0 : "level": "ERROR",
276 0 : "message": format_args!("cannot log event: {err:?}"),
277 0 : "fields": {
278 0 : "event": format_args!("{event:?}"),
279 0 : },
280 0 : })) {
281 0 : line.push(b'\n');
282 0 : self.writer.make_writer().write_all(&line).ok();
283 0 : }
284 1 : }
285 1 : }
286 :
287 : /// Registers a SpanFields instance as span extension.
288 2 : fn on_new_span(&self, attrs: &span::Attributes<'_>, id: &span::Id, ctx: Context<'_, S>) {
289 2 : let span = ctx.span(id).expect("span must exist");
290 2 : let fields = SpanFields::default();
291 2 : fields.record_fields(attrs);
292 2 :
293 2 : // This could deadlock when there's a panic somewhere in the tracing
294 2 : // event handling and a read or write guard is still held. This includes
295 2 : // the OTel subscriber.
296 2 : let mut exts = span.extensions_mut();
297 2 :
298 2 : exts.insert(fields);
299 2 : }
300 :
301 0 : fn on_record(&self, id: &span::Id, values: &span::Record<'_>, ctx: Context<'_, S>) {
302 0 : let span = ctx.span(id).expect("span must exist");
303 0 : let ext = span.extensions();
304 0 : if let Some(data) = ext.get::<SpanFields>() {
305 0 : data.record_fields(values);
306 0 : }
307 0 : }
308 :
309 : /// Called (lazily) whenever a new log call is executed. We quickly check
310 : /// for duplicate field names and record duplicates as skippable. Last one
311 : /// wins.
312 3 : fn register_callsite(&self, metadata: &'static Metadata<'static>) -> Interest {
313 3 : if !metadata.is_event() {
314 2 : self.callsite_id(metadata.callsite());
315 2 : // Must not be never because we wouldn't get trace and span data.
316 2 : return Interest::always();
317 1 : }
318 1 :
319 1 : let mut field_indices = SkippedFieldIndices::default();
320 1 : let mut seen_fields = HashMap::<&'static str, usize>::new();
321 5 : for field in metadata.fields() {
322 : use std::collections::hash_map::Entry;
323 5 : match seen_fields.entry(field.name()) {
324 2 : Entry::Vacant(entry) => {
325 2 : // field not seen yet
326 2 : entry.insert(field.index());
327 2 : }
328 3 : Entry::Occupied(mut entry) => {
329 3 : // replace currently stored index
330 3 : let old_index = entry.insert(field.index());
331 3 : // ... and append it to list of skippable indices
332 3 : field_indices.push(old_index);
333 3 : }
334 : }
335 : }
336 :
337 1 : if !field_indices.is_empty() {
338 1 : self.skipped_field_indices
339 1 : .pin()
340 1 : .insert(metadata.callsite(), field_indices);
341 1 : }
342 :
343 1 : Interest::always()
344 3 : }
345 : }
346 :
347 : #[derive(Copy, Clone, Debug, Default)]
348 : #[repr(transparent)]
349 : struct CallsiteId(u32);
350 :
351 : impl CallsiteId {
352 : #[inline]
353 2 : fn next() -> Self {
354 : // Start at 1 to reserve 0 for default.
355 : static COUNTER: AtomicU32 = AtomicU32::new(1);
356 2 : CallsiteId(COUNTER.fetch_add(1, Ordering::Relaxed))
357 2 : }
358 : }
359 :
360 : impl fmt::Display for CallsiteId {
361 : #[inline]
362 2 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363 2 : self.0.fmt(f)
364 2 : }
365 : }
366 :
367 : /// Stores span field values recorded during the spans lifetime.
368 : #[derive(Default)]
369 : struct SpanFields {
370 : // TODO: Switch to custom enum with lasso::Spur for Strings?
371 : fields: papaya::HashMap<&'static str, serde_json::Value>,
372 : }
373 :
374 : impl SpanFields {
375 : #[inline]
376 2 : fn record_fields<R: tracing_subscriber::field::RecordFields>(&self, fields: R) {
377 2 : fields.record(&mut SpanFieldsRecorder {
378 2 : fields: self.fields.pin(),
379 2 : });
380 2 : }
381 : }
382 :
383 : /// Implements a tracing field visitor to convert and store values.
384 : struct SpanFieldsRecorder<'m, S, G> {
385 : fields: papaya::HashMapRef<'m, &'static str, serde_json::Value, S, G>,
386 : }
387 :
388 : impl<S: BuildHasher, G: papaya::Guard> tracing::field::Visit for SpanFieldsRecorder<'_, S, G> {
389 : #[inline]
390 0 : fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
391 0 : self.fields
392 0 : .insert(field.name(), serde_json::Value::from(value));
393 0 : }
394 :
395 : #[inline]
396 4 : fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
397 4 : self.fields
398 4 : .insert(field.name(), serde_json::Value::from(value));
399 4 : }
400 :
401 : #[inline]
402 0 : fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
403 0 : self.fields
404 0 : .insert(field.name(), serde_json::Value::from(value));
405 0 : }
406 :
407 : #[inline]
408 0 : fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
409 0 : if let Ok(value) = i64::try_from(value) {
410 0 : self.fields
411 0 : .insert(field.name(), serde_json::Value::from(value));
412 0 : } else {
413 0 : self.fields
414 0 : .insert(field.name(), serde_json::Value::from(format!("{value}")));
415 0 : }
416 0 : }
417 :
418 : #[inline]
419 0 : fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
420 0 : if let Ok(value) = u64::try_from(value) {
421 0 : self.fields
422 0 : .insert(field.name(), serde_json::Value::from(value));
423 0 : } else {
424 0 : self.fields
425 0 : .insert(field.name(), serde_json::Value::from(format!("{value}")));
426 0 : }
427 0 : }
428 :
429 : #[inline]
430 0 : fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
431 0 : self.fields
432 0 : .insert(field.name(), serde_json::Value::from(value));
433 0 : }
434 :
435 : #[inline]
436 0 : fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
437 0 : self.fields
438 0 : .insert(field.name(), serde_json::Value::from(value));
439 0 : }
440 :
441 : #[inline]
442 0 : fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
443 0 : self.fields
444 0 : .insert(field.name(), serde_json::Value::from(value));
445 0 : }
446 :
447 : #[inline]
448 0 : fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
449 0 : self.fields
450 0 : .insert(field.name(), serde_json::Value::from(format!("{value:?}")));
451 0 : }
452 :
453 : #[inline]
454 0 : fn record_error(
455 0 : &mut self,
456 0 : field: &tracing::field::Field,
457 0 : value: &(dyn std::error::Error + 'static),
458 0 : ) {
459 0 : self.fields
460 0 : .insert(field.name(), serde_json::Value::from(format!("{value}")));
461 0 : }
462 : }
463 :
464 : /// List of field indices skipped during logging. Can list duplicate fields or
465 : /// metafields not meant to be logged.
466 : #[derive(Clone, Default)]
467 : struct SkippedFieldIndices {
468 : bits: u64,
469 : }
470 :
471 : impl SkippedFieldIndices {
472 : #[inline]
473 1 : fn is_empty(&self) -> bool {
474 1 : self.bits == 0
475 1 : }
476 :
477 : #[inline]
478 3 : fn push(&mut self, index: usize) {
479 3 : self.bits |= 1u64
480 3 : .checked_shl(index as u32)
481 3 : .expect("field index too large");
482 3 : }
483 :
484 : #[inline]
485 10 : fn contains(&self, index: usize) -> bool {
486 10 : self.bits
487 10 : & 1u64
488 10 : .checked_shl(index as u32)
489 10 : .expect("field index too large")
490 10 : != 0
491 10 : }
492 : }
493 :
494 : /// Formats a tracing event and writes JSON to its internal buffer including a newline.
495 : // TODO: buffer capacity management, truncate if too large
496 : struct EventFormatter {
497 : logline_buffer: Vec<u8>,
498 : }
499 :
500 : impl EventFormatter {
501 : #[inline]
502 1 : fn new() -> Self {
503 1 : EventFormatter {
504 1 : logline_buffer: Vec::new(),
505 1 : }
506 1 : }
507 :
508 : #[inline]
509 1 : fn buffer(&self) -> &[u8] {
510 1 : &self.logline_buffer
511 1 : }
512 :
513 : #[inline]
514 1 : fn reset(&mut self) {
515 1 : self.logline_buffer.clear();
516 1 : }
517 :
518 1 : fn format<S, const F: usize>(
519 1 : &mut self,
520 1 : now: DateTime<Utc>,
521 1 : event: &Event<'_>,
522 1 : ctx: &Context<'_, S>,
523 1 : skipped_field_indices: &papaya::HashMap<callsite::Identifier, SkippedFieldIndices>,
524 1 : callsite_ids: &papaya::HashMap<callsite::Identifier, CallsiteId>,
525 1 : extract_fields: &IndexSet<&'static str>,
526 1 : ) -> io::Result<()>
527 1 : where
528 1 : S: Subscriber + for<'a> LookupSpan<'a>,
529 1 : {
530 1 : let timestamp = now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true);
531 :
532 : use tracing_log::NormalizeEvent;
533 1 : let normalized_meta = event.normalized_metadata();
534 1 : let meta = normalized_meta.as_ref().unwrap_or_else(|| event.metadata());
535 1 :
536 1 : let skipped_field_indices = skipped_field_indices.pin();
537 1 : let skipped_field_indices = skipped_field_indices.get(&meta.callsite());
538 1 :
539 1 : let mut serialize = || {
540 1 : let mut serializer = serde_json::Serializer::new(&mut self.logline_buffer);
541 :
542 1 : let mut serializer = serializer.serialize_map(None)?;
543 :
544 : // Timestamp comes first, so raw lines can be sorted by timestamp.
545 1 : serializer.serialize_entry("timestamp", ×tamp)?;
546 :
547 : // Level next.
548 1 : serializer.serialize_entry("level", &meta.level().as_str())?;
549 :
550 : // Message next.
551 1 : serializer.serialize_key("message")?;
552 1 : let mut message_extractor =
553 1 : MessageFieldExtractor::new(serializer, skipped_field_indices);
554 1 : event.record(&mut message_extractor);
555 1 : let mut serializer = message_extractor.into_serializer()?;
556 :
557 : // Direct message fields.
558 1 : let mut fields_present = FieldsPresent(false, skipped_field_indices);
559 1 : event.record(&mut fields_present);
560 1 : if fields_present.0 {
561 1 : serializer.serialize_entry(
562 1 : "fields",
563 1 : &SerializableEventFields(event, skipped_field_indices),
564 1 : )?;
565 0 : }
566 :
567 1 : let spans = SerializableSpans {
568 1 : ctx,
569 1 : callsite_ids,
570 1 : extract: ExtractedSpanFields::<'_, F>::new(extract_fields),
571 1 : };
572 1 : serializer.serialize_entry("spans", &spans)?;
573 :
574 : // TODO: thread-local cache?
575 1 : let pid = std::process::id();
576 1 : // Skip adding pid 1 to reduce noise for services running in containers.
577 1 : if pid != 1 {
578 1 : serializer.serialize_entry("process_id", &pid)?;
579 0 : }
580 :
581 1 : THREAD_ID.with(|tid| serializer.serialize_entry("thread_id", tid))?;
582 :
583 : // TODO: tls cache? name could change
584 1 : if let Some(thread_name) = std::thread::current().name() {
585 1 : if !thread_name.is_empty() && thread_name != "tokio-runtime-worker" {
586 1 : serializer.serialize_entry("thread_name", thread_name)?;
587 0 : }
588 0 : }
589 :
590 1 : if let Some(task_id) = tokio::task::try_id() {
591 0 : serializer.serialize_entry("task_id", &format_args!("{task_id}"))?;
592 1 : }
593 :
594 1 : serializer.serialize_entry("target", meta.target())?;
595 :
596 : // Skip adding module if it's the same as target.
597 1 : if let Some(module) = meta.module_path() {
598 1 : if module != meta.target() {
599 0 : serializer.serialize_entry("module", module)?;
600 1 : }
601 0 : }
602 :
603 1 : if let Some(file) = meta.file() {
604 1 : if let Some(line) = meta.line() {
605 1 : serializer.serialize_entry("src", &format_args!("{file}:{line}"))?;
606 : } else {
607 0 : serializer.serialize_entry("src", file)?;
608 : }
609 0 : }
610 :
611 : {
612 1 : let otel_context = Span::current().context();
613 1 : let otel_spanref = otel_context.span();
614 1 : let span_context = otel_spanref.span_context();
615 1 : if span_context.is_valid() {
616 0 : serializer.serialize_entry(
617 0 : "trace_id",
618 0 : &format_args!("{}", span_context.trace_id()),
619 0 : )?;
620 1 : }
621 : }
622 :
623 1 : if spans.extract.has_values() {
624 : // TODO: add fields from event, too?
625 1 : serializer.serialize_entry("extract", &spans.extract)?;
626 0 : }
627 :
628 1 : serializer.end()
629 1 : };
630 :
631 1 : serialize().map_err(io::Error::other)?;
632 1 : self.logline_buffer.push(b'\n');
633 1 : Ok(())
634 1 : }
635 : }
636 :
637 : /// Extracts the message field that's mixed will other fields.
638 : struct MessageFieldExtractor<'a, S: serde::ser::SerializeMap> {
639 : serializer: S,
640 : skipped_field_indices: Option<&'a SkippedFieldIndices>,
641 : state: Option<Result<(), S::Error>>,
642 : }
643 :
644 : impl<'a, S: serde::ser::SerializeMap> MessageFieldExtractor<'a, S> {
645 : #[inline]
646 1 : fn new(serializer: S, skipped_field_indices: Option<&'a SkippedFieldIndices>) -> Self {
647 1 : Self {
648 1 : serializer,
649 1 : skipped_field_indices,
650 1 : state: None,
651 1 : }
652 1 : }
653 :
654 : #[inline]
655 1 : fn into_serializer(mut self) -> Result<S, S::Error> {
656 1 : match self.state {
657 1 : Some(Ok(())) => {}
658 0 : Some(Err(err)) => return Err(err),
659 0 : None => self.serializer.serialize_value("")?,
660 : }
661 1 : Ok(self.serializer)
662 1 : }
663 :
664 : #[inline]
665 5 : fn accept_field(&self, field: &tracing::field::Field) -> bool {
666 5 : self.state.is_none()
667 5 : && field.name() == MESSAGE_FIELD
668 2 : && !self
669 2 : .skipped_field_indices
670 2 : .is_some_and(|i| i.contains(field.index()))
671 5 : }
672 : }
673 :
674 : impl<S: serde::ser::SerializeMap> tracing::field::Visit for MessageFieldExtractor<'_, S> {
675 : #[inline]
676 0 : fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
677 0 : if self.accept_field(field) {
678 0 : self.state = Some(self.serializer.serialize_value(&value));
679 0 : }
680 0 : }
681 :
682 : #[inline]
683 3 : fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
684 3 : if self.accept_field(field) {
685 0 : self.state = Some(self.serializer.serialize_value(&value));
686 3 : }
687 3 : }
688 :
689 : #[inline]
690 0 : fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
691 0 : if self.accept_field(field) {
692 0 : self.state = Some(self.serializer.serialize_value(&value));
693 0 : }
694 0 : }
695 :
696 : #[inline]
697 0 : fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
698 0 : if self.accept_field(field) {
699 0 : self.state = Some(self.serializer.serialize_value(&value));
700 0 : }
701 0 : }
702 :
703 : #[inline]
704 0 : fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
705 0 : if self.accept_field(field) {
706 0 : self.state = Some(self.serializer.serialize_value(&value));
707 0 : }
708 0 : }
709 :
710 : #[inline]
711 0 : fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
712 0 : if self.accept_field(field) {
713 0 : self.state = Some(self.serializer.serialize_value(&value));
714 0 : }
715 0 : }
716 :
717 : #[inline]
718 0 : fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
719 0 : if self.accept_field(field) {
720 0 : self.state = Some(self.serializer.serialize_value(&format_args!("{value:x?}")));
721 0 : }
722 0 : }
723 :
724 : #[inline]
725 1 : fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
726 1 : if self.accept_field(field) {
727 1 : self.state = Some(self.serializer.serialize_value(&value));
728 1 : }
729 1 : }
730 :
731 : #[inline]
732 1 : fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
733 1 : if self.accept_field(field) {
734 0 : self.state = Some(self.serializer.serialize_value(&format_args!("{value:?}")));
735 1 : }
736 1 : }
737 :
738 : #[inline]
739 0 : fn record_error(
740 0 : &mut self,
741 0 : field: &tracing::field::Field,
742 0 : value: &(dyn std::error::Error + 'static),
743 0 : ) {
744 0 : if self.accept_field(field) {
745 0 : self.state = Some(self.serializer.serialize_value(&format_args!("{value}")));
746 0 : }
747 0 : }
748 : }
749 :
750 : /// Checks if there's any fields and field values present. If not, the JSON subobject
751 : /// can be skipped.
752 : // This is entirely optional and only cosmetic, though maybe helps a
753 : // bit during log parsing in dashboards when there's no field with empty object.
754 : struct FieldsPresent<'a>(pub bool, Option<&'a SkippedFieldIndices>);
755 :
756 : // Even though some methods have an overhead (error, bytes) it is assumed the
757 : // compiler won't include this since we ignore the value entirely.
758 : impl tracing::field::Visit for FieldsPresent<'_> {
759 : #[inline]
760 5 : fn record_debug(&mut self, field: &tracing::field::Field, _: &dyn std::fmt::Debug) {
761 5 : if !self.1.is_some_and(|i| i.contains(field.index()))
762 2 : && field.name() != MESSAGE_FIELD
763 1 : && !field.name().starts_with("log.")
764 1 : {
765 1 : self.0 |= true;
766 4 : }
767 5 : }
768 : }
769 :
770 : /// Serializes the fields directly supplied with a log event.
771 : struct SerializableEventFields<'a, 'event>(
772 : &'a tracing::Event<'event>,
773 : Option<&'a SkippedFieldIndices>,
774 : );
775 :
776 : impl serde::ser::Serialize for SerializableEventFields<'_, '_> {
777 1 : fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
778 1 : where
779 1 : S: Serializer,
780 1 : {
781 : use serde::ser::SerializeMap;
782 1 : let serializer = serializer.serialize_map(None)?;
783 1 : let mut message_skipper = MessageFieldSkipper::new(serializer, self.1);
784 1 : self.0.record(&mut message_skipper);
785 1 : let serializer = message_skipper.into_serializer()?;
786 1 : serializer.end()
787 1 : }
788 : }
789 :
790 : /// A tracing field visitor that skips the message field.
791 : struct MessageFieldSkipper<'a, S: serde::ser::SerializeMap> {
792 : serializer: S,
793 : skipped_field_indices: Option<&'a SkippedFieldIndices>,
794 : state: Result<(), S::Error>,
795 : }
796 :
797 : impl<'a, S: serde::ser::SerializeMap> MessageFieldSkipper<'a, S> {
798 : #[inline]
799 1 : fn new(serializer: S, skipped_field_indices: Option<&'a SkippedFieldIndices>) -> Self {
800 1 : Self {
801 1 : serializer,
802 1 : skipped_field_indices,
803 1 : state: Ok(()),
804 1 : }
805 1 : }
806 :
807 : #[inline]
808 5 : fn accept_field(&self, field: &tracing::field::Field) -> bool {
809 5 : self.state.is_ok()
810 5 : && field.name() != MESSAGE_FIELD
811 3 : && !field.name().starts_with("log.")
812 3 : && !self
813 3 : .skipped_field_indices
814 3 : .is_some_and(|i| i.contains(field.index()))
815 5 : }
816 :
817 : #[inline]
818 1 : fn into_serializer(self) -> Result<S, S::Error> {
819 1 : self.state?;
820 1 : Ok(self.serializer)
821 1 : }
822 : }
823 :
824 : impl<S: serde::ser::SerializeMap> tracing::field::Visit for MessageFieldSkipper<'_, S> {
825 : #[inline]
826 0 : fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
827 0 : if self.accept_field(field) {
828 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
829 0 : }
830 0 : }
831 :
832 : #[inline]
833 3 : fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
834 3 : if self.accept_field(field) {
835 1 : self.state = self.serializer.serialize_entry(field.name(), &value);
836 2 : }
837 3 : }
838 :
839 : #[inline]
840 0 : fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
841 0 : if self.accept_field(field) {
842 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
843 0 : }
844 0 : }
845 :
846 : #[inline]
847 0 : fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
848 0 : if self.accept_field(field) {
849 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
850 0 : }
851 0 : }
852 :
853 : #[inline]
854 0 : fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
855 0 : if self.accept_field(field) {
856 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
857 0 : }
858 0 : }
859 :
860 : #[inline]
861 0 : fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
862 0 : if self.accept_field(field) {
863 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
864 0 : }
865 0 : }
866 :
867 : #[inline]
868 0 : fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
869 0 : if self.accept_field(field) {
870 0 : self.state = self
871 0 : .serializer
872 0 : .serialize_entry(field.name(), &format_args!("{value:x?}"));
873 0 : }
874 0 : }
875 :
876 : #[inline]
877 1 : fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
878 1 : if self.accept_field(field) {
879 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
880 1 : }
881 1 : }
882 :
883 : #[inline]
884 1 : fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
885 1 : if self.accept_field(field) {
886 0 : self.state = self
887 0 : .serializer
888 0 : .serialize_entry(field.name(), &format_args!("{value:?}"));
889 1 : }
890 1 : }
891 :
892 : #[inline]
893 0 : fn record_error(
894 0 : &mut self,
895 0 : field: &tracing::field::Field,
896 0 : value: &(dyn std::error::Error + 'static),
897 0 : ) {
898 0 : if self.accept_field(field) {
899 0 : self.state = self.serializer.serialize_value(&format_args!("{value}"));
900 0 : }
901 0 : }
902 : }
903 :
904 : /// Serializes the span stack from root to leaf (parent of event) as object
905 : /// with the span names as keys. To prevent collision we append a numberic value
906 : /// to the name. Also, collects any span fields we're interested in. Last one
907 : /// wins.
908 : struct SerializableSpans<'a, 'ctx, Span, const F: usize>
909 : where
910 : Span: Subscriber + for<'lookup> LookupSpan<'lookup>,
911 : {
912 : ctx: &'a Context<'ctx, Span>,
913 : callsite_ids: &'a papaya::HashMap<callsite::Identifier, CallsiteId>,
914 : extract: ExtractedSpanFields<'a, F>,
915 : }
916 :
917 : impl<Span, const F: usize> serde::ser::Serialize for SerializableSpans<'_, '_, Span, F>
918 : where
919 : Span: Subscriber + for<'lookup> LookupSpan<'lookup>,
920 : {
921 1 : fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
922 1 : where
923 1 : Ser: serde::ser::Serializer,
924 1 : {
925 1 : let mut serializer = serializer.serialize_map(None)?;
926 :
927 1 : if let Some(leaf_span) = self.ctx.lookup_current() {
928 2 : for span in leaf_span.scope().from_root() {
929 : // Append a numeric callsite ID to the span name to keep the name unique
930 : // in the JSON object.
931 2 : let cid = self
932 2 : .callsite_ids
933 2 : .pin()
934 2 : .get(&span.metadata().callsite())
935 2 : .copied()
936 2 : .unwrap_or_default();
937 2 :
938 2 : // Loki turns the # into an underscore during field name concatenation.
939 2 : serializer.serialize_key(&format_args!("{}#{}", span.metadata().name(), &cid))?;
940 :
941 2 : serializer.serialize_value(&SerializableSpanFields {
942 2 : span: &span,
943 2 : extract: &self.extract,
944 2 : })?;
945 : }
946 0 : }
947 :
948 1 : serializer.end()
949 1 : }
950 : }
951 :
952 : /// Serializes the span fields as object.
953 : struct SerializableSpanFields<'a, 'span, Span, const F: usize>
954 : where
955 : Span: for<'lookup> LookupSpan<'lookup>,
956 : {
957 : span: &'a SpanRef<'span, Span>,
958 : extract: &'a ExtractedSpanFields<'a, F>,
959 : }
960 :
961 : impl<Span, const F: usize> serde::ser::Serialize for SerializableSpanFields<'_, '_, Span, F>
962 : where
963 : Span: for<'lookup> LookupSpan<'lookup>,
964 : {
965 2 : fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
966 2 : where
967 2 : S: serde::ser::Serializer,
968 2 : {
969 2 : let mut serializer = serializer.serialize_map(None)?;
970 :
971 2 : let ext = self.span.extensions();
972 2 : if let Some(data) = ext.get::<SpanFields>() {
973 2 : for (name, value) in &data.fields.pin() {
974 2 : serializer.serialize_entry(name, value)?;
975 : // TODO: replace clone with reference, if possible.
976 2 : self.extract.set(name, value.clone());
977 : }
978 0 : }
979 :
980 2 : serializer.end()
981 2 : }
982 : }
983 :
984 : struct ExtractedSpanFields<'a, const F: usize> {
985 : names: &'a IndexSet<&'static str>,
986 : // TODO: replace TryLock with something local thread and interior mutability.
987 : // serde API doesn't let us use `mut`.
988 : values: TryLock<([Option<serde_json::Value>; F], bool)>,
989 : }
990 :
991 : impl<'a, const F: usize> ExtractedSpanFields<'a, F> {
992 1 : fn new(names: &'a IndexSet<&'static str>) -> Self {
993 1 : ExtractedSpanFields {
994 1 : names,
995 1 : values: TryLock::new((array::from_fn(|_| Option::default()), false)),
996 1 : }
997 1 : }
998 :
999 : #[inline]
1000 2 : fn set(&self, name: &'static str, value: serde_json::Value) {
1001 2 : if let Some((index, _)) = self.names.get_full(name) {
1002 2 : let mut fields = self.values.try_lock().expect("thread-local use");
1003 2 : fields.0[index] = Some(value);
1004 2 : fields.1 = true;
1005 2 : }
1006 2 : }
1007 :
1008 : #[inline]
1009 1 : fn has_values(&self) -> bool {
1010 1 : self.values.try_lock().expect("thread-local use").1
1011 1 : }
1012 : }
1013 :
1014 : impl<const F: usize> serde::ser::Serialize for ExtractedSpanFields<'_, F> {
1015 1 : fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1016 1 : where
1017 1 : S: serde::ser::Serializer,
1018 1 : {
1019 1 : let mut serializer = serializer.serialize_map(None)?;
1020 :
1021 1 : let values = self.values.try_lock().expect("thread-local use");
1022 1 : for (i, value) in values.0.iter().enumerate() {
1023 1 : if let Some(value) = value {
1024 1 : let key = self.names[i];
1025 1 : serializer.serialize_entry(key, value)?;
1026 0 : }
1027 : }
1028 :
1029 1 : serializer.end()
1030 1 : }
1031 : }
1032 :
1033 : #[cfg(test)]
1034 : #[allow(clippy::unwrap_used)]
1035 : mod tests {
1036 : use std::marker::PhantomData;
1037 : use std::sync::{Arc, Mutex, MutexGuard};
1038 :
1039 : use assert_json_diff::assert_json_eq;
1040 : use tracing::info_span;
1041 :
1042 : use super::*;
1043 :
1044 : struct TestClock {
1045 : current_time: Mutex<DateTime<Utc>>,
1046 : }
1047 :
1048 : impl Clock for Arc<TestClock> {
1049 2 : fn now(&self) -> DateTime<Utc> {
1050 2 : *self.current_time.lock().expect("poisoned")
1051 2 : }
1052 : }
1053 :
1054 : struct VecWriter<'a> {
1055 : buffer: MutexGuard<'a, Vec<u8>>,
1056 : }
1057 :
1058 : impl MakeWriter for Arc<Mutex<Vec<u8>>> {
1059 1 : fn make_writer(&self) -> impl io::Write {
1060 1 : VecWriter {
1061 1 : buffer: self.lock().expect("poisoned"),
1062 1 : }
1063 1 : }
1064 : }
1065 :
1066 : impl io::Write for VecWriter<'_> {
1067 1 : fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1068 1 : self.buffer.write(buf)
1069 1 : }
1070 :
1071 0 : fn flush(&mut self) -> io::Result<()> {
1072 0 : Ok(())
1073 0 : }
1074 : }
1075 :
1076 : #[test]
1077 1 : fn test_field_collection() {
1078 1 : let clock = Arc::new(TestClock {
1079 1 : current_time: Mutex::new(Utc::now()),
1080 1 : });
1081 1 : let buffer = Arc::new(Mutex::new(Vec::new()));
1082 1 : let log_layer = JsonLoggingLayer {
1083 1 : clock: clock.clone(),
1084 1 : skipped_field_indices: papaya::HashMap::default(),
1085 1 : callsite_ids: papaya::HashMap::default(),
1086 1 : writer: buffer.clone(),
1087 1 : extract_fields: IndexSet::from_iter(["x"]),
1088 1 : _marker: PhantomData::<[&'static str; 1]>,
1089 1 : };
1090 1 :
1091 1 : let registry = tracing_subscriber::Registry::default().with(log_layer);
1092 1 :
1093 1 : tracing::subscriber::with_default(registry, || {
1094 1 : info_span!("some_span", x = 24).in_scope(|| {
1095 1 : info_span!("some_span", x = 40, x = 41, x = 42).in_scope(|| {
1096 1 : tracing::error!(
1097 : a = 1,
1098 : a = 2,
1099 : a = 3,
1100 : message = "explicit message field",
1101 0 : "implicit message field"
1102 : );
1103 1 : });
1104 1 : });
1105 1 : });
1106 1 :
1107 1 : let buffer = Arc::try_unwrap(buffer)
1108 1 : .expect("no other reference")
1109 1 : .into_inner()
1110 1 : .expect("poisoned");
1111 1 : let actual: serde_json::Value = serde_json::from_slice(&buffer).expect("valid JSON");
1112 1 : let expected: serde_json::Value = serde_json::json!(
1113 1 : {
1114 1 : "timestamp": clock.now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
1115 1 : "level": "ERROR",
1116 1 : "message": "explicit message field",
1117 1 : "fields": {
1118 1 : "a": 3,
1119 1 : },
1120 1 : "spans": {
1121 1 : "some_span#1":{
1122 1 : "x": 24,
1123 1 : },
1124 1 : "some_span#2": {
1125 1 : "x": 42,
1126 1 : }
1127 1 : },
1128 1 : "extract": {
1129 1 : "x": 42,
1130 1 : },
1131 1 : "src": actual.as_object().unwrap().get("src").unwrap().as_str().unwrap(),
1132 1 : "target": "proxy::logging::tests",
1133 1 : "process_id": actual.as_object().unwrap().get("process_id").unwrap().as_number().unwrap(),
1134 1 : "thread_id": actual.as_object().unwrap().get("thread_id").unwrap().as_number().unwrap(),
1135 1 : "thread_name": "logging::tests::test_field_collection",
1136 1 : }
1137 1 : );
1138 1 :
1139 1 : assert_json_eq!(actual, expected);
1140 1 : }
1141 : }
|