LCOV - code coverage report
Current view: top level - proxy/src - http.rs (source / functions) Coverage Total Hit
Test: ccf45ed1c149555259baec52d6229a81013dcd6a.info Lines: 61.9 % 97 60
Test Date: 2024-08-21 17:32:46 Functions: 46.7 % 15 7

            Line data    Source code
       1              : //! HTTP client and server impls.
       2              : //! Other modules should use stuff from this module instead of
       3              : //! directly relying on deps like `reqwest` (think loose coupling).
       4              : 
       5              : pub mod health_server;
       6              : 
       7              : use std::time::Duration;
       8              : 
       9              : use anyhow::bail;
      10              : use bytes::Bytes;
      11              : use http_body_util::BodyExt;
      12              : use hyper1::body::Body;
      13              : use serde::de::DeserializeOwned;
      14              : 
      15              : pub use reqwest::{Request, Response, StatusCode};
      16              : pub use reqwest_middleware::{ClientWithMiddleware, Error};
      17              : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
      18              : 
      19              : use crate::{
      20              :     metrics::{ConsoleRequest, Metrics},
      21              :     url::ApiUrl,
      22              : };
      23              : use reqwest_middleware::RequestBuilder;
      24              : 
      25              : /// This is the preferred way to create new http clients,
      26              : /// because it takes care of observability (OpenTelemetry).
      27              : /// We deliberately don't want to replace this with a public static.
      28            2 : pub fn new_client() -> ClientWithMiddleware {
      29            2 :     let client = reqwest::ClientBuilder::new()
      30            2 :         .build()
      31            2 :         .expect("Failed to create http client");
      32            2 : 
      33            2 :     reqwest_middleware::ClientBuilder::new(client)
      34            2 :         .with(reqwest_tracing::TracingMiddleware::default())
      35            2 :         .build()
      36            2 : }
      37              : 
      38            0 : pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware {
      39            0 :     let timeout_client = reqwest::ClientBuilder::new()
      40            0 :         .timeout(default_timout)
      41            0 :         .build()
      42            0 :         .expect("Failed to create http client with timeout");
      43            0 : 
      44            0 :     let retry_policy =
      45            0 :         ExponentialBackoff::builder().build_with_total_retry_duration(default_timout);
      46            0 : 
      47            0 :     reqwest_middleware::ClientBuilder::new(timeout_client)
      48            0 :         .with(reqwest_tracing::TracingMiddleware::default())
      49            0 :         // As per docs, "This middleware always errors when given requests with streaming bodies".
      50            0 :         // That's all right because we only use this client to send `serde_json::RawValue`, which
      51            0 :         // is not a stream.
      52            0 :         //
      53            0 :         // ex-maintainer note:
      54            0 :         // this limitation can be fixed if streaming is necessary.
      55            0 :         // retries will still not be performed, but it wont error immediately
      56            0 :         .with(RetryTransientMiddleware::new_with_policy(retry_policy))
      57            0 :         .build()
      58            0 : }
      59              : 
      60              : /// Thin convenience wrapper for an API provided by an http endpoint.
      61              : #[derive(Debug, Clone)]
      62              : pub struct Endpoint {
      63              :     /// API's base URL.
      64              :     endpoint: ApiUrl,
      65              :     /// Connection manager with built-in pooling.
      66              :     client: ClientWithMiddleware,
      67              : }
      68              : 
      69              : impl Endpoint {
      70              :     /// Construct a new HTTP endpoint wrapper.
      71              :     /// Http client is not constructed under the hood so that it can be shared.
      72            4 :     pub fn new(endpoint: ApiUrl, client: impl Into<ClientWithMiddleware>) -> Self {
      73            4 :         Self {
      74            4 :             endpoint,
      75            4 :             client: client.into(),
      76            4 :         }
      77            4 :     }
      78              : 
      79              :     #[inline(always)]
      80            0 :     pub fn url(&self) -> &ApiUrl {
      81            0 :         &self.endpoint
      82            0 :     }
      83              : 
      84              :     /// Return a [builder](RequestBuilder) for a `GET` request,
      85              :     /// appending a single `path` segment to the base endpoint URL.
      86            4 :     pub fn get(&self, path: &str) -> RequestBuilder {
      87            4 :         let mut url = self.endpoint.clone();
      88            4 :         url.path_segments_mut().push(path);
      89            4 :         self.client.get(url.into_inner())
      90            4 :     }
      91              : 
      92              :     /// Execute a [request](reqwest::Request).
      93            0 :     pub async fn execute(&self, request: Request) -> Result<Response, Error> {
      94            0 :         let _timer = Metrics::get()
      95            0 :             .proxy
      96            0 :             .console_request_latency
      97            0 :             .start_timer(ConsoleRequest {
      98            0 :                 request: request.url().path(),
      99            0 :             });
     100            0 : 
     101            0 :         self.client.execute(request).await
     102            0 :     }
     103              : }
     104              : 
     105            4 : pub async fn parse_json_body_with_limit<D: DeserializeOwned>(
     106            4 :     mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
     107            4 :     limit: usize,
     108            4 : ) -> anyhow::Result<D> {
     109              :     // We could use `b.limited().collect().await.to_bytes()` here
     110              :     // but this ends up being slightly more efficient as far as I can tell.
     111              : 
     112              :     // check the lower bound of the size hint.
     113              :     // in reqwest, this value is influenced by the Content-Length header.
     114            4 :     let lower_bound = match usize::try_from(b.size_hint().lower()) {
     115            4 :         Ok(bound) if bound <= limit => bound,
     116            0 :         _ => bail!("Content length exceeds limit of {limit} bytes"),
     117              :     };
     118            4 :     let mut bytes = Vec::with_capacity(lower_bound);
     119              : 
     120            8 :     while let Some(frame) = b.frame().await.transpose()? {
     121            4 :         if let Ok(data) = frame.into_data() {
     122            4 :             if bytes.len() + data.len() > limit {
     123            0 :                 bail!("Content length exceeds limit of {limit} bytes")
     124            4 :             }
     125            4 :             bytes.extend_from_slice(&data);
     126            0 :         }
     127              :     }
     128              : 
     129            4 :     Ok(serde_json::from_slice::<D>(&bytes)?)
     130            4 : }
     131              : 
     132              : #[cfg(test)]
     133              : mod tests {
     134              :     use super::*;
     135              :     use reqwest::Client;
     136              : 
     137              :     #[test]
     138            2 :     fn optional_query_params() -> anyhow::Result<()> {
     139            2 :         let url = "http://example.com".parse()?;
     140            2 :         let endpoint = Endpoint::new(url, Client::new());
     141              : 
     142              :         // Validate that this pattern makes sense.
     143            2 :         let req = endpoint
     144            2 :             .get("frobnicate")
     145            2 :             .query(&[
     146            2 :                 ("foo", Some("10")), // should be just `foo=10`
     147            2 :                 ("bar", None),       // shouldn't be passed at all
     148            2 :             ])
     149            2 :             .build()?;
     150              : 
     151            2 :         assert_eq!(req.url().as_str(), "http://example.com/frobnicate?foo=10");
     152              : 
     153            2 :         Ok(())
     154            2 :     }
     155              : 
     156              :     #[test]
     157            2 :     fn uuid_params() -> anyhow::Result<()> {
     158            2 :         let url = "http://example.com".parse()?;
     159            2 :         let endpoint = Endpoint::new(url, Client::new());
     160              : 
     161            2 :         let req = endpoint
     162            2 :             .get("frobnicate")
     163            2 :             .query(&[("session_id", uuid::Uuid::nil())])
     164            2 :             .build()?;
     165              : 
     166            2 :         assert_eq!(
     167            2 :             req.url().as_str(),
     168            2 :             "http://example.com/frobnicate?session_id=00000000-0000-0000-0000-000000000000"
     169            2 :         );
     170              : 
     171            2 :         Ok(())
     172            2 :     }
     173              : }
        

Generated by: LCOV version 2.1-beta