Line data Source code
1 : use std::{
2 : fmt::Display,
3 : ops::{Deref, DerefMut},
4 : };
5 :
6 : use axum::{extract::FromRequestParts, response::IntoResponse};
7 : use http::{StatusCode, request::Parts};
8 :
9 : use crate::http::{JsonResponse, headers::X_REQUEST_ID};
10 :
11 : /// Extract the request ID from the `X-Request-Id` header.
12 : #[derive(Debug, Clone, Default)]
13 : pub(crate) struct RequestId(pub String);
14 :
15 : #[derive(Debug)]
16 : /// Rejection used for [`RequestId`].
17 : ///
18 : /// Contains one variant for each way the [`RequestId`] extractor can
19 : /// fail.
20 : pub(crate) enum RequestIdRejection {
21 : /// The request is missing the header.
22 : MissingRequestId,
23 :
24 : /// The value of the header is invalid UTF-8.
25 : InvalidUtf8,
26 : }
27 :
28 : impl RequestIdRejection {
29 0 : pub fn status(&self) -> StatusCode {
30 0 : match self {
31 0 : RequestIdRejection::MissingRequestId => StatusCode::INTERNAL_SERVER_ERROR,
32 0 : RequestIdRejection::InvalidUtf8 => StatusCode::BAD_REQUEST,
33 : }
34 0 : }
35 :
36 0 : pub fn message(&self) -> String {
37 0 : match self {
38 0 : RequestIdRejection::MissingRequestId => "request ID is missing",
39 0 : RequestIdRejection::InvalidUtf8 => "request ID is invalid UTF-8",
40 : }
41 0 : .to_string()
42 0 : }
43 : }
44 :
45 : impl IntoResponse for RequestIdRejection {
46 0 : fn into_response(self) -> axum::response::Response {
47 0 : JsonResponse::error(self.status(), self.message())
48 0 : }
49 : }
50 :
51 : impl<S> FromRequestParts<S> for RequestId
52 : where
53 : S: Send + Sync,
54 : {
55 : type Rejection = RequestIdRejection;
56 :
57 0 : async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
58 0 : match parts.headers.get(X_REQUEST_ID) {
59 0 : Some(value) => match value.to_str() {
60 0 : Ok(request_id) => Ok(Self(request_id.to_string())),
61 0 : Err(_) => Err(RequestIdRejection::InvalidUtf8),
62 : },
63 0 : None => Err(RequestIdRejection::MissingRequestId),
64 : }
65 0 : }
66 : }
67 :
68 : impl Deref for RequestId {
69 : type Target = String;
70 :
71 0 : fn deref(&self) -> &Self::Target {
72 0 : &self.0
73 0 : }
74 : }
75 :
76 : impl DerefMut for RequestId {
77 0 : fn deref_mut(&mut self) -> &mut Self::Target {
78 0 : &mut self.0
79 0 : }
80 : }
81 :
82 : impl Display for RequestId {
83 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 0 : f.write_str(&self.0)
85 0 : }
86 : }
|