Line data Source code
1 : //! HTTP client and server impls.
2 : //! Other modules should use stuff from this module instead of
3 : //! directly relying on deps like `reqwest` (think loose coupling).
4 :
5 : pub mod health_server;
6 :
7 : use std::time::Duration;
8 :
9 : use bytes::Bytes;
10 : use http::Method;
11 : use http_body_util::BodyExt;
12 : use hyper::body::Body;
13 : pub(crate) use reqwest::{Request, Response};
14 : use reqwest_middleware::RequestBuilder;
15 : pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
16 : pub(crate) use reqwest_retry::policies::ExponentialBackoff;
17 : pub(crate) use reqwest_retry::RetryTransientMiddleware;
18 : use thiserror::Error;
19 :
20 : use crate::metrics::{ConsoleRequest, Metrics};
21 : use crate::url::ApiUrl;
22 :
23 : /// This is the preferred way to create new http clients,
24 : /// because it takes care of observability (OpenTelemetry).
25 : /// We deliberately don't want to replace this with a public static.
26 1 : pub fn new_client() -> ClientWithMiddleware {
27 1 : let client = reqwest::ClientBuilder::new()
28 1 : .build()
29 1 : .expect("Failed to create http client");
30 1 :
31 1 : reqwest_middleware::ClientBuilder::new(client)
32 1 : .with(reqwest_tracing::TracingMiddleware::default())
33 1 : .build()
34 1 : }
35 :
36 0 : pub(crate) fn new_client_with_timeout(
37 0 : request_timeout: Duration,
38 0 : total_retry_duration: Duration,
39 0 : ) -> ClientWithMiddleware {
40 0 : let timeout_client = reqwest::ClientBuilder::new()
41 0 : .timeout(request_timeout)
42 0 : .build()
43 0 : .expect("Failed to create http client with timeout");
44 0 :
45 0 : let retry_policy =
46 0 : ExponentialBackoff::builder().build_with_total_retry_duration(total_retry_duration);
47 0 :
48 0 : reqwest_middleware::ClientBuilder::new(timeout_client)
49 0 : .with(reqwest_tracing::TracingMiddleware::default())
50 0 : // As per docs, "This middleware always errors when given requests with streaming bodies".
51 0 : // That's all right because we only use this client to send `serde_json::RawValue`, which
52 0 : // is not a stream.
53 0 : //
54 0 : // ex-maintainer note:
55 0 : // this limitation can be fixed if streaming is necessary.
56 0 : // retries will still not be performed, but it wont error immediately
57 0 : .with(RetryTransientMiddleware::new_with_policy(retry_policy))
58 0 : .build()
59 0 : }
60 :
61 : /// Thin convenience wrapper for an API provided by an http endpoint.
62 : #[derive(Debug, Clone)]
63 : pub struct Endpoint {
64 : /// API's base URL.
65 : endpoint: ApiUrl,
66 : /// Connection manager with built-in pooling.
67 : client: ClientWithMiddleware,
68 : }
69 :
70 : impl Endpoint {
71 : /// Construct a new HTTP endpoint wrapper.
72 : /// Http client is not constructed under the hood so that it can be shared.
73 2 : pub fn new(endpoint: ApiUrl, client: impl Into<ClientWithMiddleware>) -> Self {
74 2 : Self {
75 2 : endpoint,
76 2 : client: client.into(),
77 2 : }
78 2 : }
79 :
80 : #[inline(always)]
81 0 : pub(crate) fn url(&self) -> &ApiUrl {
82 0 : &self.endpoint
83 0 : }
84 :
85 : /// Return a [builder](RequestBuilder) for a `GET` request,
86 : /// appending a single `path` segment to the base endpoint URL.
87 2 : pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
88 2 : self.get_with_url(|u| {
89 2 : u.path_segments_mut().push(path);
90 2 : })
91 2 : }
92 :
93 : /// Return a [builder](RequestBuilder) for a `GET` request,
94 : /// accepting a closure to modify the url path segments for more complex paths queries.
95 2 : pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
96 2 : self.request_with_url(Method::GET, f)
97 2 : }
98 :
99 : /// Return a [builder](RequestBuilder) for a request,
100 : /// accepting a closure to modify the url path segments for more complex paths queries.
101 2 : pub(crate) fn request_with_url(
102 2 : &self,
103 2 : method: Method,
104 2 : f: impl for<'a> FnOnce(&'a mut ApiUrl),
105 2 : ) -> RequestBuilder {
106 2 : let mut url = self.endpoint.clone();
107 2 : f(&mut url);
108 2 : self.client.request(method, url.into_inner())
109 2 : }
110 :
111 : /// Execute a [request](reqwest::Request).
112 0 : pub(crate) async fn execute(&self, request: Request) -> Result<Response, Error> {
113 0 : let _timer = Metrics::get()
114 0 : .proxy
115 0 : .console_request_latency
116 0 : .start_timer(ConsoleRequest {
117 0 : request: request.url().path(),
118 0 : });
119 0 :
120 0 : self.client.execute(request).await
121 0 : }
122 : }
123 :
124 0 : #[derive(Error, Debug)]
125 : pub(crate) enum ReadBodyError<E> {
126 : #[error("Content length exceeds limit of {limit} bytes")]
127 : BodyTooLarge { limit: usize },
128 :
129 : #[error(transparent)]
130 : Read(#[from] E),
131 : }
132 :
133 9 : pub(crate) async fn read_body_with_limit<E>(
134 9 : mut b: impl Body<Data = Bytes, Error = E> + Unpin,
135 9 : limit: usize,
136 9 : ) -> Result<Vec<u8>, ReadBodyError<E>> {
137 : // We could use `b.limited().collect().await.to_bytes()` here
138 : // but this ends up being slightly more efficient as far as I can tell.
139 :
140 : // check the lower bound of the size hint.
141 : // in reqwest, this value is influenced by the Content-Length header.
142 9 : let lower_bound = match usize::try_from(b.size_hint().lower()) {
143 9 : Ok(bound) if bound <= limit => bound,
144 0 : _ => return Err(ReadBodyError::BodyTooLarge { limit }),
145 : };
146 9 : let mut bytes = Vec::with_capacity(lower_bound);
147 :
148 18 : while let Some(frame) = b.frame().await.transpose()? {
149 9 : if let Ok(data) = frame.into_data() {
150 9 : if bytes.len() + data.len() > limit {
151 0 : return Err(ReadBodyError::BodyTooLarge { limit });
152 9 : }
153 9 : bytes.extend_from_slice(&data);
154 0 : }
155 : }
156 :
157 9 : Ok(bytes)
158 9 : }
159 :
160 : #[cfg(test)]
161 : mod tests {
162 : use reqwest::Client;
163 :
164 : use super::*;
165 :
166 : #[test]
167 1 : fn optional_query_params() -> anyhow::Result<()> {
168 1 : let url = "http://example.com".parse()?;
169 1 : let endpoint = Endpoint::new(url, Client::new());
170 :
171 : // Validate that this pattern makes sense.
172 1 : let req = endpoint
173 1 : .get_path("frobnicate")
174 1 : .query(&[
175 1 : ("foo", Some("10")), // should be just `foo=10`
176 1 : ("bar", None), // shouldn't be passed at all
177 1 : ])
178 1 : .build()?;
179 :
180 1 : assert_eq!(req.url().as_str(), "http://example.com/frobnicate?foo=10");
181 :
182 1 : Ok(())
183 1 : }
184 :
185 : #[test]
186 1 : fn uuid_params() -> anyhow::Result<()> {
187 1 : let url = "http://example.com".parse()?;
188 1 : let endpoint = Endpoint::new(url, Client::new());
189 :
190 1 : let req = endpoint
191 1 : .get_path("frobnicate")
192 1 : .query(&[("session_id", uuid::Uuid::nil())])
193 1 : .build()?;
194 :
195 1 : assert_eq!(
196 1 : req.url().as_str(),
197 1 : "http://example.com/frobnicate?session_id=00000000-0000-0000-0000-000000000000"
198 1 : );
199 :
200 1 : Ok(())
201 1 : }
202 : }
|