Line data Source code
1 : use anyhow::{Result, anyhow};
2 : use axum::{RequestExt, body::Body};
3 : use axum_extra::{
4 : TypedHeader,
5 : headers::{Authorization, authorization::Bearer},
6 : };
7 : use compute_api::requests::{COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope};
8 : use futures::future::BoxFuture;
9 : use http::{Request, Response, StatusCode};
10 : use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, jwk::JwkSet};
11 : use tower_http::auth::AsyncAuthorizeRequest;
12 : use tracing::{debug, warn};
13 :
14 : use crate::http::JsonResponse;
15 :
16 : #[derive(Clone, Debug)]
17 : pub(in crate::http) struct Authorize {
18 : compute_id: String,
19 : // BEGIN HADRON
20 : // Hadron instance ID. Only set if it's a Lakebase V1 a.k.a. Hadron instance.
21 : instance_id: Option<String>,
22 : // END HADRON
23 : jwks: JwkSet,
24 : validation: Validation,
25 : }
26 :
27 : impl Authorize {
28 0 : pub fn new(compute_id: String, instance_id: Option<String>, jwks: JwkSet) -> Self {
29 0 : let mut validation = Validation::new(Algorithm::EdDSA);
30 :
31 : // BEGIN HADRON
32 0 : let use_rsa = jwks.keys.iter().any(|jwk| {
33 0 : jwk.common
34 0 : .key_algorithm
35 0 : .is_some_and(|alg| alg == jsonwebtoken::jwk::KeyAlgorithm::RS256)
36 0 : });
37 0 : if use_rsa {
38 0 : validation = Validation::new(Algorithm::RS256);
39 0 : }
40 : // END HADRON
41 :
42 0 : validation.validate_exp = true;
43 : // Unused by the control plane
44 0 : validation.validate_nbf = false;
45 : // Unused by the control plane
46 0 : validation.validate_aud = false;
47 0 : validation.set_audience(&[COMPUTE_AUDIENCE]);
48 : // Nothing is currently required
49 0 : validation.set_required_spec_claims(&[] as &[&str; 0]);
50 :
51 0 : Self {
52 0 : compute_id,
53 0 : instance_id,
54 0 : jwks,
55 0 : validation,
56 0 : }
57 0 : }
58 : }
59 :
60 : impl AsyncAuthorizeRequest<Body> for Authorize {
61 : type RequestBody = Body;
62 : type ResponseBody = Body;
63 : type Future = BoxFuture<'static, Result<Request<Body>, Response<Self::ResponseBody>>>;
64 :
65 0 : fn authorize(&mut self, mut request: Request<Body>) -> Self::Future {
66 0 : let compute_id = self.compute_id.clone();
67 0 : let is_hadron_instance = self.instance_id.is_some();
68 0 : let jwks = self.jwks.clone();
69 0 : let validation = self.validation.clone();
70 :
71 0 : Box::pin(async move {
72 : // BEGIN HADRON
73 : // In Hadron deployments the "external" HTTP endpoint on compute_ctl can only be
74 : // accessed by trusted components (enforced by dblet network policy), so we can bypass
75 : // all auth here.
76 0 : if is_hadron_instance {
77 0 : return Ok(request);
78 0 : }
79 : // END HADRON
80 :
81 0 : let TypedHeader(Authorization(bearer)) = request
82 0 : .extract_parts::<TypedHeader<Authorization<Bearer>>>()
83 0 : .await
84 0 : .map_err(|_| {
85 0 : JsonResponse::error(StatusCode::BAD_REQUEST, "invalid authorization token")
86 0 : })?;
87 :
88 0 : let data = match Self::verify(&jwks, bearer.token(), &validation) {
89 0 : Ok(claims) => claims,
90 0 : Err(e) => return Err(JsonResponse::error(StatusCode::UNAUTHORIZED, e)),
91 : };
92 :
93 0 : match data.claims.scope {
94 : // TODO: We should validate audience for every token, but
95 : // instead of this ad-hoc validation, we should turn
96 : // [`Validation::validate_aud`] on. This is merely a stopgap
97 : // while we roll out `aud` deployment. We return a 401
98 : // Unauthorized because when we eventually do use
99 : // [`Validation`], we will hit the above `Err` match arm which
100 : // returns 401 Unauthorized.
101 : Some(ComputeClaimsScope::Admin) => {
102 0 : let Some(ref audience) = data.claims.audience else {
103 0 : return Err(JsonResponse::error(
104 0 : StatusCode::UNAUTHORIZED,
105 0 : "missing audience in authorization token claims",
106 0 : ));
107 : };
108 :
109 0 : if !audience.iter().any(|a| a == COMPUTE_AUDIENCE) {
110 0 : return Err(JsonResponse::error(
111 0 : StatusCode::UNAUTHORIZED,
112 0 : "invalid audience in authorization token claims",
113 0 : ));
114 0 : }
115 : }
116 :
117 : // If the scope is not [`ComputeClaimsScope::Admin`], then we
118 : // must validate the compute_id
119 : _ => {
120 0 : let Some(ref claimed_compute_id) = data.claims.compute_id else {
121 0 : return Err(JsonResponse::error(
122 0 : StatusCode::FORBIDDEN,
123 0 : "missing compute_id in authorization token claims",
124 0 : ));
125 : };
126 :
127 0 : if *claimed_compute_id != compute_id {
128 0 : return Err(JsonResponse::error(
129 0 : StatusCode::FORBIDDEN,
130 0 : "invalid compute ID in authorization token claims",
131 0 : ));
132 0 : }
133 : }
134 : }
135 :
136 : // Make claims available to any subsequent middleware or request
137 : // handlers
138 0 : request.extensions_mut().insert(data.claims);
139 :
140 0 : Ok(request)
141 0 : })
142 0 : }
143 : }
144 :
145 : impl Authorize {
146 : /// Verify the token using the JSON Web Key set and return the token data.
147 0 : fn verify(
148 0 : jwks: &JwkSet,
149 0 : token: &str,
150 0 : validation: &Validation,
151 0 : ) -> Result<TokenData<ComputeClaims>> {
152 0 : debug_assert!(!jwks.keys.is_empty());
153 :
154 0 : debug!("verifying token {}", token);
155 :
156 0 : for jwk in jwks.keys.iter() {
157 0 : let decoding_key = match DecodingKey::from_jwk(jwk) {
158 0 : Ok(key) => key,
159 0 : Err(e) => {
160 0 : warn!(
161 0 : "failed to construct decoding key from {}: {}",
162 0 : jwk.common.key_id.as_ref().unwrap(),
163 : e
164 : );
165 :
166 0 : continue;
167 : }
168 : };
169 :
170 0 : match jsonwebtoken::decode::<ComputeClaims>(token, &decoding_key, validation) {
171 0 : Ok(data) => return Ok(data),
172 0 : Err(e) => {
173 0 : warn!(
174 0 : "failed to decode authorization token using {}: {}",
175 0 : jwk.common.key_id.as_ref().unwrap(),
176 : e
177 : );
178 :
179 0 : continue;
180 : }
181 : }
182 : }
183 :
184 0 : Err(anyhow!("failed to verify authorization token"))
185 0 : }
186 : }
|