LCOV - code coverage report
Current view: top level - compute_tools/src/http/middleware - authorize.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 0.0 % 91 0
Test Date: 2025-07-26 17:20:05 Functions: 0.0 % 8 0

            Line data    Source code
       1              : use anyhow::{Result, anyhow};
       2              : use axum::{RequestExt, body::Body};
       3              : use axum_extra::{
       4              :     TypedHeader,
       5              :     headers::{Authorization, authorization::Bearer},
       6              : };
       7              : use compute_api::requests::{COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope};
       8              : use futures::future::BoxFuture;
       9              : use http::{Request, Response, StatusCode};
      10              : use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, jwk::JwkSet};
      11              : use tower_http::auth::AsyncAuthorizeRequest;
      12              : use tracing::{debug, warn};
      13              : 
      14              : use crate::http::JsonResponse;
      15              : 
      16              : #[derive(Clone, Debug)]
      17              : pub(in crate::http) struct Authorize {
      18              :     compute_id: String,
      19              :     // BEGIN HADRON
      20              :     // Hadron instance ID. Only set if it's a Lakebase V1 a.k.a. Hadron instance.
      21              :     instance_id: Option<String>,
      22              :     // END HADRON
      23              :     jwks: JwkSet,
      24              :     validation: Validation,
      25              : }
      26              : 
      27              : impl Authorize {
      28            0 :     pub fn new(compute_id: String, instance_id: Option<String>, jwks: JwkSet) -> Self {
      29            0 :         let mut validation = Validation::new(Algorithm::EdDSA);
      30              : 
      31              :         // BEGIN HADRON
      32            0 :         let use_rsa = jwks.keys.iter().any(|jwk| {
      33            0 :             jwk.common
      34            0 :                 .key_algorithm
      35            0 :                 .is_some_and(|alg| alg == jsonwebtoken::jwk::KeyAlgorithm::RS256)
      36            0 :         });
      37            0 :         if use_rsa {
      38            0 :             validation = Validation::new(Algorithm::RS256);
      39            0 :         }
      40              :         // END HADRON
      41              : 
      42            0 :         validation.validate_exp = true;
      43              :         // Unused by the control plane
      44            0 :         validation.validate_nbf = false;
      45              :         // Unused by the control plane
      46            0 :         validation.validate_aud = false;
      47            0 :         validation.set_audience(&[COMPUTE_AUDIENCE]);
      48              :         // Nothing is currently required
      49            0 :         validation.set_required_spec_claims(&[] as &[&str; 0]);
      50              : 
      51            0 :         Self {
      52            0 :             compute_id,
      53            0 :             instance_id,
      54            0 :             jwks,
      55            0 :             validation,
      56            0 :         }
      57            0 :     }
      58              : }
      59              : 
      60              : impl AsyncAuthorizeRequest<Body> for Authorize {
      61              :     type RequestBody = Body;
      62              :     type ResponseBody = Body;
      63              :     type Future = BoxFuture<'static, Result<Request<Body>, Response<Self::ResponseBody>>>;
      64              : 
      65            0 :     fn authorize(&mut self, mut request: Request<Body>) -> Self::Future {
      66            0 :         let compute_id = self.compute_id.clone();
      67            0 :         let is_hadron_instance = self.instance_id.is_some();
      68            0 :         let jwks = self.jwks.clone();
      69            0 :         let validation = self.validation.clone();
      70              : 
      71            0 :         Box::pin(async move {
      72              :             // BEGIN HADRON
      73              :             // In Hadron deployments the "external" HTTP endpoint on compute_ctl can only be
      74              :             // accessed by trusted components (enforced by dblet network policy), so we can bypass
      75              :             // all auth here.
      76            0 :             if is_hadron_instance {
      77            0 :                 return Ok(request);
      78            0 :             }
      79              :             // END HADRON
      80              : 
      81            0 :             let TypedHeader(Authorization(bearer)) = request
      82            0 :                 .extract_parts::<TypedHeader<Authorization<Bearer>>>()
      83            0 :                 .await
      84            0 :                 .map_err(|_| {
      85            0 :                     JsonResponse::error(StatusCode::BAD_REQUEST, "invalid authorization token")
      86            0 :                 })?;
      87              : 
      88            0 :             let data = match Self::verify(&jwks, bearer.token(), &validation) {
      89            0 :                 Ok(claims) => claims,
      90            0 :                 Err(e) => return Err(JsonResponse::error(StatusCode::UNAUTHORIZED, e)),
      91              :             };
      92              : 
      93            0 :             match data.claims.scope {
      94              :                 // TODO: We should validate audience for every token, but
      95              :                 // instead of this ad-hoc validation, we should turn
      96              :                 // [`Validation::validate_aud`] on. This is merely a stopgap
      97              :                 // while we roll out `aud` deployment. We return a 401
      98              :                 // Unauthorized because when we eventually do use
      99              :                 // [`Validation`], we will hit the above `Err` match arm which
     100              :                 // returns 401 Unauthorized.
     101              :                 Some(ComputeClaimsScope::Admin) => {
     102            0 :                     let Some(ref audience) = data.claims.audience else {
     103            0 :                         return Err(JsonResponse::error(
     104            0 :                             StatusCode::UNAUTHORIZED,
     105            0 :                             "missing audience in authorization token claims",
     106            0 :                         ));
     107              :                     };
     108              : 
     109            0 :                     if !audience.iter().any(|a| a == COMPUTE_AUDIENCE) {
     110            0 :                         return Err(JsonResponse::error(
     111            0 :                             StatusCode::UNAUTHORIZED,
     112            0 :                             "invalid audience in authorization token claims",
     113            0 :                         ));
     114            0 :                     }
     115              :                 }
     116              : 
     117              :                 // If the scope is not [`ComputeClaimsScope::Admin`], then we
     118              :                 // must validate the compute_id
     119              :                 _ => {
     120            0 :                     let Some(ref claimed_compute_id) = data.claims.compute_id else {
     121            0 :                         return Err(JsonResponse::error(
     122            0 :                             StatusCode::FORBIDDEN,
     123            0 :                             "missing compute_id in authorization token claims",
     124            0 :                         ));
     125              :                     };
     126              : 
     127            0 :                     if *claimed_compute_id != compute_id {
     128            0 :                         return Err(JsonResponse::error(
     129            0 :                             StatusCode::FORBIDDEN,
     130            0 :                             "invalid compute ID in authorization token claims",
     131            0 :                         ));
     132            0 :                     }
     133              :                 }
     134              :             }
     135              : 
     136              :             // Make claims available to any subsequent middleware or request
     137              :             // handlers
     138            0 :             request.extensions_mut().insert(data.claims);
     139              : 
     140            0 :             Ok(request)
     141            0 :         })
     142            0 :     }
     143              : }
     144              : 
     145              : impl Authorize {
     146              :     /// Verify the token using the JSON Web Key set and return the token data.
     147            0 :     fn verify(
     148            0 :         jwks: &JwkSet,
     149            0 :         token: &str,
     150            0 :         validation: &Validation,
     151            0 :     ) -> Result<TokenData<ComputeClaims>> {
     152            0 :         debug_assert!(!jwks.keys.is_empty());
     153              : 
     154            0 :         debug!("verifying token {}", token);
     155              : 
     156            0 :         for jwk in jwks.keys.iter() {
     157            0 :             let decoding_key = match DecodingKey::from_jwk(jwk) {
     158            0 :                 Ok(key) => key,
     159            0 :                 Err(e) => {
     160            0 :                     warn!(
     161            0 :                         "failed to construct decoding key from {}: {}",
     162            0 :                         jwk.common.key_id.as_ref().unwrap(),
     163              :                         e
     164              :                     );
     165              : 
     166            0 :                     continue;
     167              :                 }
     168              :             };
     169              : 
     170            0 :             match jsonwebtoken::decode::<ComputeClaims>(token, &decoding_key, validation) {
     171            0 :                 Ok(data) => return Ok(data),
     172            0 :                 Err(e) => {
     173            0 :                     warn!(
     174            0 :                         "failed to decode authorization token using {}: {}",
     175            0 :                         jwk.common.key_id.as_ref().unwrap(),
     176              :                         e
     177              :                     );
     178              : 
     179            0 :                     continue;
     180              :                 }
     181              :             }
     182              :         }
     183              : 
     184            0 :         Err(anyhow!("failed to verify authorization token"))
     185            0 :     }
     186              : }
        

Generated by: LCOV version 2.1-beta