LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - jwt.rs (source / functions) Coverage Total Hit
Test: 49aa928ec5b4b510172d8b5c6d154da28e70a46c.info Lines: 89.2 % 858 765
Test Date: 2024-11-13 18:23:39 Functions: 50.7 % 205 104

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::future::Future;
       3              : use std::sync::Arc;
       4              : use std::time::{Duration, SystemTime};
       5              : 
       6              : use arc_swap::ArcSwapOption;
       7              : use dashmap::DashMap;
       8              : use jose_jwk::crypto::KeyInfo;
       9              : use reqwest::{redirect, Client};
      10              : use reqwest_retry::policies::ExponentialBackoff;
      11              : use reqwest_retry::RetryTransientMiddleware;
      12              : use serde::de::Visitor;
      13              : use serde::{Deserialize, Deserializer};
      14              : use serde_json::value::RawValue;
      15              : use signature::Verifier;
      16              : use thiserror::Error;
      17              : use tokio::time::Instant;
      18              : 
      19              : use crate::auth::backend::ComputeCredentialKeys;
      20              : use crate::context::RequestMonitoring;
      21              : use crate::control_plane::errors::GetEndpointJwksError;
      22              : use crate::http::read_body_with_limit;
      23              : use crate::intern::RoleNameInt;
      24              : use crate::types::{EndpointId, RoleName};
      25              : 
      26              : // TODO(conrad): make these configurable.
      27              : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
      28              : const MIN_RENEW: Duration = Duration::from_secs(30);
      29              : const AUTO_RENEW: Duration = Duration::from_secs(300);
      30              : const MAX_RENEW: Duration = Duration::from_secs(3600);
      31              : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
      32              : const JWKS_USER_AGENT: &str = "neon-proxy";
      33              : 
      34              : const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
      35              : const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
      36              : const JWKS_FETCH_RETRIES: u32 = 3;
      37              : 
      38              : /// How to get the JWT auth rules
      39              : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
      40              :     fn fetch_auth_rules(
      41              :         &self,
      42              :         ctx: &RequestMonitoring,
      43              :         endpoint: EndpointId,
      44              :     ) -> impl Future<Output = Result<Vec<AuthRule>, FetchAuthRulesError>> + Send;
      45              : }
      46              : 
      47            0 : #[derive(Error, Debug)]
      48              : pub(crate) enum FetchAuthRulesError {
      49              :     #[error(transparent)]
      50              :     GetEndpointJwks(#[from] GetEndpointJwksError),
      51              : 
      52              :     #[error("JWKs settings for this role were not configured")]
      53              :     RoleJwksNotConfigured,
      54              : }
      55              : 
      56              : #[derive(Clone)]
      57              : pub(crate) struct AuthRule {
      58              :     pub(crate) id: String,
      59              :     pub(crate) jwks_url: url::Url,
      60              :     pub(crate) audience: Option<String>,
      61              :     pub(crate) role_names: Vec<RoleNameInt>,
      62              : }
      63              : 
      64              : pub struct JwkCache {
      65              :     client: reqwest_middleware::ClientWithMiddleware,
      66              : 
      67              :     map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
      68              : }
      69              : 
      70              : pub(crate) struct JwkCacheEntry {
      71              :     /// Should refetch at least every hour to verify when old keys have been removed.
      72              :     /// Should refetch when new key IDs are seen only every 5 minutes or so
      73              :     last_retrieved: Instant,
      74              : 
      75              :     /// cplane will return multiple JWKs urls that we need to scrape.
      76              :     key_sets: ahash::HashMap<String, KeySet>,
      77              : }
      78              : 
      79              : impl JwkCacheEntry {
      80           19 :     fn find_jwk_and_audience(
      81           19 :         &self,
      82           19 :         key_id: &str,
      83           19 :         role_name: &RoleName,
      84           19 :     ) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
      85           19 :         self.key_sets
      86           19 :             .values()
      87           19 :             // make sure our requested role has access to the key set
      88           29 :             .filter(|key_set| key_set.role_names.iter().any(|role| **role == **role_name))
      89           19 :             // try and find the requested key-id in the key set
      90           22 :             .find_map(|key_set| {
      91           22 :                 key_set
      92           22 :                     .find_key(key_id)
      93           22 :                     .map(|jwk| (jwk, key_set.audience.as_deref()))
      94           22 :             })
      95           19 :     }
      96              : }
      97              : 
      98              : struct KeySet {
      99              :     jwks: jose_jwk::JwkSet,
     100              :     audience: Option<String>,
     101              :     role_names: Vec<RoleNameInt>,
     102              : }
     103              : 
     104              : impl KeySet {
     105           22 :     fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
     106           22 :         self.jwks
     107           22 :             .keys
     108           22 :             .iter()
     109           30 :             .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
     110           22 :     }
     111              : }
     112              : 
     113              : pub(crate) struct JwkCacheEntryLock {
     114              :     cached: ArcSwapOption<JwkCacheEntry>,
     115              :     lookup: tokio::sync::Semaphore,
     116              : }
     117              : 
     118              : impl Default for JwkCacheEntryLock {
     119            7 :     fn default() -> Self {
     120            7 :         JwkCacheEntryLock {
     121            7 :             cached: ArcSwapOption::empty(),
     122            7 :             lookup: tokio::sync::Semaphore::new(1),
     123            7 :         }
     124            7 :     }
     125              : }
     126              : 
     127           18 : #[derive(Deserialize)]
     128              : struct JwkSet<'a> {
     129              :     /// we parse into raw-value because not all keys in a JWKS are ones
     130              :     /// we can parse directly, so we parse them lazily.
     131              :     #[serde(borrow)]
     132              :     keys: Vec<&'a RawValue>,
     133              : }
     134              : 
     135              : impl JwkCacheEntryLock {
     136            7 :     async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
     137            7 :         JwkRenewalPermit::acquire_permit(self).await
     138            7 :     }
     139              : 
     140            0 :     fn try_acquire_permit<'a>(self: &'a Arc<Self>) -> Option<JwkRenewalPermit<'a>> {
     141            0 :         JwkRenewalPermit::try_acquire_permit(self)
     142            0 :     }
     143              : 
     144            7 :     async fn renew_jwks<F: FetchAuthRules>(
     145            7 :         &self,
     146            7 :         _permit: JwkRenewalPermit<'_>,
     147            7 :         ctx: &RequestMonitoring,
     148            7 :         client: &reqwest_middleware::ClientWithMiddleware,
     149            7 :         endpoint: EndpointId,
     150            7 :         auth_rules: &F,
     151            7 :     ) -> Result<Arc<JwkCacheEntry>, JwtError> {
     152            7 :         // double check that no one beat us to updating the cache.
     153            7 :         let now = Instant::now();
     154            7 :         let guard = self.cached.load_full();
     155            7 :         if let Some(cached) = guard {
     156            0 :             let last_update = now.duration_since(cached.last_retrieved);
     157            0 :             if last_update < Duration::from_secs(300) {
     158            0 :                 return Ok(cached);
     159            0 :             }
     160            7 :         }
     161              : 
     162            7 :         let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
     163            7 :         let mut key_sets =
     164            7 :             ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
     165              : 
     166              :         // TODO(conrad): run concurrently
     167              :         // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
     168           16 :         for rule in rules {
     169            9 :             let req = client.get(rule.jwks_url.clone());
     170            9 :             // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
     171            9 :             // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
     172           21 :             match req.send().await.and_then(|r| {
     173            9 :                 r.error_for_status()
     174            9 :                     .map_err(reqwest_middleware::Error::Reqwest)
     175            9 :             }) {
     176              :                 // todo: should we re-insert JWKs if we want to keep this JWKs URL?
     177              :                 // I expect these failures would be quite sparse.
     178            0 :                 Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
     179            9 :                 Ok(r) => {
     180            9 :                     let resp: http::Response<reqwest::Body> = r.into();
     181              : 
     182            9 :                     let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE)
     183            0 :                         .await
     184              :                     {
     185            9 :                         Ok(bytes) => bytes,
     186            0 :                         Err(e) => {
     187            0 :                             tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
     188            0 :                             continue;
     189              :                         }
     190              :                     };
     191              : 
     192            9 :                     match serde_json::from_slice::<JwkSet>(&bytes) {
     193            0 :                         Err(e) => {
     194            0 :                             tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
     195              :                         }
     196            9 :                         Ok(jwks) => {
     197            9 :                             // size_of::<&RawValue>() == 16
     198            9 :                             // size_of::<jose_jwk::Jwk>() == 288
     199            9 :                             // better to not pre-allocate this as it might be pretty large - especially if it has many
     200            9 :                             // keys we don't want or need.
     201            9 :                             // trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`
     202            9 :                             // this would consume 8MiB just like that!
     203            9 :                             let mut keys = vec![];
     204            9 :                             let mut failed = 0;
     205           23 :                             for key in jwks.keys {
     206           14 :                                 match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
     207           13 :                                     Ok(key) => {
     208           13 :                                         // if `use` (called `cls` in rust) is specified to be something other than signing,
     209           13 :                                         // we can skip storing it.
     210           13 :                                         if key
     211           13 :                                             .prm
     212           13 :                                             .cls
     213           13 :                                             .as_ref()
     214           13 :                                             .is_some_and(|c| *c != jose_jwk::Class::Signing)
     215              :                                         {
     216            0 :                                             continue;
     217           13 :                                         }
     218           13 : 
     219           13 :                                         keys.push(key);
     220              :                                     }
     221            1 :                                     Err(e) => {
     222            1 :                                         tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK");
     223            1 :                                         failed += 1;
     224              :                                     }
     225              :                                 }
     226              :                             }
     227            9 :                             keys.shrink_to_fit();
     228            9 : 
     229            9 :                             if failed > 0 {
     230            1 :                                 tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs");
     231            8 :                             }
     232              : 
     233            9 :                             if keys.is_empty() {
     234            0 :                                 tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body");
     235            0 :                                 continue;
     236            9 :                             }
     237            9 : 
     238            9 :                             let jwks = jose_jwk::JwkSet { keys };
     239            9 :                             key_sets.insert(
     240            9 :                                 rule.id,
     241            9 :                                 KeySet {
     242            9 :                                     jwks,
     243            9 :                                     audience: rule.audience,
     244            9 :                                     role_names: rule.role_names,
     245            9 :                                 },
     246            9 :                             );
     247              :                         }
     248              :                     };
     249              :                 }
     250              :             }
     251              :         }
     252              : 
     253            7 :         let entry = Arc::new(JwkCacheEntry {
     254            7 :             last_retrieved: now,
     255            7 :             key_sets,
     256            7 :         });
     257            7 :         self.cached.swap(Some(Arc::clone(&entry)));
     258            7 : 
     259            7 :         Ok(entry)
     260            7 :     }
     261              : 
     262           19 :     async fn get_or_update_jwk_cache<F: FetchAuthRules>(
     263           19 :         self: &Arc<Self>,
     264           19 :         ctx: &RequestMonitoring,
     265           19 :         client: &reqwest_middleware::ClientWithMiddleware,
     266           19 :         endpoint: EndpointId,
     267           19 :         fetch: &F,
     268           19 :     ) -> Result<Arc<JwkCacheEntry>, JwtError> {
     269           19 :         let now = Instant::now();
     270           19 :         let guard = self.cached.load_full();
     271              : 
     272              :         // if we have no cached JWKs, try and get some
     273           19 :         let Some(cached) = guard else {
     274            7 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     275            7 :             let permit = self.acquire_permit().await;
     276           21 :             return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
     277              :         };
     278              : 
     279           12 :         let last_update = now.duration_since(cached.last_retrieved);
     280           12 : 
     281           12 :         // check if the cached JWKs need updating.
     282           12 :         if last_update > MAX_RENEW {
     283            0 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     284            0 :             let permit = self.acquire_permit().await;
     285              : 
     286              :             // it's been too long since we checked the keys. wait for them to update.
     287            0 :             return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
     288           12 :         }
     289           12 : 
     290           12 :         // every 5 minutes we should spawn a job to eagerly update the token.
     291           12 :         if last_update > AUTO_RENEW {
     292            0 :             if let Some(permit) = self.try_acquire_permit() {
     293            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit acquired");
     294            0 :                 let permit = permit.into_owned();
     295            0 :                 let entry = self.clone();
     296            0 :                 let client = client.clone();
     297            0 :                 let fetch = fetch.clone();
     298            0 :                 let ctx = ctx.clone();
     299            0 :                 tokio::spawn(async move {
     300            0 :                     if let Err(e) = entry
     301            0 :                         .renew_jwks(permit, &ctx, &client, endpoint, &fetch)
     302            0 :                         .await
     303              :                     {
     304            0 :                         tracing::warn!(error=?e, "could not fetch JWKs in background job");
     305            0 :                     }
     306            0 :                 });
     307            0 :             } else {
     308            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
     309              :             }
     310           12 :         }
     311              : 
     312           12 :         Ok(cached)
     313           19 :     }
     314              : 
     315           19 :     async fn check_jwt<F: FetchAuthRules>(
     316           19 :         self: &Arc<Self>,
     317           19 :         ctx: &RequestMonitoring,
     318           19 :         jwt: &str,
     319           19 :         client: &reqwest_middleware::ClientWithMiddleware,
     320           19 :         endpoint: EndpointId,
     321           19 :         role_name: &RoleName,
     322           19 :         fetch: &F,
     323           19 :     ) -> Result<ComputeCredentialKeys, JwtError> {
     324              :         // JWT compact form is defined to be
     325              :         // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
     326              :         // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
     327              : 
     328           19 :         let (header_payload, signature) = jwt
     329           19 :             .rsplit_once('.')
     330           19 :             .ok_or(JwtEncodingError::InvalidCompactForm)?;
     331           19 :         let (header, payload) = header_payload
     332           19 :             .split_once('.')
     333           19 :             .ok_or(JwtEncodingError::InvalidCompactForm)?;
     334              : 
     335           19 :         let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?;
     336           19 :         let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
     337              : 
     338           19 :         let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)?;
     339              : 
     340           19 :         let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
     341              : 
     342           19 :         let mut guard = self
     343           19 :             .get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
     344           21 :             .await?;
     345              : 
     346              :         // get the key from the JWKs if possible. If not, wait for the keys to update.
     347           18 :         let (jwk, expected_audience) = loop {
     348           19 :             match guard.find_jwk_and_audience(&kid, role_name) {
     349           18 :                 Some(jwk) => break jwk,
     350            1 :                 None if guard.last_retrieved.elapsed() > MIN_RENEW => {
     351            0 :                     let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     352              : 
     353            0 :                     let permit = self.acquire_permit().await;
     354            0 :                     guard = self
     355            0 :                         .renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
     356            0 :                         .await?;
     357              :                 }
     358            1 :                 _ => return Err(JwtError::JwkNotFound),
     359              :             }
     360              :         };
     361              : 
     362           18 :         if !jwk.is_supported(&header.algorithm) {
     363            0 :             return Err(JwtError::SignatureAlgorithmNotSupported);
     364           18 :         }
     365           18 : 
     366           18 :         match &jwk.key {
     367           13 :             jose_jwk::Key::Ec(key) => {
     368           13 :                 verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
     369              :             }
     370            5 :             jose_jwk::Key::Rsa(key) => {
     371            5 :                 verify_rsa_signature(header_payload.as_bytes(), &sig, key, &header.algorithm)?;
     372              :             }
     373            0 :             key => return Err(JwtError::UnsupportedKeyType(key.into())),
     374              :         };
     375              : 
     376           17 :         let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)?;
     377           17 :         let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
     378              : 
     379           17 :         tracing::debug!(?payload, "JWT signature valid with claims");
     380              : 
     381           17 :         if let Some(aud) = expected_audience {
     382            7 :             if payload.audience.0.iter().all(|s| s != aud) {
     383            5 :                 return Err(JwtError::InvalidClaims(
     384            5 :                     JwtClaimsError::InvalidJwtTokenAudience,
     385            5 :                 ));
     386            2 :             }
     387           10 :         }
     388              : 
     389           12 :         let now = SystemTime::now();
     390              : 
     391           12 :         if let Some(exp) = payload.expiration {
     392           11 :             if now >= exp + CLOCK_SKEW_LEEWAY {
     393            1 :                 return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired));
     394           10 :             }
     395            1 :         }
     396              : 
     397           11 :         if let Some(nbf) = payload.not_before {
     398           10 :             if nbf >= now + CLOCK_SKEW_LEEWAY {
     399            1 :                 return Err(JwtError::InvalidClaims(
     400            1 :                     JwtClaimsError::JwtTokenNotYetReadyToUse,
     401            1 :                 ));
     402            9 :             }
     403            1 :         }
     404              : 
     405           10 :         Ok(ComputeCredentialKeys::JwtPayload(payloadb))
     406           19 :     }
     407              : }
     408              : 
     409              : impl JwkCache {
     410           19 :     pub(crate) async fn check_jwt<F: FetchAuthRules>(
     411           19 :         &self,
     412           19 :         ctx: &RequestMonitoring,
     413           19 :         endpoint: EndpointId,
     414           19 :         role_name: &RoleName,
     415           19 :         fetch: &F,
     416           19 :         jwt: &str,
     417           19 :     ) -> Result<ComputeCredentialKeys, JwtError> {
     418           19 :         // try with just a read lock first
     419           19 :         let key = (endpoint.clone(), role_name.clone());
     420           19 :         let entry = self.map.get(&key).as_deref().map(Arc::clone);
     421           19 :         let entry = entry.unwrap_or_else(|| {
     422            7 :             // acquire a write lock after to insert.
     423            7 :             let entry = self.map.entry(key).or_default();
     424            7 :             Arc::clone(&*entry)
     425           19 :         });
     426           19 : 
     427           19 :         entry
     428           19 :             .check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
     429           21 :             .await
     430           19 :     }
     431              : }
     432              : 
     433              : impl Default for JwkCache {
     434            9 :     fn default() -> Self {
     435            9 :         let client = Client::builder()
     436            9 :             .user_agent(JWKS_USER_AGENT)
     437            9 :             .redirect(redirect::Policy::none())
     438            9 :             .tls_built_in_native_certs(true)
     439            9 :             .connect_timeout(JWKS_CONNECT_TIMEOUT)
     440            9 :             .timeout(JWKS_FETCH_TIMEOUT)
     441            9 :             .build()
     442            9 :             .expect("client config should be valid");
     443            9 : 
     444            9 :         // Retry up to 3 times with increasing intervals between attempts.
     445            9 :         let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES);
     446            9 : 
     447            9 :         let client = reqwest_middleware::ClientBuilder::new(client)
     448            9 :             .with(RetryTransientMiddleware::new_with_policy(retry_policy))
     449            9 :             .build();
     450            9 : 
     451            9 :         JwkCache {
     452            9 :             client,
     453            9 :             map: DashMap::default(),
     454            9 :         }
     455            9 :     }
     456              : }
     457              : 
     458           13 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> Result<(), JwtError> {
     459              :     use ecdsa::Signature;
     460              :     use signature::Verifier;
     461              : 
     462           13 :     match key.crv {
     463              :         jose_jwk::EcCurves::P256 => {
     464           13 :             let pk = p256::PublicKey::try_from(key).map_err(JwtError::InvalidP256Key)?;
     465           13 :             let key = p256::ecdsa::VerifyingKey::from(&pk);
     466           13 :             let sig = Signature::from_slice(sig)?;
     467           13 :             key.verify(data, &sig)?;
     468              :         }
     469            0 :         key => return Err(JwtError::UnsupportedEcKeyType(key)),
     470              :     }
     471              : 
     472           12 :     Ok(())
     473           13 : }
     474              : 
     475            5 : fn verify_rsa_signature(
     476            5 :     data: &[u8],
     477            5 :     sig: &[u8],
     478            5 :     key: &jose_jwk::Rsa,
     479            5 :     alg: &jose_jwa::Algorithm,
     480            5 : ) -> Result<(), JwtError> {
     481              :     use jose_jwa::{Algorithm, Signing};
     482              :     use rsa::pkcs1v15::{Signature, VerifyingKey};
     483              :     use rsa::RsaPublicKey;
     484              : 
     485            5 :     let key = RsaPublicKey::try_from(key).map_err(JwtError::InvalidRsaKey)?;
     486              : 
     487            5 :     match alg {
     488              :         Algorithm::Signing(Signing::Rs256) => {
     489            5 :             let key = VerifyingKey::<sha2::Sha256>::new(key);
     490            5 :             let sig = Signature::try_from(sig)?;
     491            5 :             key.verify(data, &sig)?;
     492              :         }
     493            0 :         _ => return Err(JwtError::InvalidRsaSigningAlgorithm),
     494              :     };
     495              : 
     496            5 :     Ok(())
     497            5 : }
     498              : 
     499              : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
     500           57 : #[derive(serde::Deserialize, serde::Serialize)]
     501              : struct JwtHeader<'a> {
     502              :     /// must be a supported alg
     503              :     #[serde(rename = "alg")]
     504              :     algorithm: jose_jwa::Algorithm,
     505              :     /// key id, must be provided for our usecase
     506              :     #[serde(rename = "kid", borrow)]
     507              :     key_id: Option<Cow<'a, str>>,
     508              : }
     509              : 
     510              : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
     511          121 : #[derive(serde::Deserialize, Debug)]
     512              : #[allow(dead_code)]
     513              : struct JwtPayload<'a> {
     514              :     /// Audience - Recipient for which the JWT is intended
     515              :     #[serde(rename = "aud", default)]
     516              :     audience: OneOrMany,
     517              :     /// Expiration - Time after which the JWT expires
     518              :     #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
     519              :     expiration: Option<SystemTime>,
     520              :     /// Not before - Time after which the JWT expires
     521              :     #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
     522              :     not_before: Option<SystemTime>,
     523              : 
     524              :     // the following entries are only extracted for the sake of debug logging.
     525              :     /// Issuer of the JWT
     526              :     #[serde(rename = "iss", borrow)]
     527              :     issuer: Option<Cow<'a, str>>,
     528              :     /// Subject of the JWT (the user)
     529              :     #[serde(rename = "sub", borrow)]
     530              :     subject: Option<Cow<'a, str>>,
     531              :     /// Unique token identifier
     532              :     #[serde(rename = "jti", borrow)]
     533              :     jwt_id: Option<Cow<'a, str>>,
     534              :     /// Unique session identifier
     535              :     #[serde(rename = "sid", borrow)]
     536              :     session_id: Option<Cow<'a, str>>,
     537              : }
     538              : 
     539              : /// `OneOrMany` supports parsing either a single item or an array of items.
     540              : ///
     541              : /// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
     542              : ///
     543              : /// > The "aud" (audience) claim identifies the recipients that the JWT is
     544              : /// > intended for.  Each principal intended to process the JWT MUST
     545              : /// > identify itself with a value in the audience claim.  If the principal
     546              : /// > processing the claim does not identify itself with a value in the
     547              : /// > "aud" claim when this claim is present, then the JWT MUST be
     548              : /// > rejected.  In the general case, the "aud" value is **an array of case-
     549              : /// > sensitive strings**, each containing a StringOrURI value.  In the
     550              : /// > special case when the JWT has one audience, the "aud" value MAY be a
     551              : /// > **single case-sensitive string** containing a StringOrURI value.  The
     552              : /// > interpretation of audience values is generally application specific.
     553              : /// > Use of this claim is OPTIONAL.
     554              : #[derive(Default, Debug)]
     555              : struct OneOrMany(Vec<String>);
     556              : 
     557              : impl<'de> Deserialize<'de> for OneOrMany {
     558           15 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     559           15 :     where
     560           15 :         D: Deserializer<'de>,
     561           15 :     {
     562              :         struct OneOrManyVisitor;
     563              :         impl<'de> Visitor<'de> for OneOrManyVisitor {
     564              :             type Value = OneOrMany;
     565              : 
     566            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
     567            0 :                 formatter.write_str("a single string or an array of strings")
     568            0 :             }
     569              : 
     570            2 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
     571            2 :             where
     572            2 :                 E: serde::de::Error,
     573            2 :             {
     574            2 :                 Ok(OneOrMany(vec![v.to_owned()]))
     575            2 :             }
     576              : 
     577           13 :             fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
     578           13 :             where
     579           13 :                 A: serde::de::SeqAccess<'de>,
     580           13 :             {
     581           13 :                 let mut v = vec![];
     582           44 :                 while let Some(s) = seq.next_element()? {
     583           31 :                     v.push(s);
     584           31 :                 }
     585           13 :                 Ok(OneOrMany(v))
     586           13 :             }
     587              :         }
     588           15 :         deserializer.deserialize_any(OneOrManyVisitor)
     589           15 :     }
     590              : }
     591              : 
     592           21 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
     593           21 :     let d = <Option<u64>>::deserialize(d)?;
     594           21 :     Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
     595           21 : }
     596              : 
     597              : struct JwkRenewalPermit<'a> {
     598              :     inner: Option<JwkRenewalPermitInner<'a>>,
     599              : }
     600              : 
     601              : enum JwkRenewalPermitInner<'a> {
     602              :     Owned(Arc<JwkCacheEntryLock>),
     603              :     Borrowed(&'a Arc<JwkCacheEntryLock>),
     604              : }
     605              : 
     606              : impl JwkRenewalPermit<'_> {
     607            0 :     fn into_owned(mut self) -> JwkRenewalPermit<'static> {
     608            0 :         JwkRenewalPermit {
     609            0 :             inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
     610            0 :         }
     611            0 :     }
     612              : 
     613            7 :     async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
     614            7 :         match from.lookup.acquire().await {
     615            7 :             Ok(permit) => {
     616            7 :                 permit.forget();
     617            7 :                 JwkRenewalPermit {
     618            7 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     619            7 :                 }
     620              :             }
     621            0 :             Err(_) => panic!("semaphore should not be closed"),
     622              :         }
     623            7 :     }
     624              : 
     625            0 :     fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
     626            0 :         match from.lookup.try_acquire() {
     627            0 :             Ok(permit) => {
     628            0 :                 permit.forget();
     629            0 :                 Some(JwkRenewalPermit {
     630            0 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     631            0 :                 })
     632              :             }
     633            0 :             Err(tokio::sync::TryAcquireError::NoPermits) => None,
     634            0 :             Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
     635              :         }
     636            0 :     }
     637              : }
     638              : 
     639              : impl JwkRenewalPermitInner<'_> {
     640            0 :     fn into_owned(self) -> JwkRenewalPermitInner<'static> {
     641            0 :         match self {
     642            0 :             JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
     643            0 :             JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
     644              :         }
     645            0 :     }
     646              : }
     647              : 
     648              : impl Drop for JwkRenewalPermit<'_> {
     649            7 :     fn drop(&mut self) {
     650            7 :         let entry = match &self.inner {
     651            0 :             None => return,
     652            0 :             Some(JwkRenewalPermitInner::Owned(p)) => p,
     653            7 :             Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
     654              :         };
     655            7 :         entry.lookup.add_permits(1);
     656            7 :     }
     657              : }
     658              : 
     659            1 : #[derive(Error, Debug)]
     660              : #[non_exhaustive]
     661              : pub(crate) enum JwtError {
     662              :     #[error("jwk not found")]
     663              :     JwkNotFound,
     664              : 
     665              :     #[error("missing key id")]
     666              :     MissingKeyId,
     667              : 
     668              :     #[error("Provided authentication token is not a valid JWT encoding")]
     669              :     JwtEncoding(#[from] JwtEncodingError),
     670              : 
     671              :     #[error(transparent)]
     672              :     InvalidClaims(#[from] JwtClaimsError),
     673              : 
     674              :     #[error("invalid P256 key")]
     675              :     InvalidP256Key(jose_jwk::crypto::Error),
     676              : 
     677              :     #[error("invalid RSA key")]
     678              :     InvalidRsaKey(jose_jwk::crypto::Error),
     679              : 
     680              :     #[error("invalid RSA signing algorithm")]
     681              :     InvalidRsaSigningAlgorithm,
     682              : 
     683              :     #[error("unsupported EC key type {0:?}")]
     684              :     UnsupportedEcKeyType(jose_jwk::EcCurves),
     685              : 
     686              :     #[error("unsupported key type {0:?}")]
     687              :     UnsupportedKeyType(KeyType),
     688              : 
     689              :     #[error("signature algorithm not supported")]
     690              :     SignatureAlgorithmNotSupported,
     691              : 
     692              :     #[error("signature error: {0}")]
     693              :     Signature(#[from] signature::Error),
     694              : 
     695              :     #[error("failed to fetch auth rules: {0}")]
     696              :     FetchAuthRules(#[from] FetchAuthRulesError),
     697              : }
     698              : 
     699              : impl From<base64::DecodeError> for JwtError {
     700            0 :     fn from(err: base64::DecodeError) -> Self {
     701            0 :         JwtEncodingError::Base64Decode(err).into()
     702            0 :     }
     703              : }
     704              : 
     705              : impl From<serde_json::Error> for JwtError {
     706            0 :     fn from(err: serde_json::Error) -> Self {
     707            0 :         JwtEncodingError::SerdeJson(err).into()
     708            0 :     }
     709              : }
     710              : 
     711            0 : #[derive(Error, Debug)]
     712              : #[non_exhaustive]
     713              : pub enum JwtEncodingError {
     714              :     #[error(transparent)]
     715              :     Base64Decode(#[from] base64::DecodeError),
     716              : 
     717              :     #[error(transparent)]
     718              :     SerdeJson(#[from] serde_json::Error),
     719              : 
     720              :     #[error("invalid compact form")]
     721              :     InvalidCompactForm,
     722              : }
     723              : 
     724            0 : #[derive(Error, Debug, PartialEq)]
     725              : #[non_exhaustive]
     726              : pub enum JwtClaimsError {
     727              :     #[error("invalid JWT token audience")]
     728              :     InvalidJwtTokenAudience,
     729              : 
     730              :     #[error("JWT token has expired")]
     731              :     JwtTokenHasExpired,
     732              : 
     733              :     #[error("JWT token is not yet ready to use")]
     734              :     JwtTokenNotYetReadyToUse,
     735              : }
     736              : 
     737              : #[allow(dead_code, reason = "Debug use only")]
     738              : #[derive(Debug)]
     739              : pub(crate) enum KeyType {
     740              :     Ec(jose_jwk::EcCurves),
     741              :     Rsa,
     742              :     Oct,
     743              :     Okp(jose_jwk::OkpCurves),
     744              :     Unknown,
     745              : }
     746              : 
     747              : impl From<&jose_jwk::Key> for KeyType {
     748            0 :     fn from(key: &jose_jwk::Key) -> Self {
     749            0 :         match key {
     750            0 :             jose_jwk::Key::Ec(ec) => Self::Ec(ec.crv),
     751            0 :             jose_jwk::Key::Rsa(_rsa) => Self::Rsa,
     752            0 :             jose_jwk::Key::Oct(_oct) => Self::Oct,
     753            0 :             jose_jwk::Key::Okp(okp) => Self::Okp(okp.crv),
     754            0 :             _ => Self::Unknown,
     755              :         }
     756            0 :     }
     757              : }
     758              : 
     759              : #[cfg(test)]
     760              : mod tests {
     761              :     use std::future::IntoFuture;
     762              :     use std::net::SocketAddr;
     763              :     use std::time::SystemTime;
     764              : 
     765              :     use base64::URL_SAFE_NO_PAD;
     766              :     use bytes::Bytes;
     767              :     use http::Response;
     768              :     use http_body_util::Full;
     769              :     use hyper::service::service_fn;
     770              :     use hyper_util::rt::TokioIo;
     771              :     use rand::rngs::OsRng;
     772              :     use rsa::pkcs8::DecodePrivateKey;
     773              :     use serde::Serialize;
     774              :     use serde_json::json;
     775              :     use signature::Signer;
     776              :     use tokio::net::TcpListener;
     777              : 
     778              :     use super::*;
     779              :     use crate::types::RoleName;
     780              : 
     781            6 :     fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
     782            6 :         let sk = p256::SecretKey::random(&mut OsRng);
     783            6 :         let pk = sk.public_key().into();
     784            6 :         let jwk = jose_jwk::Jwk {
     785            6 :             key: jose_jwk::Key::Ec(pk),
     786            6 :             prm: jose_jwk::Parameters {
     787            6 :                 kid: Some(kid),
     788            6 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
     789            6 :                 ..Default::default()
     790            6 :             },
     791            6 :         };
     792            6 :         (sk, jwk)
     793            6 :     }
     794              : 
     795            4 :     fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
     796            4 :         let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap();
     797            4 :         let pk = sk.to_public_key().into();
     798            4 :         let jwk = jose_jwk::Jwk {
     799            4 :             key: jose_jwk::Key::Rsa(pk),
     800            4 :             prm: jose_jwk::Parameters {
     801            4 :                 kid: Some(kid),
     802            4 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
     803            4 :                 ..Default::default()
     804            4 :             },
     805            4 :         };
     806            4 :         (sk, jwk)
     807            4 :     }
     808              : 
     809            8 :     fn now() -> u64 {
     810            8 :         SystemTime::now()
     811            8 :             .duration_since(SystemTime::UNIX_EPOCH)
     812            8 :             .unwrap()
     813            8 :             .as_secs()
     814            8 :     }
     815              : 
     816            7 :     fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
     817            7 :         let now = now();
     818            7 :         let body = typed_json::json! {{
     819            7 :             "exp": now + 3600,
     820            7 :             "nbf": now,
     821            7 :             "aud": ["audience1", "neon", "audience2"],
     822            7 :             "sub": "user1",
     823            7 :             "sid": "session1",
     824            7 :             "jti": "token1",
     825            7 :             "iss": "neon-testing",
     826            7 :         }};
     827            7 :         build_custom_jwt_payload(kid, body, sig)
     828            7 :     }
     829              : 
     830           15 :     fn build_custom_jwt_payload(
     831           15 :         kid: String,
     832           15 :         body: impl Serialize,
     833           15 :         sig: jose_jwa::Signing,
     834           15 :     ) -> String {
     835           15 :         let header = JwtHeader {
     836           15 :             algorithm: jose_jwa::Algorithm::Signing(sig),
     837           15 :             key_id: Some(Cow::Owned(kid)),
     838           15 :         };
     839           15 : 
     840           15 :         let header =
     841           15 :             base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
     842           15 :         let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
     843           15 : 
     844           15 :         format!("{header}.{body}")
     845           15 :     }
     846              : 
     847            3 :     fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
     848              :         use p256::ecdsa::{Signature, SigningKey};
     849              : 
     850            3 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
     851            3 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     852            3 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     853            3 : 
     854            3 :         format!("{payload}.{sig}")
     855            3 :     }
     856              : 
     857            8 :     fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
     858              :         use p256::ecdsa::{Signature, SigningKey};
     859              : 
     860            8 :         let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
     861            8 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     862            8 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     863            8 : 
     864            8 :         format!("{payload}.{sig}")
     865            8 :     }
     866              : 
     867            4 :     fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
     868              :         use rsa::pkcs1v15::SigningKey;
     869              :         use rsa::signature::SignatureEncoding;
     870              : 
     871            4 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
     872            4 :         let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
     873            4 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     874            4 : 
     875            4 :         format!("{payload}.{sig}")
     876            4 :     }
     877              : 
     878              :     // RSA key gen is slow....
     879              :     const RS1: &str = "-----BEGIN PRIVATE KEY-----
     880              : MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y
     881              : aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA
     882              : SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA
     883              : cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X
     884              : mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s
     885              : PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM
     886              : nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9
     887              : ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi
     888              : 5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+
     889              : KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz
     890              : 1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP
     891              : We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony
     892              : UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu
     893              : GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+
     894              : HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU
     895              : f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24
     896              : DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY
     897              : 9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF
     898              : /4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL
     899              : 3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE
     900              : qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d
     901              : ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW
     902              : rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf
     903              : JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI
     904              : rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9
     905              : 8LMFlz1QPqSZYN/A/kOcLBfa3A==
     906              : -----END PRIVATE KEY-----
     907              : ";
     908              :     const RS2: &str = "-----BEGIN PRIVATE KEY-----
     909              : MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J
     910              : HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J
     911              : zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk
     912              : qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO
     913              : IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt
     914              : ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy
     915              : uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv
     916              : K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG
     917              : LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe
     918              : QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT
     919              : zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX
     920              : 0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR
     921              : +Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1
     922              : P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL
     923              : x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/
     924              : FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg
     925              : iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M
     926              : 4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA
     927              : iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj
     928              : N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB
     929              : uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U
     930              : qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP
     931              : WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB
     932              : rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne
     933              : X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
     934              : 5JiconnI5aLek0QVPoFaVXFa
     935              : -----END PRIVATE KEY-----
     936              : ";
     937              : 
     938              :     #[derive(Clone)]
     939              :     struct Fetch(Vec<AuthRule>);
     940              : 
     941              :     impl FetchAuthRules for Fetch {
     942            7 :         async fn fetch_auth_rules(
     943            7 :             &self,
     944            7 :             _ctx: &RequestMonitoring,
     945            7 :             _endpoint: EndpointId,
     946            7 :         ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
     947            7 :             Ok(self.0.clone())
     948            7 :         }
     949              :     }
     950              : 
     951            6 :     async fn jwks_server(
     952            6 :         router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
     953            6 :     ) -> SocketAddr {
     954            6 :         let router = Arc::new(router);
     955            9 :         let service = service_fn(move |req| {
     956            9 :             let router = Arc::clone(&router);
     957            9 :             async move {
     958            9 :                 match router(req.uri().path()) {
     959            9 :                     Some(body) => Response::builder()
     960            9 :                         .status(200)
     961            9 :                         .body(Full::new(Bytes::from(body))),
     962            0 :                     None => Response::builder()
     963            0 :                         .status(404)
     964            0 :                         .body(Full::new(Bytes::new())),
     965              :                 }
     966            9 :             }
     967            9 :         });
     968              : 
     969            6 :         let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
     970            6 :         let server = hyper::server::conn::http1::Builder::new();
     971            6 :         let addr = listener.local_addr().unwrap();
     972            6 :         tokio::spawn(async move {
     973              :             loop {
     974           12 :                 let (s, _) = listener.accept().await.unwrap();
     975            6 :                 let serve = server.serve_connection(TokioIo::new(s), service.clone());
     976            6 :                 tokio::spawn(serve.into_future());
     977              :             }
     978            6 :         });
     979            6 : 
     980            6 :         addr
     981            6 :     }
     982              : 
     983              :     #[tokio::test]
     984            1 :     async fn check_jwt_happy_path() {
     985            1 :         let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
     986            1 :         let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
     987            1 :         let (ec1, jwk3) = new_ec_jwk("ec1".into());
     988            1 :         let (ec2, jwk4) = new_ec_jwk("ec2".into());
     989            1 : 
     990            1 :         let foo_jwks = jose_jwk::JwkSet {
     991            1 :             keys: vec![jwk1, jwk3],
     992            1 :         };
     993            1 :         let bar_jwks = jose_jwk::JwkSet {
     994            1 :             keys: vec![jwk2, jwk4],
     995            1 :         };
     996            1 : 
     997            4 :         let jwks_addr = jwks_server(move |path| match path {
     998            4 :             "/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
     999            2 :             "/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
    1000            1 :             _ => None,
    1001            4 :         })
    1002            1 :         .await;
    1003            1 : 
    1004            1 :         let role_name1 = RoleName::from("anonymous");
    1005            1 :         let role_name2 = RoleName::from("authenticated");
    1006            1 : 
    1007            1 :         let roles = vec![
    1008            1 :             RoleNameInt::from(&role_name1),
    1009            1 :             RoleNameInt::from(&role_name2),
    1010            1 :         ];
    1011            1 :         let rules = vec![
    1012            1 :             AuthRule {
    1013            1 :                 id: "foo".to_owned(),
    1014            1 :                 jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
    1015            1 :                 audience: None,
    1016            1 :                 role_names: roles.clone(),
    1017            1 :             },
    1018            1 :             AuthRule {
    1019            1 :                 id: "bar".to_owned(),
    1020            1 :                 jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
    1021            1 :                 audience: None,
    1022            1 :                 role_names: roles.clone(),
    1023            1 :             },
    1024            1 :         ];
    1025            1 : 
    1026            1 :         let fetch = Fetch(rules);
    1027            1 :         let jwk_cache = JwkCache::default();
    1028            1 : 
    1029            1 :         let endpoint = EndpointId::from("ep");
    1030            1 : 
    1031            1 :         let jwt1 = new_rsa_jwt("rs1".into(), rs1);
    1032            1 :         let jwt2 = new_rsa_jwt("rs2".into(), rs2);
    1033            1 :         let jwt3 = new_ec_jwt("ec1".into(), &ec1);
    1034            1 :         let jwt4 = new_ec_jwt("ec2".into(), &ec2);
    1035            1 : 
    1036            1 :         let tokens = [jwt1, jwt2, jwt3, jwt4];
    1037            1 :         let role_names = [role_name1, role_name2];
    1038            3 :         for role in &role_names {
    1039           10 :             for token in &tokens {
    1040            8 :                 jwk_cache
    1041            8 :                     .check_jwt(
    1042            8 :                         &RequestMonitoring::test(),
    1043            8 :                         endpoint.clone(),
    1044            8 :                         role,
    1045            8 :                         &fetch,
    1046            8 :                         token,
    1047            8 :                     )
    1048            6 :                     .await
    1049            8 :                     .unwrap();
    1050            1 :             }
    1051            1 :         }
    1052            1 :     }
    1053              : 
    1054              :     /// AWS Cognito escapes the `/` in the URL.
    1055              :     #[tokio::test]
    1056            1 :     async fn check_jwt_regression_cognito_issuer() {
    1057            1 :         let (key, jwk) = new_ec_jwk("key".into());
    1058            1 : 
    1059            1 :         let now = now();
    1060            1 :         let token = new_custom_ec_jwt(
    1061            1 :             "key".into(),
    1062            1 :             &key,
    1063            1 :             typed_json::json! {{
    1064            1 :                 "sub": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
    1065            1 :                 // cognito uses `\/`. I cannot replicated that easily here as serde_json will refuse
    1066            1 :                 // to write that escape character. instead I will make a bogus URL using `\` instead.
    1067            1 :                 "iss": "https:\\\\cognito-idp.us-west-2.amazonaws.com\\us-west-2_abcdefgh",
    1068            1 :                 "client_id": "abcdefghijklmnopqrstuvwxyz",
    1069            1 :                 "origin_jti": "6759d132-3fe7-446e-9e90-2fe7e8017893",
    1070            1 :                 "event_id": "ec9c36ab-b01d-46a0-94e4-87fde6767065",
    1071            1 :                 "token_use": "access",
    1072            1 :                 "scope": "aws.cognito.signin.user.admin",
    1073            1 :                 "auth_time":now,
    1074            1 :                 "exp":now + 60,
    1075            1 :                 "iat":now,
    1076            1 :                 "jti": "b241614b-0b93-4bdc-96db-0a3c7061d9c0",
    1077            1 :                 "username": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
    1078            1 :             }},
    1079            1 :         );
    1080            1 : 
    1081            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1082            1 : 
    1083            1 :         let jwks_addr = jwks_server(move |_path| Some(serde_json::to_vec(&jwks).unwrap())).await;
    1084            1 : 
    1085            1 :         let role_name = RoleName::from("anonymous");
    1086            1 :         let rules = vec![AuthRule {
    1087            1 :             id: "aws-cognito".to_owned(),
    1088            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1089            1 :             audience: None,
    1090            1 :             role_names: vec![RoleNameInt::from(&role_name)],
    1091            1 :         }];
    1092            1 : 
    1093            1 :         let fetch = Fetch(rules);
    1094            1 :         let jwk_cache = JwkCache::default();
    1095            1 : 
    1096            1 :         let endpoint = EndpointId::from("ep");
    1097            1 : 
    1098            1 :         jwk_cache
    1099            1 :             .check_jwt(
    1100            1 :                 &RequestMonitoring::test(),
    1101            1 :                 endpoint.clone(),
    1102            1 :                 &role_name,
    1103            1 :                 &fetch,
    1104            1 :                 &token,
    1105            1 :             )
    1106            3 :             .await
    1107            1 :             .unwrap();
    1108            1 :     }
    1109              : 
    1110              :     #[tokio::test]
    1111            1 :     async fn check_jwt_invalid_signature() {
    1112            1 :         let (_, jwk) = new_ec_jwk("1".into());
    1113            1 :         let (key, _) = new_ec_jwk("1".into());
    1114            1 : 
    1115            1 :         // has a matching kid, but signed by the wrong key
    1116            1 :         let bad_jwt = new_ec_jwt("1".into(), &key);
    1117            1 : 
    1118            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1119            1 :         let jwks_addr = jwks_server(move |path| match path {
    1120            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1121            1 :             _ => None,
    1122            1 :         })
    1123            1 :         .await;
    1124            1 : 
    1125            1 :         let role = RoleName::from("authenticated");
    1126            1 : 
    1127            1 :         let rules = vec![AuthRule {
    1128            1 :             id: String::new(),
    1129            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1130            1 :             audience: None,
    1131            1 :             role_names: vec![RoleNameInt::from(&role)],
    1132            1 :         }];
    1133            1 : 
    1134            1 :         let fetch = Fetch(rules);
    1135            1 :         let jwk_cache = JwkCache::default();
    1136            1 : 
    1137            1 :         let ep = EndpointId::from("ep");
    1138            1 : 
    1139            1 :         let ctx = RequestMonitoring::test();
    1140            1 :         let err = jwk_cache
    1141            1 :             .check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
    1142            3 :             .await
    1143            1 :             .unwrap_err();
    1144            1 :         assert!(
    1145            1 :             matches!(err, JwtError::Signature(_)),
    1146            1 :             "expected \"signature error\", got {err:?}"
    1147            1 :         );
    1148            1 :     }
    1149              : 
    1150              :     #[tokio::test]
    1151            1 :     async fn check_jwt_unknown_role() {
    1152            1 :         let (key, jwk) = new_rsa_jwk(RS1, "1".into());
    1153            1 :         let jwt = new_rsa_jwt("1".into(), key);
    1154            1 : 
    1155            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1156            1 :         let jwks_addr = jwks_server(move |path| match path {
    1157            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1158            1 :             _ => None,
    1159            1 :         })
    1160            1 :         .await;
    1161            1 : 
    1162            1 :         let role = RoleName::from("authenticated");
    1163            1 :         let rules = vec![AuthRule {
    1164            1 :             id: String::new(),
    1165            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1166            1 :             audience: None,
    1167            1 :             role_names: vec![RoleNameInt::from(&role)],
    1168            1 :         }];
    1169            1 : 
    1170            1 :         let fetch = Fetch(rules);
    1171            1 :         let jwk_cache = JwkCache::default();
    1172            1 : 
    1173            1 :         let ep = EndpointId::from("ep");
    1174            1 : 
    1175            1 :         // this role_name is not accepted
    1176            1 :         let bad_role_name = RoleName::from("cloud_admin");
    1177            1 : 
    1178            1 :         let ctx = RequestMonitoring::test();
    1179            1 :         let err = jwk_cache
    1180            1 :             .check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
    1181            3 :             .await
    1182            1 :             .unwrap_err();
    1183            1 : 
    1184            1 :         assert!(
    1185            1 :             matches!(err, JwtError::JwkNotFound),
    1186            1 :             "expected \"jwk not found\", got {err:?}"
    1187            1 :         );
    1188            1 :     }
    1189              : 
    1190              :     #[tokio::test]
    1191            1 :     async fn check_jwt_invalid_claims() {
    1192            1 :         let (key, jwk) = new_ec_jwk("1".into());
    1193            1 : 
    1194            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1195            1 :         let jwks_addr = jwks_server(move |path| match path {
    1196            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1197            1 :             _ => None,
    1198            1 :         })
    1199            1 :         .await;
    1200            1 : 
    1201            1 :         let now = SystemTime::now()
    1202            1 :             .duration_since(SystemTime::UNIX_EPOCH)
    1203            1 :             .unwrap()
    1204            1 :             .as_secs();
    1205            1 : 
    1206            1 :         struct Test {
    1207            1 :             body: serde_json::Value,
    1208            1 :             error: JwtClaimsError,
    1209            1 :         }
    1210            1 : 
    1211            1 :         let table = vec![
    1212            1 :             Test {
    1213            1 :                 body: json! {{
    1214            1 :                     "nbf": now + 60,
    1215            1 :                     "aud": "neon",
    1216            1 :                 }},
    1217            1 :                 error: JwtClaimsError::JwtTokenNotYetReadyToUse,
    1218            1 :             },
    1219            1 :             Test {
    1220            1 :                 body: json! {{
    1221            1 :                     "exp": now - 60,
    1222            1 :                     "aud": ["neon"],
    1223            1 :                 }},
    1224            1 :                 error: JwtClaimsError::JwtTokenHasExpired,
    1225            1 :             },
    1226            1 :             Test {
    1227            1 :                 body: json! {{
    1228            1 :                 }},
    1229            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1230            1 :             },
    1231            1 :             Test {
    1232            1 :                 body: json! {{
    1233            1 :                     "aud": [],
    1234            1 :                 }},
    1235            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1236            1 :             },
    1237            1 :             Test {
    1238            1 :                 body: json! {{
    1239            1 :                     "aud": "foo",
    1240            1 :                 }},
    1241            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1242            1 :             },
    1243            1 :             Test {
    1244            1 :                 body: json! {{
    1245            1 :                     "aud": ["foo"],
    1246            1 :                 }},
    1247            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1248            1 :             },
    1249            1 :             Test {
    1250            1 :                 body: json! {{
    1251            1 :                     "aud": ["foo", "bar"],
    1252            1 :                 }},
    1253            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1254            1 :             },
    1255            1 :         ];
    1256            1 : 
    1257            1 :         let role = RoleName::from("authenticated");
    1258            1 : 
    1259            1 :         let rules = vec![AuthRule {
    1260            1 :             id: String::new(),
    1261            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1262            1 :             audience: Some("neon".to_string()),
    1263            1 :             role_names: vec![RoleNameInt::from(&role)],
    1264            1 :         }];
    1265            1 : 
    1266            1 :         let fetch = Fetch(rules);
    1267            1 :         let jwk_cache = JwkCache::default();
    1268            1 : 
    1269            1 :         let ep = EndpointId::from("ep");
    1270            1 : 
    1271            1 :         let ctx = RequestMonitoring::test();
    1272            8 :         for test in table {
    1273            7 :             let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
    1274            7 : 
    1275            7 :             match jwk_cache
    1276            7 :                 .check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
    1277            3 :                 .await
    1278            1 :             {
    1279            7 :                 Err(JwtError::InvalidClaims(error)) if error == test.error => {}
    1280            1 :                 Err(err) => {
    1281            0 :                     panic!("expected {:?}, got {err:?}", test.error)
    1282            1 :                 }
    1283            1 :                 Ok(_payload) => {
    1284            0 :                     panic!("expected {:?}, got ok", test.error)
    1285            1 :                 }
    1286            1 :             }
    1287            1 :         }
    1288            1 :     }
    1289              : 
    1290              :     #[tokio::test]
    1291            1 :     async fn check_jwk_keycloak_regression() {
    1292            1 :         let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into());
    1293            1 :         let valid_jwk = serde_json::to_value(valid_jwk).unwrap();
    1294            1 : 
    1295            1 :         // This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones.
    1296            1 :         // This is taken directly from keycloak.
    1297            1 :         let invalid_jwk = serde_json::json! {
    1298            1 :             {
    1299            1 :                 "kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ",
    1300            1 :                 "kty": "RSA",
    1301            1 :                 "alg": "RSA-OAEP",
    1302            1 :                 "use": "enc",
    1303            1 :                 "n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw",
    1304            1 :                 "e": "AQAB",
    1305            1 :                 "x5c": [
    1306            1 :                     "MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI="
    1307            1 :                 ],
    1308            1 :                 "x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs",
    1309            1 :                 "x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU"
    1310            1 :             }
    1311            1 :         };
    1312            1 : 
    1313            1 :         let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }};
    1314            1 :         let jwks_addr = jwks_server(move |path| match path {
    1315            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1316            1 :             _ => None,
    1317            1 :         })
    1318            1 :         .await;
    1319            1 : 
    1320            1 :         let role_name = RoleName::from("anonymous");
    1321            1 :         let role = RoleNameInt::from(&role_name);
    1322            1 : 
    1323            1 :         let rules = vec![AuthRule {
    1324            1 :             id: "foo".to_owned(),
    1325            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1326            1 :             audience: None,
    1327            1 :             role_names: vec![role],
    1328            1 :         }];
    1329            1 : 
    1330            1 :         let fetch = Fetch(rules);
    1331            1 :         let jwk_cache = JwkCache::default();
    1332            1 : 
    1333            1 :         let endpoint = EndpointId::from("ep");
    1334            1 : 
    1335            1 :         let token = new_rsa_jwt("rs1".into(), rs);
    1336            1 : 
    1337            1 :         jwk_cache
    1338            1 :             .check_jwt(
    1339            1 :                 &RequestMonitoring::test(),
    1340            1 :                 endpoint.clone(),
    1341            1 :                 &role_name,
    1342            1 :                 &fetch,
    1343            1 :                 &token,
    1344            1 :             )
    1345            3 :             .await
    1346            1 :             .unwrap();
    1347            1 :     }
    1348              : }
        

Generated by: LCOV version 2.1-beta