Line data Source code
1 : use std::borrow::Cow;
2 : use std::future::Future;
3 : use std::sync::Arc;
4 : use std::time::{Duration, SystemTime};
5 :
6 : use arc_swap::ArcSwapOption;
7 : use dashmap::DashMap;
8 : use jose_jwk::crypto::KeyInfo;
9 : use reqwest::{redirect, Client};
10 : use reqwest_retry::policies::ExponentialBackoff;
11 : use reqwest_retry::RetryTransientMiddleware;
12 : use serde::de::Visitor;
13 : use serde::{Deserialize, Deserializer};
14 : use serde_json::value::RawValue;
15 : use signature::Verifier;
16 : use thiserror::Error;
17 : use tokio::time::Instant;
18 :
19 : use crate::auth::backend::ComputeCredentialKeys;
20 : use crate::context::RequestContext;
21 : use crate::control_plane::errors::GetEndpointJwksError;
22 : use crate::http::read_body_with_limit;
23 : use crate::intern::RoleNameInt;
24 : use crate::types::{EndpointId, RoleName};
25 :
26 : // TODO(conrad): make these configurable.
27 : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
28 : const MIN_RENEW: Duration = Duration::from_secs(30);
29 : const AUTO_RENEW: Duration = Duration::from_secs(300);
30 : const MAX_RENEW: Duration = Duration::from_secs(3600);
31 : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
32 : const JWKS_USER_AGENT: &str = "neon-proxy";
33 :
34 : const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
35 : const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
36 : const JWKS_FETCH_RETRIES: u32 = 3;
37 :
38 : /// How to get the JWT auth rules
39 : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
40 : fn fetch_auth_rules(
41 : &self,
42 : ctx: &RequestContext,
43 : endpoint: EndpointId,
44 : ) -> impl Future<Output = Result<Vec<AuthRule>, FetchAuthRulesError>> + Send;
45 : }
46 :
47 0 : #[derive(Error, Debug)]
48 : pub(crate) enum FetchAuthRulesError {
49 : #[error(transparent)]
50 : GetEndpointJwks(#[from] GetEndpointJwksError),
51 :
52 : #[error("JWKs settings for this role were not configured")]
53 : RoleJwksNotConfigured,
54 : }
55 :
56 : #[derive(Clone)]
57 : pub(crate) struct AuthRule {
58 : pub(crate) id: String,
59 : pub(crate) jwks_url: url::Url,
60 : pub(crate) audience: Option<String>,
61 : pub(crate) role_names: Vec<RoleNameInt>,
62 : }
63 :
64 : pub struct JwkCache {
65 : client: reqwest_middleware::ClientWithMiddleware,
66 :
67 : map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
68 : }
69 :
70 : pub(crate) struct JwkCacheEntry {
71 : /// Should refetch at least every hour to verify when old keys have been removed.
72 : /// Should refetch when new key IDs are seen only every 5 minutes or so
73 : last_retrieved: Instant,
74 :
75 : /// cplane will return multiple JWKs urls that we need to scrape.
76 : key_sets: ahash::HashMap<String, KeySet>,
77 : }
78 :
79 : impl JwkCacheEntry {
80 19 : fn find_jwk_and_audience(
81 19 : &self,
82 19 : key_id: &str,
83 19 : role_name: &RoleName,
84 19 : ) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
85 19 : self.key_sets
86 19 : .values()
87 19 : // make sure our requested role has access to the key set
88 29 : .filter(|key_set| key_set.role_names.iter().any(|role| **role == **role_name))
89 19 : // try and find the requested key-id in the key set
90 22 : .find_map(|key_set| {
91 22 : key_set
92 22 : .find_key(key_id)
93 22 : .map(|jwk| (jwk, key_set.audience.as_deref()))
94 22 : })
95 19 : }
96 : }
97 :
98 : struct KeySet {
99 : jwks: jose_jwk::JwkSet,
100 : audience: Option<String>,
101 : role_names: Vec<RoleNameInt>,
102 : }
103 :
104 : impl KeySet {
105 22 : fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
106 22 : self.jwks
107 22 : .keys
108 22 : .iter()
109 30 : .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
110 22 : }
111 : }
112 :
113 : pub(crate) struct JwkCacheEntryLock {
114 : cached: ArcSwapOption<JwkCacheEntry>,
115 : lookup: tokio::sync::Semaphore,
116 : }
117 :
118 : impl Default for JwkCacheEntryLock {
119 7 : fn default() -> Self {
120 7 : JwkCacheEntryLock {
121 7 : cached: ArcSwapOption::empty(),
122 7 : lookup: tokio::sync::Semaphore::new(1),
123 7 : }
124 7 : }
125 : }
126 :
127 18 : #[derive(Deserialize)]
128 : struct JwkSet<'a> {
129 : /// we parse into raw-value because not all keys in a JWKS are ones
130 : /// we can parse directly, so we parse them lazily.
131 : #[serde(borrow)]
132 : keys: Vec<&'a RawValue>,
133 : }
134 :
135 : impl JwkCacheEntryLock {
136 7 : async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
137 7 : JwkRenewalPermit::acquire_permit(self).await
138 7 : }
139 :
140 0 : fn try_acquire_permit<'a>(self: &'a Arc<Self>) -> Option<JwkRenewalPermit<'a>> {
141 0 : JwkRenewalPermit::try_acquire_permit(self)
142 0 : }
143 :
144 7 : async fn renew_jwks<F: FetchAuthRules>(
145 7 : &self,
146 7 : _permit: JwkRenewalPermit<'_>,
147 7 : ctx: &RequestContext,
148 7 : client: &reqwest_middleware::ClientWithMiddleware,
149 7 : endpoint: EndpointId,
150 7 : auth_rules: &F,
151 7 : ) -> Result<Arc<JwkCacheEntry>, JwtError> {
152 7 : // double check that no one beat us to updating the cache.
153 7 : let now = Instant::now();
154 7 : let guard = self.cached.load_full();
155 7 : if let Some(cached) = guard {
156 0 : let last_update = now.duration_since(cached.last_retrieved);
157 0 : if last_update < Duration::from_secs(300) {
158 0 : return Ok(cached);
159 0 : }
160 7 : }
161 :
162 7 : let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
163 7 : let mut key_sets =
164 7 : ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
165 :
166 : // TODO(conrad): run concurrently
167 : // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
168 16 : for rule in rules {
169 9 : let req = client.get(rule.jwks_url.clone());
170 9 : // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
171 9 : // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
172 21 : match req.send().await.and_then(|r| {
173 9 : r.error_for_status()
174 9 : .map_err(reqwest_middleware::Error::Reqwest)
175 9 : }) {
176 : // todo: should we re-insert JWKs if we want to keep this JWKs URL?
177 : // I expect these failures would be quite sparse.
178 0 : Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
179 9 : Ok(r) => {
180 9 : let resp: http::Response<reqwest::Body> = r.into();
181 :
182 9 : let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE)
183 0 : .await
184 : {
185 9 : Ok(bytes) => bytes,
186 0 : Err(e) => {
187 0 : tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
188 0 : continue;
189 : }
190 : };
191 :
192 9 : match serde_json::from_slice::<JwkSet>(&bytes) {
193 0 : Err(e) => {
194 0 : tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
195 : }
196 9 : Ok(jwks) => {
197 9 : // size_of::<&RawValue>() == 16
198 9 : // size_of::<jose_jwk::Jwk>() == 288
199 9 : // better to not pre-allocate this as it might be pretty large - especially if it has many
200 9 : // keys we don't want or need.
201 9 : // trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`
202 9 : // this would consume 8MiB just like that!
203 9 : let mut keys = vec![];
204 9 : let mut failed = 0;
205 23 : for key in jwks.keys {
206 14 : match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
207 13 : Ok(key) => {
208 13 : // if `use` (called `cls` in rust) is specified to be something other than signing,
209 13 : // we can skip storing it.
210 13 : if key
211 13 : .prm
212 13 : .cls
213 13 : .as_ref()
214 13 : .is_some_and(|c| *c != jose_jwk::Class::Signing)
215 : {
216 0 : continue;
217 13 : }
218 13 :
219 13 : keys.push(key);
220 : }
221 1 : Err(e) => {
222 1 : tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK");
223 1 : failed += 1;
224 : }
225 : }
226 : }
227 9 : keys.shrink_to_fit();
228 9 :
229 9 : if failed > 0 {
230 1 : tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs");
231 8 : }
232 :
233 9 : if keys.is_empty() {
234 0 : tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body");
235 0 : continue;
236 9 : }
237 9 :
238 9 : let jwks = jose_jwk::JwkSet { keys };
239 9 : key_sets.insert(
240 9 : rule.id,
241 9 : KeySet {
242 9 : jwks,
243 9 : audience: rule.audience,
244 9 : role_names: rule.role_names,
245 9 : },
246 9 : );
247 : }
248 : };
249 : }
250 : }
251 : }
252 :
253 7 : let entry = Arc::new(JwkCacheEntry {
254 7 : last_retrieved: now,
255 7 : key_sets,
256 7 : });
257 7 : self.cached.swap(Some(Arc::clone(&entry)));
258 7 :
259 7 : Ok(entry)
260 7 : }
261 :
262 19 : async fn get_or_update_jwk_cache<F: FetchAuthRules>(
263 19 : self: &Arc<Self>,
264 19 : ctx: &RequestContext,
265 19 : client: &reqwest_middleware::ClientWithMiddleware,
266 19 : endpoint: EndpointId,
267 19 : fetch: &F,
268 19 : ) -> Result<Arc<JwkCacheEntry>, JwtError> {
269 19 : let now = Instant::now();
270 19 : let guard = self.cached.load_full();
271 :
272 : // if we have no cached JWKs, try and get some
273 19 : let Some(cached) = guard else {
274 7 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
275 7 : let permit = self.acquire_permit().await;
276 21 : return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
277 : };
278 :
279 12 : let last_update = now.duration_since(cached.last_retrieved);
280 12 :
281 12 : // check if the cached JWKs need updating.
282 12 : if last_update > MAX_RENEW {
283 0 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
284 0 : let permit = self.acquire_permit().await;
285 :
286 : // it's been too long since we checked the keys. wait for them to update.
287 0 : return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
288 12 : }
289 12 :
290 12 : // every 5 minutes we should spawn a job to eagerly update the token.
291 12 : if last_update > AUTO_RENEW {
292 0 : if let Some(permit) = self.try_acquire_permit() {
293 0 : tracing::debug!("JWKs should be renewed. Renewal permit acquired");
294 0 : let permit = permit.into_owned();
295 0 : let entry = self.clone();
296 0 : let client = client.clone();
297 0 : let fetch = fetch.clone();
298 0 : let ctx = ctx.clone();
299 0 : tokio::spawn(async move {
300 0 : if let Err(e) = entry
301 0 : .renew_jwks(permit, &ctx, &client, endpoint, &fetch)
302 0 : .await
303 : {
304 0 : tracing::warn!(error=?e, "could not fetch JWKs in background job");
305 0 : }
306 0 : });
307 0 : } else {
308 0 : tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
309 : }
310 12 : }
311 :
312 12 : Ok(cached)
313 19 : }
314 :
315 19 : async fn check_jwt<F: FetchAuthRules>(
316 19 : self: &Arc<Self>,
317 19 : ctx: &RequestContext,
318 19 : jwt: &str,
319 19 : client: &reqwest_middleware::ClientWithMiddleware,
320 19 : endpoint: EndpointId,
321 19 : role_name: &RoleName,
322 19 : fetch: &F,
323 19 : ) -> Result<ComputeCredentialKeys, JwtError> {
324 : // JWT compact form is defined to be
325 : // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
326 : // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
327 :
328 19 : let (header_payload, signature) = jwt
329 19 : .rsplit_once('.')
330 19 : .ok_or(JwtEncodingError::InvalidCompactForm)?;
331 19 : let (header, payload) = header_payload
332 19 : .split_once('.')
333 19 : .ok_or(JwtEncodingError::InvalidCompactForm)?;
334 :
335 19 : let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?;
336 19 : let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
337 :
338 19 : let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)?;
339 :
340 19 : let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
341 :
342 19 : let mut guard = self
343 19 : .get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
344 21 : .await?;
345 :
346 : // get the key from the JWKs if possible. If not, wait for the keys to update.
347 18 : let (jwk, expected_audience) = loop {
348 19 : match guard.find_jwk_and_audience(&kid, role_name) {
349 18 : Some(jwk) => break jwk,
350 1 : None if guard.last_retrieved.elapsed() > MIN_RENEW => {
351 0 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
352 :
353 0 : let permit = self.acquire_permit().await;
354 0 : guard = self
355 0 : .renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
356 0 : .await?;
357 : }
358 1 : _ => return Err(JwtError::JwkNotFound),
359 : }
360 : };
361 :
362 18 : if !jwk.is_supported(&header.algorithm) {
363 0 : return Err(JwtError::SignatureAlgorithmNotSupported);
364 18 : }
365 18 :
366 18 : match &jwk.key {
367 13 : jose_jwk::Key::Ec(key) => {
368 13 : verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
369 : }
370 5 : jose_jwk::Key::Rsa(key) => {
371 5 : verify_rsa_signature(header_payload.as_bytes(), &sig, key, &header.algorithm)?;
372 : }
373 0 : key => return Err(JwtError::UnsupportedKeyType(key.into())),
374 : };
375 :
376 17 : let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)?;
377 17 : let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
378 :
379 17 : tracing::debug!(?payload, "JWT signature valid with claims");
380 :
381 17 : if let Some(aud) = expected_audience {
382 7 : if payload.audience.0.iter().all(|s| s != aud) {
383 5 : return Err(JwtError::InvalidClaims(
384 5 : JwtClaimsError::InvalidJwtTokenAudience,
385 5 : ));
386 2 : }
387 10 : }
388 :
389 12 : let now = SystemTime::now();
390 :
391 12 : if let Some(exp) = payload.expiration {
392 11 : if now >= exp + CLOCK_SKEW_LEEWAY {
393 1 : return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired));
394 10 : }
395 1 : }
396 :
397 11 : if let Some(nbf) = payload.not_before {
398 10 : if nbf >= now + CLOCK_SKEW_LEEWAY {
399 1 : return Err(JwtError::InvalidClaims(
400 1 : JwtClaimsError::JwtTokenNotYetReadyToUse,
401 1 : ));
402 9 : }
403 1 : }
404 :
405 10 : Ok(ComputeCredentialKeys::JwtPayload(payloadb))
406 19 : }
407 : }
408 :
409 : impl JwkCache {
410 19 : pub(crate) async fn check_jwt<F: FetchAuthRules>(
411 19 : &self,
412 19 : ctx: &RequestContext,
413 19 : endpoint: EndpointId,
414 19 : role_name: &RoleName,
415 19 : fetch: &F,
416 19 : jwt: &str,
417 19 : ) -> Result<ComputeCredentialKeys, JwtError> {
418 19 : // try with just a read lock first
419 19 : let key = (endpoint.clone(), role_name.clone());
420 19 : let entry = self.map.get(&key).as_deref().map(Arc::clone);
421 19 : let entry = entry.unwrap_or_else(|| {
422 7 : // acquire a write lock after to insert.
423 7 : let entry = self.map.entry(key).or_default();
424 7 : Arc::clone(&*entry)
425 19 : });
426 19 :
427 19 : entry
428 19 : .check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
429 21 : .await
430 19 : }
431 : }
432 :
433 : impl Default for JwkCache {
434 9 : fn default() -> Self {
435 9 : let client = Client::builder()
436 9 : .user_agent(JWKS_USER_AGENT)
437 9 : .redirect(redirect::Policy::none())
438 9 : .tls_built_in_native_certs(true)
439 9 : .connect_timeout(JWKS_CONNECT_TIMEOUT)
440 9 : .timeout(JWKS_FETCH_TIMEOUT)
441 9 : .build()
442 9 : .expect("client config should be valid");
443 9 :
444 9 : // Retry up to 3 times with increasing intervals between attempts.
445 9 : let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES);
446 9 :
447 9 : let client = reqwest_middleware::ClientBuilder::new(client)
448 9 : .with(RetryTransientMiddleware::new_with_policy(retry_policy))
449 9 : .build();
450 9 :
451 9 : JwkCache {
452 9 : client,
453 9 : map: DashMap::default(),
454 9 : }
455 9 : }
456 : }
457 :
458 13 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> Result<(), JwtError> {
459 : use ecdsa::Signature;
460 : use signature::Verifier;
461 :
462 13 : match key.crv {
463 : jose_jwk::EcCurves::P256 => {
464 13 : let pk = p256::PublicKey::try_from(key).map_err(JwtError::InvalidP256Key)?;
465 13 : let key = p256::ecdsa::VerifyingKey::from(&pk);
466 13 : let sig = Signature::from_slice(sig)?;
467 13 : key.verify(data, &sig)?;
468 : }
469 0 : key => return Err(JwtError::UnsupportedEcKeyType(key)),
470 : }
471 :
472 12 : Ok(())
473 13 : }
474 :
475 5 : fn verify_rsa_signature(
476 5 : data: &[u8],
477 5 : sig: &[u8],
478 5 : key: &jose_jwk::Rsa,
479 5 : alg: &jose_jwa::Algorithm,
480 5 : ) -> Result<(), JwtError> {
481 : use jose_jwa::{Algorithm, Signing};
482 : use rsa::pkcs1v15::{Signature, VerifyingKey};
483 : use rsa::RsaPublicKey;
484 :
485 5 : let key = RsaPublicKey::try_from(key).map_err(JwtError::InvalidRsaKey)?;
486 :
487 5 : match alg {
488 : Algorithm::Signing(Signing::Rs256) => {
489 5 : let key = VerifyingKey::<sha2::Sha256>::new(key);
490 5 : let sig = Signature::try_from(sig)?;
491 5 : key.verify(data, &sig)?;
492 : }
493 0 : _ => return Err(JwtError::InvalidRsaSigningAlgorithm),
494 : };
495 :
496 5 : Ok(())
497 5 : }
498 :
499 : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
500 57 : #[derive(serde::Deserialize, serde::Serialize)]
501 : struct JwtHeader<'a> {
502 : /// must be a supported alg
503 : #[serde(rename = "alg")]
504 : algorithm: jose_jwa::Algorithm,
505 : /// key id, must be provided for our usecase
506 : #[serde(rename = "kid", borrow)]
507 : key_id: Option<Cow<'a, str>>,
508 : }
509 :
510 : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
511 121 : #[derive(serde::Deserialize, Debug)]
512 : #[allow(dead_code)]
513 : struct JwtPayload<'a> {
514 : /// Audience - Recipient for which the JWT is intended
515 : #[serde(rename = "aud", default)]
516 : audience: OneOrMany,
517 : /// Expiration - Time after which the JWT expires
518 : #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
519 : expiration: Option<SystemTime>,
520 : /// Not before - Time after which the JWT expires
521 : #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
522 : not_before: Option<SystemTime>,
523 :
524 : // the following entries are only extracted for the sake of debug logging.
525 : /// Issuer of the JWT
526 : #[serde(rename = "iss", borrow)]
527 : issuer: Option<Cow<'a, str>>,
528 : /// Subject of the JWT (the user)
529 : #[serde(rename = "sub", borrow)]
530 : subject: Option<Cow<'a, str>>,
531 : /// Unique token identifier
532 : #[serde(rename = "jti", borrow)]
533 : jwt_id: Option<Cow<'a, str>>,
534 : /// Unique session identifier
535 : #[serde(rename = "sid", borrow)]
536 : session_id: Option<Cow<'a, str>>,
537 : }
538 :
539 : /// `OneOrMany` supports parsing either a single item or an array of items.
540 : ///
541 : /// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
542 : ///
543 : /// > The "aud" (audience) claim identifies the recipients that the JWT is
544 : /// > intended for. Each principal intended to process the JWT MUST
545 : /// > identify itself with a value in the audience claim. If the principal
546 : /// > processing the claim does not identify itself with a value in the
547 : /// > "aud" claim when this claim is present, then the JWT MUST be
548 : /// > rejected. In the general case, the "aud" value is **an array of case-
549 : /// > sensitive strings**, each containing a StringOrURI value. In the
550 : /// > special case when the JWT has one audience, the "aud" value MAY be a
551 : /// > **single case-sensitive string** containing a StringOrURI value. The
552 : /// > interpretation of audience values is generally application specific.
553 : /// > Use of this claim is OPTIONAL.
554 : #[derive(Default, Debug)]
555 : struct OneOrMany(Vec<String>);
556 :
557 : impl<'de> Deserialize<'de> for OneOrMany {
558 15 : fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
559 15 : where
560 15 : D: Deserializer<'de>,
561 15 : {
562 : struct OneOrManyVisitor;
563 : impl<'de> Visitor<'de> for OneOrManyVisitor {
564 : type Value = OneOrMany;
565 :
566 0 : fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
567 0 : formatter.write_str("a single string or an array of strings")
568 0 : }
569 :
570 2 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
571 2 : where
572 2 : E: serde::de::Error,
573 2 : {
574 2 : Ok(OneOrMany(vec![v.to_owned()]))
575 2 : }
576 :
577 13 : fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
578 13 : where
579 13 : A: serde::de::SeqAccess<'de>,
580 13 : {
581 13 : let mut v = vec![];
582 44 : while let Some(s) = seq.next_element()? {
583 31 : v.push(s);
584 31 : }
585 13 : Ok(OneOrMany(v))
586 13 : }
587 : }
588 15 : deserializer.deserialize_any(OneOrManyVisitor)
589 15 : }
590 : }
591 :
592 21 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
593 21 : let d = <Option<u64>>::deserialize(d)?;
594 21 : Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
595 21 : }
596 :
597 : struct JwkRenewalPermit<'a> {
598 : inner: Option<JwkRenewalPermitInner<'a>>,
599 : }
600 :
601 : enum JwkRenewalPermitInner<'a> {
602 : Owned(Arc<JwkCacheEntryLock>),
603 : Borrowed(&'a Arc<JwkCacheEntryLock>),
604 : }
605 :
606 : impl JwkRenewalPermit<'_> {
607 0 : fn into_owned(mut self) -> JwkRenewalPermit<'static> {
608 0 : JwkRenewalPermit {
609 0 : inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
610 0 : }
611 0 : }
612 :
613 7 : async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
614 7 : match from.lookup.acquire().await {
615 7 : Ok(permit) => {
616 7 : permit.forget();
617 7 : JwkRenewalPermit {
618 7 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
619 7 : }
620 : }
621 0 : Err(_) => panic!("semaphore should not be closed"),
622 : }
623 7 : }
624 :
625 0 : fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
626 0 : match from.lookup.try_acquire() {
627 0 : Ok(permit) => {
628 0 : permit.forget();
629 0 : Some(JwkRenewalPermit {
630 0 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
631 0 : })
632 : }
633 0 : Err(tokio::sync::TryAcquireError::NoPermits) => None,
634 0 : Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
635 : }
636 0 : }
637 : }
638 :
639 : impl JwkRenewalPermitInner<'_> {
640 0 : fn into_owned(self) -> JwkRenewalPermitInner<'static> {
641 0 : match self {
642 0 : JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
643 0 : JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
644 : }
645 0 : }
646 : }
647 :
648 : impl Drop for JwkRenewalPermit<'_> {
649 7 : fn drop(&mut self) {
650 7 : let entry = match &self.inner {
651 0 : None => return,
652 0 : Some(JwkRenewalPermitInner::Owned(p)) => p,
653 7 : Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
654 : };
655 7 : entry.lookup.add_permits(1);
656 7 : }
657 : }
658 :
659 1 : #[derive(Error, Debug)]
660 : #[non_exhaustive]
661 : pub(crate) enum JwtError {
662 : #[error("jwk not found")]
663 : JwkNotFound,
664 :
665 : #[error("missing key id")]
666 : MissingKeyId,
667 :
668 : #[error("Provided authentication token is not a valid JWT encoding")]
669 : JwtEncoding(#[from] JwtEncodingError),
670 :
671 : #[error(transparent)]
672 : InvalidClaims(#[from] JwtClaimsError),
673 :
674 : #[error("invalid P256 key")]
675 : InvalidP256Key(jose_jwk::crypto::Error),
676 :
677 : #[error("invalid RSA key")]
678 : InvalidRsaKey(jose_jwk::crypto::Error),
679 :
680 : #[error("invalid RSA signing algorithm")]
681 : InvalidRsaSigningAlgorithm,
682 :
683 : #[error("unsupported EC key type {0:?}")]
684 : UnsupportedEcKeyType(jose_jwk::EcCurves),
685 :
686 : #[error("unsupported key type {0:?}")]
687 : UnsupportedKeyType(KeyType),
688 :
689 : #[error("signature algorithm not supported")]
690 : SignatureAlgorithmNotSupported,
691 :
692 : #[error("signature error: {0}")]
693 : Signature(#[from] signature::Error),
694 :
695 : #[error("failed to fetch auth rules: {0}")]
696 : FetchAuthRules(#[from] FetchAuthRulesError),
697 : }
698 :
699 : impl From<base64::DecodeError> for JwtError {
700 0 : fn from(err: base64::DecodeError) -> Self {
701 0 : JwtEncodingError::Base64Decode(err).into()
702 0 : }
703 : }
704 :
705 : impl From<serde_json::Error> for JwtError {
706 0 : fn from(err: serde_json::Error) -> Self {
707 0 : JwtEncodingError::SerdeJson(err).into()
708 0 : }
709 : }
710 :
711 0 : #[derive(Error, Debug)]
712 : #[non_exhaustive]
713 : pub enum JwtEncodingError {
714 : #[error(transparent)]
715 : Base64Decode(#[from] base64::DecodeError),
716 :
717 : #[error(transparent)]
718 : SerdeJson(#[from] serde_json::Error),
719 :
720 : #[error("invalid compact form")]
721 : InvalidCompactForm,
722 : }
723 :
724 0 : #[derive(Error, Debug, PartialEq)]
725 : #[non_exhaustive]
726 : pub enum JwtClaimsError {
727 : #[error("invalid JWT token audience")]
728 : InvalidJwtTokenAudience,
729 :
730 : #[error("JWT token has expired")]
731 : JwtTokenHasExpired,
732 :
733 : #[error("JWT token is not yet ready to use")]
734 : JwtTokenNotYetReadyToUse,
735 : }
736 :
737 : #[allow(dead_code, reason = "Debug use only")]
738 : #[derive(Debug)]
739 : pub(crate) enum KeyType {
740 : Ec(jose_jwk::EcCurves),
741 : Rsa,
742 : Oct,
743 : Okp(jose_jwk::OkpCurves),
744 : Unknown,
745 : }
746 :
747 : impl From<&jose_jwk::Key> for KeyType {
748 0 : fn from(key: &jose_jwk::Key) -> Self {
749 0 : match key {
750 0 : jose_jwk::Key::Ec(ec) => Self::Ec(ec.crv),
751 0 : jose_jwk::Key::Rsa(_rsa) => Self::Rsa,
752 0 : jose_jwk::Key::Oct(_oct) => Self::Oct,
753 0 : jose_jwk::Key::Okp(okp) => Self::Okp(okp.crv),
754 0 : _ => Self::Unknown,
755 : }
756 0 : }
757 : }
758 :
759 : #[cfg(test)]
760 : mod tests {
761 : use std::future::IntoFuture;
762 : use std::net::SocketAddr;
763 : use std::time::SystemTime;
764 :
765 : use base64::URL_SAFE_NO_PAD;
766 : use bytes::Bytes;
767 : use http::Response;
768 : use http_body_util::Full;
769 : use hyper::service::service_fn;
770 : use hyper_util::rt::TokioIo;
771 : use rand::rngs::OsRng;
772 : use rsa::pkcs8::DecodePrivateKey;
773 : use serde::Serialize;
774 : use serde_json::json;
775 : use signature::Signer;
776 : use tokio::net::TcpListener;
777 :
778 : use super::*;
779 : use crate::types::RoleName;
780 :
781 6 : fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
782 6 : let sk = p256::SecretKey::random(&mut OsRng);
783 6 : let pk = sk.public_key().into();
784 6 : let jwk = jose_jwk::Jwk {
785 6 : key: jose_jwk::Key::Ec(pk),
786 6 : prm: jose_jwk::Parameters {
787 6 : kid: Some(kid),
788 6 : alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
789 6 : ..Default::default()
790 6 : },
791 6 : };
792 6 : (sk, jwk)
793 6 : }
794 :
795 4 : fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
796 4 : let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap();
797 4 : let pk = sk.to_public_key().into();
798 4 : let jwk = jose_jwk::Jwk {
799 4 : key: jose_jwk::Key::Rsa(pk),
800 4 : prm: jose_jwk::Parameters {
801 4 : kid: Some(kid),
802 4 : alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
803 4 : ..Default::default()
804 4 : },
805 4 : };
806 4 : (sk, jwk)
807 4 : }
808 :
809 8 : fn now() -> u64 {
810 8 : SystemTime::now()
811 8 : .duration_since(SystemTime::UNIX_EPOCH)
812 8 : .unwrap()
813 8 : .as_secs()
814 8 : }
815 :
816 7 : fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
817 7 : let now = now();
818 7 : let body = typed_json::json! {{
819 7 : "exp": now + 3600,
820 7 : "nbf": now,
821 7 : "aud": ["audience1", "neon", "audience2"],
822 7 : "sub": "user1",
823 7 : "sid": "session1",
824 7 : "jti": "token1",
825 7 : "iss": "neon-testing",
826 7 : }};
827 7 : build_custom_jwt_payload(kid, body, sig)
828 7 : }
829 :
830 15 : fn build_custom_jwt_payload(
831 15 : kid: String,
832 15 : body: impl Serialize,
833 15 : sig: jose_jwa::Signing,
834 15 : ) -> String {
835 15 : let header = JwtHeader {
836 15 : algorithm: jose_jwa::Algorithm::Signing(sig),
837 15 : key_id: Some(Cow::Owned(kid)),
838 15 : };
839 15 :
840 15 : let header =
841 15 : base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
842 15 : let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
843 15 :
844 15 : format!("{header}.{body}")
845 15 : }
846 :
847 3 : fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
848 : use p256::ecdsa::{Signature, SigningKey};
849 :
850 3 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
851 3 : let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
852 3 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
853 3 :
854 3 : format!("{payload}.{sig}")
855 3 : }
856 :
857 8 : fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
858 : use p256::ecdsa::{Signature, SigningKey};
859 :
860 8 : let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
861 8 : let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
862 8 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
863 8 :
864 8 : format!("{payload}.{sig}")
865 8 : }
866 :
867 4 : fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
868 : use rsa::pkcs1v15::SigningKey;
869 : use rsa::signature::SignatureEncoding;
870 :
871 4 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
872 4 : let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
873 4 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
874 4 :
875 4 : format!("{payload}.{sig}")
876 4 : }
877 :
878 : // RSA key gen is slow....
879 : const RS1: &str = "-----BEGIN PRIVATE KEY-----
880 : MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y
881 : aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA
882 : SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA
883 : cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X
884 : mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s
885 : PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM
886 : nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9
887 : ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi
888 : 5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+
889 : KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz
890 : 1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP
891 : We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony
892 : UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu
893 : GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+
894 : HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU
895 : f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24
896 : DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY
897 : 9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF
898 : /4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL
899 : 3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE
900 : qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d
901 : ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW
902 : rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf
903 : JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI
904 : rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9
905 : 8LMFlz1QPqSZYN/A/kOcLBfa3A==
906 : -----END PRIVATE KEY-----
907 : ";
908 : const RS2: &str = "-----BEGIN PRIVATE KEY-----
909 : MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J
910 : HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J
911 : zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk
912 : qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO
913 : IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt
914 : ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy
915 : uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv
916 : K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG
917 : LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe
918 : QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT
919 : zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX
920 : 0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR
921 : +Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1
922 : P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL
923 : x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/
924 : FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg
925 : iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M
926 : 4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA
927 : iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj
928 : N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB
929 : uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U
930 : qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP
931 : WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB
932 : rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne
933 : X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
934 : 5JiconnI5aLek0QVPoFaVXFa
935 : -----END PRIVATE KEY-----
936 : ";
937 :
938 : #[derive(Clone)]
939 : struct Fetch(Vec<AuthRule>);
940 :
941 : impl FetchAuthRules for Fetch {
942 7 : async fn fetch_auth_rules(
943 7 : &self,
944 7 : _ctx: &RequestContext,
945 7 : _endpoint: EndpointId,
946 7 : ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
947 7 : Ok(self.0.clone())
948 7 : }
949 : }
950 :
951 6 : async fn jwks_server(
952 6 : router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
953 6 : ) -> SocketAddr {
954 6 : let router = Arc::new(router);
955 9 : let service = service_fn(move |req| {
956 9 : let router = Arc::clone(&router);
957 9 : async move {
958 9 : match router(req.uri().path()) {
959 9 : Some(body) => Response::builder()
960 9 : .status(200)
961 9 : .body(Full::new(Bytes::from(body))),
962 0 : None => Response::builder()
963 0 : .status(404)
964 0 : .body(Full::new(Bytes::new())),
965 : }
966 9 : }
967 9 : });
968 :
969 6 : let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
970 6 : let server = hyper::server::conn::http1::Builder::new();
971 6 : let addr = listener.local_addr().unwrap();
972 6 : tokio::spawn(async move {
973 : loop {
974 12 : let (s, _) = listener.accept().await.unwrap();
975 6 : let serve = server.serve_connection(TokioIo::new(s), service.clone());
976 6 : tokio::spawn(serve.into_future());
977 : }
978 6 : });
979 6 :
980 6 : addr
981 6 : }
982 :
983 : #[tokio::test]
984 1 : async fn check_jwt_happy_path() {
985 1 : let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
986 1 : let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
987 1 : let (ec1, jwk3) = new_ec_jwk("ec1".into());
988 1 : let (ec2, jwk4) = new_ec_jwk("ec2".into());
989 1 :
990 1 : let foo_jwks = jose_jwk::JwkSet {
991 1 : keys: vec![jwk1, jwk3],
992 1 : };
993 1 : let bar_jwks = jose_jwk::JwkSet {
994 1 : keys: vec![jwk2, jwk4],
995 1 : };
996 1 :
997 4 : let jwks_addr = jwks_server(move |path| match path {
998 4 : "/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
999 2 : "/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
1000 1 : _ => None,
1001 4 : })
1002 1 : .await;
1003 1 :
1004 1 : let role_name1 = RoleName::from("anonymous");
1005 1 : let role_name2 = RoleName::from("authenticated");
1006 1 :
1007 1 : let roles = vec![
1008 1 : RoleNameInt::from(&role_name1),
1009 1 : RoleNameInt::from(&role_name2),
1010 1 : ];
1011 1 : let rules = vec![
1012 1 : AuthRule {
1013 1 : id: "foo".to_owned(),
1014 1 : jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
1015 1 : audience: None,
1016 1 : role_names: roles.clone(),
1017 1 : },
1018 1 : AuthRule {
1019 1 : id: "bar".to_owned(),
1020 1 : jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
1021 1 : audience: None,
1022 1 : role_names: roles.clone(),
1023 1 : },
1024 1 : ];
1025 1 :
1026 1 : let fetch = Fetch(rules);
1027 1 : let jwk_cache = JwkCache::default();
1028 1 :
1029 1 : let endpoint = EndpointId::from("ep");
1030 1 :
1031 1 : let jwt1 = new_rsa_jwt("rs1".into(), rs1);
1032 1 : let jwt2 = new_rsa_jwt("rs2".into(), rs2);
1033 1 : let jwt3 = new_ec_jwt("ec1".into(), &ec1);
1034 1 : let jwt4 = new_ec_jwt("ec2".into(), &ec2);
1035 1 :
1036 1 : let tokens = [jwt1, jwt2, jwt3, jwt4];
1037 1 : let role_names = [role_name1, role_name2];
1038 3 : for role in &role_names {
1039 10 : for token in &tokens {
1040 8 : jwk_cache
1041 8 : .check_jwt(
1042 8 : &RequestContext::test(),
1043 8 : endpoint.clone(),
1044 8 : role,
1045 8 : &fetch,
1046 8 : token,
1047 8 : )
1048 6 : .await
1049 8 : .unwrap();
1050 1 : }
1051 1 : }
1052 1 : }
1053 :
1054 : /// AWS Cognito escapes the `/` in the URL.
1055 : #[tokio::test]
1056 1 : async fn check_jwt_regression_cognito_issuer() {
1057 1 : let (key, jwk) = new_ec_jwk("key".into());
1058 1 :
1059 1 : let now = now();
1060 1 : let token = new_custom_ec_jwt(
1061 1 : "key".into(),
1062 1 : &key,
1063 1 : typed_json::json! {{
1064 1 : "sub": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
1065 1 : // cognito uses `\/`. I cannot replicated that easily here as serde_json will refuse
1066 1 : // to write that escape character. instead I will make a bogus URL using `\` instead.
1067 1 : "iss": "https:\\\\cognito-idp.us-west-2.amazonaws.com\\us-west-2_abcdefgh",
1068 1 : "client_id": "abcdefghijklmnopqrstuvwxyz",
1069 1 : "origin_jti": "6759d132-3fe7-446e-9e90-2fe7e8017893",
1070 1 : "event_id": "ec9c36ab-b01d-46a0-94e4-87fde6767065",
1071 1 : "token_use": "access",
1072 1 : "scope": "aws.cognito.signin.user.admin",
1073 1 : "auth_time":now,
1074 1 : "exp":now + 60,
1075 1 : "iat":now,
1076 1 : "jti": "b241614b-0b93-4bdc-96db-0a3c7061d9c0",
1077 1 : "username": "dd9a73fd-e785-4a13-aae1-e691ce43e89d",
1078 1 : }},
1079 1 : );
1080 1 :
1081 1 : let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
1082 1 :
1083 1 : let jwks_addr = jwks_server(move |_path| Some(serde_json::to_vec(&jwks).unwrap())).await;
1084 1 :
1085 1 : let role_name = RoleName::from("anonymous");
1086 1 : let rules = vec![AuthRule {
1087 1 : id: "aws-cognito".to_owned(),
1088 1 : jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1089 1 : audience: None,
1090 1 : role_names: vec![RoleNameInt::from(&role_name)],
1091 1 : }];
1092 1 :
1093 1 : let fetch = Fetch(rules);
1094 1 : let jwk_cache = JwkCache::default();
1095 1 :
1096 1 : let endpoint = EndpointId::from("ep");
1097 1 :
1098 1 : jwk_cache
1099 1 : .check_jwt(
1100 1 : &RequestContext::test(),
1101 1 : endpoint.clone(),
1102 1 : &role_name,
1103 1 : &fetch,
1104 1 : &token,
1105 1 : )
1106 3 : .await
1107 1 : .unwrap();
1108 1 : }
1109 :
1110 : #[tokio::test]
1111 1 : async fn check_jwt_invalid_signature() {
1112 1 : let (_, jwk) = new_ec_jwk("1".into());
1113 1 : let (key, _) = new_ec_jwk("1".into());
1114 1 :
1115 1 : // has a matching kid, but signed by the wrong key
1116 1 : let bad_jwt = new_ec_jwt("1".into(), &key);
1117 1 :
1118 1 : let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
1119 1 : let jwks_addr = jwks_server(move |path| match path {
1120 1 : "/" => Some(serde_json::to_vec(&jwks).unwrap()),
1121 1 : _ => None,
1122 1 : })
1123 1 : .await;
1124 1 :
1125 1 : let role = RoleName::from("authenticated");
1126 1 :
1127 1 : let rules = vec![AuthRule {
1128 1 : id: String::new(),
1129 1 : jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1130 1 : audience: None,
1131 1 : role_names: vec![RoleNameInt::from(&role)],
1132 1 : }];
1133 1 :
1134 1 : let fetch = Fetch(rules);
1135 1 : let jwk_cache = JwkCache::default();
1136 1 :
1137 1 : let ep = EndpointId::from("ep");
1138 1 :
1139 1 : let ctx = RequestContext::test();
1140 1 : let err = jwk_cache
1141 1 : .check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
1142 3 : .await
1143 1 : .unwrap_err();
1144 1 : assert!(
1145 1 : matches!(err, JwtError::Signature(_)),
1146 1 : "expected \"signature error\", got {err:?}"
1147 1 : );
1148 1 : }
1149 :
1150 : #[tokio::test]
1151 1 : async fn check_jwt_unknown_role() {
1152 1 : let (key, jwk) = new_rsa_jwk(RS1, "1".into());
1153 1 : let jwt = new_rsa_jwt("1".into(), key);
1154 1 :
1155 1 : let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
1156 1 : let jwks_addr = jwks_server(move |path| match path {
1157 1 : "/" => Some(serde_json::to_vec(&jwks).unwrap()),
1158 1 : _ => None,
1159 1 : })
1160 1 : .await;
1161 1 :
1162 1 : let role = RoleName::from("authenticated");
1163 1 : let rules = vec![AuthRule {
1164 1 : id: String::new(),
1165 1 : jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1166 1 : audience: None,
1167 1 : role_names: vec![RoleNameInt::from(&role)],
1168 1 : }];
1169 1 :
1170 1 : let fetch = Fetch(rules);
1171 1 : let jwk_cache = JwkCache::default();
1172 1 :
1173 1 : let ep = EndpointId::from("ep");
1174 1 :
1175 1 : // this role_name is not accepted
1176 1 : let bad_role_name = RoleName::from("cloud_admin");
1177 1 :
1178 1 : let ctx = RequestContext::test();
1179 1 : let err = jwk_cache
1180 1 : .check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
1181 3 : .await
1182 1 : .unwrap_err();
1183 1 :
1184 1 : assert!(
1185 1 : matches!(err, JwtError::JwkNotFound),
1186 1 : "expected \"jwk not found\", got {err:?}"
1187 1 : );
1188 1 : }
1189 :
1190 : #[tokio::test]
1191 1 : async fn check_jwt_invalid_claims() {
1192 1 : let (key, jwk) = new_ec_jwk("1".into());
1193 1 :
1194 1 : let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
1195 1 : let jwks_addr = jwks_server(move |path| match path {
1196 1 : "/" => Some(serde_json::to_vec(&jwks).unwrap()),
1197 1 : _ => None,
1198 1 : })
1199 1 : .await;
1200 1 :
1201 1 : let now = SystemTime::now()
1202 1 : .duration_since(SystemTime::UNIX_EPOCH)
1203 1 : .unwrap()
1204 1 : .as_secs();
1205 1 :
1206 1 : struct Test {
1207 1 : body: serde_json::Value,
1208 1 : error: JwtClaimsError,
1209 1 : }
1210 1 :
1211 1 : let table = vec![
1212 1 : Test {
1213 1 : body: json! {{
1214 1 : "nbf": now + 60,
1215 1 : "aud": "neon",
1216 1 : }},
1217 1 : error: JwtClaimsError::JwtTokenNotYetReadyToUse,
1218 1 : },
1219 1 : Test {
1220 1 : body: json! {{
1221 1 : "exp": now - 60,
1222 1 : "aud": ["neon"],
1223 1 : }},
1224 1 : error: JwtClaimsError::JwtTokenHasExpired,
1225 1 : },
1226 1 : Test {
1227 1 : body: json! {{
1228 1 : }},
1229 1 : error: JwtClaimsError::InvalidJwtTokenAudience,
1230 1 : },
1231 1 : Test {
1232 1 : body: json! {{
1233 1 : "aud": [],
1234 1 : }},
1235 1 : error: JwtClaimsError::InvalidJwtTokenAudience,
1236 1 : },
1237 1 : Test {
1238 1 : body: json! {{
1239 1 : "aud": "foo",
1240 1 : }},
1241 1 : error: JwtClaimsError::InvalidJwtTokenAudience,
1242 1 : },
1243 1 : Test {
1244 1 : body: json! {{
1245 1 : "aud": ["foo"],
1246 1 : }},
1247 1 : error: JwtClaimsError::InvalidJwtTokenAudience,
1248 1 : },
1249 1 : Test {
1250 1 : body: json! {{
1251 1 : "aud": ["foo", "bar"],
1252 1 : }},
1253 1 : error: JwtClaimsError::InvalidJwtTokenAudience,
1254 1 : },
1255 1 : ];
1256 1 :
1257 1 : let role = RoleName::from("authenticated");
1258 1 :
1259 1 : let rules = vec![AuthRule {
1260 1 : id: String::new(),
1261 1 : jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1262 1 : audience: Some("neon".to_string()),
1263 1 : role_names: vec![RoleNameInt::from(&role)],
1264 1 : }];
1265 1 :
1266 1 : let fetch = Fetch(rules);
1267 1 : let jwk_cache = JwkCache::default();
1268 1 :
1269 1 : let ep = EndpointId::from("ep");
1270 1 :
1271 1 : let ctx = RequestContext::test();
1272 8 : for test in table {
1273 7 : let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
1274 7 :
1275 7 : match jwk_cache
1276 7 : .check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
1277 3 : .await
1278 1 : {
1279 7 : Err(JwtError::InvalidClaims(error)) if error == test.error => {}
1280 1 : Err(err) => {
1281 0 : panic!("expected {:?}, got {err:?}", test.error)
1282 1 : }
1283 1 : Ok(_payload) => {
1284 0 : panic!("expected {:?}, got ok", test.error)
1285 1 : }
1286 1 : }
1287 1 : }
1288 1 : }
1289 :
1290 : #[tokio::test]
1291 1 : async fn check_jwk_keycloak_regression() {
1292 1 : let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into());
1293 1 : let valid_jwk = serde_json::to_value(valid_jwk).unwrap();
1294 1 :
1295 1 : // This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones.
1296 1 : // This is taken directly from keycloak.
1297 1 : let invalid_jwk = serde_json::json! {
1298 1 : {
1299 1 : "kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ",
1300 1 : "kty": "RSA",
1301 1 : "alg": "RSA-OAEP",
1302 1 : "use": "enc",
1303 1 : "n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw",
1304 1 : "e": "AQAB",
1305 1 : "x5c": [
1306 1 : "MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI="
1307 1 : ],
1308 1 : "x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs",
1309 1 : "x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU"
1310 1 : }
1311 1 : };
1312 1 :
1313 1 : let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }};
1314 1 : let jwks_addr = jwks_server(move |path| match path {
1315 1 : "/" => Some(serde_json::to_vec(&jwks).unwrap()),
1316 1 : _ => None,
1317 1 : })
1318 1 : .await;
1319 1 :
1320 1 : let role_name = RoleName::from("anonymous");
1321 1 : let role = RoleNameInt::from(&role_name);
1322 1 :
1323 1 : let rules = vec![AuthRule {
1324 1 : id: "foo".to_owned(),
1325 1 : jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
1326 1 : audience: None,
1327 1 : role_names: vec![role],
1328 1 : }];
1329 1 :
1330 1 : let fetch = Fetch(rules);
1331 1 : let jwk_cache = JwkCache::default();
1332 1 :
1333 1 : let endpoint = EndpointId::from("ep");
1334 1 :
1335 1 : let token = new_rsa_jwt("rs1".into(), rs);
1336 1 :
1337 1 : jwk_cache
1338 1 : .check_jwt(
1339 1 : &RequestContext::test(),
1340 1 : endpoint.clone(),
1341 1 : &role_name,
1342 1 : &fetch,
1343 1 : &token,
1344 1 : )
1345 3 : .await
1346 1 : .unwrap();
1347 1 : }
1348 : }
|