LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - jwt.rs (source / functions) Coverage Total Hit
Test: 2aa98e37cd3250b9a68c97ef6050b16fe702ab33.info Lines: 79.0 % 420 332
Test Date: 2024-08-29 11:33:10 Functions: 42.3 % 97 41

            Line data    Source code
       1              : use std::{
       2              :     future::Future,
       3              :     sync::Arc,
       4              :     time::{Duration, SystemTime},
       5              : };
       6              : 
       7              : use anyhow::{bail, ensure, Context};
       8              : use arc_swap::ArcSwapOption;
       9              : use dashmap::DashMap;
      10              : use jose_jwk::crypto::KeyInfo;
      11              : use serde::{Deserialize, Deserializer};
      12              : use signature::Verifier;
      13              : use tokio::time::Instant;
      14              : 
      15              : use crate::{context::RequestMonitoring, http::parse_json_body_with_limit, EndpointId, RoleName};
      16              : 
      17              : // TODO(conrad): make these configurable.
      18              : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
      19              : const MIN_RENEW: Duration = Duration::from_secs(30);
      20              : const AUTO_RENEW: Duration = Duration::from_secs(300);
      21              : const MAX_RENEW: Duration = Duration::from_secs(3600);
      22              : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
      23              : 
      24              : /// How to get the JWT auth rules
      25              : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
      26              :     fn fetch_auth_rules(
      27              :         &self,
      28              :         role_name: RoleName,
      29              :     ) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
      30              : }
      31              : 
      32              : pub(crate) struct AuthRule {
      33              :     pub(crate) id: String,
      34              :     pub(crate) jwks_url: url::Url,
      35              :     pub(crate) audience: Option<String>,
      36              : }
      37              : 
      38              : #[derive(Default)]
      39              : pub(crate) struct JwkCache {
      40              :     client: reqwest::Client,
      41              : 
      42              :     map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
      43              : }
      44              : 
      45              : pub(crate) struct JwkCacheEntry {
      46              :     /// Should refetch at least every hour to verify when old keys have been removed.
      47              :     /// Should refetch when new key IDs are seen only every 5 minutes or so
      48              :     last_retrieved: Instant,
      49              : 
      50              :     /// cplane will return multiple JWKs urls that we need to scrape.
      51              :     key_sets: ahash::HashMap<String, KeySet>,
      52              : }
      53              : 
      54              : impl JwkCacheEntry {
      55           24 :     fn find_jwk_and_audience(&self, key_id: &str) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
      56           36 :         self.key_sets.values().find_map(|key_set| {
      57           36 :             key_set
      58           36 :                 .find_key(key_id)
      59           36 :                 .map(|jwk| (jwk, key_set.audience.as_deref()))
      60           36 :         })
      61           24 :     }
      62              : }
      63              : 
      64              : struct KeySet {
      65              :     jwks: jose_jwk::JwkSet,
      66              :     audience: Option<String>,
      67              : }
      68              : 
      69              : impl KeySet {
      70           36 :     fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
      71           36 :         self.jwks
      72           36 :             .keys
      73           36 :             .iter()
      74           60 :             .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
      75           36 :     }
      76              : }
      77              : 
      78              : pub(crate) struct JwkCacheEntryLock {
      79              :     cached: ArcSwapOption<JwkCacheEntry>,
      80              :     lookup: tokio::sync::Semaphore,
      81              : }
      82              : 
      83              : impl Default for JwkCacheEntryLock {
      84            6 :     fn default() -> Self {
      85            6 :         JwkCacheEntryLock {
      86            6 :             cached: ArcSwapOption::empty(),
      87            6 :             lookup: tokio::sync::Semaphore::new(1),
      88            6 :         }
      89            6 :     }
      90              : }
      91              : 
      92              : impl JwkCacheEntryLock {
      93            6 :     async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
      94            6 :         JwkRenewalPermit::acquire_permit(self).await
      95            6 :     }
      96              : 
      97            0 :     fn try_acquire_permit<'a>(self: &'a Arc<Self>) -> Option<JwkRenewalPermit<'a>> {
      98            0 :         JwkRenewalPermit::try_acquire_permit(self)
      99            0 :     }
     100              : 
     101            6 :     async fn renew_jwks<F: FetchAuthRules>(
     102            6 :         &self,
     103            6 :         _permit: JwkRenewalPermit<'_>,
     104            6 :         client: &reqwest::Client,
     105            6 :         role_name: RoleName,
     106            6 :         auth_rules: &F,
     107            6 :     ) -> anyhow::Result<Arc<JwkCacheEntry>> {
     108            6 :         // double check that no one beat us to updating the cache.
     109            6 :         let now = Instant::now();
     110            6 :         let guard = self.cached.load_full();
     111            6 :         if let Some(cached) = guard {
     112            0 :             let last_update = now.duration_since(cached.last_retrieved);
     113            0 :             if last_update < Duration::from_secs(300) {
     114            0 :                 return Ok(cached);
     115            0 :             }
     116            6 :         }
     117              : 
     118            6 :         let rules = auth_rules.fetch_auth_rules(role_name).await?;
     119            6 :         let mut key_sets =
     120            6 :             ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
     121              :         // TODO(conrad): run concurrently
     122              :         // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
     123           18 :         for rule in rules {
     124           12 :             let req = client.get(rule.jwks_url.clone());
     125           12 :             // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
     126           12 :             // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
     127           24 :             match req.send().await.and_then(|r| r.error_for_status()) {
     128              :                 // todo: should we re-insert JWKs if we want to keep this JWKs URL?
     129              :                 // I expect these failures would be quite sparse.
     130            0 :                 Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
     131           12 :                 Ok(r) => {
     132           12 :                     let resp: http::Response<reqwest::Body> = r.into();
     133           12 :                     match parse_json_body_with_limit::<jose_jwk::JwkSet>(
     134           12 :                         resp.into_body(),
     135           12 :                         MAX_JWK_BODY_SIZE,
     136           12 :                     )
     137            0 :                     .await
     138              :                     {
     139            0 :                         Err(e) => {
     140            0 :                             tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
     141              :                         }
     142           12 :                         Ok(jwks) => {
     143           12 :                             key_sets.insert(
     144           12 :                                 rule.id,
     145           12 :                                 KeySet {
     146           12 :                                     jwks,
     147           12 :                                     audience: rule.audience,
     148           12 :                                 },
     149           12 :                             );
     150           12 :                         }
     151              :                     }
     152              :                 }
     153              :             }
     154              :         }
     155              : 
     156            6 :         let entry = Arc::new(JwkCacheEntry {
     157            6 :             last_retrieved: now,
     158            6 :             key_sets,
     159            6 :         });
     160            6 :         self.cached.swap(Some(Arc::clone(&entry)));
     161            6 : 
     162            6 :         Ok(entry)
     163            6 :     }
     164              : 
     165           24 :     async fn get_or_update_jwk_cache<F: FetchAuthRules>(
     166           24 :         self: &Arc<Self>,
     167           24 :         ctx: &RequestMonitoring,
     168           24 :         client: &reqwest::Client,
     169           24 :         role_name: RoleName,
     170           24 :         fetch: &F,
     171           24 :     ) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
     172           24 :         let now = Instant::now();
     173           24 :         let guard = self.cached.load_full();
     174              : 
     175              :         // if we have no cached JWKs, try and get some
     176           24 :         let Some(cached) = guard else {
     177            6 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     178            6 :             let permit = self.acquire_permit().await;
     179           24 :             return self.renew_jwks(permit, client, role_name, fetch).await;
     180              :         };
     181              : 
     182           18 :         let last_update = now.duration_since(cached.last_retrieved);
     183           18 : 
     184           18 :         // check if the cached JWKs need updating.
     185           18 :         if last_update > MAX_RENEW {
     186            0 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     187            0 :             let permit = self.acquire_permit().await;
     188              : 
     189              :             // it's been too long since we checked the keys. wait for them to update.
     190            0 :             return self.renew_jwks(permit, client, role_name, fetch).await;
     191           18 :         }
     192           18 : 
     193           18 :         // every 5 minutes we should spawn a job to eagerly update the token.
     194           18 :         if last_update > AUTO_RENEW {
     195            0 :             if let Some(permit) = self.try_acquire_permit() {
     196            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit acquired");
     197            0 :                 let permit = permit.into_owned();
     198            0 :                 let entry = self.clone();
     199            0 :                 let client = client.clone();
     200            0 :                 let fetch = fetch.clone();
     201            0 :                 tokio::spawn(async move {
     202            0 :                     if let Err(e) = entry.renew_jwks(permit, &client, role_name, &fetch).await {
     203            0 :                         tracing::warn!(error=?e, "could not fetch JWKs in background job");
     204            0 :                     }
     205            0 :                 });
     206            0 :             } else {
     207            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
     208              :             }
     209           18 :         }
     210              : 
     211           18 :         Ok(cached)
     212           24 :     }
     213              : 
     214           24 :     async fn check_jwt<F: FetchAuthRules>(
     215           24 :         self: &Arc<Self>,
     216           24 :         ctx: &RequestMonitoring,
     217           24 :         jwt: &str,
     218           24 :         client: &reqwest::Client,
     219           24 :         role_name: RoleName,
     220           24 :         fetch: &F,
     221           24 :     ) -> Result<(), anyhow::Error> {
     222              :         // JWT compact form is defined to be
     223              :         // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
     224              :         // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
     225              : 
     226           24 :         let (header_payload, signature) = jwt
     227           24 :             .rsplit_once('.')
     228           24 :             .context("Provided authentication token is not a valid JWT encoding")?;
     229           24 :         let (header, payload) = header_payload
     230           24 :             .split_once('.')
     231           24 :             .context("Provided authentication token is not a valid JWT encoding")?;
     232              : 
     233           24 :         let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)
     234           24 :             .context("Provided authentication token is not a valid JWT encoding")?;
     235           24 :         let header = serde_json::from_slice::<JwtHeader<'_>>(&header)
     236           24 :             .context("Provided authentication token is not a valid JWT encoding")?;
     237              : 
     238           24 :         let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
     239           24 :             .context("Provided authentication token is not a valid JWT encoding")?;
     240              : 
     241           24 :         ensure!(header.typ == "JWT");
     242           24 :         let kid = header.key_id.context("missing key id")?;
     243              : 
     244           24 :         let mut guard = self
     245           24 :             .get_or_update_jwk_cache(ctx, client, role_name.clone(), fetch)
     246           24 :             .await?;
     247              : 
     248              :         // get the key from the JWKs if possible. If not, wait for the keys to update.
     249           24 :         let (jwk, expected_audience) = loop {
     250           24 :             match guard.find_jwk_and_audience(kid) {
     251           24 :                 Some(jwk) => break jwk,
     252            0 :                 None if guard.last_retrieved.elapsed() > MIN_RENEW => {
     253            0 :                     let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     254              : 
     255            0 :                     let permit = self.acquire_permit().await;
     256            0 :                     guard = self
     257            0 :                         .renew_jwks(permit, client, role_name.clone(), fetch)
     258            0 :                         .await?;
     259              :                 }
     260              :                 _ => {
     261            0 :                     bail!("jwk not found");
     262              :                 }
     263              :             }
     264              :         };
     265              : 
     266           24 :         ensure!(
     267           24 :             jwk.is_supported(&header.algorithm),
     268            0 :             "signature algorithm not supported"
     269              :         );
     270              : 
     271           24 :         match &jwk.key {
     272           12 :             jose_jwk::Key::Ec(key) => {
     273           12 :                 verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
     274              :             }
     275           12 :             jose_jwk::Key::Rsa(key) => {
     276           12 :                 verify_rsa_signature(header_payload.as_bytes(), &sig, key, &jwk.prm.alg)?;
     277              :             }
     278            0 :             key => bail!("unsupported key type {key:?}"),
     279              :         };
     280              : 
     281           24 :         let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
     282           24 :             .context("Provided authentication token is not a valid JWT encoding")?;
     283           24 :         let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
     284           24 :             .context("Provided authentication token is not a valid JWT encoding")?;
     285              : 
     286           24 :         tracing::debug!(?payload, "JWT signature valid with claims");
     287              : 
     288           24 :         match (expected_audience, payload.audience) {
     289              :             // check the audience matches
     290            0 :             (Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
     291              :             // the audience is expected but is missing
     292            0 :             (Some(_), None) => bail!("invalid JWT token audience"),
     293              :             // we don't care for the audience field
     294           24 :             (None, _) => {}
     295              :         }
     296              : 
     297           24 :         let now = SystemTime::now();
     298              : 
     299           24 :         if let Some(exp) = payload.expiration {
     300           24 :             ensure!(now < exp + CLOCK_SKEW_LEEWAY);
     301            0 :         }
     302              : 
     303           24 :         if let Some(nbf) = payload.not_before {
     304            0 :             ensure!(nbf < now + CLOCK_SKEW_LEEWAY);
     305           24 :         }
     306              : 
     307           24 :         Ok(())
     308           24 :     }
     309              : }
     310              : 
     311              : impl JwkCache {
     312            0 :     pub(crate) async fn check_jwt<F: FetchAuthRules>(
     313            0 :         &self,
     314            0 :         ctx: &RequestMonitoring,
     315            0 :         endpoint: EndpointId,
     316            0 :         role_name: RoleName,
     317            0 :         fetch: &F,
     318            0 :         jwt: &str,
     319            0 :     ) -> Result<(), anyhow::Error> {
     320            0 :         // try with just a read lock first
     321            0 :         let key = (endpoint, role_name.clone());
     322            0 :         let entry = self.map.get(&key).as_deref().map(Arc::clone);
     323            0 :         let entry = entry.unwrap_or_else(|| {
     324            0 :             // acquire a write lock after to insert.
     325            0 :             let entry = self.map.entry(key).or_default();
     326            0 :             Arc::clone(&*entry)
     327            0 :         });
     328            0 : 
     329            0 :         entry
     330            0 :             .check_jwt(ctx, jwt, &self.client, role_name, fetch)
     331            0 :             .await
     332            0 :     }
     333              : }
     334              : 
     335           12 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> anyhow::Result<()> {
     336           12 :     use ecdsa::Signature;
     337           12 :     use signature::Verifier;
     338           12 : 
     339           12 :     match key.crv {
     340              :         jose_jwk::EcCurves::P256 => {
     341           12 :             let pk =
     342           12 :                 p256::PublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid P256 key"))?;
     343           12 :             let key = p256::ecdsa::VerifyingKey::from(&pk);
     344           12 :             let sig = Signature::from_slice(sig)?;
     345           12 :             key.verify(data, &sig)?;
     346              :         }
     347            0 :         key => bail!("unsupported ec key type {key:?}"),
     348              :     }
     349              : 
     350           12 :     Ok(())
     351           12 : }
     352              : 
     353           12 : fn verify_rsa_signature(
     354           12 :     data: &[u8],
     355           12 :     sig: &[u8],
     356           12 :     key: &jose_jwk::Rsa,
     357           12 :     alg: &Option<jose_jwa::Algorithm>,
     358           12 : ) -> anyhow::Result<()> {
     359              :     use jose_jwa::{Algorithm, Signing};
     360              :     use rsa::{
     361              :         pkcs1v15::{Signature, VerifyingKey},
     362              :         RsaPublicKey,
     363              :     };
     364              : 
     365           12 :     let key = RsaPublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid RSA key"))?;
     366              : 
     367           12 :     match alg {
     368              :         Some(Algorithm::Signing(Signing::Rs256)) => {
     369           12 :             let key = VerifyingKey::<sha2::Sha256>::new(key);
     370           12 :             let sig = Signature::try_from(sig)?;
     371           12 :             key.verify(data, &sig)?;
     372              :         }
     373            0 :         _ => bail!("invalid RSA signing algorithm"),
     374              :     };
     375              : 
     376           12 :     Ok(())
     377           12 : }
     378              : 
     379              : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
     380           96 : #[derive(serde::Deserialize, serde::Serialize)]
     381              : struct JwtHeader<'a> {
     382              :     /// must be "JWT"
     383              :     #[serde(rename = "typ")]
     384              :     typ: &'a str,
     385              :     /// must be a supported alg
     386              :     #[serde(rename = "alg")]
     387              :     algorithm: jose_jwa::Algorithm,
     388              :     /// key id, must be provided for our usecase
     389              :     #[serde(rename = "kid")]
     390              :     key_id: Option<&'a str>,
     391              : }
     392              : 
     393              : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
     394           72 : #[derive(serde::Deserialize, serde::Serialize, Debug)]
     395              : struct JwtPayload<'a> {
     396              :     /// Audience - Recipient for which the JWT is intended
     397              :     #[serde(rename = "aud")]
     398              :     audience: Option<&'a str>,
     399              :     /// Expiration - Time after which the JWT expires
     400              :     #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
     401              :     expiration: Option<SystemTime>,
     402              :     /// Not before - Time after which the JWT expires
     403              :     #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
     404              :     not_before: Option<SystemTime>,
     405              : 
     406              :     // the following entries are only extracted for the sake of debug logging.
     407              :     /// Issuer of the JWT
     408              :     #[serde(rename = "iss")]
     409              :     issuer: Option<&'a str>,
     410              :     /// Subject of the JWT (the user)
     411              :     #[serde(rename = "sub")]
     412              :     subject: Option<&'a str>,
     413              :     /// Unique token identifier
     414              :     #[serde(rename = "jti")]
     415              :     jwt_id: Option<&'a str>,
     416              :     /// Unique session identifier
     417              :     #[serde(rename = "sid")]
     418              :     session_id: Option<&'a str>,
     419              : }
     420              : 
     421           24 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
     422           24 :     let d = <Option<u64>>::deserialize(d)?;
     423           24 :     Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
     424           24 : }
     425              : 
     426              : struct JwkRenewalPermit<'a> {
     427              :     inner: Option<JwkRenewalPermitInner<'a>>,
     428              : }
     429              : 
     430              : enum JwkRenewalPermitInner<'a> {
     431              :     Owned(Arc<JwkCacheEntryLock>),
     432              :     Borrowed(&'a Arc<JwkCacheEntryLock>),
     433              : }
     434              : 
     435              : impl JwkRenewalPermit<'_> {
     436            0 :     fn into_owned(mut self) -> JwkRenewalPermit<'static> {
     437            0 :         JwkRenewalPermit {
     438            0 :             inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
     439            0 :         }
     440            0 :     }
     441              : 
     442            6 :     async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
     443            6 :         match from.lookup.acquire().await {
     444            6 :             Ok(permit) => {
     445            6 :                 permit.forget();
     446            6 :                 JwkRenewalPermit {
     447            6 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     448            6 :                 }
     449              :             }
     450            0 :             Err(_) => panic!("semaphore should not be closed"),
     451              :         }
     452            6 :     }
     453              : 
     454            0 :     fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
     455            0 :         match from.lookup.try_acquire() {
     456            0 :             Ok(permit) => {
     457            0 :                 permit.forget();
     458            0 :                 Some(JwkRenewalPermit {
     459            0 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     460            0 :                 })
     461              :             }
     462            0 :             Err(tokio::sync::TryAcquireError::NoPermits) => None,
     463            0 :             Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
     464              :         }
     465            0 :     }
     466              : }
     467              : 
     468              : impl JwkRenewalPermitInner<'_> {
     469            0 :     fn into_owned(self) -> JwkRenewalPermitInner<'static> {
     470            0 :         match self {
     471            0 :             JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
     472            0 :             JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
     473              :         }
     474            0 :     }
     475              : }
     476              : 
     477              : impl Drop for JwkRenewalPermit<'_> {
     478            6 :     fn drop(&mut self) {
     479            6 :         let entry = match &self.inner {
     480            0 :             None => return,
     481            0 :             Some(JwkRenewalPermitInner::Owned(p)) => p,
     482            6 :             Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
     483              :         };
     484            6 :         entry.lookup.add_permits(1);
     485            6 :     }
     486              : }
     487              : 
     488              : #[cfg(test)]
     489              : mod tests {
     490              :     use crate::RoleName;
     491              : 
     492              :     use super::*;
     493              : 
     494              :     use std::{future::IntoFuture, net::SocketAddr, time::SystemTime};
     495              : 
     496              :     use base64::URL_SAFE_NO_PAD;
     497              :     use bytes::Bytes;
     498              :     use http::Response;
     499              :     use http_body_util::Full;
     500              :     use hyper1::service::service_fn;
     501              :     use hyper_util::rt::TokioIo;
     502              :     use rand::rngs::OsRng;
     503              :     use signature::Signer;
     504              :     use tokio::net::TcpListener;
     505              : 
     506           12 :     fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
     507           12 :         let sk = p256::SecretKey::random(&mut OsRng);
     508           12 :         let pk = sk.public_key().into();
     509           12 :         let jwk = jose_jwk::Jwk {
     510           12 :             key: jose_jwk::Key::Ec(pk),
     511           12 :             prm: jose_jwk::Parameters {
     512           12 :                 kid: Some(kid),
     513           12 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
     514           12 :                 ..Default::default()
     515           12 :             },
     516           12 :         };
     517           12 :         (sk, jwk)
     518           12 :     }
     519              : 
     520           12 :     fn new_rsa_jwk(kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
     521           12 :         let sk = rsa::RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
     522           12 :         let pk = sk.to_public_key().into();
     523           12 :         let jwk = jose_jwk::Jwk {
     524           12 :             key: jose_jwk::Key::Rsa(pk),
     525           12 :             prm: jose_jwk::Parameters {
     526           12 :                 kid: Some(kid),
     527           12 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
     528           12 :                 ..Default::default()
     529           12 :             },
     530           12 :         };
     531           12 :         (sk, jwk)
     532           12 :     }
     533              : 
     534           24 :     fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
     535           24 :         let header = JwtHeader {
     536           24 :             typ: "JWT",
     537           24 :             algorithm: jose_jwa::Algorithm::Signing(sig),
     538           24 :             key_id: Some(&kid),
     539           24 :         };
     540           24 :         let body = typed_json::json! {{
     541           24 :             "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
     542           24 :         }};
     543           24 : 
     544           24 :         let header =
     545           24 :             base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
     546           24 :         let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
     547           24 : 
     548           24 :         format!("{header}.{body}")
     549           24 :     }
     550              : 
     551           12 :     fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String {
     552           12 :         use p256::ecdsa::{Signature, SigningKey};
     553           12 : 
     554           12 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
     555           12 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     556           12 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     557           12 : 
     558           12 :         format!("{payload}.{sig}")
     559           12 :     }
     560              : 
     561           12 :     fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
     562           12 :         use rsa::pkcs1v15::SigningKey;
     563           12 :         use rsa::signature::SignatureEncoding;
     564           12 : 
     565           12 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
     566           12 :         let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
     567           12 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     568           12 : 
     569           12 :         format!("{payload}.{sig}")
     570           12 :     }
     571              : 
     572              :     #[tokio::test]
     573            6 :     async fn renew() {
     574            6 :         let (rs1, jwk1) = new_rsa_jwk("1".into());
     575            6 :         let (rs2, jwk2) = new_rsa_jwk("2".into());
     576            6 :         let (ec1, jwk3) = new_ec_jwk("3".into());
     577            6 :         let (ec2, jwk4) = new_ec_jwk("4".into());
     578            6 : 
     579            6 :         let jwt1 = new_rsa_jwt("1".into(), rs1);
     580            6 :         let jwt2 = new_rsa_jwt("2".into(), rs2);
     581            6 :         let jwt3 = new_ec_jwt("3".into(), ec1);
     582            6 :         let jwt4 = new_ec_jwt("4".into(), ec2);
     583            6 : 
     584            6 :         let foo_jwks = jose_jwk::JwkSet {
     585            6 :             keys: vec![jwk1, jwk3],
     586            6 :         };
     587            6 :         let bar_jwks = jose_jwk::JwkSet {
     588            6 :             keys: vec![jwk2, jwk4],
     589            6 :         };
     590            6 : 
     591           12 :         let service = service_fn(move |req| {
     592           12 :             let foo_jwks = foo_jwks.clone();
     593           12 :             let bar_jwks = bar_jwks.clone();
     594           12 :             async move {
     595           12 :                 let jwks = match req.uri().path() {
     596           12 :                     "/foo" => &foo_jwks,
     597            6 :                     "/bar" => &bar_jwks,
     598            6 :                     _ => {
     599            6 :                         return Response::builder()
     600            0 :                             .status(404)
     601            0 :                             .body(Full::new(Bytes::new()));
     602            6 :                     }
     603            6 :                 };
     604           12 :                 let body = serde_json::to_vec(jwks).unwrap();
     605           12 :                 Response::builder()
     606           12 :                     .status(200)
     607           12 :                     .body(Full::new(Bytes::from(body)))
     608           12 :             }
     609           12 :         });
     610            6 : 
     611            6 :         let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
     612            6 :         let server = hyper1::server::conn::http1::Builder::new();
     613            6 :         let addr = listener.local_addr().unwrap();
     614            6 :         tokio::spawn(async move {
     615            6 :             loop {
     616           12 :                 let (s, _) = listener.accept().await.unwrap();
     617            6 :                 let serve = server.serve_connection(TokioIo::new(s), service.clone());
     618            6 :                 tokio::spawn(serve.into_future());
     619            6 :             }
     620            6 :         });
     621            6 : 
     622            6 :         let client = reqwest::Client::new();
     623            6 : 
     624            6 :         #[derive(Clone)]
     625            6 :         struct Fetch(SocketAddr);
     626            6 : 
     627            6 :         impl FetchAuthRules for Fetch {
     628            6 :             async fn fetch_auth_rules(
     629            6 :                 &self,
     630            6 :                 _role_name: RoleName,
     631            6 :             ) -> anyhow::Result<Vec<AuthRule>> {
     632            6 :                 Ok(vec![
     633            6 :                     AuthRule {
     634            6 :                         id: "foo".to_owned(),
     635            6 :                         jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
     636            6 :                         audience: None,
     637            6 :                     },
     638            6 :                     AuthRule {
     639            6 :                         id: "bar".to_owned(),
     640            6 :                         jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
     641            6 :                         audience: None,
     642            6 :                     },
     643            6 :                 ])
     644            6 :             }
     645            6 :         }
     646            6 : 
     647            6 :         let role_name = RoleName::from("user");
     648            6 : 
     649            6 :         let jwk_cache = Arc::new(JwkCacheEntryLock::default());
     650            6 : 
     651           24 :         for token in [jwt1, jwt2, jwt3, jwt4] {
     652           24 :             jwk_cache
     653           24 :                 .check_jwt(
     654           24 :                     &RequestMonitoring::test(),
     655           24 :                     &token,
     656           24 :                     &client,
     657           24 :                     role_name.clone(),
     658           24 :                     &Fetch(addr),
     659           24 :                 )
     660           24 :                 .await
     661           24 :                 .unwrap();
     662            6 :         }
     663            6 :     }
     664              : }
        

Generated by: LCOV version 2.1-beta