Line data Source code
1 : //! HTTP client and server impls.
2 : //! Other modules should use stuff from this module instead of
3 : //! directly relying on deps like `reqwest` (think loose coupling).
4 :
5 : pub mod health_server;
6 :
7 : use std::time::Duration;
8 :
9 : use anyhow::bail;
10 : use bytes::Bytes;
11 : use http_body_util::BodyExt;
12 : use hyper1::body::Body;
13 : use serde::de::DeserializeOwned;
14 :
15 : pub use reqwest::{Request, Response, StatusCode};
16 : pub use reqwest_middleware::{ClientWithMiddleware, Error};
17 : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
18 :
19 : use crate::{
20 : metrics::{ConsoleRequest, Metrics},
21 : url::ApiUrl,
22 : };
23 : use reqwest_middleware::RequestBuilder;
24 :
25 : /// This is the preferred way to create new http clients,
26 : /// because it takes care of observability (OpenTelemetry).
27 : /// We deliberately don't want to replace this with a public static.
28 2 : pub fn new_client() -> ClientWithMiddleware {
29 2 : let client = reqwest::ClientBuilder::new()
30 2 : .build()
31 2 : .expect("Failed to create http client");
32 2 :
33 2 : reqwest_middleware::ClientBuilder::new(client)
34 2 : .with(reqwest_tracing::TracingMiddleware::default())
35 2 : .build()
36 2 : }
37 :
38 0 : pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware {
39 0 : let timeout_client = reqwest::ClientBuilder::new()
40 0 : .timeout(default_timout)
41 0 : .build()
42 0 : .expect("Failed to create http client with timeout");
43 0 :
44 0 : let retry_policy =
45 0 : ExponentialBackoff::builder().build_with_total_retry_duration(default_timout);
46 0 :
47 0 : reqwest_middleware::ClientBuilder::new(timeout_client)
48 0 : .with(reqwest_tracing::TracingMiddleware::default())
49 0 : // As per docs, "This middleware always errors when given requests with streaming bodies".
50 0 : // That's all right because we only use this client to send `serde_json::RawValue`, which
51 0 : // is not a stream.
52 0 : //
53 0 : // ex-maintainer note:
54 0 : // this limitation can be fixed if streaming is necessary.
55 0 : // retries will still not be performed, but it wont error immediately
56 0 : .with(RetryTransientMiddleware::new_with_policy(retry_policy))
57 0 : .build()
58 0 : }
59 :
60 : /// Thin convenience wrapper for an API provided by an http endpoint.
61 : #[derive(Debug, Clone)]
62 : pub struct Endpoint {
63 : /// API's base URL.
64 : endpoint: ApiUrl,
65 : /// Connection manager with built-in pooling.
66 : client: ClientWithMiddleware,
67 : }
68 :
69 : impl Endpoint {
70 : /// Construct a new HTTP endpoint wrapper.
71 : /// Http client is not constructed under the hood so that it can be shared.
72 4 : pub fn new(endpoint: ApiUrl, client: impl Into<ClientWithMiddleware>) -> Self {
73 4 : Self {
74 4 : endpoint,
75 4 : client: client.into(),
76 4 : }
77 4 : }
78 :
79 : #[inline(always)]
80 0 : pub fn url(&self) -> &ApiUrl {
81 0 : &self.endpoint
82 0 : }
83 :
84 : /// Return a [builder](RequestBuilder) for a `GET` request,
85 : /// appending a single `path` segment to the base endpoint URL.
86 4 : pub fn get(&self, path: &str) -> RequestBuilder {
87 4 : let mut url = self.endpoint.clone();
88 4 : url.path_segments_mut().push(path);
89 4 : self.client.get(url.into_inner())
90 4 : }
91 :
92 : /// Execute a [request](reqwest::Request).
93 0 : pub async fn execute(&self, request: Request) -> Result<Response, Error> {
94 0 : let _timer = Metrics::get()
95 0 : .proxy
96 0 : .console_request_latency
97 0 : .start_timer(ConsoleRequest {
98 0 : request: request.url().path(),
99 0 : });
100 0 :
101 0 : self.client.execute(request).await
102 0 : }
103 : }
104 :
105 4 : pub async fn parse_json_body_with_limit<D: DeserializeOwned>(
106 4 : mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
107 4 : limit: usize,
108 4 : ) -> anyhow::Result<D> {
109 : // We could use `b.limited().collect().await.to_bytes()` here
110 : // but this ends up being slightly more efficient as far as I can tell.
111 :
112 : // check the lower bound of the size hint.
113 : // in reqwest, this value is influenced by the Content-Length header.
114 4 : let lower_bound = match usize::try_from(b.size_hint().lower()) {
115 4 : Ok(bound) if bound <= limit => bound,
116 0 : _ => bail!("Content length exceeds limit of {limit} bytes"),
117 : };
118 4 : let mut bytes = Vec::with_capacity(lower_bound);
119 :
120 8 : while let Some(frame) = b.frame().await.transpose()? {
121 4 : if let Ok(data) = frame.into_data() {
122 4 : if bytes.len() + data.len() > limit {
123 0 : bail!("Content length exceeds limit of {limit} bytes")
124 4 : }
125 4 : bytes.extend_from_slice(&data);
126 0 : }
127 : }
128 :
129 4 : Ok(serde_json::from_slice::<D>(&bytes)?)
130 4 : }
131 :
132 : #[cfg(test)]
133 : mod tests {
134 : use super::*;
135 : use reqwest::Client;
136 :
137 : #[test]
138 2 : fn optional_query_params() -> anyhow::Result<()> {
139 2 : let url = "http://example.com".parse()?;
140 2 : let endpoint = Endpoint::new(url, Client::new());
141 :
142 : // Validate that this pattern makes sense.
143 2 : let req = endpoint
144 2 : .get("frobnicate")
145 2 : .query(&[
146 2 : ("foo", Some("10")), // should be just `foo=10`
147 2 : ("bar", None), // shouldn't be passed at all
148 2 : ])
149 2 : .build()?;
150 :
151 2 : assert_eq!(req.url().as_str(), "http://example.com/frobnicate?foo=10");
152 :
153 2 : Ok(())
154 2 : }
155 :
156 : #[test]
157 2 : fn uuid_params() -> anyhow::Result<()> {
158 2 : let url = "http://example.com".parse()?;
159 2 : let endpoint = Endpoint::new(url, Client::new());
160 :
161 2 : let req = endpoint
162 2 : .get("frobnicate")
163 2 : .query(&[("session_id", uuid::Uuid::nil())])
164 2 : .build()?;
165 :
166 2 : assert_eq!(
167 2 : req.url().as_str(),
168 2 : "http://example.com/frobnicate?session_id=00000000-0000-0000-0000-000000000000"
169 2 : );
170 :
171 2 : Ok(())
172 2 : }
173 : }
|