LCOV - code coverage report
Current view: top level - compute_tools/src/http/middleware - authorize.rs (source / functions) Coverage Total Hit
Test: 98683a8629f0f7f0031d02e04512998d589d76ea.info Lines: 0.0 % 88 0
Test Date: 2025-04-11 16:58:57 Functions: 0.0 % 5 0

            Line data    Source code
       1              : use std::{collections::HashSet, net::SocketAddr};
       2              : 
       3              : use anyhow::{Result, anyhow};
       4              : use axum::{RequestExt, body::Body, extract::ConnectInfo};
       5              : use axum_extra::{
       6              :     TypedHeader,
       7              :     headers::{Authorization, authorization::Bearer},
       8              : };
       9              : use compute_api::requests::ComputeClaims;
      10              : use futures::future::BoxFuture;
      11              : use http::{Request, Response, StatusCode};
      12              : use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, jwk::JwkSet};
      13              : use tower_http::auth::AsyncAuthorizeRequest;
      14              : use tracing::warn;
      15              : 
      16              : use crate::http::{JsonResponse, extract::RequestId};
      17              : 
      18              : #[derive(Clone, Debug)]
      19              : pub(in crate::http) struct Authorize {
      20              :     compute_id: String,
      21              :     jwks: JwkSet,
      22              :     validation: Validation,
      23              : }
      24              : 
      25              : impl Authorize {
      26            0 :     pub fn new(compute_id: String, jwks: JwkSet) -> Self {
      27            0 :         let mut validation = Validation::new(Algorithm::EdDSA);
      28            0 :         // Nothing is currently required
      29            0 :         validation.required_spec_claims = HashSet::new();
      30            0 :         validation.validate_exp = true;
      31            0 :         // Unused by the control plane
      32            0 :         validation.validate_aud = false;
      33            0 :         // Unused by the control plane
      34            0 :         validation.validate_nbf = false;
      35            0 : 
      36            0 :         Self {
      37            0 :             compute_id,
      38            0 :             jwks,
      39            0 :             validation,
      40            0 :         }
      41            0 :     }
      42              : }
      43              : 
      44              : impl AsyncAuthorizeRequest<Body> for Authorize {
      45              :     type RequestBody = Body;
      46              :     type ResponseBody = Body;
      47              :     type Future = BoxFuture<'static, Result<Request<Body>, Response<Self::ResponseBody>>>;
      48              : 
      49            0 :     fn authorize(&mut self, mut request: Request<Body>) -> Self::Future {
      50            0 :         let compute_id = self.compute_id.clone();
      51            0 :         let jwks = self.jwks.clone();
      52            0 :         let validation = self.validation.clone();
      53            0 : 
      54            0 :         Box::pin(async move {
      55            0 :             let request_id = request.extract_parts::<RequestId>().await.unwrap();
      56            0 : 
      57            0 :             // TODO: Remove this stanza after teaching neon_local and the
      58            0 :             // regression tests to use a JWT + JWKS.
      59            0 :             //
      60            0 :             // https://github.com/neondatabase/neon/issues/11316
      61            0 :             if cfg!(feature = "testing") {
      62            0 :                 warn!(%request_id, "Skipping compute_ctl authorization check");
      63              : 
      64            0 :                 return Ok(request);
      65            0 :             }
      66              : 
      67            0 :             let connect_info = request
      68            0 :                 .extract_parts::<ConnectInfo<SocketAddr>>()
      69            0 :                 .await
      70            0 :                 .unwrap();
      71            0 : 
      72            0 :             // In the event the request is coming from the loopback interface,
      73            0 :             // allow all requests
      74            0 :             if connect_info.ip().is_loopback() {
      75            0 :                 warn!(%request_id, "Bypassed authorization because request is coming from the loopback interface");
      76              : 
      77            0 :                 return Ok(request);
      78            0 :             }
      79              : 
      80            0 :             let TypedHeader(Authorization(bearer)) = request
      81            0 :                 .extract_parts::<TypedHeader<Authorization<Bearer>>>()
      82            0 :                 .await
      83            0 :                 .map_err(|_| {
      84            0 :                     JsonResponse::error(StatusCode::BAD_REQUEST, "invalid authorization token")
      85            0 :                 })?;
      86              : 
      87            0 :             let data = match Self::verify(&jwks, bearer.token(), &validation) {
      88            0 :                 Ok(claims) => claims,
      89            0 :                 Err(e) => return Err(JsonResponse::error(StatusCode::UNAUTHORIZED, e)),
      90              :             };
      91              : 
      92            0 :             if data.claims.compute_id != compute_id {
      93            0 :                 return Err(JsonResponse::error(
      94            0 :                     StatusCode::UNAUTHORIZED,
      95            0 :                     "invalid claims in authorization token",
      96            0 :                 ));
      97            0 :             }
      98            0 : 
      99            0 :             // Make claims available to any subsequent middleware or request
     100            0 :             // handlers
     101            0 :             request.extensions_mut().insert(data.claims);
     102            0 : 
     103            0 :             Ok(request)
     104            0 :         })
     105            0 :     }
     106              : }
     107              : 
     108              : impl Authorize {
     109              :     /// Verify the token using the JSON Web Key set and return the token data.
     110            0 :     fn verify(
     111            0 :         jwks: &JwkSet,
     112            0 :         token: &str,
     113            0 :         validation: &Validation,
     114            0 :     ) -> Result<TokenData<ComputeClaims>> {
     115            0 :         for jwk in jwks.keys.iter() {
     116            0 :             let decoding_key = match DecodingKey::from_jwk(jwk) {
     117            0 :                 Ok(key) => key,
     118            0 :                 Err(e) => {
     119            0 :                     warn!(
     120            0 :                         "Failed to construct decoding key from {}: {}",
     121            0 :                         jwk.common.key_id.as_ref().unwrap(),
     122              :                         e
     123              :                     );
     124              : 
     125            0 :                     continue;
     126              :                 }
     127              :             };
     128              : 
     129            0 :             match jsonwebtoken::decode::<ComputeClaims>(token, &decoding_key, validation) {
     130            0 :                 Ok(data) => return Ok(data),
     131            0 :                 Err(e) => {
     132            0 :                     warn!(
     133            0 :                         "Failed to decode authorization token using {}: {}",
     134            0 :                         jwk.common.key_id.as_ref().unwrap(),
     135              :                         e
     136              :                     );
     137              : 
     138            0 :                     continue;
     139              :                 }
     140              :             }
     141              :         }
     142              : 
     143            0 :         Err(anyhow!("Failed to verify authorization token"))
     144            0 :     }
     145              : }
        

Generated by: LCOV version 2.1-beta