Line data Source code
1 : use crate::auth::{AuthError, Claims, SwappableJwtAuth};
2 : use crate::http::error::{api_error_handler, route_error_handler, ApiError};
3 : use anyhow::Context;
4 : use hyper::header::{HeaderName, AUTHORIZATION};
5 : use hyper::http::HeaderValue;
6 : use hyper::Method;
7 : use hyper::{header::CONTENT_TYPE, Body, Request, Response};
8 : use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
9 : use once_cell::sync::Lazy;
10 : use routerify::ext::RequestExt;
11 : use routerify::{Middleware, RequestInfo, Router, RouterBuilder};
12 : use tracing::{debug, info, info_span, warn, Instrument};
13 :
14 : use std::future::Future;
15 : use std::str::FromStr;
16 :
17 : use bytes::{Bytes, BytesMut};
18 : use std::io::Write as _;
19 : use tokio::sync::mpsc;
20 : use tokio_stream::wrappers::ReceiverStream;
21 :
22 0 : static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
23 0 : register_int_counter!(
24 0 : "libmetrics_metric_handler_requests_total",
25 0 : "Number of metric requests made"
26 0 : )
27 0 : .expect("failed to define a metric")
28 0 : });
29 :
30 : static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
31 :
32 : static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
33 : #[derive(Debug, Default, Clone)]
34 : struct RequestId(String);
35 :
36 : /// Adds a tracing info_span! instrumentation around the handler events,
37 : /// logs the request start and end events for non-GET requests and non-200 responses.
38 : ///
39 : /// Usage: Replace `my_handler` with `|r| request_span(r, my_handler)`
40 : ///
41 : /// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
42 : /// with this will get request info logged in the wrapping span, including the unique request ID.
43 : ///
44 : /// This also handles errors, logging them and converting them to an HTTP error response.
45 : ///
46 : /// NB: If the client disconnects, Hyper will drop the Future, without polling it to
47 : /// completion. In other words, the handler must be async cancellation safe! request_span
48 : /// prints a warning to the log when that happens, so that you have some trace of it in
49 : /// the log.
50 : ///
51 : ///
52 : /// There could be other ways to implement similar functionality:
53 : ///
54 : /// * procmacros placed on top of all handler methods
55 : /// With all the drawbacks of procmacros, brings no difference implementation-wise,
56 : /// and little code reduction compared to the existing approach.
57 : ///
58 : /// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
59 : /// implemented for [`RouterBuilder`].
60 : /// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
61 : ///
62 : /// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
63 : /// later, in a post-response middleware.
64 : /// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
65 : /// tries to achive with its `.instrument` used in the current approach.
66 : ///
67 : /// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
68 0 : pub async fn request_span<R, H>(request: Request<Body>, handler: H) -> R::Output
69 0 : where
70 0 : R: Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
71 0 : H: FnOnce(Request<Body>) -> R + Send + Sync + 'static,
72 0 : {
73 0 : let request_id = request.context::<RequestId>().unwrap_or_default().0;
74 0 : let method = request.method();
75 0 : let path = request.uri().path();
76 0 : let request_span = info_span!("request", %method, %path, %request_id);
77 :
78 0 : let log_quietly = method == Method::GET;
79 0 : async move {
80 0 : let cancellation_guard = RequestCancelled::warn_when_dropped_without_responding();
81 0 : if log_quietly {
82 0 : debug!("Handling request");
83 : } else {
84 0 : info!("Handling request");
85 : }
86 :
87 : // No special handling for panics here. There's a `tracing_panic_hook` from another
88 : // module to do that globally.
89 0 : let res = handler(request).await;
90 :
91 0 : cancellation_guard.disarm();
92 0 :
93 0 : // Log the result if needed.
94 0 : //
95 0 : // We also convert any errors into an Ok response with HTTP error code here.
96 0 : // `make_router` sets a last-resort error handler that would do the same, but
97 0 : // we prefer to do it here, before we exit the request span, so that the error
98 0 : // is still logged with the span.
99 0 : //
100 0 : // (Because we convert errors to Ok response, we never actually return an error,
101 0 : // and we could declare the function to return the never type (`!`). However,
102 0 : // using `routerify::RouterBuilder` requires a proper error type.)
103 0 : match res {
104 0 : Ok(response) => {
105 0 : let response_status = response.status();
106 0 : if log_quietly && response_status.is_success() {
107 0 : debug!("Request handled, status: {response_status}");
108 : } else {
109 0 : info!("Request handled, status: {response_status}");
110 : }
111 0 : Ok(response)
112 : }
113 0 : Err(err) => Ok(api_error_handler(err)),
114 : }
115 0 : }
116 0 : .instrument(request_span)
117 0 : .await
118 0 : }
119 :
120 : /// Drop guard to WARN in case the request was dropped before completion.
121 : struct RequestCancelled {
122 : warn: Option<tracing::Span>,
123 : }
124 :
125 : impl RequestCancelled {
126 : /// Create the drop guard using the [`tracing::Span::current`] as the span.
127 0 : fn warn_when_dropped_without_responding() -> Self {
128 0 : RequestCancelled {
129 0 : warn: Some(tracing::Span::current()),
130 0 : }
131 0 : }
132 :
133 : /// Consume the drop guard without logging anything.
134 0 : fn disarm(mut self) {
135 0 : self.warn = None;
136 0 : }
137 : }
138 :
139 : impl Drop for RequestCancelled {
140 0 : fn drop(&mut self) {
141 0 : if std::thread::panicking() {
142 0 : // we are unwinding due to panicking, assume we are not dropped for cancellation
143 0 : } else if let Some(span) = self.warn.take() {
144 : // the span has all of the info already, but the outer `.instrument(span)` has already
145 : // been dropped, so we need to manually re-enter it for this message.
146 : //
147 : // this is what the instrument would do before polling so it is fine.
148 0 : let _g = span.entered();
149 0 : warn!("request was dropped before completing");
150 0 : }
151 0 : }
152 : }
153 :
154 : /// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
155 : pub struct ChannelWriter {
156 : buffer: BytesMut,
157 : pub tx: mpsc::Sender<std::io::Result<Bytes>>,
158 : written: usize,
159 : /// Time spent waiting for the channel to make progress. It is not the same as time to upload a
160 : /// buffer because we cannot know anything about that, but this should allow us to understand
161 : /// the actual time taken without the time spent `std::thread::park`ed.
162 : wait_time: std::time::Duration,
163 : }
164 :
165 : impl ChannelWriter {
166 0 : pub fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
167 0 : assert_ne!(buf_len, 0);
168 0 : ChannelWriter {
169 0 : // split about half off the buffer from the start, because we flush depending on
170 0 : // capacity. first flush will come sooner than without this, but now resizes will
171 0 : // have better chance of picking up the "other" half. not guaranteed of course.
172 0 : buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
173 0 : tx,
174 0 : written: 0,
175 0 : wait_time: std::time::Duration::ZERO,
176 0 : }
177 0 : }
178 :
179 0 : pub fn flush0(&mut self) -> std::io::Result<usize> {
180 0 : let n = self.buffer.len();
181 0 : if n == 0 {
182 0 : return Ok(0);
183 0 : }
184 0 :
185 0 : tracing::trace!(n, "flushing");
186 0 : let ready = self.buffer.split().freeze();
187 0 :
188 0 : let wait_started_at = std::time::Instant::now();
189 0 :
190 0 : // not ideal to call from blocking code to block_on, but we are sure that this
191 0 : // operation does not spawn_blocking other tasks
192 0 : let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
193 0 : self.tx.send(Ok(ready)).await.map_err(|_| ())?;
194 :
195 : // throttle sending to allow reuse of our buffer in `write`.
196 0 : self.tx.reserve().await.map_err(|_| ())?;
197 :
198 : // now the response task has picked up the buffer and hopefully started
199 : // sending it to the client.
200 0 : Ok(())
201 0 : });
202 0 :
203 0 : self.wait_time += wait_started_at.elapsed();
204 0 :
205 0 : if res.is_err() {
206 0 : return Err(std::io::ErrorKind::BrokenPipe.into());
207 0 : }
208 0 : self.written += n;
209 0 : Ok(n)
210 0 : }
211 :
212 0 : pub fn flushed_bytes(&self) -> usize {
213 0 : self.written
214 0 : }
215 :
216 0 : pub fn wait_time(&self) -> std::time::Duration {
217 0 : self.wait_time
218 0 : }
219 : }
220 :
221 : impl std::io::Write for ChannelWriter {
222 0 : fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
223 0 : let remaining = self.buffer.capacity() - self.buffer.len();
224 0 :
225 0 : let out_of_space = remaining < buf.len();
226 0 :
227 0 : let original_len = buf.len();
228 0 :
229 0 : if out_of_space {
230 0 : let can_still_fit = buf.len() - remaining;
231 0 : self.buffer.extend_from_slice(&buf[..can_still_fit]);
232 0 : buf = &buf[can_still_fit..];
233 0 : self.flush0()?;
234 0 : }
235 :
236 : // assume that this will often under normal operation just move the pointer back to the
237 : // beginning of allocation, because previous split off parts are already sent and
238 : // dropped.
239 0 : self.buffer.extend_from_slice(buf);
240 0 : Ok(original_len)
241 0 : }
242 :
243 0 : fn flush(&mut self) -> std::io::Result<()> {
244 0 : self.flush0().map(|_| ())
245 0 : }
246 : }
247 :
248 0 : pub async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
249 0 : SERVE_METRICS_COUNT.inc();
250 0 :
251 0 : let started_at = std::time::Instant::now();
252 0 :
253 0 : let (tx, rx) = mpsc::channel(1);
254 0 :
255 0 : let body = Body::wrap_stream(ReceiverStream::new(rx));
256 0 :
257 0 : let mut writer = ChannelWriter::new(128 * 1024, tx);
258 0 :
259 0 : let encoder = TextEncoder::new();
260 0 :
261 0 : let response = Response::builder()
262 0 : .status(200)
263 0 : .header(CONTENT_TYPE, encoder.format_type())
264 0 : .body(body)
265 0 : .unwrap();
266 :
267 0 : let span = info_span!("blocking");
268 0 : tokio::task::spawn_blocking(move || {
269 0 : // there are situations where we lose scraped metrics under load, try to gather some clues
270 0 : // since all nodes are queried this, keep the message count low.
271 0 : let spawned_at = std::time::Instant::now();
272 0 :
273 0 : let _span = span.entered();
274 0 :
275 0 : let metrics = metrics::gather();
276 0 :
277 0 : let gathered_at = std::time::Instant::now();
278 0 :
279 0 : let res = encoder
280 0 : .encode(&metrics, &mut writer)
281 0 : .and_then(|_| writer.flush().map_err(|e| e.into()));
282 0 :
283 0 : // this instant is not when we finally got the full response sent, sending is done by hyper
284 0 : // in another task.
285 0 : let encoded_at = std::time::Instant::now();
286 0 :
287 0 : let spawned_in = spawned_at - started_at;
288 0 : let collected_in = gathered_at - spawned_at;
289 0 : // remove the wait time here in case the tcp connection was clogged
290 0 : let encoded_in = encoded_at - gathered_at - writer.wait_time();
291 0 : let total = encoded_at - started_at;
292 0 :
293 0 : match res {
294 : Ok(()) => {
295 0 : tracing::info!(
296 0 : bytes = writer.flushed_bytes(),
297 0 : total_ms = total.as_millis(),
298 0 : spawning_ms = spawned_in.as_millis(),
299 0 : collection_ms = collected_in.as_millis(),
300 0 : encoding_ms = encoded_in.as_millis(),
301 0 : "responded /metrics"
302 : );
303 : }
304 0 : Err(e) => {
305 0 : // there is a chance that this error is not the BrokenPipe we generate in the writer
306 0 : // for "closed connection", but it is highly unlikely.
307 0 : tracing::warn!(
308 0 : after_bytes = writer.flushed_bytes(),
309 0 : total_ms = total.as_millis(),
310 0 : spawning_ms = spawned_in.as_millis(),
311 0 : collection_ms = collected_in.as_millis(),
312 0 : encoding_ms = encoded_in.as_millis(),
313 0 : "failed to write out /metrics response: {e:?}"
314 : );
315 : // semantics of this error are quite... unclear. we want to error the stream out to
316 : // abort the response to somehow notify the client that we failed.
317 : //
318 : // though, most likely the reason for failure is that the receiver is already gone.
319 0 : drop(
320 0 : writer
321 0 : .tx
322 0 : .blocking_send(Err(std::io::ErrorKind::BrokenPipe.into())),
323 0 : );
324 : }
325 : }
326 0 : });
327 0 :
328 0 : Ok(response)
329 0 : }
330 :
331 2 : pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
332 2 : ) -> Middleware<B, ApiError> {
333 2 : Middleware::pre(move |req| async move {
334 2 : let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
335 1 : Some(request_id) => request_id
336 1 : .to_str()
337 1 : .expect("extract request id value")
338 1 : .to_owned(),
339 : None => {
340 1 : let request_id = uuid::Uuid::new_v4();
341 1 : request_id.to_string()
342 : }
343 : };
344 2 : req.set_context(RequestId(request_id));
345 2 :
346 2 : Ok(req)
347 2 : })
348 2 : }
349 :
350 2 : async fn add_request_id_header_to_response(
351 2 : mut res: Response<Body>,
352 2 : req_info: RequestInfo,
353 2 : ) -> Result<Response<Body>, ApiError> {
354 2 : if let Some(request_id) = req_info.context::<RequestId>() {
355 2 : if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
356 2 : res.headers_mut()
357 2 : .insert(&X_REQUEST_ID_HEADER, request_header_value);
358 2 : };
359 0 : };
360 :
361 2 : Ok(res)
362 2 : }
363 :
364 2 : pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
365 2 : Router::builder()
366 2 : .middleware(add_request_id_middleware())
367 2 : .middleware(Middleware::post_with_info(
368 2 : add_request_id_header_to_response,
369 2 : ))
370 2 : .err_handler(route_error_handler)
371 2 : }
372 :
373 0 : pub fn attach_openapi_ui(
374 0 : router_builder: RouterBuilder<hyper::Body, ApiError>,
375 0 : spec: &'static [u8],
376 0 : spec_mount_path: &'static str,
377 0 : ui_mount_path: &'static str,
378 0 : ) -> RouterBuilder<hyper::Body, ApiError> {
379 0 : router_builder
380 0 : .get(spec_mount_path,
381 0 : move |r| request_span(r, move |_| async move {
382 0 : Ok(Response::builder().body(Body::from(spec)).unwrap())
383 0 : })
384 0 : )
385 0 : .get(ui_mount_path,
386 0 : move |r| request_span(r, move |_| async move {
387 0 : Ok(Response::builder().body(Body::from(format!(r#"
388 0 : <!DOCTYPE html>
389 0 : <html lang="en">
390 0 : <head>
391 0 : <title>rweb</title>
392 0 : <link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
393 0 : </head>
394 0 : <body>
395 0 : <div id="swagger-ui"></div>
396 0 : <script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
397 0 : <script>
398 0 : window.onload = function() {{
399 0 : const ui = SwaggerUIBundle({{
400 0 : "dom_id": "\#swagger-ui",
401 0 : presets: [
402 0 : SwaggerUIBundle.presets.apis,
403 0 : SwaggerUIBundle.SwaggerUIStandalonePreset
404 0 : ],
405 0 : layout: "BaseLayout",
406 0 : deepLinking: true,
407 0 : showExtensions: true,
408 0 : showCommonExtensions: true,
409 0 : url: "{}",
410 0 : }})
411 0 : window.ui = ui;
412 0 : }};
413 0 : </script>
414 0 : </body>
415 0 : </html>
416 0 : "#, spec_mount_path))).unwrap())
417 0 : })
418 0 : )
419 0 : }
420 :
421 0 : fn parse_token(header_value: &str) -> Result<&str, ApiError> {
422 : // header must be in form Bearer <token>
423 0 : let (prefix, token) = header_value
424 0 : .split_once(' ')
425 0 : .ok_or_else(|| ApiError::Unauthorized("malformed authorization header".to_string()))?;
426 0 : if prefix != "Bearer" {
427 0 : return Err(ApiError::Unauthorized(
428 0 : "malformed authorization header".to_string(),
429 0 : ));
430 0 : }
431 0 : Ok(token)
432 0 : }
433 :
434 0 : pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
435 0 : provide_auth: fn(&Request<Body>) -> Option<&SwappableJwtAuth>,
436 0 : ) -> Middleware<B, ApiError> {
437 0 : Middleware::pre(move |req| async move {
438 0 : if let Some(auth) = provide_auth(&req) {
439 0 : match req.headers().get(AUTHORIZATION) {
440 0 : Some(value) => {
441 0 : let header_value = value.to_str().map_err(|_| {
442 0 : ApiError::Unauthorized("malformed authorization header".to_string())
443 0 : })?;
444 0 : let token = parse_token(header_value)?;
445 :
446 0 : let data = auth.decode(token).map_err(|err| {
447 0 : warn!("Authentication error: {err}");
448 : // Rely on From<AuthError> for ApiError impl
449 0 : err
450 0 : })?;
451 0 : req.set_context(data.claims);
452 : }
453 : None => {
454 0 : return Err(ApiError::Unauthorized(
455 0 : "missing authorization header".to_string(),
456 0 : ))
457 : }
458 : }
459 0 : }
460 0 : Ok(req)
461 0 : })
462 0 : }
463 :
464 0 : pub fn add_response_header_middleware<B>(
465 0 : header: &str,
466 0 : value: &str,
467 0 : ) -> anyhow::Result<Middleware<B, ApiError>>
468 0 : where
469 0 : B: hyper::body::HttpBody + Send + Sync + 'static,
470 0 : {
471 0 : let name =
472 0 : HeaderName::from_str(header).with_context(|| format!("invalid header name: {header}"))?;
473 0 : let value =
474 0 : HeaderValue::from_str(value).with_context(|| format!("invalid header value: {value}"))?;
475 0 : Ok(Middleware::post_with_info(
476 0 : move |mut response, request_info| {
477 0 : let name = name.clone();
478 0 : let value = value.clone();
479 0 : async move {
480 0 : let headers = response.headers_mut();
481 0 : if headers.contains_key(&name) {
482 0 : warn!(
483 0 : "{} response already contains header {:?}",
484 0 : request_info.uri(),
485 0 : &name,
486 : );
487 0 : } else {
488 0 : headers.insert(name, value);
489 0 : }
490 0 : Ok(response)
491 0 : }
492 0 : },
493 0 : ))
494 0 : }
495 :
496 0 : pub fn check_permission_with(
497 0 : req: &Request<Body>,
498 0 : check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
499 0 : ) -> Result<(), ApiError> {
500 0 : match req.context::<Claims>() {
501 0 : Some(claims) => Ok(check_permission(&claims)
502 0 : .map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
503 0 : None => Ok(()), // claims is None because auth is disabled
504 : }
505 0 : }
506 :
507 : #[cfg(test)]
508 : mod tests {
509 : use super::*;
510 : use futures::future::poll_fn;
511 : use hyper::service::Service;
512 : use routerify::RequestServiceBuilder;
513 : use std::net::{IpAddr, SocketAddr};
514 :
515 : #[tokio::test]
516 1 : async fn test_request_id_returned() {
517 1 : let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
518 1 : let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
519 1 : let mut service = builder.build(remote_addr);
520 1 : if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
521 1 : panic!("request service is not ready: {:?}", e);
522 1 : }
523 1 :
524 1 : let mut req: Request<Body> = Request::default();
525 1 : req.headers_mut()
526 1 : .append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
527 1 :
528 1 : let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
529 1 :
530 1 : let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
531 1 :
532 1 : assert!(header_val == "42", "response header mismatch");
533 1 : }
534 :
535 : #[tokio::test]
536 1 : async fn test_request_id_empty() {
537 1 : let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
538 1 : let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
539 1 : let mut service = builder.build(remote_addr);
540 1 : if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
541 1 : panic!("request service is not ready: {:?}", e);
542 1 : }
543 1 :
544 1 : let req: Request<Body> = Request::default();
545 1 : let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
546 1 :
547 1 : let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
548 1 :
549 1 : assert_ne!(header_val, None, "response header should NOT be empty");
550 1 : }
551 : }
|