Line data Source code
1 : use std::collections::HashSet;
2 :
3 : use anyhow::{Result, anyhow};
4 : use axum::{RequestExt, body::Body};
5 : use axum_extra::{
6 : TypedHeader,
7 : headers::{Authorization, authorization::Bearer},
8 : };
9 : use compute_api::requests::ComputeClaims;
10 : use futures::future::BoxFuture;
11 : use http::{Request, Response, StatusCode};
12 : use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, jwk::JwkSet};
13 : use tower_http::auth::AsyncAuthorizeRequest;
14 : use tracing::{debug, warn};
15 :
16 : use crate::http::{JsonResponse, extract::RequestId};
17 :
18 : #[derive(Clone, Debug)]
19 : pub(in crate::http) struct Authorize {
20 : compute_id: String,
21 : jwks: JwkSet,
22 : validation: Validation,
23 : }
24 :
25 : impl Authorize {
26 0 : pub fn new(compute_id: String, jwks: JwkSet) -> Self {
27 0 : let mut validation = Validation::new(Algorithm::EdDSA);
28 0 : // Nothing is currently required
29 0 : validation.required_spec_claims = HashSet::new();
30 0 : validation.validate_exp = true;
31 0 : // Unused by the control plane
32 0 : validation.validate_aud = false;
33 0 : // Unused by the control plane
34 0 : validation.validate_nbf = false;
35 0 :
36 0 : Self {
37 0 : compute_id,
38 0 : jwks,
39 0 : validation,
40 0 : }
41 0 : }
42 : }
43 :
44 : impl AsyncAuthorizeRequest<Body> for Authorize {
45 : type RequestBody = Body;
46 : type ResponseBody = Body;
47 : type Future = BoxFuture<'static, Result<Request<Body>, Response<Self::ResponseBody>>>;
48 :
49 0 : fn authorize(&mut self, mut request: Request<Body>) -> Self::Future {
50 0 : let compute_id = self.compute_id.clone();
51 0 : let jwks = self.jwks.clone();
52 0 : let validation = self.validation.clone();
53 0 :
54 0 : Box::pin(async move {
55 0 : let request_id = request.extract_parts::<RequestId>().await.unwrap();
56 0 :
57 0 : // TODO(tristan957): Remove this stanza after teaching neon_local
58 0 : // and the regression tests to use a JWT + JWKS.
59 0 : //
60 0 : // https://github.com/neondatabase/neon/issues/11316
61 0 : if cfg!(feature = "testing") {
62 0 : warn!(%request_id, "Skipping compute_ctl authorization check");
63 :
64 0 : return Ok(request);
65 0 : }
66 :
67 0 : let TypedHeader(Authorization(bearer)) = request
68 0 : .extract_parts::<TypedHeader<Authorization<Bearer>>>()
69 0 : .await
70 0 : .map_err(|_| {
71 0 : JsonResponse::error(StatusCode::BAD_REQUEST, "invalid authorization token")
72 0 : })?;
73 :
74 0 : let data = match Self::verify(&jwks, bearer.token(), &validation) {
75 0 : Ok(claims) => claims,
76 0 : Err(e) => return Err(JsonResponse::error(StatusCode::UNAUTHORIZED, e)),
77 : };
78 :
79 0 : if data.claims.compute_id != compute_id {
80 0 : return Err(JsonResponse::error(
81 0 : StatusCode::UNAUTHORIZED,
82 0 : "invalid compute ID in authorization token claims",
83 0 : ));
84 0 : }
85 0 :
86 0 : // Make claims available to any subsequent middleware or request
87 0 : // handlers
88 0 : request.extensions_mut().insert(data.claims);
89 0 :
90 0 : Ok(request)
91 0 : })
92 0 : }
93 : }
94 :
95 : impl Authorize {
96 : /// Verify the token using the JSON Web Key set and return the token data.
97 0 : fn verify(
98 0 : jwks: &JwkSet,
99 0 : token: &str,
100 0 : validation: &Validation,
101 0 : ) -> Result<TokenData<ComputeClaims>> {
102 0 : debug_assert!(!jwks.keys.is_empty());
103 :
104 0 : debug!("verifying token {}", token);
105 :
106 0 : for jwk in jwks.keys.iter() {
107 0 : let decoding_key = match DecodingKey::from_jwk(jwk) {
108 0 : Ok(key) => key,
109 0 : Err(e) => {
110 0 : warn!(
111 0 : "failed to construct decoding key from {}: {}",
112 0 : jwk.common.key_id.as_ref().unwrap(),
113 : e
114 : );
115 :
116 0 : continue;
117 : }
118 : };
119 :
120 0 : match jsonwebtoken::decode::<ComputeClaims>(token, &decoding_key, validation) {
121 0 : Ok(data) => return Ok(data),
122 0 : Err(e) => {
123 0 : warn!(
124 0 : "failed to decode authorization token using {}: {}",
125 0 : jwk.common.key_id.as_ref().unwrap(),
126 : e
127 : );
128 :
129 0 : continue;
130 : }
131 : }
132 : }
133 :
134 0 : Err(anyhow!("failed to verify authorization token"))
135 0 : }
136 : }
|