|             Line data    Source code 
       1              : //! HTTP client and server impls.
       2              : //! Other modules should use stuff from this module instead of
       3              : //! directly relying on deps like `reqwest` (think loose coupling).
       4              : 
       5              : pub mod health_server;
       6              : 
       7              : use std::time::{Duration, Instant};
       8              : 
       9              : use bytes::Bytes;
      10              : use futures::FutureExt;
      11              : use http::Method;
      12              : use http_body_util::BodyExt;
      13              : use hyper::body::Body;
      14              : pub(crate) use reqwest::{Request, Response};
      15              : use reqwest_middleware::RequestBuilder;
      16              : pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
      17              : pub(crate) use reqwest_retry::RetryTransientMiddleware;
      18              : pub(crate) use reqwest_retry::policies::ExponentialBackoff;
      19              : use thiserror::Error;
      20              : 
      21              : use crate::metrics::{ConsoleRequest, Metrics};
      22              : use crate::url::ApiUrl;
      23              : 
      24              : /// This is the preferred way to create new http clients,
      25              : /// because it takes care of observability (OpenTelemetry).
      26              : /// We deliberately don't want to replace this with a public static.
      27            1 : pub fn new_client() -> ClientWithMiddleware {
      28            1 :     let client = reqwest::ClientBuilder::new()
      29            1 :         .build()
      30            1 :         .expect("Failed to create http client");
      31              : 
      32            1 :     reqwest_middleware::ClientBuilder::new(client)
      33            1 :         .with(reqwest_tracing::TracingMiddleware::default())
      34            1 :         .build()
      35            1 : }
      36              : 
      37            0 : pub(crate) fn new_client_with_timeout(
      38            0 :     request_timeout: Duration,
      39            0 :     total_retry_duration: Duration,
      40            0 : ) -> ClientWithMiddleware {
      41            0 :     let timeout_client = reqwest::ClientBuilder::new()
      42            0 :         .timeout(request_timeout)
      43            0 :         .build()
      44            0 :         .expect("Failed to create http client with timeout");
      45              : 
      46            0 :     let retry_policy =
      47            0 :         ExponentialBackoff::builder().build_with_total_retry_duration(total_retry_duration);
      48              : 
      49            0 :     reqwest_middleware::ClientBuilder::new(timeout_client)
      50            0 :         .with(reqwest_tracing::TracingMiddleware::default())
      51              :         // As per docs, "This middleware always errors when given requests with streaming bodies".
      52              :         // That's all right because we only use this client to send `serde_json::RawValue`, which
      53              :         // is not a stream.
      54              :         //
      55              :         // ex-maintainer note:
      56              :         // this limitation can be fixed if streaming is necessary.
      57              :         // retries will still not be performed, but it wont error immediately
      58            0 :         .with(RetryTransientMiddleware::new_with_policy(retry_policy))
      59            0 :         .build()
      60            0 : }
      61              : 
      62              : /// Thin convenience wrapper for an API provided by an http endpoint.
      63              : #[derive(Debug, Clone)]
      64              : pub struct Endpoint {
      65              :     /// API's base URL.
      66              :     endpoint: ApiUrl,
      67              :     /// Connection manager with built-in pooling.
      68              :     client: ClientWithMiddleware,
      69              : }
      70              : 
      71              : impl Endpoint {
      72              :     /// Construct a new HTTP endpoint wrapper.
      73              :     /// Http client is not constructed under the hood so that it can be shared.
      74            2 :     pub fn new(endpoint: ApiUrl, client: impl Into<ClientWithMiddleware>) -> Self {
      75            2 :         Self {
      76            2 :             endpoint,
      77            2 :             client: client.into(),
      78            2 :         }
      79            2 :     }
      80              : 
      81              :     #[inline(always)]
      82            0 :     pub(crate) fn url(&self) -> &ApiUrl {
      83            0 :         &self.endpoint
      84            0 :     }
      85              : 
      86              :     /// Return a [builder](RequestBuilder) for a `GET` request,
      87              :     /// appending a single `path` segment to the base endpoint URL.
      88            2 :     pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
      89            2 :         self.get_with_url(|u| {
      90            2 :             u.path_segments_mut().push(path);
      91            2 :         })
      92            2 :     }
      93              : 
      94              :     /// Return a [builder](RequestBuilder) for a `GET` request,
      95              :     /// accepting a closure to modify the url path segments for more complex paths queries.
      96            2 :     pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
      97            2 :         self.request_with_url(Method::GET, f)
      98            2 :     }
      99              : 
     100              :     /// Return a [builder](RequestBuilder) for a request,
     101              :     /// accepting a closure to modify the url path segments for more complex paths queries.
     102            2 :     pub(crate) fn request_with_url(
     103            2 :         &self,
     104            2 :         method: Method,
     105            2 :         f: impl for<'a> FnOnce(&'a mut ApiUrl),
     106            2 :     ) -> RequestBuilder {
     107            2 :         let mut url = self.endpoint.clone();
     108            2 :         f(&mut url);
     109            2 :         self.client.request(method, url.into_inner())
     110            2 :     }
     111              : 
     112              :     /// Execute a [request](reqwest::Request).
     113            0 :     pub(crate) fn execute(
     114            0 :         &self,
     115            0 :         request: Request,
     116            0 :     ) -> impl Future<Output = Result<Response, Error>> {
     117            0 :         let metric = Metrics::get()
     118            0 :             .proxy
     119            0 :             .console_request_latency
     120            0 :             .with_labels(ConsoleRequest {
     121            0 :                 request: request.url().path(),
     122            0 :             });
     123              : 
     124            0 :         let req = self.client.execute(request).boxed();
     125              : 
     126            0 :         async move {
     127            0 :             let start = Instant::now();
     128            0 :             scopeguard::defer!({
     129              :                 Metrics::get()
     130              :                     .proxy
     131              :                     .console_request_latency
     132              :                     .get_metric(metric)
     133              :                     .observe_duration_since(start);
     134              :             });
     135              : 
     136            0 :             req.await
     137            0 :         }
     138            0 :     }
     139              : }
     140              : 
     141              : #[derive(Error, Debug)]
     142              : pub(crate) enum ReadBodyError<E> {
     143              :     #[error("Content length exceeds limit of {limit} bytes")]
     144              :     BodyTooLarge { limit: usize },
     145              : 
     146              :     #[error(transparent)]
     147              :     Read(#[from] E),
     148              : }
     149              : 
     150            9 : pub(crate) async fn read_body_with_limit<E>(
     151            9 :     mut b: impl Body<Data = Bytes, Error = E> + Unpin,
     152            9 :     limit: usize,
     153            9 : ) -> Result<Vec<u8>, ReadBodyError<E>> {
     154              :     // We could use `b.limited().collect().await.to_bytes()` here
     155              :     // but this ends up being slightly more efficient as far as I can tell.
     156              : 
     157              :     // check the lower bound of the size hint.
     158              :     // in reqwest, this value is influenced by the Content-Length header.
     159            9 :     let lower_bound = match usize::try_from(b.size_hint().lower()) {
     160            9 :         Ok(bound) if bound <= limit => bound,
     161            0 :         _ => return Err(ReadBodyError::BodyTooLarge { limit }),
     162              :     };
     163            9 :     let mut bytes = Vec::with_capacity(lower_bound);
     164              : 
     165           18 :     while let Some(frame) = b.frame().await.transpose()? {
     166            9 :         if let Ok(data) = frame.into_data() {
     167            9 :             if bytes.len() + data.len() > limit {
     168            0 :                 return Err(ReadBodyError::BodyTooLarge { limit });
     169            9 :             }
     170            9 :             bytes.extend_from_slice(&data);
     171            0 :         }
     172              :     }
     173              : 
     174            9 :     Ok(bytes)
     175            9 : }
     176              : 
     177              : #[cfg(test)]
     178              : mod tests {
     179              :     use reqwest::Client;
     180              : 
     181              :     use super::*;
     182              : 
     183              :     #[test]
     184            1 :     fn optional_query_params() -> anyhow::Result<()> {
     185            1 :         let url = "http://example.com".parse()?;
     186            1 :         let endpoint = Endpoint::new(url, Client::new());
     187              : 
     188              :         // Validate that this pattern makes sense.
     189            1 :         let req = endpoint
     190            1 :             .get_path("frobnicate")
     191            1 :             .query(&[
     192            1 :                 ("foo", Some("10")), // should be just `foo=10`
     193            1 :                 ("bar", None),       // shouldn't be passed at all
     194            1 :             ])
     195            1 :             .build()?;
     196              : 
     197            1 :         assert_eq!(req.url().as_str(), "http://example.com/frobnicate?foo=10");
     198              : 
     199            1 :         Ok(())
     200            1 :     }
     201              : 
     202              :     #[test]
     203            1 :     fn uuid_params() -> anyhow::Result<()> {
     204            1 :         let url = "http://example.com".parse()?;
     205            1 :         let endpoint = Endpoint::new(url, Client::new());
     206              : 
     207            1 :         let req = endpoint
     208            1 :             .get_path("frobnicate")
     209            1 :             .query(&[("session_id", uuid::Uuid::nil())])
     210            1 :             .build()?;
     211              : 
     212            1 :         assert_eq!(
     213            1 :             req.url().as_str(),
     214              :             "http://example.com/frobnicate?session_id=00000000-0000-0000-0000-000000000000"
     215              :         );
     216              : 
     217            1 :         Ok(())
     218            1 :     }
     219              : }
         |