Line data Source code
1 : use std::{
2 : future::Future,
3 : sync::Arc,
4 : time::{Duration, SystemTime},
5 : };
6 :
7 : use anyhow::{bail, ensure, Context};
8 : use arc_swap::ArcSwapOption;
9 : use dashmap::DashMap;
10 : use jose_jwk::crypto::KeyInfo;
11 : use serde::{Deserialize, Deserializer};
12 : use signature::Verifier;
13 : use tokio::time::Instant;
14 :
15 : use crate::{context::RequestMonitoring, http::parse_json_body_with_limit, EndpointId, RoleName};
16 :
17 : // TODO(conrad): make these configurable.
18 : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
19 : const MIN_RENEW: Duration = Duration::from_secs(30);
20 : const AUTO_RENEW: Duration = Duration::from_secs(300);
21 : const MAX_RENEW: Duration = Duration::from_secs(3600);
22 : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
23 :
24 : /// How to get the JWT auth rules
25 : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
26 : fn fetch_auth_rules(
27 : &self,
28 : role_name: RoleName,
29 : ) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
30 : }
31 :
32 : pub(crate) struct AuthRule {
33 : pub(crate) id: String,
34 : pub(crate) jwks_url: url::Url,
35 : pub(crate) audience: Option<String>,
36 : }
37 :
38 : #[derive(Default)]
39 : pub(crate) struct JwkCache {
40 : client: reqwest::Client,
41 :
42 : map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
43 : }
44 :
45 : pub(crate) struct JwkCacheEntry {
46 : /// Should refetch at least every hour to verify when old keys have been removed.
47 : /// Should refetch when new key IDs are seen only every 5 minutes or so
48 : last_retrieved: Instant,
49 :
50 : /// cplane will return multiple JWKs urls that we need to scrape.
51 : key_sets: ahash::HashMap<String, KeySet>,
52 : }
53 :
54 : impl JwkCacheEntry {
55 24 : fn find_jwk_and_audience(&self, key_id: &str) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
56 36 : self.key_sets.values().find_map(|key_set| {
57 36 : key_set
58 36 : .find_key(key_id)
59 36 : .map(|jwk| (jwk, key_set.audience.as_deref()))
60 36 : })
61 24 : }
62 : }
63 :
64 : struct KeySet {
65 : jwks: jose_jwk::JwkSet,
66 : audience: Option<String>,
67 : }
68 :
69 : impl KeySet {
70 36 : fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
71 36 : self.jwks
72 36 : .keys
73 36 : .iter()
74 60 : .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
75 36 : }
76 : }
77 :
78 : pub(crate) struct JwkCacheEntryLock {
79 : cached: ArcSwapOption<JwkCacheEntry>,
80 : lookup: tokio::sync::Semaphore,
81 : }
82 :
83 : impl Default for JwkCacheEntryLock {
84 6 : fn default() -> Self {
85 6 : JwkCacheEntryLock {
86 6 : cached: ArcSwapOption::empty(),
87 6 : lookup: tokio::sync::Semaphore::new(1),
88 6 : }
89 6 : }
90 : }
91 :
92 : impl JwkCacheEntryLock {
93 6 : async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
94 6 : JwkRenewalPermit::acquire_permit(self).await
95 6 : }
96 :
97 0 : fn try_acquire_permit<'a>(self: &'a Arc<Self>) -> Option<JwkRenewalPermit<'a>> {
98 0 : JwkRenewalPermit::try_acquire_permit(self)
99 0 : }
100 :
101 6 : async fn renew_jwks<F: FetchAuthRules>(
102 6 : &self,
103 6 : _permit: JwkRenewalPermit<'_>,
104 6 : client: &reqwest::Client,
105 6 : role_name: RoleName,
106 6 : auth_rules: &F,
107 6 : ) -> anyhow::Result<Arc<JwkCacheEntry>> {
108 6 : // double check that no one beat us to updating the cache.
109 6 : let now = Instant::now();
110 6 : let guard = self.cached.load_full();
111 6 : if let Some(cached) = guard {
112 0 : let last_update = now.duration_since(cached.last_retrieved);
113 0 : if last_update < Duration::from_secs(300) {
114 0 : return Ok(cached);
115 0 : }
116 6 : }
117 :
118 6 : let rules = auth_rules.fetch_auth_rules(role_name).await?;
119 6 : let mut key_sets =
120 6 : ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
121 : // TODO(conrad): run concurrently
122 : // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
123 18 : for rule in rules {
124 12 : let req = client.get(rule.jwks_url.clone());
125 12 : // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
126 12 : // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
127 24 : match req.send().await.and_then(|r| r.error_for_status()) {
128 : // todo: should we re-insert JWKs if we want to keep this JWKs URL?
129 : // I expect these failures would be quite sparse.
130 0 : Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
131 12 : Ok(r) => {
132 12 : let resp: http::Response<reqwest::Body> = r.into();
133 12 : match parse_json_body_with_limit::<jose_jwk::JwkSet>(
134 12 : resp.into_body(),
135 12 : MAX_JWK_BODY_SIZE,
136 12 : )
137 0 : .await
138 : {
139 0 : Err(e) => {
140 0 : tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
141 : }
142 12 : Ok(jwks) => {
143 12 : key_sets.insert(
144 12 : rule.id,
145 12 : KeySet {
146 12 : jwks,
147 12 : audience: rule.audience,
148 12 : },
149 12 : );
150 12 : }
151 : }
152 : }
153 : }
154 : }
155 :
156 6 : let entry = Arc::new(JwkCacheEntry {
157 6 : last_retrieved: now,
158 6 : key_sets,
159 6 : });
160 6 : self.cached.swap(Some(Arc::clone(&entry)));
161 6 :
162 6 : Ok(entry)
163 6 : }
164 :
165 24 : async fn get_or_update_jwk_cache<F: FetchAuthRules>(
166 24 : self: &Arc<Self>,
167 24 : ctx: &RequestMonitoring,
168 24 : client: &reqwest::Client,
169 24 : role_name: RoleName,
170 24 : fetch: &F,
171 24 : ) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
172 24 : let now = Instant::now();
173 24 : let guard = self.cached.load_full();
174 :
175 : // if we have no cached JWKs, try and get some
176 24 : let Some(cached) = guard else {
177 6 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
178 6 : let permit = self.acquire_permit().await;
179 24 : return self.renew_jwks(permit, client, role_name, fetch).await;
180 : };
181 :
182 18 : let last_update = now.duration_since(cached.last_retrieved);
183 18 :
184 18 : // check if the cached JWKs need updating.
185 18 : if last_update > MAX_RENEW {
186 0 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
187 0 : let permit = self.acquire_permit().await;
188 :
189 : // it's been too long since we checked the keys. wait for them to update.
190 0 : return self.renew_jwks(permit, client, role_name, fetch).await;
191 18 : }
192 18 :
193 18 : // every 5 minutes we should spawn a job to eagerly update the token.
194 18 : if last_update > AUTO_RENEW {
195 0 : if let Some(permit) = self.try_acquire_permit() {
196 0 : tracing::debug!("JWKs should be renewed. Renewal permit acquired");
197 0 : let permit = permit.into_owned();
198 0 : let entry = self.clone();
199 0 : let client = client.clone();
200 0 : let fetch = fetch.clone();
201 0 : tokio::spawn(async move {
202 0 : if let Err(e) = entry.renew_jwks(permit, &client, role_name, &fetch).await {
203 0 : tracing::warn!(error=?e, "could not fetch JWKs in background job");
204 0 : }
205 0 : });
206 0 : } else {
207 0 : tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
208 : }
209 18 : }
210 :
211 18 : Ok(cached)
212 24 : }
213 :
214 24 : async fn check_jwt<F: FetchAuthRules>(
215 24 : self: &Arc<Self>,
216 24 : ctx: &RequestMonitoring,
217 24 : jwt: &str,
218 24 : client: &reqwest::Client,
219 24 : role_name: RoleName,
220 24 : fetch: &F,
221 24 : ) -> Result<(), anyhow::Error> {
222 : // JWT compact form is defined to be
223 : // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
224 : // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
225 :
226 24 : let (header_payload, signature) = jwt
227 24 : .rsplit_once('.')
228 24 : .context("Provided authentication token is not a valid JWT encoding")?;
229 24 : let (header, payload) = header_payload
230 24 : .split_once('.')
231 24 : .context("Provided authentication token is not a valid JWT encoding")?;
232 :
233 24 : let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)
234 24 : .context("Provided authentication token is not a valid JWT encoding")?;
235 24 : let header = serde_json::from_slice::<JwtHeader<'_>>(&header)
236 24 : .context("Provided authentication token is not a valid JWT encoding")?;
237 :
238 24 : let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
239 24 : .context("Provided authentication token is not a valid JWT encoding")?;
240 :
241 24 : ensure!(header.typ == "JWT");
242 24 : let kid = header.key_id.context("missing key id")?;
243 :
244 24 : let mut guard = self
245 24 : .get_or_update_jwk_cache(ctx, client, role_name.clone(), fetch)
246 24 : .await?;
247 :
248 : // get the key from the JWKs if possible. If not, wait for the keys to update.
249 24 : let (jwk, expected_audience) = loop {
250 24 : match guard.find_jwk_and_audience(kid) {
251 24 : Some(jwk) => break jwk,
252 0 : None if guard.last_retrieved.elapsed() > MIN_RENEW => {
253 0 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
254 :
255 0 : let permit = self.acquire_permit().await;
256 0 : guard = self
257 0 : .renew_jwks(permit, client, role_name.clone(), fetch)
258 0 : .await?;
259 : }
260 : _ => {
261 0 : bail!("jwk not found");
262 : }
263 : }
264 : };
265 :
266 24 : ensure!(
267 24 : jwk.is_supported(&header.algorithm),
268 0 : "signature algorithm not supported"
269 : );
270 :
271 24 : match &jwk.key {
272 12 : jose_jwk::Key::Ec(key) => {
273 12 : verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
274 : }
275 12 : jose_jwk::Key::Rsa(key) => {
276 12 : verify_rsa_signature(header_payload.as_bytes(), &sig, key, &jwk.prm.alg)?;
277 : }
278 0 : key => bail!("unsupported key type {key:?}"),
279 : };
280 :
281 24 : let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
282 24 : .context("Provided authentication token is not a valid JWT encoding")?;
283 24 : let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
284 24 : .context("Provided authentication token is not a valid JWT encoding")?;
285 :
286 24 : tracing::debug!(?payload, "JWT signature valid with claims");
287 :
288 24 : match (expected_audience, payload.audience) {
289 : // check the audience matches
290 0 : (Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
291 : // the audience is expected but is missing
292 0 : (Some(_), None) => bail!("invalid JWT token audience"),
293 : // we don't care for the audience field
294 24 : (None, _) => {}
295 : }
296 :
297 24 : let now = SystemTime::now();
298 :
299 24 : if let Some(exp) = payload.expiration {
300 24 : ensure!(now < exp + CLOCK_SKEW_LEEWAY);
301 0 : }
302 :
303 24 : if let Some(nbf) = payload.not_before {
304 0 : ensure!(nbf < now + CLOCK_SKEW_LEEWAY);
305 24 : }
306 :
307 24 : Ok(())
308 24 : }
309 : }
310 :
311 : impl JwkCache {
312 0 : pub(crate) async fn check_jwt<F: FetchAuthRules>(
313 0 : &self,
314 0 : ctx: &RequestMonitoring,
315 0 : endpoint: EndpointId,
316 0 : role_name: RoleName,
317 0 : fetch: &F,
318 0 : jwt: &str,
319 0 : ) -> Result<(), anyhow::Error> {
320 0 : // try with just a read lock first
321 0 : let key = (endpoint, role_name.clone());
322 0 : let entry = self.map.get(&key).as_deref().map(Arc::clone);
323 0 : let entry = entry.unwrap_or_else(|| {
324 0 : // acquire a write lock after to insert.
325 0 : let entry = self.map.entry(key).or_default();
326 0 : Arc::clone(&*entry)
327 0 : });
328 0 :
329 0 : entry
330 0 : .check_jwt(ctx, jwt, &self.client, role_name, fetch)
331 0 : .await
332 0 : }
333 : }
334 :
335 12 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> anyhow::Result<()> {
336 12 : use ecdsa::Signature;
337 12 : use signature::Verifier;
338 12 :
339 12 : match key.crv {
340 : jose_jwk::EcCurves::P256 => {
341 12 : let pk =
342 12 : p256::PublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid P256 key"))?;
343 12 : let key = p256::ecdsa::VerifyingKey::from(&pk);
344 12 : let sig = Signature::from_slice(sig)?;
345 12 : key.verify(data, &sig)?;
346 : }
347 0 : key => bail!("unsupported ec key type {key:?}"),
348 : }
349 :
350 12 : Ok(())
351 12 : }
352 :
353 12 : fn verify_rsa_signature(
354 12 : data: &[u8],
355 12 : sig: &[u8],
356 12 : key: &jose_jwk::Rsa,
357 12 : alg: &Option<jose_jwa::Algorithm>,
358 12 : ) -> anyhow::Result<()> {
359 : use jose_jwa::{Algorithm, Signing};
360 : use rsa::{
361 : pkcs1v15::{Signature, VerifyingKey},
362 : RsaPublicKey,
363 : };
364 :
365 12 : let key = RsaPublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid RSA key"))?;
366 :
367 12 : match alg {
368 : Some(Algorithm::Signing(Signing::Rs256)) => {
369 12 : let key = VerifyingKey::<sha2::Sha256>::new(key);
370 12 : let sig = Signature::try_from(sig)?;
371 12 : key.verify(data, &sig)?;
372 : }
373 0 : _ => bail!("invalid RSA signing algorithm"),
374 : };
375 :
376 12 : Ok(())
377 12 : }
378 :
379 : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
380 96 : #[derive(serde::Deserialize, serde::Serialize)]
381 : struct JwtHeader<'a> {
382 : /// must be "JWT"
383 : #[serde(rename = "typ")]
384 : typ: &'a str,
385 : /// must be a supported alg
386 : #[serde(rename = "alg")]
387 : algorithm: jose_jwa::Algorithm,
388 : /// key id, must be provided for our usecase
389 : #[serde(rename = "kid")]
390 : key_id: Option<&'a str>,
391 : }
392 :
393 : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
394 72 : #[derive(serde::Deserialize, serde::Serialize, Debug)]
395 : struct JwtPayload<'a> {
396 : /// Audience - Recipient for which the JWT is intended
397 : #[serde(rename = "aud")]
398 : audience: Option<&'a str>,
399 : /// Expiration - Time after which the JWT expires
400 : #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
401 : expiration: Option<SystemTime>,
402 : /// Not before - Time after which the JWT expires
403 : #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
404 : not_before: Option<SystemTime>,
405 :
406 : // the following entries are only extracted for the sake of debug logging.
407 : /// Issuer of the JWT
408 : #[serde(rename = "iss")]
409 : issuer: Option<&'a str>,
410 : /// Subject of the JWT (the user)
411 : #[serde(rename = "sub")]
412 : subject: Option<&'a str>,
413 : /// Unique token identifier
414 : #[serde(rename = "jti")]
415 : jwt_id: Option<&'a str>,
416 : /// Unique session identifier
417 : #[serde(rename = "sid")]
418 : session_id: Option<&'a str>,
419 : }
420 :
421 24 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
422 24 : let d = <Option<u64>>::deserialize(d)?;
423 24 : Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
424 24 : }
425 :
426 : struct JwkRenewalPermit<'a> {
427 : inner: Option<JwkRenewalPermitInner<'a>>,
428 : }
429 :
430 : enum JwkRenewalPermitInner<'a> {
431 : Owned(Arc<JwkCacheEntryLock>),
432 : Borrowed(&'a Arc<JwkCacheEntryLock>),
433 : }
434 :
435 : impl JwkRenewalPermit<'_> {
436 0 : fn into_owned(mut self) -> JwkRenewalPermit<'static> {
437 0 : JwkRenewalPermit {
438 0 : inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
439 0 : }
440 0 : }
441 :
442 6 : async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
443 6 : match from.lookup.acquire().await {
444 6 : Ok(permit) => {
445 6 : permit.forget();
446 6 : JwkRenewalPermit {
447 6 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
448 6 : }
449 : }
450 0 : Err(_) => panic!("semaphore should not be closed"),
451 : }
452 6 : }
453 :
454 0 : fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
455 0 : match from.lookup.try_acquire() {
456 0 : Ok(permit) => {
457 0 : permit.forget();
458 0 : Some(JwkRenewalPermit {
459 0 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
460 0 : })
461 : }
462 0 : Err(tokio::sync::TryAcquireError::NoPermits) => None,
463 0 : Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
464 : }
465 0 : }
466 : }
467 :
468 : impl JwkRenewalPermitInner<'_> {
469 0 : fn into_owned(self) -> JwkRenewalPermitInner<'static> {
470 0 : match self {
471 0 : JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
472 0 : JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
473 : }
474 0 : }
475 : }
476 :
477 : impl Drop for JwkRenewalPermit<'_> {
478 6 : fn drop(&mut self) {
479 6 : let entry = match &self.inner {
480 0 : None => return,
481 0 : Some(JwkRenewalPermitInner::Owned(p)) => p,
482 6 : Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
483 : };
484 6 : entry.lookup.add_permits(1);
485 6 : }
486 : }
487 :
488 : #[cfg(test)]
489 : mod tests {
490 : use crate::RoleName;
491 :
492 : use super::*;
493 :
494 : use std::{future::IntoFuture, net::SocketAddr, time::SystemTime};
495 :
496 : use base64::URL_SAFE_NO_PAD;
497 : use bytes::Bytes;
498 : use http::Response;
499 : use http_body_util::Full;
500 : use hyper1::service::service_fn;
501 : use hyper_util::rt::TokioIo;
502 : use rand::rngs::OsRng;
503 : use signature::Signer;
504 : use tokio::net::TcpListener;
505 :
506 12 : fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
507 12 : let sk = p256::SecretKey::random(&mut OsRng);
508 12 : let pk = sk.public_key().into();
509 12 : let jwk = jose_jwk::Jwk {
510 12 : key: jose_jwk::Key::Ec(pk),
511 12 : prm: jose_jwk::Parameters {
512 12 : kid: Some(kid),
513 12 : alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
514 12 : ..Default::default()
515 12 : },
516 12 : };
517 12 : (sk, jwk)
518 12 : }
519 :
520 12 : fn new_rsa_jwk(kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
521 12 : let sk = rsa::RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
522 12 : let pk = sk.to_public_key().into();
523 12 : let jwk = jose_jwk::Jwk {
524 12 : key: jose_jwk::Key::Rsa(pk),
525 12 : prm: jose_jwk::Parameters {
526 12 : kid: Some(kid),
527 12 : alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
528 12 : ..Default::default()
529 12 : },
530 12 : };
531 12 : (sk, jwk)
532 12 : }
533 :
534 24 : fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
535 24 : let header = JwtHeader {
536 24 : typ: "JWT",
537 24 : algorithm: jose_jwa::Algorithm::Signing(sig),
538 24 : key_id: Some(&kid),
539 24 : };
540 24 : let body = typed_json::json! {{
541 24 : "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
542 24 : }};
543 24 :
544 24 : let header =
545 24 : base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
546 24 : let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
547 24 :
548 24 : format!("{header}.{body}")
549 24 : }
550 :
551 12 : fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String {
552 12 : use p256::ecdsa::{Signature, SigningKey};
553 12 :
554 12 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
555 12 : let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
556 12 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
557 12 :
558 12 : format!("{payload}.{sig}")
559 12 : }
560 :
561 12 : fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
562 12 : use rsa::pkcs1v15::SigningKey;
563 12 : use rsa::signature::SignatureEncoding;
564 12 :
565 12 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
566 12 : let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
567 12 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
568 12 :
569 12 : format!("{payload}.{sig}")
570 12 : }
571 :
572 : #[tokio::test]
573 6 : async fn renew() {
574 6 : let (rs1, jwk1) = new_rsa_jwk("1".into());
575 6 : let (rs2, jwk2) = new_rsa_jwk("2".into());
576 6 : let (ec1, jwk3) = new_ec_jwk("3".into());
577 6 : let (ec2, jwk4) = new_ec_jwk("4".into());
578 6 :
579 6 : let jwt1 = new_rsa_jwt("1".into(), rs1);
580 6 : let jwt2 = new_rsa_jwt("2".into(), rs2);
581 6 : let jwt3 = new_ec_jwt("3".into(), ec1);
582 6 : let jwt4 = new_ec_jwt("4".into(), ec2);
583 6 :
584 6 : let foo_jwks = jose_jwk::JwkSet {
585 6 : keys: vec![jwk1, jwk3],
586 6 : };
587 6 : let bar_jwks = jose_jwk::JwkSet {
588 6 : keys: vec![jwk2, jwk4],
589 6 : };
590 6 :
591 12 : let service = service_fn(move |req| {
592 12 : let foo_jwks = foo_jwks.clone();
593 12 : let bar_jwks = bar_jwks.clone();
594 12 : async move {
595 12 : let jwks = match req.uri().path() {
596 12 : "/foo" => &foo_jwks,
597 6 : "/bar" => &bar_jwks,
598 6 : _ => {
599 6 : return Response::builder()
600 0 : .status(404)
601 0 : .body(Full::new(Bytes::new()));
602 6 : }
603 6 : };
604 12 : let body = serde_json::to_vec(jwks).unwrap();
605 12 : Response::builder()
606 12 : .status(200)
607 12 : .body(Full::new(Bytes::from(body)))
608 12 : }
609 12 : });
610 6 :
611 6 : let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
612 6 : let server = hyper1::server::conn::http1::Builder::new();
613 6 : let addr = listener.local_addr().unwrap();
614 6 : tokio::spawn(async move {
615 6 : loop {
616 12 : let (s, _) = listener.accept().await.unwrap();
617 6 : let serve = server.serve_connection(TokioIo::new(s), service.clone());
618 6 : tokio::spawn(serve.into_future());
619 6 : }
620 6 : });
621 6 :
622 6 : let client = reqwest::Client::new();
623 6 :
624 6 : #[derive(Clone)]
625 6 : struct Fetch(SocketAddr);
626 6 :
627 6 : impl FetchAuthRules for Fetch {
628 6 : async fn fetch_auth_rules(
629 6 : &self,
630 6 : _role_name: RoleName,
631 6 : ) -> anyhow::Result<Vec<AuthRule>> {
632 6 : Ok(vec![
633 6 : AuthRule {
634 6 : id: "foo".to_owned(),
635 6 : jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
636 6 : audience: None,
637 6 : },
638 6 : AuthRule {
639 6 : id: "bar".to_owned(),
640 6 : jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
641 6 : audience: None,
642 6 : },
643 6 : ])
644 6 : }
645 6 : }
646 6 :
647 6 : let role_name = RoleName::from("user");
648 6 :
649 6 : let jwk_cache = Arc::new(JwkCacheEntryLock::default());
650 6 :
651 24 : for token in [jwt1, jwt2, jwt3, jwt4] {
652 24 : jwk_cache
653 24 : .check_jwt(
654 24 : &RequestMonitoring::test(),
655 24 : &token,
656 24 : &client,
657 24 : role_name.clone(),
658 24 : &Fetch(addr),
659 24 : )
660 24 : .await
661 24 : .unwrap();
662 6 : }
663 6 : }
664 : }
|