LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - jwt.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 87.1 % 757 659
Test Date: 2025-07-16 12:29:03 Functions: 72.0 % 132 95

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::future::Future;
       3              : use std::sync::Arc;
       4              : use std::time::{Duration, SystemTime};
       5              : 
       6              : use arc_swap::ArcSwapOption;
       7              : use base64::Engine as _;
       8              : use base64::prelude::BASE64_URL_SAFE_NO_PAD;
       9              : use clashmap::ClashMap;
      10              : use jose_jwk::crypto::KeyInfo;
      11              : use reqwest::{Client, redirect};
      12              : use reqwest_retry::RetryTransientMiddleware;
      13              : use reqwest_retry::policies::ExponentialBackoff;
      14              : use serde::de::Visitor;
      15              : use serde::{Deserialize, Deserializer};
      16              : use serde_json::value::RawValue;
      17              : use signature::Verifier;
      18              : use thiserror::Error;
      19              : use tokio::time::Instant;
      20              : 
      21              : use crate::auth::backend::ComputeCredentialKeys;
      22              : use crate::context::RequestContext;
      23              : use crate::control_plane::errors::GetEndpointJwksError;
      24              : use crate::http::read_body_with_limit;
      25              : use crate::intern::RoleNameInt;
      26              : use crate::types::{EndpointId, RoleName};
      27              : 
      28              : // TODO(conrad): make these configurable.
      29              : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
      30              : const MIN_RENEW: Duration = Duration::from_secs(30);
      31              : const AUTO_RENEW: Duration = Duration::from_secs(300);
      32              : const MAX_RENEW: Duration = Duration::from_secs(3600);
      33              : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
      34              : const JWKS_USER_AGENT: &str = "neon-proxy";
      35              : 
      36              : const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
      37              : const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
      38              : const JWKS_FETCH_RETRIES: u32 = 3;
      39              : 
      40              : /// How to get the JWT auth rules
      41              : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
      42              :     fn fetch_auth_rules(
      43              :         &self,
      44              :         ctx: &RequestContext,
      45              :         endpoint: EndpointId,
      46              :     ) -> impl Future<Output = Result<Vec<AuthRule>, FetchAuthRulesError>> + Send;
      47              : }
      48              : 
      49              : #[derive(Error, Debug)]
      50              : pub(crate) enum FetchAuthRulesError {
      51              :     #[error(transparent)]
      52              :     GetEndpointJwks(#[from] GetEndpointJwksError),
      53              : 
      54              :     #[error("JWKs settings for this role were not configured")]
      55              :     RoleJwksNotConfigured,
      56              : }
      57              : 
      58              : #[derive(Clone)]
      59              : pub(crate) struct AuthRule {
      60              :     pub(crate) id: String,
      61              :     pub(crate) jwks_url: url::Url,
      62              :     pub(crate) audience: Option<String>,
      63              :     pub(crate) role_names: Vec<RoleNameInt>,
      64              : }
      65              : 
      66              : pub struct JwkCache {
      67              :     client: reqwest_middleware::ClientWithMiddleware,
      68              : 
      69              :     map: ClashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
      70              : }
      71              : 
      72              : pub(crate) struct JwkCacheEntry {
      73              :     /// Should refetch at least every hour to verify when old keys have been removed.
      74              :     /// Should refetch when new key IDs are seen only every 5 minutes or so
      75              :     last_retrieved: Instant,
      76              : 
      77              :     /// cplane will return multiple JWKs urls that we need to scrape.
      78              :     key_sets: ahash::HashMap<String, KeySet>,
      79              : }
      80              : 
      81              : impl JwkCacheEntry {
      82           19 :     fn find_jwk_and_audience(
      83           19 :         &self,
      84           19 :         key_id: &str,
      85           19 :         role_name: &RoleName,
      86           19 :     ) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
      87           19 :         self.key_sets
      88           19 :             .values()
      89              :             // make sure our requested role has access to the key set
      90           29 :             .filter(|key_set| key_set.role_names.iter().any(|role| **role == **role_name))
      91              :             // try and find the requested key-id in the key set
      92           22 :             .find_map(|key_set| {
      93           22 :                 key_set
      94           22 :                     .find_key(key_id)
      95           22 :                     .map(|jwk| (jwk, key_set.audience.as_deref()))
      96           22 :             })
      97           19 :     }
      98              : }
      99              : 
     100              : struct KeySet {
     101              :     jwks: jose_jwk::JwkSet,
     102              :     audience: Option<String>,
     103              :     role_names: Vec<RoleNameInt>,
     104              : }
     105              : 
     106              : impl KeySet {
     107           22 :     fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
     108           22 :         self.jwks
     109           22 :             .keys
     110           22 :             .iter()
     111           30 :             .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
     112           22 :     }
     113              : }
     114              : 
     115              : pub(crate) struct JwkCacheEntryLock {
     116              :     cached: ArcSwapOption<JwkCacheEntry>,
     117              :     lookup: tokio::sync::Semaphore,
     118              : }
     119              : 
     120              : impl Default for JwkCacheEntryLock {
     121            7 :     fn default() -> Self {
     122            7 :         JwkCacheEntryLock {
     123            7 :             cached: ArcSwapOption::empty(),
     124            7 :             lookup: tokio::sync::Semaphore::new(1),
     125            7 :         }
     126            7 :     }
     127              : }
     128              : 
     129            0 : #[derive(Deserialize)]
     130              : struct JwkSet<'a> {
     131              :     /// we parse into raw-value because not all keys in a JWKS are ones
     132              :     /// we can parse directly, so we parse them lazily.
     133              :     #[serde(borrow)]
     134              :     keys: Vec<&'a RawValue>,
     135              : }
     136              : 
     137              : /// Given a jwks_url, fetch the JWKS and parse out all the signing JWKs.
     138              : /// Returns `None` and log a warning if there are any errors.
     139            9 : async fn fetch_jwks(
     140            9 :     client: &reqwest_middleware::ClientWithMiddleware,
     141            9 :     jwks_url: url::Url,
     142            9 : ) -> Option<jose_jwk::JwkSet> {
     143            9 :     let req = client.get(jwks_url.clone());
     144              :     // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
     145            9 :     let resp = req.send().await.and_then(|r| {
     146            9 :         r.error_for_status()
     147            9 :             .map_err(reqwest_middleware::Error::Reqwest)
     148            9 :     });
     149              : 
     150            9 :     let resp = match resp {
     151            9 :         Ok(r) => r,
     152              :         // TODO: should we re-insert JWKs if we want to keep this JWKs URL?
     153              :         // I expect these failures would be quite sparse.
     154            0 :         Err(e) => {
     155            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not fetch JWKs");
     156            0 :             return None;
     157              :         }
     158              :     };
     159              : 
     160            9 :     let resp: http::Response<reqwest::Body> = resp.into();
     161              : 
     162            9 :     let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE).await {
     163            9 :         Ok(bytes) => bytes,
     164            0 :         Err(e) => {
     165            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
     166            0 :             return None;
     167              :         }
     168              :     };
     169              : 
     170            9 :     let jwks = match serde_json::from_slice::<JwkSet>(&bytes) {
     171            9 :         Ok(jwks) => jwks,
     172            0 :         Err(e) => {
     173            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
     174            0 :             return None;
     175              :         }
     176              :     };
     177              : 
     178              :     // `jose_jwk::Jwk` is quite large (288 bytes). Let's not pre-allocate for what we don't need.
     179              :     //
     180              :     // Even though we limit our responses to 64KiB, we could still receive a payload like
     181              :     // `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`. Parsing this as `RawValue` uses 468KiB.
     182              :     // Pre-allocating the corresponding `Vec::<jose_jwk::Jwk>::with_capacity(30000)` uses 8.2MiB.
     183            9 :     let mut keys = vec![];
     184              : 
     185            9 :     let mut failed = 0;
     186           23 :     for key in jwks.keys {
     187           14 :         let key = match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
     188           13 :             Ok(key) => key,
     189            1 :             Err(e) => {
     190            1 :                 tracing::debug!(url=?jwks_url, failed=?e, "could not decode JWK");
     191            1 :                 failed += 1;
     192            1 :                 continue;
     193              :             }
     194              :         };
     195              : 
     196              :         // if `use` (called `cls` in rust) is specified to be something other than signing,
     197              :         // we can skip storing it.
     198           13 :         if key
     199           13 :             .prm
     200           13 :             .cls
     201           13 :             .as_ref()
     202           13 :             .is_some_and(|c| *c != jose_jwk::Class::Signing)
     203              :         {
     204            0 :             continue;
     205           13 :         }
     206              : 
     207           13 :         keys.push(key);
     208              :     }
     209              : 
     210            9 :     keys.shrink_to_fit();
     211              : 
     212            9 :     if failed > 0 {
     213            1 :         tracing::warn!(url=?jwks_url, failed, "could not decode JWKs");
     214            8 :     }
     215              : 
     216            9 :     if keys.is_empty() {
     217            0 :         tracing::warn!(url=?jwks_url, "no valid JWKs found inside the response body");
     218            0 :         return None;
     219            9 :     }
     220              : 
     221            9 :     Some(jose_jwk::JwkSet { keys })
     222            9 : }
     223              : 
     224              : impl JwkCacheEntryLock {
     225            7 :     async fn acquire_permit(self: &Arc<Self>) -> JwkRenewalPermit<'_> {
     226            7 :         JwkRenewalPermit::acquire_permit(self).await
     227            7 :     }
     228              : 
     229            0 :     fn try_acquire_permit(self: &Arc<Self>) -> Option<JwkRenewalPermit<'_>> {
     230            0 :         JwkRenewalPermit::try_acquire_permit(self)
     231            0 :     }
     232              : 
     233            7 :     async fn renew_jwks<F: FetchAuthRules>(
     234            7 :         &self,
     235            7 :         _permit: JwkRenewalPermit<'_>,
     236            7 :         ctx: &RequestContext,
     237            7 :         client: &reqwest_middleware::ClientWithMiddleware,
     238            7 :         endpoint: EndpointId,
     239            7 :         auth_rules: &F,
     240            7 :     ) -> Result<Arc<JwkCacheEntry>, JwtError> {
     241              :         // double check that no one beat us to updating the cache.
     242            7 :         let now = Instant::now();
     243            7 :         let guard = self.cached.load_full();
     244            7 :         if let Some(cached) = guard {
     245            0 :             let last_update = now.duration_since(cached.last_retrieved);
     246            0 :             if last_update < Duration::from_secs(300) {
     247            0 :                 return Ok(cached);
     248            0 :             }
     249            7 :         }
     250              : 
     251            7 :         let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
     252            7 :         let mut key_sets =
     253            7 :             ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
     254              : 
     255              :         // TODO(conrad): run concurrently
     256              :         // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
     257           16 :         for rule in rules {
     258            9 :             if let Some(jwks) = fetch_jwks(client, rule.jwks_url).await {
     259            9 :                 key_sets.insert(
     260            9 :                     rule.id,
     261            9 :                     KeySet {
     262            9 :                         jwks,
     263            9 :                         audience: rule.audience,
     264            9 :                         role_names: rule.role_names,
     265            9 :                     },
     266            9 :                 );
     267            9 :             }
     268              :         }
     269              : 
     270            7 :         let entry = Arc::new(JwkCacheEntry {
     271            7 :             last_retrieved: now,
     272            7 :             key_sets,
     273            7 :         });
     274            7 :         self.cached.swap(Some(Arc::clone(&entry)));
     275              : 
     276            7 :         Ok(entry)
     277            7 :     }
     278              : 
     279           19 :     async fn get_or_update_jwk_cache<F: FetchAuthRules>(
     280           19 :         self: &Arc<Self>,
     281           19 :         ctx: &RequestContext,
     282           19 :         client: &reqwest_middleware::ClientWithMiddleware,
     283           19 :         endpoint: EndpointId,
     284           19 :         fetch: &F,
     285           19 :     ) -> Result<Arc<JwkCacheEntry>, JwtError> {
     286           19 :         let now = Instant::now();
     287           19 :         let guard = self.cached.load_full();
     288              : 
     289              :         // if we have no cached JWKs, try and get some
     290           19 :         let Some(cached) = guard else {
     291            7 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     292            7 :             let permit = self.acquire_permit().await;
     293            7 :             return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
     294              :         };
     295              : 
     296           12 :         let last_update = now.duration_since(cached.last_retrieved);
     297              : 
     298              :         // check if the cached JWKs need updating.
     299           12 :         if last_update > MAX_RENEW {
     300            0 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     301            0 :             let permit = self.acquire_permit().await;
     302              : 
     303              :             // it's been too long since we checked the keys. wait for them to update.
     304            0 :             return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
     305           12 :         }
     306              : 
     307              :         // every 5 minutes we should spawn a job to eagerly update the token.
     308           12 :         if last_update > AUTO_RENEW {
     309            0 :             if let Some(permit) = self.try_acquire_permit() {
     310            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit acquired");
     311            0 :                 let permit = permit.into_owned();
     312            0 :                 let entry = self.clone();
     313            0 :                 let client = client.clone();
     314            0 :                 let fetch = fetch.clone();
     315            0 :                 let ctx = ctx.clone();
     316            0 :                 tokio::spawn(async move {
     317            0 :                     if let Err(e) = entry
     318            0 :                         .renew_jwks(permit, &ctx, &client, endpoint, &fetch)
     319            0 :                         .await
     320              :                     {
     321            0 :                         tracing::warn!(error=?e, "could not fetch JWKs in background job");
     322            0 :                     }
     323            0 :                 });
     324              :             } else {
     325            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
     326              :             }
     327           12 :         }
     328              : 
     329           12 :         Ok(cached)
     330           19 :     }
     331              : 
     332           19 :     async fn check_jwt<F: FetchAuthRules>(
     333           19 :         self: &Arc<Self>,
     334           19 :         ctx: &RequestContext,
     335           19 :         jwt: &str,
     336           19 :         client: &reqwest_middleware::ClientWithMiddleware,
     337           19 :         endpoint: EndpointId,
     338           19 :         role_name: &RoleName,
     339           19 :         fetch: &F,
     340           19 :     ) -> Result<ComputeCredentialKeys, JwtError> {
     341              :         // JWT compact form is defined to be
     342              :         // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
     343              :         // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
     344              : 
     345           19 :         let (header_payload, signature) = jwt
     346           19 :             .rsplit_once('.')
     347           19 :             .ok_or(JwtEncodingError::InvalidCompactForm)?;
     348           19 :         let (header, payload) = header_payload
     349           19 :             .split_once('.')
     350           19 :             .ok_or(JwtEncodingError::InvalidCompactForm)?;
     351              : 
     352           19 :         let header = BASE64_URL_SAFE_NO_PAD.decode(header)?;
     353           19 :         let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
     354              : 
     355           19 :         let payloadb = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
     356           19 :         let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
     357              : 
     358           19 :         if let Some(iss) = &payload.issuer {
     359           12 :             ctx.set_jwt_issuer(iss.as_ref().to_owned());
     360           12 :         }
     361              : 
     362           19 :         let sig = BASE64_URL_SAFE_NO_PAD.decode(signature)?;
     363              : 
     364           19 :         let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
     365              : 
     366           19 :         let mut guard = self
     367           19 :             .get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
     368           19 :             .await?;
     369              : 
     370              :         // get the key from the JWKs if possible. If not, wait for the keys to update.
     371           18 :         let (jwk, expected_audience) = loop {
     372           19 :             match guard.find_jwk_and_audience(&kid, role_name) {
     373           18 :                 Some(jwk) => break jwk,
     374            1 :                 None if guard.last_retrieved.elapsed() > MIN_RENEW => {
     375            0 :                     let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     376              : 
     377            0 :                     let permit = self.acquire_permit().await;
     378            0 :                     guard = self
     379            0 :                         .renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
     380            0 :                         .await?;
     381              :                 }
     382            1 :                 _ => return Err(JwtError::JwkNotFound),
     383              :             }
     384              :         };
     385              : 
     386           18 :         if !jwk.is_supported(&header.algorithm) {
     387            0 :             return Err(JwtError::SignatureAlgorithmNotSupported);
     388           18 :         }
     389              : 
     390           18 :         match &jwk.key {
     391           13 :             jose_jwk::Key::Ec(key) => {
     392           13 :                 verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
     393              :             }
     394            5 :             jose_jwk::Key::Rsa(key) => {
     395            5 :                 verify_rsa_signature(header_payload.as_bytes(), &sig, key, &header.algorithm)?;
     396              :             }
     397            0 :             key => return Err(JwtError::UnsupportedKeyType(key.into())),
     398              :         }
     399              : 
     400           17 :         tracing::debug!(?payload, "JWT signature valid with claims");
     401              : 
     402           17 :         if let Some(aud) = expected_audience
     403            7 :             && payload.audience.0.iter().all(|s| s != aud)
     404              :         {
     405            5 :             return Err(JwtError::InvalidClaims(
     406            5 :                 JwtClaimsError::InvalidJwtTokenAudience,
     407            5 :             ));
     408           12 :         }
     409              : 
     410           12 :         let now = SystemTime::now();
     411              : 
     412           12 :         if let Some(exp) = payload.expiration
     413           11 :             && now >= exp + CLOCK_SKEW_LEEWAY
     414              :         {
     415            1 :             return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
     416            1 :                 exp.duration_since(SystemTime::UNIX_EPOCH)
     417            1 :                     .unwrap_or_default()
     418            1 :                     .as_secs(),
     419            1 :             )));
     420           11 :         }
     421              : 
     422           11 :         if let Some(nbf) = payload.not_before
     423           10 :             && nbf >= now + CLOCK_SKEW_LEEWAY
     424              :         {
     425            1 :             return Err(JwtError::InvalidClaims(
     426            1 :                 JwtClaimsError::JwtTokenNotYetReadyToUse(
     427            1 :                     nbf.duration_since(SystemTime::UNIX_EPOCH)
     428            1 :                         .unwrap_or_default()
     429            1 :                         .as_secs(),
     430            1 :                 ),
     431            1 :             ));
     432           10 :         }
     433              : 
     434           10 :         Ok(ComputeCredentialKeys::JwtPayload(payloadb))
     435           19 :     }
     436              : }
     437              : 
     438              : impl JwkCache {
     439           19 :     pub(crate) async fn check_jwt<F: FetchAuthRules>(
     440           19 :         &self,
     441           19 :         ctx: &RequestContext,
     442           19 :         endpoint: EndpointId,
     443           19 :         role_name: &RoleName,
     444           19 :         fetch: &F,
     445           19 :         jwt: &str,
     446           19 :     ) -> Result<ComputeCredentialKeys, JwtError> {
     447              :         // try with just a read lock first
     448           19 :         let key = (endpoint.clone(), role_name.clone());
     449           19 :         let entry = self.map.get(&key).as_deref().map(Arc::clone);
     450           19 :         let entry = entry.unwrap_or_else(|| {
     451              :             // acquire a write lock after to insert.
     452            7 :             let entry = self.map.entry(key).or_default();
     453            7 :             Arc::clone(&*entry)
     454            7 :         });
     455              : 
     456           19 :         entry
     457           19 :             .check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
     458           19 :             .await
     459           19 :     }
     460              : }
     461              : 
     462              : impl Default for JwkCache {
     463            9 :     fn default() -> Self {
     464            9 :         let client = Client::builder()
     465            9 :             .user_agent(JWKS_USER_AGENT)
     466            9 :             .redirect(redirect::Policy::none())
     467            9 :             .tls_built_in_native_certs(true)
     468            9 :             .connect_timeout(JWKS_CONNECT_TIMEOUT)
     469            9 :             .timeout(JWKS_FETCH_TIMEOUT)
     470            9 :             .build()
     471            9 :             .expect("client config should be valid");
     472              : 
     473              :         // Retry up to 3 times with increasing intervals between attempts.
     474            9 :         let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES);
     475              : 
     476            9 :         let client = reqwest_middleware::ClientBuilder::new(client)
     477            9 :             .with(RetryTransientMiddleware::new_with_policy(retry_policy))
     478            9 :             .build();
     479              : 
     480            9 :         JwkCache {
     481            9 :             client,
     482            9 :             map: ClashMap::default(),
     483            9 :         }
     484            9 :     }
     485              : }
     486              : 
     487           13 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> Result<(), JwtError> {
     488              :     use ecdsa::Signature;
     489              :     use signature::Verifier;
     490              : 
     491           13 :     match key.crv {
     492              :         jose_jwk::EcCurves::P256 => {
     493           13 :             let pk = p256::PublicKey::try_from(key).map_err(JwtError::InvalidP256Key)?;
     494           13 :             let key = p256::ecdsa::VerifyingKey::from(&pk);
     495           13 :             let sig = Signature::from_slice(sig)?;
     496           13 :             key.verify(data, &sig)?;
     497              :         }
     498            0 :         key => return Err(JwtError::UnsupportedEcKeyType(key)),
     499              :     }
     500              : 
     501           12 :     Ok(())
     502           13 : }
     503              : 
     504            5 : fn verify_rsa_signature(
     505            5 :     data: &[u8],
     506            5 :     sig: &[u8],
     507            5 :     key: &jose_jwk::Rsa,
     508            5 :     alg: &jose_jwa::Algorithm,
     509            5 : ) -> Result<(), JwtError> {
     510              :     use jose_jwa::{Algorithm, Signing};
     511              :     use rsa::RsaPublicKey;
     512              :     use rsa::pkcs1v15::{Signature, VerifyingKey};
     513              : 
     514            5 :     let key = RsaPublicKey::try_from(key).map_err(JwtError::InvalidRsaKey)?;
     515              : 
     516            5 :     match alg {
     517              :         Algorithm::Signing(Signing::Rs256) => {
     518            5 :             let key = VerifyingKey::<sha2::Sha256>::new(key);
     519            5 :             let sig = Signature::try_from(sig)?;
     520            5 :             key.verify(data, &sig)?;
     521              :         }
     522            0 :         _ => return Err(JwtError::InvalidRsaSigningAlgorithm),
     523              :     }
     524              : 
     525            5 :     Ok(())
     526            5 : }
     527              : 
     528              : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
     529            0 : #[derive(serde::Deserialize, serde::Serialize)]
     530              : struct JwtHeader<'a> {
     531              :     /// must be a supported alg
     532              :     #[serde(rename = "alg")]
     533              :     algorithm: jose_jwa::Algorithm,
     534              :     /// key id, must be provided for our usecase
     535              :     #[serde(rename = "kid", borrow)]
     536              :     key_id: Option<Cow<'a, str>>,
     537              : }
     538              : 
     539              : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
     540              : #[derive(serde::Deserialize, Debug)]
     541              : #[allow(dead_code)]
     542              : struct JwtPayload<'a> {
     543              :     /// Audience - Recipient for which the JWT is intended
     544              :     #[serde(rename = "aud", default)]
     545              :     audience: OneOrMany,
     546              :     /// Expiration - Time after which the JWT expires
     547              :     #[serde(rename = "exp", deserialize_with = "numeric_date_opt", default)]
     548              :     expiration: Option<SystemTime>,
     549              :     /// Not before - Time before which the JWT is not valid
     550              :     #[serde(rename = "nbf", deserialize_with = "numeric_date_opt", default)]
     551              :     not_before: Option<SystemTime>,
     552              : 
     553              :     // the following entries are only extracted for the sake of debug logging.
     554              :     /// Issuer of the JWT
     555              :     #[serde(rename = "iss", borrow)]
     556              :     issuer: Option<Cow<'a, str>>,
     557              :     /// Subject of the JWT (the user)
     558              :     #[serde(rename = "sub", borrow)]
     559              :     subject: Option<Cow<'a, str>>,
     560              :     /// Unique token identifier
     561              :     #[serde(rename = "jti", borrow)]
     562              :     jwt_id: Option<Cow<'a, str>>,
     563              :     /// Unique session identifier
     564              :     #[serde(rename = "sid", borrow)]
     565              :     session_id: Option<Cow<'a, str>>,
     566              : }
     567              : 
     568              : /// `OneOrMany` supports parsing either a single item or an array of items.
     569              : ///
     570              : /// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
     571              : ///
     572              : /// > The "aud" (audience) claim identifies the recipients that the JWT is
     573              : /// > intended for.  Each principal intended to process the JWT MUST
     574              : /// > identify itself with a value in the audience claim.  If the principal
     575              : /// > processing the claim does not identify itself with a value in the
     576              : /// > "aud" claim when this claim is present, then the JWT MUST be
     577              : /// > rejected.  In the general case, the "aud" value is **an array of case-
     578              : /// > sensitive strings**, each containing a StringOrURI value.  In the
     579              : /// > special case when the JWT has one audience, the "aud" value MAY be a
     580              : /// > **single case-sensitive string** containing a StringOrURI value.  The
     581              : /// > interpretation of audience values is generally application specific.
     582              : /// > Use of this claim is OPTIONAL.
     583              : #[derive(Default, Debug)]
     584              : struct OneOrMany(Vec<String>);
     585              : 
     586              : impl<'de> Deserialize<'de> for OneOrMany {
     587           17 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     588           17 :     where
     589           17 :         D: Deserializer<'de>,
     590              :     {
     591              :         struct OneOrManyVisitor;
     592              :         impl<'de> Visitor<'de> for OneOrManyVisitor {
     593              :             type Value = OneOrMany;
     594              : 
     595            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
     596            0 :                 formatter.write_str("a single string or an array of strings")
     597            0 :             }
     598              : 
     599            2 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
     600            2 :             where
     601            2 :                 E: serde::de::Error,
     602              :             {
     603            2 :                 Ok(OneOrMany(vec![v.to_owned()]))
     604            2 :             }
     605              : 
     606           15 :             fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
     607           15 :             where
     608           15 :                 A: serde::de::SeqAccess<'de>,
     609              :             {
     610           15 :                 let mut v = vec![];
     611           52 :                 while let Some(s) = seq.next_element()? {
     612           37 :                     v.push(s);
     613           37 :                 }
     614           15 :                 Ok(OneOrMany(v))
     615           15 :             }
     616              :         }
     617           17 :         deserializer.deserialize_any(OneOrManyVisitor)
     618           17 :     }
     619              : }
     620              : 
     621           25 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
     622           25 :     <Option<u64>>::deserialize(d)?
     623           25 :         .map(|t| {
     624           25 :             SystemTime::UNIX_EPOCH
     625           25 :                 .checked_add(Duration::from_secs(t))
     626           25 :                 .ok_or_else(|| {
     627            0 :                     serde::de::Error::custom(format_args!("timestamp out of bounds: {t}"))
     628            0 :                 })
     629           25 :         })
     630           25 :         .transpose()
     631           25 : }
     632              : 
     633              : struct JwkRenewalPermit<'a> {
     634              :     inner: Option<JwkRenewalPermitInner<'a>>,
     635              : }
     636              : 
     637              : enum JwkRenewalPermitInner<'a> {
     638              :     Owned(Arc<JwkCacheEntryLock>),
     639              :     Borrowed(&'a Arc<JwkCacheEntryLock>),
     640              : }
     641              : 
     642              : impl JwkRenewalPermit<'_> {
     643            0 :     fn into_owned(mut self) -> JwkRenewalPermit<'static> {
     644            0 :         JwkRenewalPermit {
     645            0 :             inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
     646            0 :         }
     647            0 :     }
     648              : 
     649            7 :     async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
     650            7 :         match from.lookup.acquire().await {
     651            7 :             Ok(permit) => {
     652            7 :                 permit.forget();
     653            7 :                 JwkRenewalPermit {
     654            7 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     655            7 :                 }
     656              :             }
     657            0 :             Err(_) => panic!("semaphore should not be closed"),
     658              :         }
     659            7 :     }
     660              : 
     661            0 :     fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
     662            0 :         match from.lookup.try_acquire() {
     663            0 :             Ok(permit) => {
     664            0 :                 permit.forget();
     665            0 :                 Some(JwkRenewalPermit {
     666            0 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     667            0 :                 })
     668              :             }
     669            0 :             Err(tokio::sync::TryAcquireError::NoPermits) => None,
     670            0 :             Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
     671              :         }
     672            0 :     }
     673              : }
     674              : 
     675              : impl JwkRenewalPermitInner<'_> {
     676            0 :     fn into_owned(self) -> JwkRenewalPermitInner<'static> {
     677            0 :         match self {
     678            0 :             JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
     679            0 :             JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
     680              :         }
     681            0 :     }
     682              : }
     683              : 
     684              : impl Drop for JwkRenewalPermit<'_> {
     685            7 :     fn drop(&mut self) {
     686            7 :         let entry = match &self.inner {
     687            0 :             None => return,
     688            0 :             Some(JwkRenewalPermitInner::Owned(p)) => p,
     689            7 :             Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
     690              :         };
     691            7 :         entry.lookup.add_permits(1);
     692            7 :     }
     693              : }
     694              : 
     695              : #[derive(Error, Debug)]
     696              : #[non_exhaustive]
     697              : pub(crate) enum JwtError {
     698              :     #[error("jwk not found")]
     699              :     JwkNotFound,
     700              : 
     701              :     #[error("missing key id")]
     702              :     MissingKeyId,
     703              : 
     704              :     #[error("Provided authentication token is not a valid JWT encoding")]
     705              :     JwtEncoding(#[from] JwtEncodingError),
     706              : 
     707              :     #[error(transparent)]
     708              :     InvalidClaims(#[from] JwtClaimsError),
     709              : 
     710              :     #[error("invalid P256 key")]
     711              :     InvalidP256Key(jose_jwk::crypto::Error),
     712              : 
     713              :     #[error("invalid RSA key")]
     714              :     InvalidRsaKey(jose_jwk::crypto::Error),
     715              : 
     716              :     #[error("invalid RSA signing algorithm")]
     717              :     InvalidRsaSigningAlgorithm,
     718              : 
     719              :     #[error("unsupported EC key type {0:?}")]
     720              :     UnsupportedEcKeyType(jose_jwk::EcCurves),
     721              : 
     722              :     #[error("unsupported key type {0:?}")]
     723              :     UnsupportedKeyType(KeyType),
     724              : 
     725              :     #[error("signature algorithm not supported")]
     726              :     SignatureAlgorithmNotSupported,
     727              : 
     728              :     #[error("signature error: {0}")]
     729              :     Signature(#[from] signature::Error),
     730              : 
     731              :     #[error("failed to fetch auth rules: {0}")]
     732              :     FetchAuthRules(#[from] FetchAuthRulesError),
     733              : }
     734              : 
     735              : impl From<base64::DecodeError> for JwtError {
     736            0 :     fn from(err: base64::DecodeError) -> Self {
     737            0 :         JwtEncodingError::Base64Decode(err).into()
     738            0 :     }
     739              : }
     740              : 
     741              : impl From<serde_json::Error> for JwtError {
     742            0 :     fn from(err: serde_json::Error) -> Self {
     743            0 :         JwtEncodingError::SerdeJson(err).into()
     744            0 :     }
     745              : }
     746              : 
     747              : #[derive(Error, Debug)]
     748              : #[non_exhaustive]
     749              : pub enum JwtEncodingError {
     750              :     #[error(transparent)]
     751              :     Base64Decode(#[from] base64::DecodeError),
     752              : 
     753              :     #[error(transparent)]
     754              :     SerdeJson(#[from] serde_json::Error),
     755              : 
     756              :     #[error("invalid compact form")]
     757              :     InvalidCompactForm,
     758              : }
     759              : 
     760              : #[derive(Error, Debug, PartialEq)]
     761              : #[non_exhaustive]
     762              : pub enum JwtClaimsError {
     763              :     #[error("invalid JWT token audience")]
     764              :     InvalidJwtTokenAudience,
     765              : 
     766              :     #[error("JWT token has expired (exp={0})")]
     767              :     JwtTokenHasExpired(u64),
     768              : 
     769              :     #[error("JWT token is not yet ready to use (nbf={0})")]
     770              :     JwtTokenNotYetReadyToUse(u64),
     771              : }
     772              : 
     773              : #[allow(dead_code, reason = "Debug use only")]
     774              : #[derive(Debug)]
     775              : pub(crate) enum KeyType {
     776              :     Ec(jose_jwk::EcCurves),
     777              :     Rsa,
     778              :     Oct,
     779              :     Okp(jose_jwk::OkpCurves),
     780              :     Unknown,
     781              : }
     782              : 
     783              : impl From<&jose_jwk::Key> for KeyType {
     784            0 :     fn from(key: &jose_jwk::Key) -> Self {
     785            0 :         match key {
     786            0 :             jose_jwk::Key::Ec(ec) => Self::Ec(ec.crv),
     787            0 :             jose_jwk::Key::Rsa(_rsa) => Self::Rsa,
     788            0 :             jose_jwk::Key::Oct(_oct) => Self::Oct,
     789            0 :             jose_jwk::Key::Okp(okp) => Self::Okp(okp.crv),
     790            0 :             _ => Self::Unknown,
     791              :         }
     792            0 :     }
     793              : }
     794              : 
     795              : #[cfg(test)]
     796              : mod tests {
     797              :     use std::future::IntoFuture;
     798              :     use std::net::SocketAddr;
     799              :     use std::time::SystemTime;
     800              : 
     801              :     use bytes::Bytes;
     802              :     use http::Response;
     803              :     use http_body_util::Full;
     804              :     use hyper::service::service_fn;
     805              :     use hyper_util::rt::TokioIo;
     806              :     use rand::rngs::OsRng;
     807              :     use rsa::pkcs8::DecodePrivateKey;
     808              :     use serde::Serialize;
     809              :     use serde_json::json;
     810              :     use signature::Signer;
     811              :     use tokio::net::TcpListener;
     812              : 
     813              :     use super::*;
     814              :     use crate::types::RoleName;
     815              : 
     816            6 :     fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
     817            6 :         let sk = p256::SecretKey::random(&mut OsRng);
     818            6 :         let pk = sk.public_key().into();
     819            6 :         let jwk = jose_jwk::Jwk {
     820            6 :             key: jose_jwk::Key::Ec(pk),
     821            6 :             prm: jose_jwk::Parameters {
     822            6 :                 kid: Some(kid),
     823            6 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
     824            6 :                 ..Default::default()
     825            6 :             },
     826            6 :         };
     827            6 :         (sk, jwk)
     828            6 :     }
     829              : 
     830            4 :     fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
     831            4 :         let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap();
     832            4 :         let pk = sk.to_public_key().into();
     833            4 :         let jwk = jose_jwk::Jwk {
     834            4 :             key: jose_jwk::Key::Rsa(pk),
     835            4 :             prm: jose_jwk::Parameters {
     836            4 :                 kid: Some(kid),
     837            4 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
     838            4 :                 ..Default::default()
     839            4 :             },
     840            4 :         };
     841            4 :         (sk, jwk)
     842            4 :     }
     843              : 
     844            8 :     fn now() -> u64 {
     845            8 :         SystemTime::now()
     846            8 :             .duration_since(SystemTime::UNIX_EPOCH)
     847            8 :             .unwrap()
     848            8 :             .as_secs()
     849            8 :     }
     850              : 
     851            7 :     fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
     852            7 :         let now = now();
     853            7 :         let body = typed_json::json! {{
     854            7 :             "exp": now + 3600,
     855            7 :             "nbf": now,
     856              :             "aud": ["audience1", "neon", "audience2"],
     857              :             "sub": "user1",
     858              :             "sid": "session1",
     859              :             "jti": "token1",
     860              :             "iss": "neon-testing",
     861              :         }};
     862            7 :         build_custom_jwt_payload(kid, body, sig)
     863            7 :     }
     864              : 
     865           15 :     fn build_custom_jwt_payload(
     866           15 :         kid: String,
     867           15 :         body: impl Serialize,
     868           15 :         sig: jose_jwa::Signing,
     869           15 :     ) -> String {
     870           15 :         let header = JwtHeader {
     871           15 :             algorithm: jose_jwa::Algorithm::Signing(sig),
     872           15 :             key_id: Some(Cow::Owned(kid)),
     873           15 :         };
     874              : 
     875           15 :         let header = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
     876           15 :         let body = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&body).unwrap());
     877              : 
     878           15 :         format!("{header}.{body}")
     879           15 :     }
     880              : 
     881            3 :     fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
     882              :         use p256::ecdsa::{Signature, SigningKey};
     883              : 
     884            3 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
     885            3 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     886            3 :         let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
     887              : 
     888            3 :         format!("{payload}.{sig}")
     889            3 :     }
     890              : 
     891            8 :     fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
     892              :         use p256::ecdsa::{Signature, SigningKey};
     893              : 
     894            8 :         let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
     895            8 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     896            8 :         let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
     897              : 
     898            8 :         format!("{payload}.{sig}")
     899            8 :     }
     900              : 
     901            4 :     fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
     902              :         use rsa::pkcs1v15::SigningKey;
     903              :         use rsa::signature::SignatureEncoding;
     904              : 
     905            4 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
     906            4 :         let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
     907            4 :         let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
     908              : 
     909            4 :         format!("{payload}.{sig}")
     910            4 :     }
     911              : 
     912              :     // RSA key gen is slow....
     913              :     const RS1: &str = "-----BEGIN PRIVATE KEY-----
     914              : MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y
     915              : aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA
     916              : SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA
     917              : cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X
     918              : mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s
     919              : PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM
     920              : nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9
     921              : ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi
     922              : 5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+
     923              : KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz
     924              : 1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP
     925              : We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony
     926              : UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu
     927              : GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+
     928              : HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU
     929              : f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24
     930              : DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY
     931              : 9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF
     932              : /4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL
     933              : 3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE
     934              : qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d
     935              : ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW
     936              : rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf
     937              : JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI
     938              : rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9
     939              : 8LMFlz1QPqSZYN/A/kOcLBfa3A==
     940              : -----END PRIVATE KEY-----
     941              : ";
     942              :     const RS2: &str = "-----BEGIN PRIVATE KEY-----
     943              : MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J
     944              : HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J
     945              : zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk
     946              : qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO
     947              : IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt
     948              : ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy
     949              : uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv
     950              : K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG
     951              : LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe
     952              : QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT
     953              : zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX
     954              : 0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR
     955              : +Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1
     956              : P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL
     957              : x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/
     958              : FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg
     959              : iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M
     960              : 4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA
     961              : iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj
     962              : N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB
     963              : uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U
     964              : qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP
     965              : WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB
     966              : rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne
     967              : X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
     968              : 5JiconnI5aLek0QVPoFaVXFa
     969              : -----END PRIVATE KEY-----
     970              : ";
     971              : 
     972              :     #[derive(Clone)]
     973              :     struct Fetch(Vec<AuthRule>);
     974              : 
     975              :     impl FetchAuthRules for Fetch {
     976            7 :         async fn fetch_auth_rules(
     977            7 :             &self,
     978            7 :             _ctx: &RequestContext,
     979            7 :             _endpoint: EndpointId,
     980            7 :         ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
     981            7 :             Ok(self.0.clone())
     982            7 :         }
     983              :     }
     984              : 
     985            6 :     async fn jwks_server(
     986            6 :         router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
     987            6 :     ) -> SocketAddr {
     988            6 :         let router = Arc::new(router);
     989            9 :         let service = service_fn(move |req| {
     990            9 :             let router = Arc::clone(&router);
     991            9 :             async move {
     992            9 :                 match router(req.uri().path()) {
     993            9 :                     Some(body) => Response::builder()
     994            9 :                         .status(200)
     995            9 :                         .body(Full::new(Bytes::from(body))),
     996            0 :                     None => Response::builder()
     997            0 :                         .status(404)
     998            0 :                         .body(Full::new(Bytes::new())),
     999              :                 }
    1000            9 :             }
    1001            9 :         });
    1002              : 
    1003            6 :         let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
    1004            6 :         let server = hyper::server::conn::http1::Builder::new();
    1005            6 :         let addr = listener.local_addr().unwrap();
    1006            6 :         tokio::spawn(async move {
    1007              :             loop {
    1008           12 :                 let (s, _) = listener.accept().await.unwrap();
    1009            6 :                 let serve = server.serve_connection(TokioIo::new(s), service.clone());
    1010            6 :                 tokio::spawn(serve.into_future());
    1011              :             }
    1012              :         });
    1013              : 
    1014            6 :         addr
    1015            6 :     }
    1016              : 
    1017              :     #[tokio::test]
    1018            1 :     async fn check_jwt_happy_path() {
    1019            1 :         let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
    1020            1 :         let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
    1021            1 :         let (ec1, jwk3) = new_ec_jwk("ec1".into());
    1022            1 :         let (ec2, jwk4) = new_ec_jwk("ec2".into());
    1023              : 
    1024            1 :         let foo_jwks = jose_jwk::JwkSet {
    1025            1 :             keys: vec![jwk1, jwk3],
    1026            1 :         };
    1027            1 :         let bar_jwks = jose_jwk::JwkSet {
    1028            1 :             keys: vec![jwk2, jwk4],
    1029            1 :         };
    1030              : 
    1031            4 :         let jwks_addr = jwks_server(move |path| match path {
    1032            4 :             "/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
    1033            2 :             "/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
    1034            0 :             _ => None,
    1035            4 :         })
    1036            1 :         .await;
    1037              : 
    1038            1 :         let role_name1 = RoleName::from("anonymous");
    1039            1 :         let role_name2 = RoleName::from("authenticated");
    1040              : 
    1041            1 :         let roles = vec![
    1042            1 :             RoleNameInt::from(&role_name1),
    1043            1 :             RoleNameInt::from(&role_name2),
    1044              :         ];
    1045            1 :         let rules = vec![
    1046            1 :             AuthRule {
    1047            1 :                 id: "foo".to_owned(),
    1048            1 :                 jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
    1049            1 :                 audience: None,
    1050            1 :                 role_names: roles.clone(),
    1051            1 :             },
    1052            1 :             AuthRule {
    1053            1 :                 id: "bar".to_owned(),
    1054            1 :                 jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
    1055            1 :                 audience: None,
    1056            1 :                 role_names: roles.clone(),
    1057            1 :             },
    1058              :         ];
    1059              : 
    1060            1 :         let fetch = Fetch(rules);
    1061            1 :         let jwk_cache = JwkCache::default();
    1062              : 
    1063            1 :         let endpoint = EndpointId::from("ep");
    1064              : 
    1065            1 :         let jwt1 = new_rsa_jwt("rs1".into(), rs1);
    1066            1 :         let jwt2 = new_rsa_jwt("rs2".into(), rs2);
    1067            1 :         let jwt3 = new_ec_jwt("ec1".into(), &ec1);
    1068            1 :         let jwt4 = new_ec_jwt("ec2".into(), &ec2);
    1069              : 
    1070            1 :         let tokens = [jwt1, jwt2, jwt3, jwt4];
    1071            1 :         let role_names = [role_name1, role_name2];
    1072            3 :         for role in &role_names {
    1073           10 :             for token in &tokens {
    1074            8 :                 jwk_cache
    1075            8 :                     .check_jwt(
    1076            8 :                         &RequestContext::test(),
    1077            8 :                         endpoint.clone(),
    1078            8 :                         role,
    1079            8 :                         &fetch,
    1080            8 :                         token,
    1081            8 :                     )
    1082            8 :                     .await
    1083            8 :                     .unwrap();
    1084            1 :             }
    1085            1 :         }
    1086            1 :     }
    1087              : 
    1088              :     /// AWS Cognito escapes the `/` in the URL.
    1089              :     #[tokio::test]
    1090            1 :     async fn check_jwt_regression_cognito_issuer() {
    1091            1 :         let (key, jwk) = new_ec_jwk("key".into());
    1092              : 
    1093            1 :         let now = now();
    1094            1 :         let token = new_custom_ec_jwt(
    1095            1 :             "key".into(),
    1096            1 :             &key,
    1097            1 :             typed_json::json! {{
    1098              :                 "sub": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
    1099              :                 // cognito uses `\/`. I cannot replicated that easily here as serde_json will refuse
    1100              :                 // to write that escape character. instead I will make a bogus URL using `\` instead.
    1101              :                 "iss": "https:\\\\cognito-idp.us-west-2.amazonaws.com\\us-west-2_abcdefgh",
    1102              :                 "client_id": "abcdefghijklmnopqrstuvwxyz",
    1103              :                 "origin_jti": "6759d132-3fe7-446e-9e90-2fe7e8017893",
    1104              :                 "event_id": "ec9c36ab-b01d-46a0-94e4-87fde6767065",
    1105              :                 "token_use": "access",
    1106              :                 "scope": "aws.cognito.signin.user.admin",
    1107            1 :                 "auth_time":now,
    1108            1 :                 "exp":now + 60,
    1109            1 :                 "iat":now,
    1110              :                 "jti": "b241614b-0b93-4bdc-96db-0a3c7061d9c0",
    1111              :                 "username": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
    1112              :             }},
    1113              :         );
    1114              : 
    1115            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1116              : 
    1117            1 :         let jwks_addr = jwks_server(move |_path| Some(serde_json::to_vec(&jwks).unwrap())).await;
    1118              : 
    1119            1 :         let role_name = RoleName::from("anonymous");
    1120            1 :         let rules = vec![AuthRule {
    1121            1 :             id: "aws-cognito".to_owned(),
    1122            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1123            1 :             audience: None,
    1124            1 :             role_names: vec![RoleNameInt::from(&role_name)],
    1125            1 :         }];
    1126              : 
    1127            1 :         let fetch = Fetch(rules);
    1128            1 :         let jwk_cache = JwkCache::default();
    1129              : 
    1130            1 :         let endpoint = EndpointId::from("ep");
    1131              : 
    1132            1 :         jwk_cache
    1133            1 :             .check_jwt(
    1134            1 :                 &RequestContext::test(),
    1135            1 :                 endpoint.clone(),
    1136            1 :                 &role_name,
    1137            1 :                 &fetch,
    1138            1 :                 &token,
    1139            1 :             )
    1140            1 :             .await
    1141            1 :             .unwrap();
    1142            1 :     }
    1143              : 
    1144              :     #[tokio::test]
    1145            1 :     async fn check_jwt_invalid_signature() {
    1146            1 :         let (_, jwk) = new_ec_jwk("1".into());
    1147            1 :         let (key, _) = new_ec_jwk("1".into());
    1148              : 
    1149              :         // has a matching kid, but signed by the wrong key
    1150            1 :         let bad_jwt = new_ec_jwt("1".into(), &key);
    1151              : 
    1152            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1153            1 :         let jwks_addr = jwks_server(move |path| match path {
    1154            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1155            0 :             _ => None,
    1156            1 :         })
    1157            1 :         .await;
    1158              : 
    1159            1 :         let role = RoleName::from("authenticated");
    1160              : 
    1161            1 :         let rules = vec![AuthRule {
    1162            1 :             id: String::new(),
    1163            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1164            1 :             audience: None,
    1165            1 :             role_names: vec![RoleNameInt::from(&role)],
    1166            1 :         }];
    1167              : 
    1168            1 :         let fetch = Fetch(rules);
    1169            1 :         let jwk_cache = JwkCache::default();
    1170              : 
    1171            1 :         let ep = EndpointId::from("ep");
    1172              : 
    1173            1 :         let ctx = RequestContext::test();
    1174            1 :         let err = jwk_cache
    1175            1 :             .check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
    1176            1 :             .await
    1177            1 :             .unwrap_err();
    1178            1 :         assert!(
    1179            1 :             matches!(err, JwtError::Signature(_)),
    1180            1 :             "expected \"signature error\", got {err:?}"
    1181            1 :         );
    1182            1 :     }
    1183              : 
    1184              :     #[tokio::test]
    1185            1 :     async fn check_jwt_unknown_role() {
    1186            1 :         let (key, jwk) = new_rsa_jwk(RS1, "1".into());
    1187            1 :         let jwt = new_rsa_jwt("1".into(), key);
    1188              : 
    1189            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1190            1 :         let jwks_addr = jwks_server(move |path| match path {
    1191            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1192            0 :             _ => None,
    1193            1 :         })
    1194            1 :         .await;
    1195              : 
    1196            1 :         let role = RoleName::from("authenticated");
    1197            1 :         let rules = vec![AuthRule {
    1198            1 :             id: String::new(),
    1199            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1200            1 :             audience: None,
    1201            1 :             role_names: vec![RoleNameInt::from(&role)],
    1202            1 :         }];
    1203              : 
    1204            1 :         let fetch = Fetch(rules);
    1205            1 :         let jwk_cache = JwkCache::default();
    1206              : 
    1207            1 :         let ep = EndpointId::from("ep");
    1208              : 
    1209              :         // this role_name is not accepted
    1210            1 :         let bad_role_name = RoleName::from("cloud_admin");
    1211              : 
    1212            1 :         let ctx = RequestContext::test();
    1213            1 :         let err = jwk_cache
    1214            1 :             .check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
    1215            1 :             .await
    1216            1 :             .unwrap_err();
    1217              : 
    1218            1 :         assert!(
    1219            1 :             matches!(err, JwtError::JwkNotFound),
    1220            1 :             "expected \"jwk not found\", got {err:?}"
    1221            1 :         );
    1222            1 :     }
    1223              : 
    1224              :     #[tokio::test]
    1225            1 :     async fn check_jwt_invalid_claims() {
    1226            1 :         let (key, jwk) = new_ec_jwk("1".into());
    1227              : 
    1228            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1229            1 :         let jwks_addr = jwks_server(move |path| match path {
    1230            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1231            0 :             _ => None,
    1232            1 :         })
    1233            1 :         .await;
    1234              : 
    1235            1 :         let now = SystemTime::now()
    1236            1 :             .duration_since(SystemTime::UNIX_EPOCH)
    1237            1 :             .unwrap()
    1238            1 :             .as_secs();
    1239              : 
    1240              :         struct Test {
    1241              :             body: serde_json::Value,
    1242              :             error: JwtClaimsError,
    1243              :         }
    1244              : 
    1245            1 :         let table = vec![
    1246            1 :             Test {
    1247            1 :                 body: json! {{
    1248            1 :                     "nbf": now + 60,
    1249            1 :                     "aud": "neon",
    1250            1 :                 }},
    1251            1 :                 error: JwtClaimsError::JwtTokenNotYetReadyToUse(now + 60),
    1252            1 :             },
    1253            1 :             Test {
    1254            1 :                 body: json! {{
    1255            1 :                     "exp": now - 60,
    1256            1 :                     "aud": ["neon"],
    1257            1 :                 }},
    1258            1 :                 error: JwtClaimsError::JwtTokenHasExpired(now - 60),
    1259            1 :             },
    1260            1 :             Test {
    1261            1 :                 body: json! {{
    1262            1 :                 }},
    1263            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1264            1 :             },
    1265            1 :             Test {
    1266            1 :                 body: json! {{
    1267            1 :                     "aud": [],
    1268            1 :                 }},
    1269            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1270            1 :             },
    1271            1 :             Test {
    1272            1 :                 body: json! {{
    1273            1 :                     "aud": "foo",
    1274            1 :                 }},
    1275            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1276            1 :             },
    1277            1 :             Test {
    1278            1 :                 body: json! {{
    1279            1 :                     "aud": ["foo"],
    1280            1 :                 }},
    1281            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1282            1 :             },
    1283            1 :             Test {
    1284            1 :                 body: json! {{
    1285            1 :                     "aud": ["foo", "bar"],
    1286            1 :                 }},
    1287            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1288            1 :             },
    1289              :         ];
    1290              : 
    1291            1 :         let role = RoleName::from("authenticated");
    1292              : 
    1293            1 :         let rules = vec![AuthRule {
    1294            1 :             id: String::new(),
    1295            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1296            1 :             audience: Some("neon".to_string()),
    1297            1 :             role_names: vec![RoleNameInt::from(&role)],
    1298            1 :         }];
    1299              : 
    1300            1 :         let fetch = Fetch(rules);
    1301            1 :         let jwk_cache = JwkCache::default();
    1302              : 
    1303            1 :         let ep = EndpointId::from("ep");
    1304              : 
    1305            1 :         let ctx = RequestContext::test();
    1306            8 :         for test in table {
    1307            7 :             let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
    1308            1 : 
    1309            7 :             match jwk_cache
    1310            7 :                 .check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
    1311            7 :                 .await
    1312            1 :             {
    1313            7 :                 Err(JwtError::InvalidClaims(error)) if error == test.error => {}
    1314            1 :                 Err(err) => {
    1315            1 :                     panic!("expected {:?}, got {err:?}", test.error)
    1316            1 :                 }
    1317            1 :                 Ok(_payload) => {
    1318            1 :                     panic!("expected {:?}, got ok", test.error)
    1319            1 :                 }
    1320            1 :             }
    1321            1 :         }
    1322            1 :     }
    1323              : 
    1324              :     #[tokio::test]
    1325            1 :     async fn check_jwk_keycloak_regression() {
    1326            1 :         let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into());
    1327            1 :         let valid_jwk = serde_json::to_value(valid_jwk).unwrap();
    1328              : 
    1329              :         // This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones.
    1330              :         // This is taken directly from keycloak.
    1331            1 :         let invalid_jwk = serde_json::json! {
    1332              :             {
    1333            1 :                 "kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ",
    1334            1 :                 "kty": "RSA",
    1335            1 :                 "alg": "RSA-OAEP",
    1336            1 :                 "use": "enc",
    1337            1 :                 "n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw",
    1338            1 :                 "e": "AQAB",
    1339            1 :                 "x5c": [
    1340              :                     "MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI="
    1341              :                 ],
    1342            1 :                 "x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs",
    1343            1 :                 "x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU"
    1344              :             }
    1345              :         };
    1346              : 
    1347            1 :         let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }};
    1348            1 :         let jwks_addr = jwks_server(move |path| match path {
    1349            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1350            0 :             _ => None,
    1351            1 :         })
    1352            1 :         .await;
    1353              : 
    1354            1 :         let role_name = RoleName::from("anonymous");
    1355            1 :         let role = RoleNameInt::from(&role_name);
    1356              : 
    1357            1 :         let rules = vec![AuthRule {
    1358            1 :             id: "foo".to_owned(),
    1359            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1360            1 :             audience: None,
    1361            1 :             role_names: vec![role],
    1362            1 :         }];
    1363              : 
    1364            1 :         let fetch = Fetch(rules);
    1365            1 :         let jwk_cache = JwkCache::default();
    1366              : 
    1367            1 :         let endpoint = EndpointId::from("ep");
    1368              : 
    1369            1 :         let token = new_rsa_jwt("rs1".into(), rs);
    1370              : 
    1371            1 :         jwk_cache
    1372            1 :             .check_jwt(
    1373            1 :                 &RequestContext::test(),
    1374            1 :                 endpoint.clone(),
    1375            1 :                 &role_name,
    1376            1 :                 &fetch,
    1377            1 :                 &token,
    1378            1 :             )
    1379            1 :             .await
    1380            1 :             .unwrap();
    1381            1 :     }
    1382              : }
        

Generated by: LCOV version 2.1-beta