Line data Source code
1 : use std::future::Future;
2 : use std::io::Write as _;
3 : use std::str::FromStr;
4 : use std::time::Duration;
5 :
6 : use anyhow::{Context, anyhow};
7 : use bytes::{Bytes, BytesMut};
8 : use hyper::header::{AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_TYPE, HeaderName};
9 : use hyper::http::HeaderValue;
10 : use hyper::{Body, Method, Request, Response};
11 : use jsonwebtoken::TokenData;
12 : use metrics::{Encoder, IntCounter, TextEncoder, register_int_counter};
13 : use once_cell::sync::Lazy;
14 : use pprof::ProfilerGuardBuilder;
15 : use pprof::protos::Message as _;
16 : use routerify::ext::RequestExt;
17 : use routerify::{Middleware, RequestInfo, Router, RouterBuilder};
18 : use tokio::sync::{Mutex, Notify, mpsc};
19 : use tokio_stream::wrappers::ReceiverStream;
20 : use tokio_util::io::ReaderStream;
21 : use tracing::{Instrument, debug, info, info_span, warn};
22 : use utils::auth::{AuthError, Claims, SwappableJwtAuth};
23 : use utils::metrics_collector::{METRICS_COLLECTOR, METRICS_STALE_MILLIS};
24 :
25 : use crate::error::{ApiError, api_error_handler, route_error_handler};
26 : use crate::request::{get_query_param, parse_query_param};
27 :
28 0 : static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
29 0 : register_int_counter!(
30 : "libmetrics_metric_handler_requests_total",
31 : "Number of metric requests made"
32 : )
33 0 : .expect("failed to define a metric")
34 0 : });
35 :
36 : static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
37 :
38 : static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
39 : #[derive(Debug, Default, Clone)]
40 : struct RequestId(String);
41 :
42 : /// Adds a tracing info_span! instrumentation around the handler events,
43 : /// logs the request start and end events for non-GET requests and non-200 responses.
44 : ///
45 : /// Usage: Replace `my_handler` with `|r| request_span(r, my_handler)`
46 : ///
47 : /// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
48 : /// with this will get request info logged in the wrapping span, including the unique request ID.
49 : ///
50 : /// This also handles errors, logging them and converting them to an HTTP error response.
51 : ///
52 : /// NB: If the client disconnects, Hyper will drop the Future, without polling it to
53 : /// completion. In other words, the handler must be async cancellation safe! request_span
54 : /// prints a warning to the log when that happens, so that you have some trace of it in
55 : /// the log.
56 : ///
57 : ///
58 : /// There could be other ways to implement similar functionality:
59 : ///
60 : /// * procmacros placed on top of all handler methods
61 : /// With all the drawbacks of procmacros, brings no difference implementation-wise,
62 : /// and little code reduction compared to the existing approach.
63 : ///
64 : /// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
65 : /// implemented for [`RouterBuilder`].
66 : /// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
67 : ///
68 : /// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
69 : /// later, in a post-response middleware.
70 : /// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
71 : /// tries to achive with its `.instrument` used in the current approach.
72 : ///
73 : /// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
74 0 : pub async fn request_span<R, H>(request: Request<Body>, handler: H) -> R::Output
75 0 : where
76 0 : R: Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
77 0 : H: FnOnce(Request<Body>) -> R + Send + Sync + 'static,
78 0 : {
79 0 : let request_id = request.context::<RequestId>().unwrap_or_default().0;
80 0 : let method = request.method();
81 0 : let path = request.uri().path();
82 0 : let request_span = info_span!("request", %method, %path, %request_id);
83 :
84 0 : let log_quietly = method == Method::GET;
85 0 : async move {
86 0 : let cancellation_guard = RequestCancelled::warn_when_dropped_without_responding();
87 0 : if log_quietly {
88 0 : debug!("Handling request");
89 : } else {
90 0 : info!("Handling request");
91 : }
92 :
93 : // No special handling for panics here. There's a `tracing_panic_hook` from another
94 : // module to do that globally.
95 0 : let res = handler(request).await;
96 :
97 0 : cancellation_guard.disarm();
98 :
99 : // Log the result if needed.
100 : //
101 : // We also convert any errors into an Ok response with HTTP error code here.
102 : // `make_router` sets a last-resort error handler that would do the same, but
103 : // we prefer to do it here, before we exit the request span, so that the error
104 : // is still logged with the span.
105 : //
106 : // (Because we convert errors to Ok response, we never actually return an error,
107 : // and we could declare the function to return the never type (`!`). However,
108 : // using `routerify::RouterBuilder` requires a proper error type.)
109 0 : match res {
110 0 : Ok(response) => {
111 0 : let response_status = response.status();
112 0 : if log_quietly && response_status.is_success() {
113 0 : debug!("Request handled, status: {response_status}");
114 : } else {
115 0 : info!("Request handled, status: {response_status}");
116 : }
117 0 : Ok(response)
118 : }
119 0 : Err(err) => Ok(api_error_handler(err)),
120 : }
121 0 : }
122 0 : .instrument(request_span)
123 0 : .await
124 0 : }
125 :
126 : /// Drop guard to WARN in case the request was dropped before completion.
127 : struct RequestCancelled {
128 : warn: Option<tracing::Span>,
129 : }
130 :
131 : impl RequestCancelled {
132 : /// Create the drop guard using the [`tracing::Span::current`] as the span.
133 0 : fn warn_when_dropped_without_responding() -> Self {
134 0 : RequestCancelled {
135 0 : warn: Some(tracing::Span::current()),
136 0 : }
137 0 : }
138 :
139 : /// Consume the drop guard without logging anything.
140 0 : fn disarm(mut self) {
141 0 : self.warn = None;
142 0 : }
143 : }
144 :
145 : impl Drop for RequestCancelled {
146 0 : fn drop(&mut self) {
147 0 : if std::thread::panicking() {
148 0 : // we are unwinding due to panicking, assume we are not dropped for cancellation
149 0 : } else if let Some(span) = self.warn.take() {
150 : // the span has all of the info already, but the outer `.instrument(span)` has already
151 : // been dropped, so we need to manually re-enter it for this message.
152 : //
153 : // this is what the instrument would do before polling so it is fine.
154 0 : let _g = span.entered();
155 0 : warn!("request was dropped before completing");
156 0 : }
157 0 : }
158 : }
159 :
160 : /// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
161 : pub struct ChannelWriter {
162 : buffer: BytesMut,
163 : pub tx: mpsc::Sender<std::io::Result<Bytes>>,
164 : written: usize,
165 : /// Time spent waiting for the channel to make progress. It is not the same as time to upload a
166 : /// buffer because we cannot know anything about that, but this should allow us to understand
167 : /// the actual time taken without the time spent `std::thread::park`ed.
168 : wait_time: std::time::Duration,
169 : }
170 :
171 : impl ChannelWriter {
172 0 : pub fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
173 0 : assert_ne!(buf_len, 0);
174 0 : ChannelWriter {
175 0 : // split about half off the buffer from the start, because we flush depending on
176 0 : // capacity. first flush will come sooner than without this, but now resizes will
177 0 : // have better chance of picking up the "other" half. not guaranteed of course.
178 0 : buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
179 0 : tx,
180 0 : written: 0,
181 0 : wait_time: std::time::Duration::ZERO,
182 0 : }
183 0 : }
184 :
185 0 : pub fn flush0(&mut self) -> std::io::Result<usize> {
186 0 : let n = self.buffer.len();
187 0 : if n == 0 {
188 0 : return Ok(0);
189 0 : }
190 :
191 0 : tracing::trace!(n, "flushing");
192 0 : let ready = self.buffer.split().freeze();
193 :
194 0 : let wait_started_at = std::time::Instant::now();
195 :
196 : // not ideal to call from blocking code to block_on, but we are sure that this
197 : // operation does not spawn_blocking other tasks
198 0 : let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
199 0 : self.tx.send(Ok(ready)).await.map_err(|_| ())?;
200 :
201 : // throttle sending to allow reuse of our buffer in `write`.
202 0 : self.tx.reserve().await.map_err(|_| ())?;
203 :
204 : // now the response task has picked up the buffer and hopefully started
205 : // sending it to the client.
206 0 : Ok(())
207 0 : });
208 :
209 0 : self.wait_time += wait_started_at.elapsed();
210 :
211 0 : if res.is_err() {
212 0 : return Err(std::io::ErrorKind::BrokenPipe.into());
213 0 : }
214 0 : self.written += n;
215 0 : Ok(n)
216 0 : }
217 :
218 0 : pub fn flushed_bytes(&self) -> usize {
219 0 : self.written
220 0 : }
221 :
222 0 : pub fn wait_time(&self) -> std::time::Duration {
223 0 : self.wait_time
224 0 : }
225 : }
226 :
227 : impl std::io::Write for ChannelWriter {
228 0 : fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
229 0 : let remaining = self.buffer.capacity() - self.buffer.len();
230 :
231 0 : let out_of_space = remaining < buf.len();
232 :
233 0 : let original_len = buf.len();
234 :
235 0 : if out_of_space {
236 0 : let can_still_fit = buf.len() - remaining;
237 0 : self.buffer.extend_from_slice(&buf[..can_still_fit]);
238 0 : buf = &buf[can_still_fit..];
239 0 : self.flush0()?;
240 0 : }
241 :
242 : // assume that this will often under normal operation just move the pointer back to the
243 : // beginning of allocation, because previous split off parts are already sent and
244 : // dropped.
245 0 : self.buffer.extend_from_slice(buf);
246 0 : Ok(original_len)
247 0 : }
248 :
249 0 : fn flush(&mut self) -> std::io::Result<()> {
250 0 : self.flush0().map(|_| ())
251 0 : }
252 : }
253 :
254 0 : pub async fn prometheus_metrics_handler(
255 0 : req: Request<Body>,
256 0 : force_metric_collection_on_scrape: bool,
257 0 : ) -> Result<Response<Body>, ApiError> {
258 0 : SERVE_METRICS_COUNT.inc();
259 :
260 : // HADRON
261 0 : let requested_use_latest = parse_query_param(&req, "use_latest")?;
262 :
263 0 : let use_latest = match requested_use_latest {
264 0 : None => force_metric_collection_on_scrape,
265 0 : Some(true) => true,
266 : Some(false) => {
267 0 : if force_metric_collection_on_scrape {
268 : // We don't cache in this case
269 0 : true
270 : } else {
271 0 : false
272 : }
273 : }
274 : };
275 :
276 0 : let started_at = std::time::Instant::now();
277 :
278 0 : let (tx, rx) = mpsc::channel(1);
279 :
280 0 : let body = Body::wrap_stream(ReceiverStream::new(rx));
281 :
282 0 : let mut writer = ChannelWriter::new(128 * 1024, tx);
283 :
284 0 : let encoder = TextEncoder::new();
285 :
286 0 : let response = Response::builder()
287 0 : .status(200)
288 0 : .header(CONTENT_TYPE, encoder.format_type())
289 0 : .body(body)
290 0 : .unwrap();
291 :
292 0 : let span = info_span!("blocking");
293 0 : tokio::task::spawn_blocking(move || {
294 : // there are situations where we lose scraped metrics under load, try to gather some clues
295 : // since all nodes are queried this, keep the message count low.
296 0 : let spawned_at = std::time::Instant::now();
297 :
298 0 : let _span = span.entered();
299 :
300 : // HADRON
301 0 : let collected = if use_latest {
302 : // Skip caching the results if we always force metric collection on scrape.
303 0 : METRICS_COLLECTOR.run_once(!force_metric_collection_on_scrape)
304 : } else {
305 0 : METRICS_COLLECTOR.last_collected()
306 : };
307 :
308 0 : let gathered_at = std::time::Instant::now();
309 :
310 0 : let res = encoder
311 0 : .encode(&collected.metrics, &mut writer)
312 0 : .and_then(|_| writer.flush().map_err(|e| e.into()));
313 :
314 : // this instant is not when we finally got the full response sent, sending is done by hyper
315 : // in another task.
316 0 : let encoded_at = std::time::Instant::now();
317 :
318 0 : let spawned_in = spawned_at - started_at;
319 0 : let collected_in = gathered_at - spawned_at;
320 : // remove the wait time here in case the tcp connection was clogged
321 0 : let encoded_in = encoded_at - gathered_at - writer.wait_time();
322 0 : let total = encoded_at - started_at;
323 :
324 : // HADRON
325 0 : let staleness_ms = (encoded_at - collected.collected_at).as_millis();
326 0 : METRICS_STALE_MILLIS.set(staleness_ms as i64);
327 :
328 0 : match res {
329 : Ok(()) => {
330 0 : tracing::info!(
331 0 : bytes = writer.flushed_bytes(),
332 0 : total_ms = total.as_millis(),
333 0 : spawning_ms = spawned_in.as_millis(),
334 0 : collection_ms = collected_in.as_millis(),
335 0 : encoding_ms = encoded_in.as_millis(),
336 : stalenss_ms = staleness_ms,
337 0 : "responded /metrics"
338 : );
339 : }
340 0 : Err(e) => {
341 : // there is a chance that this error is not the BrokenPipe we generate in the writer
342 : // for "closed connection", but it is highly unlikely.
343 0 : tracing::warn!(
344 0 : after_bytes = writer.flushed_bytes(),
345 0 : total_ms = total.as_millis(),
346 0 : spawning_ms = spawned_in.as_millis(),
347 0 : collection_ms = collected_in.as_millis(),
348 0 : encoding_ms = encoded_in.as_millis(),
349 0 : "failed to write out /metrics response: {e:?}"
350 : );
351 : // semantics of this error are quite... unclear. we want to error the stream out to
352 : // abort the response to somehow notify the client that we failed.
353 : //
354 : // though, most likely the reason for failure is that the receiver is already gone.
355 0 : drop(
356 0 : writer
357 0 : .tx
358 0 : .blocking_send(Err(std::io::ErrorKind::BrokenPipe.into())),
359 : );
360 : }
361 : }
362 0 : });
363 :
364 0 : Ok(response)
365 0 : }
366 :
367 : /// Generates CPU profiles.
368 0 : pub async fn profile_cpu_handler(req: Request<Body>) -> Result<Response<Body>, ApiError> {
369 : enum Format {
370 : Pprof,
371 : Svg,
372 : }
373 :
374 : // Parameters.
375 0 : let format = match get_query_param(&req, "format")?.as_deref() {
376 0 : None => Format::Pprof,
377 0 : Some("pprof") => Format::Pprof,
378 0 : Some("svg") => Format::Svg,
379 0 : Some(format) => return Err(ApiError::BadRequest(anyhow!("invalid format {format}"))),
380 : };
381 0 : let seconds = match parse_query_param(&req, "seconds")? {
382 0 : None => 5,
383 0 : Some(seconds @ 1..=60) => seconds,
384 0 : Some(_) => return Err(ApiError::BadRequest(anyhow!("duration must be 1-60 secs"))),
385 : };
386 0 : let frequency_hz = match parse_query_param(&req, "frequency")? {
387 0 : None => 99,
388 0 : Some(1001..) => return Err(ApiError::BadRequest(anyhow!("frequency must be <=1000 Hz"))),
389 0 : Some(frequency) => frequency,
390 : };
391 0 : let force: bool = parse_query_param(&req, "force")?.unwrap_or_default();
392 :
393 : // Take the profile.
394 0 : static PROFILE_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
395 : static PROFILE_CANCEL: Lazy<Notify> = Lazy::new(Notify::new);
396 :
397 0 : let report = {
398 : // Only allow one profiler at a time. If force is true, cancel a running profile (e.g. a
399 : // Grafana continuous profile). We use a try_lock() loop when cancelling instead of waiting
400 : // for a lock(), to avoid races where the notify isn't currently awaited.
401 0 : let _lock = loop {
402 0 : match PROFILE_LOCK.try_lock() {
403 0 : Ok(lock) => break lock,
404 0 : Err(_) if force => PROFILE_CANCEL.notify_waiters(),
405 : Err(_) => {
406 0 : return Err(ApiError::Conflict(
407 0 : "profiler already running (use ?force=true to cancel it)".into(),
408 0 : ));
409 : }
410 : }
411 0 : tokio::time::sleep(Duration::from_millis(1)).await; // don't busy-wait
412 : };
413 :
414 0 : let guard = ProfilerGuardBuilder::default()
415 0 : .frequency(frequency_hz)
416 0 : .blocklist(&["libc", "libgcc", "pthread", "vdso"])
417 0 : .build()
418 0 : .map_err(|err| ApiError::InternalServerError(err.into()))?;
419 :
420 0 : tokio::select! {
421 0 : _ = tokio::time::sleep(Duration::from_secs(seconds)) => {},
422 0 : _ = PROFILE_CANCEL.notified() => {},
423 : };
424 :
425 0 : guard
426 0 : .report()
427 0 : .build()
428 0 : .map_err(|err| ApiError::InternalServerError(err.into()))?
429 : };
430 :
431 : // Return the report in the requested format.
432 0 : match format {
433 : Format::Pprof => {
434 0 : let body = report
435 0 : .pprof()
436 0 : .map_err(|err| ApiError::InternalServerError(err.into()))?
437 0 : .encode_to_vec();
438 :
439 0 : Response::builder()
440 0 : .status(200)
441 0 : .header(CONTENT_TYPE, "application/octet-stream")
442 0 : .header(CONTENT_DISPOSITION, "attachment; filename=\"profile.pb\"")
443 0 : .body(Body::from(body))
444 0 : .map_err(|err| ApiError::InternalServerError(err.into()))
445 : }
446 :
447 : Format::Svg => {
448 0 : let mut body = Vec::new();
449 0 : report
450 0 : .flamegraph(&mut body)
451 0 : .map_err(|err| ApiError::InternalServerError(err.into()))?;
452 0 : Response::builder()
453 0 : .status(200)
454 0 : .header(CONTENT_TYPE, "image/svg+xml")
455 0 : .body(Body::from(body))
456 0 : .map_err(|err| ApiError::InternalServerError(err.into()))
457 : }
458 : }
459 0 : }
460 :
461 : /// Generates heap profiles.
462 : ///
463 : /// This only works with jemalloc on Linux.
464 0 : pub async fn profile_heap_handler(req: Request<Body>) -> Result<Response<Body>, ApiError> {
465 : enum Format {
466 : Jemalloc,
467 : Pprof,
468 : Svg,
469 : }
470 :
471 : // Parameters.
472 0 : let format = match get_query_param(&req, "format")?.as_deref() {
473 0 : None => Format::Pprof,
474 0 : Some("jemalloc") => Format::Jemalloc,
475 0 : Some("pprof") => Format::Pprof,
476 0 : Some("svg") => Format::Svg,
477 0 : Some(format) => return Err(ApiError::BadRequest(anyhow!("invalid format {format}"))),
478 : };
479 :
480 : // Obtain profiler handle.
481 0 : let mut prof_ctl = jemalloc_pprof::PROF_CTL
482 0 : .as_ref()
483 0 : .ok_or(ApiError::InternalServerError(anyhow!(
484 0 : "heap profiling not enabled"
485 0 : )))?
486 0 : .lock()
487 0 : .await;
488 0 : if !prof_ctl.activated() {
489 0 : return Err(ApiError::InternalServerError(anyhow!(
490 0 : "heap profiling not enabled"
491 0 : )));
492 0 : }
493 :
494 : // Take and return the profile.
495 0 : match format {
496 : Format::Jemalloc => {
497 : // NB: file is an open handle to a tempfile that's already deleted.
498 0 : let file = tokio::task::spawn_blocking(move || prof_ctl.dump())
499 0 : .await
500 0 : .map_err(|join_err| ApiError::InternalServerError(join_err.into()))?
501 0 : .map_err(ApiError::InternalServerError)?;
502 0 : let stream = ReaderStream::new(tokio::fs::File::from_std(file));
503 0 : Response::builder()
504 0 : .status(200)
505 0 : .header(CONTENT_TYPE, "application/octet-stream")
506 0 : .header(CONTENT_DISPOSITION, "attachment; filename=\"heap.dump\"")
507 0 : .body(Body::wrap_stream(stream))
508 0 : .map_err(|err| ApiError::InternalServerError(err.into()))
509 : }
510 :
511 : Format::Pprof => {
512 0 : let data = tokio::task::spawn_blocking(move || prof_ctl.dump_pprof())
513 0 : .await
514 0 : .map_err(|join_err| ApiError::InternalServerError(join_err.into()))?
515 0 : .map_err(ApiError::InternalServerError)?;
516 0 : Response::builder()
517 0 : .status(200)
518 0 : .header(CONTENT_TYPE, "application/octet-stream")
519 0 : .header(CONTENT_DISPOSITION, "attachment; filename=\"heap.pb.gz\"")
520 0 : .body(Body::from(data))
521 0 : .map_err(|err| ApiError::InternalServerError(err.into()))
522 : }
523 :
524 : Format::Svg => {
525 0 : let svg = tokio::task::spawn_blocking(move || prof_ctl.dump_flamegraph())
526 0 : .await
527 0 : .map_err(|join_err| ApiError::InternalServerError(join_err.into()))?
528 0 : .map_err(ApiError::InternalServerError)?;
529 0 : Response::builder()
530 0 : .status(200)
531 0 : .header(CONTENT_TYPE, "image/svg+xml")
532 0 : .body(Body::from(svg))
533 0 : .map_err(|err| ApiError::InternalServerError(err.into()))
534 : }
535 : }
536 0 : }
537 :
538 2 : pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>()
539 2 : -> Middleware<B, ApiError> {
540 2 : Middleware::pre(move |req| async move {
541 2 : let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
542 1 : Some(request_id) => request_id
543 1 : .to_str()
544 1 : .expect("extract request id value")
545 1 : .to_owned(),
546 : None => {
547 1 : let request_id = uuid::Uuid::new_v4();
548 1 : request_id.to_string()
549 : }
550 : };
551 2 : req.set_context(RequestId(request_id));
552 :
553 2 : Ok(req)
554 4 : })
555 2 : }
556 :
557 2 : async fn add_request_id_header_to_response(
558 2 : mut res: Response<Body>,
559 2 : req_info: RequestInfo,
560 2 : ) -> Result<Response<Body>, ApiError> {
561 2 : if let Some(request_id) = req_info.context::<RequestId>() {
562 2 : if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
563 2 : res.headers_mut()
564 2 : .insert(&X_REQUEST_ID_HEADER, request_header_value);
565 2 : };
566 0 : };
567 :
568 2 : Ok(res)
569 2 : }
570 :
571 2 : pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
572 2 : Router::builder()
573 2 : .middleware(add_request_id_middleware())
574 2 : .middleware(Middleware::post_with_info(
575 : add_request_id_header_to_response,
576 : ))
577 2 : .err_handler(route_error_handler)
578 2 : }
579 :
580 0 : pub fn attach_openapi_ui(
581 0 : router_builder: RouterBuilder<hyper::Body, ApiError>,
582 0 : spec: &'static [u8],
583 0 : spec_mount_path: &'static str,
584 0 : ui_mount_path: &'static str,
585 0 : ) -> RouterBuilder<hyper::Body, ApiError> {
586 0 : router_builder
587 0 : .get(spec_mount_path,
588 0 : move |r| request_span(r, move |_| async move {
589 0 : Ok(Response::builder().body(Body::from(spec)).unwrap())
590 0 : })
591 : )
592 0 : .get(ui_mount_path,
593 0 : move |r| request_span(r, move |_| async move {
594 0 : Ok(Response::builder().body(Body::from(format!(r#"
595 0 : <!DOCTYPE html>
596 0 : <html lang="en">
597 0 : <head>
598 0 : <title>rweb</title>
599 0 : <link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
600 0 : </head>
601 0 : <body>
602 0 : <div id="swagger-ui"></div>
603 0 : <script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
604 0 : <script>
605 0 : window.onload = function() {{
606 0 : const ui = SwaggerUIBundle({{
607 0 : "dom_id": "\#swagger-ui",
608 0 : presets: [
609 0 : SwaggerUIBundle.presets.apis,
610 0 : SwaggerUIBundle.SwaggerUIStandalonePreset
611 0 : ],
612 0 : layout: "BaseLayout",
613 0 : deepLinking: true,
614 0 : showExtensions: true,
615 0 : showCommonExtensions: true,
616 0 : url: "{spec_mount_path}",
617 0 : }})
618 0 : window.ui = ui;
619 0 : }};
620 0 : </script>
621 0 : </body>
622 0 : </html>
623 0 : "#))).unwrap())
624 0 : })
625 : )
626 0 : }
627 :
628 0 : fn parse_token(header_value: &str) -> Result<&str, ApiError> {
629 : // header must be in form Bearer <token>
630 0 : let (prefix, token) = header_value
631 0 : .split_once(' ')
632 0 : .ok_or_else(|| ApiError::Unauthorized("malformed authorization header".to_string()))?;
633 0 : if prefix != "Bearer" {
634 0 : return Err(ApiError::Unauthorized(
635 0 : "malformed authorization header".to_string(),
636 0 : ));
637 0 : }
638 0 : Ok(token)
639 0 : }
640 :
641 0 : pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
642 0 : provide_auth: fn(&Request<Body>) -> Option<&SwappableJwtAuth>,
643 0 : ) -> Middleware<B, ApiError> {
644 0 : Middleware::pre(move |req| async move {
645 0 : if let Some(auth) = provide_auth(&req) {
646 0 : match req.headers().get(AUTHORIZATION) {
647 0 : Some(value) => {
648 0 : let header_value = value.to_str().map_err(|_| {
649 0 : ApiError::Unauthorized("malformed authorization header".to_string())
650 0 : })?;
651 0 : let token = parse_token(header_value)?;
652 :
653 0 : let data: TokenData<Claims> = auth.decode(token).map_err(|err| {
654 0 : warn!("Authentication error: {err}");
655 : // Rely on From<AuthError> for ApiError impl
656 0 : err
657 0 : })?;
658 0 : req.set_context(data.claims);
659 : }
660 : None => {
661 0 : return Err(ApiError::Unauthorized(
662 0 : "missing authorization header".to_string(),
663 0 : ));
664 : }
665 : }
666 0 : }
667 0 : Ok(req)
668 0 : })
669 0 : }
670 :
671 0 : pub fn add_response_header_middleware<B>(
672 0 : header: &str,
673 0 : value: &str,
674 0 : ) -> anyhow::Result<Middleware<B, ApiError>>
675 0 : where
676 0 : B: hyper::body::HttpBody + Send + Sync + 'static,
677 : {
678 0 : let name =
679 0 : HeaderName::from_str(header).with_context(|| format!("invalid header name: {header}"))?;
680 0 : let value =
681 0 : HeaderValue::from_str(value).with_context(|| format!("invalid header value: {value}"))?;
682 0 : Ok(Middleware::post_with_info(
683 0 : move |mut response, request_info| {
684 0 : let name = name.clone();
685 0 : let value = value.clone();
686 0 : async move {
687 0 : let headers = response.headers_mut();
688 0 : if headers.contains_key(&name) {
689 0 : warn!(
690 0 : "{} response already contains header {:?}",
691 0 : request_info.uri(),
692 0 : &name,
693 : );
694 0 : } else {
695 0 : headers.insert(name, value);
696 0 : }
697 0 : Ok(response)
698 0 : }
699 0 : },
700 : ))
701 0 : }
702 :
703 0 : pub fn check_permission_with(
704 0 : req: &Request<Body>,
705 0 : check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
706 0 : ) -> Result<(), ApiError> {
707 0 : match req.context::<Claims>() {
708 0 : Some(claims) => Ok(check_permission(&claims)
709 0 : .map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
710 0 : None => Ok(()), // claims is None because auth is disabled
711 : }
712 0 : }
713 :
714 : #[cfg(test)]
715 : mod tests {
716 : use std::future::poll_fn;
717 : use std::net::{IpAddr, SocketAddr};
718 :
719 : use hyper::service::Service;
720 : use routerify::RequestServiceBuilder;
721 :
722 : use super::*;
723 :
724 : #[tokio::test]
725 1 : async fn test_request_id_returned() {
726 1 : let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
727 1 : let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
728 1 : let mut service = builder.build(remote_addr);
729 1 : if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
730 0 : panic!("request service is not ready: {e:?}");
731 1 : }
732 :
733 1 : let mut req: Request<Body> = Request::default();
734 1 : req.headers_mut()
735 1 : .append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
736 :
737 1 : let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
738 :
739 1 : let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
740 :
741 1 : assert!(header_val == "42", "response header mismatch");
742 1 : }
743 :
744 : #[tokio::test]
745 1 : async fn test_request_id_empty() {
746 1 : let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
747 1 : let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
748 1 : let mut service = builder.build(remote_addr);
749 1 : if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
750 0 : panic!("request service is not ready: {e:?}");
751 1 : }
752 :
753 1 : let req: Request<Body> = Request::default();
754 1 : let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
755 :
756 1 : let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
757 :
758 1 : assert_ne!(header_val, None, "response header should NOT be empty");
759 1 : }
760 : }
|