Line data Source code
1 : use std::borrow::Cow;
2 : use std::future::Future;
3 : use std::sync::Arc;
4 : use std::time::{Duration, SystemTime};
5 :
6 : use arc_swap::ArcSwapOption;
7 : use base64::Engine as _;
8 : use base64::prelude::BASE64_URL_SAFE_NO_PAD;
9 : use clashmap::ClashMap;
10 : use jose_jwk::crypto::KeyInfo;
11 : use reqwest::{Client, redirect};
12 : use reqwest_retry::RetryTransientMiddleware;
13 : use reqwest_retry::policies::ExponentialBackoff;
14 : use serde::de::Visitor;
15 : use serde::{Deserialize, Deserializer};
16 : use serde_json::value::RawValue;
17 : use signature::Verifier;
18 : use thiserror::Error;
19 : use tokio::time::Instant;
20 :
21 : use crate::auth::backend::ComputeCredentialKeys;
22 : use crate::context::RequestContext;
23 : use crate::control_plane::errors::GetEndpointJwksError;
24 : use crate::http::read_body_with_limit;
25 : use crate::intern::RoleNameInt;
26 : use crate::types::{EndpointId, RoleName};
27 :
28 : // TODO(conrad): make these configurable.
29 : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
30 : const MIN_RENEW: Duration = Duration::from_secs(30);
31 : const AUTO_RENEW: Duration = Duration::from_secs(300);
32 : const MAX_RENEW: Duration = Duration::from_secs(3600);
33 : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
34 : const JWKS_USER_AGENT: &str = "neon-proxy";
35 :
36 : const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
37 : const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
38 : const JWKS_FETCH_RETRIES: u32 = 3;
39 :
40 : /// How to get the JWT auth rules
41 : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
42 : fn fetch_auth_rules(
43 : &self,
44 : ctx: &RequestContext,
45 : endpoint: EndpointId,
46 : ) -> impl Future<Output = Result<Vec<AuthRule>, FetchAuthRulesError>> + Send;
47 : }
48 :
49 : #[derive(Error, Debug)]
50 : pub(crate) enum FetchAuthRulesError {
51 : #[error(transparent)]
52 : GetEndpointJwks(#[from] GetEndpointJwksError),
53 :
54 : #[error("JWKs settings for this role were not configured")]
55 : RoleJwksNotConfigured,
56 : }
57 :
58 : #[derive(Clone)]
59 : pub(crate) struct AuthRule {
60 : pub(crate) id: String,
61 : pub(crate) jwks_url: url::Url,
62 : pub(crate) audience: Option<String>,
63 : pub(crate) role_names: Vec<RoleNameInt>,
64 : }
65 :
66 : pub struct JwkCache {
67 : client: reqwest_middleware::ClientWithMiddleware,
68 :
69 : map: ClashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
70 : }
71 :
72 : pub(crate) struct JwkCacheEntry {
73 : /// Should refetch at least every hour to verify when old keys have been removed.
74 : /// Should refetch when new key IDs are seen only every 5 minutes or so
75 : last_retrieved: Instant,
76 :
77 : /// cplane will return multiple JWKs urls that we need to scrape.
78 : key_sets: ahash::HashMap<String, KeySet>,
79 : }
80 :
81 : impl JwkCacheEntry {
82 19 : fn find_jwk_and_audience(
83 19 : &self,
84 19 : key_id: &str,
85 19 : role_name: &RoleName,
86 19 : ) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
87 19 : self.key_sets
88 19 : .values()
89 : // make sure our requested role has access to the key set
90 29 : .filter(|key_set| key_set.role_names.iter().any(|role| **role == **role_name))
91 : // try and find the requested key-id in the key set
92 22 : .find_map(|key_set| {
93 22 : key_set
94 22 : .find_key(key_id)
95 22 : .map(|jwk| (jwk, key_set.audience.as_deref()))
96 22 : })
97 19 : }
98 : }
99 :
100 : struct KeySet {
101 : jwks: jose_jwk::JwkSet,
102 : audience: Option<String>,
103 : role_names: Vec<RoleNameInt>,
104 : }
105 :
106 : impl KeySet {
107 22 : fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
108 22 : self.jwks
109 22 : .keys
110 22 : .iter()
111 30 : .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
112 22 : }
113 : }
114 :
115 : pub(crate) struct JwkCacheEntryLock {
116 : cached: ArcSwapOption<JwkCacheEntry>,
117 : lookup: tokio::sync::Semaphore,
118 : }
119 :
120 : impl Default for JwkCacheEntryLock {
121 7 : fn default() -> Self {
122 7 : JwkCacheEntryLock {
123 7 : cached: ArcSwapOption::empty(),
124 7 : lookup: tokio::sync::Semaphore::new(1),
125 7 : }
126 7 : }
127 : }
128 :
129 0 : #[derive(Deserialize)]
130 : struct JwkSet<'a> {
131 : /// we parse into raw-value because not all keys in a JWKS are ones
132 : /// we can parse directly, so we parse them lazily.
133 : #[serde(borrow)]
134 : keys: Vec<&'a RawValue>,
135 : }
136 :
137 : /// Given a jwks_url, fetch the JWKS and parse out all the signing JWKs.
138 : /// Returns `None` and log a warning if there are any errors.
139 9 : async fn fetch_jwks(
140 9 : client: &reqwest_middleware::ClientWithMiddleware,
141 9 : jwks_url: url::Url,
142 9 : ) -> Option<jose_jwk::JwkSet> {
143 9 : let req = client.get(jwks_url.clone());
144 : // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
145 9 : let resp = req.send().await.and_then(|r| {
146 9 : r.error_for_status()
147 9 : .map_err(reqwest_middleware::Error::Reqwest)
148 9 : });
149 :
150 9 : let resp = match resp {
151 9 : Ok(r) => r,
152 : // TODO: should we re-insert JWKs if we want to keep this JWKs URL?
153 : // I expect these failures would be quite sparse.
154 0 : Err(e) => {
155 0 : tracing::warn!(url=?jwks_url, error=?e, "could not fetch JWKs");
156 0 : return None;
157 : }
158 : };
159 :
160 9 : let resp: http::Response<reqwest::Body> = resp.into();
161 :
162 9 : let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE).await {
163 9 : Ok(bytes) => bytes,
164 0 : Err(e) => {
165 0 : tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
166 0 : return None;
167 : }
168 : };
169 :
170 9 : let jwks = match serde_json::from_slice::<JwkSet>(&bytes) {
171 9 : Ok(jwks) => jwks,
172 0 : Err(e) => {
173 0 : tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
174 0 : return None;
175 : }
176 : };
177 :
178 : // `jose_jwk::Jwk` is quite large (288 bytes). Let's not pre-allocate for what we don't need.
179 : //
180 : // Even though we limit our responses to 64KiB, we could still receive a payload like
181 : // `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`. Parsing this as `RawValue` uses 468KiB.
182 : // Pre-allocating the corresponding `Vec::<jose_jwk::Jwk>::with_capacity(30000)` uses 8.2MiB.
183 9 : let mut keys = vec![];
184 :
185 9 : let mut failed = 0;
186 23 : for key in jwks.keys {
187 14 : let key = match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
188 13 : Ok(key) => key,
189 1 : Err(e) => {
190 1 : tracing::debug!(url=?jwks_url, failed=?e, "could not decode JWK");
191 1 : failed += 1;
192 1 : continue;
193 : }
194 : };
195 :
196 : // if `use` (called `cls` in rust) is specified to be something other than signing,
197 : // we can skip storing it.
198 13 : if key
199 13 : .prm
200 13 : .cls
201 13 : .as_ref()
202 13 : .is_some_and(|c| *c != jose_jwk::Class::Signing)
203 : {
204 0 : continue;
205 13 : }
206 :
207 13 : keys.push(key);
208 : }
209 :
210 9 : keys.shrink_to_fit();
211 :
212 9 : if failed > 0 {
213 1 : tracing::warn!(url=?jwks_url, failed, "could not decode JWKs");
214 8 : }
215 :
216 9 : if keys.is_empty() {
217 0 : tracing::warn!(url=?jwks_url, "no valid JWKs found inside the response body");
218 0 : return None;
219 9 : }
220 :
221 9 : Some(jose_jwk::JwkSet { keys })
222 9 : }
223 :
224 : impl JwkCacheEntryLock {
225 7 : async fn acquire_permit(self: &Arc<Self>) -> JwkRenewalPermit<'_> {
226 7 : JwkRenewalPermit::acquire_permit(self).await
227 7 : }
228 :
229 0 : fn try_acquire_permit(self: &Arc<Self>) -> Option<JwkRenewalPermit<'_>> {
230 0 : JwkRenewalPermit::try_acquire_permit(self)
231 0 : }
232 :
233 7 : async fn renew_jwks<F: FetchAuthRules>(
234 7 : &self,
235 7 : _permit: JwkRenewalPermit<'_>,
236 7 : ctx: &RequestContext,
237 7 : client: &reqwest_middleware::ClientWithMiddleware,
238 7 : endpoint: EndpointId,
239 7 : auth_rules: &F,
240 7 : ) -> Result<Arc<JwkCacheEntry>, JwtError> {
241 : // double check that no one beat us to updating the cache.
242 7 : let now = Instant::now();
243 7 : let guard = self.cached.load_full();
244 7 : if let Some(cached) = guard {
245 0 : let last_update = now.duration_since(cached.last_retrieved);
246 0 : if last_update < Duration::from_secs(300) {
247 0 : return Ok(cached);
248 0 : }
249 7 : }
250 :
251 7 : let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
252 7 : let mut key_sets =
253 7 : ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
254 :
255 : // TODO(conrad): run concurrently
256 : // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
257 16 : for rule in rules {
258 9 : if let Some(jwks) = fetch_jwks(client, rule.jwks_url).await {
259 9 : key_sets.insert(
260 9 : rule.id,
261 9 : KeySet {
262 9 : jwks,
263 9 : audience: rule.audience,
264 9 : role_names: rule.role_names,
265 9 : },
266 9 : );
267 9 : }
268 : }
269 :
270 7 : let entry = Arc::new(JwkCacheEntry {
271 7 : last_retrieved: now,
272 7 : key_sets,
273 7 : });
274 7 : self.cached.swap(Some(Arc::clone(&entry)));
275 :
276 7 : Ok(entry)
277 7 : }
278 :
279 19 : async fn get_or_update_jwk_cache<F: FetchAuthRules>(
280 19 : self: &Arc<Self>,
281 19 : ctx: &RequestContext,
282 19 : client: &reqwest_middleware::ClientWithMiddleware,
283 19 : endpoint: EndpointId,
284 19 : fetch: &F,
285 19 : ) -> Result<Arc<JwkCacheEntry>, JwtError> {
286 19 : let now = Instant::now();
287 19 : let guard = self.cached.load_full();
288 :
289 : // if we have no cached JWKs, try and get some
290 19 : let Some(cached) = guard else {
291 7 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
292 7 : let permit = self.acquire_permit().await;
293 7 : return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
294 : };
295 :
296 12 : let last_update = now.duration_since(cached.last_retrieved);
297 :
298 : // check if the cached JWKs need updating.
299 12 : if last_update > MAX_RENEW {
300 0 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
301 0 : let permit = self.acquire_permit().await;
302 :
303 : // it's been too long since we checked the keys. wait for them to update.
304 0 : return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
305 12 : }
306 :
307 : // every 5 minutes we should spawn a job to eagerly update the token.
308 12 : if last_update > AUTO_RENEW {
309 0 : if let Some(permit) = self.try_acquire_permit() {
310 0 : tracing::debug!("JWKs should be renewed. Renewal permit acquired");
311 0 : let permit = permit.into_owned();
312 0 : let entry = self.clone();
313 0 : let client = client.clone();
314 0 : let fetch = fetch.clone();
315 0 : let ctx = ctx.clone();
316 0 : tokio::spawn(async move {
317 0 : if let Err(e) = entry
318 0 : .renew_jwks(permit, &ctx, &client, endpoint, &fetch)
319 0 : .await
320 : {
321 0 : tracing::warn!(error=?e, "could not fetch JWKs in background job");
322 0 : }
323 0 : });
324 : } else {
325 0 : tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
326 : }
327 12 : }
328 :
329 12 : Ok(cached)
330 19 : }
331 :
332 19 : async fn check_jwt<F: FetchAuthRules>(
333 19 : self: &Arc<Self>,
334 19 : ctx: &RequestContext,
335 19 : jwt: &str,
336 19 : client: &reqwest_middleware::ClientWithMiddleware,
337 19 : endpoint: EndpointId,
338 19 : role_name: &RoleName,
339 19 : fetch: &F,
340 19 : ) -> Result<ComputeCredentialKeys, JwtError> {
341 : // JWT compact form is defined to be
342 : // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
343 : // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
344 :
345 19 : let (header_payload, signature) = jwt
346 19 : .rsplit_once('.')
347 19 : .ok_or(JwtEncodingError::InvalidCompactForm)?;
348 19 : let (header, payload) = header_payload
349 19 : .split_once('.')
350 19 : .ok_or(JwtEncodingError::InvalidCompactForm)?;
351 :
352 19 : let header = BASE64_URL_SAFE_NO_PAD.decode(header)?;
353 19 : let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
354 :
355 19 : let payloadb = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
356 19 : let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
357 :
358 19 : if let Some(iss) = &payload.issuer {
359 12 : ctx.set_jwt_issuer(iss.as_ref().to_owned());
360 12 : }
361 :
362 19 : let sig = BASE64_URL_SAFE_NO_PAD.decode(signature)?;
363 :
364 19 : let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
365 :
366 19 : let mut guard = self
367 19 : .get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
368 19 : .await?;
369 :
370 : // get the key from the JWKs if possible. If not, wait for the keys to update.
371 18 : let (jwk, expected_audience) = loop {
372 19 : match guard.find_jwk_and_audience(&kid, role_name) {
373 18 : Some(jwk) => break jwk,
374 1 : None if guard.last_retrieved.elapsed() > MIN_RENEW => {
375 0 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
376 :
377 0 : let permit = self.acquire_permit().await;
378 0 : guard = self
379 0 : .renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
380 0 : .await?;
381 : }
382 1 : _ => return Err(JwtError::JwkNotFound),
383 : }
384 : };
385 :
386 18 : if !jwk.is_supported(&header.algorithm) {
387 0 : return Err(JwtError::SignatureAlgorithmNotSupported);
388 18 : }
389 :
390 18 : match &jwk.key {
391 13 : jose_jwk::Key::Ec(key) => {
392 13 : verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
393 : }
394 5 : jose_jwk::Key::Rsa(key) => {
395 5 : verify_rsa_signature(header_payload.as_bytes(), &sig, key, &header.algorithm)?;
396 : }
397 0 : key => return Err(JwtError::UnsupportedKeyType(key.into())),
398 : }
399 :
400 17 : tracing::debug!(?payload, "JWT signature valid with claims");
401 :
402 17 : if let Some(aud) = expected_audience
403 7 : && payload.audience.0.iter().all(|s| s != aud)
404 : {
405 5 : return Err(JwtError::InvalidClaims(
406 5 : JwtClaimsError::InvalidJwtTokenAudience,
407 5 : ));
408 12 : }
409 :
410 12 : let now = SystemTime::now();
411 :
412 12 : if let Some(exp) = payload.expiration
413 11 : && now >= exp + CLOCK_SKEW_LEEWAY
414 : {
415 1 : return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
416 1 : exp.duration_since(SystemTime::UNIX_EPOCH)
417 1 : .unwrap_or_default()
418 1 : .as_secs(),
419 1 : )));
420 11 : }
421 :
422 11 : if let Some(nbf) = payload.not_before
423 10 : && nbf >= now + CLOCK_SKEW_LEEWAY
424 : {
425 1 : return Err(JwtError::InvalidClaims(
426 1 : JwtClaimsError::JwtTokenNotYetReadyToUse(
427 1 : nbf.duration_since(SystemTime::UNIX_EPOCH)
428 1 : .unwrap_or_default()
429 1 : .as_secs(),
430 1 : ),
431 1 : ));
432 10 : }
433 :
434 10 : Ok(ComputeCredentialKeys::JwtPayload(payloadb))
435 19 : }
436 : }
437 :
438 : impl JwkCache {
439 19 : pub(crate) async fn check_jwt<F: FetchAuthRules>(
440 19 : &self,
441 19 : ctx: &RequestContext,
442 19 : endpoint: EndpointId,
443 19 : role_name: &RoleName,
444 19 : fetch: &F,
445 19 : jwt: &str,
446 19 : ) -> Result<ComputeCredentialKeys, JwtError> {
447 : // try with just a read lock first
448 19 : let key = (endpoint.clone(), role_name.clone());
449 19 : let entry = self.map.get(&key).as_deref().map(Arc::clone);
450 19 : let entry = entry.unwrap_or_else(|| {
451 : // acquire a write lock after to insert.
452 7 : let entry = self.map.entry(key).or_default();
453 7 : Arc::clone(&*entry)
454 7 : });
455 :
456 19 : entry
457 19 : .check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
458 19 : .await
459 19 : }
460 : }
461 :
462 : impl Default for JwkCache {
463 9 : fn default() -> Self {
464 9 : let client = Client::builder()
465 9 : .user_agent(JWKS_USER_AGENT)
466 9 : .redirect(redirect::Policy::none())
467 9 : .tls_built_in_native_certs(true)
468 9 : .connect_timeout(JWKS_CONNECT_TIMEOUT)
469 9 : .timeout(JWKS_FETCH_TIMEOUT)
470 9 : .build()
471 9 : .expect("client config should be valid");
472 :
473 : // Retry up to 3 times with increasing intervals between attempts.
474 9 : let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES);
475 :
476 9 : let client = reqwest_middleware::ClientBuilder::new(client)
477 9 : .with(RetryTransientMiddleware::new_with_policy(retry_policy))
478 9 : .build();
479 :
480 9 : JwkCache {
481 9 : client,
482 9 : map: ClashMap::default(),
483 9 : }
484 9 : }
485 : }
486 :
487 13 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> Result<(), JwtError> {
488 : use ecdsa::Signature;
489 : use signature::Verifier;
490 :
491 13 : match key.crv {
492 : jose_jwk::EcCurves::P256 => {
493 13 : let pk = p256::PublicKey::try_from(key).map_err(JwtError::InvalidP256Key)?;
494 13 : let key = p256::ecdsa::VerifyingKey::from(&pk);
495 13 : let sig = Signature::from_slice(sig)?;
496 13 : key.verify(data, &sig)?;
497 : }
498 0 : key => return Err(JwtError::UnsupportedEcKeyType(key)),
499 : }
500 :
501 12 : Ok(())
502 13 : }
503 :
504 5 : fn verify_rsa_signature(
505 5 : data: &[u8],
506 5 : sig: &[u8],
507 5 : key: &jose_jwk::Rsa,
508 5 : alg: &jose_jwa::Algorithm,
509 5 : ) -> Result<(), JwtError> {
510 : use jose_jwa::{Algorithm, Signing};
511 : use rsa::RsaPublicKey;
512 : use rsa::pkcs1v15::{Signature, VerifyingKey};
513 :
514 5 : let key = RsaPublicKey::try_from(key).map_err(JwtError::InvalidRsaKey)?;
515 :
516 5 : match alg {
517 : Algorithm::Signing(Signing::Rs256) => {
518 5 : let key = VerifyingKey::<sha2::Sha256>::new(key);
519 5 : let sig = Signature::try_from(sig)?;
520 5 : key.verify(data, &sig)?;
521 : }
522 0 : _ => return Err(JwtError::InvalidRsaSigningAlgorithm),
523 : }
524 :
525 5 : Ok(())
526 5 : }
527 :
528 : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
529 0 : #[derive(serde::Deserialize, serde::Serialize)]
530 : struct JwtHeader<'a> {
531 : /// must be a supported alg
532 : #[serde(rename = "alg")]
533 : algorithm: jose_jwa::Algorithm,
534 : /// key id, must be provided for our usecase
535 : #[serde(rename = "kid", borrow)]
536 : key_id: Option<Cow<'a, str>>,
537 : }
538 :
539 : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
540 : #[derive(serde::Deserialize, Debug)]
541 : #[allow(dead_code)]
542 : struct JwtPayload<'a> {
543 : /// Audience - Recipient for which the JWT is intended
544 : #[serde(rename = "aud", default)]
545 : audience: OneOrMany,
546 : /// Expiration - Time after which the JWT expires
547 : #[serde(rename = "exp", deserialize_with = "numeric_date_opt", default)]
548 : expiration: Option<SystemTime>,
549 : /// Not before - Time before which the JWT is not valid
550 : #[serde(rename = "nbf", deserialize_with = "numeric_date_opt", default)]
551 : not_before: Option<SystemTime>,
552 :
553 : // the following entries are only extracted for the sake of debug logging.
554 : /// Issuer of the JWT
555 : #[serde(rename = "iss", borrow)]
556 : issuer: Option<Cow<'a, str>>,
557 : /// Subject of the JWT (the user)
558 : #[serde(rename = "sub", borrow)]
559 : subject: Option<Cow<'a, str>>,
560 : /// Unique token identifier
561 : #[serde(rename = "jti", borrow)]
562 : jwt_id: Option<Cow<'a, str>>,
563 : /// Unique session identifier
564 : #[serde(rename = "sid", borrow)]
565 : session_id: Option<Cow<'a, str>>,
566 : }
567 :
568 : /// `OneOrMany` supports parsing either a single item or an array of items.
569 : ///
570 : /// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
571 : ///
572 : /// > The "aud" (audience) claim identifies the recipients that the JWT is
573 : /// > intended for. Each principal intended to process the JWT MUST
574 : /// > identify itself with a value in the audience claim. If the principal
575 : /// > processing the claim does not identify itself with a value in the
576 : /// > "aud" claim when this claim is present, then the JWT MUST be
577 : /// > rejected. In the general case, the "aud" value is **an array of case-
578 : /// > sensitive strings**, each containing a StringOrURI value. In the
579 : /// > special case when the JWT has one audience, the "aud" value MAY be a
580 : /// > **single case-sensitive string** containing a StringOrURI value. The
581 : /// > interpretation of audience values is generally application specific.
582 : /// > Use of this claim is OPTIONAL.
583 : #[derive(Default, Debug)]
584 : struct OneOrMany(Vec<String>);
585 :
586 : impl<'de> Deserialize<'de> for OneOrMany {
587 17 : fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
588 17 : where
589 17 : D: Deserializer<'de>,
590 : {
591 : struct OneOrManyVisitor;
592 : impl<'de> Visitor<'de> for OneOrManyVisitor {
593 : type Value = OneOrMany;
594 :
595 0 : fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
596 0 : formatter.write_str("a single string or an array of strings")
597 0 : }
598 :
599 2 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
600 2 : where
601 2 : E: serde::de::Error,
602 : {
603 2 : Ok(OneOrMany(vec![v.to_owned()]))
604 2 : }
605 :
606 15 : fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
607 15 : where
608 15 : A: serde::de::SeqAccess<'de>,
609 : {
610 15 : let mut v = vec![];
611 52 : while let Some(s) = seq.next_element()? {
612 37 : v.push(s);
613 37 : }
614 15 : Ok(OneOrMany(v))
615 15 : }
616 : }
617 17 : deserializer.deserialize_any(OneOrManyVisitor)
618 17 : }
619 : }
620 :
621 25 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
622 25 : <Option<u64>>::deserialize(d)?
623 25 : .map(|t| {
624 25 : SystemTime::UNIX_EPOCH
625 25 : .checked_add(Duration::from_secs(t))
626 25 : .ok_or_else(|| {
627 0 : serde::de::Error::custom(format_args!("timestamp out of bounds: {t}"))
628 0 : })
629 25 : })
630 25 : .transpose()
631 25 : }
632 :
633 : struct JwkRenewalPermit<'a> {
634 : inner: Option<JwkRenewalPermitInner<'a>>,
635 : }
636 :
637 : enum JwkRenewalPermitInner<'a> {
638 : Owned(Arc<JwkCacheEntryLock>),
639 : Borrowed(&'a Arc<JwkCacheEntryLock>),
640 : }
641 :
642 : impl JwkRenewalPermit<'_> {
643 0 : fn into_owned(mut self) -> JwkRenewalPermit<'static> {
644 0 : JwkRenewalPermit {
645 0 : inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
646 0 : }
647 0 : }
648 :
649 7 : async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
650 7 : match from.lookup.acquire().await {
651 7 : Ok(permit) => {
652 7 : permit.forget();
653 7 : JwkRenewalPermit {
654 7 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
655 7 : }
656 : }
657 0 : Err(_) => panic!("semaphore should not be closed"),
658 : }
659 7 : }
660 :
661 0 : fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
662 0 : match from.lookup.try_acquire() {
663 0 : Ok(permit) => {
664 0 : permit.forget();
665 0 : Some(JwkRenewalPermit {
666 0 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
667 0 : })
668 : }
669 0 : Err(tokio::sync::TryAcquireError::NoPermits) => None,
670 0 : Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
671 : }
672 0 : }
673 : }
674 :
675 : impl JwkRenewalPermitInner<'_> {
676 0 : fn into_owned(self) -> JwkRenewalPermitInner<'static> {
677 0 : match self {
678 0 : JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
679 0 : JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
680 : }
681 0 : }
682 : }
683 :
684 : impl Drop for JwkRenewalPermit<'_> {
685 7 : fn drop(&mut self) {
686 7 : let entry = match &self.inner {
687 0 : None => return,
688 0 : Some(JwkRenewalPermitInner::Owned(p)) => p,
689 7 : Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
690 : };
691 7 : entry.lookup.add_permits(1);
692 7 : }
693 : }
694 :
695 : #[derive(Error, Debug)]
696 : #[non_exhaustive]
697 : pub(crate) enum JwtError {
698 : #[error("jwk not found")]
699 : JwkNotFound,
700 :
701 : #[error("missing key id")]
702 : MissingKeyId,
703 :
704 : #[error("Provided authentication token is not a valid JWT encoding")]
705 : JwtEncoding(#[from] JwtEncodingError),
706 :
707 : #[error(transparent)]
708 : InvalidClaims(#[from] JwtClaimsError),
709 :
710 : #[error("invalid P256 key")]
711 : InvalidP256Key(jose_jwk::crypto::Error),
712 :
713 : #[error("invalid RSA key")]
714 : InvalidRsaKey(jose_jwk::crypto::Error),
715 :
716 : #[error("invalid RSA signing algorithm")]
717 : InvalidRsaSigningAlgorithm,
718 :
719 : #[error("unsupported EC key type {0:?}")]
720 : UnsupportedEcKeyType(jose_jwk::EcCurves),
721 :
722 : #[error("unsupported key type {0:?}")]
723 : UnsupportedKeyType(KeyType),
724 :
725 : #[error("signature algorithm not supported")]
726 : SignatureAlgorithmNotSupported,
727 :
728 : #[error("signature error: {0}")]
729 : Signature(#[from] signature::Error),
730 :
731 : #[error("failed to fetch auth rules: {0}")]
732 : FetchAuthRules(#[from] FetchAuthRulesError),
733 : }
734 :
735 : impl From<base64::DecodeError> for JwtError {
736 0 : fn from(err: base64::DecodeError) -> Self {
737 0 : JwtEncodingError::Base64Decode(err).into()
738 0 : }
739 : }
740 :
741 : impl From<serde_json::Error> for JwtError {
742 0 : fn from(err: serde_json::Error) -> Self {
743 0 : JwtEncodingError::SerdeJson(err).into()
744 0 : }
745 : }
746 :
747 : #[derive(Error, Debug)]
748 : #[non_exhaustive]
749 : pub enum JwtEncodingError {
750 : #[error(transparent)]
751 : Base64Decode(#[from] base64::DecodeError),
752 :
753 : #[error(transparent)]
754 : SerdeJson(#[from] serde_json::Error),
755 :
756 : #[error("invalid compact form")]
757 : InvalidCompactForm,
758 : }
759 :
760 : #[derive(Error, Debug, PartialEq)]
761 : #[non_exhaustive]
762 : pub enum JwtClaimsError {
763 : #[error("invalid JWT token audience")]
764 : InvalidJwtTokenAudience,
765 :
766 : #[error("JWT token has expired (exp={0})")]
767 : JwtTokenHasExpired(u64),
768 :
769 : #[error("JWT token is not yet ready to use (nbf={0})")]
770 : JwtTokenNotYetReadyToUse(u64),
771 : }
772 :
773 : #[allow(dead_code, reason = "Debug use only")]
774 : #[derive(Debug)]
775 : pub(crate) enum KeyType {
776 : Ec(jose_jwk::EcCurves),
777 : Rsa,
778 : Oct,
779 : Okp(jose_jwk::OkpCurves),
780 : Unknown,
781 : }
782 :
783 : impl From<&jose_jwk::Key> for KeyType {
784 0 : fn from(key: &jose_jwk::Key) -> Self {
785 0 : match key {
786 0 : jose_jwk::Key::Ec(ec) => Self::Ec(ec.crv),
787 0 : jose_jwk::Key::Rsa(_rsa) => Self::Rsa,
788 0 : jose_jwk::Key::Oct(_oct) => Self::Oct,
789 0 : jose_jwk::Key::Okp(okp) => Self::Okp(okp.crv),
790 0 : _ => Self::Unknown,
791 : }
792 0 : }
793 : }
794 :
795 : #[cfg(test)]
796 : mod tests {
797 : use std::future::IntoFuture;
798 : use std::net::SocketAddr;
799 : use std::time::SystemTime;
800 :
801 : use bytes::Bytes;
802 : use http::Response;
803 : use http_body_util::Full;
804 : use hyper::service::service_fn;
805 : use hyper_util::rt::TokioIo;
806 : use rand::rngs::OsRng;
807 : use rsa::pkcs8::DecodePrivateKey;
808 : use serde::Serialize;
809 : use serde_json::json;
810 : use signature::Signer;
811 : use tokio::net::TcpListener;
812 :
813 : use super::*;
814 : use crate::types::RoleName;
815 :
816 6 : fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
817 6 : let sk = p256::SecretKey::random(&mut OsRng);
818 6 : let pk = sk.public_key().into();
819 6 : let jwk = jose_jwk::Jwk {
820 6 : key: jose_jwk::Key::Ec(pk),
821 6 : prm: jose_jwk::Parameters {
822 6 : kid: Some(kid),
823 6 : alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
824 6 : ..Default::default()
825 6 : },
826 6 : };
827 6 : (sk, jwk)
828 6 : }
829 :
830 4 : fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
831 4 : let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap();
832 4 : let pk = sk.to_public_key().into();
833 4 : let jwk = jose_jwk::Jwk {
834 4 : key: jose_jwk::Key::Rsa(pk),
835 4 : prm: jose_jwk::Parameters {
836 4 : kid: Some(kid),
837 4 : alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
838 4 : ..Default::default()
839 4 : },
840 4 : };
841 4 : (sk, jwk)
842 4 : }
843 :
844 8 : fn now() -> u64 {
845 8 : SystemTime::now()
846 8 : .duration_since(SystemTime::UNIX_EPOCH)
847 8 : .unwrap()
848 8 : .as_secs()
849 8 : }
850 :
851 7 : fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
852 7 : let now = now();
853 7 : let body = typed_json::json! {{
854 7 : "exp": now + 3600,
855 7 : "nbf": now,
856 : "aud": ["audience1", "neon", "audience2"],
857 : "sub": "user1",
858 : "sid": "session1",
859 : "jti": "token1",
860 : "iss": "neon-testing",
861 : }};
862 7 : build_custom_jwt_payload(kid, body, sig)
863 7 : }
864 :
865 15 : fn build_custom_jwt_payload(
866 15 : kid: String,
867 15 : body: impl Serialize,
868 15 : sig: jose_jwa::Signing,
869 15 : ) -> String {
870 15 : let header = JwtHeader {
871 15 : algorithm: jose_jwa::Algorithm::Signing(sig),
872 15 : key_id: Some(Cow::Owned(kid)),
873 15 : };
874 :
875 15 : let header = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
876 15 : let body = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&body).unwrap());
877 :
878 15 : format!("{header}.{body}")
879 15 : }
880 :
881 3 : fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
882 : use p256::ecdsa::{Signature, SigningKey};
883 :
884 3 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
885 3 : let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
886 3 : let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
887 :
888 3 : format!("{payload}.{sig}")
889 3 : }
890 :
891 8 : fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
892 : use p256::ecdsa::{Signature, SigningKey};
893 :
894 8 : let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
895 8 : let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
896 8 : let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
897 :
898 8 : format!("{payload}.{sig}")
899 8 : }
900 :
901 4 : fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
902 : use rsa::pkcs1v15::SigningKey;
903 : use rsa::signature::SignatureEncoding;
904 :
905 4 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
906 4 : let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
907 4 : let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
908 :
909 4 : format!("{payload}.{sig}")
910 4 : }
911 :
912 : // RSA key gen is slow....
913 : const RS1: &str = "-----BEGIN PRIVATE KEY-----
914 : MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y
915 : aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA
916 : SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA
917 : cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X
918 : mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s
919 : PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM
920 : nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9
921 : ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi
922 : 5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+
923 : KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz
924 : 1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP
925 : We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony
926 : UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu
927 : GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+
928 : HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU
929 : f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24
930 : DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY
931 : 9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF
932 : /4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL
933 : 3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE
934 : qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d
935 : ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW
936 : rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf
937 : JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI
938 : rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9
939 : 8LMFlz1QPqSZYN/A/kOcLBfa3A==
940 : -----END PRIVATE KEY-----
941 : ";
942 : const RS2: &str = "-----BEGIN PRIVATE KEY-----
943 : MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J
944 : HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J
945 : zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk
946 : qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO
947 : IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt
948 : ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy
949 : uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv
950 : K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG
951 : LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe
952 : QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT
953 : zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX
954 : 0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR
955 : +Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1
956 : P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL
957 : x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/
958 : FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg
959 : iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M
960 : 4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA
961 : iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj
962 : N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB
963 : uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U
964 : qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP
965 : WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB
966 : rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne
967 : X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
968 : 5JiconnI5aLek0QVPoFaVXFa
969 : -----END PRIVATE KEY-----
970 : ";
971 :
972 : #[derive(Clone)]
973 : struct Fetch(Vec<AuthRule>);
974 :
975 : impl FetchAuthRules for Fetch {
976 7 : async fn fetch_auth_rules(
977 7 : &self,
978 7 : _ctx: &RequestContext,
979 7 : _endpoint: EndpointId,
980 7 : ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
981 7 : Ok(self.0.clone())
982 7 : }
983 : }
984 :
985 6 : async fn jwks_server(
986 6 : router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
987 6 : ) -> SocketAddr {
988 6 : let router = Arc::new(router);
989 9 : let service = service_fn(move |req| {
990 9 : let router = Arc::clone(&router);
991 9 : async move {
992 9 : match router(req.uri().path()) {
993 9 : Some(body) => Response::builder()
994 9 : .status(200)
995 9 : .body(Full::new(Bytes::from(body))),
996 0 : None => Response::builder()
997 0 : .status(404)
998 0 : .body(Full::new(Bytes::new())),
999 : }
1000 9 : }
1001 9 : });
1002 :
1003 6 : let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
1004 6 : let server = hyper::server::conn::http1::Builder::new();
1005 6 : let addr = listener.local_addr().unwrap();
1006 6 : tokio::spawn(async move {
1007 : loop {
1008 12 : let (s, _) = listener.accept().await.unwrap();
1009 6 : let serve = server.serve_connection(TokioIo::new(s), service.clone());
1010 6 : tokio::spawn(serve.into_future());
1011 : }
1012 : });
1013 :
1014 6 : addr
1015 6 : }
1016 :
1017 : #[tokio::test]
1018 1 : async fn check_jwt_happy_path() {
1019 1 : let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
1020 1 : let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
1021 1 : let (ec1, jwk3) = new_ec_jwk("ec1".into());
1022 1 : let (ec2, jwk4) = new_ec_jwk("ec2".into());
1023 :
1024 1 : let foo_jwks = jose_jwk::JwkSet {
1025 1 : keys: vec![jwk1, jwk3],
1026 1 : };
1027 1 : let bar_jwks = jose_jwk::JwkSet {
1028 1 : keys: vec![jwk2, jwk4],
1029 1 : };
1030 :
1031 4 : let jwks_addr = jwks_server(move |path| match path {
1032 4 : "/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
1033 2 : "/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
1034 0 : _ => None,
1035 4 : })
1036 1 : .await;
1037 :
1038 1 : let role_name1 = RoleName::from("anonymous");
1039 1 : let role_name2 = RoleName::from("authenticated");
1040 :
1041 1 : let roles = vec![
1042 1 : RoleNameInt::from(&role_name1),
1043 1 : RoleNameInt::from(&role_name2),
1044 : ];
1045 1 : let rules = vec![
1046 1 : AuthRule {
1047 1 : id: "foo".to_owned(),
1048 1 : jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
1049 1 : audience: None,
1050 1 : role_names: roles.clone(),
1051 1 : },
1052 1 : AuthRule {
1053 1 : id: "bar".to_owned(),
1054 1 : jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
1055 1 : audience: None,
1056 1 : role_names: roles.clone(),
1057 1 : },
1058 : ];
1059 :
1060 1 : let fetch = Fetch(rules);
1061 1 : let jwk_cache = JwkCache::default();
1062 :
1063 1 : let endpoint = EndpointId::from("ep");
1064 :
1065 1 : let jwt1 = new_rsa_jwt("rs1".into(), rs1);
1066 1 : let jwt2 = new_rsa_jwt("rs2".into(), rs2);
1067 1 : let jwt3 = new_ec_jwt("ec1".into(), &ec1);
1068 1 : let jwt4 = new_ec_jwt("ec2".into(), &ec2);
1069 :
1070 1 : let tokens = [jwt1, jwt2, jwt3, jwt4];
1071 1 : let role_names = [role_name1, role_name2];
1072 3 : for role in &role_names {
1073 10 : for token in &tokens {
1074 8 : jwk_cache
1075 8 : .check_jwt(
1076 8 : &RequestContext::test(),
1077 8 : endpoint.clone(),
1078 8 : role,
1079 8 : &fetch,
1080 8 : token,
1081 8 : )
1082 8 : .await
1083 8 : .unwrap();
1084 1 : }
1085 1 : }
1086 1 : }
1087 :
1088 : /// AWS Cognito escapes the `/` in the URL.
1089 : #[tokio::test]
1090 1 : async fn check_jwt_regression_cognito_issuer() {
1091 1 : let (key, jwk) = new_ec_jwk("key".into());
1092 :
1093 1 : let now = now();
1094 1 : let token = new_custom_ec_jwt(
1095 1 : "key".into(),
1096 1 : &key,
1097 1 : typed_json::json! {{
1098 : "sub": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
1099 : // cognito uses `\/`. I cannot replicated that easily here as serde_json will refuse
1100 : // to write that escape character. instead I will make a bogus URL using `\` instead.
1101 : "iss": "https:\\\\cognito-idp.us-west-2.amazonaws.com\\us-west-2_abcdefgh",
1102 : "client_id": "abcdefghijklmnopqrstuvwxyz",
1103 : "origin_jti": "6759d132-3fe7-446e-9e90-2fe7e8017893",
1104 : "event_id": "ec9c36ab-b01d-46a0-94e4-87fde6767065",
1105 : "token_use": "access",
1106 : "scope": "aws.cognito.signin.user.admin",
1107 1 : "auth_time":now,
1108 1 : "exp":now + 60,
1109 1 : "iat":now,
1110 : "jti": "b241614b-0b93-4bdc-96db-0a3c7061d9c0",
1111 : "username": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
1112 : }},
1113 : );
1114 :
1115 1 : let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
1116 :
1117 1 : let jwks_addr = jwks_server(move |_path| Some(serde_json::to_vec(&jwks).unwrap())).await;
1118 :
1119 1 : let role_name = RoleName::from("anonymous");
1120 1 : let rules = vec![AuthRule {
1121 1 : id: "aws-cognito".to_owned(),
1122 1 : jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1123 1 : audience: None,
1124 1 : role_names: vec![RoleNameInt::from(&role_name)],
1125 1 : }];
1126 :
1127 1 : let fetch = Fetch(rules);
1128 1 : let jwk_cache = JwkCache::default();
1129 :
1130 1 : let endpoint = EndpointId::from("ep");
1131 :
1132 1 : jwk_cache
1133 1 : .check_jwt(
1134 1 : &RequestContext::test(),
1135 1 : endpoint.clone(),
1136 1 : &role_name,
1137 1 : &fetch,
1138 1 : &token,
1139 1 : )
1140 1 : .await
1141 1 : .unwrap();
1142 1 : }
1143 :
1144 : #[tokio::test]
1145 1 : async fn check_jwt_invalid_signature() {
1146 1 : let (_, jwk) = new_ec_jwk("1".into());
1147 1 : let (key, _) = new_ec_jwk("1".into());
1148 :
1149 : // has a matching kid, but signed by the wrong key
1150 1 : let bad_jwt = new_ec_jwt("1".into(), &key);
1151 :
1152 1 : let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
1153 1 : let jwks_addr = jwks_server(move |path| match path {
1154 1 : "/" => Some(serde_json::to_vec(&jwks).unwrap()),
1155 0 : _ => None,
1156 1 : })
1157 1 : .await;
1158 :
1159 1 : let role = RoleName::from("authenticated");
1160 :
1161 1 : let rules = vec![AuthRule {
1162 1 : id: String::new(),
1163 1 : jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1164 1 : audience: None,
1165 1 : role_names: vec![RoleNameInt::from(&role)],
1166 1 : }];
1167 :
1168 1 : let fetch = Fetch(rules);
1169 1 : let jwk_cache = JwkCache::default();
1170 :
1171 1 : let ep = EndpointId::from("ep");
1172 :
1173 1 : let ctx = RequestContext::test();
1174 1 : let err = jwk_cache
1175 1 : .check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
1176 1 : .await
1177 1 : .unwrap_err();
1178 1 : assert!(
1179 1 : matches!(err, JwtError::Signature(_)),
1180 1 : "expected \"signature error\", got {err:?}"
1181 1 : );
1182 1 : }
1183 :
1184 : #[tokio::test]
1185 1 : async fn check_jwt_unknown_role() {
1186 1 : let (key, jwk) = new_rsa_jwk(RS1, "1".into());
1187 1 : let jwt = new_rsa_jwt("1".into(), key);
1188 :
1189 1 : let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
1190 1 : let jwks_addr = jwks_server(move |path| match path {
1191 1 : "/" => Some(serde_json::to_vec(&jwks).unwrap()),
1192 0 : _ => None,
1193 1 : })
1194 1 : .await;
1195 :
1196 1 : let role = RoleName::from("authenticated");
1197 1 : let rules = vec![AuthRule {
1198 1 : id: String::new(),
1199 1 : jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1200 1 : audience: None,
1201 1 : role_names: vec![RoleNameInt::from(&role)],
1202 1 : }];
1203 :
1204 1 : let fetch = Fetch(rules);
1205 1 : let jwk_cache = JwkCache::default();
1206 :
1207 1 : let ep = EndpointId::from("ep");
1208 :
1209 : // this role_name is not accepted
1210 1 : let bad_role_name = RoleName::from("cloud_admin");
1211 :
1212 1 : let ctx = RequestContext::test();
1213 1 : let err = jwk_cache
1214 1 : .check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
1215 1 : .await
1216 1 : .unwrap_err();
1217 :
1218 1 : assert!(
1219 1 : matches!(err, JwtError::JwkNotFound),
1220 1 : "expected \"jwk not found\", got {err:?}"
1221 1 : );
1222 1 : }
1223 :
1224 : #[tokio::test]
1225 1 : async fn check_jwt_invalid_claims() {
1226 1 : let (key, jwk) = new_ec_jwk("1".into());
1227 :
1228 1 : let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
1229 1 : let jwks_addr = jwks_server(move |path| match path {
1230 1 : "/" => Some(serde_json::to_vec(&jwks).unwrap()),
1231 0 : _ => None,
1232 1 : })
1233 1 : .await;
1234 :
1235 1 : let now = SystemTime::now()
1236 1 : .duration_since(SystemTime::UNIX_EPOCH)
1237 1 : .unwrap()
1238 1 : .as_secs();
1239 :
1240 : struct Test {
1241 : body: serde_json::Value,
1242 : error: JwtClaimsError,
1243 : }
1244 :
1245 1 : let table = vec![
1246 1 : Test {
1247 1 : body: json! {{
1248 1 : "nbf": now + 60,
1249 1 : "aud": "neon",
1250 1 : }},
1251 1 : error: JwtClaimsError::JwtTokenNotYetReadyToUse(now + 60),
1252 1 : },
1253 1 : Test {
1254 1 : body: json! {{
1255 1 : "exp": now - 60,
1256 1 : "aud": ["neon"],
1257 1 : }},
1258 1 : error: JwtClaimsError::JwtTokenHasExpired(now - 60),
1259 1 : },
1260 1 : Test {
1261 1 : body: json! {{
1262 1 : }},
1263 1 : error: JwtClaimsError::InvalidJwtTokenAudience,
1264 1 : },
1265 1 : Test {
1266 1 : body: json! {{
1267 1 : "aud": [],
1268 1 : }},
1269 1 : error: JwtClaimsError::InvalidJwtTokenAudience,
1270 1 : },
1271 1 : Test {
1272 1 : body: json! {{
1273 1 : "aud": "foo",
1274 1 : }},
1275 1 : error: JwtClaimsError::InvalidJwtTokenAudience,
1276 1 : },
1277 1 : Test {
1278 1 : body: json! {{
1279 1 : "aud": ["foo"],
1280 1 : }},
1281 1 : error: JwtClaimsError::InvalidJwtTokenAudience,
1282 1 : },
1283 1 : Test {
1284 1 : body: json! {{
1285 1 : "aud": ["foo", "bar"],
1286 1 : }},
1287 1 : error: JwtClaimsError::InvalidJwtTokenAudience,
1288 1 : },
1289 : ];
1290 :
1291 1 : let role = RoleName::from("authenticated");
1292 :
1293 1 : let rules = vec![AuthRule {
1294 1 : id: String::new(),
1295 1 : jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1296 1 : audience: Some("neon".to_string()),
1297 1 : role_names: vec![RoleNameInt::from(&role)],
1298 1 : }];
1299 :
1300 1 : let fetch = Fetch(rules);
1301 1 : let jwk_cache = JwkCache::default();
1302 :
1303 1 : let ep = EndpointId::from("ep");
1304 :
1305 1 : let ctx = RequestContext::test();
1306 8 : for test in table {
1307 7 : let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
1308 1 :
1309 7 : match jwk_cache
1310 7 : .check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
1311 7 : .await
1312 1 : {
1313 7 : Err(JwtError::InvalidClaims(error)) if error == test.error => {}
1314 1 : Err(err) => {
1315 1 : panic!("expected {:?}, got {err:?}", test.error)
1316 1 : }
1317 1 : Ok(_payload) => {
1318 1 : panic!("expected {:?}, got ok", test.error)
1319 1 : }
1320 1 : }
1321 1 : }
1322 1 : }
1323 :
1324 : #[tokio::test]
1325 1 : async fn check_jwk_keycloak_regression() {
1326 1 : let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into());
1327 1 : let valid_jwk = serde_json::to_value(valid_jwk).unwrap();
1328 :
1329 : // This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones.
1330 : // This is taken directly from keycloak.
1331 1 : let invalid_jwk = serde_json::json! {
1332 : {
1333 1 : "kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ",
1334 1 : "kty": "RSA",
1335 1 : "alg": "RSA-OAEP",
1336 1 : "use": "enc",
1337 1 : "n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw",
1338 1 : "e": "AQAB",
1339 1 : "x5c": [
1340 : "MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI="
1341 : ],
1342 1 : "x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs",
1343 1 : "x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU"
1344 : }
1345 : };
1346 :
1347 1 : let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }};
1348 1 : let jwks_addr = jwks_server(move |path| match path {
1349 1 : "/" => Some(serde_json::to_vec(&jwks).unwrap()),
1350 0 : _ => None,
1351 1 : })
1352 1 : .await;
1353 :
1354 1 : let role_name = RoleName::from("anonymous");
1355 1 : let role = RoleNameInt::from(&role_name);
1356 :
1357 1 : let rules = vec![AuthRule {
1358 1 : id: "foo".to_owned(),
1359 1 : jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1360 1 : audience: None,
1361 1 : role_names: vec![role],
1362 1 : }];
1363 :
1364 1 : let fetch = Fetch(rules);
1365 1 : let jwk_cache = JwkCache::default();
1366 :
1367 1 : let endpoint = EndpointId::from("ep");
1368 :
1369 1 : let token = new_rsa_jwt("rs1".into(), rs);
1370 :
1371 1 : jwk_cache
1372 1 : .check_jwt(
1373 1 : &RequestContext::test(),
1374 1 : endpoint.clone(),
1375 1 : &role_name,
1376 1 : &fetch,
1377 1 : &token,
1378 1 : )
1379 1 : .await
1380 1 : .unwrap();
1381 1 : }
1382 : }
|