Line data Source code
1 : use std::ops::{Deref, DerefMut};
2 :
3 : use axum::{
4 : async_trait,
5 : extract::{rejection::JsonRejection, FromRequest, Request},
6 : };
7 : use compute_api::responses::GenericAPIError;
8 : use http::StatusCode;
9 :
10 : /// Custom `Json` extractor, so that we can format errors into
11 : /// `JsonResponse<GenericAPIError>`.
12 : #[derive(Debug, Clone, Copy, Default)]
13 : pub(crate) struct Json<T>(pub T);
14 :
15 : #[async_trait]
16 : impl<S, T> FromRequest<S> for Json<T>
17 : where
18 : axum::Json<T>: FromRequest<S, Rejection = JsonRejection>,
19 : S: Send + Sync,
20 : {
21 : type Rejection = (StatusCode, axum::Json<GenericAPIError>);
22 :
23 0 : async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
24 0 : match axum::Json::<T>::from_request(req, state).await {
25 0 : Ok(value) => Ok(Self(value.0)),
26 0 : Err(rejection) => Err((
27 0 : rejection.status(),
28 0 : axum::Json(GenericAPIError {
29 0 : error: rejection.body_text().to_lowercase(),
30 0 : }),
31 0 : )),
32 : }
33 0 : }
34 : }
35 :
36 : impl<T> Deref for Json<T> {
37 : type Target = T;
38 :
39 0 : fn deref(&self) -> &Self::Target {
40 0 : &self.0
41 0 : }
42 : }
43 :
44 : impl<T> DerefMut for Json<T> {
45 0 : fn deref_mut(&mut self) -> &mut Self::Target {
46 0 : &mut self.0
47 0 : }
48 : }
|