LCOV - code coverage report
Current view: top level - libs/http-utils/src - endpoint.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 12.8 % 524 67
Test Date: 2025-02-20 13:11:02 Functions: 2.2 % 540 12

            Line data    Source code
       1              : use crate::error::{api_error_handler, route_error_handler, ApiError};
       2              : use crate::pprof;
       3              : use crate::request::{get_query_param, parse_query_param};
       4              : use ::pprof::protos::Message as _;
       5              : use ::pprof::ProfilerGuardBuilder;
       6              : use anyhow::{anyhow, Context};
       7              : use bytes::{Bytes, BytesMut};
       8              : use hyper::header::{HeaderName, AUTHORIZATION, CONTENT_DISPOSITION};
       9              : use hyper::http::HeaderValue;
      10              : use hyper::Method;
      11              : use hyper::{header::CONTENT_TYPE, Body, Request, Response};
      12              : use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
      13              : use once_cell::sync::Lazy;
      14              : use regex::Regex;
      15              : use routerify::ext::RequestExt;
      16              : use routerify::{Middleware, RequestInfo, Router, RouterBuilder};
      17              : use tokio::sync::{mpsc, Mutex, Notify};
      18              : use tokio_stream::wrappers::ReceiverStream;
      19              : use tokio_util::io::ReaderStream;
      20              : use tracing::{debug, info, info_span, warn, Instrument};
      21              : use utils::auth::{AuthError, Claims, SwappableJwtAuth};
      22              : 
      23              : use std::future::Future;
      24              : use std::io::Write as _;
      25              : use std::str::FromStr;
      26              : use std::time::Duration;
      27              : 
      28            0 : static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
      29            0 :     register_int_counter!(
      30            0 :         "libmetrics_metric_handler_requests_total",
      31            0 :         "Number of metric requests made"
      32            0 :     )
      33            0 :     .expect("failed to define a metric")
      34            0 : });
      35              : 
      36              : static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
      37              : 
      38              : static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
      39              : #[derive(Debug, Default, Clone)]
      40              : struct RequestId(String);
      41              : 
      42              : /// Adds a tracing info_span! instrumentation around the handler events,
      43              : /// logs the request start and end events for non-GET requests and non-200 responses.
      44              : ///
      45              : /// Usage: Replace `my_handler` with `|r| request_span(r, my_handler)`
      46              : ///
      47              : /// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
      48              : /// with this will get request info logged in the wrapping span, including the unique request ID.
      49              : ///
      50              : /// This also handles errors, logging them and converting them to an HTTP error response.
      51              : ///
      52              : /// NB: If the client disconnects, Hyper will drop the Future, without polling it to
      53              : /// completion. In other words, the handler must be async cancellation safe! request_span
      54              : /// prints a warning to the log when that happens, so that you have some trace of it in
      55              : /// the log.
      56              : ///
      57              : ///
      58              : /// There could be other ways to implement similar functionality:
      59              : ///
      60              : /// * procmacros placed on top of all handler methods
      61              : ///   With all the drawbacks of procmacros, brings no difference implementation-wise,
      62              : ///   and little code reduction compared to the existing approach.
      63              : ///
      64              : /// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
      65              : ///   implemented for [`RouterBuilder`].
      66              : ///   Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
      67              : ///
      68              : /// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
      69              : ///   later, in a post-response middleware.
      70              : ///   Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
      71              : ///   tries to achive with its `.instrument` used in the current approach.
      72              : ///
      73              : /// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
      74            0 : pub async fn request_span<R, H>(request: Request<Body>, handler: H) -> R::Output
      75            0 : where
      76            0 :     R: Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
      77            0 :     H: FnOnce(Request<Body>) -> R + Send + Sync + 'static,
      78            0 : {
      79            0 :     let request_id = request.context::<RequestId>().unwrap_or_default().0;
      80            0 :     let method = request.method();
      81            0 :     let path = request.uri().path();
      82            0 :     let request_span = info_span!("request", %method, %path, %request_id);
      83              : 
      84            0 :     let log_quietly = method == Method::GET;
      85            0 :     async move {
      86            0 :         let cancellation_guard = RequestCancelled::warn_when_dropped_without_responding();
      87            0 :         if log_quietly {
      88            0 :             debug!("Handling request");
      89              :         } else {
      90            0 :             info!("Handling request");
      91              :         }
      92              : 
      93              :         // No special handling for panics here. There's a `tracing_panic_hook` from another
      94              :         // module to do that globally.
      95            0 :         let res = handler(request).await;
      96              : 
      97            0 :         cancellation_guard.disarm();
      98            0 : 
      99            0 :         // Log the result if needed.
     100            0 :         //
     101            0 :         // We also convert any errors into an Ok response with HTTP error code here.
     102            0 :         // `make_router` sets a last-resort error handler that would do the same, but
     103            0 :         // we prefer to do it here, before we exit the request span, so that the error
     104            0 :         // is still logged with the span.
     105            0 :         //
     106            0 :         // (Because we convert errors to Ok response, we never actually return an error,
     107            0 :         // and we could declare the function to return the never type (`!`). However,
     108            0 :         // using `routerify::RouterBuilder` requires a proper error type.)
     109            0 :         match res {
     110            0 :             Ok(response) => {
     111            0 :                 let response_status = response.status();
     112            0 :                 if log_quietly && response_status.is_success() {
     113            0 :                     debug!("Request handled, status: {response_status}");
     114              :                 } else {
     115            0 :                     info!("Request handled, status: {response_status}");
     116              :                 }
     117            0 :                 Ok(response)
     118              :             }
     119            0 :             Err(err) => Ok(api_error_handler(err)),
     120              :         }
     121            0 :     }
     122            0 :     .instrument(request_span)
     123            0 :     .await
     124            0 : }
     125              : 
     126              : /// Drop guard to WARN in case the request was dropped before completion.
     127              : struct RequestCancelled {
     128              :     warn: Option<tracing::Span>,
     129              : }
     130              : 
     131              : impl RequestCancelled {
     132              :     /// Create the drop guard using the [`tracing::Span::current`] as the span.
     133            0 :     fn warn_when_dropped_without_responding() -> Self {
     134            0 :         RequestCancelled {
     135            0 :             warn: Some(tracing::Span::current()),
     136            0 :         }
     137            0 :     }
     138              : 
     139              :     /// Consume the drop guard without logging anything.
     140            0 :     fn disarm(mut self) {
     141            0 :         self.warn = None;
     142            0 :     }
     143              : }
     144              : 
     145              : impl Drop for RequestCancelled {
     146            0 :     fn drop(&mut self) {
     147            0 :         if std::thread::panicking() {
     148            0 :             // we are unwinding due to panicking, assume we are not dropped for cancellation
     149            0 :         } else if let Some(span) = self.warn.take() {
     150              :             // the span has all of the info already, but the outer `.instrument(span)` has already
     151              :             // been dropped, so we need to manually re-enter it for this message.
     152              :             //
     153              :             // this is what the instrument would do before polling so it is fine.
     154            0 :             let _g = span.entered();
     155            0 :             warn!("request was dropped before completing");
     156            0 :         }
     157            0 :     }
     158              : }
     159              : 
     160              : /// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
     161              : pub struct ChannelWriter {
     162              :     buffer: BytesMut,
     163              :     pub tx: mpsc::Sender<std::io::Result<Bytes>>,
     164              :     written: usize,
     165              :     /// Time spent waiting for the channel to make progress. It is not the same as time to upload a
     166              :     /// buffer because we cannot know anything about that, but this should allow us to understand
     167              :     /// the actual time taken without the time spent `std::thread::park`ed.
     168              :     wait_time: std::time::Duration,
     169              : }
     170              : 
     171              : impl ChannelWriter {
     172            0 :     pub fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
     173            0 :         assert_ne!(buf_len, 0);
     174            0 :         ChannelWriter {
     175            0 :             // split about half off the buffer from the start, because we flush depending on
     176            0 :             // capacity. first flush will come sooner than without this, but now resizes will
     177            0 :             // have better chance of picking up the "other" half. not guaranteed of course.
     178            0 :             buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
     179            0 :             tx,
     180            0 :             written: 0,
     181            0 :             wait_time: std::time::Duration::ZERO,
     182            0 :         }
     183            0 :     }
     184              : 
     185            0 :     pub fn flush0(&mut self) -> std::io::Result<usize> {
     186            0 :         let n = self.buffer.len();
     187            0 :         if n == 0 {
     188            0 :             return Ok(0);
     189            0 :         }
     190            0 : 
     191            0 :         tracing::trace!(n, "flushing");
     192            0 :         let ready = self.buffer.split().freeze();
     193            0 : 
     194            0 :         let wait_started_at = std::time::Instant::now();
     195            0 : 
     196            0 :         // not ideal to call from blocking code to block_on, but we are sure that this
     197            0 :         // operation does not spawn_blocking other tasks
     198            0 :         let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
     199            0 :             self.tx.send(Ok(ready)).await.map_err(|_| ())?;
     200              : 
     201              :             // throttle sending to allow reuse of our buffer in `write`.
     202            0 :             self.tx.reserve().await.map_err(|_| ())?;
     203              : 
     204              :             // now the response task has picked up the buffer and hopefully started
     205              :             // sending it to the client.
     206            0 :             Ok(())
     207            0 :         });
     208            0 : 
     209            0 :         self.wait_time += wait_started_at.elapsed();
     210            0 : 
     211            0 :         if res.is_err() {
     212            0 :             return Err(std::io::ErrorKind::BrokenPipe.into());
     213            0 :         }
     214            0 :         self.written += n;
     215            0 :         Ok(n)
     216            0 :     }
     217              : 
     218            0 :     pub fn flushed_bytes(&self) -> usize {
     219            0 :         self.written
     220            0 :     }
     221              : 
     222            0 :     pub fn wait_time(&self) -> std::time::Duration {
     223            0 :         self.wait_time
     224            0 :     }
     225              : }
     226              : 
     227              : impl std::io::Write for ChannelWriter {
     228            0 :     fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
     229            0 :         let remaining = self.buffer.capacity() - self.buffer.len();
     230            0 : 
     231            0 :         let out_of_space = remaining < buf.len();
     232            0 : 
     233            0 :         let original_len = buf.len();
     234            0 : 
     235            0 :         if out_of_space {
     236            0 :             let can_still_fit = buf.len() - remaining;
     237            0 :             self.buffer.extend_from_slice(&buf[..can_still_fit]);
     238            0 :             buf = &buf[can_still_fit..];
     239            0 :             self.flush0()?;
     240            0 :         }
     241              : 
     242              :         // assume that this will often under normal operation just move the pointer back to the
     243              :         // beginning of allocation, because previous split off parts are already sent and
     244              :         // dropped.
     245            0 :         self.buffer.extend_from_slice(buf);
     246            0 :         Ok(original_len)
     247            0 :     }
     248              : 
     249            0 :     fn flush(&mut self) -> std::io::Result<()> {
     250            0 :         self.flush0().map(|_| ())
     251            0 :     }
     252              : }
     253              : 
     254            0 : pub async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
     255            0 :     SERVE_METRICS_COUNT.inc();
     256            0 : 
     257            0 :     let started_at = std::time::Instant::now();
     258            0 : 
     259            0 :     let (tx, rx) = mpsc::channel(1);
     260            0 : 
     261            0 :     let body = Body::wrap_stream(ReceiverStream::new(rx));
     262            0 : 
     263            0 :     let mut writer = ChannelWriter::new(128 * 1024, tx);
     264            0 : 
     265            0 :     let encoder = TextEncoder::new();
     266            0 : 
     267            0 :     let response = Response::builder()
     268            0 :         .status(200)
     269            0 :         .header(CONTENT_TYPE, encoder.format_type())
     270            0 :         .body(body)
     271            0 :         .unwrap();
     272              : 
     273            0 :     let span = info_span!("blocking");
     274            0 :     tokio::task::spawn_blocking(move || {
     275            0 :         // there are situations where we lose scraped metrics under load, try to gather some clues
     276            0 :         // since all nodes are queried this, keep the message count low.
     277            0 :         let spawned_at = std::time::Instant::now();
     278            0 : 
     279            0 :         let _span = span.entered();
     280            0 : 
     281            0 :         let metrics = metrics::gather();
     282            0 : 
     283            0 :         let gathered_at = std::time::Instant::now();
     284            0 : 
     285            0 :         let res = encoder
     286            0 :             .encode(&metrics, &mut writer)
     287            0 :             .and_then(|_| writer.flush().map_err(|e| e.into()));
     288            0 : 
     289            0 :         // this instant is not when we finally got the full response sent, sending is done by hyper
     290            0 :         // in another task.
     291            0 :         let encoded_at = std::time::Instant::now();
     292            0 : 
     293            0 :         let spawned_in = spawned_at - started_at;
     294            0 :         let collected_in = gathered_at - spawned_at;
     295            0 :         // remove the wait time here in case the tcp connection was clogged
     296            0 :         let encoded_in = encoded_at - gathered_at - writer.wait_time();
     297            0 :         let total = encoded_at - started_at;
     298            0 : 
     299            0 :         match res {
     300              :             Ok(()) => {
     301            0 :                 tracing::info!(
     302            0 :                     bytes = writer.flushed_bytes(),
     303            0 :                     total_ms = total.as_millis(),
     304            0 :                     spawning_ms = spawned_in.as_millis(),
     305            0 :                     collection_ms = collected_in.as_millis(),
     306            0 :                     encoding_ms = encoded_in.as_millis(),
     307            0 :                     "responded /metrics"
     308              :                 );
     309              :             }
     310            0 :             Err(e) => {
     311            0 :                 // there is a chance that this error is not the BrokenPipe we generate in the writer
     312            0 :                 // for "closed connection", but it is highly unlikely.
     313            0 :                 tracing::warn!(
     314            0 :                     after_bytes = writer.flushed_bytes(),
     315            0 :                     total_ms = total.as_millis(),
     316            0 :                     spawning_ms = spawned_in.as_millis(),
     317            0 :                     collection_ms = collected_in.as_millis(),
     318            0 :                     encoding_ms = encoded_in.as_millis(),
     319            0 :                     "failed to write out /metrics response: {e:?}"
     320              :                 );
     321              :                 // semantics of this error are quite... unclear. we want to error the stream out to
     322              :                 // abort the response to somehow notify the client that we failed.
     323              :                 //
     324              :                 // though, most likely the reason for failure is that the receiver is already gone.
     325            0 :                 drop(
     326            0 :                     writer
     327            0 :                         .tx
     328            0 :                         .blocking_send(Err(std::io::ErrorKind::BrokenPipe.into())),
     329            0 :                 );
     330              :             }
     331              :         }
     332            0 :     });
     333            0 : 
     334            0 :     Ok(response)
     335            0 : }
     336              : 
     337              : /// Generates CPU profiles.
     338            0 : pub async fn profile_cpu_handler(req: Request<Body>) -> Result<Response<Body>, ApiError> {
     339              :     enum Format {
     340              :         Pprof,
     341              :         Svg,
     342              :     }
     343              : 
     344              :     // Parameters.
     345            0 :     let format = match get_query_param(&req, "format")?.as_deref() {
     346            0 :         None => Format::Pprof,
     347            0 :         Some("pprof") => Format::Pprof,
     348            0 :         Some("svg") => Format::Svg,
     349            0 :         Some(format) => return Err(ApiError::BadRequest(anyhow!("invalid format {format}"))),
     350              :     };
     351            0 :     let seconds = match parse_query_param(&req, "seconds")? {
     352            0 :         None => 5,
     353            0 :         Some(seconds @ 1..=60) => seconds,
     354            0 :         Some(_) => return Err(ApiError::BadRequest(anyhow!("duration must be 1-60 secs"))),
     355              :     };
     356            0 :     let frequency_hz = match parse_query_param(&req, "frequency")? {
     357            0 :         None => 99,
     358            0 :         Some(1001..) => return Err(ApiError::BadRequest(anyhow!("frequency must be <=1000 Hz"))),
     359            0 :         Some(frequency) => frequency,
     360              :     };
     361            0 :     let force: bool = parse_query_param(&req, "force")?.unwrap_or_default();
     362              : 
     363              :     // Take the profile.
     364            0 :     static PROFILE_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
     365              :     static PROFILE_CANCEL: Lazy<Notify> = Lazy::new(Notify::new);
     366              : 
     367            0 :     let report = {
     368              :         // Only allow one profiler at a time. If force is true, cancel a running profile (e.g. a
     369              :         // Grafana continuous profile). We use a try_lock() loop when cancelling instead of waiting
     370              :         // for a lock(), to avoid races where the notify isn't currently awaited.
     371            0 :         let _lock = loop {
     372            0 :             match PROFILE_LOCK.try_lock() {
     373            0 :                 Ok(lock) => break lock,
     374            0 :                 Err(_) if force => PROFILE_CANCEL.notify_waiters(),
     375              :                 Err(_) => {
     376            0 :                     return Err(ApiError::Conflict(
     377            0 :                         "profiler already running (use ?force=true to cancel it)".into(),
     378            0 :                     ))
     379              :                 }
     380              :             }
     381            0 :             tokio::time::sleep(Duration::from_millis(1)).await; // don't busy-wait
     382              :         };
     383              : 
     384            0 :         let guard = ProfilerGuardBuilder::default()
     385            0 :             .frequency(frequency_hz)
     386            0 :             .blocklist(&["libc", "libgcc", "pthread", "vdso"])
     387            0 :             .build()
     388            0 :             .map_err(|err| ApiError::InternalServerError(err.into()))?;
     389              : 
     390            0 :         tokio::select! {
     391            0 :             _ = tokio::time::sleep(Duration::from_secs(seconds)) => {},
     392            0 :             _ = PROFILE_CANCEL.notified() => {},
     393              :         };
     394              : 
     395            0 :         guard
     396            0 :             .report()
     397            0 :             .build()
     398            0 :             .map_err(|err| ApiError::InternalServerError(err.into()))?
     399              :     };
     400              : 
     401              :     // Return the report in the requested format.
     402            0 :     match format {
     403              :         Format::Pprof => {
     404            0 :             let mut body = Vec::new();
     405            0 :             report
     406            0 :                 .pprof()
     407            0 :                 .map_err(|err| ApiError::InternalServerError(err.into()))?
     408            0 :                 .write_to_vec(&mut body)
     409            0 :                 .map_err(|err| ApiError::InternalServerError(err.into()))?;
     410              : 
     411            0 :             Response::builder()
     412            0 :                 .status(200)
     413            0 :                 .header(CONTENT_TYPE, "application/octet-stream")
     414            0 :                 .header(CONTENT_DISPOSITION, "attachment; filename=\"profile.pb\"")
     415            0 :                 .body(Body::from(body))
     416            0 :                 .map_err(|err| ApiError::InternalServerError(err.into()))
     417              :         }
     418              : 
     419              :         Format::Svg => {
     420            0 :             let mut body = Vec::new();
     421            0 :             report
     422            0 :                 .flamegraph(&mut body)
     423            0 :                 .map_err(|err| ApiError::InternalServerError(err.into()))?;
     424            0 :             Response::builder()
     425            0 :                 .status(200)
     426            0 :                 .header(CONTENT_TYPE, "image/svg+xml")
     427            0 :                 .body(Body::from(body))
     428            0 :                 .map_err(|err| ApiError::InternalServerError(err.into()))
     429              :         }
     430              :     }
     431            0 : }
     432              : 
     433              : /// Generates heap profiles.
     434              : ///
     435              : /// This only works with jemalloc on Linux.
     436            0 : pub async fn profile_heap_handler(req: Request<Body>) -> Result<Response<Body>, ApiError> {
     437              :     enum Format {
     438              :         Jemalloc,
     439              :         Pprof,
     440              :         Svg,
     441              :     }
     442              : 
     443              :     // Parameters.
     444            0 :     let format = match get_query_param(&req, "format")?.as_deref() {
     445            0 :         None => Format::Pprof,
     446            0 :         Some("jemalloc") => Format::Jemalloc,
     447            0 :         Some("pprof") => Format::Pprof,
     448            0 :         Some("svg") => Format::Svg,
     449            0 :         Some(format) => return Err(ApiError::BadRequest(anyhow!("invalid format {format}"))),
     450              :     };
     451              : 
     452              :     // Functions and mappings to strip when symbolizing pprof profiles. If true,
     453              :     // also remove child frames.
     454            0 :     static STRIP_FUNCTIONS: Lazy<Vec<(Regex, bool)>> = Lazy::new(|| {
     455            0 :         vec![
     456            0 :             (Regex::new("^__rust").unwrap(), false),
     457            0 :             (Regex::new("^_start$").unwrap(), false),
     458            0 :             (Regex::new("^irallocx_prof").unwrap(), true),
     459            0 :             (Regex::new("^prof_alloc_prep").unwrap(), true),
     460            0 :             (Regex::new("^std::rt::lang_start").unwrap(), false),
     461            0 :             (Regex::new("^std::sys::backtrace::__rust").unwrap(), false),
     462            0 :         ]
     463            0 :     });
     464              :     const STRIP_MAPPINGS: &[&str] = &["libc", "libgcc", "pthread", "vdso"];
     465              : 
     466              :     // Obtain profiler handle.
     467            0 :     let mut prof_ctl = jemalloc_pprof::PROF_CTL
     468            0 :         .as_ref()
     469            0 :         .ok_or(ApiError::InternalServerError(anyhow!(
     470            0 :             "heap profiling not enabled"
     471            0 :         )))?
     472            0 :         .lock()
     473            0 :         .await;
     474            0 :     if !prof_ctl.activated() {
     475            0 :         return Err(ApiError::InternalServerError(anyhow!(
     476            0 :             "heap profiling not enabled"
     477            0 :         )));
     478            0 :     }
     479            0 : 
     480            0 :     // Take and return the profile.
     481            0 :     match format {
     482              :         Format::Jemalloc => {
     483              :             // NB: file is an open handle to a tempfile that's already deleted.
     484            0 :             let file = tokio::task::spawn_blocking(move || prof_ctl.dump())
     485            0 :                 .await
     486            0 :                 .map_err(|join_err| ApiError::InternalServerError(join_err.into()))?
     487            0 :                 .map_err(ApiError::InternalServerError)?;
     488            0 :             let stream = ReaderStream::new(tokio::fs::File::from_std(file));
     489            0 :             Response::builder()
     490            0 :                 .status(200)
     491            0 :                 .header(CONTENT_TYPE, "application/octet-stream")
     492            0 :                 .header(CONTENT_DISPOSITION, "attachment; filename=\"heap.dump\"")
     493            0 :                 .body(Body::wrap_stream(stream))
     494            0 :                 .map_err(|err| ApiError::InternalServerError(err.into()))
     495              :         }
     496              : 
     497              :         Format::Pprof => {
     498            0 :             let data = tokio::task::spawn_blocking(move || {
     499            0 :                 let bytes = prof_ctl.dump_pprof()?;
     500              :                 // Symbolize the profile.
     501              :                 // TODO: consider moving this upstream to jemalloc_pprof and avoiding the
     502              :                 // serialization roundtrip.
     503            0 :                 let profile = pprof::decode(&bytes)?;
     504            0 :                 let profile = pprof::symbolize(profile)?;
     505            0 :                 let profile = pprof::strip_locations(profile, STRIP_MAPPINGS, &STRIP_FUNCTIONS);
     506            0 :                 pprof::encode(&profile)
     507            0 :             })
     508            0 :             .await
     509            0 :             .map_err(|join_err| ApiError::InternalServerError(join_err.into()))?
     510            0 :             .map_err(ApiError::InternalServerError)?;
     511            0 :             Response::builder()
     512            0 :                 .status(200)
     513            0 :                 .header(CONTENT_TYPE, "application/octet-stream")
     514            0 :                 .header(CONTENT_DISPOSITION, "attachment; filename=\"heap.pb\"")
     515            0 :                 .body(Body::from(data))
     516            0 :                 .map_err(|err| ApiError::InternalServerError(err.into()))
     517              :         }
     518              : 
     519              :         Format::Svg => {
     520            0 :             let body = tokio::task::spawn_blocking(move || {
     521            0 :                 let bytes = prof_ctl.dump_pprof()?;
     522            0 :                 let profile = pprof::decode(&bytes)?;
     523            0 :                 let profile = pprof::symbolize(profile)?;
     524            0 :                 let profile = pprof::strip_locations(profile, STRIP_MAPPINGS, &STRIP_FUNCTIONS);
     525            0 :                 let mut opts = inferno::flamegraph::Options::default();
     526            0 :                 opts.title = "Heap inuse".to_string();
     527            0 :                 opts.count_name = "bytes".to_string();
     528            0 :                 pprof::flamegraph(profile, &mut opts)
     529            0 :             })
     530            0 :             .await
     531            0 :             .map_err(|join_err| ApiError::InternalServerError(join_err.into()))?
     532            0 :             .map_err(ApiError::InternalServerError)?;
     533            0 :             Response::builder()
     534            0 :                 .status(200)
     535            0 :                 .header(CONTENT_TYPE, "image/svg+xml")
     536            0 :                 .body(Body::from(body))
     537            0 :                 .map_err(|err| ApiError::InternalServerError(err.into()))
     538              :         }
     539              :     }
     540            0 : }
     541              : 
     542            2 : pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
     543            2 : ) -> Middleware<B, ApiError> {
     544            2 :     Middleware::pre(move |req| async move {
     545            2 :         let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
     546            1 :             Some(request_id) => request_id
     547            1 :                 .to_str()
     548            1 :                 .expect("extract request id value")
     549            1 :                 .to_owned(),
     550              :             None => {
     551            1 :                 let request_id = uuid::Uuid::new_v4();
     552            1 :                 request_id.to_string()
     553              :             }
     554              :         };
     555            2 :         req.set_context(RequestId(request_id));
     556            2 : 
     557            2 :         Ok(req)
     558            2 :     })
     559            2 : }
     560              : 
     561            2 : async fn add_request_id_header_to_response(
     562            2 :     mut res: Response<Body>,
     563            2 :     req_info: RequestInfo,
     564            2 : ) -> Result<Response<Body>, ApiError> {
     565            2 :     if let Some(request_id) = req_info.context::<RequestId>() {
     566            2 :         if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
     567            2 :             res.headers_mut()
     568            2 :                 .insert(&X_REQUEST_ID_HEADER, request_header_value);
     569            2 :         };
     570            0 :     };
     571              : 
     572            2 :     Ok(res)
     573            2 : }
     574              : 
     575            2 : pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
     576            2 :     Router::builder()
     577            2 :         .middleware(add_request_id_middleware())
     578            2 :         .middleware(Middleware::post_with_info(
     579            2 :             add_request_id_header_to_response,
     580            2 :         ))
     581            2 :         .err_handler(route_error_handler)
     582            2 : }
     583              : 
     584            0 : pub fn attach_openapi_ui(
     585            0 :     router_builder: RouterBuilder<hyper::Body, ApiError>,
     586            0 :     spec: &'static [u8],
     587            0 :     spec_mount_path: &'static str,
     588            0 :     ui_mount_path: &'static str,
     589            0 : ) -> RouterBuilder<hyper::Body, ApiError> {
     590            0 :     router_builder
     591            0 :         .get(spec_mount_path,
     592            0 :             move |r| request_span(r, move |_| async move {
     593            0 :                 Ok(Response::builder().body(Body::from(spec)).unwrap())
     594            0 :             })
     595            0 :         )
     596            0 :         .get(ui_mount_path,
     597            0 :              move |r| request_span(r, move |_| async move {
     598            0 :                  Ok(Response::builder().body(Body::from(format!(r#"
     599            0 :                 <!DOCTYPE html>
     600            0 :                 <html lang="en">
     601            0 :                 <head>
     602            0 :                 <title>rweb</title>
     603            0 :                 <link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
     604            0 :                 </head>
     605            0 :                 <body>
     606            0 :                     <div id="swagger-ui"></div>
     607            0 :                     <script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
     608            0 :                     <script>
     609            0 :                         window.onload = function() {{
     610            0 :                         const ui = SwaggerUIBundle({{
     611            0 :                             "dom_id": "\#swagger-ui",
     612            0 :                             presets: [
     613            0 :                             SwaggerUIBundle.presets.apis,
     614            0 :                             SwaggerUIBundle.SwaggerUIStandalonePreset
     615            0 :                             ],
     616            0 :                             layout: "BaseLayout",
     617            0 :                             deepLinking: true,
     618            0 :                             showExtensions: true,
     619            0 :                             showCommonExtensions: true,
     620            0 :                             url: "{}",
     621            0 :                         }})
     622            0 :                         window.ui = ui;
     623            0 :                     }};
     624            0 :                 </script>
     625            0 :                 </body>
     626            0 :                 </html>
     627            0 :             "#, spec_mount_path))).unwrap())
     628            0 :              })
     629            0 :         )
     630            0 : }
     631              : 
     632            0 : fn parse_token(header_value: &str) -> Result<&str, ApiError> {
     633              :     // header must be in form Bearer <token>
     634            0 :     let (prefix, token) = header_value
     635            0 :         .split_once(' ')
     636            0 :         .ok_or_else(|| ApiError::Unauthorized("malformed authorization header".to_string()))?;
     637            0 :     if prefix != "Bearer" {
     638            0 :         return Err(ApiError::Unauthorized(
     639            0 :             "malformed authorization header".to_string(),
     640            0 :         ));
     641            0 :     }
     642            0 :     Ok(token)
     643            0 : }
     644              : 
     645            0 : pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
     646            0 :     provide_auth: fn(&Request<Body>) -> Option<&SwappableJwtAuth>,
     647            0 : ) -> Middleware<B, ApiError> {
     648            0 :     Middleware::pre(move |req| async move {
     649            0 :         if let Some(auth) = provide_auth(&req) {
     650            0 :             match req.headers().get(AUTHORIZATION) {
     651            0 :                 Some(value) => {
     652            0 :                     let header_value = value.to_str().map_err(|_| {
     653            0 :                         ApiError::Unauthorized("malformed authorization header".to_string())
     654            0 :                     })?;
     655            0 :                     let token = parse_token(header_value)?;
     656              : 
     657            0 :                     let data = auth.decode(token).map_err(|err| {
     658            0 :                         warn!("Authentication error: {err}");
     659              :                         // Rely on From<AuthError> for ApiError impl
     660            0 :                         err
     661            0 :                     })?;
     662            0 :                     req.set_context(data.claims);
     663              :                 }
     664              :                 None => {
     665            0 :                     return Err(ApiError::Unauthorized(
     666            0 :                         "missing authorization header".to_string(),
     667            0 :                     ))
     668              :                 }
     669              :             }
     670            0 :         }
     671            0 :         Ok(req)
     672            0 :     })
     673            0 : }
     674              : 
     675            0 : pub fn add_response_header_middleware<B>(
     676            0 :     header: &str,
     677            0 :     value: &str,
     678            0 : ) -> anyhow::Result<Middleware<B, ApiError>>
     679            0 : where
     680            0 :     B: hyper::body::HttpBody + Send + Sync + 'static,
     681            0 : {
     682            0 :     let name =
     683            0 :         HeaderName::from_str(header).with_context(|| format!("invalid header name: {header}"))?;
     684            0 :     let value =
     685            0 :         HeaderValue::from_str(value).with_context(|| format!("invalid header value: {value}"))?;
     686            0 :     Ok(Middleware::post_with_info(
     687            0 :         move |mut response, request_info| {
     688            0 :             let name = name.clone();
     689            0 :             let value = value.clone();
     690            0 :             async move {
     691            0 :                 let headers = response.headers_mut();
     692            0 :                 if headers.contains_key(&name) {
     693            0 :                     warn!(
     694            0 :                         "{} response already contains header {:?}",
     695            0 :                         request_info.uri(),
     696            0 :                         &name,
     697              :                     );
     698            0 :                 } else {
     699            0 :                     headers.insert(name, value);
     700            0 :                 }
     701            0 :                 Ok(response)
     702            0 :             }
     703            0 :         },
     704            0 :     ))
     705            0 : }
     706              : 
     707            0 : pub fn check_permission_with(
     708            0 :     req: &Request<Body>,
     709            0 :     check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
     710            0 : ) -> Result<(), ApiError> {
     711            0 :     match req.context::<Claims>() {
     712            0 :         Some(claims) => Ok(check_permission(&claims)
     713            0 :             .map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
     714            0 :         None => Ok(()), // claims is None because auth is disabled
     715              :     }
     716            0 : }
     717              : 
     718              : #[cfg(test)]
     719              : mod tests {
     720              :     use super::*;
     721              :     use hyper::service::Service;
     722              :     use routerify::RequestServiceBuilder;
     723              :     use std::future::poll_fn;
     724              :     use std::net::{IpAddr, SocketAddr};
     725              : 
     726              :     #[tokio::test]
     727            1 :     async fn test_request_id_returned() {
     728            1 :         let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
     729            1 :         let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
     730            1 :         let mut service = builder.build(remote_addr);
     731            1 :         if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
     732            1 :             panic!("request service is not ready: {:?}", e);
     733            1 :         }
     734            1 : 
     735            1 :         let mut req: Request<Body> = Request::default();
     736            1 :         req.headers_mut()
     737            1 :             .append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
     738            1 : 
     739            1 :         let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
     740            1 : 
     741            1 :         let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
     742            1 : 
     743            1 :         assert!(header_val == "42", "response header mismatch");
     744            1 :     }
     745              : 
     746              :     #[tokio::test]
     747            1 :     async fn test_request_id_empty() {
     748            1 :         let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
     749            1 :         let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
     750            1 :         let mut service = builder.build(remote_addr);
     751            1 :         if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
     752            1 :             panic!("request service is not ready: {:?}", e);
     753            1 :         }
     754            1 : 
     755            1 :         let req: Request<Body> = Request::default();
     756            1 :         let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
     757            1 : 
     758            1 :         let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
     759            1 : 
     760            1 :         assert_ne!(header_val, None, "response header should NOT be empty");
     761            1 :     }
     762              : }
        

Generated by: LCOV version 2.1-beta