LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - jwt.rs (source / functions) Coverage Total Hit
Test: 90b23405d17e36048d3bb64e314067f397803f1b.info Lines: 76.8 % 436 335
Test Date: 2024-09-20 13:14:58 Functions: 42.3 % 97 41

            Line data    Source code
       1              : use std::{
       2              :     future::Future,
       3              :     sync::Arc,
       4              :     time::{Duration, SystemTime},
       5              : };
       6              : 
       7              : use anyhow::{bail, ensure, Context};
       8              : use arc_swap::ArcSwapOption;
       9              : use dashmap::DashMap;
      10              : use jose_jwk::crypto::KeyInfo;
      11              : use serde::{Deserialize, Deserializer};
      12              : use signature::Verifier;
      13              : use tokio::time::Instant;
      14              : 
      15              : use crate::{context::RequestMonitoring, http::parse_json_body_with_limit, EndpointId, RoleName};
      16              : 
      17              : // TODO(conrad): make these configurable.
      18              : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
      19              : const MIN_RENEW: Duration = Duration::from_secs(30);
      20              : const AUTO_RENEW: Duration = Duration::from_secs(300);
      21              : const MAX_RENEW: Duration = Duration::from_secs(3600);
      22              : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
      23              : 
      24              : /// How to get the JWT auth rules
      25              : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
      26              :     fn fetch_auth_rules(
      27              :         &self,
      28              :         ctx: &RequestMonitoring,
      29              :         endpoint: EndpointId,
      30              :         role_name: RoleName,
      31              :     ) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
      32              : }
      33              : 
      34              : pub(crate) struct AuthRule {
      35              :     pub(crate) id: String,
      36              :     pub(crate) jwks_url: url::Url,
      37              :     pub(crate) audience: Option<String>,
      38              : }
      39              : 
      40              : #[derive(Default)]
      41              : pub(crate) struct JwkCache {
      42              :     client: reqwest::Client,
      43              : 
      44              :     map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
      45              : }
      46              : 
      47              : pub(crate) struct JwkCacheEntry {
      48              :     /// Should refetch at least every hour to verify when old keys have been removed.
      49              :     /// Should refetch when new key IDs are seen only every 5 minutes or so
      50              :     last_retrieved: Instant,
      51              : 
      52              :     /// cplane will return multiple JWKs urls that we need to scrape.
      53              :     key_sets: ahash::HashMap<String, KeySet>,
      54              : }
      55              : 
      56              : impl JwkCacheEntry {
      57            4 :     fn find_jwk_and_audience(&self, key_id: &str) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
      58            6 :         self.key_sets.values().find_map(|key_set| {
      59            6 :             key_set
      60            6 :                 .find_key(key_id)
      61            6 :                 .map(|jwk| (jwk, key_set.audience.as_deref()))
      62            6 :         })
      63            4 :     }
      64              : }
      65              : 
      66              : struct KeySet {
      67              :     jwks: jose_jwk::JwkSet,
      68              :     audience: Option<String>,
      69              : }
      70              : 
      71              : impl KeySet {
      72            6 :     fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
      73            6 :         self.jwks
      74            6 :             .keys
      75            6 :             .iter()
      76           10 :             .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
      77            6 :     }
      78              : }
      79              : 
      80              : pub(crate) struct JwkCacheEntryLock {
      81              :     cached: ArcSwapOption<JwkCacheEntry>,
      82              :     lookup: tokio::sync::Semaphore,
      83              : }
      84              : 
      85              : impl Default for JwkCacheEntryLock {
      86            1 :     fn default() -> Self {
      87            1 :         JwkCacheEntryLock {
      88            1 :             cached: ArcSwapOption::empty(),
      89            1 :             lookup: tokio::sync::Semaphore::new(1),
      90            1 :         }
      91            1 :     }
      92              : }
      93              : 
      94              : impl JwkCacheEntryLock {
      95            1 :     async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
      96            1 :         JwkRenewalPermit::acquire_permit(self).await
      97            1 :     }
      98              : 
      99            0 :     fn try_acquire_permit<'a>(self: &'a Arc<Self>) -> Option<JwkRenewalPermit<'a>> {
     100            0 :         JwkRenewalPermit::try_acquire_permit(self)
     101            0 :     }
     102              : 
     103            1 :     async fn renew_jwks<F: FetchAuthRules>(
     104            1 :         &self,
     105            1 :         _permit: JwkRenewalPermit<'_>,
     106            1 :         ctx: &RequestMonitoring,
     107            1 :         client: &reqwest::Client,
     108            1 :         endpoint: EndpointId,
     109            1 :         role_name: RoleName,
     110            1 :         auth_rules: &F,
     111            1 :     ) -> anyhow::Result<Arc<JwkCacheEntry>> {
     112            1 :         // double check that no one beat us to updating the cache.
     113            1 :         let now = Instant::now();
     114            1 :         let guard = self.cached.load_full();
     115            1 :         if let Some(cached) = guard {
     116            0 :             let last_update = now.duration_since(cached.last_retrieved);
     117            0 :             if last_update < Duration::from_secs(300) {
     118            0 :                 return Ok(cached);
     119            0 :             }
     120            1 :         }
     121              : 
     122            1 :         let rules = auth_rules
     123            1 :             .fetch_auth_rules(ctx, endpoint, role_name)
     124            0 :             .await?;
     125            1 :         let mut key_sets =
     126            1 :             ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
     127              :         // TODO(conrad): run concurrently
     128              :         // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
     129            3 :         for rule in rules {
     130            2 :             let req = client.get(rule.jwks_url.clone());
     131            2 :             // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
     132            2 :             // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
     133            4 :             match req.send().await.and_then(|r| r.error_for_status()) {
     134              :                 // todo: should we re-insert JWKs if we want to keep this JWKs URL?
     135              :                 // I expect these failures would be quite sparse.
     136            0 :                 Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
     137            2 :                 Ok(r) => {
     138            2 :                     let resp: http::Response<reqwest::Body> = r.into();
     139            2 :                     match parse_json_body_with_limit::<jose_jwk::JwkSet>(
     140            2 :                         resp.into_body(),
     141            2 :                         MAX_JWK_BODY_SIZE,
     142            2 :                     )
     143            0 :                     .await
     144              :                     {
     145            0 :                         Err(e) => {
     146            0 :                             tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
     147              :                         }
     148            2 :                         Ok(jwks) => {
     149            2 :                             key_sets.insert(
     150            2 :                                 rule.id,
     151            2 :                                 KeySet {
     152            2 :                                     jwks,
     153            2 :                                     audience: rule.audience,
     154            2 :                                 },
     155            2 :                             );
     156            2 :                         }
     157              :                     }
     158              :                 }
     159              :             }
     160              :         }
     161              : 
     162            1 :         let entry = Arc::new(JwkCacheEntry {
     163            1 :             last_retrieved: now,
     164            1 :             key_sets,
     165            1 :         });
     166            1 :         self.cached.swap(Some(Arc::clone(&entry)));
     167            1 : 
     168            1 :         Ok(entry)
     169            1 :     }
     170              : 
     171            4 :     async fn get_or_update_jwk_cache<F: FetchAuthRules>(
     172            4 :         self: &Arc<Self>,
     173            4 :         ctx: &RequestMonitoring,
     174            4 :         client: &reqwest::Client,
     175            4 :         endpoint: EndpointId,
     176            4 :         role_name: RoleName,
     177            4 :         fetch: &F,
     178            4 :     ) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
     179            4 :         let now = Instant::now();
     180            4 :         let guard = self.cached.load_full();
     181              : 
     182              :         // if we have no cached JWKs, try and get some
     183            4 :         let Some(cached) = guard else {
     184            1 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     185            1 :             let permit = self.acquire_permit().await;
     186            1 :             return self
     187            1 :                 .renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
     188            4 :                 .await;
     189              :         };
     190              : 
     191            3 :         let last_update = now.duration_since(cached.last_retrieved);
     192            3 : 
     193            3 :         // check if the cached JWKs need updating.
     194            3 :         if last_update > MAX_RENEW {
     195            0 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     196            0 :             let permit = self.acquire_permit().await;
     197              : 
     198              :             // it's been too long since we checked the keys. wait for them to update.
     199            0 :             return self
     200            0 :                 .renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
     201            0 :                 .await;
     202            3 :         }
     203            3 : 
     204            3 :         // every 5 minutes we should spawn a job to eagerly update the token.
     205            3 :         if last_update > AUTO_RENEW {
     206            0 :             if let Some(permit) = self.try_acquire_permit() {
     207            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit acquired");
     208            0 :                 let permit = permit.into_owned();
     209            0 :                 let entry = self.clone();
     210            0 :                 let client = client.clone();
     211            0 :                 let fetch = fetch.clone();
     212            0 :                 let ctx = ctx.clone();
     213            0 :                 tokio::spawn(async move {
     214            0 :                     if let Err(e) = entry
     215            0 :                         .renew_jwks(permit, &ctx, &client, endpoint, role_name, &fetch)
     216            0 :                         .await
     217              :                     {
     218            0 :                         tracing::warn!(error=?e, "could not fetch JWKs in background job");
     219            0 :                     }
     220            0 :                 });
     221            0 :             } else {
     222            0 :                 tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
     223              :             }
     224            3 :         }
     225              : 
     226            3 :         Ok(cached)
     227            4 :     }
     228              : 
     229            4 :     async fn check_jwt<F: FetchAuthRules>(
     230            4 :         self: &Arc<Self>,
     231            4 :         ctx: &RequestMonitoring,
     232            4 :         jwt: &str,
     233            4 :         client: &reqwest::Client,
     234            4 :         endpoint: EndpointId,
     235            4 :         role_name: RoleName,
     236            4 :         fetch: &F,
     237            4 :     ) -> Result<(), anyhow::Error> {
     238              :         // JWT compact form is defined to be
     239              :         // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
     240              :         // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
     241              : 
     242            4 :         let (header_payload, signature) = jwt
     243            4 :             .rsplit_once('.')
     244            4 :             .context("Provided authentication token is not a valid JWT encoding")?;
     245            4 :         let (header, payload) = header_payload
     246            4 :             .split_once('.')
     247            4 :             .context("Provided authentication token is not a valid JWT encoding")?;
     248              : 
     249            4 :         let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)
     250            4 :             .context("Provided authentication token is not a valid JWT encoding")?;
     251            4 :         let header = serde_json::from_slice::<JwtHeader<'_>>(&header)
     252            4 :             .context("Provided authentication token is not a valid JWT encoding")?;
     253              : 
     254            4 :         let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
     255            4 :             .context("Provided authentication token is not a valid JWT encoding")?;
     256              : 
     257            4 :         ensure!(header.typ == "JWT");
     258            4 :         let kid = header.key_id.context("missing key id")?;
     259              : 
     260            4 :         let mut guard = self
     261            4 :             .get_or_update_jwk_cache(ctx, client, endpoint.clone(), role_name.clone(), fetch)
     262            4 :             .await?;
     263              : 
     264              :         // get the key from the JWKs if possible. If not, wait for the keys to update.
     265            4 :         let (jwk, expected_audience) = loop {
     266            4 :             match guard.find_jwk_and_audience(kid) {
     267            4 :                 Some(jwk) => break jwk,
     268            0 :                 None if guard.last_retrieved.elapsed() > MIN_RENEW => {
     269            0 :                     let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     270              : 
     271            0 :                     let permit = self.acquire_permit().await;
     272            0 :                     guard = self
     273            0 :                         .renew_jwks(
     274            0 :                             permit,
     275            0 :                             ctx,
     276            0 :                             client,
     277            0 :                             endpoint.clone(),
     278            0 :                             role_name.clone(),
     279            0 :                             fetch,
     280            0 :                         )
     281            0 :                         .await?;
     282              :                 }
     283              :                 _ => {
     284            0 :                     bail!("jwk not found");
     285              :                 }
     286              :             }
     287              :         };
     288              : 
     289            4 :         ensure!(
     290            4 :             jwk.is_supported(&header.algorithm),
     291            0 :             "signature algorithm not supported"
     292              :         );
     293              : 
     294            4 :         match &jwk.key {
     295            2 :             jose_jwk::Key::Ec(key) => {
     296            2 :                 verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
     297              :             }
     298            2 :             jose_jwk::Key::Rsa(key) => {
     299            2 :                 verify_rsa_signature(header_payload.as_bytes(), &sig, key, &jwk.prm.alg)?;
     300              :             }
     301            0 :             key => bail!("unsupported key type {key:?}"),
     302              :         };
     303              : 
     304            4 :         let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
     305            4 :             .context("Provided authentication token is not a valid JWT encoding")?;
     306            4 :         let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
     307            4 :             .context("Provided authentication token is not a valid JWT encoding")?;
     308              : 
     309            4 :         tracing::debug!(?payload, "JWT signature valid with claims");
     310              : 
     311            4 :         match (expected_audience, payload.audience) {
     312              :             // check the audience matches
     313            0 :             (Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
     314              :             // the audience is expected but is missing
     315            0 :             (Some(_), None) => bail!("invalid JWT token audience"),
     316              :             // we don't care for the audience field
     317            4 :             (None, _) => {}
     318              :         }
     319              : 
     320            4 :         let now = SystemTime::now();
     321              : 
     322            4 :         if let Some(exp) = payload.expiration {
     323            4 :             ensure!(now < exp + CLOCK_SKEW_LEEWAY);
     324            0 :         }
     325              : 
     326            4 :         if let Some(nbf) = payload.not_before {
     327            0 :             ensure!(nbf < now + CLOCK_SKEW_LEEWAY);
     328            4 :         }
     329              : 
     330            4 :         Ok(())
     331            4 :     }
     332              : }
     333              : 
     334              : impl JwkCache {
     335            0 :     pub(crate) async fn check_jwt<F: FetchAuthRules>(
     336            0 :         &self,
     337            0 :         ctx: &RequestMonitoring,
     338            0 :         endpoint: EndpointId,
     339            0 :         role_name: RoleName,
     340            0 :         fetch: &F,
     341            0 :         jwt: &str,
     342            0 :     ) -> Result<(), anyhow::Error> {
     343            0 :         // try with just a read lock first
     344            0 :         let key = (endpoint.clone(), role_name.clone());
     345            0 :         let entry = self.map.get(&key).as_deref().map(Arc::clone);
     346            0 :         let entry = entry.unwrap_or_else(|| {
     347            0 :             // acquire a write lock after to insert.
     348            0 :             let entry = self.map.entry(key).or_default();
     349            0 :             Arc::clone(&*entry)
     350            0 :         });
     351            0 : 
     352            0 :         entry
     353            0 :             .check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
     354            0 :             .await
     355            0 :     }
     356              : }
     357              : 
     358            2 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> anyhow::Result<()> {
     359              :     use ecdsa::Signature;
     360              :     use signature::Verifier;
     361              : 
     362            2 :     match key.crv {
     363              :         jose_jwk::EcCurves::P256 => {
     364            2 :             let pk =
     365            2 :                 p256::PublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid P256 key"))?;
     366            2 :             let key = p256::ecdsa::VerifyingKey::from(&pk);
     367            2 :             let sig = Signature::from_slice(sig)?;
     368            2 :             key.verify(data, &sig)?;
     369              :         }
     370            0 :         key => bail!("unsupported ec key type {key:?}"),
     371              :     }
     372              : 
     373            2 :     Ok(())
     374            2 : }
     375              : 
     376            2 : fn verify_rsa_signature(
     377            2 :     data: &[u8],
     378            2 :     sig: &[u8],
     379            2 :     key: &jose_jwk::Rsa,
     380            2 :     alg: &Option<jose_jwa::Algorithm>,
     381            2 : ) -> anyhow::Result<()> {
     382              :     use jose_jwa::{Algorithm, Signing};
     383              :     use rsa::{
     384              :         pkcs1v15::{Signature, VerifyingKey},
     385              :         RsaPublicKey,
     386              :     };
     387              : 
     388            2 :     let key = RsaPublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid RSA key"))?;
     389              : 
     390            2 :     match alg {
     391              :         Some(Algorithm::Signing(Signing::Rs256)) => {
     392            2 :             let key = VerifyingKey::<sha2::Sha256>::new(key);
     393            2 :             let sig = Signature::try_from(sig)?;
     394            2 :             key.verify(data, &sig)?;
     395              :         }
     396            0 :         _ => bail!("invalid RSA signing algorithm"),
     397              :     };
     398              : 
     399            2 :     Ok(())
     400            2 : }
     401              : 
     402              : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
     403           16 : #[derive(serde::Deserialize, serde::Serialize)]
     404              : struct JwtHeader<'a> {
     405              :     /// must be "JWT"
     406              :     #[serde(rename = "typ")]
     407              :     typ: &'a str,
     408              :     /// must be a supported alg
     409              :     #[serde(rename = "alg")]
     410              :     algorithm: jose_jwa::Algorithm,
     411              :     /// key id, must be provided for our usecase
     412              :     #[serde(rename = "kid")]
     413              :     key_id: Option<&'a str>,
     414              : }
     415              : 
     416              : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
     417           12 : #[derive(serde::Deserialize, serde::Serialize, Debug)]
     418              : struct JwtPayload<'a> {
     419              :     /// Audience - Recipient for which the JWT is intended
     420              :     #[serde(rename = "aud")]
     421              :     audience: Option<&'a str>,
     422              :     /// Expiration - Time after which the JWT expires
     423              :     #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
     424              :     expiration: Option<SystemTime>,
     425              :     /// Not before - Time after which the JWT expires
     426              :     #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
     427              :     not_before: Option<SystemTime>,
     428              : 
     429              :     // the following entries are only extracted for the sake of debug logging.
     430              :     /// Issuer of the JWT
     431              :     #[serde(rename = "iss")]
     432              :     issuer: Option<&'a str>,
     433              :     /// Subject of the JWT (the user)
     434              :     #[serde(rename = "sub")]
     435              :     subject: Option<&'a str>,
     436              :     /// Unique token identifier
     437              :     #[serde(rename = "jti")]
     438              :     jwt_id: Option<&'a str>,
     439              :     /// Unique session identifier
     440              :     #[serde(rename = "sid")]
     441              :     session_id: Option<&'a str>,
     442              : }
     443              : 
     444            4 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
     445            4 :     let d = <Option<u64>>::deserialize(d)?;
     446            4 :     Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
     447            4 : }
     448              : 
     449              : struct JwkRenewalPermit<'a> {
     450              :     inner: Option<JwkRenewalPermitInner<'a>>,
     451              : }
     452              : 
     453              : enum JwkRenewalPermitInner<'a> {
     454              :     Owned(Arc<JwkCacheEntryLock>),
     455              :     Borrowed(&'a Arc<JwkCacheEntryLock>),
     456              : }
     457              : 
     458              : impl JwkRenewalPermit<'_> {
     459            0 :     fn into_owned(mut self) -> JwkRenewalPermit<'static> {
     460            0 :         JwkRenewalPermit {
     461            0 :             inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
     462            0 :         }
     463            0 :     }
     464              : 
     465            1 :     async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
     466            1 :         match from.lookup.acquire().await {
     467            1 :             Ok(permit) => {
     468            1 :                 permit.forget();
     469            1 :                 JwkRenewalPermit {
     470            1 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     471            1 :                 }
     472              :             }
     473            0 :             Err(_) => panic!("semaphore should not be closed"),
     474              :         }
     475            1 :     }
     476              : 
     477            0 :     fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
     478            0 :         match from.lookup.try_acquire() {
     479            0 :             Ok(permit) => {
     480            0 :                 permit.forget();
     481            0 :                 Some(JwkRenewalPermit {
     482            0 :                     inner: Some(JwkRenewalPermitInner::Borrowed(from)),
     483            0 :                 })
     484              :             }
     485            0 :             Err(tokio::sync::TryAcquireError::NoPermits) => None,
     486            0 :             Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
     487              :         }
     488            0 :     }
     489              : }
     490              : 
     491              : impl JwkRenewalPermitInner<'_> {
     492            0 :     fn into_owned(self) -> JwkRenewalPermitInner<'static> {
     493            0 :         match self {
     494            0 :             JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
     495            0 :             JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
     496              :         }
     497            0 :     }
     498              : }
     499              : 
     500              : impl Drop for JwkRenewalPermit<'_> {
     501            1 :     fn drop(&mut self) {
     502            1 :         let entry = match &self.inner {
     503            0 :             None => return,
     504            0 :             Some(JwkRenewalPermitInner::Owned(p)) => p,
     505            1 :             Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
     506              :         };
     507            1 :         entry.lookup.add_permits(1);
     508            1 :     }
     509              : }
     510              : 
     511              : #[cfg(test)]
     512              : mod tests {
     513              :     use crate::RoleName;
     514              : 
     515              :     use super::*;
     516              : 
     517              :     use std::{future::IntoFuture, net::SocketAddr, time::SystemTime};
     518              : 
     519              :     use base64::URL_SAFE_NO_PAD;
     520              :     use bytes::Bytes;
     521              :     use http::Response;
     522              :     use http_body_util::Full;
     523              :     use hyper1::service::service_fn;
     524              :     use hyper_util::rt::TokioIo;
     525              :     use rand::rngs::OsRng;
     526              :     use rsa::pkcs8::DecodePrivateKey;
     527              :     use signature::Signer;
     528              :     use tokio::net::TcpListener;
     529              : 
     530            2 :     fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
     531            2 :         let sk = p256::SecretKey::random(&mut OsRng);
     532            2 :         let pk = sk.public_key().into();
     533            2 :         let jwk = jose_jwk::Jwk {
     534            2 :             key: jose_jwk::Key::Ec(pk),
     535            2 :             prm: jose_jwk::Parameters {
     536            2 :                 kid: Some(kid),
     537            2 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
     538            2 :                 ..Default::default()
     539            2 :             },
     540            2 :         };
     541            2 :         (sk, jwk)
     542            2 :     }
     543              : 
     544            2 :     fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
     545            2 :         let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap();
     546            2 :         let pk = sk.to_public_key().into();
     547            2 :         let jwk = jose_jwk::Jwk {
     548            2 :             key: jose_jwk::Key::Rsa(pk),
     549            2 :             prm: jose_jwk::Parameters {
     550            2 :                 kid: Some(kid),
     551            2 :                 alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
     552            2 :                 ..Default::default()
     553            2 :             },
     554            2 :         };
     555            2 :         (sk, jwk)
     556            2 :     }
     557              : 
     558            4 :     fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
     559            4 :         let header = JwtHeader {
     560            4 :             typ: "JWT",
     561            4 :             algorithm: jose_jwa::Algorithm::Signing(sig),
     562            4 :             key_id: Some(&kid),
     563            4 :         };
     564            4 :         let body = typed_json::json! {{
     565            4 :             "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
     566            4 :         }};
     567            4 : 
     568            4 :         let header =
     569            4 :             base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
     570            4 :         let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
     571            4 : 
     572            4 :         format!("{header}.{body}")
     573            4 :     }
     574              : 
     575            2 :     fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String {
     576              :         use p256::ecdsa::{Signature, SigningKey};
     577              : 
     578            2 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
     579            2 :         let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
     580            2 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     581            2 : 
     582            2 :         format!("{payload}.{sig}")
     583            2 :     }
     584              : 
     585            2 :     fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
     586              :         use rsa::pkcs1v15::SigningKey;
     587              :         use rsa::signature::SignatureEncoding;
     588              : 
     589            2 :         let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
     590            2 :         let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
     591            2 :         let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
     592            2 : 
     593            2 :         format!("{payload}.{sig}")
     594            2 :     }
     595              : 
     596              :     // RSA key gen is slow....
     597              :     const RS1: &str = "-----BEGIN PRIVATE KEY-----
     598              : MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y
     599              : aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA
     600              : SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA
     601              : cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X
     602              : mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s
     603              : PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM
     604              : nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9
     605              : ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi
     606              : 5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+
     607              : KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz
     608              : 1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP
     609              : We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony
     610              : UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu
     611              : GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+
     612              : HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU
     613              : f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24
     614              : DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY
     615              : 9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF
     616              : /4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL
     617              : 3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE
     618              : qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d
     619              : ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW
     620              : rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf
     621              : JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI
     622              : rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9
     623              : 8LMFlz1QPqSZYN/A/kOcLBfa3A==
     624              : -----END PRIVATE KEY-----
     625              : ";
     626              :     const RS2: &str = "-----BEGIN PRIVATE KEY-----
     627              : MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J
     628              : HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J
     629              : zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk
     630              : qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO
     631              : IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt
     632              : ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy
     633              : uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv
     634              : K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG
     635              : LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe
     636              : QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT
     637              : zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX
     638              : 0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR
     639              : +Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1
     640              : P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL
     641              : x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/
     642              : FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg
     643              : iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M
     644              : 4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA
     645              : iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj
     646              : N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB
     647              : uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U
     648              : qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP
     649              : WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB
     650              : rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne
     651              : X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
     652              : 5JiconnI5aLek0QVPoFaVXFa
     653              : -----END PRIVATE KEY-----
     654              : ";
     655              : 
     656              :     #[tokio::test]
     657            1 :     async fn renew() {
     658            1 :         let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into());
     659            1 :         let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into());
     660            1 :         let (ec1, jwk3) = new_ec_jwk("3".into());
     661            1 :         let (ec2, jwk4) = new_ec_jwk("4".into());
     662            1 : 
     663            1 :         let jwt1 = new_rsa_jwt("1".into(), rs1);
     664            1 :         let jwt2 = new_rsa_jwt("2".into(), rs2);
     665            1 :         let jwt3 = new_ec_jwt("3".into(), ec1);
     666            1 :         let jwt4 = new_ec_jwt("4".into(), ec2);
     667            1 : 
     668            1 :         let foo_jwks = jose_jwk::JwkSet {
     669            1 :             keys: vec![jwk1, jwk3],
     670            1 :         };
     671            1 :         let bar_jwks = jose_jwk::JwkSet {
     672            1 :             keys: vec![jwk2, jwk4],
     673            1 :         };
     674            1 : 
     675            2 :         let service = service_fn(move |req| {
     676            2 :             let foo_jwks = foo_jwks.clone();
     677            2 :             let bar_jwks = bar_jwks.clone();
     678            2 :             async move {
     679            2 :                 let jwks = match req.uri().path() {
     680            2 :                     "/foo" => &foo_jwks,
     681            1 :                     "/bar" => &bar_jwks,
     682            1 :                     _ => {
     683            1 :                         return Response::builder()
     684            0 :                             .status(404)
     685            0 :                             .body(Full::new(Bytes::new()));
     686            1 :                     }
     687            1 :                 };
     688            2 :                 let body = serde_json::to_vec(jwks).unwrap();
     689            2 :                 Response::builder()
     690            2 :                     .status(200)
     691            2 :                     .body(Full::new(Bytes::from(body)))
     692            2 :             }
     693            2 :         });
     694            1 : 
     695            1 :         let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
     696            1 :         let server = hyper1::server::conn::http1::Builder::new();
     697            1 :         let addr = listener.local_addr().unwrap();
     698            1 :         tokio::spawn(async move {
     699            1 :             loop {
     700            2 :                 let (s, _) = listener.accept().await.unwrap();
     701            1 :                 let serve = server.serve_connection(TokioIo::new(s), service.clone());
     702            1 :                 tokio::spawn(serve.into_future());
     703            1 :             }
     704            1 :         });
     705            1 : 
     706            1 :         let client = reqwest::Client::new();
     707            1 : 
     708            1 :         #[derive(Clone)]
     709            1 :         struct Fetch(SocketAddr);
     710            1 : 
     711            1 :         impl FetchAuthRules for Fetch {
     712            1 :             async fn fetch_auth_rules(
     713            1 :                 &self,
     714            1 :                 _ctx: &RequestMonitoring,
     715            1 :                 _endpoint: EndpointId,
     716            1 :                 _role_name: RoleName,
     717            1 :             ) -> anyhow::Result<Vec<AuthRule>> {
     718            1 :                 Ok(vec![
     719            1 :                     AuthRule {
     720            1 :                         id: "foo".to_owned(),
     721            1 :                         jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
     722            1 :                         audience: None,
     723            1 :                     },
     724            1 :                     AuthRule {
     725            1 :                         id: "bar".to_owned(),
     726            1 :                         jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
     727            1 :                         audience: None,
     728            1 :                     },
     729            1 :                 ])
     730            1 :             }
     731            1 :         }
     732            1 : 
     733            1 :         let role_name = RoleName::from("user");
     734            1 :         let endpoint = EndpointId::from("ep");
     735            1 : 
     736            1 :         let jwk_cache = Arc::new(JwkCacheEntryLock::default());
     737            1 : 
     738            4 :         for token in [jwt1, jwt2, jwt3, jwt4] {
     739            4 :             jwk_cache
     740            4 :                 .check_jwt(
     741            4 :                     &RequestMonitoring::test(),
     742            4 :                     &token,
     743            4 :                     &client,
     744            4 :                     endpoint.clone(),
     745            4 :                     role_name.clone(),
     746            4 :                     &Fetch(addr),
     747            4 :                 )
     748            4 :                 .await
     749            4 :                 .unwrap();
     750            1 :         }
     751            1 :     }
     752              : }
        

Generated by: LCOV version 2.1-beta