Line data Source code
1 : //! HTTP client and server impls.
2 : //! Other modules should use stuff from this module instead of
3 : //! directly relying on deps like `reqwest` (think loose coupling).
4 :
5 : pub mod health_server;
6 :
7 : use std::time::Duration;
8 :
9 : use anyhow::bail;
10 : use bytes::Bytes;
11 : use http::Method;
12 : use http_body_util::BodyExt;
13 : use hyper::body::Body;
14 : pub(crate) use reqwest::{Request, Response};
15 : use reqwest_middleware::RequestBuilder;
16 : pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
17 : pub(crate) use reqwest_retry::policies::ExponentialBackoff;
18 : pub(crate) use reqwest_retry::RetryTransientMiddleware;
19 : use serde::de::DeserializeOwned;
20 :
21 : use crate::metrics::{ConsoleRequest, Metrics};
22 : use crate::url::ApiUrl;
23 :
24 : /// This is the preferred way to create new http clients,
25 : /// because it takes care of observability (OpenTelemetry).
26 : /// We deliberately don't want to replace this with a public static.
27 1 : pub fn new_client() -> ClientWithMiddleware {
28 1 : let client = reqwest::ClientBuilder::new()
29 1 : .build()
30 1 : .expect("Failed to create http client");
31 1 :
32 1 : reqwest_middleware::ClientBuilder::new(client)
33 1 : .with(reqwest_tracing::TracingMiddleware::default())
34 1 : .build()
35 1 : }
36 :
37 0 : pub(crate) fn new_client_with_timeout(
38 0 : request_timeout: Duration,
39 0 : total_retry_duration: Duration,
40 0 : ) -> ClientWithMiddleware {
41 0 : let timeout_client = reqwest::ClientBuilder::new()
42 0 : .timeout(request_timeout)
43 0 : .build()
44 0 : .expect("Failed to create http client with timeout");
45 0 :
46 0 : let retry_policy =
47 0 : ExponentialBackoff::builder().build_with_total_retry_duration(total_retry_duration);
48 0 :
49 0 : reqwest_middleware::ClientBuilder::new(timeout_client)
50 0 : .with(reqwest_tracing::TracingMiddleware::default())
51 0 : // As per docs, "This middleware always errors when given requests with streaming bodies".
52 0 : // That's all right because we only use this client to send `serde_json::RawValue`, which
53 0 : // is not a stream.
54 0 : //
55 0 : // ex-maintainer note:
56 0 : // this limitation can be fixed if streaming is necessary.
57 0 : // retries will still not be performed, but it wont error immediately
58 0 : .with(RetryTransientMiddleware::new_with_policy(retry_policy))
59 0 : .build()
60 0 : }
61 :
62 : /// Thin convenience wrapper for an API provided by an http endpoint.
63 : #[derive(Debug, Clone)]
64 : pub struct Endpoint {
65 : /// API's base URL.
66 : endpoint: ApiUrl,
67 : /// Connection manager with built-in pooling.
68 : client: ClientWithMiddleware,
69 : }
70 :
71 : impl Endpoint {
72 : /// Construct a new HTTP endpoint wrapper.
73 : /// Http client is not constructed under the hood so that it can be shared.
74 2 : pub fn new(endpoint: ApiUrl, client: impl Into<ClientWithMiddleware>) -> Self {
75 2 : Self {
76 2 : endpoint,
77 2 : client: client.into(),
78 2 : }
79 2 : }
80 :
81 : #[inline(always)]
82 0 : pub(crate) fn url(&self) -> &ApiUrl {
83 0 : &self.endpoint
84 0 : }
85 :
86 : /// Return a [builder](RequestBuilder) for a `GET` request,
87 : /// appending a single `path` segment to the base endpoint URL.
88 2 : pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
89 2 : self.get_with_url(|u| {
90 2 : u.path_segments_mut().push(path);
91 2 : })
92 2 : }
93 :
94 : /// Return a [builder](RequestBuilder) for a `GET` request,
95 : /// accepting a closure to modify the url path segments for more complex paths queries.
96 2 : pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
97 2 : self.request_with_url(Method::GET, f)
98 2 : }
99 :
100 : /// Return a [builder](RequestBuilder) for a request,
101 : /// accepting a closure to modify the url path segments for more complex paths queries.
102 2 : pub(crate) fn request_with_url(
103 2 : &self,
104 2 : method: Method,
105 2 : f: impl for<'a> FnOnce(&'a mut ApiUrl),
106 2 : ) -> RequestBuilder {
107 2 : let mut url = self.endpoint.clone();
108 2 : f(&mut url);
109 2 : self.client.request(method, url.into_inner())
110 2 : }
111 :
112 : /// Execute a [request](reqwest::Request).
113 0 : pub(crate) async fn execute(&self, request: Request) -> Result<Response, Error> {
114 0 : let _timer = Metrics::get()
115 0 : .proxy
116 0 : .console_request_latency
117 0 : .start_timer(ConsoleRequest {
118 0 : request: request.url().path(),
119 0 : });
120 0 :
121 0 : self.client.execute(request).await
122 0 : }
123 : }
124 :
125 2 : pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
126 2 : mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
127 2 : limit: usize,
128 2 : ) -> anyhow::Result<D> {
129 : // We could use `b.limited().collect().await.to_bytes()` here
130 : // but this ends up being slightly more efficient as far as I can tell.
131 :
132 : // check the lower bound of the size hint.
133 : // in reqwest, this value is influenced by the Content-Length header.
134 2 : let lower_bound = match usize::try_from(b.size_hint().lower()) {
135 2 : Ok(bound) if bound <= limit => bound,
136 0 : _ => bail!("Content length exceeds limit of {limit} bytes"),
137 : };
138 2 : let mut bytes = Vec::with_capacity(lower_bound);
139 :
140 4 : while let Some(frame) = b.frame().await.transpose()? {
141 2 : if let Ok(data) = frame.into_data() {
142 2 : if bytes.len() + data.len() > limit {
143 0 : bail!("Content length exceeds limit of {limit} bytes")
144 2 : }
145 2 : bytes.extend_from_slice(&data);
146 0 : }
147 : }
148 :
149 2 : Ok(serde_json::from_slice::<D>(&bytes)?)
150 2 : }
151 :
152 : #[cfg(test)]
153 : mod tests {
154 : use reqwest::Client;
155 :
156 : use super::*;
157 :
158 : #[test]
159 1 : fn optional_query_params() -> anyhow::Result<()> {
160 1 : let url = "http://example.com".parse()?;
161 1 : let endpoint = Endpoint::new(url, Client::new());
162 :
163 : // Validate that this pattern makes sense.
164 1 : let req = endpoint
165 1 : .get_path("frobnicate")
166 1 : .query(&[
167 1 : ("foo", Some("10")), // should be just `foo=10`
168 1 : ("bar", None), // shouldn't be passed at all
169 1 : ])
170 1 : .build()?;
171 :
172 1 : assert_eq!(req.url().as_str(), "http://example.com/frobnicate?foo=10");
173 :
174 1 : Ok(())
175 1 : }
176 :
177 : #[test]
178 1 : fn uuid_params() -> anyhow::Result<()> {
179 1 : let url = "http://example.com".parse()?;
180 1 : let endpoint = Endpoint::new(url, Client::new());
181 :
182 1 : let req = endpoint
183 1 : .get_path("frobnicate")
184 1 : .query(&[("session_id", uuid::Uuid::nil())])
185 1 : .build()?;
186 :
187 1 : assert_eq!(
188 1 : req.url().as_str(),
189 1 : "http://example.com/frobnicate?session_id=00000000-0000-0000-0000-000000000000"
190 1 : );
191 :
192 1 : Ok(())
193 1 : }
194 : }
|