TLA Line data Source code
1 : use crate::auth::{Claims, JwtAuth};
2 : use crate::http::error::{api_error_handler, route_error_handler, ApiError};
3 : use anyhow::Context;
4 : use hyper::header::{HeaderName, AUTHORIZATION};
5 : use hyper::http::HeaderValue;
6 : use hyper::Method;
7 : use hyper::{header::CONTENT_TYPE, Body, Request, Response};
8 : use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
9 : use once_cell::sync::Lazy;
10 : use routerify::ext::RequestExt;
11 : use routerify::{Middleware, RequestInfo, Router, RouterBuilder};
12 : use tracing::{self, debug, info, info_span, warn, Instrument};
13 :
14 : use std::future::Future;
15 : use std::str::FromStr;
16 :
17 CBC 81 : static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
18 81 : register_int_counter!(
19 81 : "libmetrics_metric_handler_requests_total",
20 81 : "Number of metric requests made"
21 81 : )
22 81 : .expect("failed to define a metric")
23 81 : });
24 :
25 : static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
26 :
27 : static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
28 16896 : #[derive(Debug, Default, Clone)]
29 : struct RequestId(String);
30 :
31 : /// Adds a tracing info_span! instrumentation around the handler events,
32 : /// logs the request start and end events for non-GET requests and non-200 responses.
33 : ///
34 : /// Usage: Replace `my_handler` with `|r| request_span(r, my_handler)`
35 : ///
36 : /// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
37 : /// with this will get request info logged in the wrapping span, including the unique request ID.
38 : ///
39 : /// This also handles errors, logging them and converting them to an HTTP error response.
40 : ///
41 : /// NB: If the client disconnects, Hyper will drop the Future, without polling it to
42 : /// completion. In other words, the handler must be async cancellation safe! request_span
43 : /// prints a warning to the log when that happens, so that you have some trace of it in
44 : /// the log.
45 : ///
46 : ///
47 : /// There could be other ways to implement similar functionality:
48 : ///
49 : /// * procmacros placed on top of all handler methods
50 : /// With all the drawbacks of procmacros, brings no difference implementation-wise,
51 : /// and little code reduction compared to the existing approach.
52 : ///
53 : /// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
54 : /// implemented for [`RouterBuilder`].
55 : /// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
56 : ///
57 : /// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
58 : /// later, in a post-response middleware.
59 : /// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
60 : /// tries to achive with its `.instrument` used in the current approach.
61 : ///
62 : /// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
63 8440 : pub async fn request_span<R, H>(request: Request<Body>, handler: H) -> R::Output
64 8440 : where
65 8440 : R: Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
66 8440 : H: FnOnce(Request<Body>) -> R + Send + Sync + 'static,
67 8440 : {
68 8440 : let request_id = request.context::<RequestId>().unwrap_or_default().0;
69 8440 : let method = request.method();
70 8440 : let path = request.uri().path();
71 8440 : let request_span = info_span!("request", %method, %path, %request_id);
72 :
73 8440 : let log_quietly = method == Method::GET;
74 8440 : async move {
75 8440 : let cancellation_guard = RequestCancelled::warn_when_dropped_without_responding();
76 8440 : if log_quietly {
77 5056 : debug!("Handling request");
78 : } else {
79 3384 : info!("Handling request");
80 : }
81 :
82 : // No special handling for panics here. There's a `tracing_panic_hook` from another
83 : // module to do that globally.
84 8440 : let res = handler(request).await;
85 :
86 8433 : cancellation_guard.disarm();
87 8433 :
88 8433 : // Log the result if needed.
89 8433 : //
90 8433 : // We also convert any errors into an Ok response with HTTP error code here.
91 8433 : // `make_router` sets a last-resort error handler that would do the same, but
92 8433 : // we prefer to do it here, before we exit the request span, so that the error
93 8433 : // is still logged with the span.
94 8433 : //
95 8433 : // (Because we convert errors to Ok response, we never actually return an error,
96 8433 : // and we could declare the function to return the never type (`!`). However,
97 8433 : // using `routerify::RouterBuilder` requires a proper error type.)
98 8433 : match res {
99 8210 : Ok(response) => {
100 8210 : let response_status = response.status();
101 8210 : if log_quietly && response_status.is_success() {
102 4899 : debug!("Request handled, status: {response_status}");
103 : } else {
104 3311 : info!("Request handled, status: {response_status}");
105 : }
106 8210 : Ok(response)
107 : }
108 223 : Err(err) => Ok(api_error_handler(err)),
109 : }
110 8433 : }
111 8440 : .instrument(request_span)
112 7004 : .await
113 8433 : }
114 :
115 : /// Drop guard to WARN in case the request was dropped before completion.
116 : struct RequestCancelled {
117 : warn: Option<tracing::Span>,
118 : }
119 :
120 : impl RequestCancelled {
121 : /// Create the drop guard using the [`tracing::Span::current`] as the span.
122 8440 : fn warn_when_dropped_without_responding() -> Self {
123 8440 : RequestCancelled {
124 8440 : warn: Some(tracing::Span::current()),
125 8440 : }
126 8440 : }
127 :
128 : /// Consume the drop guard without logging anything.
129 8433 : fn disarm(mut self) {
130 8433 : self.warn = None;
131 8433 : }
132 : }
133 :
134 : impl Drop for RequestCancelled {
135 8435 : fn drop(&mut self) {
136 8435 : if std::thread::panicking() {
137 UBC 0 : // we are unwinding due to panicking, assume we are not dropped for cancellation
138 CBC 8435 : } else if let Some(span) = self.warn.take() {
139 : // the span has all of the info already, but the outer `.instrument(span)` has already
140 : // been dropped, so we need to manually re-enter it for this message.
141 : //
142 : // this is what the instrument would do before polling so it is fine.
143 2 : let _g = span.entered();
144 2 : warn!("request was dropped before completing");
145 8433 : }
146 8435 : }
147 : }
148 :
149 565 : async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
150 565 : use bytes::{Bytes, BytesMut};
151 565 : use std::io::Write as _;
152 565 : use tokio::sync::mpsc;
153 565 : use tokio_stream::wrappers::ReceiverStream;
154 565 :
155 565 : SERVE_METRICS_COUNT.inc();
156 565 :
157 565 : /// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
158 565 : struct ChannelWriter {
159 565 : buffer: BytesMut,
160 565 : tx: mpsc::Sender<std::io::Result<Bytes>>,
161 565 : written: usize,
162 565 : }
163 565 :
164 565 : impl ChannelWriter {
165 565 : fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
166 565 : assert_ne!(buf_len, 0);
167 565 : ChannelWriter {
168 565 : // split about half off the buffer from the start, because we flush depending on
169 565 : // capacity. first flush will come sooner than without this, but now resizes will
170 565 : // have better chance of picking up the "other" half. not guaranteed of course.
171 565 : buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
172 565 : tx,
173 565 : written: 0,
174 565 : }
175 565 : }
176 565 :
177 1557 : fn flush0(&mut self) -> std::io::Result<usize> {
178 1557 : let n = self.buffer.len();
179 1557 : if n == 0 {
180 565 : return Ok(0);
181 1557 : }
182 1557 :
183 1557 : tracing::trace!(n, "flushing");
184 1557 : let ready = self.buffer.split().freeze();
185 1557 :
186 1557 : // not ideal to call from blocking code to block_on, but we are sure that this
187 1557 : // operation does not spawn_blocking other tasks
188 1557 : let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
189 1557 : self.tx.send(Ok(ready)).await.map_err(|_| ())?;
190 565 :
191 565 : // throttle sending to allow reuse of our buffer in `write`.
192 1557 : self.tx.reserve().await.map_err(|_| ())?;
193 565 :
194 565 : // now the response task has picked up the buffer and hopefully started
195 565 : // sending it to the client.
196 1557 : Ok(())
197 1557 : });
198 1557 : if res.is_err() {
199 565 : return Err(std::io::ErrorKind::BrokenPipe.into());
200 1557 : }
201 1557 : self.written += n;
202 1557 : Ok(n)
203 1557 : }
204 565 :
205 565 : fn flushed_bytes(&self) -> usize {
206 565 : self.written
207 565 : }
208 565 : }
209 565 :
210 565 : impl std::io::Write for ChannelWriter {
211 14299926 : fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
212 14299926 : let remaining = self.buffer.capacity() - self.buffer.len();
213 14299926 :
214 14299926 : let out_of_space = remaining < buf.len();
215 14299926 :
216 14299926 : let original_len = buf.len();
217 14299926 :
218 14299926 : if out_of_space {
219 992 : let can_still_fit = buf.len() - remaining;
220 992 : self.buffer.extend_from_slice(&buf[..can_still_fit]);
221 992 : buf = &buf[can_still_fit..];
222 992 : self.flush0()?;
223 14298934 : }
224 565 :
225 565 : // assume that this will often under normal operation just move the pointer back to the
226 565 : // beginning of allocation, because previous split off parts are already sent and
227 565 : // dropped.
228 14299926 : self.buffer.extend_from_slice(buf);
229 14299926 : Ok(original_len)
230 14299926 : }
231 565 :
232 565 : fn flush(&mut self) -> std::io::Result<()> {
233 565 : self.flush0().map(|_| ())
234 565 : }
235 565 : }
236 565 :
237 565 : let started_at = std::time::Instant::now();
238 565 :
239 565 : let (tx, rx) = mpsc::channel(1);
240 565 :
241 565 : let body = Body::wrap_stream(ReceiverStream::new(rx));
242 565 :
243 565 : let mut writer = ChannelWriter::new(128 * 1024, tx);
244 565 :
245 565 : let encoder = TextEncoder::new();
246 565 :
247 565 : let response = Response::builder()
248 565 : .status(200)
249 565 : .header(CONTENT_TYPE, encoder.format_type())
250 565 : .body(body)
251 565 : .unwrap();
252 :
253 565 : let span = info_span!("blocking");
254 565 : tokio::task::spawn_blocking(move || {
255 565 : let _span = span.entered();
256 565 : let metrics = metrics::gather();
257 565 : let res = encoder
258 565 : .encode(&metrics, &mut writer)
259 565 : .and_then(|_| writer.flush().map_err(|e| e.into()));
260 565 :
261 565 : match res {
262 : Ok(()) => {
263 565 : tracing::info!(
264 565 : bytes = writer.flushed_bytes(),
265 565 : elapsed_ms = started_at.elapsed().as_millis(),
266 565 : "responded /metrics"
267 565 : );
268 : }
269 UBC 0 : Err(e) => {
270 0 : tracing::warn!("failed to write out /metrics response: {e:#}");
271 : // semantics of this error are quite... unclear. we want to error the stream out to
272 : // abort the response to somehow notify the client that we failed.
273 : //
274 : // though, most likely the reason for failure is that the receiver is already gone.
275 0 : drop(
276 0 : writer
277 0 : .tx
278 0 : .blocking_send(Err(std::io::ErrorKind::BrokenPipe.into())),
279 0 : );
280 : }
281 : }
282 CBC 565 : });
283 565 :
284 565 : Ok(response)
285 565 : }
286 :
287 1078 : pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
288 1078 : ) -> Middleware<B, ApiError> {
289 8463 : Middleware::pre(move |req| async move {
290 8463 : let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
291 1 : Some(request_id) => request_id
292 1 : .to_str()
293 1 : .expect("extract request id value")
294 1 : .to_owned(),
295 : None => {
296 8462 : let request_id = uuid::Uuid::new_v4();
297 8462 : request_id.to_string()
298 : }
299 : };
300 8463 : req.set_context(RequestId(request_id));
301 8463 :
302 8463 : Ok(req)
303 8463 : })
304 1078 : }
305 :
306 8456 : async fn add_request_id_header_to_response(
307 8456 : mut res: Response<Body>,
308 8456 : req_info: RequestInfo,
309 8456 : ) -> Result<Response<Body>, ApiError> {
310 8456 : if let Some(request_id) = req_info.context::<RequestId>() {
311 8456 : if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
312 8456 : res.headers_mut()
313 8456 : .insert(&X_REQUEST_ID_HEADER, request_header_value);
314 8456 : };
315 UBC 0 : };
316 :
317 CBC 8456 : Ok(res)
318 8456 : }
319 :
320 1078 : pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
321 1078 : Router::builder()
322 1078 : .middleware(add_request_id_middleware())
323 1078 : .middleware(Middleware::post_with_info(
324 1078 : add_request_id_header_to_response,
325 1078 : ))
326 1078 : .get("/metrics", |r| request_span(r, prometheus_metrics_handler))
327 1078 : .err_handler(route_error_handler)
328 1078 : }
329 :
330 560 : pub fn attach_openapi_ui(
331 560 : router_builder: RouterBuilder<hyper::Body, ApiError>,
332 560 : spec: &'static [u8],
333 560 : spec_mount_path: &'static str,
334 560 : ui_mount_path: &'static str,
335 560 : ) -> RouterBuilder<hyper::Body, ApiError> {
336 560 : router_builder
337 560 : .get(spec_mount_path,
338 560 : move |r| request_span(r, move |_| async move {
339 UBC 0 : Ok(Response::builder().body(Body::from(spec)).unwrap())
340 CBC 560 : })
341 560 : )
342 560 : .get(ui_mount_path,
343 560 : move |r| request_span(r, move |_| async move {
344 UBC 0 : Ok(Response::builder().body(Body::from(format!(r#"
345 0 : <!DOCTYPE html>
346 0 : <html lang="en">
347 0 : <head>
348 0 : <title>rweb</title>
349 0 : <link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
350 0 : </head>
351 0 : <body>
352 0 : <div id="swagger-ui"></div>
353 0 : <script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
354 0 : <script>
355 0 : window.onload = function() {{
356 0 : const ui = SwaggerUIBundle({{
357 0 : "dom_id": "\#swagger-ui",
358 0 : presets: [
359 0 : SwaggerUIBundle.presets.apis,
360 0 : SwaggerUIBundle.SwaggerUIStandalonePreset
361 0 : ],
362 0 : layout: "BaseLayout",
363 0 : deepLinking: true,
364 0 : showExtensions: true,
365 0 : showCommonExtensions: true,
366 0 : url: "{}",
367 0 : }})
368 0 : window.ui = ui;
369 0 : }};
370 0 : </script>
371 0 : </body>
372 0 : </html>
373 0 : "#, spec_mount_path))).unwrap())
374 CBC 560 : })
375 560 : )
376 560 : }
377 :
378 86 : fn parse_token(header_value: &str) -> Result<&str, ApiError> {
379 : // header must be in form Bearer <token>
380 86 : let (prefix, token) = header_value
381 86 : .split_once(' ')
382 86 : .ok_or_else(|| ApiError::Unauthorized("malformed authorization header".to_string()))?;
383 86 : if prefix != "Bearer" {
384 UBC 0 : return Err(ApiError::Unauthorized(
385 0 : "malformed authorization header".to_string(),
386 0 : ));
387 CBC 86 : }
388 86 : Ok(token)
389 86 : }
390 :
391 27 : pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
392 27 : provide_auth: fn(&Request<Body>) -> Option<&JwtAuth>,
393 27 : ) -> Middleware<B, ApiError> {
394 27 : Middleware::pre(move |req| async move {
395 141 : if let Some(auth) = provide_auth(&req) {
396 91 : match req.headers().get(AUTHORIZATION) {
397 86 : Some(value) => {
398 86 : let header_value = value.to_str().map_err(|_| {
399 UBC 0 : ApiError::Unauthorized("malformed authorization header".to_string())
400 CBC 86 : })?;
401 86 : let token = parse_token(header_value)?;
402 :
403 86 : let data = auth
404 86 : .decode(token)
405 86 : .map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
406 86 : req.set_context(data.claims);
407 : }
408 : None => {
409 5 : return Err(ApiError::Unauthorized(
410 5 : "missing authorization header".to_string(),
411 5 : ))
412 : }
413 : }
414 50 : }
415 136 : Ok(req)
416 141 : })
417 27 : }
418 :
419 560 : pub fn add_response_header_middleware<B>(
420 560 : header: &str,
421 560 : value: &str,
422 560 : ) -> anyhow::Result<Middleware<B, ApiError>>
423 560 : where
424 560 : B: hyper::body::HttpBody + Send + Sync + 'static,
425 560 : {
426 560 : let name =
427 560 : HeaderName::from_str(header).with_context(|| format!("invalid header name: {header}"))?;
428 560 : let value =
429 560 : HeaderValue::from_str(value).with_context(|| format!("invalid header value: {value}"))?;
430 560 : Ok(Middleware::post_with_info(
431 7156 : move |mut response, request_info| {
432 7156 : let name = name.clone();
433 7156 : let value = value.clone();
434 7156 : async move {
435 7156 : let headers = response.headers_mut();
436 7156 : if headers.contains_key(&name) {
437 UBC 0 : warn!(
438 0 : "{} response already contains header {:?}",
439 0 : request_info.uri(),
440 0 : &name,
441 0 : );
442 CBC 7156 : } else {
443 7156 : headers.insert(name, value);
444 7156 : }
445 7156 : Ok(response)
446 7156 : }
447 7156 : },
448 560 : ))
449 560 : }
450 :
451 7649 : pub fn check_permission_with(
452 7649 : req: &Request<Body>,
453 7649 : check_permission: impl Fn(&Claims) -> Result<(), anyhow::Error>,
454 7649 : ) -> Result<(), ApiError> {
455 7649 : match req.context::<Claims>() {
456 86 : Some(claims) => {
457 86 : Ok(check_permission(&claims).map_err(|err| ApiError::Forbidden(err.to_string()))?)
458 : }
459 7563 : None => Ok(()), // claims is None because auth is disabled
460 : }
461 7649 : }
462 :
463 : #[cfg(test)]
464 : mod tests {
465 : use super::*;
466 : use futures::future::poll_fn;
467 : use hyper::service::Service;
468 : use routerify::RequestServiceBuilder;
469 : use std::net::{IpAddr, SocketAddr};
470 :
471 1 : #[tokio::test]
472 1 : async fn test_request_id_returned() {
473 1 : let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
474 1 : let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
475 1 : let mut service = builder.build(remote_addr);
476 1 : if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
477 UBC 0 : panic!("request service is not ready: {:?}", e);
478 CBC 1 : }
479 1 :
480 1 : let mut req: Request<Body> = Request::default();
481 1 : req.headers_mut()
482 1 : .append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
483 :
484 1 : let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
485 1 :
486 1 : let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
487 1 :
488 1 : assert!(header_val == "42", "response header mismatch");
489 : }
490 :
491 1 : #[tokio::test]
492 1 : async fn test_request_id_empty() {
493 1 : let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
494 1 : let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
495 1 : let mut service = builder.build(remote_addr);
496 1 : if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
497 UBC 0 : panic!("request service is not ready: {:?}", e);
498 CBC 1 : }
499 1 :
500 1 : let req: Request<Body> = Request::default();
501 1 : let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
502 1 :
503 1 : let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
504 1 :
505 1 : assert_ne!(header_val, None, "response header should NOT be empty");
506 : }
507 : }
|