Line data Source code
1 : use std::cell::RefCell;
2 : use std::collections::HashMap;
3 : use std::sync::Arc;
4 : use std::sync::atomic::{AtomicU32, Ordering};
5 : use std::{env, io};
6 :
7 : use chrono::{DateTime, Utc};
8 : use opentelemetry::trace::TraceContextExt;
9 : use serde::ser::{SerializeMap, Serializer};
10 : use tracing::subscriber::Interest;
11 : use tracing::{Event, Metadata, Span, Subscriber, callsite, span};
12 : use tracing_opentelemetry::OpenTelemetrySpanExt;
13 : use tracing_subscriber::filter::{EnvFilter, LevelFilter};
14 : use tracing_subscriber::fmt::format::{Format, Full};
15 : use tracing_subscriber::fmt::time::SystemTime;
16 : use tracing_subscriber::fmt::{FormatEvent, FormatFields};
17 : use tracing_subscriber::layer::{Context, Layer};
18 : use tracing_subscriber::prelude::*;
19 : use tracing_subscriber::registry::{LookupSpan, SpanRef};
20 :
21 : /// Initialize logging and OpenTelemetry tracing and exporter.
22 : ///
23 : /// Logging can be configured using `RUST_LOG` environment variable.
24 : ///
25 : /// OpenTelemetry is configured with OTLP/HTTP exporter. It picks up
26 : /// configuration from environment variables. For example, to change the
27 : /// destination, set `OTEL_EXPORTER_OTLP_ENDPOINT=http://jaeger:4318`.
28 : /// See <https://opentelemetry.io/docs/reference/specification/sdk-environment-variables>
29 0 : pub async fn init() -> anyhow::Result<LoggingGuard> {
30 0 : let logfmt = LogFormat::from_env()?;
31 :
32 0 : let env_filter = EnvFilter::builder()
33 0 : .with_default_directive(LevelFilter::INFO.into())
34 0 : .from_env_lossy()
35 0 : .add_directive(
36 0 : "aws_config=info"
37 0 : .parse()
38 0 : .expect("this should be a valid filter directive"),
39 : )
40 0 : .add_directive(
41 0 : "azure_core::policies::transport=off"
42 0 : .parse()
43 0 : .expect("this should be a valid filter directive"),
44 : );
45 :
46 0 : let otlp_layer =
47 0 : tracing_utils::init_tracing("proxy", tracing_utils::ExportConfig::default()).await;
48 :
49 0 : let json_log_layer = if logfmt == LogFormat::Json {
50 0 : Some(JsonLoggingLayer::new(
51 0 : RealClock,
52 0 : StderrWriter {
53 0 : stderr: std::io::stderr(),
54 0 : },
55 0 : &["conn_id", "ep", "query_id", "request_id", "session_id"],
56 0 : ))
57 : } else {
58 0 : None
59 : };
60 :
61 0 : let text_log_layer = if logfmt == LogFormat::Text {
62 0 : Some(
63 0 : tracing_subscriber::fmt::layer()
64 0 : .with_ansi(false)
65 0 : .with_writer(std::io::stderr)
66 0 : .with_target(false),
67 0 : )
68 : } else {
69 0 : None
70 : };
71 :
72 0 : tracing_subscriber::registry()
73 0 : .with(env_filter)
74 0 : .with(otlp_layer)
75 0 : .with(json_log_layer)
76 0 : .with(text_log_layer)
77 0 : .try_init()?;
78 :
79 0 : Ok(LoggingGuard)
80 0 : }
81 :
82 : /// Initialize logging for local_proxy with log prefix and no opentelemetry.
83 : ///
84 : /// Logging can be configured using `RUST_LOG` environment variable.
85 0 : pub fn init_local_proxy() -> anyhow::Result<LoggingGuard> {
86 0 : let env_filter = EnvFilter::builder()
87 0 : .with_default_directive(LevelFilter::INFO.into())
88 0 : .from_env_lossy();
89 :
90 0 : let fmt_layer = tracing_subscriber::fmt::layer()
91 0 : .with_ansi(false)
92 0 : .with_writer(std::io::stderr)
93 0 : .event_format(LocalProxyFormatter(Format::default().with_target(false)));
94 :
95 0 : tracing_subscriber::registry()
96 0 : .with(env_filter)
97 0 : .with(fmt_layer)
98 0 : .try_init()?;
99 :
100 0 : Ok(LoggingGuard)
101 0 : }
102 :
103 : pub struct LocalProxyFormatter(Format<Full, SystemTime>);
104 :
105 : impl<S, N> FormatEvent<S, N> for LocalProxyFormatter
106 : where
107 : S: Subscriber + for<'a> LookupSpan<'a>,
108 : N: for<'a> FormatFields<'a> + 'static,
109 : {
110 0 : fn format_event(
111 0 : &self,
112 0 : ctx: &tracing_subscriber::fmt::FmtContext<'_, S, N>,
113 0 : mut writer: tracing_subscriber::fmt::format::Writer<'_>,
114 0 : event: &tracing::Event<'_>,
115 0 : ) -> std::fmt::Result {
116 0 : writer.write_str("[local_proxy] ")?;
117 0 : self.0.format_event(ctx, writer, event)
118 0 : }
119 : }
120 :
121 : pub struct LoggingGuard;
122 :
123 : impl Drop for LoggingGuard {
124 0 : fn drop(&mut self) {
125 : // Shutdown trace pipeline gracefully, so that it has a chance to send any
126 : // pending traces before we exit.
127 0 : tracing::info!("shutting down the tracing machinery");
128 0 : tracing_utils::shutdown_tracing();
129 0 : }
130 : }
131 :
132 : #[derive(Copy, Clone, PartialEq, Eq, Default, Debug)]
133 : enum LogFormat {
134 : Text,
135 : #[default]
136 : Json,
137 : }
138 :
139 : impl LogFormat {
140 0 : fn from_env() -> anyhow::Result<Self> {
141 0 : let logfmt = env::var("LOGFMT");
142 0 : Ok(match logfmt.as_deref() {
143 0 : Err(_) => LogFormat::default(),
144 0 : Ok("text") => LogFormat::Text,
145 0 : Ok("json") => LogFormat::Json,
146 0 : Ok(logfmt) => anyhow::bail!("unknown log format: {logfmt}"),
147 : })
148 0 : }
149 : }
150 :
151 : trait MakeWriter {
152 : fn make_writer(&self) -> impl io::Write;
153 : }
154 :
155 : struct StderrWriter {
156 : stderr: io::Stderr,
157 : }
158 :
159 : impl MakeWriter for StderrWriter {
160 : #[inline]
161 0 : fn make_writer(&self) -> impl io::Write {
162 0 : self.stderr.lock()
163 0 : }
164 : }
165 :
166 : // TODO: move into separate module or even separate crate.
167 : trait Clock {
168 : fn now(&self) -> DateTime<Utc>;
169 : }
170 :
171 : struct RealClock;
172 :
173 : impl Clock for RealClock {
174 : #[inline]
175 0 : fn now(&self) -> DateTime<Utc> {
176 0 : Utc::now()
177 0 : }
178 : }
179 :
180 : /// Name of the field used by tracing crate to store the event message.
181 : const MESSAGE_FIELD: &str = "message";
182 :
183 : /// Tracing used to enforce that spans/events have no more than 32 fields.
184 : /// It seems this is no longer the case, but it's still documented in some places.
185 : /// Generally, we shouldn't expect more than 32 fields anyway, so we can try and
186 : /// rely on it for some (minor) performance gains.
187 : const MAX_TRACING_FIELDS: usize = 32;
188 :
189 : thread_local! {
190 : /// Thread-local instance with per-thread buffer for log writing.
191 : static EVENT_FORMATTER: RefCell<EventFormatter> = const { RefCell::new(EventFormatter::new()) };
192 : /// Cached OS thread ID.
193 : static THREAD_ID: u64 = gettid::gettid();
194 : }
195 :
196 : /// Map for values fixed at callsite registration.
197 : // We use papaya here because registration rarely happens post-startup.
198 : // papaya is good for read-heavy workloads.
199 : //
200 : // We use rustc_hash here because callsite::Identifier will always be an integer with low-bit entropy,
201 : // since it's always a pointer to static mutable data. rustc_hash was designed for low-bit entropy.
202 : type CallsiteMap<T> =
203 : papaya::HashMap<callsite::Identifier, T, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>;
204 :
205 : /// Implements tracing layer to handle events specific to logging.
206 : struct JsonLoggingLayer<C: Clock, W: MakeWriter> {
207 : clock: C,
208 : writer: W,
209 :
210 : /// tracks which fields of each **event** are duplicates
211 : skipped_field_indices: CallsiteMap<SkippedFieldIndices>,
212 :
213 : span_info: CallsiteMap<CallsiteSpanInfo>,
214 :
215 : /// Fields we want to keep track of in a separate json object.
216 : extract_fields: &'static [&'static str],
217 : }
218 :
219 : impl<C: Clock, W: MakeWriter> JsonLoggingLayer<C, W> {
220 0 : fn new(clock: C, writer: W, extract_fields: &'static [&'static str]) -> Self {
221 0 : JsonLoggingLayer {
222 0 : clock,
223 0 : skipped_field_indices: CallsiteMap::default(),
224 0 : span_info: CallsiteMap::default(),
225 0 : writer,
226 0 : extract_fields,
227 0 : }
228 0 : }
229 :
230 : #[inline]
231 4 : fn span_info(&self, metadata: &'static Metadata<'static>) -> CallsiteSpanInfo {
232 4 : self.span_info
233 4 : .pin()
234 4 : .get_or_insert_with(metadata.callsite(), || {
235 2 : CallsiteSpanInfo::new(metadata, self.extract_fields)
236 2 : })
237 4 : .clone()
238 4 : }
239 : }
240 :
241 : impl<S, C: Clock + 'static, W: MakeWriter + 'static> Layer<S> for JsonLoggingLayer<C, W>
242 : where
243 : S: Subscriber + for<'a> LookupSpan<'a>,
244 : {
245 1 : fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
246 : use std::io::Write;
247 :
248 : // TODO: consider special tracing subscriber to grab timestamp very
249 : // early, before OTel machinery, and add as event extension.
250 1 : let now = self.clock.now();
251 :
252 1 : let res: io::Result<()> = EVENT_FORMATTER.with(|f| {
253 1 : let mut borrow = f.try_borrow_mut();
254 1 : let formatter = match borrow.as_deref_mut() {
255 1 : Ok(formatter) => formatter,
256 : // If the thread local formatter is borrowed,
257 : // then we likely hit an edge case were we panicked during formatting.
258 : // We allow the logging to proceed with an uncached formatter.
259 0 : Err(_) => &mut EventFormatter::new(),
260 : };
261 :
262 1 : formatter.reset();
263 1 : formatter.format(
264 1 : now,
265 1 : event,
266 1 : &ctx,
267 1 : &self.skipped_field_indices,
268 1 : self.extract_fields,
269 0 : )?;
270 1 : self.writer.make_writer().write_all(formatter.buffer())
271 1 : });
272 :
273 : // In case logging fails we generate a simpler JSON object.
274 1 : if let Err(err) = res
275 0 : && let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
276 0 : "timestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
277 0 : "level": "ERROR",
278 0 : "message": format_args!("cannot log event: {err:?}"),
279 0 : "fields": {
280 0 : "event": format_args!("{event:?}"),
281 0 : },
282 0 : }))
283 0 : {
284 0 : line.push(b'\n');
285 0 : self.writer.make_writer().write_all(&line).ok();
286 1 : }
287 1 : }
288 :
289 : /// Registers a SpanFields instance as span extension.
290 2 : fn on_new_span(&self, attrs: &span::Attributes<'_>, id: &span::Id, ctx: Context<'_, S>) {
291 2 : let span = ctx.span(id).expect("span must exist");
292 :
293 2 : let mut fields = SpanFields::new(self.span_info(span.metadata()));
294 2 : attrs.record(&mut fields);
295 :
296 : // This is a new span: the extensions should not be locked
297 : // unless some layer spawned a thread to process this span.
298 : // I don't think any layers do that.
299 2 : span.extensions_mut().insert(fields);
300 2 : }
301 :
302 0 : fn on_record(&self, id: &span::Id, values: &span::Record<'_>, ctx: Context<'_, S>) {
303 0 : let span = ctx.span(id).expect("span must exist");
304 :
305 : // assumption: `on_record` is rarely called.
306 : // assumption: a span being updated by one thread,
307 : // and formatted by another thread is even rarer.
308 0 : let mut ext = span.extensions_mut();
309 0 : if let Some(fields) = ext.get_mut::<SpanFields>() {
310 0 : values.record(fields);
311 0 : }
312 0 : }
313 :
314 : /// Called (lazily) roughly once per event/span instance. We quickly check
315 : /// for duplicate field names and record duplicates as skippable. Last field wins.
316 3 : fn register_callsite(&self, metadata: &'static Metadata<'static>) -> Interest {
317 3 : debug_assert!(
318 3 : metadata.fields().len() <= MAX_TRACING_FIELDS,
319 0 : "callsite {metadata:?} has too many fields."
320 : );
321 :
322 3 : if !metadata.is_event() {
323 : // register the span info.
324 2 : self.span_info(metadata);
325 : // Must not be never because we wouldn't get trace and span data.
326 2 : return Interest::always();
327 1 : }
328 :
329 1 : let mut field_indices = SkippedFieldIndices::default();
330 1 : let mut seen_fields = HashMap::new();
331 5 : for field in metadata.fields() {
332 5 : if let Some(old_index) = seen_fields.insert(field.name(), field.index()) {
333 3 : field_indices.set(old_index);
334 3 : }
335 : }
336 :
337 1 : if !field_indices.is_empty() {
338 1 : self.skipped_field_indices
339 1 : .pin()
340 1 : .insert(metadata.callsite(), field_indices);
341 1 : }
342 :
343 1 : Interest::always()
344 3 : }
345 : }
346 :
347 : /// Any span info that is fixed to a particular callsite. Not variable between span instances.
348 : #[derive(Clone)]
349 : struct CallsiteSpanInfo {
350 : /// index of each field to extract. usize::MAX if not found.
351 : extract: Arc<[usize]>,
352 :
353 : /// tracks the fixed "callsite ID" for each span.
354 : /// note: this is not stable between runs.
355 : normalized_name: Arc<str>,
356 : }
357 :
358 : impl CallsiteSpanInfo {
359 2 : fn new(metadata: &'static Metadata<'static>, extract_fields: &[&'static str]) -> Self {
360 : // Start at 1 to reserve 0 for default.
361 : static COUNTER: AtomicU32 = AtomicU32::new(1);
362 :
363 4 : let names: Vec<&'static str> = metadata.fields().iter().map(|f| f.name()).collect();
364 :
365 : // get all the indices of span fields we want to focus
366 2 : let extract = extract_fields
367 2 : .iter()
368 : // use rposition, since we want last match wins.
369 2 : .map(|f1| names.iter().rposition(|f2| f1 == f2).unwrap_or(usize::MAX))
370 2 : .collect();
371 :
372 : // normalized_name is unique for each callsite, but it is not
373 : // unified across separate proxy instances.
374 : // todo: can we do better here?
375 2 : let cid = COUNTER.fetch_add(1, Ordering::Relaxed);
376 2 : let normalized_name = format!("{}#{cid}", metadata.name()).into();
377 :
378 2 : Self {
379 2 : extract,
380 2 : normalized_name,
381 2 : }
382 2 : }
383 : }
384 :
385 : /// Stores span field values recorded during the spans lifetime.
386 : struct SpanFields {
387 : values: [serde_json::Value; MAX_TRACING_FIELDS],
388 :
389 : /// cached span info so we can avoid extra hashmap lookups in the hot path.
390 : span_info: CallsiteSpanInfo,
391 : }
392 :
393 : impl SpanFields {
394 2 : fn new(span_info: CallsiteSpanInfo) -> Self {
395 : Self {
396 2 : span_info,
397 : values: [const { serde_json::Value::Null }; MAX_TRACING_FIELDS],
398 : }
399 2 : }
400 : }
401 :
402 : impl tracing::field::Visit for SpanFields {
403 : #[inline]
404 0 : fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
405 0 : self.values[field.index()] = serde_json::Value::from(value);
406 0 : }
407 :
408 : #[inline]
409 4 : fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
410 4 : self.values[field.index()] = serde_json::Value::from(value);
411 4 : }
412 :
413 : #[inline]
414 0 : fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
415 0 : self.values[field.index()] = serde_json::Value::from(value);
416 0 : }
417 :
418 : #[inline]
419 0 : fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
420 0 : if let Ok(value) = i64::try_from(value) {
421 0 : self.values[field.index()] = serde_json::Value::from(value);
422 0 : } else {
423 0 : self.values[field.index()] = serde_json::Value::from(format!("{value}"));
424 0 : }
425 0 : }
426 :
427 : #[inline]
428 0 : fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
429 0 : if let Ok(value) = u64::try_from(value) {
430 0 : self.values[field.index()] = serde_json::Value::from(value);
431 0 : } else {
432 0 : self.values[field.index()] = serde_json::Value::from(format!("{value}"));
433 0 : }
434 0 : }
435 :
436 : #[inline]
437 0 : fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
438 0 : self.values[field.index()] = serde_json::Value::from(value);
439 0 : }
440 :
441 : #[inline]
442 0 : fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
443 0 : self.values[field.index()] = serde_json::Value::from(value);
444 0 : }
445 :
446 : #[inline]
447 0 : fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
448 0 : self.values[field.index()] = serde_json::Value::from(value);
449 0 : }
450 :
451 : #[inline]
452 0 : fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
453 0 : self.values[field.index()] = serde_json::Value::from(format!("{value:?}"));
454 0 : }
455 :
456 : #[inline]
457 0 : fn record_error(
458 0 : &mut self,
459 0 : field: &tracing::field::Field,
460 0 : value: &(dyn std::error::Error + 'static),
461 0 : ) {
462 0 : self.values[field.index()] = serde_json::Value::from(format!("{value}"));
463 0 : }
464 : }
465 :
466 : /// List of field indices skipped during logging. Can list duplicate fields or
467 : /// metafields not meant to be logged.
468 : #[derive(Copy, Clone, Default)]
469 : struct SkippedFieldIndices {
470 : // 32-bits is large enough for `MAX_TRACING_FIELDS`
471 : bits: u32,
472 : }
473 :
474 : impl SkippedFieldIndices {
475 : #[inline]
476 1 : fn is_empty(self) -> bool {
477 1 : self.bits == 0
478 1 : }
479 :
480 : #[inline]
481 3 : fn set(&mut self, index: usize) {
482 3 : debug_assert!(index <= 32, "index out of bounds of 32-bit set");
483 3 : self.bits |= 1 << index;
484 3 : }
485 :
486 : #[inline]
487 10 : fn contains(self, index: usize) -> bool {
488 10 : self.bits & (1 << index) != 0
489 10 : }
490 : }
491 :
492 : /// Formats a tracing event and writes JSON to its internal buffer including a newline.
493 : // TODO: buffer capacity management, truncate if too large
494 : struct EventFormatter {
495 : logline_buffer: Vec<u8>,
496 : }
497 :
498 : impl EventFormatter {
499 : #[inline]
500 0 : const fn new() -> Self {
501 0 : EventFormatter {
502 0 : logline_buffer: Vec::new(),
503 0 : }
504 0 : }
505 :
506 : #[inline]
507 1 : fn buffer(&self) -> &[u8] {
508 1 : &self.logline_buffer
509 1 : }
510 :
511 : #[inline]
512 1 : fn reset(&mut self) {
513 1 : self.logline_buffer.clear();
514 1 : }
515 :
516 1 : fn format<S>(
517 1 : &mut self,
518 1 : now: DateTime<Utc>,
519 1 : event: &Event<'_>,
520 1 : ctx: &Context<'_, S>,
521 1 : skipped_field_indices: &CallsiteMap<SkippedFieldIndices>,
522 1 : extract_fields: &'static [&'static str],
523 1 : ) -> io::Result<()>
524 1 : where
525 1 : S: Subscriber + for<'a> LookupSpan<'a>,
526 : {
527 1 : let timestamp = now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true);
528 :
529 : use tracing_log::NormalizeEvent;
530 1 : let normalized_meta = event.normalized_metadata();
531 1 : let meta = normalized_meta.as_ref().unwrap_or_else(|| event.metadata());
532 :
533 1 : let skipped_field_indices = skipped_field_indices
534 1 : .pin()
535 1 : .get(&meta.callsite())
536 1 : .copied()
537 1 : .unwrap_or_default();
538 :
539 1 : let mut serialize = || {
540 1 : let mut serializer = serde_json::Serializer::new(&mut self.logline_buffer);
541 :
542 1 : let mut serializer = serializer.serialize_map(None)?;
543 :
544 : // Timestamp comes first, so raw lines can be sorted by timestamp.
545 1 : serializer.serialize_entry("timestamp", ×tamp)?;
546 :
547 : // Level next.
548 1 : serializer.serialize_entry("level", &meta.level().as_str())?;
549 :
550 : // Message next.
551 1 : serializer.serialize_key("message")?;
552 1 : let mut message_extractor =
553 1 : MessageFieldExtractor::new(serializer, skipped_field_indices);
554 1 : event.record(&mut message_extractor);
555 1 : let mut serializer = message_extractor.into_serializer()?;
556 :
557 : // Direct message fields.
558 1 : let mut fields_present = FieldsPresent(false, skipped_field_indices);
559 1 : event.record(&mut fields_present);
560 1 : if fields_present.0 {
561 1 : serializer.serialize_entry(
562 1 : "fields",
563 1 : &SerializableEventFields(event, skipped_field_indices),
564 0 : )?;
565 0 : }
566 :
567 1 : let spans = SerializableSpans {
568 : // collect all spans from parent to root.
569 1 : spans: ctx
570 1 : .event_span(event)
571 1 : .map_or(vec![], |parent| parent.scope().collect()),
572 1 : extracted: ExtractedSpanFields::new(extract_fields),
573 : };
574 1 : serializer.serialize_entry("spans", &spans)?;
575 :
576 : // TODO: thread-local cache?
577 1 : let pid = std::process::id();
578 : // Skip adding pid 1 to reduce noise for services running in containers.
579 1 : if pid != 1 {
580 1 : serializer.serialize_entry("process_id", &pid)?;
581 0 : }
582 :
583 1 : THREAD_ID.with(|tid| serializer.serialize_entry("thread_id", tid))?;
584 :
585 : // TODO: tls cache? name could change
586 1 : if let Some(thread_name) = std::thread::current().name()
587 1 : && !thread_name.is_empty()
588 1 : && thread_name != "tokio-runtime-worker"
589 : {
590 1 : serializer.serialize_entry("thread_name", thread_name)?;
591 0 : }
592 :
593 1 : if let Some(task_id) = tokio::task::try_id() {
594 0 : serializer.serialize_entry("task_id", &format_args!("{task_id}"))?;
595 1 : }
596 :
597 1 : serializer.serialize_entry("target", meta.target())?;
598 :
599 : // Skip adding module if it's the same as target.
600 1 : if let Some(module) = meta.module_path()
601 1 : && module != meta.target()
602 : {
603 0 : serializer.serialize_entry("module", module)?;
604 1 : }
605 :
606 1 : if let Some(file) = meta.file() {
607 1 : if let Some(line) = meta.line() {
608 1 : serializer.serialize_entry("src", &format_args!("{file}:{line}"))?;
609 : } else {
610 0 : serializer.serialize_entry("src", file)?;
611 : }
612 0 : }
613 :
614 : {
615 1 : let otel_context = Span::current().context();
616 1 : let otel_spanref = otel_context.span();
617 1 : let span_context = otel_spanref.span_context();
618 1 : if span_context.is_valid() {
619 0 : serializer.serialize_entry(
620 0 : "trace_id",
621 0 : &format_args!("{}", span_context.trace_id()),
622 0 : )?;
623 1 : }
624 : }
625 :
626 1 : if spans.extracted.has_values() {
627 : // TODO: add fields from event, too?
628 1 : serializer.serialize_entry("extract", &spans.extracted)?;
629 0 : }
630 :
631 1 : serializer.end()
632 1 : };
633 :
634 1 : serialize().map_err(io::Error::other)?;
635 1 : self.logline_buffer.push(b'\n');
636 1 : Ok(())
637 1 : }
638 : }
639 :
640 : /// Extracts the message field that's mixed will other fields.
641 : struct MessageFieldExtractor<S: serde::ser::SerializeMap> {
642 : serializer: S,
643 : skipped_field_indices: SkippedFieldIndices,
644 : state: Option<Result<(), S::Error>>,
645 : }
646 :
647 : impl<S: serde::ser::SerializeMap> MessageFieldExtractor<S> {
648 : #[inline]
649 1 : fn new(serializer: S, skipped_field_indices: SkippedFieldIndices) -> Self {
650 1 : Self {
651 1 : serializer,
652 1 : skipped_field_indices,
653 1 : state: None,
654 1 : }
655 1 : }
656 :
657 : #[inline]
658 1 : fn into_serializer(mut self) -> Result<S, S::Error> {
659 1 : match self.state {
660 1 : Some(Ok(())) => {}
661 0 : Some(Err(err)) => return Err(err),
662 0 : None => self.serializer.serialize_value("")?,
663 : }
664 1 : Ok(self.serializer)
665 1 : }
666 :
667 : #[inline]
668 5 : fn accept_field(&self, field: &tracing::field::Field) -> bool {
669 5 : self.state.is_none()
670 5 : && field.name() == MESSAGE_FIELD
671 2 : && !self.skipped_field_indices.contains(field.index())
672 5 : }
673 : }
674 :
675 : impl<S: serde::ser::SerializeMap> tracing::field::Visit for MessageFieldExtractor<S> {
676 : #[inline]
677 0 : fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
678 0 : if self.accept_field(field) {
679 0 : self.state = Some(self.serializer.serialize_value(&value));
680 0 : }
681 0 : }
682 :
683 : #[inline]
684 3 : fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
685 3 : if self.accept_field(field) {
686 0 : self.state = Some(self.serializer.serialize_value(&value));
687 3 : }
688 3 : }
689 :
690 : #[inline]
691 0 : fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
692 0 : if self.accept_field(field) {
693 0 : self.state = Some(self.serializer.serialize_value(&value));
694 0 : }
695 0 : }
696 :
697 : #[inline]
698 0 : fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
699 0 : if self.accept_field(field) {
700 0 : self.state = Some(self.serializer.serialize_value(&value));
701 0 : }
702 0 : }
703 :
704 : #[inline]
705 0 : fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
706 0 : if self.accept_field(field) {
707 0 : self.state = Some(self.serializer.serialize_value(&value));
708 0 : }
709 0 : }
710 :
711 : #[inline]
712 0 : fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
713 0 : if self.accept_field(field) {
714 0 : self.state = Some(self.serializer.serialize_value(&value));
715 0 : }
716 0 : }
717 :
718 : #[inline]
719 0 : fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
720 0 : if self.accept_field(field) {
721 0 : self.state = Some(self.serializer.serialize_value(&format_args!("{value:x?}")));
722 0 : }
723 0 : }
724 :
725 : #[inline]
726 1 : fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
727 1 : if self.accept_field(field) {
728 1 : self.state = Some(self.serializer.serialize_value(&value));
729 1 : }
730 1 : }
731 :
732 : #[inline]
733 1 : fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
734 1 : if self.accept_field(field) {
735 0 : self.state = Some(self.serializer.serialize_value(&format_args!("{value:?}")));
736 1 : }
737 1 : }
738 :
739 : #[inline]
740 0 : fn record_error(
741 0 : &mut self,
742 0 : field: &tracing::field::Field,
743 0 : value: &(dyn std::error::Error + 'static),
744 0 : ) {
745 0 : if self.accept_field(field) {
746 0 : self.state = Some(self.serializer.serialize_value(&format_args!("{value}")));
747 0 : }
748 0 : }
749 : }
750 :
751 : /// Checks if there's any fields and field values present. If not, the JSON subobject
752 : /// can be skipped.
753 : // This is entirely optional and only cosmetic, though maybe helps a
754 : // bit during log parsing in dashboards when there's no field with empty object.
755 : struct FieldsPresent(pub bool, SkippedFieldIndices);
756 :
757 : // Even though some methods have an overhead (error, bytes) it is assumed the
758 : // compiler won't include this since we ignore the value entirely.
759 : impl tracing::field::Visit for FieldsPresent {
760 : #[inline]
761 5 : fn record_debug(&mut self, field: &tracing::field::Field, _: &dyn std::fmt::Debug) {
762 5 : if !self.1.contains(field.index())
763 2 : && field.name() != MESSAGE_FIELD
764 1 : && !field.name().starts_with("log.")
765 1 : {
766 1 : self.0 |= true;
767 4 : }
768 5 : }
769 : }
770 :
771 : /// Serializes the fields directly supplied with a log event.
772 : struct SerializableEventFields<'a, 'event>(&'a tracing::Event<'event>, SkippedFieldIndices);
773 :
774 : impl serde::ser::Serialize for SerializableEventFields<'_, '_> {
775 1 : fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
776 1 : where
777 1 : S: Serializer,
778 : {
779 : use serde::ser::SerializeMap;
780 1 : let serializer = serializer.serialize_map(None)?;
781 1 : let mut message_skipper = MessageFieldSkipper::new(serializer, self.1);
782 1 : self.0.record(&mut message_skipper);
783 1 : let serializer = message_skipper.into_serializer()?;
784 1 : serializer.end()
785 1 : }
786 : }
787 :
788 : /// A tracing field visitor that skips the message field.
789 : struct MessageFieldSkipper<S: serde::ser::SerializeMap> {
790 : serializer: S,
791 : skipped_field_indices: SkippedFieldIndices,
792 : state: Result<(), S::Error>,
793 : }
794 :
795 : impl<S: serde::ser::SerializeMap> MessageFieldSkipper<S> {
796 : #[inline]
797 1 : fn new(serializer: S, skipped_field_indices: SkippedFieldIndices) -> Self {
798 1 : Self {
799 1 : serializer,
800 1 : skipped_field_indices,
801 1 : state: Ok(()),
802 1 : }
803 1 : }
804 :
805 : #[inline]
806 5 : fn accept_field(&self, field: &tracing::field::Field) -> bool {
807 5 : self.state.is_ok()
808 5 : && field.name() != MESSAGE_FIELD
809 3 : && !field.name().starts_with("log.")
810 3 : && !self.skipped_field_indices.contains(field.index())
811 5 : }
812 :
813 : #[inline]
814 1 : fn into_serializer(self) -> Result<S, S::Error> {
815 1 : self.state?;
816 1 : Ok(self.serializer)
817 1 : }
818 : }
819 :
820 : impl<S: serde::ser::SerializeMap> tracing::field::Visit for MessageFieldSkipper<S> {
821 : #[inline]
822 0 : fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
823 0 : if self.accept_field(field) {
824 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
825 0 : }
826 0 : }
827 :
828 : #[inline]
829 3 : fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
830 3 : if self.accept_field(field) {
831 1 : self.state = self.serializer.serialize_entry(field.name(), &value);
832 2 : }
833 3 : }
834 :
835 : #[inline]
836 0 : fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
837 0 : if self.accept_field(field) {
838 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
839 0 : }
840 0 : }
841 :
842 : #[inline]
843 0 : fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
844 0 : if self.accept_field(field) {
845 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
846 0 : }
847 0 : }
848 :
849 : #[inline]
850 0 : fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
851 0 : if self.accept_field(field) {
852 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
853 0 : }
854 0 : }
855 :
856 : #[inline]
857 0 : fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
858 0 : if self.accept_field(field) {
859 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
860 0 : }
861 0 : }
862 :
863 : #[inline]
864 0 : fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
865 0 : if self.accept_field(field) {
866 0 : self.state = self
867 0 : .serializer
868 0 : .serialize_entry(field.name(), &format_args!("{value:x?}"));
869 0 : }
870 0 : }
871 :
872 : #[inline]
873 1 : fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
874 1 : if self.accept_field(field) {
875 0 : self.state = self.serializer.serialize_entry(field.name(), &value);
876 1 : }
877 1 : }
878 :
879 : #[inline]
880 1 : fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
881 1 : if self.accept_field(field) {
882 0 : self.state = self
883 0 : .serializer
884 0 : .serialize_entry(field.name(), &format_args!("{value:?}"));
885 1 : }
886 1 : }
887 :
888 : #[inline]
889 0 : fn record_error(
890 0 : &mut self,
891 0 : field: &tracing::field::Field,
892 0 : value: &(dyn std::error::Error + 'static),
893 0 : ) {
894 0 : if self.accept_field(field) {
895 0 : self.state = self.serializer.serialize_value(&format_args!("{value}"));
896 0 : }
897 0 : }
898 : }
899 :
900 : /// Serializes the span stack from root to leaf (parent of event) as object
901 : /// with the span names as keys. To prevent collision we append a numberic value
902 : /// to the name. Also, collects any span fields we're interested in. Last one
903 : /// wins.
904 : struct SerializableSpans<'ctx, S>
905 : where
906 : S: for<'lookup> LookupSpan<'lookup>,
907 : {
908 : spans: Vec<SpanRef<'ctx, S>>,
909 : extracted: ExtractedSpanFields,
910 : }
911 :
912 : impl<S> serde::ser::Serialize for SerializableSpans<'_, S>
913 : where
914 : S: for<'lookup> LookupSpan<'lookup>,
915 : {
916 1 : fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
917 1 : where
918 1 : Ser: serde::ser::Serializer,
919 : {
920 1 : let mut serializer = serializer.serialize_map(None)?;
921 :
922 2 : for span in self.spans.iter().rev() {
923 2 : let ext = span.extensions();
924 :
925 : // all spans should have this extension.
926 2 : let Some(fields) = ext.get() else { continue };
927 :
928 2 : self.extracted.layer_span(fields);
929 :
930 2 : let SpanFields { values, span_info } = fields;
931 2 : serializer.serialize_entry(
932 2 : &*span_info.normalized_name,
933 2 : &SerializableSpanFields {
934 2 : fields: span.metadata().fields(),
935 2 : values,
936 2 : },
937 0 : )?;
938 : }
939 :
940 1 : serializer.end()
941 1 : }
942 : }
943 :
944 : /// Serializes the span fields as object.
945 : struct SerializableSpanFields<'span> {
946 : fields: &'span tracing::field::FieldSet,
947 : values: &'span [serde_json::Value; MAX_TRACING_FIELDS],
948 : }
949 :
950 : impl serde::ser::Serialize for SerializableSpanFields<'_> {
951 2 : fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
952 2 : where
953 2 : S: serde::ser::Serializer,
954 : {
955 2 : let mut serializer = serializer.serialize_map(None)?;
956 :
957 4 : for (field, value) in std::iter::zip(self.fields, self.values) {
958 4 : if value.is_null() {
959 0 : continue;
960 4 : }
961 4 : serializer.serialize_entry(field.name(), value)?;
962 : }
963 :
964 2 : serializer.end()
965 2 : }
966 : }
967 :
968 : struct ExtractedSpanFields {
969 : names: &'static [&'static str],
970 : values: RefCell<Vec<serde_json::Value>>,
971 : }
972 :
973 : impl ExtractedSpanFields {
974 1 : fn new(names: &'static [&'static str]) -> Self {
975 1 : ExtractedSpanFields {
976 1 : names,
977 1 : values: RefCell::new(vec![serde_json::Value::Null; names.len()]),
978 1 : }
979 1 : }
980 :
981 2 : fn layer_span(&self, fields: &SpanFields) {
982 2 : let mut v = self.values.borrow_mut();
983 2 : let SpanFields { values, span_info } = fields;
984 :
985 : // extract the fields
986 2 : for (i, &j) in span_info.extract.iter().enumerate() {
987 2 : let Some(value) = values.get(j) else { continue };
988 :
989 2 : if !value.is_null() {
990 2 : // TODO: replace clone with reference, if possible.
991 2 : v[i] = value.clone();
992 2 : }
993 : }
994 2 : }
995 :
996 : #[inline]
997 1 : fn has_values(&self) -> bool {
998 1 : self.values.borrow().iter().any(|v| !v.is_null())
999 1 : }
1000 : }
1001 :
1002 : impl serde::ser::Serialize for ExtractedSpanFields {
1003 1 : fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1004 1 : where
1005 1 : S: serde::ser::Serializer,
1006 : {
1007 1 : let mut serializer = serializer.serialize_map(None)?;
1008 :
1009 1 : let values = self.values.borrow();
1010 1 : for (key, value) in std::iter::zip(self.names, &*values) {
1011 1 : if value.is_null() {
1012 0 : continue;
1013 1 : }
1014 :
1015 1 : serializer.serialize_entry(key, value)?;
1016 : }
1017 :
1018 1 : serializer.end()
1019 1 : }
1020 : }
1021 :
1022 : #[cfg(test)]
1023 : mod tests {
1024 : use std::sync::{Arc, Mutex, MutexGuard};
1025 :
1026 : use assert_json_diff::assert_json_eq;
1027 : use tracing::info_span;
1028 :
1029 : use super::*;
1030 :
1031 : struct TestClock {
1032 : current_time: Mutex<DateTime<Utc>>,
1033 : }
1034 :
1035 : impl Clock for Arc<TestClock> {
1036 2 : fn now(&self) -> DateTime<Utc> {
1037 2 : *self.current_time.lock().expect("poisoned")
1038 2 : }
1039 : }
1040 :
1041 : struct VecWriter<'a> {
1042 : buffer: MutexGuard<'a, Vec<u8>>,
1043 : }
1044 :
1045 : impl MakeWriter for Arc<Mutex<Vec<u8>>> {
1046 1 : fn make_writer(&self) -> impl io::Write {
1047 1 : VecWriter {
1048 1 : buffer: self.lock().expect("poisoned"),
1049 1 : }
1050 1 : }
1051 : }
1052 :
1053 : impl io::Write for VecWriter<'_> {
1054 1 : fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1055 1 : self.buffer.write(buf)
1056 1 : }
1057 :
1058 0 : fn flush(&mut self) -> io::Result<()> {
1059 0 : Ok(())
1060 0 : }
1061 : }
1062 :
1063 : #[test]
1064 1 : fn test_field_collection() {
1065 1 : let clock = Arc::new(TestClock {
1066 1 : current_time: Mutex::new(Utc::now()),
1067 1 : });
1068 1 : let buffer = Arc::new(Mutex::new(Vec::new()));
1069 1 : let log_layer = JsonLoggingLayer {
1070 1 : clock: clock.clone(),
1071 1 : skipped_field_indices: papaya::HashMap::default(),
1072 1 : span_info: papaya::HashMap::default(),
1073 1 : writer: buffer.clone(),
1074 1 : extract_fields: &["x"],
1075 1 : };
1076 :
1077 1 : let registry = tracing_subscriber::Registry::default().with(log_layer);
1078 :
1079 1 : tracing::subscriber::with_default(registry, || {
1080 1 : info_span!("some_span", x = 24).in_scope(|| {
1081 1 : info_span!("some_span", x = 40, x = 41, x = 42).in_scope(|| {
1082 1 : tracing::error!(
1083 : a = 1,
1084 : a = 2,
1085 : a = 3,
1086 : message = "explicit message field",
1087 0 : "implicit message field"
1088 : );
1089 1 : });
1090 1 : });
1091 1 : });
1092 :
1093 1 : let buffer = Arc::try_unwrap(buffer)
1094 1 : .expect("no other reference")
1095 1 : .into_inner()
1096 1 : .expect("poisoned");
1097 1 : let actual: serde_json::Value = serde_json::from_slice(&buffer).expect("valid JSON");
1098 1 : let expected: serde_json::Value = serde_json::json!(
1099 : {
1100 1 : "timestamp": clock.now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
1101 1 : "level": "ERROR",
1102 1 : "message": "explicit message field",
1103 1 : "fields": {
1104 1 : "a": 3,
1105 : },
1106 1 : "spans": {
1107 1 : "some_span#1":{
1108 1 : "x": 24,
1109 : },
1110 1 : "some_span#2": {
1111 1 : "x": 42,
1112 : }
1113 : },
1114 1 : "extract": {
1115 1 : "x": 42,
1116 : },
1117 1 : "src": actual.as_object().unwrap().get("src").unwrap().as_str().unwrap(),
1118 1 : "target": "proxy::logging::tests",
1119 1 : "process_id": actual.as_object().unwrap().get("process_id").unwrap().as_number().unwrap(),
1120 1 : "thread_id": actual.as_object().unwrap().get("thread_id").unwrap().as_number().unwrap(),
1121 1 : "thread_name": "logging::tests::test_field_collection",
1122 : }
1123 : );
1124 :
1125 1 : assert_json_eq!(actual, expected);
1126 1 : }
1127 : }
|