Line data Source code
1 : use std::{collections::HashSet, net::SocketAddr};
2 :
3 : use anyhow::{Result, anyhow};
4 : use axum::{RequestExt, body::Body, extract::ConnectInfo};
5 : use axum_extra::{
6 : TypedHeader,
7 : headers::{Authorization, authorization::Bearer},
8 : };
9 : use futures::future::BoxFuture;
10 : use http::{Request, Response, StatusCode};
11 : use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, jwk::JwkSet};
12 : use serde::Deserialize;
13 : use tower_http::auth::AsyncAuthorizeRequest;
14 : use tracing::warn;
15 :
16 : use crate::http::{JsonResponse, extract::RequestId};
17 :
18 0 : #[derive(Clone, Debug, Deserialize)]
19 : pub(in crate::http) struct Claims {
20 : compute_id: String,
21 : }
22 :
23 : #[derive(Clone, Debug)]
24 : pub(in crate::http) struct Authorize {
25 : compute_id: String,
26 : jwks: JwkSet,
27 : validation: Validation,
28 : }
29 :
30 : impl Authorize {
31 0 : pub fn new(compute_id: String, jwks: JwkSet) -> Self {
32 0 : let mut validation = Validation::new(Algorithm::EdDSA);
33 0 : // Nothing is currently required
34 0 : validation.required_spec_claims = HashSet::new();
35 0 : validation.validate_exp = true;
36 0 : // Unused by the control plane
37 0 : validation.validate_aud = false;
38 0 : // Unused by the control plane
39 0 : validation.validate_nbf = false;
40 0 :
41 0 : Self {
42 0 : compute_id,
43 0 : jwks,
44 0 : validation,
45 0 : }
46 0 : }
47 : }
48 :
49 : impl AsyncAuthorizeRequest<Body> for Authorize {
50 : type RequestBody = Body;
51 : type ResponseBody = Body;
52 : type Future = BoxFuture<'static, Result<Request<Body>, Response<Self::ResponseBody>>>;
53 :
54 0 : fn authorize(&mut self, mut request: Request<Body>) -> Self::Future {
55 0 : let compute_id = self.compute_id.clone();
56 0 : let jwks = self.jwks.clone();
57 0 : let validation = self.validation.clone();
58 0 :
59 0 : Box::pin(async move {
60 0 : let request_id = request.extract_parts::<RequestId>().await.unwrap();
61 0 :
62 0 : // TODO: Remove this check after a successful rollout
63 0 : if jwks.keys.is_empty() {
64 0 : warn!(%request_id, "Authorization has not been configured");
65 :
66 0 : return Ok(request);
67 0 : }
68 :
69 0 : let connect_info = request
70 0 : .extract_parts::<ConnectInfo<SocketAddr>>()
71 0 : .await
72 0 : .unwrap();
73 0 :
74 0 : // In the event the request is coming from the loopback interface,
75 0 : // allow all requests
76 0 : if connect_info.ip().is_loopback() {
77 0 : warn!(%request_id, "Bypassed authorization because request is coming from the loopback interface");
78 :
79 0 : return Ok(request);
80 0 : }
81 :
82 0 : let TypedHeader(Authorization(bearer)) = request
83 0 : .extract_parts::<TypedHeader<Authorization<Bearer>>>()
84 0 : .await
85 0 : .map_err(|_| {
86 0 : JsonResponse::error(StatusCode::BAD_REQUEST, "invalid authorization token")
87 0 : })?;
88 :
89 0 : let data = match Self::verify(&jwks, bearer.token(), &validation) {
90 0 : Ok(claims) => claims,
91 0 : Err(e) => return Err(JsonResponse::error(StatusCode::UNAUTHORIZED, e)),
92 : };
93 :
94 0 : if data.claims.compute_id != compute_id {
95 0 : return Err(JsonResponse::error(
96 0 : StatusCode::UNAUTHORIZED,
97 0 : "invalid claims in authorization token",
98 0 : ));
99 0 : }
100 0 :
101 0 : // Make claims available to any subsequent middleware or request
102 0 : // handlers
103 0 : request.extensions_mut().insert(data.claims);
104 0 :
105 0 : Ok(request)
106 0 : })
107 0 : }
108 : }
109 :
110 : impl Authorize {
111 : /// Verify the token using the JSON Web Key set and return the token data.
112 0 : fn verify(jwks: &JwkSet, token: &str, validation: &Validation) -> Result<TokenData<Claims>> {
113 0 : debug_assert!(!jwks.keys.is_empty());
114 :
115 0 : for jwk in jwks.keys.iter() {
116 0 : let decoding_key = match DecodingKey::from_jwk(jwk) {
117 0 : Ok(key) => key,
118 0 : Err(e) => {
119 0 : warn!(
120 0 : "Failed to construct decoding key from {}: {}",
121 0 : jwk.common.key_id.as_ref().unwrap(),
122 : e
123 : );
124 :
125 0 : continue;
126 : }
127 : };
128 :
129 0 : match jsonwebtoken::decode::<Claims>(token, &decoding_key, validation) {
130 0 : Ok(data) => return Ok(data),
131 0 : Err(e) => {
132 0 : warn!(
133 0 : "Failed to decode authorization token using {}: {}",
134 0 : jwk.common.key_id.as_ref().unwrap(),
135 : e
136 : );
137 :
138 0 : continue;
139 : }
140 : }
141 : }
142 :
143 0 : Err(anyhow!("Failed to verify authorization token"))
144 0 : }
145 : }
|