Line data Source code
1 : use std::{
2 : future::Future,
3 : sync::Arc,
4 : time::{Duration, SystemTime},
5 : };
6 :
7 : use anyhow::{bail, ensure, Context};
8 : use arc_swap::ArcSwapOption;
9 : use dashmap::DashMap;
10 : use jose_jwk::crypto::KeyInfo;
11 : use serde::{Deserialize, Deserializer};
12 : use signature::Verifier;
13 : use tokio::time::Instant;
14 :
15 : use crate::{context::RequestMonitoring, http::parse_json_body_with_limit, EndpointId, RoleName};
16 :
17 : // TODO(conrad): make these configurable.
18 : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
19 : const MIN_RENEW: Duration = Duration::from_secs(30);
20 : const AUTO_RENEW: Duration = Duration::from_secs(300);
21 : const MAX_RENEW: Duration = Duration::from_secs(3600);
22 : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
23 :
24 : /// How to get the JWT auth rules
25 : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
26 : fn fetch_auth_rules(
27 : &self,
28 : ctx: &RequestMonitoring,
29 : endpoint: EndpointId,
30 : role_name: RoleName,
31 : ) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
32 : }
33 :
34 : pub(crate) struct AuthRule {
35 : pub(crate) id: String,
36 : pub(crate) jwks_url: url::Url,
37 : pub(crate) audience: Option<String>,
38 : }
39 :
40 : #[derive(Default)]
41 : pub(crate) struct JwkCache {
42 : client: reqwest::Client,
43 :
44 : map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
45 : }
46 :
47 : pub(crate) struct JwkCacheEntry {
48 : /// Should refetch at least every hour to verify when old keys have been removed.
49 : /// Should refetch when new key IDs are seen only every 5 minutes or so
50 : last_retrieved: Instant,
51 :
52 : /// cplane will return multiple JWKs urls that we need to scrape.
53 : key_sets: ahash::HashMap<String, KeySet>,
54 : }
55 :
56 : impl JwkCacheEntry {
57 4 : fn find_jwk_and_audience(&self, key_id: &str) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
58 6 : self.key_sets.values().find_map(|key_set| {
59 6 : key_set
60 6 : .find_key(key_id)
61 6 : .map(|jwk| (jwk, key_set.audience.as_deref()))
62 6 : })
63 4 : }
64 : }
65 :
66 : struct KeySet {
67 : jwks: jose_jwk::JwkSet,
68 : audience: Option<String>,
69 : }
70 :
71 : impl KeySet {
72 6 : fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
73 6 : self.jwks
74 6 : .keys
75 6 : .iter()
76 10 : .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
77 6 : }
78 : }
79 :
80 : pub(crate) struct JwkCacheEntryLock {
81 : cached: ArcSwapOption<JwkCacheEntry>,
82 : lookup: tokio::sync::Semaphore,
83 : }
84 :
85 : impl Default for JwkCacheEntryLock {
86 1 : fn default() -> Self {
87 1 : JwkCacheEntryLock {
88 1 : cached: ArcSwapOption::empty(),
89 1 : lookup: tokio::sync::Semaphore::new(1),
90 1 : }
91 1 : }
92 : }
93 :
94 : impl JwkCacheEntryLock {
95 1 : async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
96 1 : JwkRenewalPermit::acquire_permit(self).await
97 1 : }
98 :
99 0 : fn try_acquire_permit<'a>(self: &'a Arc<Self>) -> Option<JwkRenewalPermit<'a>> {
100 0 : JwkRenewalPermit::try_acquire_permit(self)
101 0 : }
102 :
103 1 : async fn renew_jwks<F: FetchAuthRules>(
104 1 : &self,
105 1 : _permit: JwkRenewalPermit<'_>,
106 1 : ctx: &RequestMonitoring,
107 1 : client: &reqwest::Client,
108 1 : endpoint: EndpointId,
109 1 : role_name: RoleName,
110 1 : auth_rules: &F,
111 1 : ) -> anyhow::Result<Arc<JwkCacheEntry>> {
112 1 : // double check that no one beat us to updating the cache.
113 1 : let now = Instant::now();
114 1 : let guard = self.cached.load_full();
115 1 : if let Some(cached) = guard {
116 0 : let last_update = now.duration_since(cached.last_retrieved);
117 0 : if last_update < Duration::from_secs(300) {
118 0 : return Ok(cached);
119 0 : }
120 1 : }
121 :
122 1 : let rules = auth_rules
123 1 : .fetch_auth_rules(ctx, endpoint, role_name)
124 0 : .await?;
125 1 : let mut key_sets =
126 1 : ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
127 : // TODO(conrad): run concurrently
128 : // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
129 3 : for rule in rules {
130 2 : let req = client.get(rule.jwks_url.clone());
131 2 : // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
132 2 : // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
133 4 : match req.send().await.and_then(|r| r.error_for_status()) {
134 : // todo: should we re-insert JWKs if we want to keep this JWKs URL?
135 : // I expect these failures would be quite sparse.
136 0 : Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
137 2 : Ok(r) => {
138 2 : let resp: http::Response<reqwest::Body> = r.into();
139 2 : match parse_json_body_with_limit::<jose_jwk::JwkSet>(
140 2 : resp.into_body(),
141 2 : MAX_JWK_BODY_SIZE,
142 2 : )
143 0 : .await
144 : {
145 0 : Err(e) => {
146 0 : tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
147 : }
148 2 : Ok(jwks) => {
149 2 : key_sets.insert(
150 2 : rule.id,
151 2 : KeySet {
152 2 : jwks,
153 2 : audience: rule.audience,
154 2 : },
155 2 : );
156 2 : }
157 : }
158 : }
159 : }
160 : }
161 :
162 1 : let entry = Arc::new(JwkCacheEntry {
163 1 : last_retrieved: now,
164 1 : key_sets,
165 1 : });
166 1 : self.cached.swap(Some(Arc::clone(&entry)));
167 1 :
168 1 : Ok(entry)
169 1 : }
170 :
171 4 : async fn get_or_update_jwk_cache<F: FetchAuthRules>(
172 4 : self: &Arc<Self>,
173 4 : ctx: &RequestMonitoring,
174 4 : client: &reqwest::Client,
175 4 : endpoint: EndpointId,
176 4 : role_name: RoleName,
177 4 : fetch: &F,
178 4 : ) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
179 4 : let now = Instant::now();
180 4 : let guard = self.cached.load_full();
181 :
182 : // if we have no cached JWKs, try and get some
183 4 : let Some(cached) = guard else {
184 1 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
185 1 : let permit = self.acquire_permit().await;
186 1 : return self
187 1 : .renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
188 4 : .await;
189 : };
190 :
191 3 : let last_update = now.duration_since(cached.last_retrieved);
192 3 :
193 3 : // check if the cached JWKs need updating.
194 3 : if last_update > MAX_RENEW {
195 0 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
196 0 : let permit = self.acquire_permit().await;
197 :
198 : // it's been too long since we checked the keys. wait for them to update.
199 0 : return self
200 0 : .renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
201 0 : .await;
202 3 : }
203 3 :
204 3 : // every 5 minutes we should spawn a job to eagerly update the token.
205 3 : if last_update > AUTO_RENEW {
206 0 : if let Some(permit) = self.try_acquire_permit() {
207 0 : tracing::debug!("JWKs should be renewed. Renewal permit acquired");
208 0 : let permit = permit.into_owned();
209 0 : let entry = self.clone();
210 0 : let client = client.clone();
211 0 : let fetch = fetch.clone();
212 0 : let ctx = ctx.clone();
213 0 : tokio::spawn(async move {
214 0 : if let Err(e) = entry
215 0 : .renew_jwks(permit, &ctx, &client, endpoint, role_name, &fetch)
216 0 : .await
217 : {
218 0 : tracing::warn!(error=?e, "could not fetch JWKs in background job");
219 0 : }
220 0 : });
221 0 : } else {
222 0 : tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
223 : }
224 3 : }
225 :
226 3 : Ok(cached)
227 4 : }
228 :
229 4 : async fn check_jwt<F: FetchAuthRules>(
230 4 : self: &Arc<Self>,
231 4 : ctx: &RequestMonitoring,
232 4 : jwt: &str,
233 4 : client: &reqwest::Client,
234 4 : endpoint: EndpointId,
235 4 : role_name: RoleName,
236 4 : fetch: &F,
237 4 : ) -> Result<(), anyhow::Error> {
238 : // JWT compact form is defined to be
239 : // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
240 : // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
241 :
242 4 : let (header_payload, signature) = jwt
243 4 : .rsplit_once('.')
244 4 : .context("Provided authentication token is not a valid JWT encoding")?;
245 4 : let (header, payload) = header_payload
246 4 : .split_once('.')
247 4 : .context("Provided authentication token is not a valid JWT encoding")?;
248 :
249 4 : let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)
250 4 : .context("Provided authentication token is not a valid JWT encoding")?;
251 4 : let header = serde_json::from_slice::<JwtHeader<'_>>(&header)
252 4 : .context("Provided authentication token is not a valid JWT encoding")?;
253 :
254 4 : let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
255 4 : .context("Provided authentication token is not a valid JWT encoding")?;
256 :
257 4 : ensure!(header.typ == "JWT");
258 4 : let kid = header.key_id.context("missing key id")?;
259 :
260 4 : let mut guard = self
261 4 : .get_or_update_jwk_cache(ctx, client, endpoint.clone(), role_name.clone(), fetch)
262 4 : .await?;
263 :
264 : // get the key from the JWKs if possible. If not, wait for the keys to update.
265 4 : let (jwk, expected_audience) = loop {
266 4 : match guard.find_jwk_and_audience(kid) {
267 4 : Some(jwk) => break jwk,
268 0 : None if guard.last_retrieved.elapsed() > MIN_RENEW => {
269 0 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
270 :
271 0 : let permit = self.acquire_permit().await;
272 0 : guard = self
273 0 : .renew_jwks(
274 0 : permit,
275 0 : ctx,
276 0 : client,
277 0 : endpoint.clone(),
278 0 : role_name.clone(),
279 0 : fetch,
280 0 : )
281 0 : .await?;
282 : }
283 : _ => {
284 0 : bail!("jwk not found");
285 : }
286 : }
287 : };
288 :
289 4 : ensure!(
290 4 : jwk.is_supported(&header.algorithm),
291 0 : "signature algorithm not supported"
292 : );
293 :
294 4 : match &jwk.key {
295 2 : jose_jwk::Key::Ec(key) => {
296 2 : verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
297 : }
298 2 : jose_jwk::Key::Rsa(key) => {
299 2 : verify_rsa_signature(header_payload.as_bytes(), &sig, key, &jwk.prm.alg)?;
300 : }
301 0 : key => bail!("unsupported key type {key:?}"),
302 : };
303 :
304 4 : let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
305 4 : .context("Provided authentication token is not a valid JWT encoding")?;
306 4 : let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
307 4 : .context("Provided authentication token is not a valid JWT encoding")?;
308 :
309 4 : tracing::debug!(?payload, "JWT signature valid with claims");
310 :
311 4 : match (expected_audience, payload.audience) {
312 : // check the audience matches
313 0 : (Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
314 : // the audience is expected but is missing
315 0 : (Some(_), None) => bail!("invalid JWT token audience"),
316 : // we don't care for the audience field
317 4 : (None, _) => {}
318 : }
319 :
320 4 : let now = SystemTime::now();
321 :
322 4 : if let Some(exp) = payload.expiration {
323 4 : ensure!(now < exp + CLOCK_SKEW_LEEWAY);
324 0 : }
325 :
326 4 : if let Some(nbf) = payload.not_before {
327 0 : ensure!(nbf < now + CLOCK_SKEW_LEEWAY);
328 4 : }
329 :
330 4 : Ok(())
331 4 : }
332 : }
333 :
334 : impl JwkCache {
335 0 : pub(crate) async fn check_jwt<F: FetchAuthRules>(
336 0 : &self,
337 0 : ctx: &RequestMonitoring,
338 0 : endpoint: EndpointId,
339 0 : role_name: RoleName,
340 0 : fetch: &F,
341 0 : jwt: &str,
342 0 : ) -> Result<(), anyhow::Error> {
343 0 : // try with just a read lock first
344 0 : let key = (endpoint.clone(), role_name.clone());
345 0 : let entry = self.map.get(&key).as_deref().map(Arc::clone);
346 0 : let entry = entry.unwrap_or_else(|| {
347 0 : // acquire a write lock after to insert.
348 0 : let entry = self.map.entry(key).or_default();
349 0 : Arc::clone(&*entry)
350 0 : });
351 0 :
352 0 : entry
353 0 : .check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
354 0 : .await
355 0 : }
356 : }
357 :
358 2 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> anyhow::Result<()> {
359 : use ecdsa::Signature;
360 : use signature::Verifier;
361 :
362 2 : match key.crv {
363 : jose_jwk::EcCurves::P256 => {
364 2 : let pk =
365 2 : p256::PublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid P256 key"))?;
366 2 : let key = p256::ecdsa::VerifyingKey::from(&pk);
367 2 : let sig = Signature::from_slice(sig)?;
368 2 : key.verify(data, &sig)?;
369 : }
370 0 : key => bail!("unsupported ec key type {key:?}"),
371 : }
372 :
373 2 : Ok(())
374 2 : }
375 :
376 2 : fn verify_rsa_signature(
377 2 : data: &[u8],
378 2 : sig: &[u8],
379 2 : key: &jose_jwk::Rsa,
380 2 : alg: &Option<jose_jwa::Algorithm>,
381 2 : ) -> anyhow::Result<()> {
382 : use jose_jwa::{Algorithm, Signing};
383 : use rsa::{
384 : pkcs1v15::{Signature, VerifyingKey},
385 : RsaPublicKey,
386 : };
387 :
388 2 : let key = RsaPublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid RSA key"))?;
389 :
390 2 : match alg {
391 : Some(Algorithm::Signing(Signing::Rs256)) => {
392 2 : let key = VerifyingKey::<sha2::Sha256>::new(key);
393 2 : let sig = Signature::try_from(sig)?;
394 2 : key.verify(data, &sig)?;
395 : }
396 0 : _ => bail!("invalid RSA signing algorithm"),
397 : };
398 :
399 2 : Ok(())
400 2 : }
401 :
402 : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
403 16 : #[derive(serde::Deserialize, serde::Serialize)]
404 : struct JwtHeader<'a> {
405 : /// must be "JWT"
406 : #[serde(rename = "typ")]
407 : typ: &'a str,
408 : /// must be a supported alg
409 : #[serde(rename = "alg")]
410 : algorithm: jose_jwa::Algorithm,
411 : /// key id, must be provided for our usecase
412 : #[serde(rename = "kid")]
413 : key_id: Option<&'a str>,
414 : }
415 :
416 : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
417 12 : #[derive(serde::Deserialize, serde::Serialize, Debug)]
418 : struct JwtPayload<'a> {
419 : /// Audience - Recipient for which the JWT is intended
420 : #[serde(rename = "aud")]
421 : audience: Option<&'a str>,
422 : /// Expiration - Time after which the JWT expires
423 : #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
424 : expiration: Option<SystemTime>,
425 : /// Not before - Time after which the JWT expires
426 : #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
427 : not_before: Option<SystemTime>,
428 :
429 : // the following entries are only extracted for the sake of debug logging.
430 : /// Issuer of the JWT
431 : #[serde(rename = "iss")]
432 : issuer: Option<&'a str>,
433 : /// Subject of the JWT (the user)
434 : #[serde(rename = "sub")]
435 : subject: Option<&'a str>,
436 : /// Unique token identifier
437 : #[serde(rename = "jti")]
438 : jwt_id: Option<&'a str>,
439 : /// Unique session identifier
440 : #[serde(rename = "sid")]
441 : session_id: Option<&'a str>,
442 : }
443 :
444 4 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
445 4 : let d = <Option<u64>>::deserialize(d)?;
446 4 : Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
447 4 : }
448 :
449 : struct JwkRenewalPermit<'a> {
450 : inner: Option<JwkRenewalPermitInner<'a>>,
451 : }
452 :
453 : enum JwkRenewalPermitInner<'a> {
454 : Owned(Arc<JwkCacheEntryLock>),
455 : Borrowed(&'a Arc<JwkCacheEntryLock>),
456 : }
457 :
458 : impl JwkRenewalPermit<'_> {
459 0 : fn into_owned(mut self) -> JwkRenewalPermit<'static> {
460 0 : JwkRenewalPermit {
461 0 : inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
462 0 : }
463 0 : }
464 :
465 1 : async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
466 1 : match from.lookup.acquire().await {
467 1 : Ok(permit) => {
468 1 : permit.forget();
469 1 : JwkRenewalPermit {
470 1 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
471 1 : }
472 : }
473 0 : Err(_) => panic!("semaphore should not be closed"),
474 : }
475 1 : }
476 :
477 0 : fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
478 0 : match from.lookup.try_acquire() {
479 0 : Ok(permit) => {
480 0 : permit.forget();
481 0 : Some(JwkRenewalPermit {
482 0 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
483 0 : })
484 : }
485 0 : Err(tokio::sync::TryAcquireError::NoPermits) => None,
486 0 : Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
487 : }
488 0 : }
489 : }
490 :
491 : impl JwkRenewalPermitInner<'_> {
492 0 : fn into_owned(self) -> JwkRenewalPermitInner<'static> {
493 0 : match self {
494 0 : JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
495 0 : JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
496 : }
497 0 : }
498 : }
499 :
500 : impl Drop for JwkRenewalPermit<'_> {
501 1 : fn drop(&mut self) {
502 1 : let entry = match &self.inner {
503 0 : None => return,
504 0 : Some(JwkRenewalPermitInner::Owned(p)) => p,
505 1 : Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
506 : };
507 1 : entry.lookup.add_permits(1);
508 1 : }
509 : }
510 :
511 : #[cfg(test)]
512 : mod tests {
513 : use crate::RoleName;
514 :
515 : use super::*;
516 :
517 : use std::{future::IntoFuture, net::SocketAddr, time::SystemTime};
518 :
519 : use base64::URL_SAFE_NO_PAD;
520 : use bytes::Bytes;
521 : use http::Response;
522 : use http_body_util::Full;
523 : use hyper1::service::service_fn;
524 : use hyper_util::rt::TokioIo;
525 : use rand::rngs::OsRng;
526 : use rsa::pkcs8::DecodePrivateKey;
527 : use signature::Signer;
528 : use tokio::net::TcpListener;
529 :
530 2 : fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
531 2 : let sk = p256::SecretKey::random(&mut OsRng);
532 2 : let pk = sk.public_key().into();
533 2 : let jwk = jose_jwk::Jwk {
534 2 : key: jose_jwk::Key::Ec(pk),
535 2 : prm: jose_jwk::Parameters {
536 2 : kid: Some(kid),
537 2 : alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
538 2 : ..Default::default()
539 2 : },
540 2 : };
541 2 : (sk, jwk)
542 2 : }
543 :
544 2 : fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
545 2 : let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap();
546 2 : let pk = sk.to_public_key().into();
547 2 : let jwk = jose_jwk::Jwk {
548 2 : key: jose_jwk::Key::Rsa(pk),
549 2 : prm: jose_jwk::Parameters {
550 2 : kid: Some(kid),
551 2 : alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
552 2 : ..Default::default()
553 2 : },
554 2 : };
555 2 : (sk, jwk)
556 2 : }
557 :
558 4 : fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
559 4 : let header = JwtHeader {
560 4 : typ: "JWT",
561 4 : algorithm: jose_jwa::Algorithm::Signing(sig),
562 4 : key_id: Some(&kid),
563 4 : };
564 4 : let body = typed_json::json! {{
565 4 : "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
566 4 : }};
567 4 :
568 4 : let header =
569 4 : base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
570 4 : let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
571 4 :
572 4 : format!("{header}.{body}")
573 4 : }
574 :
575 2 : fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String {
576 : use p256::ecdsa::{Signature, SigningKey};
577 :
578 2 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
579 2 : let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
580 2 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
581 2 :
582 2 : format!("{payload}.{sig}")
583 2 : }
584 :
585 2 : fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
586 : use rsa::pkcs1v15::SigningKey;
587 : use rsa::signature::SignatureEncoding;
588 :
589 2 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
590 2 : let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
591 2 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
592 2 :
593 2 : format!("{payload}.{sig}")
594 2 : }
595 :
596 : // RSA key gen is slow....
597 : const RS1: &str = "-----BEGIN PRIVATE KEY-----
598 : MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y
599 : aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA
600 : SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA
601 : cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X
602 : mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s
603 : PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM
604 : nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9
605 : ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi
606 : 5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+
607 : KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz
608 : 1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP
609 : We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony
610 : UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu
611 : GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+
612 : HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU
613 : f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24
614 : DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY
615 : 9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF
616 : /4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL
617 : 3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE
618 : qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d
619 : ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW
620 : rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf
621 : JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI
622 : rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9
623 : 8LMFlz1QPqSZYN/A/kOcLBfa3A==
624 : -----END PRIVATE KEY-----
625 : ";
626 : const RS2: &str = "-----BEGIN PRIVATE KEY-----
627 : MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J
628 : HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J
629 : zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk
630 : qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO
631 : IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt
632 : ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy
633 : uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv
634 : K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG
635 : LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe
636 : QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT
637 : zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX
638 : 0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR
639 : +Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1
640 : P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL
641 : x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/
642 : FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg
643 : iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M
644 : 4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA
645 : iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj
646 : N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB
647 : uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U
648 : qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP
649 : WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB
650 : rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne
651 : X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
652 : 5JiconnI5aLek0QVPoFaVXFa
653 : -----END PRIVATE KEY-----
654 : ";
655 :
656 : #[tokio::test]
657 1 : async fn renew() {
658 1 : let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into());
659 1 : let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into());
660 1 : let (ec1, jwk3) = new_ec_jwk("3".into());
661 1 : let (ec2, jwk4) = new_ec_jwk("4".into());
662 1 :
663 1 : let jwt1 = new_rsa_jwt("1".into(), rs1);
664 1 : let jwt2 = new_rsa_jwt("2".into(), rs2);
665 1 : let jwt3 = new_ec_jwt("3".into(), ec1);
666 1 : let jwt4 = new_ec_jwt("4".into(), ec2);
667 1 :
668 1 : let foo_jwks = jose_jwk::JwkSet {
669 1 : keys: vec![jwk1, jwk3],
670 1 : };
671 1 : let bar_jwks = jose_jwk::JwkSet {
672 1 : keys: vec![jwk2, jwk4],
673 1 : };
674 1 :
675 2 : let service = service_fn(move |req| {
676 2 : let foo_jwks = foo_jwks.clone();
677 2 : let bar_jwks = bar_jwks.clone();
678 2 : async move {
679 2 : let jwks = match req.uri().path() {
680 2 : "/foo" => &foo_jwks,
681 1 : "/bar" => &bar_jwks,
682 1 : _ => {
683 1 : return Response::builder()
684 0 : .status(404)
685 0 : .body(Full::new(Bytes::new()));
686 1 : }
687 1 : };
688 2 : let body = serde_json::to_vec(jwks).unwrap();
689 2 : Response::builder()
690 2 : .status(200)
691 2 : .body(Full::new(Bytes::from(body)))
692 2 : }
693 2 : });
694 1 :
695 1 : let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
696 1 : let server = hyper1::server::conn::http1::Builder::new();
697 1 : let addr = listener.local_addr().unwrap();
698 1 : tokio::spawn(async move {
699 1 : loop {
700 2 : let (s, _) = listener.accept().await.unwrap();
701 1 : let serve = server.serve_connection(TokioIo::new(s), service.clone());
702 1 : tokio::spawn(serve.into_future());
703 1 : }
704 1 : });
705 1 :
706 1 : let client = reqwest::Client::new();
707 1 :
708 1 : #[derive(Clone)]
709 1 : struct Fetch(SocketAddr);
710 1 :
711 1 : impl FetchAuthRules for Fetch {
712 1 : async fn fetch_auth_rules(
713 1 : &self,
714 1 : _ctx: &RequestMonitoring,
715 1 : _endpoint: EndpointId,
716 1 : _role_name: RoleName,
717 1 : ) -> anyhow::Result<Vec<AuthRule>> {
718 1 : Ok(vec![
719 1 : AuthRule {
720 1 : id: "foo".to_owned(),
721 1 : jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
722 1 : audience: None,
723 1 : },
724 1 : AuthRule {
725 1 : id: "bar".to_owned(),
726 1 : jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
727 1 : audience: None,
728 1 : },
729 1 : ])
730 1 : }
731 1 : }
732 1 :
733 1 : let role_name = RoleName::from("user");
734 1 : let endpoint = EndpointId::from("ep");
735 1 :
736 1 : let jwk_cache = Arc::new(JwkCacheEntryLock::default());
737 1 :
738 4 : for token in [jwt1, jwt2, jwt3, jwt4] {
739 4 : jwk_cache
740 4 : .check_jwt(
741 4 : &RequestMonitoring::test(),
742 4 : &token,
743 4 : &client,
744 4 : endpoint.clone(),
745 4 : role_name.clone(),
746 4 : &Fetch(addr),
747 4 : )
748 4 : .await
749 4 : .unwrap();
750 1 : }
751 1 : }
752 : }
|