Line data Source code
1 : use std::cell::RefCell;
2 : use std::collections::HashMap;
3 : use std::sync::Arc;
4 : use std::sync::atomic::{AtomicU32, Ordering};
5 : use std::{env, io};
6 :
7 : use chrono::{DateTime, Utc};
8 : use opentelemetry::trace::TraceContextExt;
9 : use serde::ser::{SerializeMap, Serializer};
10 : use tracing::subscriber::Interest;
11 : use tracing::{Event, Metadata, Span, Subscriber, callsite, span};
12 : use tracing_opentelemetry::OpenTelemetrySpanExt;
13 : use tracing_subscriber::filter::{EnvFilter, LevelFilter};
14 : use tracing_subscriber::fmt::format::{Format, Full};
15 : use tracing_subscriber::fmt::time::SystemTime;
16 : use tracing_subscriber::fmt::{FormatEvent, FormatFields};
17 : use tracing_subscriber::layer::{Context, Layer};
18 : use tracing_subscriber::prelude::*;
19 : use tracing_subscriber::registry::{LookupSpan, SpanRef};
20 :
21 : /// Initialize logging and OpenTelemetry tracing and exporter.
22 : ///
23 : /// Logging can be configured using `RUST_LOG` environment variable.
24 : ///
25 : /// OpenTelemetry is configured with OTLP/HTTP exporter. It picks up
26 : /// configuration from environment variables. For example, to change the
27 : /// destination, set `OTEL_EXPORTER_OTLP_ENDPOINT=http://jaeger:4318`.
28 : /// See <https://opentelemetry.io/docs/reference/specification/sdk-environment-variables>
29 0 : pub async fn init() -> anyhow::Result<LoggingGuard> {
30 0 : let logfmt = LogFormat::from_env()?;
31 :
32 0 : let env_filter = EnvFilter::builder()
33 0 : .with_default_directive(LevelFilter::INFO.into())
34 0 : .from_env_lossy()
35 0 : .add_directive(
36 0 : "aws_config=info"
37 0 : .parse()
38 0 : .expect("this should be a valid filter directive"),
39 0 : )
40 0 : .add_directive(
41 0 : "azure_core::policies::transport=off"
42 0 : .parse()
43 0 : .expect("this should be a valid filter directive"),
44 0 : );
45 :
46 0 : let otlp_layer =
47 0 : tracing_utils::init_tracing("proxy", tracing_utils::ExportConfig::default()).await;
48 :
49 0 : let json_log_layer = if logfmt == LogFormat::Json {
50 0 : Some(JsonLoggingLayer::new(
51 0 : RealClock,
52 0 : StderrWriter {
53 0 : stderr: std::io::stderr(),
54 0 : },
55 0 : &["request_id", "session_id", "conn_id"],
56 0 : ))
57 : } else {
58 0 : None
59 : };
60 :
61 0 : let text_log_layer = if logfmt == LogFormat::Text {
62 0 : Some(
63 0 : tracing_subscriber::fmt::layer()
64 0 : .with_ansi(false)
65 0 : .with_writer(std::io::stderr)
66 0 : .with_target(false),
67 0 : )
68 : } else {
69 0 : None
70 : };
71 :
72 0 : tracing_subscriber::registry()
73 0 : .with(env_filter)
74 0 : .with(otlp_layer)
75 0 : .with(json_log_layer)
76 0 : .with(text_log_layer)
77 0 : .try_init()?;
78 :
79 0 : Ok(LoggingGuard)
80 0 : }
81 :
82 : /// Initialize logging for local_proxy with log prefix and no opentelemetry.
83 : ///
84 : /// Logging can be configured using `RUST_LOG` environment variable.
85 0 : pub fn init_local_proxy() -> anyhow::Result<LoggingGuard> {
86 0 : let env_filter = EnvFilter::builder()
87 0 : .with_default_directive(LevelFilter::INFO.into())
88 0 : .from_env_lossy();
89 0 :
90 0 : let fmt_layer = tracing_subscriber::fmt::layer()
91 0 : .with_ansi(false)
92 0 : .with_writer(std::io::stderr)
93 0 : .event_format(LocalProxyFormatter(Format::default().with_target(false)));
94 0 :
95 0 : tracing_subscriber::registry()
96 0 : .with(env_filter)
97 0 : .with(fmt_layer)
98 0 : .try_init()?;
99 :
100 0 : Ok(LoggingGuard)
101 0 : }
102 :
103 : pub struct LocalProxyFormatter(Format<Full, SystemTime>);
104 :
105 : impl<S, N> FormatEvent<S, N> for LocalProxyFormatter
106 : where
107 : S: Subscriber + for<'a> LookupSpan<'a>,
108 : N: for<'a> FormatFields<'a> + 'static,
109 : {
110 0 : fn format_event(
111 0 : &self,
112 0 : ctx: &tracing_subscriber::fmt::FmtContext<'_, S, N>,
113 0 : mut writer: tracing_subscriber::fmt::format::Writer<'_>,
114 0 : event: &tracing::Event<'_>,
115 0 : ) -> std::fmt::Result {
116 0 : writer.write_str("[local_proxy] ")?;
117 0 : self.0.format_event(ctx, writer, event)
118 0 : }
119 : }
120 :
121 : pub struct LoggingGuard;
122 :
123 : impl Drop for LoggingGuard {
124 0 : fn drop(&mut self) {
125 0 : // Shutdown trace pipeline gracefully, so that it has a chance to send any
126 0 : // pending traces before we exit.
127 0 : tracing::info!("shutting down the tracing machinery");
128 0 : tracing_utils::shutdown_tracing();
129 0 : }
130 : }
131 :
132 : #[derive(Copy, Clone, PartialEq, Eq, Default, Debug)]
133 : enum LogFormat {
134 : Text,
135 : #[default]
136 : Json,
137 : }
138 :
139 : impl LogFormat {
140 0 : fn from_env() -> anyhow::Result<Self> {
141 0 : let logfmt = env::var("LOGFMT");
142 0 : Ok(match logfmt.as_deref() {
143 0 : Err(_) => LogFormat::default(),
144 0 : Ok("text") => LogFormat::Text,
145 0 : Ok("json") => LogFormat::Json,
146 0 : Ok(logfmt) => anyhow::bail!("unknown log format: {logfmt}"),
147 : })
148 0 : }
149 : }
150 :
151 : trait MakeWriter {
152 : fn make_writer(&self) -> impl io::Write;
153 : }
154 :
155 : struct StderrWriter {
156 : stderr: io::Stderr,
157 : }
158 :
159 : impl MakeWriter for StderrWriter {
160 : #[inline]
161 0 : fn make_writer(&self) -> impl io::Write {
162 0 : self.stderr.lock()
163 0 : }
164 : }
165 :
166 : // TODO: move into separate module or even separate crate.
167 : trait Clock {
168 : fn now(&self) -> DateTime<Utc>;
169 : }
170 :
171 : struct RealClock;
172 :
173 : impl Clock for RealClock {
174 : #[inline]
175 0 : fn now(&self) -> DateTime<Utc> {
176 0 : Utc::now()
177 0 : }
178 : }
179 :
180 : /// Name of the field used by tracing crate to store the event message.
181 : const MESSAGE_FIELD: &str = "message";
182 :
183 : /// Tracing used to enforce that spans/events have no more than 32 fields.
184 : /// It seems this is no longer the case, but it's still documented in some places.
185 : /// Generally, we shouldn't expect more than 32 fields anyway, so we can try and
186 : /// rely on it for some (minor) performance gains.
187 : const MAX_TRACING_FIELDS: usize = 32;
188 :
189 : thread_local! {
190 : /// Thread-local instance with per-thread buffer for log writing.
191 : static EVENT_FORMATTER: RefCell<EventFormatter> = const { RefCell::new(EventFormatter::new()) };
192 : /// Cached OS thread ID.
193 : static THREAD_ID: u64 = gettid::gettid();
194 : }
195 :
196 : /// Map for values fixed at callsite registration.
197 : // We use papaya here because registration rarely happens post-startup.
198 : // papaya is good for read-heavy workloads.
199 : //
200 : // We use rustc_hash here because callsite::Identifier will always be an integer with low-bit entropy,
201 : // since it's always a pointer to static mutable data. rustc_hash was designed for low-bit entropy.
202 : type CallsiteMap<T> =
203 : papaya::HashMap<callsite::Identifier, T, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>;
204 :
205 : /// Implements tracing layer to handle events specific to logging.
206 : struct JsonLoggingLayer<C: Clock, W: MakeWriter> {
207 : clock: C,
208 : writer: W,
209 :
210 : /// tracks which fields of each **event** are duplicates
211 : skipped_field_indices: CallsiteMap<SkippedFieldIndices>,
212 :
213 : span_info: CallsiteMap<CallsiteSpanInfo>,
214 :
215 : /// Fields we want to keep track of in a separate json object.
216 : extract_fields: &'static [&'static str],
217 : }
218 :
219 : impl<C: Clock, W: MakeWriter> JsonLoggingLayer<C, W> {
220 0 : fn new(clock: C, writer: W, extract_fields: &'static [&'static str]) -> Self {
221 0 : JsonLoggingLayer {
222 0 : clock,
223 0 : skipped_field_indices: CallsiteMap::default(),
224 0 : span_info: CallsiteMap::default(),
225 0 : writer,
226 0 : extract_fields,
227 0 : }
228 0 : }
229 :
230 : #[inline]
231 4 : fn span_info(&self, metadata: &'static Metadata<'static>) -> CallsiteSpanInfo {
232 4 : self.span_info
233 4 : .pin()
234 4 : .get_or_insert_with(metadata.callsite(), || {
235 2 : CallsiteSpanInfo::new(metadata, self.extract_fields)
236 4 : })
237 4 : .clone()
238 4 : }
239 : }
240 :
241 : impl<S, C: Clock + 'static, W: MakeWriter + 'static> Layer<S> for JsonLoggingLayer<C, W>
242 : where
243 : S: Subscriber + for<'a> LookupSpan<'a>,
244 : {
245 1 : fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
246 : use std::io::Write;
247 :
248 : // TODO: consider special tracing subscriber to grab timestamp very
249 : // early, before OTel machinery, and add as event extension.
250 1 : let now = self.clock.now();
251 1 :
252 1 : let res: io::Result<()> = EVENT_FORMATTER.with(|f| {
253 1 : let mut borrow = f.try_borrow_mut();
254 1 : let formatter = match borrow.as_deref_mut() {
255 1 : Ok(formatter) => formatter,
256 : // If the thread local formatter is borrowed,
257 : // then we likely hit an edge case were we panicked during formatting.
258 : // We allow the logging to proceed with an uncached formatter.
259 0 : Err(_) => &mut EventFormatter::new(),
260 : };
261 :
262 1 : formatter.reset();
263 1 : formatter.format(
264 1 : now,
265 1 : event,
266 1 : &ctx,
267 1 : &self.skipped_field_indices,
268 1 : self.extract_fields,
269 1 : )?;
270 1 : self.writer.make_writer().write_all(formatter.buffer())
271 1 : });
272 :
273 : // In case logging fails we generate a simpler JSON object.
274 1 : if let Err(err) = res {
275 0 : if let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
276 0 : "timestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
277 0 : "level": "ERROR",
278 0 : "message": format_args!("cannot log event: {err:?}"),
279 0 : "fields": {
280 0 : "event": format_args!("{event:?}"),
281 0 : },
282 0 : })) {
283 0 : line.push(b'\n');
284 0 : self.writer.make_writer().write_all(&line).ok();
285 0 : }
286 1 : }
287 1 : }
288 :
289 : /// Registers a SpanFields instance as span extension.
290 2 : fn on_new_span(&self, attrs: &span::Attributes<'_>, id: &span::Id, ctx: Context<'_, S>) {
291 2 : let span = ctx.span(id).expect("span must exist");
292 2 :
293 2 : let mut fields = SpanFields::new(self.span_info(span.metadata()));
294 2 : attrs.record(&mut fields);
295 2 :
296 2 : // This is a new span: the extensions should not be locked
297 2 : // unless some layer spawned a thread to process this span.
298 2 : // I don't think any layers do that.
299 2 : span.extensions_mut().insert(fields);
300 2 : }
301 :
302 0 : fn on_record(&self, id: &span::Id, values: &span::Record<'_>, ctx: Context<'_, S>) {
303 0 : let span = ctx.span(id).expect("span must exist");
304 0 :
305 0 : // assumption: `on_record` is rarely called.
306 0 : // assumption: a span being updated by one thread,
307 0 : // and formatted by another thread is even rarer.
308 0 : let mut ext = span.extensions_mut();
309 0 : if let Some(fields) = ext.get_mut::<SpanFields>() {
310 0 : values.record(fields);
311 0 : }
312 0 : }
313 :
314 : /// Called (lazily) roughly once per event/span instance. We quickly check
315 : /// for duplicate field names and record duplicates as skippable. Last field wins.
316 3 : fn register_callsite(&self, metadata: &'static Metadata<'static>) -> Interest {
317 3 : debug_assert!(
318 3 : metadata.fields().len() <= MAX_TRACING_FIELDS,
319 0 : "callsite {metadata:?} has too many fields."
320 : );
321 :
322 3 : if !metadata.is_event() {
323 : // register the span info.
324 2 : self.span_info(metadata);
325 2 : // Must not be never because we wouldn't get trace and span data.
326 2 : return Interest::always();
327 1 : }
328 1 :
329 1 : let mut field_indices = SkippedFieldIndices::default();
330 1 : let mut seen_fields = HashMap::new();
331 5 : for field in metadata.fields() {
332 5 : if let Some(old_index) = seen_fields.insert(field.name(), field.index()) {
333 3 : field_indices.set(old_index);
334 3 : }
335 : }
336 :
337 1 : if !field_indices.is_empty() {
338 1 : self.skipped_field_indices
339 1 : .pin()
340 1 : .insert(metadata.callsite(), field_indices);
341 1 : }
342 :
343 1 : Interest::always()
344 3 : }
345 : }
346 :
347 : /// Any span info that is fixed to a particular callsite. Not variable between span instances.
348 : #[derive(Clone)]
349 : struct CallsiteSpanInfo {
350 : /// index of each field to extract. usize::MAX if not found.
351 : extract: Arc<[usize]>,
352 :
353 : /// tracks the fixed "callsite ID" for each span.
354 : /// note: this is not stable between runs.
355 : normalized_name: Arc<str>,
356 : }
357 :
358 : impl CallsiteSpanInfo {
359 2 : fn new(metadata: &'static Metadata<'static>, extract_fields: &[&'static str]) -> Self {
360 : // Start at 1 to reserve 0 for default.
361 : static COUNTER: AtomicU32 = AtomicU32::new(1);
362 :
363 4 : let names: Vec<&'static str> = metadata.fields().iter().map(|f| f.name()).collect();
364 2 :
365 2 : // get all the indices of span fields we want to focus
366 2 : let extract = extract_fields
367 2 : .iter()
368 2 : // use rposition, since we want last match wins.
369 2 : .map(|f1| names.iter().rposition(|f2| f1 == f2).unwrap_or(usize::MAX))
370 2 : .collect();
371 2 :
372 2 : // normalized_name is unique for each callsite, but it is not
373 2 : // unified across separate proxy instances.
374 2 : // todo: can we do better here?
375 2 : let cid = COUNTER.fetch_add(1, Ordering::Relaxed);
376 2 : let normalized_name = format!("{}#{cid}", metadata.name()).into();
377 2 :
378 2 : Self {
379 2 : extract,
380 2 : normalized_name,
381 2 : }
382 2 : }
383 : }
384 :
385 : /// Stores span field values recorded during the spans lifetime.
386 : struct SpanFields {
387 : values: [serde_json::Value; MAX_TRACING_FIELDS],
388 :
389 : /// cached span info so we can avoid extra hashmap lookups in the hot path.
390 : span_info: CallsiteSpanInfo,
391 : }
392 :
393 : impl SpanFields {
394 2 : fn new(span_info: CallsiteSpanInfo) -> Self {
395 2 : Self {
396 2 : span_info,
397 2 : values: [const { serde_json::Value::Null }; MAX_TRACING_FIELDS],
398 2 : }
399 2 : }
400 : }
401 :
402 : impl tracing::field::Visit for SpanFields {
403 : #[inline]
404 0 : fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
405 0 : self.values[field.index()] = serde_json::Value::from(value);
406 0 : }
407 :
408 : #[inline]
409 4 : fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
410 4 : self.values[field.index()] = serde_json::Value::from(value);
411 4 : }
412 :
413 : #[inline]
414 0 : fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
415 0 : self.values[field.index()] = serde_json::Value::from(value);
416 0 : }
417 :
418 : #[inline]
419 0 : fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
420 0 : if let Ok(value) = i64::try_from(value) {
421 0 : self.values[field.index()] = serde_json::Value::from(value);
422 0 : } else {
423 0 : self.values[field.index()] = serde_json::Value::from(format!("{value}"));
424 0 : }
425 0 : }
426 :
427 : #[inline]
428 0 : fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
429 0 : if let Ok(value) = u64::try_from(value) {
430 0 : self.values[field.index()] = serde_json::Value::from(value);
431 0 : } else {
432 0 : self.values[field.index()] = serde_json::Value::from(format!("{value}"));
433 0 : }
434 0 : }
435 :
436 : #[inline]
437 0 : fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
438 0 : self.values[field.index()] = serde_json::Value::from(value);
439 0 : }
440 :
441 : #[inline]
442 0 : fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
443 0 : self.values[field.index()] = serde_json::Value::from(value);
444 0 : }
445 :
446 : #[inline]
447 0 : fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
448 0 : self.values[field.index()] = serde_json::Value::from(value);
449 0 : }
450 :
451 : #[inline]
452 0 : fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
453 0 : self.values[field.index()] = serde_json::Value::from(format!("{value:?}"));
454 0 : }
455 :
456 : #[inline]
457 0 : fn record_error(
458 0 : &mut self,
459 0 : field: &tracing::field::Field,
460 0 : value: &(dyn std::error::Error + 'static),
461 0 : ) {
462 0 : self.values[field.index()] = serde_json::Value::from(format!("{value}"));
463 0 : }
464 : }
465 :
466 : /// List of field indices skipped during logging. Can list duplicate fields or
467 : /// metafields not meant to be logged.
468 : #[derive(Copy, Clone, Default)]
469 : struct SkippedFieldIndices {
470 : // 32-bits is large enough for `MAX_TRACING_FIELDS`
471 : bits: u32,
472 : }
473 :
474 : impl SkippedFieldIndices {
475 : #[inline]
476 1 : fn is_empty(self) -> bool {
477 1 : self.bits == 0
478 1 : }
479 :
480 : #[inline]
481 3 : fn set(&mut self, index: usize) {
482 3 : debug_assert!(index <= 32, "index out of bounds of 32-bit set");
483 3 : self.bits |= 1 << index;
484 3 : }
485 :
486 : #[inline]
487 10 : fn contains(self, index: usize) -> bool {
488 10 : self.bits & (1 << index) != 0
489 10 : }
490 : }
491 :
492 : /// Formats a tracing event and writes JSON to its internal buffer including a newline.
493 : // TODO: buffer capacity management, truncate if too large
494 : struct EventFormatter {
495 : logline_buffer: Vec<u8>,
496 : }
497 :
498 : impl EventFormatter {
499 : #[inline]
500 0 : const fn new() -> Self {
501 0 : EventFormatter {
502 0 : logline_buffer: Vec::new(),
503 0 : }
504 0 : }
505 :
506 : #[inline]
507 1 : fn buffer(&self) -> &[u8] {
508 1 : &self.logline_buffer
509 1 : }
510 :
511 : #[inline]
512 1 : fn reset(&mut self) {
513 1 : self.logline_buffer.clear();
514 1 : }
515 :
516 1 : fn format<S>(
517 1 : &mut self,
518 1 : now: DateTime<Utc>,
519 1 : event: &Event<'_>,
520 1 : ctx: &Context<'_, S>,
521 1 : skipped_field_indices: &CallsiteMap<SkippedFieldIndices>,
522 1 : extract_fields: &'static [&'static str],
523 1 : ) -> io::Result<()>
524 1 : where
525 1 : S: Subscriber + for<'a> LookupSpan<'a>,
526 1 : {
527 1 : let timestamp = now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true);
528 :
529 : use tracing_log::NormalizeEvent;
530 1 : let normalized_meta = event.normalized_metadata();
531 1 : let meta = normalized_meta.as_ref().unwrap_or_else(|| event.metadata());
532 1 :
533 1 : let skipped_field_indices = skipped_field_indices
534 1 : .pin()
535 1 : .get(&meta.callsite())
536 1 : .copied()
537 1 : .unwrap_or_default();
538 1 :
539 1 : let mut serialize = || {
540 1 : let mut serializer = serde_json::Serializer::new(&mut self.logline_buffer);
541 :
542 1 : let mut serializer = serializer.serialize_map(None)?;
543 :
544 : // Timestamp comes first, so raw lines can be sorted by timestamp.
545 1 : serializer.serialize_entry("timestamp", ×tamp)?;
546 :
547 : // Level next.
548 1 : serializer.serialize_entry("level", &meta.level().as_str())?;
549 :
550 : // Message next.
551 1 : serializer.serialize_key("message")?;
552 1 : let mut message_extractor =
553 1 : MessageFieldExtractor::new(serializer, skipped_field_indices);
554 1 : event.record(&mut message_extractor);
555 1 : let mut serializer = message_extractor.into_serializer()?;
556 :
557 : // Direct message fields.
558 1 : let mut fields_present = FieldsPresent(false, skipped_field_indices);
559 1 : event.record(&mut fields_present);
560 1 : if fields_present.0 {
561 1 : serializer.serialize_entry(
562 1 : "fields",
563 1 : &SerializableEventFields(event, skipped_field_indices),
564 1 : )?;
565 0 : }
566 :
567 1 : let spans = SerializableSpans {
568 1 : // collect all spans from parent to root.
569 1 : spans: ctx
570 1 : .event_span(event)
571 1 : .map_or(vec![], |parent| parent.scope().collect()),
572 1 : extracted: ExtractedSpanFields::new(extract_fields),
573 1 : };
574 1 : serializer.serialize_entry("spans", &spans)?;
575 :
576 : // TODO: thread-local cache?
577 1 : let pid = std::process::id();
578 1 : // Skip adding pid 1 to reduce noise for services running in containers.
579 1 : if pid != 1 {
580 1 : serializer.serialize_entry("process_id", &pid)?;
581 0 : }
582 :
583 1 : THREAD_ID.with(|tid| serializer.serialize_entry("thread_id", tid))?;
584 :
585 : // TODO: tls cache? name could change
586 1 : if let Some(thread_name) = std::thread::current().name() {
587 1 : if !thread_name.is_empty() && thread_name != "tokio-runtime-worker" {
588 1 : serializer.serialize_entry("thread_name", thread_name)?;
589 0 : }
590 0 : }
591 :
592 1 : if let Some(task_id) = tokio::task::try_id() {
593 0 : serializer.serialize_entry("task_id", &format_args!("{task_id}"))?;
594 1 : }
595 :
596 1 : serializer.serialize_entry("target", meta.target())?;
597 :
598 : // Skip adding module if it's the same as target.
599 1 : if let Some(module) = meta.module_path() {
600 1 : if module != meta.target() {
601 0 : serializer.serialize_entry("module", module)?;
602 1 : }
603 0 : }
604 :
605 1 : if let Some(file) = meta.file() {
606 1 : if let Some(line) = meta.line() {
607 1 : serializer.serialize_entry("src", &format_args!("{file}:{line}"))?;
608 : } else {
609 0 : serializer.serialize_entry("src", file)?;
610 : }
611 0 : }
612 :
613 : {
614 1 : let otel_context = Span::current().context();
615 1 : let otel_spanref = otel_context.span();
616 1 : let span_context = otel_spanref.span_context();
617 1 : if span_context.is_valid() {
618 0 : serializer.serialize_entry(
619 0 : "trace_id",
620 0 : &format_args!("{}", span_context.trace_id()),
621 0 : )?;
622 1 : }
623 : }
624 :
625 1 : if spans.extracted.has_values() {
626 : // TODO: add fields from event, too?
627 1 : serializer.serialize_entry("extract", &spans.extracted)?;
628 0 : }
629 :
630 1 : serializer.end()
631 1 : };
632 :
633 1 : serialize().map_err(io::Error::other)?;
634 1 : self.logline_buffer.push(b'\n');
635 1 : Ok(())
636 1 : }
637 : }
638 :
639 : /// Extracts the message field that's mixed will other fields.
640 : struct MessageFieldExtractor<S: serde::ser::SerializeMap> {
641 : serializer: S,
642 : skipped_field_indices: SkippedFieldIndices,
643 : state: Option<Result<(), S::Error>>,
644 : }
645 :
646 : impl<S: serde::ser::SerializeMap> MessageFieldExtractor<S> {
647 : #[inline]
648 1 : fn new(serializer: S, skipped_field_indices: SkippedFieldIndices) -> Self {
649 1 : Self {
650 1 : serializer,
651 1 : skipped_field_indices,
652 1 : state: None,
653 1 : }
654 1 : }
655 :
656 : #[inline]
657 1 : fn into_serializer(mut self) -> Result<S, S::Error> {
658 1 : match self.state {
659 1 : Some(Ok(())) => {}
660 0 : Some(Err(err)) => return Err(err),
661 0 : None => self.serializer.serialize_value("")?,
662 : }
663 1 : Ok(self.serializer)
664 1 : }
665 :
666 : #[inline]
667 5 : fn accept_field(&self, field: &tracing::field::Field) -> bool {
668 5 : self.state.is_none()
669 5 : && field.name() == MESSAGE_FIELD
670 2 : && !self.skipped_field_indices.contains(field.index())
671 5 : }
672 : }
673 :
674 : impl<S: serde::ser::SerializeMap> tracing::field::Visit for MessageFieldExtractor<S> {
675 : #[inline]
676 0 : fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
677 0 : if self.accept_field(field) {
678 0 : self.state = Some(self.serializer.serialize_value(&value));
679 0 : }
680 0 : }
681 :
682 : #[inline]
683 3 : fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
684 3 : if self.accept_field(field) {
685 0 : self.state = Some(self.serializer.serialize_value(&value));
686 3 : }
687 3 : }
688 :
689 : #[inline]
690 0 : fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
691 0 : if self.accept_field(field) {
692 0 : self.state = Some(self.serializer.serialize_value(&value));
693 0 : }
694 0 : }
695 :
696 : #[inline]
697 0 : fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
698 0 : if self.accept_field(field) {
699 0 : self.state = Some(self.serializer.serialize_value(&value));
700 0 : }
701 0 : }
702 :
703 : #[inline]
704 0 : fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
705 0 : if self.accept_field(field) {
706 0 : self.state = Some(self.serializer.serialize_value(&value));
707 0 : }
708 0 : }
709 :
710 : #[inline]
711 0 : fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
712 0 : if self.accept_field(field) {
713 0 : self.state = Some(self.serializer.serialize_value(&value));
714 0 : }
715 0 : }
716 :
717 : #[inline]
718 0 : fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
719 0 : if self.accept_field(field) {
720 0 : self.state = Some(self.serializer.serialize_value(&format_args!("{value:x?}")));
721 0 : }
722 0 : }
723 :
724 : #[inline]
725 1 : fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
726 1 : if self.accept_field(field) {
727 1 : self.state = Some(self.serializer.serialize_value(&value));
728 1 : }
729 1 : }
730 :
731 : #[inline]
732 1 : fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
733 1 : if self.accept_field(field) {
734 0 : self.state = Some(self.serializer.serialize_value(&format_args!("{value:?}")));
735 1 : }
736 1 : }
737 :
738 : #[inline]
739 0 : fn record_error(
740 0 : &mut self,
741 0 : field: &tracing::field::Field,
742 0 : value: &(dyn std::error::Error + 'static),
743 0 : ) {
744 0 : if self.accept_field(field) {
745 0 : self.state = Some(self.serializer.serialize_value(&format_args!("{value}")));
746 0 : }
747 0 : }
748 : }
749 :
750 : /// Checks if there's any fields and field values present. If not, the JSON subobject
751 : /// can be skipped.
752 : // This is entirely optional and only cosmetic, though maybe helps a
753 : // bit during log parsing in dashboards when there's no field with empty object.
754 : struct FieldsPresent(pub bool, SkippedFieldIndices);
755 :
756 : // Even though some methods have an overhead (error, bytes) it is assumed the
757 : // compiler won't include this since we ignore the value entirely.
758 : impl tracing::field::Visit for FieldsPresent {
759 : #[inline]
760 5 : fn record_debug(&mut self, field: &tracing::field::Field, _: &dyn std::fmt::Debug) {
761 5 : if !self.1.contains(field.index())
762 2 : && field.name() != MESSAGE_FIELD
763 1 : && !field.name().starts_with("log.")
764 1 : {
765 1 : self.0 |= true;
766 4 : }
767 5 : }
768 : }
769 :
770 : /// Serializes the fields directly supplied with a log event.
771 : struct SerializableEventFields<'a, 'event>(&'a tracing::Event<'event>, SkippedFieldIndices);
772 :
773 : impl serde::ser::Serialize for SerializableEventFields<'_, '_> {
774 1 : fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
775 1 : where
776 1 : S: Serializer,
777 1 : {
778 : use serde::ser::SerializeMap;
779 1 : let serializer = serializer.serialize_map(None)?;
780 1 : let mut message_skipper = MessageFieldSkipper::new(serializer, self.1);
781 1 : self.0.record(&mut message_skipper);
782 1 : let serializer = message_skipper.into_serializer()?;
783 1 : serializer.end()
784 1 : }
785 : }
786 :
787 : /// A tracing field visitor that skips the message field.
788 : struct MessageFieldSkipper<S: serde::ser::SerializeMap> {
789 : serializer: S,
790 : skipped_field_indices: SkippedFieldIndices,
791 : state: Result<(), S::Error>,
792 : }
793 :
794 : impl<S: serde::ser::SerializeMap> MessageFieldSkipper<S> {
795 : #[inline]
796 1 : fn new(serializer: S, skipped_field_indices: SkippedFieldIndices) -> Self {
797 1 : Self {
798 1 : serializer,
799 1 : skipped_field_indices,
800 1 : state: Ok(()),
801 1 : }
802 1 : }
803 :
804 : #[inline]
805 5 : fn accept_field(&self, field: &tracing::field::Field) -> bool {
806 5 : self.state.is_ok()
807 5 : && field.name() != MESSAGE_FIELD
808 3 : && !field.name().starts_with("log.")
809 3 : && !self.skipped_field_indices.contains(field.index())
810 5 : }
811 :
812 : #[inline]
813 1 : fn into_serializer(self) -> Result<S, S::Error> {
814 1 : self.state?;
815 1 : Ok(self.serializer)
816 1 : }
817 : }
818 :
819 : impl<S: serde::ser::SerializeMap> tracing::field::Visit for MessageFieldSkipper<S> {
820 : #[inline]
821 0 : fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
822 0 : if self.accept_field(field) {
823 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
824 0 : }
825 0 : }
826 :
827 : #[inline]
828 3 : fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
829 3 : if self.accept_field(field) {
830 1 : self.state = self.serializer.serialize_entry(field.name(), &value);
831 2 : }
832 3 : }
833 :
834 : #[inline]
835 0 : fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
836 0 : if self.accept_field(field) {
837 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
838 0 : }
839 0 : }
840 :
841 : #[inline]
842 0 : fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
843 0 : if self.accept_field(field) {
844 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
845 0 : }
846 0 : }
847 :
848 : #[inline]
849 0 : fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
850 0 : if self.accept_field(field) {
851 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
852 0 : }
853 0 : }
854 :
855 : #[inline]
856 0 : fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
857 0 : if self.accept_field(field) {
858 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
859 0 : }
860 0 : }
861 :
862 : #[inline]
863 0 : fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
864 0 : if self.accept_field(field) {
865 0 : self.state = self
866 0 : .serializer
867 0 : .serialize_entry(field.name(), &format_args!("{value:x?}"));
868 0 : }
869 0 : }
870 :
871 : #[inline]
872 1 : fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
873 1 : if self.accept_field(field) {
874 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
875 1 : }
876 1 : }
877 :
878 : #[inline]
879 1 : fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
880 1 : if self.accept_field(field) {
881 0 : self.state = self
882 0 : .serializer
883 0 : .serialize_entry(field.name(), &format_args!("{value:?}"));
884 1 : }
885 1 : }
886 :
887 : #[inline]
888 0 : fn record_error(
889 0 : &mut self,
890 0 : field: &tracing::field::Field,
891 0 : value: &(dyn std::error::Error + 'static),
892 0 : ) {
893 0 : if self.accept_field(field) {
894 0 : self.state = self.serializer.serialize_value(&format_args!("{value}"));
895 0 : }
896 0 : }
897 : }
898 :
899 : /// Serializes the span stack from root to leaf (parent of event) as object
900 : /// with the span names as keys. To prevent collision we append a numberic value
901 : /// to the name. Also, collects any span fields we're interested in. Last one
902 : /// wins.
903 : struct SerializableSpans<'ctx, S>
904 : where
905 : S: for<'lookup> LookupSpan<'lookup>,
906 : {
907 : spans: Vec<SpanRef<'ctx, S>>,
908 : extracted: ExtractedSpanFields,
909 : }
910 :
911 : impl<S> serde::ser::Serialize for SerializableSpans<'_, S>
912 : where
913 : S: for<'lookup> LookupSpan<'lookup>,
914 : {
915 1 : fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
916 1 : where
917 1 : Ser: serde::ser::Serializer,
918 1 : {
919 1 : let mut serializer = serializer.serialize_map(None)?;
920 :
921 2 : for span in self.spans.iter().rev() {
922 2 : let ext = span.extensions();
923 :
924 : // all spans should have this extension.
925 2 : let Some(fields) = ext.get() else { continue };
926 :
927 2 : self.extracted.layer_span(fields);
928 2 :
929 2 : let SpanFields { values, span_info } = fields;
930 2 : serializer.serialize_entry(
931 2 : &*span_info.normalized_name,
932 2 : &SerializableSpanFields {
933 2 : fields: span.metadata().fields(),
934 2 : values,
935 2 : },
936 2 : )?;
937 : }
938 :
939 1 : serializer.end()
940 1 : }
941 : }
942 :
943 : /// Serializes the span fields as object.
944 : struct SerializableSpanFields<'span> {
945 : fields: &'span tracing::field::FieldSet,
946 : values: &'span [serde_json::Value; MAX_TRACING_FIELDS],
947 : }
948 :
949 : impl serde::ser::Serialize for SerializableSpanFields<'_> {
950 2 : fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
951 2 : where
952 2 : S: serde::ser::Serializer,
953 2 : {
954 2 : let mut serializer = serializer.serialize_map(None)?;
955 :
956 4 : for (field, value) in std::iter::zip(self.fields, self.values) {
957 4 : if value.is_null() {
958 0 : continue;
959 4 : }
960 4 : serializer.serialize_entry(field.name(), value)?;
961 : }
962 :
963 2 : serializer.end()
964 2 : }
965 : }
966 :
967 : struct ExtractedSpanFields {
968 : names: &'static [&'static str],
969 : values: RefCell<Vec<serde_json::Value>>,
970 : }
971 :
972 : impl ExtractedSpanFields {
973 1 : fn new(names: &'static [&'static str]) -> Self {
974 1 : ExtractedSpanFields {
975 1 : names,
976 1 : values: RefCell::new(vec![serde_json::Value::Null; names.len()]),
977 1 : }
978 1 : }
979 :
980 2 : fn layer_span(&self, fields: &SpanFields) {
981 2 : let mut v = self.values.borrow_mut();
982 2 : let SpanFields { values, span_info } = fields;
983 :
984 : // extract the fields
985 2 : for (i, &j) in span_info.extract.iter().enumerate() {
986 2 : let Some(value) = values.get(j) else { continue };
987 :
988 2 : if !value.is_null() {
989 2 : // TODO: replace clone with reference, if possible.
990 2 : v[i] = value.clone();
991 2 : }
992 : }
993 2 : }
994 :
995 : #[inline]
996 1 : fn has_values(&self) -> bool {
997 1 : self.values.borrow().iter().any(|v| !v.is_null())
998 1 : }
999 : }
1000 :
1001 : impl serde::ser::Serialize for ExtractedSpanFields {
1002 1 : fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1003 1 : where
1004 1 : S: serde::ser::Serializer,
1005 1 : {
1006 1 : let mut serializer = serializer.serialize_map(None)?;
1007 :
1008 1 : let values = self.values.borrow();
1009 1 : for (key, value) in std::iter::zip(self.names, &*values) {
1010 1 : if value.is_null() {
1011 0 : continue;
1012 1 : }
1013 1 :
1014 1 : serializer.serialize_entry(key, value)?;
1015 : }
1016 :
1017 1 : serializer.end()
1018 1 : }
1019 : }
1020 :
1021 : #[cfg(test)]
1022 : mod tests {
1023 : use std::sync::{Arc, Mutex, MutexGuard};
1024 :
1025 : use assert_json_diff::assert_json_eq;
1026 : use tracing::info_span;
1027 :
1028 : use super::*;
1029 :
1030 : struct TestClock {
1031 : current_time: Mutex<DateTime<Utc>>,
1032 : }
1033 :
1034 : impl Clock for Arc<TestClock> {
1035 2 : fn now(&self) -> DateTime<Utc> {
1036 2 : *self.current_time.lock().expect("poisoned")
1037 2 : }
1038 : }
1039 :
1040 : struct VecWriter<'a> {
1041 : buffer: MutexGuard<'a, Vec<u8>>,
1042 : }
1043 :
1044 : impl MakeWriter for Arc<Mutex<Vec<u8>>> {
1045 1 : fn make_writer(&self) -> impl io::Write {
1046 1 : VecWriter {
1047 1 : buffer: self.lock().expect("poisoned"),
1048 1 : }
1049 1 : }
1050 : }
1051 :
1052 : impl io::Write for VecWriter<'_> {
1053 1 : fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1054 1 : self.buffer.write(buf)
1055 1 : }
1056 :
1057 0 : fn flush(&mut self) -> io::Result<()> {
1058 0 : Ok(())
1059 0 : }
1060 : }
1061 :
1062 : #[test]
1063 1 : fn test_field_collection() {
1064 1 : let clock = Arc::new(TestClock {
1065 1 : current_time: Mutex::new(Utc::now()),
1066 1 : });
1067 1 : let buffer = Arc::new(Mutex::new(Vec::new()));
1068 1 : let log_layer = JsonLoggingLayer {
1069 1 : clock: clock.clone(),
1070 1 : skipped_field_indices: papaya::HashMap::default(),
1071 1 : span_info: papaya::HashMap::default(),
1072 1 : writer: buffer.clone(),
1073 1 : extract_fields: &["x"],
1074 1 : };
1075 1 :
1076 1 : let registry = tracing_subscriber::Registry::default().with(log_layer);
1077 1 :
1078 1 : tracing::subscriber::with_default(registry, || {
1079 1 : info_span!("some_span", x = 24).in_scope(|| {
1080 1 : info_span!("some_span", x = 40, x = 41, x = 42).in_scope(|| {
1081 1 : tracing::error!(
1082 : a = 1,
1083 : a = 2,
1084 : a = 3,
1085 : message = "explicit message field",
1086 0 : "implicit message field"
1087 : );
1088 1 : });
1089 1 : });
1090 1 : });
1091 1 :
1092 1 : let buffer = Arc::try_unwrap(buffer)
1093 1 : .expect("no other reference")
1094 1 : .into_inner()
1095 1 : .expect("poisoned");
1096 1 : let actual: serde_json::Value = serde_json::from_slice(&buffer).expect("valid JSON");
1097 1 : let expected: serde_json::Value = serde_json::json!(
1098 1 : {
1099 1 : "timestamp": clock.now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
1100 1 : "level": "ERROR",
1101 1 : "message": "explicit message field",
1102 1 : "fields": {
1103 1 : "a": 3,
1104 1 : },
1105 1 : "spans": {
1106 1 : "some_span#1":{
1107 1 : "x": 24,
1108 1 : },
1109 1 : "some_span#2": {
1110 1 : "x": 42,
1111 1 : }
1112 1 : },
1113 1 : "extract": {
1114 1 : "x": 42,
1115 1 : },
1116 1 : "src": actual.as_object().unwrap().get("src").unwrap().as_str().unwrap(),
1117 1 : "target": "proxy::logging::tests",
1118 1 : "process_id": actual.as_object().unwrap().get("process_id").unwrap().as_number().unwrap(),
1119 1 : "thread_id": actual.as_object().unwrap().get("thread_id").unwrap().as_number().unwrap(),
1120 1 : "thread_name": "logging::tests::test_field_collection",
1121 1 : }
1122 1 : );
1123 1 :
1124 1 : assert_json_eq!(actual, expected);
1125 1 : }
1126 : }
|