Line data Source code
1 : use crate::auth::{AuthError, Claims, SwappableJwtAuth};
2 : use crate::http::error::{api_error_handler, route_error_handler, ApiError};
3 : use anyhow::Context;
4 : use hyper::header::{HeaderName, AUTHORIZATION};
5 : use hyper::http::HeaderValue;
6 : use hyper::Method;
7 : use hyper::{header::CONTENT_TYPE, Body, Request, Response};
8 : use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
9 : use once_cell::sync::Lazy;
10 : use routerify::ext::RequestExt;
11 : use routerify::{Middleware, RequestInfo, Router, RouterBuilder};
12 : use tracing::{self, debug, info, info_span, warn, Instrument};
13 :
14 : use std::future::Future;
15 : use std::str::FromStr;
16 :
17 : use bytes::{Bytes, BytesMut};
18 : use std::io::Write as _;
19 : use tokio::sync::mpsc;
20 : use tokio_stream::wrappers::ReceiverStream;
21 :
22 0 : static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
23 0 : register_int_counter!(
24 0 : "libmetrics_metric_handler_requests_total",
25 0 : "Number of metric requests made"
26 0 : )
27 0 : .expect("failed to define a metric")
28 0 : });
29 :
30 : static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
31 :
32 : static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
33 4 : #[derive(Debug, Default, Clone)]
34 : struct RequestId(String);
35 :
36 : /// Adds a tracing info_span! instrumentation around the handler events,
37 : /// logs the request start and end events for non-GET requests and non-200 responses.
38 : ///
39 : /// Usage: Replace `my_handler` with `|r| request_span(r, my_handler)`
40 : ///
41 : /// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
42 : /// with this will get request info logged in the wrapping span, including the unique request ID.
43 : ///
44 : /// This also handles errors, logging them and converting them to an HTTP error response.
45 : ///
46 : /// NB: If the client disconnects, Hyper will drop the Future, without polling it to
47 : /// completion. In other words, the handler must be async cancellation safe! request_span
48 : /// prints a warning to the log when that happens, so that you have some trace of it in
49 : /// the log.
50 : ///
51 : ///
52 : /// There could be other ways to implement similar functionality:
53 : ///
54 : /// * procmacros placed on top of all handler methods
55 : /// With all the drawbacks of procmacros, brings no difference implementation-wise,
56 : /// and little code reduction compared to the existing approach.
57 : ///
58 : /// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
59 : /// implemented for [`RouterBuilder`].
60 : /// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
61 : ///
62 : /// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
63 : /// later, in a post-response middleware.
64 : /// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
65 : /// tries to achive with its `.instrument` used in the current approach.
66 : ///
67 : /// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
68 0 : pub async fn request_span<R, H>(request: Request<Body>, handler: H) -> R::Output
69 0 : where
70 0 : R: Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
71 0 : H: FnOnce(Request<Body>) -> R + Send + Sync + 'static,
72 0 : {
73 0 : let request_id = request.context::<RequestId>().unwrap_or_default().0;
74 0 : let method = request.method();
75 0 : let path = request.uri().path();
76 0 : let request_span = info_span!("request", %method, %path, %request_id);
77 :
78 0 : let log_quietly = method == Method::GET;
79 0 : async move {
80 0 : let cancellation_guard = RequestCancelled::warn_when_dropped_without_responding();
81 0 : if log_quietly {
82 0 : debug!("Handling request");
83 : } else {
84 0 : info!("Handling request");
85 : }
86 :
87 : // No special handling for panics here. There's a `tracing_panic_hook` from another
88 : // module to do that globally.
89 0 : let res = handler(request).await;
90 :
91 0 : cancellation_guard.disarm();
92 0 :
93 0 : // Log the result if needed.
94 0 : //
95 0 : // We also convert any errors into an Ok response with HTTP error code here.
96 0 : // `make_router` sets a last-resort error handler that would do the same, but
97 0 : // we prefer to do it here, before we exit the request span, so that the error
98 0 : // is still logged with the span.
99 0 : //
100 0 : // (Because we convert errors to Ok response, we never actually return an error,
101 0 : // and we could declare the function to return the never type (`!`). However,
102 0 : // using `routerify::RouterBuilder` requires a proper error type.)
103 0 : match res {
104 0 : Ok(response) => {
105 0 : let response_status = response.status();
106 0 : if log_quietly && response_status.is_success() {
107 0 : debug!("Request handled, status: {response_status}");
108 : } else {
109 0 : info!("Request handled, status: {response_status}");
110 : }
111 0 : Ok(response)
112 : }
113 0 : Err(err) => Ok(api_error_handler(err)),
114 : }
115 0 : }
116 0 : .instrument(request_span)
117 0 : .await
118 0 : }
119 :
120 : /// Drop guard to WARN in case the request was dropped before completion.
121 : struct RequestCancelled {
122 : warn: Option<tracing::Span>,
123 : }
124 :
125 : impl RequestCancelled {
126 : /// Create the drop guard using the [`tracing::Span::current`] as the span.
127 0 : fn warn_when_dropped_without_responding() -> Self {
128 0 : RequestCancelled {
129 0 : warn: Some(tracing::Span::current()),
130 0 : }
131 0 : }
132 :
133 : /// Consume the drop guard without logging anything.
134 0 : fn disarm(mut self) {
135 0 : self.warn = None;
136 0 : }
137 : }
138 :
139 : impl Drop for RequestCancelled {
140 0 : fn drop(&mut self) {
141 0 : if std::thread::panicking() {
142 0 : // we are unwinding due to panicking, assume we are not dropped for cancellation
143 0 : } else if let Some(span) = self.warn.take() {
144 : // the span has all of the info already, but the outer `.instrument(span)` has already
145 : // been dropped, so we need to manually re-enter it for this message.
146 : //
147 : // this is what the instrument would do before polling so it is fine.
148 0 : let _g = span.entered();
149 0 : warn!("request was dropped before completing");
150 0 : }
151 0 : }
152 : }
153 :
154 : /// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
155 : pub struct ChannelWriter {
156 : buffer: BytesMut,
157 : pub tx: mpsc::Sender<std::io::Result<Bytes>>,
158 : written: usize,
159 : }
160 :
161 : impl ChannelWriter {
162 0 : pub fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
163 0 : assert_ne!(buf_len, 0);
164 0 : ChannelWriter {
165 0 : // split about half off the buffer from the start, because we flush depending on
166 0 : // capacity. first flush will come sooner than without this, but now resizes will
167 0 : // have better chance of picking up the "other" half. not guaranteed of course.
168 0 : buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
169 0 : tx,
170 0 : written: 0,
171 0 : }
172 0 : }
173 :
174 0 : pub fn flush0(&mut self) -> std::io::Result<usize> {
175 0 : let n = self.buffer.len();
176 0 : if n == 0 {
177 0 : return Ok(0);
178 0 : }
179 0 :
180 0 : tracing::trace!(n, "flushing");
181 0 : let ready = self.buffer.split().freeze();
182 0 :
183 0 : // not ideal to call from blocking code to block_on, but we are sure that this
184 0 : // operation does not spawn_blocking other tasks
185 0 : let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
186 0 : self.tx.send(Ok(ready)).await.map_err(|_| ())?;
187 :
188 : // throttle sending to allow reuse of our buffer in `write`.
189 0 : self.tx.reserve().await.map_err(|_| ())?;
190 :
191 : // now the response task has picked up the buffer and hopefully started
192 : // sending it to the client.
193 0 : Ok(())
194 0 : });
195 0 : if res.is_err() {
196 0 : return Err(std::io::ErrorKind::BrokenPipe.into());
197 0 : }
198 0 : self.written += n;
199 0 : Ok(n)
200 0 : }
201 :
202 0 : pub fn flushed_bytes(&self) -> usize {
203 0 : self.written
204 0 : }
205 : }
206 :
207 : impl std::io::Write for ChannelWriter {
208 0 : fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
209 0 : let remaining = self.buffer.capacity() - self.buffer.len();
210 0 :
211 0 : let out_of_space = remaining < buf.len();
212 0 :
213 0 : let original_len = buf.len();
214 0 :
215 0 : if out_of_space {
216 0 : let can_still_fit = buf.len() - remaining;
217 0 : self.buffer.extend_from_slice(&buf[..can_still_fit]);
218 0 : buf = &buf[can_still_fit..];
219 0 : self.flush0()?;
220 0 : }
221 :
222 : // assume that this will often under normal operation just move the pointer back to the
223 : // beginning of allocation, because previous split off parts are already sent and
224 : // dropped.
225 0 : self.buffer.extend_from_slice(buf);
226 0 : Ok(original_len)
227 0 : }
228 :
229 0 : fn flush(&mut self) -> std::io::Result<()> {
230 0 : self.flush0().map(|_| ())
231 0 : }
232 : }
233 :
234 0 : async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
235 0 : SERVE_METRICS_COUNT.inc();
236 0 :
237 0 : let started_at = std::time::Instant::now();
238 0 :
239 0 : let (tx, rx) = mpsc::channel(1);
240 0 :
241 0 : let body = Body::wrap_stream(ReceiverStream::new(rx));
242 0 :
243 0 : let mut writer = ChannelWriter::new(128 * 1024, tx);
244 0 :
245 0 : let encoder = TextEncoder::new();
246 0 :
247 0 : let response = Response::builder()
248 0 : .status(200)
249 0 : .header(CONTENT_TYPE, encoder.format_type())
250 0 : .body(body)
251 0 : .unwrap();
252 :
253 0 : let span = info_span!("blocking");
254 0 : tokio::task::spawn_blocking(move || {
255 0 : let _span = span.entered();
256 0 : let metrics = metrics::gather();
257 0 : let res = encoder
258 0 : .encode(&metrics, &mut writer)
259 0 : .and_then(|_| writer.flush().map_err(|e| e.into()));
260 0 :
261 0 : match res {
262 : Ok(()) => {
263 0 : tracing::info!(
264 0 : bytes = writer.flushed_bytes(),
265 0 : elapsed_ms = started_at.elapsed().as_millis(),
266 0 : "responded /metrics"
267 0 : );
268 : }
269 0 : Err(e) => {
270 0 : tracing::warn!("failed to write out /metrics response: {e:#}");
271 : // semantics of this error are quite... unclear. we want to error the stream out to
272 : // abort the response to somehow notify the client that we failed.
273 : //
274 : // though, most likely the reason for failure is that the receiver is already gone.
275 0 : drop(
276 0 : writer
277 0 : .tx
278 0 : .blocking_send(Err(std::io::ErrorKind::BrokenPipe.into())),
279 0 : );
280 : }
281 : }
282 0 : });
283 0 :
284 0 : Ok(response)
285 0 : }
286 :
287 4 : pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
288 4 : ) -> Middleware<B, ApiError> {
289 4 : Middleware::pre(move |req| async move {
290 4 : let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
291 2 : Some(request_id) => request_id
292 2 : .to_str()
293 2 : .expect("extract request id value")
294 2 : .to_owned(),
295 : None => {
296 2 : let request_id = uuid::Uuid::new_v4();
297 2 : request_id.to_string()
298 : }
299 : };
300 4 : req.set_context(RequestId(request_id));
301 4 :
302 4 : Ok(req)
303 4 : })
304 4 : }
305 :
306 4 : async fn add_request_id_header_to_response(
307 4 : mut res: Response<Body>,
308 4 : req_info: RequestInfo,
309 4 : ) -> Result<Response<Body>, ApiError> {
310 4 : if let Some(request_id) = req_info.context::<RequestId>() {
311 4 : if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
312 4 : res.headers_mut()
313 4 : .insert(&X_REQUEST_ID_HEADER, request_header_value);
314 4 : };
315 0 : };
316 :
317 4 : Ok(res)
318 4 : }
319 :
320 4 : pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
321 4 : Router::builder()
322 4 : .middleware(add_request_id_middleware())
323 4 : .middleware(Middleware::post_with_info(
324 4 : add_request_id_header_to_response,
325 4 : ))
326 4 : .get("/metrics", |r| request_span(r, prometheus_metrics_handler))
327 4 : .err_handler(route_error_handler)
328 4 : }
329 :
330 0 : pub fn attach_openapi_ui(
331 0 : router_builder: RouterBuilder<hyper::Body, ApiError>,
332 0 : spec: &'static [u8],
333 0 : spec_mount_path: &'static str,
334 0 : ui_mount_path: &'static str,
335 0 : ) -> RouterBuilder<hyper::Body, ApiError> {
336 0 : router_builder
337 0 : .get(spec_mount_path,
338 0 : move |r| request_span(r, move |_| async move {
339 0 : Ok(Response::builder().body(Body::from(spec)).unwrap())
340 0 : })
341 0 : )
342 0 : .get(ui_mount_path,
343 0 : move |r| request_span(r, move |_| async move {
344 0 : Ok(Response::builder().body(Body::from(format!(r#"
345 0 : <!DOCTYPE html>
346 0 : <html lang="en">
347 0 : <head>
348 0 : <title>rweb</title>
349 0 : <link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
350 0 : </head>
351 0 : <body>
352 0 : <div id="swagger-ui"></div>
353 0 : <script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
354 0 : <script>
355 0 : window.onload = function() {{
356 0 : const ui = SwaggerUIBundle({{
357 0 : "dom_id": "\#swagger-ui",
358 0 : presets: [
359 0 : SwaggerUIBundle.presets.apis,
360 0 : SwaggerUIBundle.SwaggerUIStandalonePreset
361 0 : ],
362 0 : layout: "BaseLayout",
363 0 : deepLinking: true,
364 0 : showExtensions: true,
365 0 : showCommonExtensions: true,
366 0 : url: "{}",
367 0 : }})
368 0 : window.ui = ui;
369 0 : }};
370 0 : </script>
371 0 : </body>
372 0 : </html>
373 0 : "#, spec_mount_path))).unwrap())
374 0 : })
375 0 : )
376 0 : }
377 :
378 0 : fn parse_token(header_value: &str) -> Result<&str, ApiError> {
379 : // header must be in form Bearer <token>
380 0 : let (prefix, token) = header_value
381 0 : .split_once(' ')
382 0 : .ok_or_else(|| ApiError::Unauthorized("malformed authorization header".to_string()))?;
383 0 : if prefix != "Bearer" {
384 0 : return Err(ApiError::Unauthorized(
385 0 : "malformed authorization header".to_string(),
386 0 : ));
387 0 : }
388 0 : Ok(token)
389 0 : }
390 :
391 0 : pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
392 0 : provide_auth: fn(&Request<Body>) -> Option<&SwappableJwtAuth>,
393 0 : ) -> Middleware<B, ApiError> {
394 0 : Middleware::pre(move |req| async move {
395 0 : if let Some(auth) = provide_auth(&req) {
396 0 : match req.headers().get(AUTHORIZATION) {
397 0 : Some(value) => {
398 0 : let header_value = value.to_str().map_err(|_| {
399 0 : ApiError::Unauthorized("malformed authorization header".to_string())
400 0 : })?;
401 0 : let token = parse_token(header_value)?;
402 :
403 0 : let data = auth.decode(token).map_err(|err| {
404 0 : warn!("Authentication error: {err}");
405 : // Rely on From<AuthError> for ApiError impl
406 0 : err
407 0 : })?;
408 0 : req.set_context(data.claims);
409 : }
410 : None => {
411 0 : return Err(ApiError::Unauthorized(
412 0 : "missing authorization header".to_string(),
413 0 : ))
414 : }
415 : }
416 0 : }
417 0 : Ok(req)
418 0 : })
419 0 : }
420 :
421 0 : pub fn add_response_header_middleware<B>(
422 0 : header: &str,
423 0 : value: &str,
424 0 : ) -> anyhow::Result<Middleware<B, ApiError>>
425 0 : where
426 0 : B: hyper::body::HttpBody + Send + Sync + 'static,
427 0 : {
428 0 : let name =
429 0 : HeaderName::from_str(header).with_context(|| format!("invalid header name: {header}"))?;
430 0 : let value =
431 0 : HeaderValue::from_str(value).with_context(|| format!("invalid header value: {value}"))?;
432 0 : Ok(Middleware::post_with_info(
433 0 : move |mut response, request_info| {
434 0 : let name = name.clone();
435 0 : let value = value.clone();
436 0 : async move {
437 0 : let headers = response.headers_mut();
438 0 : if headers.contains_key(&name) {
439 0 : warn!(
440 0 : "{} response already contains header {:?}",
441 0 : request_info.uri(),
442 0 : &name,
443 0 : );
444 0 : } else {
445 0 : headers.insert(name, value);
446 0 : }
447 0 : Ok(response)
448 0 : }
449 0 : },
450 0 : ))
451 0 : }
452 :
453 0 : pub fn check_permission_with(
454 0 : req: &Request<Body>,
455 0 : check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
456 0 : ) -> Result<(), ApiError> {
457 0 : match req.context::<Claims>() {
458 0 : Some(claims) => Ok(check_permission(&claims)
459 0 : .map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
460 0 : None => Ok(()), // claims is None because auth is disabled
461 : }
462 0 : }
463 :
464 : #[cfg(test)]
465 : mod tests {
466 : use super::*;
467 : use futures::future::poll_fn;
468 : use hyper::service::Service;
469 : use routerify::RequestServiceBuilder;
470 : use std::net::{IpAddr, SocketAddr};
471 :
472 2 : #[tokio::test]
473 2 : async fn test_request_id_returned() {
474 2 : let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
475 2 : let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
476 2 : let mut service = builder.build(remote_addr);
477 2 : if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
478 2 : panic!("request service is not ready: {:?}", e);
479 2 : }
480 2 :
481 2 : let mut req: Request<Body> = Request::default();
482 2 : req.headers_mut()
483 2 : .append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
484 2 :
485 2 : let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
486 2 :
487 2 : let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
488 2 :
489 2 : assert!(header_val == "42", "response header mismatch");
490 2 : }
491 :
492 2 : #[tokio::test]
493 2 : async fn test_request_id_empty() {
494 2 : let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
495 2 : let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
496 2 : let mut service = builder.build(remote_addr);
497 2 : if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
498 2 : panic!("request service is not ready: {:?}", e);
499 2 : }
500 2 :
501 2 : let req: Request<Body> = Request::default();
502 2 : let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
503 2 :
504 2 : let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
505 2 :
506 2 : assert_ne!(header_val, None, "response header should NOT be empty");
507 2 : }
508 : }
|