Line data Source code
1 : use crate::auth::{AuthError, Claims, SwappableJwtAuth};
2 : use crate::http::error::{api_error_handler, route_error_handler, ApiError};
3 : use anyhow::Context;
4 : use hyper::header::{HeaderName, AUTHORIZATION};
5 : use hyper::http::HeaderValue;
6 : use hyper::Method;
7 : use hyper::{header::CONTENT_TYPE, Body, Request, Response};
8 : use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
9 : use once_cell::sync::Lazy;
10 : use routerify::ext::RequestExt;
11 : use routerify::{Middleware, RequestInfo, Router, RouterBuilder};
12 : use tracing::{self, debug, info, info_span, warn, Instrument};
13 :
14 : use std::future::Future;
15 : use std::str::FromStr;
16 :
17 : use bytes::{Bytes, BytesMut};
18 : use std::io::Write as _;
19 : use tokio::sync::mpsc;
20 : use tokio_stream::wrappers::ReceiverStream;
21 :
22 407 : static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
23 407 : register_int_counter!(
24 407 : "libmetrics_metric_handler_requests_total",
25 407 : "Number of metric requests made"
26 407 : )
27 407 : .expect("failed to define a metric")
28 407 : });
29 :
30 : static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
31 :
32 : static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
33 32476 : #[derive(Debug, Default, Clone)]
34 : struct RequestId(String);
35 :
36 : /// Adds a tracing info_span! instrumentation around the handler events,
37 : /// logs the request start and end events for non-GET requests and non-200 responses.
38 : ///
39 : /// Usage: Replace `my_handler` with `|r| request_span(r, my_handler)`
40 : ///
41 : /// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
42 : /// with this will get request info logged in the wrapping span, including the unique request ID.
43 : ///
44 : /// This also handles errors, logging them and converting them to an HTTP error response.
45 : ///
46 : /// NB: If the client disconnects, Hyper will drop the Future, without polling it to
47 : /// completion. In other words, the handler must be async cancellation safe! request_span
48 : /// prints a warning to the log when that happens, so that you have some trace of it in
49 : /// the log.
50 : ///
51 : ///
52 : /// There could be other ways to implement similar functionality:
53 : ///
54 : /// * procmacros placed on top of all handler methods
55 : /// With all the drawbacks of procmacros, brings no difference implementation-wise,
56 : /// and little code reduction compared to the existing approach.
57 : ///
58 : /// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
59 : /// implemented for [`RouterBuilder`].
60 : /// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
61 : ///
62 : /// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
63 : /// later, in a post-response middleware.
64 : /// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
65 : /// tries to achive with its `.instrument` used in the current approach.
66 : ///
67 : /// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
68 16225 : pub async fn request_span<R, H>(request: Request<Body>, handler: H) -> R::Output
69 16225 : where
70 16225 : R: Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
71 16225 : H: FnOnce(Request<Body>) -> R + Send + Sync + 'static,
72 16225 : {
73 16225 : let request_id = request.context::<RequestId>().unwrap_or_default().0;
74 16225 : let method = request.method();
75 16225 : let path = request.uri().path();
76 16225 : let request_span = info_span!("request", %method, %path, %request_id);
77 :
78 16225 : let log_quietly = method == Method::GET;
79 16225 : async move {
80 16225 : let cancellation_guard = RequestCancelled::warn_when_dropped_without_responding();
81 16225 : if log_quietly {
82 7593 : debug!("Handling request");
83 : } else {
84 8632 : info!("Handling request");
85 : }
86 :
87 : // No special handling for panics here. There's a `tracing_panic_hook` from another
88 : // module to do that globally.
89 20999 : let res = handler(request).await;
90 :
91 16216 : cancellation_guard.disarm();
92 16216 :
93 16216 : // Log the result if needed.
94 16216 : //
95 16216 : // We also convert any errors into an Ok response with HTTP error code here.
96 16216 : // `make_router` sets a last-resort error handler that would do the same, but
97 16216 : // we prefer to do it here, before we exit the request span, so that the error
98 16216 : // is still logged with the span.
99 16216 : //
100 16216 : // (Because we convert errors to Ok response, we never actually return an error,
101 16216 : // and we could declare the function to return the never type (`!`). However,
102 16216 : // using `routerify::RouterBuilder` requires a proper error type.)
103 16216 : match res {
104 15927 : Ok(response) => {
105 15927 : let response_status = response.status();
106 15927 : if log_quietly && response_status.is_success() {
107 7461 : debug!("Request handled, status: {response_status}");
108 : } else {
109 8466 : info!("Request handled, status: {response_status}");
110 : }
111 15927 : Ok(response)
112 : }
113 289 : Err(err) => Ok(api_error_handler(err)),
114 : }
115 16216 : }
116 16225 : .instrument(request_span)
117 20999 : .await
118 16216 : }
119 :
120 : /// Drop guard to WARN in case the request was dropped before completion.
121 : struct RequestCancelled {
122 : warn: Option<tracing::Span>,
123 : }
124 :
125 : impl RequestCancelled {
126 : /// Create the drop guard using the [`tracing::Span::current`] as the span.
127 16225 : fn warn_when_dropped_without_responding() -> Self {
128 16225 : RequestCancelled {
129 16225 : warn: Some(tracing::Span::current()),
130 16225 : }
131 16225 : }
132 :
133 : /// Consume the drop guard without logging anything.
134 16216 : fn disarm(mut self) {
135 16216 : self.warn = None;
136 16216 : }
137 : }
138 :
139 : impl Drop for RequestCancelled {
140 16219 : fn drop(&mut self) {
141 16219 : if std::thread::panicking() {
142 0 : // we are unwinding due to panicking, assume we are not dropped for cancellation
143 16219 : } else if let Some(span) = self.warn.take() {
144 : // the span has all of the info already, but the outer `.instrument(span)` has already
145 : // been dropped, so we need to manually re-enter it for this message.
146 : //
147 : // this is what the instrument would do before polling so it is fine.
148 3 : let _g = span.entered();
149 3 : warn!("request was dropped before completing");
150 16216 : }
151 16219 : }
152 : }
153 :
154 : /// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
155 : pub struct ChannelWriter {
156 : buffer: BytesMut,
157 : pub tx: mpsc::Sender<std::io::Result<Bytes>>,
158 : written: usize,
159 : }
160 :
161 : impl ChannelWriter {
162 1261 : pub fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
163 1261 : assert_ne!(buf_len, 0);
164 1261 : ChannelWriter {
165 1261 : // split about half off the buffer from the start, because we flush depending on
166 1261 : // capacity. first flush will come sooner than without this, but now resizes will
167 1261 : // have better chance of picking up the "other" half. not guaranteed of course.
168 1261 : buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
169 1261 : tx,
170 1261 : written: 0,
171 1261 : }
172 1261 : }
173 :
174 4615 : pub fn flush0(&mut self) -> std::io::Result<usize> {
175 4615 : let n = self.buffer.len();
176 4615 : if n == 0 {
177 0 : return Ok(0);
178 4615 : }
179 4615 :
180 4615 : tracing::trace!(n, "flushing");
181 4615 : let ready = self.buffer.split().freeze();
182 4615 :
183 4615 : // not ideal to call from blocking code to block_on, but we are sure that this
184 4615 : // operation does not spawn_blocking other tasks
185 4615 : let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
186 4615 : self.tx.send(Ok(ready)).await.map_err(|_| ())?;
187 :
188 : // throttle sending to allow reuse of our buffer in `write`.
189 4615 : self.tx.reserve().await.map_err(|_| ())?;
190 :
191 : // now the response task has picked up the buffer and hopefully started
192 : // sending it to the client.
193 4615 : Ok(())
194 4615 : });
195 4615 : if res.is_err() {
196 0 : return Err(std::io::ErrorKind::BrokenPipe.into());
197 4615 : }
198 4615 : self.written += n;
199 4615 : Ok(n)
200 4615 : }
201 :
202 1261 : pub fn flushed_bytes(&self) -> usize {
203 1261 : self.written
204 1261 : }
205 : }
206 :
207 : impl std::io::Write for ChannelWriter {
208 41531899 : fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
209 41531899 : let remaining = self.buffer.capacity() - self.buffer.len();
210 41531899 :
211 41531899 : let out_of_space = remaining < buf.len();
212 41531899 :
213 41531899 : let original_len = buf.len();
214 41531899 :
215 41531899 : if out_of_space {
216 3354 : let can_still_fit = buf.len() - remaining;
217 3354 : self.buffer.extend_from_slice(&buf[..can_still_fit]);
218 3354 : buf = &buf[can_still_fit..];
219 3354 : self.flush0()?;
220 41528545 : }
221 :
222 : // assume that this will often under normal operation just move the pointer back to the
223 : // beginning of allocation, because previous split off parts are already sent and
224 : // dropped.
225 41531899 : self.buffer.extend_from_slice(buf);
226 41531899 : Ok(original_len)
227 41531899 : }
228 :
229 1261 : fn flush(&mut self) -> std::io::Result<()> {
230 1261 : self.flush0().map(|_| ())
231 1261 : }
232 : }
233 :
234 1255 : async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
235 1255 : SERVE_METRICS_COUNT.inc();
236 1255 :
237 1255 : let started_at = std::time::Instant::now();
238 1255 :
239 1255 : let (tx, rx) = mpsc::channel(1);
240 1255 :
241 1255 : let body = Body::wrap_stream(ReceiverStream::new(rx));
242 1255 :
243 1255 : let mut writer = ChannelWriter::new(128 * 1024, tx);
244 1255 :
245 1255 : let encoder = TextEncoder::new();
246 1255 :
247 1255 : let response = Response::builder()
248 1255 : .status(200)
249 1255 : .header(CONTENT_TYPE, encoder.format_type())
250 1255 : .body(body)
251 1255 : .unwrap();
252 :
253 1255 : let span = info_span!("blocking");
254 1255 : tokio::task::spawn_blocking(move || {
255 1255 : let _span = span.entered();
256 1255 : let metrics = metrics::gather();
257 1255 : let res = encoder
258 1255 : .encode(&metrics, &mut writer)
259 1255 : .and_then(|_| writer.flush().map_err(|e| e.into()));
260 1255 :
261 1255 : match res {
262 : Ok(()) => {
263 1255 : tracing::info!(
264 1255 : bytes = writer.flushed_bytes(),
265 1255 : elapsed_ms = started_at.elapsed().as_millis(),
266 1255 : "responded /metrics"
267 1255 : );
268 : }
269 0 : Err(e) => {
270 0 : tracing::warn!("failed to write out /metrics response: {e:#}");
271 : // semantics of this error are quite... unclear. we want to error the stream out to
272 : // abort the response to somehow notify the client that we failed.
273 : //
274 : // though, most likely the reason for failure is that the receiver is already gone.
275 0 : drop(
276 0 : writer
277 0 : .tx
278 0 : .blocking_send(Err(std::io::ErrorKind::BrokenPipe.into())),
279 0 : );
280 : }
281 : }
282 1255 : });
283 1255 :
284 1255 : Ok(response)
285 1255 : }
286 :
287 1502 : pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
288 1502 : ) -> Middleware<B, ApiError> {
289 16260 : Middleware::pre(move |req| async move {
290 16260 : let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
291 2 : Some(request_id) => request_id
292 2 : .to_str()
293 2 : .expect("extract request id value")
294 2 : .to_owned(),
295 : None => {
296 16258 : let request_id = uuid::Uuid::new_v4();
297 16258 : request_id.to_string()
298 : }
299 : };
300 16260 : req.set_context(RequestId(request_id));
301 16260 :
302 16260 : Ok(req)
303 16260 : })
304 1502 : }
305 :
306 16251 : async fn add_request_id_header_to_response(
307 16251 : mut res: Response<Body>,
308 16251 : req_info: RequestInfo,
309 16251 : ) -> Result<Response<Body>, ApiError> {
310 16251 : if let Some(request_id) = req_info.context::<RequestId>() {
311 16251 : if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
312 16251 : res.headers_mut()
313 16251 : .insert(&X_REQUEST_ID_HEADER, request_header_value);
314 16251 : };
315 0 : };
316 :
317 16251 : Ok(res)
318 16251 : }
319 :
320 1502 : pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
321 1502 : Router::builder()
322 1502 : .middleware(add_request_id_middleware())
323 1502 : .middleware(Middleware::post_with_info(
324 1502 : add_request_id_header_to_response,
325 1502 : ))
326 1502 : .get("/metrics", |r| request_span(r, prometheus_metrics_handler))
327 1502 : .err_handler(route_error_handler)
328 1502 : }
329 :
330 604 : pub fn attach_openapi_ui(
331 604 : router_builder: RouterBuilder<hyper::Body, ApiError>,
332 604 : spec: &'static [u8],
333 604 : spec_mount_path: &'static str,
334 604 : ui_mount_path: &'static str,
335 604 : ) -> RouterBuilder<hyper::Body, ApiError> {
336 604 : router_builder
337 604 : .get(spec_mount_path,
338 604 : move |r| request_span(r, move |_| async move {
339 0 : Ok(Response::builder().body(Body::from(spec)).unwrap())
340 604 : })
341 604 : )
342 604 : .get(ui_mount_path,
343 604 : move |r| request_span(r, move |_| async move {
344 0 : Ok(Response::builder().body(Body::from(format!(r#"
345 0 : <!DOCTYPE html>
346 0 : <html lang="en">
347 0 : <head>
348 0 : <title>rweb</title>
349 0 : <link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
350 0 : </head>
351 0 : <body>
352 0 : <div id="swagger-ui"></div>
353 0 : <script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
354 0 : <script>
355 0 : window.onload = function() {{
356 0 : const ui = SwaggerUIBundle({{
357 0 : "dom_id": "\#swagger-ui",
358 0 : presets: [
359 0 : SwaggerUIBundle.presets.apis,
360 0 : SwaggerUIBundle.SwaggerUIStandalonePreset
361 0 : ],
362 0 : layout: "BaseLayout",
363 0 : deepLinking: true,
364 0 : showExtensions: true,
365 0 : showCommonExtensions: true,
366 0 : url: "{}",
367 0 : }})
368 0 : window.ui = ui;
369 0 : }};
370 0 : </script>
371 0 : </body>
372 0 : </html>
373 0 : "#, spec_mount_path))).unwrap())
374 604 : })
375 604 : )
376 604 : }
377 :
378 190 : fn parse_token(header_value: &str) -> Result<&str, ApiError> {
379 : // header must be in form Bearer <token>
380 190 : let (prefix, token) = header_value
381 190 : .split_once(' ')
382 190 : .ok_or_else(|| ApiError::Unauthorized("malformed authorization header".to_string()))?;
383 190 : if prefix != "Bearer" {
384 0 : return Err(ApiError::Unauthorized(
385 0 : "malformed authorization header".to_string(),
386 0 : ));
387 190 : }
388 190 : Ok(token)
389 190 : }
390 :
391 42 : pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
392 42 : provide_auth: fn(&Request<Body>) -> Option<&SwappableJwtAuth>,
393 42 : ) -> Middleware<B, ApiError> {
394 284 : Middleware::pre(move |req| async move {
395 284 : if let Some(auth) = provide_auth(&req) {
396 195 : match req.headers().get(AUTHORIZATION) {
397 190 : Some(value) => {
398 190 : let header_value = value.to_str().map_err(|_| {
399 0 : ApiError::Unauthorized("malformed authorization header".to_string())
400 190 : })?;
401 190 : let token = parse_token(header_value)?;
402 :
403 190 : let data = auth.decode(token).map_err(|err| {
404 3 : warn!("Authentication error: {err}");
405 : // Rely on From<AuthError> for ApiError impl
406 3 : err
407 190 : })?;
408 187 : req.set_context(data.claims);
409 : }
410 : None => {
411 5 : return Err(ApiError::Unauthorized(
412 5 : "missing authorization header".to_string(),
413 5 : ))
414 : }
415 : }
416 89 : }
417 276 : Ok(req)
418 284 : })
419 42 : }
420 :
421 604 : pub fn add_response_header_middleware<B>(
422 604 : header: &str,
423 604 : value: &str,
424 604 : ) -> anyhow::Result<Middleware<B, ApiError>>
425 604 : where
426 604 : B: hyper::body::HttpBody + Send + Sync + 'static,
427 604 : {
428 604 : let name =
429 604 : HeaderName::from_str(header).with_context(|| format!("invalid header name: {header}"))?;
430 604 : let value =
431 604 : HeaderValue::from_str(value).with_context(|| format!("invalid header value: {value}"))?;
432 604 : Ok(Middleware::post_with_info(
433 10405 : move |mut response, request_info| {
434 10405 : let name = name.clone();
435 10405 : let value = value.clone();
436 10405 : async move {
437 10405 : let headers = response.headers_mut();
438 10405 : if headers.contains_key(&name) {
439 0 : warn!(
440 0 : "{} response already contains header {:?}",
441 0 : request_info.uri(),
442 0 : &name,
443 0 : );
444 10405 : } else {
445 10405 : headers.insert(name, value);
446 10405 : }
447 10405 : Ok(response)
448 10405 : }
449 10405 : },
450 604 : ))
451 604 : }
452 :
453 10383 : pub fn check_permission_with(
454 10383 : req: &Request<Body>,
455 10383 : check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
456 10383 : ) -> Result<(), ApiError> {
457 10383 : match req.context::<Claims>() {
458 112 : Some(claims) => Ok(check_permission(&claims)
459 112 : .map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
460 10271 : None => Ok(()), // claims is None because auth is disabled
461 : }
462 10383 : }
463 :
464 : #[cfg(test)]
465 : mod tests {
466 : use super::*;
467 : use futures::future::poll_fn;
468 : use hyper::service::Service;
469 : use routerify::RequestServiceBuilder;
470 : use std::net::{IpAddr, SocketAddr};
471 :
472 2 : #[tokio::test]
473 2 : async fn test_request_id_returned() {
474 2 : let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
475 2 : let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
476 2 : let mut service = builder.build(remote_addr);
477 2 : if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
478 0 : panic!("request service is not ready: {:?}", e);
479 2 : }
480 2 :
481 2 : let mut req: Request<Body> = Request::default();
482 2 : req.headers_mut()
483 2 : .append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
484 :
485 2 : let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
486 2 :
487 2 : let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
488 :
489 2 : assert!(header_val == "42", "response header mismatch");
490 : }
491 :
492 2 : #[tokio::test]
493 2 : async fn test_request_id_empty() {
494 2 : let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
495 2 : let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
496 2 : let mut service = builder.build(remote_addr);
497 2 : if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
498 0 : panic!("request service is not ready: {:?}", e);
499 2 : }
500 2 :
501 2 : let req: Request<Body> = Request::default();
502 2 : let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
503 2 :
504 2 : let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
505 2 :
506 2 : assert_ne!(header_val, None, "response header should NOT be empty");
507 : }
508 : }
|