LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - jwt.rs (source / functions) Coverage Total Hit
Test: 046155f5c3321e806c1c5acca9ccd26414587b38.info Lines: 89.4 % 864 772
Test Date: 2025-03-27 12:42:09 Functions: 58.7 % 167 98

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::future::Future;
       3              : use std::sync::Arc;
       4              : use std::time::{Duration, SystemTime};
       5              : 
       6              : use arc_swap::ArcSwapOption;
       7              : use clashmap::ClashMap;
       8              : use jose_jwk::crypto::KeyInfo;
       9              : use reqwest::{Client, redirect};
      10              : use reqwest_retry::RetryTransientMiddleware;
      11              : use reqwest_retry::policies::ExponentialBackoff;
      12              : use serde::de::Visitor;
      13              : use serde::{Deserialize, Deserializer};
      14              : use serde_json::value::RawValue;
      15              : use signature::Verifier;
      16              : use thiserror::Error;
      17              : use tokio::time::Instant;
      18              : 
      19              : use crate::auth::backend::ComputeCredentialKeys;
      20              : use crate::context::RequestContext;
      21              : use crate::control_plane::errors::GetEndpointJwksError;
      22              : use crate::http::read_body_with_limit;
      23              : use crate::intern::RoleNameInt;
      24              : use crate::types::{EndpointId, RoleName};
      25              : 
      26              : // TODO(conrad): make these configurable.
      27              : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
      28              : const MIN_RENEW: Duration = Duration::from_secs(30);
      29              : const AUTO_RENEW: Duration = Duration::from_secs(300);
      30              : const MAX_RENEW: Duration = Duration::from_secs(3600);
      31              : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
      32              : const JWKS_USER_AGENT: &str = "neon-proxy";
      33              : 
      34              : const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
      35              : const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
      36              : const JWKS_FETCH_RETRIES: u32 = 3;
      37              : 
      38              : /// How to get the JWT auth rules
      39              : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
      40              :     fn fetch_auth_rules(
      41              :         &self,
      42              :         ctx: &RequestContext,
      43              :         endpoint: EndpointId,
      44              :     ) -> impl Future<Output = Result<Vec<AuthRule>, FetchAuthRulesError>> + Send;
      45              : }
      46              : 
      47              : #[derive(Error, Debug)]
      48              : pub(crate) enum FetchAuthRulesError {
      49              :     #[error(transparent)]
      50              :     GetEndpointJwks(#[from] GetEndpointJwksError),
      51              : 
      52              :     #[error("JWKs settings for this role were not configured")]
      53              :     RoleJwksNotConfigured,
      54              : }
      55              : 
      56              : #[derive(Clone)]
      57              : pub(crate) struct AuthRule {
      58              :     pub(crate) id: String,
      59              :     pub(crate) jwks_url: url::Url,
      60              :     pub(crate) audience: Option<String>,
      61              :     pub(crate) role_names: Vec<RoleNameInt>,
      62              : }
      63              : 
      64              : pub struct JwkCache {
      65              :     client: reqwest_middleware::ClientWithMiddleware,
      66              : 
      67              :     map: ClashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
      68              : }
      69              : 
      70              : pub(crate) struct JwkCacheEntry {
      71              :     /// Should refetch at least every hour to verify when old keys have been removed.
      72              :     /// Should refetch when new key IDs are seen only every 5 minutes or so
      73              :     last_retrieved: Instant,
      74              : 
      75              :     /// cplane will return multiple JWKs urls that we need to scrape.
      76              :     key_sets: ahash::HashMap<String, KeySet>,
      77              : }
      78              : 
      79              : impl JwkCacheEntry {
      80           19 :     fn find_jwk_and_audience(
      81           19 :         &self,
      82           19 :         key_id: &str,
      83           19 :         role_name: &RoleName,
      84           19 :     ) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
      85           19 :         self.key_sets
      86           19 :             .values()
      87           19 :             // make sure our requested role has access to the key set
      88           23 :             .filter(|key_set| {
      89           23 :                 key_set
      90           23 :                     .role_names
      91           23 :                     .iter()
      92           29 :                     .any(|role| *role.as_str() == **role_name)
      93           23 :             })
      94           19 :             // try and find the requested key-id in the key set
      95           22 :             .find_map(|key_set| {
      96           22 :                 key_set
      97           22 :                     .find_key(key_id)
      98           22 :                     .map(|jwk| (jwk, key_set.audience.as_deref()))
      99           22 :             })
     100           19 :     }
     101              : }
     102              : 
     103              : struct KeySet {
     104              :     jwks: jose_jwk::JwkSet,
     105              :     audience: Option<String>,
     106              :     role_names: Vec<RoleNameInt>,
     107              : }
     108              : 
     109              : impl KeySet {
     110           22 :     fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
     111           22 :         self.jwks
     112           22 :             .keys
     113           22 :             .iter()
     114           30 :             .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
     115           22 :     }
     116              : }
     117              : 
     118              : pub(crate) struct JwkCacheEntryLock {
     119              :     cached: ArcSwapOption<JwkCacheEntry>,
     120              :     lookup: tokio::sync::Semaphore,
     121              : }
     122              : 
     123              : impl Default for JwkCacheEntryLock {
     124            7 :     fn default() -> Self {
     125            7 :         JwkCacheEntryLock {
     126            7 :             cached: ArcSwapOption::empty(),
     127            7 :             lookup: tokio::sync::Semaphore::new(1),
     128            7 :         }
     129            7 :     }
     130              : }
     131              : 
     132            9 : #[derive(Deserialize)]
     133              : struct JwkSet<'a> {
     134              :     /// we parse into raw-value because not all keys in a JWKS are ones
     135              :     /// we can parse directly, so we parse them lazily.
     136              :     #[serde(borrow)]
     137              :     keys: Vec<&'a RawValue>,
     138              : }
     139              : 
     140              : /// Given a jwks_url, fetch the JWKS and parse out all the signing JWKs.
     141              : /// Returns `None` and log a warning if there are any errors.
     142            9 : async fn fetch_jwks(
     143            9 :     client: &reqwest_middleware::ClientWithMiddleware,
     144            9 :     jwks_url: url::Url,
     145            9 : ) -> Option<jose_jwk::JwkSet> {
     146            9 :     let req = client.get(jwks_url.clone());
     147              :     // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
     148            9 :     let resp = req.send().await.and_then(|r| {
     149            9 :         r.error_for_status()
     150            9 :             .map_err(reqwest_middleware::Error::Reqwest)
     151            9 :     });
     152              : 
     153            9 :     let resp = match resp {
     154            9 :         Ok(r) => r,
     155              :         // TODO: should we re-insert JWKs if we want to keep this JWKs URL?
     156              :         // I expect these failures would be quite sparse.
     157            0 :         Err(e) => {
     158            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not fetch JWKs");
     159            0 :             return None;
     160              :         }
     161              :     };
     162              : 
     163            9 :     let resp: http::Response<reqwest::Body> = resp.into();
     164              : 
     165            9 :     let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE).await {
     166            9 :         Ok(bytes) => bytes,
     167            0 :         Err(e) => {
     168            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
     169            0 :             return None;
     170              :         }
     171              :     };
     172              : 
     173            9 :     let jwks = match serde_json::from_slice::<JwkSet>(&bytes) {
     174            9 :         Ok(jwks) => jwks,
     175            0 :         Err(e) => {
     176            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
     177            0 :             return None;
     178              :         }
     179              :     };
     180              : 
     181              :     // `jose_jwk::Jwk` is quite large (288 bytes). Let's not pre-allocate for what we don't need.
     182              :     //
     183              :     // Even though we limit our responses to 64KiB, we could still receive a payload like
     184              :     // `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`. Parsing this as `RawValue` uses 468KiB.
     185              :     // Pre-allocating the corresponding `Vec::<jose_jwk::Jwk>::with_capacity(30000)` uses 8.2MiB.
     186            9 :     let mut keys = vec![];
     187            9 : 
     188            9 :     let mut failed = 0;
     189           23 :     for key in jwks.keys {
     190           14 :         let key = match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
     191           13 :             Ok(key) => key,
     192            1 :             Err(e) => {
     193            1 :                 tracing::debug!(url=?jwks_url, failed=?e, "could not decode JWK");
     194            1 :                 failed += 1;
     195            1 :                 continue;
     196              :             }
     197              :         };
     198              : 
     199              :         // if `use` (called `cls` in rust) is specified to be something other than signing,
     200              :         // we can skip storing it.
     201           13 :         if key
     202           13 :             .prm
     203           13 :             .cls
     204           13 :             .as_ref()
     205           13 :             .is_some_and(|c| *c != jose_jwk::Class::Signing)
     206              :         {
     207            0 :             continue;
     208           13 :         }
     209           13 : 
     210           13 :         keys.push(key);
     211              :     }
     212              : 
     213            9 :     keys.shrink_to_fit();
     214            9 : 
     215            9 :     if failed > 0 {
     216            1 :         tracing::warn!(url=?jwks_url, failed, "could not decode JWKs");
     217            8 :     }
     218              : 
     219            9 :     if keys.is_empty() {
     220            0 :         tracing::warn!(url=?jwks_url, "no valid JWKs found inside the response body");
     221            0 :         return None;
     222            9 :     }
     223            9 : 
     224            9 :     Some(jose_jwk::JwkSet { keys })
     225            9 : }
     226              : 
     227              : impl JwkCacheEntryLock {
     228            7 :     async fn acquire_permit(self: &Arc<Self>) -> JwkRenewalPermit<'_> {
     229            7 :         JwkRenewalPermit::acquire_permit(self).await
     230            7 :     }
     231              : 
     232            0 :     fn try_acquire_permit(self: &Arc<Self>) -> Option<JwkRenewalPermit<'_>> {
     233            0 :         JwkRenewalPermit::try_acquire_permit(self)
     234            0 :     }
     235              : 
     236            7 :     async fn renew_jwks<F: FetchAuthRules>(
     237            7 :         &self,
     238            7 :         _permit: JwkRenewalPermit<'_>,
     239            7 :         ctx: &RequestContext,
     240            7 :         client: &reqwest_middleware::ClientWithMiddleware,
     241            7 :         endpoint: EndpointId,
     242            7 :         auth_rules: &F,
     243            7 :     ) -> Result<Arc<JwkCacheEntry>, JwtError> {
     244            7 :         // double check that no one beat us to updating the cache.
     245            7 :         let now = Instant::now();
     246            7 :         let guard = self.cached.load_full();
     247            7 :         if let Some(cached) = guard {
     248            0 :             let last_update = now.duration_since(cached.last_retrieved);
     249            0 :             if last_update < Duration::from_secs(300) {
     250            0 :                 return Ok(cached);
     251            0 :             }
     252            7 :         }
     253              : 
     254            7 :         let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
     255            7 :         let mut key_sets =
     256            7 :             ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
     257              : 
     258              :         // TODO(conrad): run concurrently
     259              :         // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
     260           16 :         for rule in rules {
     261            9 :             if let Some(jwks) = fetch_jwks(client, rule.jwks_url).await {
     262            9 :                 key_sets.insert(
     263            9 :                     rule.id,
     264            9 :                     KeySet {
     265            9 :                         jwks,
     266            9 :                         audience: rule.audience,
     267            9 :                         role_names: rule.role_names,
     268            9 :                     },
     269            9 :                 );
     270            9 :             }
     271              :         }
     272              : 
     273            7 :         let entry = Arc::new(JwkCacheEntry {
     274            7 :             last_retrieved: now,
     275            7 :             key_sets,
     276            7 :         });
     277            7 :         self.cached.swap(Some(Arc::clone(&entry)));
     278            7 : 
     279            7 :         Ok(entry)
     280            7 :     }
     281              : 
     282           19 :     async fn get_or_update_jwk_cache<F: FetchAuthRules>(
     283           19 :         self: &Arc<Self>,
     284           19 :         ctx: &RequestContext,
     285           19 :         client: &reqwest_middleware::ClientWithMiddleware,
     286           19 :         endpoint: EndpointId,
     287           19 :         fetch: &F,
     288           19 :     ) -> Result<Arc<JwkCacheEntry>, JwtError> {
     289           19 :         let now = Instant::now();
     290           19 :         let guard = self.cached.load_full();
     291              : 
     292              :         // if we have no cached JWKs, try and get some
     293           19 :         let Some(cached) = guard else {
     294            7 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     295            7 :             let permit = self.acquire_permit().await;
     296            7 :             return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
     297              :         };
     298              : 
     299           12 :         let last_update = now.duration_since(cached.last_retrieved);
     300           12 : 
     301           12 :         // check if the cached JWKs need updating.
     302           12 :         if last_update > MAX_RENEW {
     303            0 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     304            0 :             let permit = self.acquire_permit().await;
     305              : 
     306              :             // it's been too long since we checked the keys. wait for them to update.
     307            0 :             return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
     308           12 :         }
     309           12 : 
     310           12 :         // every 5 minutes we should spawn a job to eagerly update the token.
     311           12 :         if last_update > AUTO_RENEW {
     312            0 :             if let Some(permit) = self.try_acquire_permit() {
     313            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit acquired");
     314            0 :                 let permit = permit.into_owned();
     315            0 :                 let entry = self.clone();
     316            0 :                 let client = client.clone();
     317            0 :                 let fetch = fetch.clone();
     318            0 :                 let ctx = ctx.clone();
     319            0 :                 tokio::spawn(async move {
     320            0 :                     if let Err(e) = entry
     321            0 :                         .renew_jwks(permit, &ctx, &client, endpoint, &fetch)
     322            0 :                         .await
     323              :                     {
     324            0 :                         tracing::warn!(error=?e, "could not fetch JWKs in background job");
     325            0 :                     }
     326            0 :                 });
     327            0 :             } else {
     328            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
     329              :             }
     330           12 :         }
     331              : 
     332           12 :         Ok(cached)
     333           19 :     }
     334              : 
     335           19 :     async fn check_jwt<F: FetchAuthRules>(
     336           19 :         self: &Arc<Self>,
     337           19 :         ctx: &RequestContext,
     338           19 :         jwt: &str,
     339           19 :         client: &reqwest_middleware::ClientWithMiddleware,
     340           19 :         endpoint: EndpointId,
     341           19 :         role_name: &RoleName,
     342           19 :         fetch: &F,
     343           19 :     ) -> Result<ComputeCredentialKeys, JwtError> {
     344              :         // JWT compact form is defined to be
     345              :         // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
     346              :         // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
     347              : 
     348           19 :         let (header_payload, signature) = jwt
     349           19 :             .rsplit_once('.')
     350           19 :             .ok_or(JwtEncodingError::InvalidCompactForm)?;
     351           19 :         let (header, payload) = header_payload
     352           19 :             .split_once('.')
     353           19 :             .ok_or(JwtEncodingError::InvalidCompactForm)?;
     354              : 
     355           19 :         let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?;
     356           19 :         let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
     357              : 
     358           19 :         let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)?;
     359           19 :         let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
     360              : 
     361           19 :         if let Some(iss) = &payload.issuer {
     362           12 :             ctx.set_jwt_issuer(iss.as_ref().to_owned());
     363           12 :         }
     364              : 
     365           19 :         let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)?;
     366              : 
     367           19 :         let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
     368              : 
     369           19 :         let mut guard = self
     370           19 :             .get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
     371           19 :             .await?;
     372              : 
     373              :         // get the key from the JWKs if possible. If not, wait for the keys to update.
     374           18 :         let (jwk, expected_audience) = loop {
     375           19 :             match guard.find_jwk_and_audience(&kid, role_name) {
     376           18 :                 Some(jwk) => break jwk,
     377            1 :                 None if guard.last_retrieved.elapsed() > MIN_RENEW => {
     378            0 :                     let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     379              : 
     380            0 :                     let permit = self.acquire_permit().await;
     381            0 :                     guard = self
     382            0 :                         .renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
     383            0 :                         .await?;
     384              :                 }
     385            1 :                 _ => return Err(JwtError::JwkNotFound),
     386              :             }
     387              :         };
     388              : 
     389           18 :         if !jwk.is_supported(&header.algorithm) {
     390            0 :             return Err(JwtError::SignatureAlgorithmNotSupported);
     391           18 :         }
     392           18 : 
     393           18 :         match &jwk.key {
     394           13 :             jose_jwk::Key::Ec(key) => {
     395           13 :                 verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
     396              :             }
     397            5 :             jose_jwk::Key::Rsa(key) => {
     398            5 :                 verify_rsa_signature(header_payload.as_bytes(), &sig, key, &header.algorithm)?;
     399              :             }
     400            0 :             key => return Err(JwtError::UnsupportedKeyType(key.into())),
     401              :         }
     402              : 
     403           17 :         tracing::debug!(?payload, "JWT signature valid with claims");
     404              : 
     405           17 :         if let Some(aud) = expected_audience {
     406            7 :             if payload.audience.0.iter().all(|s| s != aud) {
     407            5 :                 return Err(JwtError::InvalidClaims(
     408            5 :                     JwtClaimsError::InvalidJwtTokenAudience,
     409            5 :                 ));
     410            2 :             }
     411           10 :         }
     412              : 
     413           12 :         let now = SystemTime::now();
     414              : 
     415           12 :         if let Some(exp) = payload.expiration {
     416           11 :             if now >= exp + CLOCK_SKEW_LEEWAY {
     417            1 :                 return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired));
     418           10 :             }
     419            1 :         }
     420              : 
     421           11 :         if let Some(nbf) = payload.not_before {
     422           10 :             if nbf >= now + CLOCK_SKEW_LEEWAY {
     423            1 :                 return Err(JwtError::InvalidClaims(
     424            1 :                     JwtClaimsError::JwtTokenNotYetReadyToUse,
     425            1 :                 ));
     426            9 :             }
     427            1 :         }
     428              : 
     429           10 :         Ok(ComputeCredentialKeys::JwtPayload(payloadb))
     430           19 :     }
     431              : }
     432              : 
     433              : impl JwkCache {
     434           19 :     pub(crate) async fn check_jwt<F: FetchAuthRules>(
     435           19 :         &self,
     436           19 :         ctx: &RequestContext,
     437           19 :         endpoint: EndpointId,
     438           19 :         role_name: &RoleName,
     439           19 :         fetch: &F,
     440           19 :         jwt: &str,
     441           19 :     ) -> Result<ComputeCredentialKeys, JwtError> {
     442           19 :         // try with just a read lock first
     443           19 :         let key = (endpoint.clone(), role_name.clone());
     444           19 :         let entry = self.map.get(&key).as_deref().map(Arc::clone);
     445           19 :         let entry = entry.unwrap_or_else(|| {
     446            7 :             // acquire a write lock after to insert.
     447            7 :             let entry = self.map.entry(key).or_default();
     448            7 :             Arc::clone(&*entry)
     449           19 :         });
     450           19 : 
     451           19 :         entry
     452           19 :             .check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
     453           19 :             .await
     454           19 :     }
     455              : }
     456              : 
     457              : impl Default for JwkCache {
     458            9 :     fn default() -> Self {
     459            9 :         let client = Client::builder()
     460            9 :             .user_agent(JWKS_USER_AGENT)
     461            9 :             .redirect(redirect::Policy::none())
     462            9 :             .tls_built_in_native_certs(true)
     463            9 :             .connect_timeout(JWKS_CONNECT_TIMEOUT)
     464            9 :             .timeout(JWKS_FETCH_TIMEOUT)
     465            9 :             .build()
     466            9 :             .expect("client config should be valid");
     467            9 : 
     468            9 :         // Retry up to 3 times with increasing intervals between attempts.
     469            9 :         let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES);
     470            9 : 
     471            9 :         let client = reqwest_middleware::ClientBuilder::new(client)
     472            9 :             .with(RetryTransientMiddleware::new_with_policy(retry_policy))
     473            9 :             .build();
     474            9 : 
     475            9 :         JwkCache {
     476            9 :             client,
     477            9 :             map: ClashMap::default(),
     478            9 :         }
     479            9 :     }
     480              : }
     481              : 
     482           13 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> Result<(), JwtError> {
     483              :     use ecdsa::Signature;
     484              :     use signature::Verifier;
     485              : 
     486           13 :     match key.crv {
     487              :         jose_jwk::EcCurves::P256 => {
     488           13 :             let pk = p256::PublicKey::try_from(key).map_err(JwtError::InvalidP256Key)?;
     489           13 :             let key = p256::ecdsa::VerifyingKey::from(&pk);
     490           13 :             let sig = Signature::from_slice(sig)?;
     491           13 :             key.verify(data, &sig)?;
     492              :         }
     493            0 :         key => return Err(JwtError::UnsupportedEcKeyType(key)),
     494              :     }
     495              : 
     496           12 :     Ok(())
     497           13 : }
     498              : 
     499            5 : fn verify_rsa_signature(
     500            5 :     data: &[u8],
     501            5 :     sig: &[u8],
     502            5 :     key: &jose_jwk::Rsa,
     503            5 :     alg: &jose_jwa::Algorithm,
     504            5 : ) -> Result<(), JwtError> {
     505              :     use jose_jwa::{Algorithm, Signing};
     506              :     use rsa::RsaPublicKey;
     507              :     use rsa::pkcs1v15::{Signature, VerifyingKey};
     508              : 
     509            5 :     let key = RsaPublicKey::try_from(key).map_err(JwtError::InvalidRsaKey)?;
     510              : 
     511            5 :     match alg {
     512              :         Algorithm::Signing(Signing::Rs256) => {
     513            5 :             let key = VerifyingKey::<sha2::Sha256>::new(key);
     514            5 :             let sig = Signature::try_from(sig)?;
     515            5 :             key.verify(data, &sig)?;
     516              :         }
     517            0 :         _ => return Err(JwtError::InvalidRsaSigningAlgorithm),
     518              :     }
     519              : 
     520            5 :     Ok(())
     521            5 : }
     522              : 
     523              : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
     524           38 : #[derive(serde::Deserialize, serde::Serialize)]
     525              : struct JwtHeader<'a> {
     526              :     /// must be a supported alg
     527              :     #[serde(rename = "alg")]
     528              :     algorithm: jose_jwa::Algorithm,
     529              :     /// key id, must be provided for our usecase
     530              :     #[serde(rename = "kid", borrow)]
     531              :     key_id: Option<Cow<'a, str>>,
     532              : }
     533              : 
     534              : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
     535           97 : #[derive(serde::Deserialize, Debug)]
     536              : #[allow(dead_code)]
     537              : struct JwtPayload<'a> {
     538              :     /// Audience - Recipient for which the JWT is intended
     539              :     #[serde(rename = "aud", default)]
     540              :     audience: OneOrMany,
     541              :     /// Expiration - Time after which the JWT expires
     542              :     #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
     543              :     expiration: Option<SystemTime>,
     544              :     /// Not before - Time after which the JWT expires
     545              :     #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
     546              :     not_before: Option<SystemTime>,
     547              : 
     548              :     // the following entries are only extracted for the sake of debug logging.
     549              :     /// Issuer of the JWT
     550              :     #[serde(rename = "iss", borrow)]
     551              :     issuer: Option<Cow<'a, str>>,
     552              :     /// Subject of the JWT (the user)
     553              :     #[serde(rename = "sub", borrow)]
     554              :     subject: Option<Cow<'a, str>>,
     555              :     /// Unique token identifier
     556              :     #[serde(rename = "jti", borrow)]
     557              :     jwt_id: Option<Cow<'a, str>>,
     558              :     /// Unique session identifier
     559              :     #[serde(rename = "sid", borrow)]
     560              :     session_id: Option<Cow<'a, str>>,
     561              : }
     562              : 
     563              : /// `OneOrMany` supports parsing either a single item or an array of items.
     564              : ///
     565              : /// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
     566              : ///
     567              : /// > The "aud" (audience) claim identifies the recipients that the JWT is
     568              : /// > intended for.  Each principal intended to process the JWT MUST
     569              : /// > identify itself with a value in the audience claim.  If the principal
     570              : /// > processing the claim does not identify itself with a value in the
     571              : /// > "aud" claim when this claim is present, then the JWT MUST be
     572              : /// > rejected.  In the general case, the "aud" value is **an array of case-
     573              : /// > sensitive strings**, each containing a StringOrURI value.  In the
     574              : /// > special case when the JWT has one audience, the "aud" value MAY be a
     575              : /// > **single case-sensitive string** containing a StringOrURI value.  The
     576              : /// > interpretation of audience values is generally application specific.
     577              : /// > Use of this claim is OPTIONAL.
     578              : #[derive(Default, Debug)]
     579              : struct OneOrMany(Vec<String>);
     580              : 
     581              : impl<'de> Deserialize<'de> for OneOrMany {
     582           17 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     583           17 :     where
     584           17 :         D: Deserializer<'de>,
     585           17 :     {
     586              :         struct OneOrManyVisitor;
     587              :         impl<'de> Visitor<'de> for OneOrManyVisitor {
     588              :             type Value = OneOrMany;
     589              : 
     590            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
     591            0 :                 formatter.write_str("a single string or an array of strings")
     592            0 :             }
     593              : 
     594            2 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
     595            2 :             where
     596            2 :                 E: serde::de::Error,
     597            2 :             {
     598            2 :                 Ok(OneOrMany(vec![v.to_owned()]))
     599            2 :             }
     600              : 
     601           15 :             fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
     602           15 :             where
     603           15 :                 A: serde::de::SeqAccess<'de>,
     604           15 :             {
     605           15 :                 let mut v = vec![];
     606           52 :                 while let Some(s) = seq.next_element()? {
     607           37 :                     v.push(s);
     608           37 :                 }
     609           15 :                 Ok(OneOrMany(v))
     610           15 :             }
     611              :         }
     612           17 :         deserializer.deserialize_any(OneOrManyVisitor)
     613           17 :     }
     614              : }
     615              : 
     616           25 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
     617           25 :     let d = <Option<u64>>::deserialize(d)?;
     618           25 :     Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
     619           25 : }
     620              : 
     621              : struct JwkRenewalPermit<'a> {
     622              :     inner: Option<JwkRenewalPermitInner<'a>>,
     623              : }
     624              : 
     625              : enum JwkRenewalPermitInner<'a> {
     626              :     Owned(Arc<JwkCacheEntryLock>),
     627              :     Borrowed(&'a Arc<JwkCacheEntryLock>),
     628              : }
     629              : 
     630              : impl JwkRenewalPermit<'_> {
     631            0 :     fn into_owned(mut self) -> JwkRenewalPermit<'static> {
     632            0 :         JwkRenewalPermit {
     633            0 :             inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
     634            0 :         }
     635            0 :     }
     636              : 
     637            7 :     async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
     638            7 :         match from.lookup.acquire().await {
     639            7 :             Ok(permit) => {
     640            7 :                 permit.forget();
     641            7 :                 JwkRenewalPermit {
     642            7 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     643            7 :                 }
     644              :             }
     645            0 :             Err(_) => panic!("semaphore should not be closed"),
     646              :         }
     647            7 :     }
     648              : 
     649            0 :     fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
     650            0 :         match from.lookup.try_acquire() {
     651            0 :             Ok(permit) => {
     652            0 :                 permit.forget();
     653            0 :                 Some(JwkRenewalPermit {
     654            0 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     655            0 :                 })
     656              :             }
     657            0 :             Err(tokio::sync::TryAcquireError::NoPermits) => None,
     658            0 :             Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
     659              :         }
     660            0 :     }
     661              : }
     662              : 
     663              : impl JwkRenewalPermitInner<'_> {
     664            0 :     fn into_owned(self) -> JwkRenewalPermitInner<'static> {
     665            0 :         match self {
     666            0 :             JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
     667            0 :             JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
     668              :         }
     669            0 :     }
     670              : }
     671              : 
     672              : impl Drop for JwkRenewalPermit<'_> {
     673            7 :     fn drop(&mut self) {
     674            7 :         let entry = match &self.inner {
     675            0 :             None => return,
     676            0 :             Some(JwkRenewalPermitInner::Owned(p)) => p,
     677            7 :             Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
     678              :         };
     679            7 :         entry.lookup.add_permits(1);
     680            7 :     }
     681              : }
     682              : 
     683              : #[derive(Error, Debug)]
     684              : #[non_exhaustive]
     685              : pub(crate) enum JwtError {
     686              :     #[error("jwk not found")]
     687              :     JwkNotFound,
     688              : 
     689              :     #[error("missing key id")]
     690              :     MissingKeyId,
     691              : 
     692              :     #[error("Provided authentication token is not a valid JWT encoding")]
     693              :     JwtEncoding(#[from] JwtEncodingError),
     694              : 
     695              :     #[error(transparent)]
     696              :     InvalidClaims(#[from] JwtClaimsError),
     697              : 
     698              :     #[error("invalid P256 key")]
     699              :     InvalidP256Key(jose_jwk::crypto::Error),
     700              : 
     701              :     #[error("invalid RSA key")]
     702              :     InvalidRsaKey(jose_jwk::crypto::Error),
     703              : 
     704              :     #[error("invalid RSA signing algorithm")]
     705              :     InvalidRsaSigningAlgorithm,
     706              : 
     707              :     #[error("unsupported EC key type {0:?}")]
     708              :     UnsupportedEcKeyType(jose_jwk::EcCurves),
     709              : 
     710              :     #[error("unsupported key type {0:?}")]
     711              :     UnsupportedKeyType(KeyType),
     712              : 
     713              :     #[error("signature algorithm not supported")]
     714              :     SignatureAlgorithmNotSupported,
     715              : 
     716              :     #[error("signature error: {0}")]
     717              :     Signature(#[from] signature::Error),
     718              : 
     719              :     #[error("failed to fetch auth rules: {0}")]
     720              :     FetchAuthRules(#[from] FetchAuthRulesError),
     721              : }
     722              : 
     723              : impl From<base64::DecodeError> for JwtError {
     724            0 :     fn from(err: base64::DecodeError) -> Self {
     725            0 :         JwtEncodingError::Base64Decode(err).into()
     726            0 :     }
     727              : }
     728              : 
     729              : impl From<serde_json::Error> for JwtError {
     730            0 :     fn from(err: serde_json::Error) -> Self {
     731            0 :         JwtEncodingError::SerdeJson(err).into()
     732            0 :     }
     733              : }
     734              : 
     735              : #[derive(Error, Debug)]
     736              : #[non_exhaustive]
     737              : pub enum JwtEncodingError {
     738              :     #[error(transparent)]
     739              :     Base64Decode(#[from] base64::DecodeError),
     740              : 
     741              :     #[error(transparent)]
     742              :     SerdeJson(#[from] serde_json::Error),
     743              : 
     744              :     #[error("invalid compact form")]
     745              :     InvalidCompactForm,
     746              : }
     747              : 
     748              : #[derive(Error, Debug, PartialEq)]
     749              : #[non_exhaustive]
     750              : pub enum JwtClaimsError {
     751              :     #[error("invalid JWT token audience")]
     752              :     InvalidJwtTokenAudience,
     753              : 
     754              :     #[error("JWT token has expired")]
     755              :     JwtTokenHasExpired,
     756              : 
     757              :     #[error("JWT token is not yet ready to use")]
     758              :     JwtTokenNotYetReadyToUse,
     759              : }
     760              : 
     761              : #[allow(dead_code, reason = "Debug use only")]
     762              : #[derive(Debug)]
     763              : pub(crate) enum KeyType {
     764              :     Ec(jose_jwk::EcCurves),
     765              :     Rsa,
     766              :     Oct,
     767              :     Okp(jose_jwk::OkpCurves),
     768              :     Unknown,
     769              : }
     770              : 
     771              : impl From<&jose_jwk::Key> for KeyType {
     772            0 :     fn from(key: &jose_jwk::Key) -> Self {
     773            0 :         match key {
     774            0 :             jose_jwk::Key::Ec(ec) => Self::Ec(ec.crv),
     775            0 :             jose_jwk::Key::Rsa(_rsa) => Self::Rsa,
     776            0 :             jose_jwk::Key::Oct(_oct) => Self::Oct,
     777            0 :             jose_jwk::Key::Okp(okp) => Self::Okp(okp.crv),
     778            0 :             _ => Self::Unknown,
     779              :         }
     780            0 :     }
     781              : }
     782              : 
     783              : #[cfg(test)]
     784              : #[expect(clippy::unwrap_used)]
     785              : mod tests {
     786              :     use std::future::IntoFuture;
     787              :     use std::net::SocketAddr;
     788              :     use std::time::SystemTime;
     789              : 
     790              :     use base64::URL_SAFE_NO_PAD;
     791              :     use bytes::Bytes;
     792              :     use http::Response;
     793              :     use http_body_util::Full;
     794              :     use hyper::service::service_fn;
     795              :     use hyper_util::rt::TokioIo;
     796              :     use rand::rngs::OsRng;
     797              :     use rsa::pkcs8::DecodePrivateKey;
     798              :     use serde::Serialize;
     799              :     use serde_json::json;
     800              :     use signature::Signer;
     801              :     use tokio::net::TcpListener;
     802              : 
     803              :     use super::*;
     804              :     use crate::types::RoleName;
     805              : 
     806            6 :     fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
     807            6 :         let sk = p256::SecretKey::random(&mut OsRng);
     808            6 :         let pk = sk.public_key().into();
     809            6 :         let jwk = jose_jwk::Jwk {
     810            6 :             key: jose_jwk::Key::Ec(pk),
     811            6 :             prm: jose_jwk::Parameters {
     812            6 :                 kid: Some(kid),
     813            6 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
     814            6 :                 ..Default::default()
     815            6 :             },
     816            6 :         };
     817            6 :         (sk, jwk)
     818            6 :     }
     819              : 
     820            4 :     fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
     821            4 :         let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap();
     822            4 :         let pk = sk.to_public_key().into();
     823            4 :         let jwk = jose_jwk::Jwk {
     824            4 :             key: jose_jwk::Key::Rsa(pk),
     825            4 :             prm: jose_jwk::Parameters {
     826            4 :                 kid: Some(kid),
     827            4 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
     828            4 :                 ..Default::default()
     829            4 :             },
     830            4 :         };
     831            4 :         (sk, jwk)
     832            4 :     }
     833              : 
     834            8 :     fn now() -> u64 {
     835            8 :         SystemTime::now()
     836            8 :             .duration_since(SystemTime::UNIX_EPOCH)
     837            8 :             .unwrap()
     838            8 :             .as_secs()
     839            8 :     }
     840              : 
     841            7 :     fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
     842            7 :         let now = now();
     843            7 :         let body = typed_json::json! {{
     844            7 :             "exp": now + 3600,
     845            7 :             "nbf": now,
     846            7 :             "aud": ["audience1", "neon", "audience2"],
     847            7 :             "sub": "user1",
     848            7 :             "sid": "session1",
     849            7 :             "jti": "token1",
     850            7 :             "iss": "neon-testing",
     851            7 :         }};
     852            7 :         build_custom_jwt_payload(kid, body, sig)
     853            7 :     }
     854              : 
     855           15 :     fn build_custom_jwt_payload(
     856           15 :         kid: String,
     857           15 :         body: impl Serialize,
     858           15 :         sig: jose_jwa::Signing,
     859           15 :     ) -> String {
     860           15 :         let header = JwtHeader {
     861           15 :             algorithm: jose_jwa::Algorithm::Signing(sig),
     862           15 :             key_id: Some(Cow::Owned(kid)),
     863           15 :         };
     864           15 : 
     865           15 :         let header =
     866           15 :             base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
     867           15 :         let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
     868           15 : 
     869           15 :         format!("{header}.{body}")
     870           15 :     }
     871              : 
     872            3 :     fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
     873              :         use p256::ecdsa::{Signature, SigningKey};
     874              : 
     875            3 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
     876            3 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     877            3 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     878            3 : 
     879            3 :         format!("{payload}.{sig}")
     880            3 :     }
     881              : 
     882            8 :     fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
     883              :         use p256::ecdsa::{Signature, SigningKey};
     884              : 
     885            8 :         let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
     886            8 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     887            8 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     888            8 : 
     889            8 :         format!("{payload}.{sig}")
     890            8 :     }
     891              : 
     892            4 :     fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
     893              :         use rsa::pkcs1v15::SigningKey;
     894              :         use rsa::signature::SignatureEncoding;
     895              : 
     896            4 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
     897            4 :         let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
     898            4 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     899            4 : 
     900            4 :         format!("{payload}.{sig}")
     901            4 :     }
     902              : 
     903              :     // RSA key gen is slow....
     904              :     const RS1: &str = "-----BEGIN PRIVATE KEY-----
     905              : MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y
     906              : aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA
     907              : SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA
     908              : cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X
     909              : mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s
     910              : PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM
     911              : nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9
     912              : ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi
     913              : 5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+
     914              : KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz
     915              : 1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP
     916              : We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony
     917              : UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu
     918              : GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+
     919              : HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU
     920              : f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24
     921              : DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY
     922              : 9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF
     923              : /4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL
     924              : 3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE
     925              : qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d
     926              : ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW
     927              : rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf
     928              : JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI
     929              : rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9
     930              : 8LMFlz1QPqSZYN/A/kOcLBfa3A==
     931              : -----END PRIVATE KEY-----
     932              : ";
     933              :     const RS2: &str = "-----BEGIN PRIVATE KEY-----
     934              : MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J
     935              : HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J
     936              : zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk
     937              : qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO
     938              : IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt
     939              : ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy
     940              : uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv
     941              : K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG
     942              : LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe
     943              : QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT
     944              : zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX
     945              : 0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR
     946              : +Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1
     947              : P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL
     948              : x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/
     949              : FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg
     950              : iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M
     951              : 4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA
     952              : iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj
     953              : N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB
     954              : uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U
     955              : qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP
     956              : WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB
     957              : rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne
     958              : X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
     959              : 5JiconnI5aLek0QVPoFaVXFa
     960              : -----END PRIVATE KEY-----
     961              : ";
     962              : 
     963              :     #[derive(Clone)]
     964              :     struct Fetch(Vec<AuthRule>);
     965              : 
     966              :     impl FetchAuthRules for Fetch {
     967            7 :         async fn fetch_auth_rules(
     968            7 :             &self,
     969            7 :             _ctx: &RequestContext,
     970            7 :             _endpoint: EndpointId,
     971            7 :         ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
     972            7 :             Ok(self.0.clone())
     973            7 :         }
     974              :     }
     975              : 
     976            6 :     async fn jwks_server(
     977            6 :         router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
     978            6 :     ) -> SocketAddr {
     979            6 :         let router = Arc::new(router);
     980            9 :         let service = service_fn(move |req| {
     981            9 :             let router = Arc::clone(&router);
     982            9 :             async move {
     983            9 :                 match router(req.uri().path()) {
     984            9 :                     Some(body) => Response::builder()
     985            9 :                         .status(200)
     986            9 :                         .body(Full::new(Bytes::from(body))),
     987            0 :                     None => Response::builder()
     988            0 :                         .status(404)
     989            0 :                         .body(Full::new(Bytes::new())),
     990              :                 }
     991            9 :             }
     992            9 :         });
     993              : 
     994            6 :         let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
     995            6 :         let server = hyper::server::conn::http1::Builder::new();
     996            6 :         let addr = listener.local_addr().unwrap();
     997            6 :         tokio::spawn(async move {
     998              :             loop {
     999           12 :                 let (s, _) = listener.accept().await.unwrap();
    1000            6 :                 let serve = server.serve_connection(TokioIo::new(s), service.clone());
    1001            6 :                 tokio::spawn(serve.into_future());
    1002              :             }
    1003            6 :         });
    1004            6 : 
    1005            6 :         addr
    1006            6 :     }
    1007              : 
    1008              :     #[tokio::test]
    1009            1 :     async fn check_jwt_happy_path() {
    1010            1 :         let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
    1011            1 :         let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
    1012            1 :         let (ec1, jwk3) = new_ec_jwk("ec1".into());
    1013            1 :         let (ec2, jwk4) = new_ec_jwk("ec2".into());
    1014            1 : 
    1015            1 :         let foo_jwks = jose_jwk::JwkSet {
    1016            1 :             keys: vec![jwk1, jwk3],
    1017            1 :         };
    1018            1 :         let bar_jwks = jose_jwk::JwkSet {
    1019            1 :             keys: vec![jwk2, jwk4],
    1020            1 :         };
    1021            1 : 
    1022            4 :         let jwks_addr = jwks_server(move |path| match path {
    1023            4 :             "/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
    1024            2 :             "/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
    1025            1 :             _ => None,
    1026            4 :         })
    1027            1 :         .await;
    1028            1 : 
    1029            1 :         let role_name1 = RoleName::from("anonymous");
    1030            1 :         let role_name2 = RoleName::from("authenticated");
    1031            1 : 
    1032            1 :         let roles = vec![
    1033            1 :             RoleNameInt::from(&role_name1),
    1034            1 :             RoleNameInt::from(&role_name2),
    1035            1 :         ];
    1036            1 :         let rules = vec![
    1037            1 :             AuthRule {
    1038            1 :                 id: "foo".to_owned(),
    1039            1 :                 jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
    1040            1 :                 audience: None,
    1041            1 :                 role_names: roles.clone(),
    1042            1 :             },
    1043            1 :             AuthRule {
    1044            1 :                 id: "bar".to_owned(),
    1045            1 :                 jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
    1046            1 :                 audience: None,
    1047            1 :                 role_names: roles.clone(),
    1048            1 :             },
    1049            1 :         ];
    1050            1 : 
    1051            1 :         let fetch = Fetch(rules);
    1052            1 :         let jwk_cache = JwkCache::default();
    1053            1 : 
    1054            1 :         let endpoint = EndpointId::from("ep");
    1055            1 : 
    1056            1 :         let jwt1 = new_rsa_jwt("rs1".into(), rs1);
    1057            1 :         let jwt2 = new_rsa_jwt("rs2".into(), rs2);
    1058            1 :         let jwt3 = new_ec_jwt("ec1".into(), &ec1);
    1059            1 :         let jwt4 = new_ec_jwt("ec2".into(), &ec2);
    1060            1 : 
    1061            1 :         let tokens = [jwt1, jwt2, jwt3, jwt4];
    1062            1 :         let role_names = [role_name1, role_name2];
    1063            3 :         for role in &role_names {
    1064           10 :             for token in &tokens {
    1065            8 :                 jwk_cache
    1066            8 :                     .check_jwt(
    1067            8 :                         &RequestContext::test(),
    1068            8 :                         endpoint.clone(),
    1069            8 :                         role,
    1070            8 :                         &fetch,
    1071            8 :                         token,
    1072            8 :                     )
    1073            8 :                     .await
    1074            8 :                     .unwrap();
    1075            1 :             }
    1076            1 :         }
    1077            1 :     }
    1078              : 
    1079              :     /// AWS Cognito escapes the `/` in the URL.
    1080              :     #[tokio::test]
    1081            1 :     async fn check_jwt_regression_cognito_issuer() {
    1082            1 :         let (key, jwk) = new_ec_jwk("key".into());
    1083            1 : 
    1084            1 :         let now = now();
    1085            1 :         let token = new_custom_ec_jwt(
    1086            1 :             "key".into(),
    1087            1 :             &key,
    1088            1 :             typed_json::json! {{
    1089            1 :                 "sub": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
    1090            1 :                 // cognito uses `\/`. I cannot replicated that easily here as serde_json will refuse
    1091            1 :                 // to write that escape character. instead I will make a bogus URL using `\` instead.
    1092            1 :                 "iss": "https:\\\\cognito-idp.us-west-2.amazonaws.com\\us-west-2_abcdefgh",
    1093            1 :                 "client_id": "abcdefghijklmnopqrstuvwxyz",
    1094            1 :                 "origin_jti": "6759d132-3fe7-446e-9e90-2fe7e8017893",
    1095            1 :                 "event_id": "ec9c36ab-b01d-46a0-94e4-87fde6767065",
    1096            1 :                 "token_use": "access",
    1097            1 :                 "scope": "aws.cognito.signin.user.admin",
    1098            1 :                 "auth_time":now,
    1099            1 :                 "exp":now + 60,
    1100            1 :                 "iat":now,
    1101            1 :                 "jti": "b241614b-0b93-4bdc-96db-0a3c7061d9c0",
    1102            1 :                 "username": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
    1103            1 :             }},
    1104            1 :         );
    1105            1 : 
    1106            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1107            1 : 
    1108            1 :         let jwks_addr = jwks_server(move |_path| Some(serde_json::to_vec(&jwks).unwrap())).await;
    1109            1 : 
    1110            1 :         let role_name = RoleName::from("anonymous");
    1111            1 :         let rules = vec![AuthRule {
    1112            1 :             id: "aws-cognito".to_owned(),
    1113            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1114            1 :             audience: None,
    1115            1 :             role_names: vec![RoleNameInt::from(&role_name)],
    1116            1 :         }];
    1117            1 : 
    1118            1 :         let fetch = Fetch(rules);
    1119            1 :         let jwk_cache = JwkCache::default();
    1120            1 : 
    1121            1 :         let endpoint = EndpointId::from("ep");
    1122            1 : 
    1123            1 :         jwk_cache
    1124            1 :             .check_jwt(
    1125            1 :                 &RequestContext::test(),
    1126            1 :                 endpoint.clone(),
    1127            1 :                 &role_name,
    1128            1 :                 &fetch,
    1129            1 :                 &token,
    1130            1 :             )
    1131            1 :             .await
    1132            1 :             .unwrap();
    1133            1 :     }
    1134              : 
    1135              :     #[tokio::test]
    1136            1 :     async fn check_jwt_invalid_signature() {
    1137            1 :         let (_, jwk) = new_ec_jwk("1".into());
    1138            1 :         let (key, _) = new_ec_jwk("1".into());
    1139            1 : 
    1140            1 :         // has a matching kid, but signed by the wrong key
    1141            1 :         let bad_jwt = new_ec_jwt("1".into(), &key);
    1142            1 : 
    1143            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1144            1 :         let jwks_addr = jwks_server(move |path| match path {
    1145            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1146            1 :             _ => None,
    1147            1 :         })
    1148            1 :         .await;
    1149            1 : 
    1150            1 :         let role = RoleName::from("authenticated");
    1151            1 : 
    1152            1 :         let rules = vec![AuthRule {
    1153            1 :             id: String::new(),
    1154            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1155            1 :             audience: None,
    1156            1 :             role_names: vec![RoleNameInt::from(&role)],
    1157            1 :         }];
    1158            1 : 
    1159            1 :         let fetch = Fetch(rules);
    1160            1 :         let jwk_cache = JwkCache::default();
    1161            1 : 
    1162            1 :         let ep = EndpointId::from("ep");
    1163            1 : 
    1164            1 :         let ctx = RequestContext::test();
    1165            1 :         let err = jwk_cache
    1166            1 :             .check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
    1167            1 :             .await
    1168            1 :             .unwrap_err();
    1169            1 :         assert!(
    1170            1 :             matches!(err, JwtError::Signature(_)),
    1171            1 :             "expected \"signature error\", got {err:?}"
    1172            1 :         );
    1173            1 :     }
    1174              : 
    1175              :     #[tokio::test]
    1176            1 :     async fn check_jwt_unknown_role() {
    1177            1 :         let (key, jwk) = new_rsa_jwk(RS1, "1".into());
    1178            1 :         let jwt = new_rsa_jwt("1".into(), key);
    1179            1 : 
    1180            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1181            1 :         let jwks_addr = jwks_server(move |path| match path {
    1182            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1183            1 :             _ => None,
    1184            1 :         })
    1185            1 :         .await;
    1186            1 : 
    1187            1 :         let role = RoleName::from("authenticated");
    1188            1 :         let rules = vec![AuthRule {
    1189            1 :             id: String::new(),
    1190            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1191            1 :             audience: None,
    1192            1 :             role_names: vec![RoleNameInt::from(&role)],
    1193            1 :         }];
    1194            1 : 
    1195            1 :         let fetch = Fetch(rules);
    1196            1 :         let jwk_cache = JwkCache::default();
    1197            1 : 
    1198            1 :         let ep = EndpointId::from("ep");
    1199            1 : 
    1200            1 :         // this role_name is not accepted
    1201            1 :         let bad_role_name = RoleName::from("cloud_admin");
    1202            1 : 
    1203            1 :         let ctx = RequestContext::test();
    1204            1 :         let err = jwk_cache
    1205            1 :             .check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
    1206            1 :             .await
    1207            1 :             .unwrap_err();
    1208            1 : 
    1209            1 :         assert!(
    1210            1 :             matches!(err, JwtError::JwkNotFound),
    1211            1 :             "expected \"jwk not found\", got {err:?}"
    1212            1 :         );
    1213            1 :     }
    1214              : 
    1215              :     #[tokio::test]
    1216            1 :     async fn check_jwt_invalid_claims() {
    1217            1 :         let (key, jwk) = new_ec_jwk("1".into());
    1218            1 : 
    1219            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1220            1 :         let jwks_addr = jwks_server(move |path| match path {
    1221            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1222            1 :             _ => None,
    1223            1 :         })
    1224            1 :         .await;
    1225            1 : 
    1226            1 :         let now = SystemTime::now()
    1227            1 :             .duration_since(SystemTime::UNIX_EPOCH)
    1228            1 :             .unwrap()
    1229            1 :             .as_secs();
    1230            1 : 
    1231            1 :         struct Test {
    1232            1 :             body: serde_json::Value,
    1233            1 :             error: JwtClaimsError,
    1234            1 :         }
    1235            1 : 
    1236            1 :         let table = vec![
    1237            1 :             Test {
    1238            1 :                 body: json! {{
    1239            1 :                     "nbf": now + 60,
    1240            1 :                     "aud": "neon",
    1241            1 :                 }},
    1242            1 :                 error: JwtClaimsError::JwtTokenNotYetReadyToUse,
    1243            1 :             },
    1244            1 :             Test {
    1245            1 :                 body: json! {{
    1246            1 :                     "exp": now - 60,
    1247            1 :                     "aud": ["neon"],
    1248            1 :                 }},
    1249            1 :                 error: JwtClaimsError::JwtTokenHasExpired,
    1250            1 :             },
    1251            1 :             Test {
    1252            1 :                 body: json! {{
    1253            1 :                 }},
    1254            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1255            1 :             },
    1256            1 :             Test {
    1257            1 :                 body: json! {{
    1258            1 :                     "aud": [],
    1259            1 :                 }},
    1260            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1261            1 :             },
    1262            1 :             Test {
    1263            1 :                 body: json! {{
    1264            1 :                     "aud": "foo",
    1265            1 :                 }},
    1266            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1267            1 :             },
    1268            1 :             Test {
    1269            1 :                 body: json! {{
    1270            1 :                     "aud": ["foo"],
    1271            1 :                 }},
    1272            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1273            1 :             },
    1274            1 :             Test {
    1275            1 :                 body: json! {{
    1276            1 :                     "aud": ["foo", "bar"],
    1277            1 :                 }},
    1278            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1279            1 :             },
    1280            1 :         ];
    1281            1 : 
    1282            1 :         let role = RoleName::from("authenticated");
    1283            1 : 
    1284            1 :         let rules = vec![AuthRule {
    1285            1 :             id: String::new(),
    1286            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1287            1 :             audience: Some("neon".to_string()),
    1288            1 :             role_names: vec![RoleNameInt::from(&role)],
    1289            1 :         }];
    1290            1 : 
    1291            1 :         let fetch = Fetch(rules);
    1292            1 :         let jwk_cache = JwkCache::default();
    1293            1 : 
    1294            1 :         let ep = EndpointId::from("ep");
    1295            1 : 
    1296            1 :         let ctx = RequestContext::test();
    1297            8 :         for test in table {
    1298            7 :             let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
    1299            7 : 
    1300            7 :             match jwk_cache
    1301            7 :                 .check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
    1302            7 :                 .await
    1303            1 :             {
    1304            7 :                 Err(JwtError::InvalidClaims(error)) if error == test.error => {}
    1305            1 :                 Err(err) => {
    1306            0 :                     panic!("expected {:?}, got {err:?}", test.error)
    1307            1 :                 }
    1308            1 :                 Ok(_payload) => {
    1309            0 :                     panic!("expected {:?}, got ok", test.error)
    1310            1 :                 }
    1311            1 :             }
    1312            1 :         }
    1313            1 :     }
    1314              : 
    1315              :     #[tokio::test]
    1316            1 :     async fn check_jwk_keycloak_regression() {
    1317            1 :         let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into());
    1318            1 :         let valid_jwk = serde_json::to_value(valid_jwk).unwrap();
    1319            1 : 
    1320            1 :         // This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones.
    1321            1 :         // This is taken directly from keycloak.
    1322            1 :         let invalid_jwk = serde_json::json! {
    1323            1 :             {
    1324            1 :                 "kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ",
    1325            1 :                 "kty": "RSA",
    1326            1 :                 "alg": "RSA-OAEP",
    1327            1 :                 "use": "enc",
    1328            1 :                 "n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw",
    1329            1 :                 "e": "AQAB",
    1330            1 :                 "x5c": [
    1331            1 :                     "MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI="
    1332            1 :                 ],
    1333            1 :                 "x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs",
    1334            1 :                 "x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU"
    1335            1 :             }
    1336            1 :         };
    1337            1 : 
    1338            1 :         let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }};
    1339            1 :         let jwks_addr = jwks_server(move |path| match path {
    1340            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1341            1 :             _ => None,
    1342            1 :         })
    1343            1 :         .await;
    1344            1 : 
    1345            1 :         let role_name = RoleName::from("anonymous");
    1346            1 :         let role = RoleNameInt::from(&role_name);
    1347            1 : 
    1348            1 :         let rules = vec![AuthRule {
    1349            1 :             id: "foo".to_owned(),
    1350            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1351            1 :             audience: None,
    1352            1 :             role_names: vec![role],
    1353            1 :         }];
    1354            1 : 
    1355            1 :         let fetch = Fetch(rules);
    1356            1 :         let jwk_cache = JwkCache::default();
    1357            1 : 
    1358            1 :         let endpoint = EndpointId::from("ep");
    1359            1 : 
    1360            1 :         let token = new_rsa_jwt("rs1".into(), rs);
    1361            1 : 
    1362            1 :         jwk_cache
    1363            1 :             .check_jwt(
    1364            1 :                 &RequestContext::test(),
    1365            1 :                 endpoint.clone(),
    1366            1 :                 &role_name,
    1367            1 :                 &fetch,
    1368            1 :                 &token,
    1369            1 :             )
    1370            1 :             .await
    1371            1 :             .unwrap();
    1372            1 :     }
    1373              : }
        

Generated by: LCOV version 2.1-beta