LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - jwt.rs (source / functions) Coverage Total Hit
Test: 20b6afc7b7f34578dcaab2b3acdaecfe91cd8bf1.info Lines: 89.0 % 860 765
Test Date: 2024-11-25 17:48:16 Functions: 52.7 % 201 106

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::future::Future;
       3              : use std::sync::Arc;
       4              : use std::time::{Duration, SystemTime};
       5              : 
       6              : use arc_swap::ArcSwapOption;
       7              : use dashmap::DashMap;
       8              : use jose_jwk::crypto::KeyInfo;
       9              : use reqwest::{redirect, Client};
      10              : use reqwest_retry::policies::ExponentialBackoff;
      11              : use reqwest_retry::RetryTransientMiddleware;
      12              : use serde::de::Visitor;
      13              : use serde::{Deserialize, Deserializer};
      14              : use serde_json::value::RawValue;
      15              : use signature::Verifier;
      16              : use thiserror::Error;
      17              : use tokio::time::Instant;
      18              : 
      19              : use crate::auth::backend::ComputeCredentialKeys;
      20              : use crate::context::RequestContext;
      21              : use crate::control_plane::errors::GetEndpointJwksError;
      22              : use crate::http::read_body_with_limit;
      23              : use crate::intern::RoleNameInt;
      24              : use crate::types::{EndpointId, RoleName};
      25              : 
      26              : // TODO(conrad): make these configurable.
      27              : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
      28              : const MIN_RENEW: Duration = Duration::from_secs(30);
      29              : const AUTO_RENEW: Duration = Duration::from_secs(300);
      30              : const MAX_RENEW: Duration = Duration::from_secs(3600);
      31              : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
      32              : const JWKS_USER_AGENT: &str = "neon-proxy";
      33              : 
      34              : const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
      35              : const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
      36              : const JWKS_FETCH_RETRIES: u32 = 3;
      37              : 
      38              : /// How to get the JWT auth rules
      39              : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
      40              :     fn fetch_auth_rules(
      41              :         &self,
      42              :         ctx: &RequestContext,
      43              :         endpoint: EndpointId,
      44              :     ) -> impl Future<Output = Result<Vec<AuthRule>, FetchAuthRulesError>> + Send;
      45              : }
      46              : 
      47            0 : #[derive(Error, Debug)]
      48              : pub(crate) enum FetchAuthRulesError {
      49              :     #[error(transparent)]
      50              :     GetEndpointJwks(#[from] GetEndpointJwksError),
      51              : 
      52              :     #[error("JWKs settings for this role were not configured")]
      53              :     RoleJwksNotConfigured,
      54              : }
      55              : 
      56              : #[derive(Clone)]
      57              : pub(crate) struct AuthRule {
      58              :     pub(crate) id: String,
      59              :     pub(crate) jwks_url: url::Url,
      60              :     pub(crate) audience: Option<String>,
      61              :     pub(crate) role_names: Vec<RoleNameInt>,
      62              : }
      63              : 
      64              : pub struct JwkCache {
      65              :     client: reqwest_middleware::ClientWithMiddleware,
      66              : 
      67              :     map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
      68              : }
      69              : 
      70              : pub(crate) struct JwkCacheEntry {
      71              :     /// Should refetch at least every hour to verify when old keys have been removed.
      72              :     /// Should refetch when new key IDs are seen only every 5 minutes or so
      73              :     last_retrieved: Instant,
      74              : 
      75              :     /// cplane will return multiple JWKs urls that we need to scrape.
      76              :     key_sets: ahash::HashMap<String, KeySet>,
      77              : }
      78              : 
      79              : impl JwkCacheEntry {
      80           19 :     fn find_jwk_and_audience(
      81           19 :         &self,
      82           19 :         key_id: &str,
      83           19 :         role_name: &RoleName,
      84           19 :     ) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
      85           19 :         self.key_sets
      86           19 :             .values()
      87           19 :             // make sure our requested role has access to the key set
      88           29 :             .filter(|key_set| key_set.role_names.iter().any(|role| **role == **role_name))
      89           19 :             // try and find the requested key-id in the key set
      90           22 :             .find_map(|key_set| {
      91           22 :                 key_set
      92           22 :                     .find_key(key_id)
      93           22 :                     .map(|jwk| (jwk, key_set.audience.as_deref()))
      94           22 :             })
      95           19 :     }
      96              : }
      97              : 
      98              : struct KeySet {
      99              :     jwks: jose_jwk::JwkSet,
     100              :     audience: Option<String>,
     101              :     role_names: Vec<RoleNameInt>,
     102              : }
     103              : 
     104              : impl KeySet {
     105           22 :     fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
     106           22 :         self.jwks
     107           22 :             .keys
     108           22 :             .iter()
     109           30 :             .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
     110           22 :     }
     111              : }
     112              : 
     113              : pub(crate) struct JwkCacheEntryLock {
     114              :     cached: ArcSwapOption<JwkCacheEntry>,
     115              :     lookup: tokio::sync::Semaphore,
     116              : }
     117              : 
     118              : impl Default for JwkCacheEntryLock {
     119            7 :     fn default() -> Self {
     120            7 :         JwkCacheEntryLock {
     121            7 :             cached: ArcSwapOption::empty(),
     122            7 :             lookup: tokio::sync::Semaphore::new(1),
     123            7 :         }
     124            7 :     }
     125              : }
     126              : 
     127           18 : #[derive(Deserialize)]
     128              : struct JwkSet<'a> {
     129              :     /// we parse into raw-value because not all keys in a JWKS are ones
     130              :     /// we can parse directly, so we parse them lazily.
     131              :     #[serde(borrow)]
     132              :     keys: Vec<&'a RawValue>,
     133              : }
     134              : 
     135              : /// Given a jwks_url, fetch the JWKS and parse out all the signing JWKs.
     136              : /// Returns `None` and log a warning if there are any errors.
     137            9 : async fn fetch_jwks(
     138            9 :     client: &reqwest_middleware::ClientWithMiddleware,
     139            9 :     jwks_url: url::Url,
     140            9 : ) -> Option<jose_jwk::JwkSet> {
     141            9 :     let req = client.get(jwks_url.clone());
     142              :     // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
     143           21 :     let resp = req.send().await.and_then(|r| {
     144            9 :         r.error_for_status()
     145            9 :             .map_err(reqwest_middleware::Error::Reqwest)
     146            9 :     });
     147              : 
     148            9 :     let resp = match resp {
     149            9 :         Ok(r) => r,
     150              :         // TODO: should we re-insert JWKs if we want to keep this JWKs URL?
     151              :         // I expect these failures would be quite sparse.
     152            0 :         Err(e) => {
     153            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not fetch JWKs");
     154            0 :             return None;
     155              :         }
     156              :     };
     157              : 
     158            9 :     let resp: http::Response<reqwest::Body> = resp.into();
     159              : 
     160            9 :     let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE).await {
     161            9 :         Ok(bytes) => bytes,
     162            0 :         Err(e) => {
     163            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
     164            0 :             return None;
     165              :         }
     166              :     };
     167              : 
     168            9 :     let jwks = match serde_json::from_slice::<JwkSet>(&bytes) {
     169            9 :         Ok(jwks) => jwks,
     170            0 :         Err(e) => {
     171            0 :             tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
     172            0 :             return None;
     173              :         }
     174              :     };
     175              : 
     176              :     // `jose_jwk::Jwk` is quite large (288 bytes). Let's not pre-allocate for what we don't need.
     177              :     //
     178              :     // Even though we limit our responses to 64KiB, we could still receive a payload like
     179              :     // `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`. Parsing this as `RawValue` uses 468KiB.
     180              :     // Pre-allocating the corresponding `Vec::<jose_jwk::Jwk>::with_capacity(30000)` uses 8.2MiB.
     181            9 :     let mut keys = vec![];
     182            9 : 
     183            9 :     let mut failed = 0;
     184           23 :     for key in jwks.keys {
     185           14 :         let key = match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
     186           13 :             Ok(key) => key,
     187            1 :             Err(e) => {
     188            1 :                 tracing::debug!(url=?jwks_url, failed=?e, "could not decode JWK");
     189            1 :                 failed += 1;
     190            1 :                 continue;
     191              :             }
     192              :         };
     193              : 
     194              :         // if `use` (called `cls` in rust) is specified to be something other than signing,
     195              :         // we can skip storing it.
     196           13 :         if key
     197           13 :             .prm
     198           13 :             .cls
     199           13 :             .as_ref()
     200           13 :             .is_some_and(|c| *c != jose_jwk::Class::Signing)
     201              :         {
     202            0 :             continue;
     203           13 :         }
     204           13 : 
     205           13 :         keys.push(key);
     206              :     }
     207              : 
     208            9 :     keys.shrink_to_fit();
     209            9 : 
     210            9 :     if failed > 0 {
     211            1 :         tracing::warn!(url=?jwks_url, failed, "could not decode JWKs");
     212            8 :     }
     213              : 
     214            9 :     if keys.is_empty() {
     215            0 :         tracing::warn!(url=?jwks_url, "no valid JWKs found inside the response body");
     216            0 :         return None;
     217            9 :     }
     218            9 : 
     219            9 :     Some(jose_jwk::JwkSet { keys })
     220            9 : }
     221              : 
     222              : impl JwkCacheEntryLock {
     223            7 :     async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
     224            7 :         JwkRenewalPermit::acquire_permit(self).await
     225            7 :     }
     226              : 
     227            0 :     fn try_acquire_permit<'a>(self: &'a Arc<Self>) -> Option<JwkRenewalPermit<'a>> {
     228            0 :         JwkRenewalPermit::try_acquire_permit(self)
     229            0 :     }
     230              : 
     231            7 :     async fn renew_jwks<F: FetchAuthRules>(
     232            7 :         &self,
     233            7 :         _permit: JwkRenewalPermit<'_>,
     234            7 :         ctx: &RequestContext,
     235            7 :         client: &reqwest_middleware::ClientWithMiddleware,
     236            7 :         endpoint: EndpointId,
     237            7 :         auth_rules: &F,
     238            7 :     ) -> Result<Arc<JwkCacheEntry>, JwtError> {
     239            7 :         // double check that no one beat us to updating the cache.
     240            7 :         let now = Instant::now();
     241            7 :         let guard = self.cached.load_full();
     242            7 :         if let Some(cached) = guard {
     243            0 :             let last_update = now.duration_since(cached.last_retrieved);
     244            0 :             if last_update < Duration::from_secs(300) {
     245            0 :                 return Ok(cached);
     246            0 :             }
     247            7 :         }
     248              : 
     249            7 :         let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
     250            7 :         let mut key_sets =
     251            7 :             ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
     252              : 
     253              :         // TODO(conrad): run concurrently
     254              :         // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
     255           16 :         for rule in rules {
     256           21 :             if let Some(jwks) = fetch_jwks(client, rule.jwks_url).await {
     257            9 :                 key_sets.insert(
     258            9 :                     rule.id,
     259            9 :                     KeySet {
     260            9 :                         jwks,
     261            9 :                         audience: rule.audience,
     262            9 :                         role_names: rule.role_names,
     263            9 :                     },
     264            9 :                 );
     265            9 :             }
     266              :         }
     267              : 
     268            7 :         let entry = Arc::new(JwkCacheEntry {
     269            7 :             last_retrieved: now,
     270            7 :             key_sets,
     271            7 :         });
     272            7 :         self.cached.swap(Some(Arc::clone(&entry)));
     273            7 : 
     274            7 :         Ok(entry)
     275            7 :     }
     276              : 
     277           19 :     async fn get_or_update_jwk_cache<F: FetchAuthRules>(
     278           19 :         self: &Arc<Self>,
     279           19 :         ctx: &RequestContext,
     280           19 :         client: &reqwest_middleware::ClientWithMiddleware,
     281           19 :         endpoint: EndpointId,
     282           19 :         fetch: &F,
     283           19 :     ) -> Result<Arc<JwkCacheEntry>, JwtError> {
     284           19 :         let now = Instant::now();
     285           19 :         let guard = self.cached.load_full();
     286              : 
     287              :         // if we have no cached JWKs, try and get some
     288           19 :         let Some(cached) = guard else {
     289            7 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     290            7 :             let permit = self.acquire_permit().await;
     291           21 :             return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
     292              :         };
     293              : 
     294           12 :         let last_update = now.duration_since(cached.last_retrieved);
     295           12 : 
     296           12 :         // check if the cached JWKs need updating.
     297           12 :         if last_update > MAX_RENEW {
     298            0 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     299            0 :             let permit = self.acquire_permit().await;
     300              : 
     301              :             // it's been too long since we checked the keys. wait for them to update.
     302            0 :             return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
     303           12 :         }
     304           12 : 
     305           12 :         // every 5 minutes we should spawn a job to eagerly update the token.
     306           12 :         if last_update > AUTO_RENEW {
     307            0 :             if let Some(permit) = self.try_acquire_permit() {
     308            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit acquired");
     309            0 :                 let permit = permit.into_owned();
     310            0 :                 let entry = self.clone();
     311            0 :                 let client = client.clone();
     312            0 :                 let fetch = fetch.clone();
     313            0 :                 let ctx = ctx.clone();
     314            0 :                 tokio::spawn(async move {
     315            0 :                     if let Err(e) = entry
     316            0 :                         .renew_jwks(permit, &ctx, &client, endpoint, &fetch)
     317            0 :                         .await
     318              :                     {
     319            0 :                         tracing::warn!(error=?e, "could not fetch JWKs in background job");
     320            0 :                     }
     321            0 :                 });
     322            0 :             } else {
     323            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
     324              :             }
     325           12 :         }
     326              : 
     327           12 :         Ok(cached)
     328           19 :     }
     329              : 
     330           19 :     async fn check_jwt<F: FetchAuthRules>(
     331           19 :         self: &Arc<Self>,
     332           19 :         ctx: &RequestContext,
     333           19 :         jwt: &str,
     334           19 :         client: &reqwest_middleware::ClientWithMiddleware,
     335           19 :         endpoint: EndpointId,
     336           19 :         role_name: &RoleName,
     337           19 :         fetch: &F,
     338           19 :     ) -> Result<ComputeCredentialKeys, JwtError> {
     339              :         // JWT compact form is defined to be
     340              :         // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
     341              :         // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
     342              : 
     343           19 :         let (header_payload, signature) = jwt
     344           19 :             .rsplit_once('.')
     345           19 :             .ok_or(JwtEncodingError::InvalidCompactForm)?;
     346           19 :         let (header, payload) = header_payload
     347           19 :             .split_once('.')
     348           19 :             .ok_or(JwtEncodingError::InvalidCompactForm)?;
     349              : 
     350           19 :         let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?;
     351           19 :         let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
     352              : 
     353           19 :         let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)?;
     354              : 
     355           19 :         let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
     356              : 
     357           19 :         let mut guard = self
     358           19 :             .get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
     359           21 :             .await?;
     360              : 
     361              :         // get the key from the JWKs if possible. If not, wait for the keys to update.
     362           18 :         let (jwk, expected_audience) = loop {
     363           19 :             match guard.find_jwk_and_audience(&kid, role_name) {
     364           18 :                 Some(jwk) => break jwk,
     365            1 :                 None if guard.last_retrieved.elapsed() > MIN_RENEW => {
     366            0 :                     let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     367              : 
     368            0 :                     let permit = self.acquire_permit().await;
     369            0 :                     guard = self
     370            0 :                         .renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
     371            0 :                         .await?;
     372              :                 }
     373            1 :                 _ => return Err(JwtError::JwkNotFound),
     374              :             }
     375              :         };
     376              : 
     377           18 :         if !jwk.is_supported(&header.algorithm) {
     378            0 :             return Err(JwtError::SignatureAlgorithmNotSupported);
     379           18 :         }
     380           18 : 
     381           18 :         match &jwk.key {
     382           13 :             jose_jwk::Key::Ec(key) => {
     383           13 :                 verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
     384              :             }
     385            5 :             jose_jwk::Key::Rsa(key) => {
     386            5 :                 verify_rsa_signature(header_payload.as_bytes(), &sig, key, &header.algorithm)?;
     387              :             }
     388            0 :             key => return Err(JwtError::UnsupportedKeyType(key.into())),
     389              :         };
     390              : 
     391           17 :         let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)?;
     392           17 :         let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
     393              : 
     394           17 :         tracing::debug!(?payload, "JWT signature valid with claims");
     395              : 
     396           17 :         if let Some(aud) = expected_audience {
     397            7 :             if payload.audience.0.iter().all(|s| s != aud) {
     398            5 :                 return Err(JwtError::InvalidClaims(
     399            5 :                     JwtClaimsError::InvalidJwtTokenAudience,
     400            5 :                 ));
     401            2 :             }
     402           10 :         }
     403              : 
     404           12 :         let now = SystemTime::now();
     405              : 
     406           12 :         if let Some(exp) = payload.expiration {
     407           11 :             if now >= exp + CLOCK_SKEW_LEEWAY {
     408            1 :                 return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired));
     409           10 :             }
     410            1 :         }
     411              : 
     412           11 :         if let Some(nbf) = payload.not_before {
     413           10 :             if nbf >= now + CLOCK_SKEW_LEEWAY {
     414            1 :                 return Err(JwtError::InvalidClaims(
     415            1 :                     JwtClaimsError::JwtTokenNotYetReadyToUse,
     416            1 :                 ));
     417            9 :             }
     418            1 :         }
     419              : 
     420           10 :         Ok(ComputeCredentialKeys::JwtPayload(payloadb))
     421           19 :     }
     422              : }
     423              : 
     424              : impl JwkCache {
     425           19 :     pub(crate) async fn check_jwt<F: FetchAuthRules>(
     426           19 :         &self,
     427           19 :         ctx: &RequestContext,
     428           19 :         endpoint: EndpointId,
     429           19 :         role_name: &RoleName,
     430           19 :         fetch: &F,
     431           19 :         jwt: &str,
     432           19 :     ) -> Result<ComputeCredentialKeys, JwtError> {
     433           19 :         // try with just a read lock first
     434           19 :         let key = (endpoint.clone(), role_name.clone());
     435           19 :         let entry = self.map.get(&key).as_deref().map(Arc::clone);
     436           19 :         let entry = entry.unwrap_or_else(|| {
     437            7 :             // acquire a write lock after to insert.
     438            7 :             let entry = self.map.entry(key).or_default();
     439            7 :             Arc::clone(&*entry)
     440           19 :         });
     441           19 : 
     442           19 :         entry
     443           19 :             .check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
     444           21 :             .await
     445           19 :     }
     446              : }
     447              : 
     448              : impl Default for JwkCache {
     449            9 :     fn default() -> Self {
     450            9 :         let client = Client::builder()
     451            9 :             .user_agent(JWKS_USER_AGENT)
     452            9 :             .redirect(redirect::Policy::none())
     453            9 :             .tls_built_in_native_certs(true)
     454            9 :             .connect_timeout(JWKS_CONNECT_TIMEOUT)
     455            9 :             .timeout(JWKS_FETCH_TIMEOUT)
     456            9 :             .build()
     457            9 :             .expect("client config should be valid");
     458            9 : 
     459            9 :         // Retry up to 3 times with increasing intervals between attempts.
     460            9 :         let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES);
     461            9 : 
     462            9 :         let client = reqwest_middleware::ClientBuilder::new(client)
     463            9 :             .with(RetryTransientMiddleware::new_with_policy(retry_policy))
     464            9 :             .build();
     465            9 : 
     466            9 :         JwkCache {
     467            9 :             client,
     468            9 :             map: DashMap::default(),
     469            9 :         }
     470            9 :     }
     471              : }
     472              : 
     473           13 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> Result<(), JwtError> {
     474              :     use ecdsa::Signature;
     475              :     use signature::Verifier;
     476              : 
     477           13 :     match key.crv {
     478              :         jose_jwk::EcCurves::P256 => {
     479           13 :             let pk = p256::PublicKey::try_from(key).map_err(JwtError::InvalidP256Key)?;
     480           13 :             let key = p256::ecdsa::VerifyingKey::from(&pk);
     481           13 :             let sig = Signature::from_slice(sig)?;
     482           13 :             key.verify(data, &sig)?;
     483              :         }
     484            0 :         key => return Err(JwtError::UnsupportedEcKeyType(key)),
     485              :     }
     486              : 
     487           12 :     Ok(())
     488           13 : }
     489              : 
     490            5 : fn verify_rsa_signature(
     491            5 :     data: &[u8],
     492            5 :     sig: &[u8],
     493            5 :     key: &jose_jwk::Rsa,
     494            5 :     alg: &jose_jwa::Algorithm,
     495            5 : ) -> Result<(), JwtError> {
     496              :     use jose_jwa::{Algorithm, Signing};
     497              :     use rsa::pkcs1v15::{Signature, VerifyingKey};
     498              :     use rsa::RsaPublicKey;
     499              : 
     500            5 :     let key = RsaPublicKey::try_from(key).map_err(JwtError::InvalidRsaKey)?;
     501              : 
     502            5 :     match alg {
     503              :         Algorithm::Signing(Signing::Rs256) => {
     504            5 :             let key = VerifyingKey::<sha2::Sha256>::new(key);
     505            5 :             let sig = Signature::try_from(sig)?;
     506            5 :             key.verify(data, &sig)?;
     507              :         }
     508            0 :         _ => return Err(JwtError::InvalidRsaSigningAlgorithm),
     509              :     };
     510              : 
     511            5 :     Ok(())
     512            5 : }
     513              : 
     514              : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
     515           57 : #[derive(serde::Deserialize, serde::Serialize)]
     516              : struct JwtHeader<'a> {
     517              :     /// must be a supported alg
     518              :     #[serde(rename = "alg")]
     519              :     algorithm: jose_jwa::Algorithm,
     520              :     /// key id, must be provided for our usecase
     521              :     #[serde(rename = "kid", borrow)]
     522              :     key_id: Option<Cow<'a, str>>,
     523              : }
     524              : 
     525              : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
     526          121 : #[derive(serde::Deserialize, Debug)]
     527              : #[allow(dead_code)]
     528              : struct JwtPayload<'a> {
     529              :     /// Audience - Recipient for which the JWT is intended
     530              :     #[serde(rename = "aud", default)]
     531              :     audience: OneOrMany,
     532              :     /// Expiration - Time after which the JWT expires
     533              :     #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
     534              :     expiration: Option<SystemTime>,
     535              :     /// Not before - Time after which the JWT expires
     536              :     #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
     537              :     not_before: Option<SystemTime>,
     538              : 
     539              :     // the following entries are only extracted for the sake of debug logging.
     540              :     /// Issuer of the JWT
     541              :     #[serde(rename = "iss", borrow)]
     542              :     issuer: Option<Cow<'a, str>>,
     543              :     /// Subject of the JWT (the user)
     544              :     #[serde(rename = "sub", borrow)]
     545              :     subject: Option<Cow<'a, str>>,
     546              :     /// Unique token identifier
     547              :     #[serde(rename = "jti", borrow)]
     548              :     jwt_id: Option<Cow<'a, str>>,
     549              :     /// Unique session identifier
     550              :     #[serde(rename = "sid", borrow)]
     551              :     session_id: Option<Cow<'a, str>>,
     552              : }
     553              : 
     554              : /// `OneOrMany` supports parsing either a single item or an array of items.
     555              : ///
     556              : /// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
     557              : ///
     558              : /// > The "aud" (audience) claim identifies the recipients that the JWT is
     559              : /// > intended for.  Each principal intended to process the JWT MUST
     560              : /// > identify itself with a value in the audience claim.  If the principal
     561              : /// > processing the claim does not identify itself with a value in the
     562              : /// > "aud" claim when this claim is present, then the JWT MUST be
     563              : /// > rejected.  In the general case, the "aud" value is **an array of case-
     564              : /// > sensitive strings**, each containing a StringOrURI value.  In the
     565              : /// > special case when the JWT has one audience, the "aud" value MAY be a
     566              : /// > **single case-sensitive string** containing a StringOrURI value.  The
     567              : /// > interpretation of audience values is generally application specific.
     568              : /// > Use of this claim is OPTIONAL.
     569              : #[derive(Default, Debug)]
     570              : struct OneOrMany(Vec<String>);
     571              : 
     572              : impl<'de> Deserialize<'de> for OneOrMany {
     573           15 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     574           15 :     where
     575           15 :         D: Deserializer<'de>,
     576           15 :     {
     577              :         struct OneOrManyVisitor;
     578              :         impl<'de> Visitor<'de> for OneOrManyVisitor {
     579              :             type Value = OneOrMany;
     580              : 
     581            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
     582            0 :                 formatter.write_str("a single string or an array of strings")
     583            0 :             }
     584              : 
     585            2 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
     586            2 :             where
     587            2 :                 E: serde::de::Error,
     588            2 :             {
     589            2 :                 Ok(OneOrMany(vec![v.to_owned()]))
     590            2 :             }
     591              : 
     592           13 :             fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
     593           13 :             where
     594           13 :                 A: serde::de::SeqAccess<'de>,
     595           13 :             {
     596           13 :                 let mut v = vec![];
     597           44 :                 while let Some(s) = seq.next_element()? {
     598           31 :                     v.push(s);
     599           31 :                 }
     600           13 :                 Ok(OneOrMany(v))
     601           13 :             }
     602              :         }
     603           15 :         deserializer.deserialize_any(OneOrManyVisitor)
     604           15 :     }
     605              : }
     606              : 
     607           21 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
     608           21 :     let d = <Option<u64>>::deserialize(d)?;
     609           21 :     Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
     610           21 : }
     611              : 
     612              : struct JwkRenewalPermit<'a> {
     613              :     inner: Option<JwkRenewalPermitInner<'a>>,
     614              : }
     615              : 
     616              : enum JwkRenewalPermitInner<'a> {
     617              :     Owned(Arc<JwkCacheEntryLock>),
     618              :     Borrowed(&'a Arc<JwkCacheEntryLock>),
     619              : }
     620              : 
     621              : impl JwkRenewalPermit<'_> {
     622            0 :     fn into_owned(mut self) -> JwkRenewalPermit<'static> {
     623            0 :         JwkRenewalPermit {
     624            0 :             inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
     625            0 :         }
     626            0 :     }
     627              : 
     628            7 :     async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
     629            7 :         match from.lookup.acquire().await {
     630            7 :             Ok(permit) => {
     631            7 :                 permit.forget();
     632            7 :                 JwkRenewalPermit {
     633            7 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     634            7 :                 }
     635              :             }
     636            0 :             Err(_) => panic!("semaphore should not be closed"),
     637              :         }
     638            7 :     }
     639              : 
     640            0 :     fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
     641            0 :         match from.lookup.try_acquire() {
     642            0 :             Ok(permit) => {
     643            0 :                 permit.forget();
     644            0 :                 Some(JwkRenewalPermit {
     645            0 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     646            0 :                 })
     647              :             }
     648            0 :             Err(tokio::sync::TryAcquireError::NoPermits) => None,
     649            0 :             Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
     650              :         }
     651            0 :     }
     652              : }
     653              : 
     654              : impl JwkRenewalPermitInner<'_> {
     655            0 :     fn into_owned(self) -> JwkRenewalPermitInner<'static> {
     656            0 :         match self {
     657            0 :             JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
     658            0 :             JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
     659              :         }
     660            0 :     }
     661              : }
     662              : 
     663              : impl Drop for JwkRenewalPermit<'_> {
     664            7 :     fn drop(&mut self) {
     665            7 :         let entry = match &self.inner {
     666            0 :             None => return,
     667            0 :             Some(JwkRenewalPermitInner::Owned(p)) => p,
     668            7 :             Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
     669              :         };
     670            7 :         entry.lookup.add_permits(1);
     671            7 :     }
     672              : }
     673              : 
     674            1 : #[derive(Error, Debug)]
     675              : #[non_exhaustive]
     676              : pub(crate) enum JwtError {
     677              :     #[error("jwk not found")]
     678              :     JwkNotFound,
     679              : 
     680              :     #[error("missing key id")]
     681              :     MissingKeyId,
     682              : 
     683              :     #[error("Provided authentication token is not a valid JWT encoding")]
     684              :     JwtEncoding(#[from] JwtEncodingError),
     685              : 
     686              :     #[error(transparent)]
     687              :     InvalidClaims(#[from] JwtClaimsError),
     688              : 
     689              :     #[error("invalid P256 key")]
     690              :     InvalidP256Key(jose_jwk::crypto::Error),
     691              : 
     692              :     #[error("invalid RSA key")]
     693              :     InvalidRsaKey(jose_jwk::crypto::Error),
     694              : 
     695              :     #[error("invalid RSA signing algorithm")]
     696              :     InvalidRsaSigningAlgorithm,
     697              : 
     698              :     #[error("unsupported EC key type {0:?}")]
     699              :     UnsupportedEcKeyType(jose_jwk::EcCurves),
     700              : 
     701              :     #[error("unsupported key type {0:?}")]
     702              :     UnsupportedKeyType(KeyType),
     703              : 
     704              :     #[error("signature algorithm not supported")]
     705              :     SignatureAlgorithmNotSupported,
     706              : 
     707              :     #[error("signature error: {0}")]
     708              :     Signature(#[from] signature::Error),
     709              : 
     710              :     #[error("failed to fetch auth rules: {0}")]
     711              :     FetchAuthRules(#[from] FetchAuthRulesError),
     712              : }
     713              : 
     714              : impl From<base64::DecodeError> for JwtError {
     715            0 :     fn from(err: base64::DecodeError) -> Self {
     716            0 :         JwtEncodingError::Base64Decode(err).into()
     717            0 :     }
     718              : }
     719              : 
     720              : impl From<serde_json::Error> for JwtError {
     721            0 :     fn from(err: serde_json::Error) -> Self {
     722            0 :         JwtEncodingError::SerdeJson(err).into()
     723            0 :     }
     724              : }
     725              : 
     726            0 : #[derive(Error, Debug)]
     727              : #[non_exhaustive]
     728              : pub enum JwtEncodingError {
     729              :     #[error(transparent)]
     730              :     Base64Decode(#[from] base64::DecodeError),
     731              : 
     732              :     #[error(transparent)]
     733              :     SerdeJson(#[from] serde_json::Error),
     734              : 
     735              :     #[error("invalid compact form")]
     736              :     InvalidCompactForm,
     737              : }
     738              : 
     739            0 : #[derive(Error, Debug, PartialEq)]
     740              : #[non_exhaustive]
     741              : pub enum JwtClaimsError {
     742              :     #[error("invalid JWT token audience")]
     743              :     InvalidJwtTokenAudience,
     744              : 
     745              :     #[error("JWT token has expired")]
     746              :     JwtTokenHasExpired,
     747              : 
     748              :     #[error("JWT token is not yet ready to use")]
     749              :     JwtTokenNotYetReadyToUse,
     750              : }
     751              : 
     752              : #[allow(dead_code, reason = "Debug use only")]
     753              : #[derive(Debug)]
     754              : pub(crate) enum KeyType {
     755              :     Ec(jose_jwk::EcCurves),
     756              :     Rsa,
     757              :     Oct,
     758              :     Okp(jose_jwk::OkpCurves),
     759              :     Unknown,
     760              : }
     761              : 
     762              : impl From<&jose_jwk::Key> for KeyType {
     763            0 :     fn from(key: &jose_jwk::Key) -> Self {
     764            0 :         match key {
     765            0 :             jose_jwk::Key::Ec(ec) => Self::Ec(ec.crv),
     766            0 :             jose_jwk::Key::Rsa(_rsa) => Self::Rsa,
     767            0 :             jose_jwk::Key::Oct(_oct) => Self::Oct,
     768            0 :             jose_jwk::Key::Okp(okp) => Self::Okp(okp.crv),
     769            0 :             _ => Self::Unknown,
     770              :         }
     771            0 :     }
     772              : }
     773              : 
     774              : #[cfg(test)]
     775              : mod tests {
     776              :     use std::future::IntoFuture;
     777              :     use std::net::SocketAddr;
     778              :     use std::time::SystemTime;
     779              : 
     780              :     use base64::URL_SAFE_NO_PAD;
     781              :     use bytes::Bytes;
     782              :     use http::Response;
     783              :     use http_body_util::Full;
     784              :     use hyper::service::service_fn;
     785              :     use hyper_util::rt::TokioIo;
     786              :     use rand::rngs::OsRng;
     787              :     use rsa::pkcs8::DecodePrivateKey;
     788              :     use serde::Serialize;
     789              :     use serde_json::json;
     790              :     use signature::Signer;
     791              :     use tokio::net::TcpListener;
     792              : 
     793              :     use super::*;
     794              :     use crate::types::RoleName;
     795              : 
     796            6 :     fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
     797            6 :         let sk = p256::SecretKey::random(&mut OsRng);
     798            6 :         let pk = sk.public_key().into();
     799            6 :         let jwk = jose_jwk::Jwk {
     800            6 :             key: jose_jwk::Key::Ec(pk),
     801            6 :             prm: jose_jwk::Parameters {
     802            6 :                 kid: Some(kid),
     803            6 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
     804            6 :                 ..Default::default()
     805            6 :             },
     806            6 :         };
     807            6 :         (sk, jwk)
     808            6 :     }
     809              : 
     810            4 :     fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
     811            4 :         let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap();
     812            4 :         let pk = sk.to_public_key().into();
     813            4 :         let jwk = jose_jwk::Jwk {
     814            4 :             key: jose_jwk::Key::Rsa(pk),
     815            4 :             prm: jose_jwk::Parameters {
     816            4 :                 kid: Some(kid),
     817            4 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
     818            4 :                 ..Default::default()
     819            4 :             },
     820            4 :         };
     821            4 :         (sk, jwk)
     822            4 :     }
     823              : 
     824            8 :     fn now() -> u64 {
     825            8 :         SystemTime::now()
     826            8 :             .duration_since(SystemTime::UNIX_EPOCH)
     827            8 :             .unwrap()
     828            8 :             .as_secs()
     829            8 :     }
     830              : 
     831            7 :     fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
     832            7 :         let now = now();
     833            7 :         let body = typed_json::json! {{
     834            7 :             "exp": now + 3600,
     835            7 :             "nbf": now,
     836            7 :             "aud": ["audience1", "neon", "audience2"],
     837            7 :             "sub": "user1",
     838            7 :             "sid": "session1",
     839            7 :             "jti": "token1",
     840            7 :             "iss": "neon-testing",
     841            7 :         }};
     842            7 :         build_custom_jwt_payload(kid, body, sig)
     843            7 :     }
     844              : 
     845           15 :     fn build_custom_jwt_payload(
     846           15 :         kid: String,
     847           15 :         body: impl Serialize,
     848           15 :         sig: jose_jwa::Signing,
     849           15 :     ) -> String {
     850           15 :         let header = JwtHeader {
     851           15 :             algorithm: jose_jwa::Algorithm::Signing(sig),
     852           15 :             key_id: Some(Cow::Owned(kid)),
     853           15 :         };
     854           15 : 
     855           15 :         let header =
     856           15 :             base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
     857           15 :         let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
     858           15 : 
     859           15 :         format!("{header}.{body}")
     860           15 :     }
     861              : 
     862            3 :     fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
     863              :         use p256::ecdsa::{Signature, SigningKey};
     864              : 
     865            3 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
     866            3 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     867            3 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     868            3 : 
     869            3 :         format!("{payload}.{sig}")
     870            3 :     }
     871              : 
     872            8 :     fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
     873              :         use p256::ecdsa::{Signature, SigningKey};
     874              : 
     875            8 :         let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
     876            8 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     877            8 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     878            8 : 
     879            8 :         format!("{payload}.{sig}")
     880            8 :     }
     881              : 
     882            4 :     fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
     883              :         use rsa::pkcs1v15::SigningKey;
     884              :         use rsa::signature::SignatureEncoding;
     885              : 
     886            4 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
     887            4 :         let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
     888            4 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     889            4 : 
     890            4 :         format!("{payload}.{sig}")
     891            4 :     }
     892              : 
     893              :     // RSA key gen is slow....
     894              :     const RS1: &str = "-----BEGIN PRIVATE KEY-----
     895              : MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y
     896              : aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA
     897              : SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA
     898              : cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X
     899              : mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s
     900              : PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM
     901              : nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9
     902              : ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi
     903              : 5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+
     904              : KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz
     905              : 1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP
     906              : We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony
     907              : UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu
     908              : GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+
     909              : HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU
     910              : f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24
     911              : DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY
     912              : 9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF
     913              : /4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL
     914              : 3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE
     915              : qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d
     916              : ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW
     917              : rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf
     918              : JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI
     919              : rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9
     920              : 8LMFlz1QPqSZYN/A/kOcLBfa3A==
     921              : -----END PRIVATE KEY-----
     922              : ";
     923              :     const RS2: &str = "-----BEGIN PRIVATE KEY-----
     924              : MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J
     925              : HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J
     926              : zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk
     927              : qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO
     928              : IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt
     929              : ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy
     930              : uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv
     931              : K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG
     932              : LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe
     933              : QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT
     934              : zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX
     935              : 0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR
     936              : +Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1
     937              : P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL
     938              : x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/
     939              : FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg
     940              : iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M
     941              : 4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA
     942              : iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj
     943              : N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB
     944              : uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U
     945              : qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP
     946              : WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB
     947              : rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne
     948              : X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
     949              : 5JiconnI5aLek0QVPoFaVXFa
     950              : -----END PRIVATE KEY-----
     951              : ";
     952              : 
     953              :     #[derive(Clone)]
     954              :     struct Fetch(Vec<AuthRule>);
     955              : 
     956              :     impl FetchAuthRules for Fetch {
     957            7 :         async fn fetch_auth_rules(
     958            7 :             &self,
     959            7 :             _ctx: &RequestContext,
     960            7 :             _endpoint: EndpointId,
     961            7 :         ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
     962            7 :             Ok(self.0.clone())
     963            7 :         }
     964              :     }
     965              : 
     966            6 :     async fn jwks_server(
     967            6 :         router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
     968            6 :     ) -> SocketAddr {
     969            6 :         let router = Arc::new(router);
     970            9 :         let service = service_fn(move |req| {
     971            9 :             let router = Arc::clone(&router);
     972            9 :             async move {
     973            9 :                 match router(req.uri().path()) {
     974            9 :                     Some(body) => Response::builder()
     975            9 :                         .status(200)
     976            9 :                         .body(Full::new(Bytes::from(body))),
     977            0 :                     None => Response::builder()
     978            0 :                         .status(404)
     979            0 :                         .body(Full::new(Bytes::new())),
     980              :                 }
     981            9 :             }
     982            9 :         });
     983              : 
     984            6 :         let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
     985            6 :         let server = hyper::server::conn::http1::Builder::new();
     986            6 :         let addr = listener.local_addr().unwrap();
     987            6 :         tokio::spawn(async move {
     988              :             loop {
     989           12 :                 let (s, _) = listener.accept().await.unwrap();
     990            6 :                 let serve = server.serve_connection(TokioIo::new(s), service.clone());
     991            6 :                 tokio::spawn(serve.into_future());
     992              :             }
     993            6 :         });
     994            6 : 
     995            6 :         addr
     996            6 :     }
     997              : 
     998              :     #[tokio::test]
     999            1 :     async fn check_jwt_happy_path() {
    1000            1 :         let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
    1001            1 :         let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
    1002            1 :         let (ec1, jwk3) = new_ec_jwk("ec1".into());
    1003            1 :         let (ec2, jwk4) = new_ec_jwk("ec2".into());
    1004            1 : 
    1005            1 :         let foo_jwks = jose_jwk::JwkSet {
    1006            1 :             keys: vec![jwk1, jwk3],
    1007            1 :         };
    1008            1 :         let bar_jwks = jose_jwk::JwkSet {
    1009            1 :             keys: vec![jwk2, jwk4],
    1010            1 :         };
    1011            1 : 
    1012            4 :         let jwks_addr = jwks_server(move |path| match path {
    1013            4 :             "/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
    1014            2 :             "/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
    1015            1 :             _ => None,
    1016            4 :         })
    1017            1 :         .await;
    1018            1 : 
    1019            1 :         let role_name1 = RoleName::from("anonymous");
    1020            1 :         let role_name2 = RoleName::from("authenticated");
    1021            1 : 
    1022            1 :         let roles = vec![
    1023            1 :             RoleNameInt::from(&role_name1),
    1024            1 :             RoleNameInt::from(&role_name2),
    1025            1 :         ];
    1026            1 :         let rules = vec![
    1027            1 :             AuthRule {
    1028            1 :                 id: "foo".to_owned(),
    1029            1 :                 jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
    1030            1 :                 audience: None,
    1031            1 :                 role_names: roles.clone(),
    1032            1 :             },
    1033            1 :             AuthRule {
    1034            1 :                 id: "bar".to_owned(),
    1035            1 :                 jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
    1036            1 :                 audience: None,
    1037            1 :                 role_names: roles.clone(),
    1038            1 :             },
    1039            1 :         ];
    1040            1 : 
    1041            1 :         let fetch = Fetch(rules);
    1042            1 :         let jwk_cache = JwkCache::default();
    1043            1 : 
    1044            1 :         let endpoint = EndpointId::from("ep");
    1045            1 : 
    1046            1 :         let jwt1 = new_rsa_jwt("rs1".into(), rs1);
    1047            1 :         let jwt2 = new_rsa_jwt("rs2".into(), rs2);
    1048            1 :         let jwt3 = new_ec_jwt("ec1".into(), &ec1);
    1049            1 :         let jwt4 = new_ec_jwt("ec2".into(), &ec2);
    1050            1 : 
    1051            1 :         let tokens = [jwt1, jwt2, jwt3, jwt4];
    1052            1 :         let role_names = [role_name1, role_name2];
    1053            3 :         for role in &role_names {
    1054           10 :             for token in &tokens {
    1055            8 :                 jwk_cache
    1056            8 :                     .check_jwt(
    1057            8 :                         &RequestContext::test(),
    1058            8 :                         endpoint.clone(),
    1059            8 :                         role,
    1060            8 :                         &fetch,
    1061            8 :                         token,
    1062            8 :                     )
    1063            6 :                     .await
    1064            8 :                     .unwrap();
    1065            1 :             }
    1066            1 :         }
    1067            1 :     }
    1068              : 
    1069              :     /// AWS Cognito escapes the `/` in the URL.
    1070              :     #[tokio::test]
    1071            1 :     async fn check_jwt_regression_cognito_issuer() {
    1072            1 :         let (key, jwk) = new_ec_jwk("key".into());
    1073            1 : 
    1074            1 :         let now = now();
    1075            1 :         let token = new_custom_ec_jwt(
    1076            1 :             "key".into(),
    1077            1 :             &key,
    1078            1 :             typed_json::json! {{
    1079            1 :                 "sub": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
    1080            1 :                 // cognito uses `\/`. I cannot replicated that easily here as serde_json will refuse
    1081            1 :                 // to write that escape character. instead I will make a bogus URL using `\` instead.
    1082            1 :                 "iss": "https:\\\\cognito-idp.us-west-2.amazonaws.com\\us-west-2_abcdefgh",
    1083            1 :                 "client_id": "abcdefghijklmnopqrstuvwxyz",
    1084            1 :                 "origin_jti": "6759d132-3fe7-446e-9e90-2fe7e8017893",
    1085            1 :                 "event_id": "ec9c36ab-b01d-46a0-94e4-87fde6767065",
    1086            1 :                 "token_use": "access",
    1087            1 :                 "scope": "aws.cognito.signin.user.admin",
    1088            1 :                 "auth_time":now,
    1089            1 :                 "exp":now + 60,
    1090            1 :                 "iat":now,
    1091            1 :                 "jti": "b241614b-0b93-4bdc-96db-0a3c7061d9c0",
    1092            1 :                 "username": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
    1093            1 :             }},
    1094            1 :         );
    1095            1 : 
    1096            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1097            1 : 
    1098            1 :         let jwks_addr = jwks_server(move |_path| Some(serde_json::to_vec(&jwks).unwrap())).await;
    1099            1 : 
    1100            1 :         let role_name = RoleName::from("anonymous");
    1101            1 :         let rules = vec![AuthRule {
    1102            1 :             id: "aws-cognito".to_owned(),
    1103            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1104            1 :             audience: None,
    1105            1 :             role_names: vec![RoleNameInt::from(&role_name)],
    1106            1 :         }];
    1107            1 : 
    1108            1 :         let fetch = Fetch(rules);
    1109            1 :         let jwk_cache = JwkCache::default();
    1110            1 : 
    1111            1 :         let endpoint = EndpointId::from("ep");
    1112            1 : 
    1113            1 :         jwk_cache
    1114            1 :             .check_jwt(
    1115            1 :                 &RequestContext::test(),
    1116            1 :                 endpoint.clone(),
    1117            1 :                 &role_name,
    1118            1 :                 &fetch,
    1119            1 :                 &token,
    1120            1 :             )
    1121            3 :             .await
    1122            1 :             .unwrap();
    1123            1 :     }
    1124              : 
    1125              :     #[tokio::test]
    1126            1 :     async fn check_jwt_invalid_signature() {
    1127            1 :         let (_, jwk) = new_ec_jwk("1".into());
    1128            1 :         let (key, _) = new_ec_jwk("1".into());
    1129            1 : 
    1130            1 :         // has a matching kid, but signed by the wrong key
    1131            1 :         let bad_jwt = new_ec_jwt("1".into(), &key);
    1132            1 : 
    1133            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1134            1 :         let jwks_addr = jwks_server(move |path| match path {
    1135            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1136            1 :             _ => None,
    1137            1 :         })
    1138            1 :         .await;
    1139            1 : 
    1140            1 :         let role = RoleName::from("authenticated");
    1141            1 : 
    1142            1 :         let rules = vec![AuthRule {
    1143            1 :             id: String::new(),
    1144            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1145            1 :             audience: None,
    1146            1 :             role_names: vec![RoleNameInt::from(&role)],
    1147            1 :         }];
    1148            1 : 
    1149            1 :         let fetch = Fetch(rules);
    1150            1 :         let jwk_cache = JwkCache::default();
    1151            1 : 
    1152            1 :         let ep = EndpointId::from("ep");
    1153            1 : 
    1154            1 :         let ctx = RequestContext::test();
    1155            1 :         let err = jwk_cache
    1156            1 :             .check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
    1157            3 :             .await
    1158            1 :             .unwrap_err();
    1159            1 :         assert!(
    1160            1 :             matches!(err, JwtError::Signature(_)),
    1161            1 :             "expected \"signature error\", got {err:?}"
    1162            1 :         );
    1163            1 :     }
    1164              : 
    1165              :     #[tokio::test]
    1166            1 :     async fn check_jwt_unknown_role() {
    1167            1 :         let (key, jwk) = new_rsa_jwk(RS1, "1".into());
    1168            1 :         let jwt = new_rsa_jwt("1".into(), key);
    1169            1 : 
    1170            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1171            1 :         let jwks_addr = jwks_server(move |path| match path {
    1172            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1173            1 :             _ => None,
    1174            1 :         })
    1175            1 :         .await;
    1176            1 : 
    1177            1 :         let role = RoleName::from("authenticated");
    1178            1 :         let rules = vec![AuthRule {
    1179            1 :             id: String::new(),
    1180            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1181            1 :             audience: None,
    1182            1 :             role_names: vec![RoleNameInt::from(&role)],
    1183            1 :         }];
    1184            1 : 
    1185            1 :         let fetch = Fetch(rules);
    1186            1 :         let jwk_cache = JwkCache::default();
    1187            1 : 
    1188            1 :         let ep = EndpointId::from("ep");
    1189            1 : 
    1190            1 :         // this role_name is not accepted
    1191            1 :         let bad_role_name = RoleName::from("cloud_admin");
    1192            1 : 
    1193            1 :         let ctx = RequestContext::test();
    1194            1 :         let err = jwk_cache
    1195            1 :             .check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
    1196            3 :             .await
    1197            1 :             .unwrap_err();
    1198            1 : 
    1199            1 :         assert!(
    1200            1 :             matches!(err, JwtError::JwkNotFound),
    1201            1 :             "expected \"jwk not found\", got {err:?}"
    1202            1 :         );
    1203            1 :     }
    1204              : 
    1205              :     #[tokio::test]
    1206            1 :     async fn check_jwt_invalid_claims() {
    1207            1 :         let (key, jwk) = new_ec_jwk("1".into());
    1208            1 : 
    1209            1 :         let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
    1210            1 :         let jwks_addr = jwks_server(move |path| match path {
    1211            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1212            1 :             _ => None,
    1213            1 :         })
    1214            1 :         .await;
    1215            1 : 
    1216            1 :         let now = SystemTime::now()
    1217            1 :             .duration_since(SystemTime::UNIX_EPOCH)
    1218            1 :             .unwrap()
    1219            1 :             .as_secs();
    1220            1 : 
    1221            1 :         struct Test {
    1222            1 :             body: serde_json::Value,
    1223            1 :             error: JwtClaimsError,
    1224            1 :         }
    1225            1 : 
    1226            1 :         let table = vec![
    1227            1 :             Test {
    1228            1 :                 body: json! {{
    1229            1 :                     "nbf": now + 60,
    1230            1 :                     "aud": "neon",
    1231            1 :                 }},
    1232            1 :                 error: JwtClaimsError::JwtTokenNotYetReadyToUse,
    1233            1 :             },
    1234            1 :             Test {
    1235            1 :                 body: json! {{
    1236            1 :                     "exp": now - 60,
    1237            1 :                     "aud": ["neon"],
    1238            1 :                 }},
    1239            1 :                 error: JwtClaimsError::JwtTokenHasExpired,
    1240            1 :             },
    1241            1 :             Test {
    1242            1 :                 body: json! {{
    1243            1 :                 }},
    1244            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1245            1 :             },
    1246            1 :             Test {
    1247            1 :                 body: json! {{
    1248            1 :                     "aud": [],
    1249            1 :                 }},
    1250            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1251            1 :             },
    1252            1 :             Test {
    1253            1 :                 body: json! {{
    1254            1 :                     "aud": "foo",
    1255            1 :                 }},
    1256            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1257            1 :             },
    1258            1 :             Test {
    1259            1 :                 body: json! {{
    1260            1 :                     "aud": ["foo"],
    1261            1 :                 }},
    1262            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1263            1 :             },
    1264            1 :             Test {
    1265            1 :                 body: json! {{
    1266            1 :                     "aud": ["foo", "bar"],
    1267            1 :                 }},
    1268            1 :                 error: JwtClaimsError::InvalidJwtTokenAudience,
    1269            1 :             },
    1270            1 :         ];
    1271            1 : 
    1272            1 :         let role = RoleName::from("authenticated");
    1273            1 : 
    1274            1 :         let rules = vec![AuthRule {
    1275            1 :             id: String::new(),
    1276            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1277            1 :             audience: Some("neon".to_string()),
    1278            1 :             role_names: vec![RoleNameInt::from(&role)],
    1279            1 :         }];
    1280            1 : 
    1281            1 :         let fetch = Fetch(rules);
    1282            1 :         let jwk_cache = JwkCache::default();
    1283            1 : 
    1284            1 :         let ep = EndpointId::from("ep");
    1285            1 : 
    1286            1 :         let ctx = RequestContext::test();
    1287            8 :         for test in table {
    1288            7 :             let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
    1289            7 : 
    1290            7 :             match jwk_cache
    1291            7 :                 .check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
    1292            3 :                 .await
    1293            1 :             {
    1294            7 :                 Err(JwtError::InvalidClaims(error)) if error == test.error => {}
    1295            1 :                 Err(err) => {
    1296            0 :                     panic!("expected {:?}, got {err:?}", test.error)
    1297            1 :                 }
    1298            1 :                 Ok(_payload) => {
    1299            0 :                     panic!("expected {:?}, got ok", test.error)
    1300            1 :                 }
    1301            1 :             }
    1302            1 :         }
    1303            1 :     }
    1304              : 
    1305              :     #[tokio::test]
    1306            1 :     async fn check_jwk_keycloak_regression() {
    1307            1 :         let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into());
    1308            1 :         let valid_jwk = serde_json::to_value(valid_jwk).unwrap();
    1309            1 : 
    1310            1 :         // This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones.
    1311            1 :         // This is taken directly from keycloak.
    1312            1 :         let invalid_jwk = serde_json::json! {
    1313            1 :             {
    1314            1 :                 "kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ",
    1315            1 :                 "kty": "RSA",
    1316            1 :                 "alg": "RSA-OAEP",
    1317            1 :                 "use": "enc",
    1318            1 :                 "n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw",
    1319            1 :                 "e": "AQAB",
    1320            1 :                 "x5c": [
    1321            1 :                     "MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI="
    1322            1 :                 ],
    1323            1 :                 "x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs",
    1324            1 :                 "x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU"
    1325            1 :             }
    1326            1 :         };
    1327            1 : 
    1328            1 :         let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }};
    1329            1 :         let jwks_addr = jwks_server(move |path| match path {
    1330            1 :             "/" => Some(serde_json::to_vec(&jwks).unwrap()),
    1331            1 :             _ => None,
    1332            1 :         })
    1333            1 :         .await;
    1334            1 : 
    1335            1 :         let role_name = RoleName::from("anonymous");
    1336            1 :         let role = RoleNameInt::from(&role_name);
    1337            1 : 
    1338            1 :         let rules = vec![AuthRule {
    1339            1 :             id: "foo".to_owned(),
    1340            1 :             jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
    1341            1 :             audience: None,
    1342            1 :             role_names: vec![role],
    1343            1 :         }];
    1344            1 : 
    1345            1 :         let fetch = Fetch(rules);
    1346            1 :         let jwk_cache = JwkCache::default();
    1347            1 : 
    1348            1 :         let endpoint = EndpointId::from("ep");
    1349            1 : 
    1350            1 :         let token = new_rsa_jwt("rs1".into(), rs);
    1351            1 : 
    1352            1 :         jwk_cache
    1353            1 :             .check_jwt(
    1354            1 :                 &RequestContext::test(),
    1355            1 :                 endpoint.clone(),
    1356            1 :                 &role_name,
    1357            1 :                 &fetch,
    1358            1 :                 &token,
    1359            1 :             )
    1360            3 :             .await
    1361            1 :             .unwrap();
    1362            1 :     }
    1363              : }
        

Generated by: LCOV version 2.1-beta