LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - jwt.rs (source / functions) Coverage Total Hit
Test: 1b0a6a0c05cee5a7de360813c8034804e105ce1c.info Lines: 89.3 % 859 767
Test Date: 2025-03-12 00:01:28 Functions: 58.7 % 167 98

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::future::Future;
       3              : use std::sync::Arc;
       4              : use std::time::{Duration, SystemTime};
       5              : 
       6              : use arc_swap::ArcSwapOption;
       7              : use clashmap::ClashMap;
       8              : use jose_jwk::crypto::KeyInfo;
       9              : use reqwest::{Client, redirect};
      10              : use reqwest_retry::RetryTransientMiddleware;
      11              : use reqwest_retry::policies::ExponentialBackoff;
      12              : use serde::de::Visitor;
      13              : use serde::{Deserialize, Deserializer};
      14              : use serde_json::value::RawValue;
      15              : use signature::Verifier;
      16              : use thiserror::Error;
      17              : use tokio::time::Instant;
      18              : 
      19              : use crate::auth::backend::ComputeCredentialKeys;
      20              : use crate::context::RequestContext;
      21              : use crate::control_plane::errors::GetEndpointJwksError;
      22              : use crate::http::read_body_with_limit;
      23              : use crate::intern::RoleNameInt;
      24              : use crate::types::{EndpointId, RoleName};
      25              : 
      26              : // TODO(conrad): make these configurable.
      27              : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
      28              : const MIN_RENEW: Duration = Duration::from_secs(30);
      29              : const AUTO_RENEW: Duration = Duration::from_secs(300);
      30              : const MAX_RENEW: Duration = Duration::from_secs(3600);
      31              : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
      32              : const JWKS_USER_AGENT: &str = "neon-proxy";
      33              : 
      34              : const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
      35              : const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
      36              : const JWKS_FETCH_RETRIES: u32 = 3;
      37              : 
      38              : /// How to get the JWT auth rules
      39              : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
      40              :     fn fetch_auth_rules(
      41              :         &self,
      42              :         ctx: &RequestContext,
      43              :         endpoint: EndpointId,
      44              :     ) -> impl Future<Output = Result<Vec<AuthRule>, FetchAuthRulesError>> + Send;
      45              : }
      46              : 
      47              : #[derive(Error, Debug)]
      48              : pub(crate) enum FetchAuthRulesError {
      49              :     #[error(transparent)]
      50              :     GetEndpointJwks(#[from] GetEndpointJwksError),
      51              : 
      52              :     #[error("JWKs settings for this role were not configured")]
      53              :     RoleJwksNotConfigured,
      54              : }
      55              : 
      56              : #[derive(Clone)]
      57              : pub(crate) struct AuthRule {
      58              :     pub(crate) id: String,
      59              :     pub(crate) jwks_url: url::Url,
      60              :     pub(crate) audience: Option<String>,
      61              :     pub(crate) role_names: Vec<RoleNameInt>,
      62              : }
      63              : 
      64              : pub struct JwkCache {
      65              :     client: reqwest_middleware::ClientWithMiddleware,
      66              : 
      67              :     map: ClashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
      68              : }
      69              : 
      70              : pub(crate) struct JwkCacheEntry {
      71              :     /// Should refetch at least every hour to verify when old keys have been removed.
      72              :     /// Should refetch when new key IDs are seen only every 5 minutes or so
      73              :     last_retrieved: Instant,
      74              : 
      75              :     /// cplane will return multiple JWKs urls that we need to scrape.
      76              :     key_sets: ahash::HashMap<String, KeySet>,
      77              : }
      78              : 
      79              : impl JwkCacheEntry {
      80           19 :     fn find_jwk_and_audience(
      81           19 :         &self,
      82           19 :         key_id: &str,
      83           19 :         role_name: &RoleName,
      84           19 :     ) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
      85           19 :         self.key_sets
      86           19 :             .values()
      87           19 :             // make sure our requested role has access to the key set
      88           29 :             .filter(|key_set| key_set.role_names.iter().any(|role| **role == **role_name))
      89           19 :             // try and find the requested key-id in the key set
      90           22 :             .find_map(|key_set| {
      91           22 :                 key_set
      92           22 :                     .find_key(key_id)
      93           22 :                     .map(|jwk| (jwk, key_set.audience.as_deref()))
      94           22 :             })
      95           19 :     }
      96              : }
      97              : 
      98              : struct KeySet {
      99              :     jwks: jose_jwk::JwkSet,
     100              :     audience: Option<String>,
     101              :     role_names: Vec<RoleNameInt>,
     102              : }
     103              : 
     104              : impl KeySet {
     105           22 :     fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
     106           22 :         self.jwks
     107           22 :             .keys
     108           22 :             .iter()
     109           30 :             .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
     110           22 :     }
     111              : }
     112              : 
     113              : pub(crate) struct JwkCacheEntryLock {
     114              :     cached: ArcSwapOption<JwkCacheEntry>,
     115              :     lookup: tokio::sync::Semaphore,
     116              : }
     117              : 
     118              : impl Default for JwkCacheEntryLock {
     119            7 :     fn default() -> Self {
     120            7 :         JwkCacheEntryLock {
     121            7 :             cached: ArcSwapOption::empty(),
     122            7 :             lookup: tokio::sync::Semaphore::new(1),
     123            7 :         }
     124            7 :     }
     125              : }
     126              : 
     127            9 : #[derive(Deserialize)]
     128              : struct JwkSet<'a> {
     129              :     /// we parse into raw-value because not all keys in a JWKS are ones
     130              :     /// we can parse directly, so we parse them lazily.
     131              :     #[serde(borrow)]
     132              :     keys: Vec<&'a RawValue>,
     133              : }
     134              : 
     135              : /// Given a jwks_url, fetch the JWKS and parse out all the signing JWKs.
     136              : /// Returns `None` and log a warning if there are any errors.
     137            9 : async fn fetch_jwks(
     138            9 :     client: &reqwest_middleware::ClientWithMiddleware,
     139            9 :     jwks_url: url::Url,
     140            9 : ) -> Option<jose_jwk::JwkSet> {
     141            9 :     let req = client.get(jwks_url.clone());
     142              :     // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
     143            9 :     let resp = req.send().await.and_then(|r| {
     144            9 :         r.error_for_status()
     145            9 :             .map_err(reqwest_middleware::Error::Reqwest)
     146            9 :     });
     147              : 
     148            9 :     let resp = match resp {
     149            9 :         Ok(r) => r,
     150              :         // TODO: should we re-insert JWKs if we want to keep this JWKs URL?
     151              :         // I expect these failures would be quite sparse.
     152            0 :         Err(e) => {
     153            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not fetch JWKs");
     154            0 :             return None;
     155              :         }
     156              :     };
     157              : 
     158            9 :     let resp: http::Response<reqwest::Body> = resp.into();
     159              : 
     160            9 :     let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE).await {
     161            9 :         Ok(bytes) => bytes,
     162            0 :         Err(e) => {
     163            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
     164            0 :             return None;
     165              :         }
     166              :     };
     167              : 
     168            9 :     let jwks = match serde_json::from_slice::<JwkSet>(&bytes) {
     169            9 :         Ok(jwks) => jwks,
     170            0 :         Err(e) => {
     171            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
     172            0 :             return None;
     173              :         }
     174              :     };
     175              : 
     176              :     // `jose_jwk::Jwk` is quite large (288 bytes). Let's not pre-allocate for what we don't need.
     177              :     //
     178              :     // Even though we limit our responses to 64KiB, we could still receive a payload like
     179              :     // `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`. Parsing this as `RawValue` uses 468KiB.
     180              :     // Pre-allocating the corresponding `Vec::<jose_jwk::Jwk>::with_capacity(30000)` uses 8.2MiB.
     181            9 :     let mut keys = vec![];
     182            9 : 
     183            9 :     let mut failed = 0;
     184           23 :     for key in jwks.keys {
     185           14 :         let key = match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
     186           13 :             Ok(key) => key,
     187            1 :             Err(e) => {
     188            1 :                 tracing::debug!(url=?jwks_url, failed=?e, "could not decode JWK");
     189            1 :                 failed += 1;
     190            1 :                 continue;
     191              :             }
     192              :         };
     193              : 
     194              :         // if `use` (called `cls` in rust) is specified to be something other than signing,
     195              :         // we can skip storing it.
     196           13 :         if key
     197           13 :             .prm
     198           13 :             .cls
     199           13 :             .as_ref()
     200           13 :             .is_some_and(|c| *c != jose_jwk::Class::Signing)
     201              :         {
     202            0 :             continue;
     203           13 :         }
     204           13 : 
     205           13 :         keys.push(key);
     206              :     }
     207              : 
     208            9 :     keys.shrink_to_fit();
     209            9 : 
     210            9 :     if failed > 0 {
     211            1 :         tracing::warn!(url=?jwks_url, failed, "could not decode JWKs");
     212            8 :     }
     213              : 
     214            9 :     if keys.is_empty() {
     215            0 :         tracing::warn!(url=?jwks_url, "no valid JWKs found inside the response body");
     216            0 :         return None;
     217            9 :     }
     218            9 : 
     219            9 :     Some(jose_jwk::JwkSet { keys })
     220            9 : }
     221              : 
     222              : impl JwkCacheEntryLock {
     223            7 :     async fn acquire_permit(self: &Arc<Self>) -> JwkRenewalPermit<'_> {
     224            7 :         JwkRenewalPermit::acquire_permit(self).await
     225            7 :     }
     226              : 
     227            0 :     fn try_acquire_permit(self: &Arc<Self>) -> Option<JwkRenewalPermit<'_>> {
     228            0 :         JwkRenewalPermit::try_acquire_permit(self)
     229            0 :     }
     230              : 
     231            7 :     async fn renew_jwks<F: FetchAuthRules>(
     232            7 :         &self,
     233            7 :         _permit: JwkRenewalPermit<'_>,
     234            7 :         ctx: &RequestContext,
     235            7 :         client: &reqwest_middleware::ClientWithMiddleware,
     236            7 :         endpoint: EndpointId,
     237            7 :         auth_rules: &F,
     238            7 :     ) -> Result<Arc<JwkCacheEntry>, JwtError> {
     239            7 :         // double check that no one beat us to updating the cache.
     240            7 :         let now = Instant::now();
     241            7 :         let guard = self.cached.load_full();
     242            7 :         if let Some(cached) = guard {
     243            0 :             let last_update = now.duration_since(cached.last_retrieved);
     244            0 :             if last_update < Duration::from_secs(300) {
     245            0 :                 return Ok(cached);
     246            0 :             }
     247            7 :         }
     248              : 
     249            7 :         let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
     250            7 :         let mut key_sets =
     251            7 :             ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
     252              : 
     253              :         // TODO(conrad): run concurrently
     254              :         // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
     255           16 :         for rule in rules {
     256            9 :             if let Some(jwks) = fetch_jwks(client, rule.jwks_url).await {
     257            9 :                 key_sets.insert(
     258            9 :                     rule.id,
     259            9 :                     KeySet {
     260            9 :                         jwks,
     261            9 :                         audience: rule.audience,
     262            9 :                         role_names: rule.role_names,
     263            9 :                     },
     264            9 :                 );
     265            9 :             }
     266              :         }
     267              : 
     268            7 :         let entry = Arc::new(JwkCacheEntry {
     269            7 :             last_retrieved: now,
     270            7 :             key_sets,
     271            7 :         });
     272            7 :         self.cached.swap(Some(Arc::clone(&entry)));
     273            7 : 
     274            7 :         Ok(entry)
     275            7 :     }
     276              : 
     277           19 :     async fn get_or_update_jwk_cache<F: FetchAuthRules>(
     278           19 :         self: &Arc<Self>,
     279           19 :         ctx: &RequestContext,
     280           19 :         client: &reqwest_middleware::ClientWithMiddleware,
     281           19 :         endpoint: EndpointId,
     282           19 :         fetch: &F,
     283           19 :     ) -> Result<Arc<JwkCacheEntry>, JwtError> {
     284           19 :         let now = Instant::now();
     285           19 :         let guard = self.cached.load_full();
     286              : 
     287              :         // if we have no cached JWKs, try and get some
     288           19 :         let Some(cached) = guard else {
     289            7 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     290            7 :             let permit = self.acquire_permit().await;
     291            7 :             return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
     292              :         };
     293              : 
     294           12 :         let last_update = now.duration_since(cached.last_retrieved);
     295           12 : 
     296           12 :         // check if the cached JWKs need updating.
     297           12 :         if last_update > MAX_RENEW {
     298            0 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     299            0 :             let permit = self.acquire_permit().await;
     300              : 
     301              :             // it's been too long since we checked the keys. wait for them to update.
     302            0 :             return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
     303           12 :         }
     304           12 : 
     305           12 :         // every 5 minutes we should spawn a job to eagerly update the token.
     306           12 :         if last_update > AUTO_RENEW {
     307            0 :             if let Some(permit) = self.try_acquire_permit() {
     308            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit acquired");
     309            0 :                 let permit = permit.into_owned();
     310            0 :                 let entry = self.clone();
     311            0 :                 let client = client.clone();
     312            0 :                 let fetch = fetch.clone();
     313            0 :                 let ctx = ctx.clone();
     314            0 :                 tokio::spawn(async move {
     315            0 :                     if let Err(e) = entry
     316            0 :                         .renew_jwks(permit, &ctx, &client, endpoint, &fetch)
     317            0 :                         .await
     318              :                     {
     319            0 :                         tracing::warn!(error=?e, "could not fetch JWKs in background job");
     320            0 :                     }
     321            0 :                 });
     322            0 :             } else {
     323            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
     324              :             }
     325           12 :         }
     326              : 
     327           12 :         Ok(cached)
     328           19 :     }
     329              : 
     330           19 :     async fn check_jwt<F: FetchAuthRules>(
     331           19 :         self: &Arc<Self>,
     332           19 :         ctx: &RequestContext,
     333           19 :         jwt: &str,
     334           19 :         client: &reqwest_middleware::ClientWithMiddleware,
     335           19 :         endpoint: EndpointId,
     336           19 :         role_name: &RoleName,
     337           19 :         fetch: &F,
     338           19 :     ) -> Result<ComputeCredentialKeys, JwtError> {
     339              :         // JWT compact form is defined to be
     340              :         // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
     341              :         // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
     342              : 
     343           19 :         let (header_payload, signature) = jwt
     344           19 :             .rsplit_once('.')
     345           19 :             .ok_or(JwtEncodingError::InvalidCompactForm)?;
     346           19 :         let (header, payload) = header_payload
     347           19 :             .split_once('.')
     348           19 :             .ok_or(JwtEncodingError::InvalidCompactForm)?;
     349              : 
     350           19 :         let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?;
     351           19 :         let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
     352              : 
     353           19 :         let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)?;
     354           19 :         let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
     355              : 
     356           19 :         if let Some(iss) = &payload.issuer {
     357           12 :             ctx.set_jwt_issuer(iss.as_ref().to_owned());
     358           12 :         }
     359              : 
     360           19 :         let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)?;
     361              : 
     362           19 :         let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
     363              : 
     364           19 :         let mut guard = self
     365           19 :             .get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
     366           19 :             .await?;
     367              : 
     368              :         // get the key from the JWKs if possible. If not, wait for the keys to update.
     369           18 :         let (jwk, expected_audience) = loop {
     370           19 :             match guard.find_jwk_and_audience(&kid, role_name) {
     371           18 :                 Some(jwk) => break jwk,
     372            1 :                 None if guard.last_retrieved.elapsed() > MIN_RENEW => {
     373            0 :                     let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     374              : 
     375            0 :                     let permit = self.acquire_permit().await;
     376            0 :                     guard = self
     377            0 :                         .renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
     378            0 :                         .await?;
     379              :                 }
     380            1 :                 _ => return Err(JwtError::JwkNotFound),
     381              :             }
     382              :         };
     383              : 
     384           18 :         if !jwk.is_supported(&header.algorithm) {
     385            0 :             return Err(JwtError::SignatureAlgorithmNotSupported);
     386           18 :         }
     387           18 : 
     388           18 :         match &jwk.key {
     389           13 :             jose_jwk::Key::Ec(key) => {
     390           13 :                 verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
     391              :             }
     392            5 :             jose_jwk::Key::Rsa(key) => {
     393            5 :                 verify_rsa_signature(header_payload.as_bytes(), &sig, key, &header.algorithm)?;
     394              :             }
     395            0 :             key => return Err(JwtError::UnsupportedKeyType(key.into())),
     396              :         }
     397              : 
     398           17 :         tracing::debug!(?payload, "JWT signature valid with claims");
     399              : 
     400           17 :         if let Some(aud) = expected_audience {
     401            7 :             if payload.audience.0.iter().all(|s| s != aud) {
     402            5 :                 return Err(JwtError::InvalidClaims(
     403            5 :                     JwtClaimsError::InvalidJwtTokenAudience,
     404            5 :                 ));
     405            2 :             }
     406           10 :         }
     407              : 
     408           12 :         let now = SystemTime::now();
     409              : 
     410           12 :         if let Some(exp) = payload.expiration {
     411           11 :             if now >= exp + CLOCK_SKEW_LEEWAY {
     412            1 :                 return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired));
     413           10 :             }
     414            1 :         }
     415              : 
     416           11 :         if let Some(nbf) = payload.not_before {
     417           10 :             if nbf >= now + CLOCK_SKEW_LEEWAY {
     418            1 :                 return Err(JwtError::InvalidClaims(
     419            1 :                     JwtClaimsError::JwtTokenNotYetReadyToUse,
     420            1 :                 ));
     421            9 :             }
     422            1 :         }
     423              : 
     424           10 :         Ok(ComputeCredentialKeys::JwtPayload(payloadb))
     425           19 :     }
     426              : }
     427              : 
     428              : impl JwkCache {
     429           19 :     pub(crate) async fn check_jwt<F: FetchAuthRules>(
     430           19 :         &self,
     431           19 :         ctx: &RequestContext,
     432           19 :         endpoint: EndpointId,
     433           19 :         role_name: &RoleName,
     434           19 :         fetch: &F,
     435           19 :         jwt: &str,
     436           19 :     ) -> Result<ComputeCredentialKeys, JwtError> {
     437           19 :         // try with just a read lock first
     438           19 :         let key = (endpoint.clone(), role_name.clone());
     439           19 :         let entry = self.map.get(&key).as_deref().map(Arc::clone);
     440           19 :         let entry = entry.unwrap_or_else(|| {
     441            7 :             // acquire a write lock after to insert.
     442            7 :             let entry = self.map.entry(key).or_default();
     443            7 :             Arc::clone(&*entry)
     444           19 :         });
     445           19 : 
     446           19 :         entry
     447           19 :             .check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
     448           19 :             .await
     449           19 :     }
     450              : }
     451              : 
     452              : impl Default for JwkCache {
     453            9 :     fn default() -> Self {
     454            9 :         let client = Client::builder()
     455            9 :             .user_agent(JWKS_USER_AGENT)
     456            9 :             .redirect(redirect::Policy::none())
     457            9 :             .tls_built_in_native_certs(true)
     458            9 :             .connect_timeout(JWKS_CONNECT_TIMEOUT)
     459            9 :             .timeout(JWKS_FETCH_TIMEOUT)
     460            9 :             .build()
     461            9 :             .expect("client config should be valid");
     462            9 : 
     463            9 :         // Retry up to 3 times with increasing intervals between attempts.
     464            9 :         let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES);
     465            9 : 
     466            9 :         let client = reqwest_middleware::ClientBuilder::new(client)
     467            9 :             .with(RetryTransientMiddleware::new_with_policy(retry_policy))
     468            9 :             .build();
     469            9 : 
     470            9 :         JwkCache {
     471            9 :             client,
     472            9 :             map: ClashMap::default(),
     473            9 :         }
     474            9 :     }
     475              : }
     476              : 
     477           13 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> Result<(), JwtError> {
     478              :     use ecdsa::Signature;
     479              :     use signature::Verifier;
     480              : 
     481           13 :     match key.crv {
     482              :         jose_jwk::EcCurves::P256 => {
     483           13 :             let pk = p256::PublicKey::try_from(key).map_err(JwtError::InvalidP256Key)?;
     484           13 :             let key = p256::ecdsa::VerifyingKey::from(&pk);
     485           13 :             let sig = Signature::from_slice(sig)?;
     486           13 :             key.verify(data, &sig)?;
     487              :         }
     488            0 :         key => return Err(JwtError::UnsupportedEcKeyType(key)),
     489              :     }
     490              : 
     491           12 :     Ok(())
     492           13 : }
     493              : 
     494            5 : fn verify_rsa_signature(
     495            5 :     data: &[u8],
     496            5 :     sig: &[u8],
     497            5 :     key: &jose_jwk::Rsa,
     498            5 :     alg: &jose_jwa::Algorithm,
     499            5 : ) -> Result<(), JwtError> {
     500              :     use jose_jwa::{Algorithm, Signing};
     501              :     use rsa::RsaPublicKey;
     502              :     use rsa::pkcs1v15::{Signature, VerifyingKey};
     503              : 
     504            5 :     let key = RsaPublicKey::try_from(key).map_err(JwtError::InvalidRsaKey)?;
     505              : 
     506            5 :     match alg {
     507              :         Algorithm::Signing(Signing::Rs256) => {
     508            5 :             let key = VerifyingKey::<sha2::Sha256>::new(key);
     509            5 :             let sig = Signature::try_from(sig)?;
     510            5 :             key.verify(data, &sig)?;
     511              :         }
     512            0 :         _ => return Err(JwtError::InvalidRsaSigningAlgorithm),
     513              :     }
     514              : 
     515            5 :     Ok(())
     516            5 : }
     517              : 
     518              : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
     519           38 : #[derive(serde::Deserialize, serde::Serialize)]
     520              : struct JwtHeader<'a> {
     521              :     /// must be a supported alg
     522              :     #[serde(rename = "alg")]
     523              :     algorithm: jose_jwa::Algorithm,
     524              :     /// key id, must be provided for our usecase
     525              :     #[serde(rename = "kid", borrow)]
     526              :     key_id: Option<Cow<'a, str>>,
     527              : }
     528              : 
     529              : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
     530           97 : #[derive(serde::Deserialize, Debug)]
     531              : #[allow(dead_code)]
     532              : struct JwtPayload<'a> {
     533              :     /// Audience - Recipient for which the JWT is intended
     534              :     #[serde(rename = "aud", default)]
     535              :     audience: OneOrMany,
     536              :     /// Expiration - Time after which the JWT expires
     537              :     #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
     538              :     expiration: Option<SystemTime>,
     539              :     /// Not before - Time after which the JWT expires
     540              :     #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
     541              :     not_before: Option<SystemTime>,
     542              : 
     543              :     // the following entries are only extracted for the sake of debug logging.
     544              :     /// Issuer of the JWT
     545              :     #[serde(rename = "iss", borrow)]
     546              :     issuer: Option<Cow<'a, str>>,
     547              :     /// Subject of the JWT (the user)
     548              :     #[serde(rename = "sub", borrow)]
     549              :     subject: Option<Cow<'a, str>>,
     550              :     /// Unique token identifier
     551              :     #[serde(rename = "jti", borrow)]
     552              :     jwt_id: Option<Cow<'a, str>>,
     553              :     /// Unique session identifier
     554              :     #[serde(rename = "sid", borrow)]
     555              :     session_id: Option<Cow<'a, str>>,
     556              : }
     557              : 
     558              : /// `OneOrMany` supports parsing either a single item or an array of items.
     559              : ///
     560              : /// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
     561              : ///
     562              : /// > The "aud" (audience) claim identifies the recipients that the JWT is
     563              : /// > intended for.  Each principal intended to process the JWT MUST
     564              : /// > identify itself with a value in the audience claim.  If the principal
     565              : /// > processing the claim does not identify itself with a value in the
     566              : /// > "aud" claim when this claim is present, then the JWT MUST be
     567              : /// > rejected.  In the general case, the "aud" value is **an array of case-
     568              : /// > sensitive strings**, each containing a StringOrURI value.  In the
     569              : /// > special case when the JWT has one audience, the "aud" value MAY be a
     570              : /// > **single case-sensitive string** containing a StringOrURI value.  The
     571              : /// > interpretation of audience values is generally application specific.
     572              : /// > Use of this claim is OPTIONAL.
     573              : #[derive(Default, Debug)]
     574              : struct OneOrMany(Vec<String>);
     575              : 
     576              : impl<'de> Deserialize<'de> for OneOrMany {
     577           17 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     578           17 :     where
     579           17 :         D: Deserializer<'de>,
     580           17 :     {
     581              :         struct OneOrManyVisitor;
     582              :         impl<'de> Visitor<'de> for OneOrManyVisitor {
     583              :             type Value = OneOrMany;
     584              : 
     585            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
     586            0 :                 formatter.write_str("a single string or an array of strings")
     587            0 :             }
     588              : 
     589            2 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
     590            2 :             where
     591            2 :                 E: serde::de::Error,
     592            2 :             {
     593            2 :                 Ok(OneOrMany(vec![v.to_owned()]))
     594            2 :             }
     595              : 
     596           15 :             fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
     597           15 :             where
     598           15 :                 A: serde::de::SeqAccess<'de>,
     599           15 :             {
     600           15 :                 let mut v = vec![];
     601           52 :                 while let Some(s) = seq.next_element()? {
     602           37 :                     v.push(s);
     603           37 :                 }
     604           15 :                 Ok(OneOrMany(v))
     605           15 :             }
     606              :         }
     607           17 :         deserializer.deserialize_any(OneOrManyVisitor)
     608           17 :     }
     609              : }
     610              : 
     611           25 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
     612           25 :     let d = <Option<u64>>::deserialize(d)?;
     613           25 :     Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
     614           25 : }
     615              : 
     616              : struct JwkRenewalPermit<'a> {
     617              :     inner: Option<JwkRenewalPermitInner<'a>>,
     618              : }
     619              : 
     620              : enum JwkRenewalPermitInner<'a> {
     621              :     Owned(Arc<JwkCacheEntryLock>),
     622              :     Borrowed(&'a Arc<JwkCacheEntryLock>),
     623              : }
     624              : 
     625              : impl JwkRenewalPermit<'_> {
     626            0 :     fn into_owned(mut self) -> JwkRenewalPermit<'static> {
     627            0 :         JwkRenewalPermit {
     628            0 :             inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
     629            0 :         }
     630            0 :     }
     631              : 
     632            7 :     async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
     633            7 :         match from.lookup.acquire().await {
     634            7 :             Ok(permit) => {
     635            7 :                 permit.forget();
     636            7 :                 JwkRenewalPermit {
     637            7 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     638            7 :                 }
     639              :             }
     640            0 :             Err(_) => panic!("semaphore should not be closed"),
     641              :         }
     642            7 :     }
     643              : 
     644            0 :     fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
     645            0 :         match from.lookup.try_acquire() {
     646            0 :             Ok(permit) => {
     647            0 :                 permit.forget();
     648            0 :                 Some(JwkRenewalPermit {
     649            0 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     650            0 :                 })
     651              :             }
     652            0 :             Err(tokio::sync::TryAcquireError::NoPermits) => None,
     653            0 :             Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
     654              :         }
     655            0 :     }
     656              : }
     657              : 
     658              : impl JwkRenewalPermitInner<'_> {
     659            0 :     fn into_owned(self) -> JwkRenewalPermitInner<'static> {
     660            0 :         match self {
     661            0 :             JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
     662            0 :             JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
     663              :         }
     664            0 :     }
     665              : }
     666              : 
     667              : impl Drop for JwkRenewalPermit<'_> {
     668            7 :     fn drop(&mut self) {
     669            7 :         let entry = match &self.inner {
     670            0 :             None => return,
     671            0 :             Some(JwkRenewalPermitInner::Owned(p)) => p,
     672            7 :             Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
     673              :         };
     674            7 :         entry.lookup.add_permits(1);
     675            7 :     }
     676              : }
     677              : 
     678              : #[derive(Error, Debug)]
     679              : #[non_exhaustive]
     680              : pub(crate) enum JwtError {
     681              :     #[error("jwk not found")]
     682              :     JwkNotFound,
     683              : 
     684              :     #[error("missing key id")]
     685              :     MissingKeyId,
     686              : 
     687              :     #[error("Provided authentication token is not a valid JWT encoding")]
     688              :     JwtEncoding(#[from] JwtEncodingError),
     689              : 
     690              :     #[error(transparent)]
     691              :     InvalidClaims(#[from] JwtClaimsError),
     692              : 
     693              :     #[error("invalid P256 key")]
     694              :     InvalidP256Key(jose_jwk::crypto::Error),
     695              : 
     696              :     #[error("invalid RSA key")]
     697              :     InvalidRsaKey(jose_jwk::crypto::Error),
     698              : 
     699              :     #[error("invalid RSA signing algorithm")]
     700              :     InvalidRsaSigningAlgorithm,
     701              : 
     702              :     #[error("unsupported EC key type {0:?}")]
     703              :     UnsupportedEcKeyType(jose_jwk::EcCurves),
     704              : 
     705              :     #[error("unsupported key type {0:?}")]
     706              :     UnsupportedKeyType(KeyType),
     707              : 
     708              :     #[error("signature algorithm not supported")]
     709              :     SignatureAlgorithmNotSupported,
     710              : 
     711              :     #[error("signature error: {0}")]
     712              :     Signature(#[from] signature::Error),
     713              : 
     714              :     #[error("failed to fetch auth rules: {0}")]
     715              :     FetchAuthRules(#[from] FetchAuthRulesError),
     716              : }
     717              : 
     718              : impl From<base64::DecodeError> for JwtError {
     719            0 :     fn from(err: base64::DecodeError) -> Self {
     720            0 :         JwtEncodingError::Base64Decode(err).into()
     721            0 :     }
     722              : }
     723              : 
     724              : impl From<serde_json::Error> for JwtError {
     725            0 :     fn from(err: serde_json::Error) -> Self {
     726            0 :         JwtEncodingError::SerdeJson(err).into()
     727            0 :     }
     728              : }
     729              : 
     730              : #[derive(Error, Debug)]
     731              : #[non_exhaustive]
     732              : pub enum JwtEncodingError {
     733              :     #[error(transparent)]
     734              :     Base64Decode(#[from] base64::DecodeError),
     735              : 
     736              :     #[error(transparent)]
     737              :     SerdeJson(#[from] serde_json::Error),
     738              : 
     739              :     #[error("invalid compact form")]
     740              :     InvalidCompactForm,
     741              : }
     742              : 
     743              : #[derive(Error, Debug, PartialEq)]
     744              : #[non_exhaustive]
     745              : pub enum JwtClaimsError {
     746              :     #[error("invalid JWT token audience")]
     747              :     InvalidJwtTokenAudience,
     748              : 
     749              :     #[error("JWT token has expired")]
     750              :     JwtTokenHasExpired,
     751              : 
     752              :     #[error("JWT token is not yet ready to use")]
     753              :     JwtTokenNotYetReadyToUse,
     754              : }
     755              : 
     756              : #[allow(dead_code, reason = "Debug use only")]
     757              : #[derive(Debug)]
     758              : pub(crate) enum KeyType {
     759              :     Ec(jose_jwk::EcCurves),
     760              :     Rsa,
     761              :     Oct,
     762              :     Okp(jose_jwk::OkpCurves),
     763              :     Unknown,
     764              : }
     765              : 
     766              : impl From<&jose_jwk::Key> for KeyType {
     767            0 :     fn from(key: &jose_jwk::Key) -> Self {
     768            0 :         match key {
     769            0 :             jose_jwk::Key::Ec(ec) => Self::Ec(ec.crv),
     770            0 :             jose_jwk::Key::Rsa(_rsa) => Self::Rsa,
     771            0 :             jose_jwk::Key::Oct(_oct) => Self::Oct,
     772            0 :             jose_jwk::Key::Okp(okp) => Self::Okp(okp.crv),
     773            0 :             _ => Self::Unknown,
     774              :         }
     775            0 :     }
     776              : }
     777              : 
     778              : #[cfg(test)]
     779              : #[expect(clippy::unwrap_used)]
     780              : mod tests {
     781              :     use std::future::IntoFuture;
     782              :     use std::net::SocketAddr;
     783              :     use std::time::SystemTime;
     784              : 
     785              :     use base64::URL_SAFE_NO_PAD;
     786              :     use bytes::Bytes;
     787              :     use http::Response;
     788              :     use http_body_util::Full;
     789              :     use hyper::service::service_fn;
     790              :     use hyper_util::rt::TokioIo;
     791              :     use rand::rngs::OsRng;
     792              :     use rsa::pkcs8::DecodePrivateKey;
     793              :     use serde::Serialize;
     794              :     use serde_json::json;
     795              :     use signature::Signer;
     796              :     use tokio::net::TcpListener;
     797              : 
     798              :     use super::*;
     799              :     use crate::types::RoleName;
     800              : 
     801            6 :     fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
     802            6 :         let sk = p256::SecretKey::random(&mut OsRng);
     803            6 :         let pk = sk.public_key().into();
     804            6 :         let jwk = jose_jwk::Jwk {
     805            6 :             key: jose_jwk::Key::Ec(pk),
     806            6 :             prm: jose_jwk::Parameters {
     807            6 :                 kid: Some(kid),
     808            6 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
     809            6 :                 ..Default::default()
     810            6 :             },
     811            6 :         };
     812            6 :         (sk, jwk)
     813            6 :     }
     814              : 
     815            4 :     fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
     816            4 :         let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap();
     817            4 :         let pk = sk.to_public_key().into();
     818            4 :         let jwk = jose_jwk::Jwk {
     819            4 :             key: jose_jwk::Key::Rsa(pk),
     820            4 :             prm: jose_jwk::Parameters {
     821            4 :                 kid: Some(kid),
     822            4 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
     823            4 :                 ..Default::default()
     824            4 :             },
     825            4 :         };
     826            4 :         (sk, jwk)
     827            4 :     }
     828              : 
     829            8 :     fn now() -> u64 {
     830            8 :         SystemTime::now()
     831            8 :             .duration_since(SystemTime::UNIX_EPOCH)
     832            8 :             .unwrap()
     833            8 :             .as_secs()
     834            8 :     }
     835              : 
     836            7 :     fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
     837            7 :         let now = now();
     838            7 :         let body = typed_json::json! {{
     839            7 :             "exp": now + 3600,
     840            7 :             "nbf": now,
     841            7 :             "aud": ["audience1", "neon", "audience2"],
     842            7 :             "sub": "user1",
     843            7 :             "sid": "session1",
     844            7 :             "jti": "token1",
     845            7 :             "iss": "neon-testing",
     846            7 :         }};
     847            7 :         build_custom_jwt_payload(kid, body, sig)
     848            7 :     }
     849              : 
     850           15 :     fn build_custom_jwt_payload(
     851           15 :         kid: String,
     852           15 :         body: impl Serialize,
     853           15 :         sig: jose_jwa::Signing,
     854           15 :     ) -> String {
     855           15 :         let header = JwtHeader {
     856           15 :             algorithm: jose_jwa::Algorithm::Signing(sig),
     857           15 :             key_id: Some(Cow::Owned(kid)),
     858           15 :         };
     859           15 : 
     860           15 :         let header =
     861           15 :             base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
     862           15 :         let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
     863           15 : 
     864           15 :         format!("{header}.{body}")
     865           15 :     }
     866              : 
     867            3 :     fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
     868              :         use p256::ecdsa::{Signature, SigningKey};
     869              : 
     870            3 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
     871            3 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     872            3 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     873            3 : 
     874            3 :         format!("{payload}.{sig}")
     875            3 :     }
     876              : 
     877            8 :     fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
     878              :         use p256::ecdsa::{Signature, SigningKey};
     879              : 
     880            8 :         let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
     881            8 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     882            8 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     883            8 : 
     884            8 :         format!("{payload}.{sig}")
     885            8 :     }
     886              : 
     887            4 :     fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
     888              :         use rsa::pkcs1v15::SigningKey;
     889              :         use rsa::signature::SignatureEncoding;
     890              : 
     891            4 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
     892            4 :         let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
     893            4 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     894            4 : 
     895            4 :         format!("{payload}.{sig}")
     896            4 :     }
     897              : 
     898              :     // RSA key gen is slow....
     899              :     const RS1: &str = "-----BEGIN PRIVATE KEY-----
     900              : MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y
     901              : aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA
     902              : SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA
     903              : cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X
     904              : mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s
     905              : PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM
     906              : nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9
     907              : ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi
     908              : 5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+
     909              : KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz
     910              : 1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP
     911              : We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony
     912              : UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu
     913              : GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+
     914              : HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU
     915              : f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24
     916              : DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY
     917              : 9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF
     918              : /4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL
     919              : 3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE
     920              : qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d
     921              : ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW
     922              : rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf
     923              : JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI
     924              : rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9
     925              : 8LMFlz1QPqSZYN/A/kOcLBfa3A==
     926              : -----END PRIVATE KEY-----
     927              : ";
     928              :     const RS2: &str = "-----BEGIN PRIVATE KEY-----
     929              : MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J
     930              : HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J
     931              : zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk
     932              : qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO
     933              : IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt
     934              : ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy
     935              : uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv
     936              : K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG
     937              : LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe
     938              : QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT
     939              : zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX
     940              : 0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR
     941              : +Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1
     942              : P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL
     943              : x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/
     944              : FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg
     945              : iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M
     946              : 4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA
     947              : iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj
     948              : N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB
     949              : uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U
     950              : qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP
     951              : WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB
     952              : rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne
     953              : X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
     954              : 5JiconnI5aLek0QVPoFaVXFa
     955              : -----END PRIVATE KEY-----
     956              : ";
     957              : 
     958              :     #[derive(Clone)]
     959              :     struct Fetch(Vec<AuthRule>);
     960              : 
     961              :     impl FetchAuthRules for Fetch {
     962            7 :         async fn fetch_auth_rules(
     963            7 :             &self,
     964            7 :             _ctx: &RequestContext,
     965            7 :             _endpoint: EndpointId,
     966            7 :         ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
     967            7 :             Ok(self.0.clone())
     968            7 :         }
     969              :     }
     970              : 
     971            6 :     async fn jwks_server(
     972            6 :         router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
     973            6 :     ) -> SocketAddr {
     974            6 :         let router = Arc::new(router);
     975            9 :         let service = service_fn(move |req| {
     976            9 :             let router = Arc::clone(&router);
     977            9 :             async move {
     978            9 :                 match router(req.uri().path()) {
     979            9 :                     Some(body) => Response::builder()
     980            9 :                         .status(200)
     981            9 :                         .body(Full::new(Bytes::from(body))),
     982            0 :                     None => Response::builder()
     983            0 :                         .status(404)
     984            0 :                         .body(Full::new(Bytes::new())),
     985              :                 }
     986            9 :             }
     987            9 :         });
     988              : 
     989            6 :         let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
     990            6 :         let server = hyper::server::conn::http1::Builder::new();
     991            6 :         let addr = listener.local_addr().unwrap();
     992            6 :         tokio::spawn(async move {
     993              :             loop {
     994           12 :                 let (s, _) = listener.accept().await.unwrap();
     995            6 :                 let serve = server.serve_connection(TokioIo::new(s), service.clone());
     996            6 :                 tokio::spawn(serve.into_future());
     997              :             }
     998            6 :         });
     999            6 : 
    1000            6 :         addr
    1001            6 :     }
    1002              : 
    1003              :     #[tokio::test]
    1004            1 :     async fn check_jwt_happy_path() {
    1005            1 :         let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
    1006            1 :         let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
    1007            1 :         let (ec1, jwk3) = new_ec_jwk("ec1".into());
    1008            1 :         let (ec2, jwk4) = new_ec_jwk("ec2".into());
    1009            1 : 
    1010            1 :         let foo_jwks = jose_jwk::JwkSet {
    1011            1 :             keys: vec![jwk1, jwk3],
    1012            1 :         };
    1013            1 :         let bar_jwks = jose_jwk::JwkSet {
    1014            1 :             keys: vec![jwk2, jwk4],
    1015            1 :         };
    1016            1 : 
    1017            4 :         let jwks_addr = jwks_server(move |path| match path {
    1018            4 :             "/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
    1019            2 :             "/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
    1020            1 :             _ => None,
    1021            4 :         })
    1022            1 :         .await;
    1023            1 : 
    1024            1 :         let role_name1 = RoleName::from("anonymous");
    1025            1 :         let role_name2 = RoleName::from("authenticated");
    1026            1 : 
    1027            1 :         let roles = vec![
    1028            1 :             RoleNameInt::from(&role_name1),
    1029            1 :             RoleNameInt::from(&role_name2),
    1030            1 :         ];
    1031            1 :         let rules = vec![
    1032            1 :             AuthRule {
    1033            1 :                 id: "foo".to_owned(),
    1034            1 :                 jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
    1035            1 :                 audience: None,
    1036            1 :                 role_names: roles.clone(),
    1037            1 :             },
    1038            1 :             AuthRule {
    1039            1 :                 id: "bar".to_owned(),
    1040            1 :                 jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
    1041            1 :                 audience: None,
    1042            1 :                 role_names: roles.clone(),
    1043            1 :             },
    1044            1 :         ];
    1045            1 : 
    1046            1 :         let fetch = Fetch(rules);
    1047            1 :         let jwk_cache = JwkCache::default();
    1048            1 : 
    1049            1 :         let endpoint = EndpointId::from("ep");
    1050            1 : 
    1051            1 :         let jwt1 = new_rsa_jwt("rs1".into(), rs1);
    1052            1 :         let jwt2 = new_rsa_jwt("rs2".into(), rs2);
    1053            1 :         let jwt3 = new_ec_jwt("ec1".into(), &ec1);
    1054            1 :         let jwt4 = new_ec_jwt("ec2".into(), &ec2);
    1055            1 : 
    1056            1 :         let tokens = [jwt1, jwt2, jwt3, jwt4];
    1057            1 :         let role_names = [role_name1, role_name2];
    1058            3 :         for role in &role_names {
    1059           10 :             for token in &tokens {
    1060            8 :                 jwk_cache
    1061            8 :                     .check_jwt(
    1062            8 :                         &RequestContext::test(),
    1063            8 :                         endpoint.clone(),
    1064            8 :                         role,
    1065            8 :                         &fetch,
    1066            8 :                         token,
    1067            8 :                     )
    1068            8 :                     .await
    1069            8 :                     .unwrap();
    1070            1 :             }
    1071            1 :         }
    1072            1 :     }
    1073              : 
    1074              :     /// AWS Cognito escapes the `/` in the URL.
    1075              :     #[tokio::test]
    1076            1 :     async fn check_jwt_regression_cognito_issuer() {
    1077            1 :         let (key, jwk) = new_ec_jwk("key".into());
    1078            1 : 
    1079            1 :         let now = now();
    1080            1 :         let token = new_custom_ec_jwt(
    1081            1 :             "key".into(),
    1082            1 :             &key,
    1083            1 :             typed_json::json! {{
    1084            1 :                 "sub": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
    1085            1 :                 // cognito uses `\/`. I cannot replicated that easily here as serde_json will refuse
    1086            1 :                 // to write that escape character. instead I will make a bogus URL using `\` instead.
    1087            1 :                 "iss": "https:\\\\cognito-idp.us-west-2.amazonaws.com\\us-west-2_abcdefgh",
    1088            1 :                 "client_id": "abcdefghijklmnopqrstuvwxyz",
    1089            1 :                 "origin_jti": "6759d132-3fe7-446e-9e90-2fe7e8017893",
    1090            1 :                 "event_id": "ec9c36ab-b01d-46a0-94e4-87fde6767065",
    1091            1 :                 "token_use": "access",
    1092            1 :                 "scope": "aws.cognito.signin.user.admin",
    1093            1 :                 "auth_time":now,
    1094            1 :                 "exp":now + 60,
    1095            1 :                 "iat":now,
    1096            1 :                 "jti": "b241614b-0b93-4bdc-96db-0a3c7061d9c0",
    1097            1 :                 "username": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
    1098            1 :             }},
    1099            1 :         );
    1100            1 : 
    1101            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1102            1 : 
    1103            1 :         let jwks_addr = jwks_server(move |_path| Some(serde_json::to_vec(&jwks).unwrap())).await;
    1104            1 : 
    1105            1 :         let role_name = RoleName::from("anonymous");
    1106            1 :         let rules = vec![AuthRule {
    1107            1 :             id: "aws-cognito".to_owned(),
    1108            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1109            1 :             audience: None,
    1110            1 :             role_names: vec![RoleNameInt::from(&role_name)],
    1111            1 :         }];
    1112            1 : 
    1113            1 :         let fetch = Fetch(rules);
    1114            1 :         let jwk_cache = JwkCache::default();
    1115            1 : 
    1116            1 :         let endpoint = EndpointId::from("ep");
    1117            1 : 
    1118            1 :         jwk_cache
    1119            1 :             .check_jwt(
    1120            1 :                 &RequestContext::test(),
    1121            1 :                 endpoint.clone(),
    1122            1 :                 &role_name,
    1123            1 :                 &fetch,
    1124            1 :                 &token,
    1125            1 :             )
    1126            1 :             .await
    1127            1 :             .unwrap();
    1128            1 :     }
    1129              : 
    1130              :     #[tokio::test]
    1131            1 :     async fn check_jwt_invalid_signature() {
    1132            1 :         let (_, jwk) = new_ec_jwk("1".into());
    1133            1 :         let (key, _) = new_ec_jwk("1".into());
    1134            1 : 
    1135            1 :         // has a matching kid, but signed by the wrong key
    1136            1 :         let bad_jwt = new_ec_jwt("1".into(), &key);
    1137            1 : 
    1138            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1139            1 :         let jwks_addr = jwks_server(move |path| match path {
    1140            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1141            1 :             _ => None,
    1142            1 :         })
    1143            1 :         .await;
    1144            1 : 
    1145            1 :         let role = RoleName::from("authenticated");
    1146            1 : 
    1147            1 :         let rules = vec![AuthRule {
    1148            1 :             id: String::new(),
    1149            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1150            1 :             audience: None,
    1151            1 :             role_names: vec![RoleNameInt::from(&role)],
    1152            1 :         }];
    1153            1 : 
    1154            1 :         let fetch = Fetch(rules);
    1155            1 :         let jwk_cache = JwkCache::default();
    1156            1 : 
    1157            1 :         let ep = EndpointId::from("ep");
    1158            1 : 
    1159            1 :         let ctx = RequestContext::test();
    1160            1 :         let err = jwk_cache
    1161            1 :             .check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
    1162            1 :             .await
    1163            1 :             .unwrap_err();
    1164            1 :         assert!(
    1165            1 :             matches!(err, JwtError::Signature(_)),
    1166            1 :             "expected \"signature error\", got {err:?}"
    1167            1 :         );
    1168            1 :     }
    1169              : 
    1170              :     #[tokio::test]
    1171            1 :     async fn check_jwt_unknown_role() {
    1172            1 :         let (key, jwk) = new_rsa_jwk(RS1, "1".into());
    1173            1 :         let jwt = new_rsa_jwt("1".into(), key);
    1174            1 : 
    1175            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1176            1 :         let jwks_addr = jwks_server(move |path| match path {
    1177            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1178            1 :             _ => None,
    1179            1 :         })
    1180            1 :         .await;
    1181            1 : 
    1182            1 :         let role = RoleName::from("authenticated");
    1183            1 :         let rules = vec![AuthRule {
    1184            1 :             id: String::new(),
    1185            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1186            1 :             audience: None,
    1187            1 :             role_names: vec![RoleNameInt::from(&role)],
    1188            1 :         }];
    1189            1 : 
    1190            1 :         let fetch = Fetch(rules);
    1191            1 :         let jwk_cache = JwkCache::default();
    1192            1 : 
    1193            1 :         let ep = EndpointId::from("ep");
    1194            1 : 
    1195            1 :         // this role_name is not accepted
    1196            1 :         let bad_role_name = RoleName::from("cloud_admin");
    1197            1 : 
    1198            1 :         let ctx = RequestContext::test();
    1199            1 :         let err = jwk_cache
    1200            1 :             .check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
    1201            1 :             .await
    1202            1 :             .unwrap_err();
    1203            1 : 
    1204            1 :         assert!(
    1205            1 :             matches!(err, JwtError::JwkNotFound),
    1206            1 :             "expected \"jwk not found\", got {err:?}"
    1207            1 :         );
    1208            1 :     }
    1209              : 
    1210              :     #[tokio::test]
    1211            1 :     async fn check_jwt_invalid_claims() {
    1212            1 :         let (key, jwk) = new_ec_jwk("1".into());
    1213            1 : 
    1214            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1215            1 :         let jwks_addr = jwks_server(move |path| match path {
    1216            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1217            1 :             _ => None,
    1218            1 :         })
    1219            1 :         .await;
    1220            1 : 
    1221            1 :         let now = SystemTime::now()
    1222            1 :             .duration_since(SystemTime::UNIX_EPOCH)
    1223            1 :             .unwrap()
    1224            1 :             .as_secs();
    1225            1 : 
    1226            1 :         struct Test {
    1227            1 :             body: serde_json::Value,
    1228            1 :             error: JwtClaimsError,
    1229            1 :         }
    1230            1 : 
    1231            1 :         let table = vec![
    1232            1 :             Test {
    1233            1 :                 body: json! {{
    1234            1 :                     "nbf": now + 60,
    1235            1 :                     "aud": "neon",
    1236            1 :                 }},
    1237            1 :                 error: JwtClaimsError::JwtTokenNotYetReadyToUse,
    1238            1 :             },
    1239            1 :             Test {
    1240            1 :                 body: json! {{
    1241            1 :                     "exp": now - 60,
    1242            1 :                     "aud": ["neon"],
    1243            1 :                 }},
    1244            1 :                 error: JwtClaimsError::JwtTokenHasExpired,
    1245            1 :             },
    1246            1 :             Test {
    1247            1 :                 body: json! {{
    1248            1 :                 }},
    1249            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1250            1 :             },
    1251            1 :             Test {
    1252            1 :                 body: json! {{
    1253            1 :                     "aud": [],
    1254            1 :                 }},
    1255            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1256            1 :             },
    1257            1 :             Test {
    1258            1 :                 body: json! {{
    1259            1 :                     "aud": "foo",
    1260            1 :                 }},
    1261            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1262            1 :             },
    1263            1 :             Test {
    1264            1 :                 body: json! {{
    1265            1 :                     "aud": ["foo"],
    1266            1 :                 }},
    1267            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1268            1 :             },
    1269            1 :             Test {
    1270            1 :                 body: json! {{
    1271            1 :                     "aud": ["foo", "bar"],
    1272            1 :                 }},
    1273            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1274            1 :             },
    1275            1 :         ];
    1276            1 : 
    1277            1 :         let role = RoleName::from("authenticated");
    1278            1 : 
    1279            1 :         let rules = vec![AuthRule {
    1280            1 :             id: String::new(),
    1281            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1282            1 :             audience: Some("neon".to_string()),
    1283            1 :             role_names: vec![RoleNameInt::from(&role)],
    1284            1 :         }];
    1285            1 : 
    1286            1 :         let fetch = Fetch(rules);
    1287            1 :         let jwk_cache = JwkCache::default();
    1288            1 : 
    1289            1 :         let ep = EndpointId::from("ep");
    1290            1 : 
    1291            1 :         let ctx = RequestContext::test();
    1292            8 :         for test in table {
    1293            7 :             let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
    1294            7 : 
    1295            7 :             match jwk_cache
    1296            7 :                 .check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
    1297            7 :                 .await
    1298            1 :             {
    1299            7 :                 Err(JwtError::InvalidClaims(error)) if error == test.error => {}
    1300            1 :                 Err(err) => {
    1301            0 :                     panic!("expected {:?}, got {err:?}", test.error)
    1302            1 :                 }
    1303            1 :                 Ok(_payload) => {
    1304            0 :                     panic!("expected {:?}, got ok", test.error)
    1305            1 :                 }
    1306            1 :             }
    1307            1 :         }
    1308            1 :     }
    1309              : 
    1310              :     #[tokio::test]
    1311            1 :     async fn check_jwk_keycloak_regression() {
    1312            1 :         let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into());
    1313            1 :         let valid_jwk = serde_json::to_value(valid_jwk).unwrap();
    1314            1 : 
    1315            1 :         // This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones.
    1316            1 :         // This is taken directly from keycloak.
    1317            1 :         let invalid_jwk = serde_json::json! {
    1318            1 :             {
    1319            1 :                 "kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ",
    1320            1 :                 "kty": "RSA",
    1321            1 :                 "alg": "RSA-OAEP",
    1322            1 :                 "use": "enc",
    1323            1 :                 "n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw",
    1324            1 :                 "e": "AQAB",
    1325            1 :                 "x5c": [
    1326            1 :                     "MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI="
    1327            1 :                 ],
    1328            1 :                 "x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs",
    1329            1 :                 "x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU"
    1330            1 :             }
    1331            1 :         };
    1332            1 : 
    1333            1 :         let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }};
    1334            1 :         let jwks_addr = jwks_server(move |path| match path {
    1335            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1336            1 :             _ => None,
    1337            1 :         })
    1338            1 :         .await;
    1339            1 : 
    1340            1 :         let role_name = RoleName::from("anonymous");
    1341            1 :         let role = RoleNameInt::from(&role_name);
    1342            1 : 
    1343            1 :         let rules = vec![AuthRule {
    1344            1 :             id: "foo".to_owned(),
    1345            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1346            1 :             audience: None,
    1347            1 :             role_names: vec![role],
    1348            1 :         }];
    1349            1 : 
    1350            1 :         let fetch = Fetch(rules);
    1351            1 :         let jwk_cache = JwkCache::default();
    1352            1 : 
    1353            1 :         let endpoint = EndpointId::from("ep");
    1354            1 : 
    1355            1 :         let token = new_rsa_jwt("rs1".into(), rs);
    1356            1 : 
    1357            1 :         jwk_cache
    1358            1 :             .check_jwt(
    1359            1 :                 &RequestContext::test(),
    1360            1 :                 endpoint.clone(),
    1361            1 :                 &role_name,
    1362            1 :                 &fetch,
    1363            1 :                 &token,
    1364            1 :             )
    1365            1 :             .await
    1366            1 :             .unwrap();
    1367            1 :     }
    1368              : }
        

Generated by: LCOV version 2.1-beta