LCOV - code coverage report
Current view: top level - libs/utils/src/http - endpoint.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 84.9 % 350 297
Test Date: 2024-02-07 07:37:29 Functions: 56.5 % 591 334

            Line data    Source code
       1              : use crate::auth::{AuthError, Claims, SwappableJwtAuth};
       2              : use crate::http::error::{api_error_handler, route_error_handler, ApiError};
       3              : use anyhow::Context;
       4              : use hyper::header::{HeaderName, AUTHORIZATION};
       5              : use hyper::http::HeaderValue;
       6              : use hyper::Method;
       7              : use hyper::{header::CONTENT_TYPE, Body, Request, Response};
       8              : use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
       9              : use once_cell::sync::Lazy;
      10              : use routerify::ext::RequestExt;
      11              : use routerify::{Middleware, RequestInfo, Router, RouterBuilder};
      12              : use tracing::{self, debug, info, info_span, warn, Instrument};
      13              : 
      14              : use std::future::Future;
      15              : use std::str::FromStr;
      16              : 
      17              : use bytes::{Bytes, BytesMut};
      18              : use std::io::Write as _;
      19              : use tokio::sync::mpsc;
      20              : use tokio_stream::wrappers::ReceiverStream;
      21              : 
      22          407 : static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
      23          407 :     register_int_counter!(
      24          407 :         "libmetrics_metric_handler_requests_total",
      25          407 :         "Number of metric requests made"
      26          407 :     )
      27          407 :     .expect("failed to define a metric")
      28          407 : });
      29              : 
      30              : static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
      31              : 
      32              : static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
      33        32476 : #[derive(Debug, Default, Clone)]
      34              : struct RequestId(String);
      35              : 
      36              : /// Adds a tracing info_span! instrumentation around the handler events,
      37              : /// logs the request start and end events for non-GET requests and non-200 responses.
      38              : ///
      39              : /// Usage: Replace `my_handler` with `|r| request_span(r, my_handler)`
      40              : ///
      41              : /// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
      42              : /// with this will get request info logged in the wrapping span, including the unique request ID.
      43              : ///
      44              : /// This also handles errors, logging them and converting them to an HTTP error response.
      45              : ///
      46              : /// NB: If the client disconnects, Hyper will drop the Future, without polling it to
      47              : /// completion. In other words, the handler must be async cancellation safe! request_span
      48              : /// prints a warning to the log when that happens, so that you have some trace of it in
      49              : /// the log.
      50              : ///
      51              : ///
      52              : /// There could be other ways to implement similar functionality:
      53              : ///
      54              : /// * procmacros placed on top of all handler methods
      55              : /// With all the drawbacks of procmacros, brings no difference implementation-wise,
      56              : /// and little code reduction compared to the existing approach.
      57              : ///
      58              : /// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
      59              : /// implemented for [`RouterBuilder`].
      60              : /// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
      61              : ///
      62              : /// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
      63              : /// later, in a post-response middleware.
      64              : /// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
      65              : /// tries to achive with its `.instrument` used in the current approach.
      66              : ///
      67              : /// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
      68        16225 : pub async fn request_span<R, H>(request: Request<Body>, handler: H) -> R::Output
      69        16225 : where
      70        16225 :     R: Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
      71        16225 :     H: FnOnce(Request<Body>) -> R + Send + Sync + 'static,
      72        16225 : {
      73        16225 :     let request_id = request.context::<RequestId>().unwrap_or_default().0;
      74        16225 :     let method = request.method();
      75        16225 :     let path = request.uri().path();
      76        16225 :     let request_span = info_span!("request", %method, %path, %request_id);
      77              : 
      78        16225 :     let log_quietly = method == Method::GET;
      79        16225 :     async move {
      80        16225 :         let cancellation_guard = RequestCancelled::warn_when_dropped_without_responding();
      81        16225 :         if log_quietly {
      82         7593 :             debug!("Handling request");
      83              :         } else {
      84         8632 :             info!("Handling request");
      85              :         }
      86              : 
      87              :         // No special handling for panics here. There's a `tracing_panic_hook` from another
      88              :         // module to do that globally.
      89        20999 :         let res = handler(request).await;
      90              : 
      91        16216 :         cancellation_guard.disarm();
      92        16216 : 
      93        16216 :         // Log the result if needed.
      94        16216 :         //
      95        16216 :         // We also convert any errors into an Ok response with HTTP error code here.
      96        16216 :         // `make_router` sets a last-resort error handler that would do the same, but
      97        16216 :         // we prefer to do it here, before we exit the request span, so that the error
      98        16216 :         // is still logged with the span.
      99        16216 :         //
     100        16216 :         // (Because we convert errors to Ok response, we never actually return an error,
     101        16216 :         // and we could declare the function to return the never type (`!`). However,
     102        16216 :         // using `routerify::RouterBuilder` requires a proper error type.)
     103        16216 :         match res {
     104        15927 :             Ok(response) => {
     105        15927 :                 let response_status = response.status();
     106        15927 :                 if log_quietly && response_status.is_success() {
     107         7461 :                     debug!("Request handled, status: {response_status}");
     108              :                 } else {
     109         8466 :                     info!("Request handled, status: {response_status}");
     110              :                 }
     111        15927 :                 Ok(response)
     112              :             }
     113          289 :             Err(err) => Ok(api_error_handler(err)),
     114              :         }
     115        16216 :     }
     116        16225 :     .instrument(request_span)
     117        20999 :     .await
     118        16216 : }
     119              : 
     120              : /// Drop guard to WARN in case the request was dropped before completion.
     121              : struct RequestCancelled {
     122              :     warn: Option<tracing::Span>,
     123              : }
     124              : 
     125              : impl RequestCancelled {
     126              :     /// Create the drop guard using the [`tracing::Span::current`] as the span.
     127        16225 :     fn warn_when_dropped_without_responding() -> Self {
     128        16225 :         RequestCancelled {
     129        16225 :             warn: Some(tracing::Span::current()),
     130        16225 :         }
     131        16225 :     }
     132              : 
     133              :     /// Consume the drop guard without logging anything.
     134        16216 :     fn disarm(mut self) {
     135        16216 :         self.warn = None;
     136        16216 :     }
     137              : }
     138              : 
     139              : impl Drop for RequestCancelled {
     140        16219 :     fn drop(&mut self) {
     141        16219 :         if std::thread::panicking() {
     142            0 :             // we are unwinding due to panicking, assume we are not dropped for cancellation
     143        16219 :         } else if let Some(span) = self.warn.take() {
     144              :             // the span has all of the info already, but the outer `.instrument(span)` has already
     145              :             // been dropped, so we need to manually re-enter it for this message.
     146              :             //
     147              :             // this is what the instrument would do before polling so it is fine.
     148            3 :             let _g = span.entered();
     149            3 :             warn!("request was dropped before completing");
     150        16216 :         }
     151        16219 :     }
     152              : }
     153              : 
     154              : /// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
     155              : pub struct ChannelWriter {
     156              :     buffer: BytesMut,
     157              :     pub tx: mpsc::Sender<std::io::Result<Bytes>>,
     158              :     written: usize,
     159              : }
     160              : 
     161              : impl ChannelWriter {
     162         1261 :     pub fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
     163         1261 :         assert_ne!(buf_len, 0);
     164         1261 :         ChannelWriter {
     165         1261 :             // split about half off the buffer from the start, because we flush depending on
     166         1261 :             // capacity. first flush will come sooner than without this, but now resizes will
     167         1261 :             // have better chance of picking up the "other" half. not guaranteed of course.
     168         1261 :             buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
     169         1261 :             tx,
     170         1261 :             written: 0,
     171         1261 :         }
     172         1261 :     }
     173              : 
     174         4615 :     pub fn flush0(&mut self) -> std::io::Result<usize> {
     175         4615 :         let n = self.buffer.len();
     176         4615 :         if n == 0 {
     177            0 :             return Ok(0);
     178         4615 :         }
     179         4615 : 
     180         4615 :         tracing::trace!(n, "flushing");
     181         4615 :         let ready = self.buffer.split().freeze();
     182         4615 : 
     183         4615 :         // not ideal to call from blocking code to block_on, but we are sure that this
     184         4615 :         // operation does not spawn_blocking other tasks
     185         4615 :         let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
     186         4615 :             self.tx.send(Ok(ready)).await.map_err(|_| ())?;
     187              : 
     188              :             // throttle sending to allow reuse of our buffer in `write`.
     189         4615 :             self.tx.reserve().await.map_err(|_| ())?;
     190              : 
     191              :             // now the response task has picked up the buffer and hopefully started
     192              :             // sending it to the client.
     193         4615 :             Ok(())
     194         4615 :         });
     195         4615 :         if res.is_err() {
     196            0 :             return Err(std::io::ErrorKind::BrokenPipe.into());
     197         4615 :         }
     198         4615 :         self.written += n;
     199         4615 :         Ok(n)
     200         4615 :     }
     201              : 
     202         1261 :     pub fn flushed_bytes(&self) -> usize {
     203         1261 :         self.written
     204         1261 :     }
     205              : }
     206              : 
     207              : impl std::io::Write for ChannelWriter {
     208     41531899 :     fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
     209     41531899 :         let remaining = self.buffer.capacity() - self.buffer.len();
     210     41531899 : 
     211     41531899 :         let out_of_space = remaining < buf.len();
     212     41531899 : 
     213     41531899 :         let original_len = buf.len();
     214     41531899 : 
     215     41531899 :         if out_of_space {
     216         3354 :             let can_still_fit = buf.len() - remaining;
     217         3354 :             self.buffer.extend_from_slice(&buf[..can_still_fit]);
     218         3354 :             buf = &buf[can_still_fit..];
     219         3354 :             self.flush0()?;
     220     41528545 :         }
     221              : 
     222              :         // assume that this will often under normal operation just move the pointer back to the
     223              :         // beginning of allocation, because previous split off parts are already sent and
     224              :         // dropped.
     225     41531899 :         self.buffer.extend_from_slice(buf);
     226     41531899 :         Ok(original_len)
     227     41531899 :     }
     228              : 
     229         1261 :     fn flush(&mut self) -> std::io::Result<()> {
     230         1261 :         self.flush0().map(|_| ())
     231         1261 :     }
     232              : }
     233              : 
     234         1255 : async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
     235         1255 :     SERVE_METRICS_COUNT.inc();
     236         1255 : 
     237         1255 :     let started_at = std::time::Instant::now();
     238         1255 : 
     239         1255 :     let (tx, rx) = mpsc::channel(1);
     240         1255 : 
     241         1255 :     let body = Body::wrap_stream(ReceiverStream::new(rx));
     242         1255 : 
     243         1255 :     let mut writer = ChannelWriter::new(128 * 1024, tx);
     244         1255 : 
     245         1255 :     let encoder = TextEncoder::new();
     246         1255 : 
     247         1255 :     let response = Response::builder()
     248         1255 :         .status(200)
     249         1255 :         .header(CONTENT_TYPE, encoder.format_type())
     250         1255 :         .body(body)
     251         1255 :         .unwrap();
     252              : 
     253         1255 :     let span = info_span!("blocking");
     254         1255 :     tokio::task::spawn_blocking(move || {
     255         1255 :         let _span = span.entered();
     256         1255 :         let metrics = metrics::gather();
     257         1255 :         let res = encoder
     258         1255 :             .encode(&metrics, &mut writer)
     259         1255 :             .and_then(|_| writer.flush().map_err(|e| e.into()));
     260         1255 : 
     261         1255 :         match res {
     262              :             Ok(()) => {
     263         1255 :                 tracing::info!(
     264         1255 :                     bytes = writer.flushed_bytes(),
     265         1255 :                     elapsed_ms = started_at.elapsed().as_millis(),
     266         1255 :                     "responded /metrics"
     267         1255 :                 );
     268              :             }
     269            0 :             Err(e) => {
     270            0 :                 tracing::warn!("failed to write out /metrics response: {e:#}");
     271              :                 // semantics of this error are quite... unclear. we want to error the stream out to
     272              :                 // abort the response to somehow notify the client that we failed.
     273              :                 //
     274              :                 // though, most likely the reason for failure is that the receiver is already gone.
     275            0 :                 drop(
     276            0 :                     writer
     277            0 :                         .tx
     278            0 :                         .blocking_send(Err(std::io::ErrorKind::BrokenPipe.into())),
     279            0 :                 );
     280              :             }
     281              :         }
     282         1255 :     });
     283         1255 : 
     284         1255 :     Ok(response)
     285         1255 : }
     286              : 
     287         1502 : pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
     288         1502 : ) -> Middleware<B, ApiError> {
     289        16260 :     Middleware::pre(move |req| async move {
     290        16260 :         let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
     291            2 :             Some(request_id) => request_id
     292            2 :                 .to_str()
     293            2 :                 .expect("extract request id value")
     294            2 :                 .to_owned(),
     295              :             None => {
     296        16258 :                 let request_id = uuid::Uuid::new_v4();
     297        16258 :                 request_id.to_string()
     298              :             }
     299              :         };
     300        16260 :         req.set_context(RequestId(request_id));
     301        16260 : 
     302        16260 :         Ok(req)
     303        16260 :     })
     304         1502 : }
     305              : 
     306        16251 : async fn add_request_id_header_to_response(
     307        16251 :     mut res: Response<Body>,
     308        16251 :     req_info: RequestInfo,
     309        16251 : ) -> Result<Response<Body>, ApiError> {
     310        16251 :     if let Some(request_id) = req_info.context::<RequestId>() {
     311        16251 :         if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
     312        16251 :             res.headers_mut()
     313        16251 :                 .insert(&X_REQUEST_ID_HEADER, request_header_value);
     314        16251 :         };
     315            0 :     };
     316              : 
     317        16251 :     Ok(res)
     318        16251 : }
     319              : 
     320         1502 : pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
     321         1502 :     Router::builder()
     322         1502 :         .middleware(add_request_id_middleware())
     323         1502 :         .middleware(Middleware::post_with_info(
     324         1502 :             add_request_id_header_to_response,
     325         1502 :         ))
     326         1502 :         .get("/metrics", |r| request_span(r, prometheus_metrics_handler))
     327         1502 :         .err_handler(route_error_handler)
     328         1502 : }
     329              : 
     330          604 : pub fn attach_openapi_ui(
     331          604 :     router_builder: RouterBuilder<hyper::Body, ApiError>,
     332          604 :     spec: &'static [u8],
     333          604 :     spec_mount_path: &'static str,
     334          604 :     ui_mount_path: &'static str,
     335          604 : ) -> RouterBuilder<hyper::Body, ApiError> {
     336          604 :     router_builder
     337          604 :         .get(spec_mount_path,
     338          604 :             move |r| request_span(r, move |_| async move {
     339            0 :                 Ok(Response::builder().body(Body::from(spec)).unwrap())
     340          604 :             })
     341          604 :         )
     342          604 :         .get(ui_mount_path,
     343          604 :              move |r| request_span(r, move |_| async move {
     344            0 :                  Ok(Response::builder().body(Body::from(format!(r#"
     345            0 :                 <!DOCTYPE html>
     346            0 :                 <html lang="en">
     347            0 :                 <head>
     348            0 :                 <title>rweb</title>
     349            0 :                 <link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
     350            0 :                 </head>
     351            0 :                 <body>
     352            0 :                     <div id="swagger-ui"></div>
     353            0 :                     <script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
     354            0 :                     <script>
     355            0 :                         window.onload = function() {{
     356            0 :                         const ui = SwaggerUIBundle({{
     357            0 :                             "dom_id": "\#swagger-ui",
     358            0 :                             presets: [
     359            0 :                             SwaggerUIBundle.presets.apis,
     360            0 :                             SwaggerUIBundle.SwaggerUIStandalonePreset
     361            0 :                             ],
     362            0 :                             layout: "BaseLayout",
     363            0 :                             deepLinking: true,
     364            0 :                             showExtensions: true,
     365            0 :                             showCommonExtensions: true,
     366            0 :                             url: "{}",
     367            0 :                         }})
     368            0 :                         window.ui = ui;
     369            0 :                     }};
     370            0 :                 </script>
     371            0 :                 </body>
     372            0 :                 </html>
     373            0 :             "#, spec_mount_path))).unwrap())
     374          604 :              })
     375          604 :         )
     376          604 : }
     377              : 
     378          190 : fn parse_token(header_value: &str) -> Result<&str, ApiError> {
     379              :     // header must be in form Bearer <token>
     380          190 :     let (prefix, token) = header_value
     381          190 :         .split_once(' ')
     382          190 :         .ok_or_else(|| ApiError::Unauthorized("malformed authorization header".to_string()))?;
     383          190 :     if prefix != "Bearer" {
     384            0 :         return Err(ApiError::Unauthorized(
     385            0 :             "malformed authorization header".to_string(),
     386            0 :         ));
     387          190 :     }
     388          190 :     Ok(token)
     389          190 : }
     390              : 
     391           42 : pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
     392           42 :     provide_auth: fn(&Request<Body>) -> Option<&SwappableJwtAuth>,
     393           42 : ) -> Middleware<B, ApiError> {
     394          284 :     Middleware::pre(move |req| async move {
     395          284 :         if let Some(auth) = provide_auth(&req) {
     396          195 :             match req.headers().get(AUTHORIZATION) {
     397          190 :                 Some(value) => {
     398          190 :                     let header_value = value.to_str().map_err(|_| {
     399            0 :                         ApiError::Unauthorized("malformed authorization header".to_string())
     400          190 :                     })?;
     401          190 :                     let token = parse_token(header_value)?;
     402              : 
     403          190 :                     let data = auth.decode(token).map_err(|err| {
     404            3 :                         warn!("Authentication error: {err}");
     405              :                         // Rely on From<AuthError> for ApiError impl
     406            3 :                         err
     407          190 :                     })?;
     408          187 :                     req.set_context(data.claims);
     409              :                 }
     410              :                 None => {
     411            5 :                     return Err(ApiError::Unauthorized(
     412            5 :                         "missing authorization header".to_string(),
     413            5 :                     ))
     414              :                 }
     415              :             }
     416           89 :         }
     417          276 :         Ok(req)
     418          284 :     })
     419           42 : }
     420              : 
     421          604 : pub fn add_response_header_middleware<B>(
     422          604 :     header: &str,
     423          604 :     value: &str,
     424          604 : ) -> anyhow::Result<Middleware<B, ApiError>>
     425          604 : where
     426          604 :     B: hyper::body::HttpBody + Send + Sync + 'static,
     427          604 : {
     428          604 :     let name =
     429          604 :         HeaderName::from_str(header).with_context(|| format!("invalid header name: {header}"))?;
     430          604 :     let value =
     431          604 :         HeaderValue::from_str(value).with_context(|| format!("invalid header value: {value}"))?;
     432          604 :     Ok(Middleware::post_with_info(
     433        10405 :         move |mut response, request_info| {
     434        10405 :             let name = name.clone();
     435        10405 :             let value = value.clone();
     436        10405 :             async move {
     437        10405 :                 let headers = response.headers_mut();
     438        10405 :                 if headers.contains_key(&name) {
     439            0 :                     warn!(
     440            0 :                         "{} response already contains header {:?}",
     441            0 :                         request_info.uri(),
     442            0 :                         &name,
     443            0 :                     );
     444        10405 :                 } else {
     445        10405 :                     headers.insert(name, value);
     446        10405 :                 }
     447        10405 :                 Ok(response)
     448        10405 :             }
     449        10405 :         },
     450          604 :     ))
     451          604 : }
     452              : 
     453        10383 : pub fn check_permission_with(
     454        10383 :     req: &Request<Body>,
     455        10383 :     check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
     456        10383 : ) -> Result<(), ApiError> {
     457        10383 :     match req.context::<Claims>() {
     458          112 :         Some(claims) => Ok(check_permission(&claims)
     459          112 :             .map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
     460        10271 :         None => Ok(()), // claims is None because auth is disabled
     461              :     }
     462        10383 : }
     463              : 
     464              : #[cfg(test)]
     465              : mod tests {
     466              :     use super::*;
     467              :     use futures::future::poll_fn;
     468              :     use hyper::service::Service;
     469              :     use routerify::RequestServiceBuilder;
     470              :     use std::net::{IpAddr, SocketAddr};
     471              : 
     472            2 :     #[tokio::test]
     473            2 :     async fn test_request_id_returned() {
     474            2 :         let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
     475            2 :         let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
     476            2 :         let mut service = builder.build(remote_addr);
     477            2 :         if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
     478            0 :             panic!("request service is not ready: {:?}", e);
     479            2 :         }
     480            2 : 
     481            2 :         let mut req: Request<Body> = Request::default();
     482            2 :         req.headers_mut()
     483            2 :             .append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
     484              : 
     485            2 :         let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
     486            2 : 
     487            2 :         let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
     488              : 
     489            2 :         assert!(header_val == "42", "response header mismatch");
     490              :     }
     491              : 
     492            2 :     #[tokio::test]
     493            2 :     async fn test_request_id_empty() {
     494            2 :         let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
     495            2 :         let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
     496            2 :         let mut service = builder.build(remote_addr);
     497            2 :         if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
     498            0 :             panic!("request service is not ready: {:?}", e);
     499            2 :         }
     500            2 : 
     501            2 :         let req: Request<Body> = Request::default();
     502            2 :         let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
     503            2 : 
     504            2 :         let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
     505            2 : 
     506            2 :         assert_ne!(header_val, None, "response header should NOT be empty");
     507              :     }
     508              : }
        

Generated by: LCOV version 2.1-beta