Line data Source code
1 : use std::future::Future;
2 : use std::sync::Arc;
3 : use std::time::{Duration, SystemTime};
4 :
5 : use arc_swap::ArcSwapOption;
6 : use dashmap::DashMap;
7 : use jose_jwk::crypto::KeyInfo;
8 : use serde::de::Visitor;
9 : use serde::{Deserialize, Deserializer};
10 : use signature::Verifier;
11 : use thiserror::Error;
12 : use tokio::time::Instant;
13 :
14 : use crate::auth::backend::ComputeCredentialKeys;
15 : use crate::context::RequestMonitoring;
16 : use crate::control_plane::errors::GetEndpointJwksError;
17 : use crate::http::parse_json_body_with_limit;
18 : use crate::intern::RoleNameInt;
19 : use crate::{EndpointId, RoleName};
20 :
21 : // TODO(conrad): make these configurable.
22 : const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
23 : const MIN_RENEW: Duration = Duration::from_secs(30);
24 : const AUTO_RENEW: Duration = Duration::from_secs(300);
25 : const MAX_RENEW: Duration = Duration::from_secs(3600);
26 : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
27 :
28 : /// How to get the JWT auth rules
29 : pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
30 : fn fetch_auth_rules(
31 : &self,
32 : ctx: &RequestMonitoring,
33 : endpoint: EndpointId,
34 : ) -> impl Future<Output = Result<Vec<AuthRule>, FetchAuthRulesError>> + Send;
35 : }
36 :
37 0 : #[derive(Error, Debug)]
38 : pub(crate) enum FetchAuthRulesError {
39 : #[error(transparent)]
40 : GetEndpointJwks(#[from] GetEndpointJwksError),
41 :
42 : #[error("JWKs settings for this role were not configured")]
43 : RoleJwksNotConfigured,
44 : }
45 :
46 : pub(crate) struct AuthRule {
47 : pub(crate) id: String,
48 : pub(crate) jwks_url: url::Url,
49 : pub(crate) audience: Option<String>,
50 : pub(crate) role_names: Vec<RoleNameInt>,
51 : }
52 :
53 : #[derive(Default)]
54 : pub struct JwkCache {
55 : client: reqwest::Client,
56 :
57 : map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
58 : }
59 :
60 : pub(crate) struct JwkCacheEntry {
61 : /// Should refetch at least every hour to verify when old keys have been removed.
62 : /// Should refetch when new key IDs are seen only every 5 minutes or so
63 : last_retrieved: Instant,
64 :
65 : /// cplane will return multiple JWKs urls that we need to scrape.
66 : key_sets: ahash::HashMap<String, KeySet>,
67 : }
68 :
69 : impl JwkCacheEntry {
70 10 : fn find_jwk_and_audience(
71 10 : &self,
72 10 : key_id: &str,
73 10 : role_name: &RoleName,
74 10 : ) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
75 10 : self.key_sets
76 10 : .values()
77 10 : // make sure our requested role has access to the key set
78 24 : .filter(|key_set| key_set.role_names.iter().any(|role| **role == **role_name))
79 10 : // try and find the requested key-id in the key set
80 14 : .find_map(|key_set| {
81 14 : key_set
82 14 : .find_key(key_id)
83 14 : .map(|jwk| (jwk, key_set.audience.as_deref()))
84 14 : })
85 10 : }
86 : }
87 :
88 : struct KeySet {
89 : jwks: jose_jwk::JwkSet,
90 : audience: Option<String>,
91 : role_names: Vec<RoleNameInt>,
92 : }
93 :
94 : impl KeySet {
95 14 : fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
96 14 : self.jwks
97 14 : .keys
98 14 : .iter()
99 24 : .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
100 14 : }
101 : }
102 :
103 : pub(crate) struct JwkCacheEntryLock {
104 : cached: ArcSwapOption<JwkCacheEntry>,
105 : lookup: tokio::sync::Semaphore,
106 : }
107 :
108 : impl Default for JwkCacheEntryLock {
109 1 : fn default() -> Self {
110 1 : JwkCacheEntryLock {
111 1 : cached: ArcSwapOption::empty(),
112 1 : lookup: tokio::sync::Semaphore::new(1),
113 1 : }
114 1 : }
115 : }
116 :
117 : impl JwkCacheEntryLock {
118 1 : async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
119 1 : JwkRenewalPermit::acquire_permit(self).await
120 1 : }
121 :
122 0 : fn try_acquire_permit<'a>(self: &'a Arc<Self>) -> Option<JwkRenewalPermit<'a>> {
123 0 : JwkRenewalPermit::try_acquire_permit(self)
124 0 : }
125 :
126 1 : async fn renew_jwks<F: FetchAuthRules>(
127 1 : &self,
128 1 : _permit: JwkRenewalPermit<'_>,
129 1 : ctx: &RequestMonitoring,
130 1 : client: &reqwest::Client,
131 1 : endpoint: EndpointId,
132 1 : auth_rules: &F,
133 1 : ) -> Result<Arc<JwkCacheEntry>, JwtError> {
134 1 : // double check that no one beat us to updating the cache.
135 1 : let now = Instant::now();
136 1 : let guard = self.cached.load_full();
137 1 : if let Some(cached) = guard {
138 0 : let last_update = now.duration_since(cached.last_retrieved);
139 0 : if last_update < Duration::from_secs(300) {
140 0 : return Ok(cached);
141 0 : }
142 1 : }
143 :
144 1 : let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
145 1 : let mut key_sets =
146 1 : ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
147 :
148 : // TODO(conrad): run concurrently
149 : // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
150 3 : for rule in rules {
151 2 : let req = client.get(rule.jwks_url.clone());
152 2 : // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
153 2 : // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
154 4 : match req.send().await.and_then(|r| r.error_for_status()) {
155 : // todo: should we re-insert JWKs if we want to keep this JWKs URL?
156 : // I expect these failures would be quite sparse.
157 0 : Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
158 2 : Ok(r) => {
159 2 : let resp: http::Response<reqwest::Body> = r.into();
160 2 : match parse_json_body_with_limit::<jose_jwk::JwkSet>(
161 2 : resp.into_body(),
162 2 : MAX_JWK_BODY_SIZE,
163 2 : )
164 0 : .await
165 : {
166 0 : Err(e) => {
167 0 : tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
168 : }
169 2 : Ok(jwks) => {
170 2 : key_sets.insert(
171 2 : rule.id,
172 2 : KeySet {
173 2 : jwks,
174 2 : audience: rule.audience,
175 2 : role_names: rule.role_names,
176 2 : },
177 2 : );
178 2 : }
179 : }
180 : }
181 : }
182 : }
183 :
184 1 : let entry = Arc::new(JwkCacheEntry {
185 1 : last_retrieved: now,
186 1 : key_sets,
187 1 : });
188 1 : self.cached.swap(Some(Arc::clone(&entry)));
189 1 :
190 1 : Ok(entry)
191 1 : }
192 :
193 10 : async fn get_or_update_jwk_cache<F: FetchAuthRules>(
194 10 : self: &Arc<Self>,
195 10 : ctx: &RequestMonitoring,
196 10 : client: &reqwest::Client,
197 10 : endpoint: EndpointId,
198 10 : fetch: &F,
199 10 : ) -> Result<Arc<JwkCacheEntry>, JwtError> {
200 10 : let now = Instant::now();
201 10 : let guard = self.cached.load_full();
202 :
203 : // if we have no cached JWKs, try and get some
204 10 : let Some(cached) = guard else {
205 1 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
206 1 : let permit = self.acquire_permit().await;
207 4 : return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
208 : };
209 :
210 9 : let last_update = now.duration_since(cached.last_retrieved);
211 9 :
212 9 : // check if the cached JWKs need updating.
213 9 : if last_update > MAX_RENEW {
214 0 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
215 0 : let permit = self.acquire_permit().await;
216 :
217 : // it's been too long since we checked the keys. wait for them to update.
218 0 : return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
219 9 : }
220 9 :
221 9 : // every 5 minutes we should spawn a job to eagerly update the token.
222 9 : if last_update > AUTO_RENEW {
223 0 : if let Some(permit) = self.try_acquire_permit() {
224 0 : tracing::debug!("JWKs should be renewed. Renewal permit acquired");
225 0 : let permit = permit.into_owned();
226 0 : let entry = self.clone();
227 0 : let client = client.clone();
228 0 : let fetch = fetch.clone();
229 0 : let ctx = ctx.clone();
230 0 : tokio::spawn(async move {
231 0 : if let Err(e) = entry
232 0 : .renew_jwks(permit, &ctx, &client, endpoint, &fetch)
233 0 : .await
234 : {
235 0 : tracing::warn!(error=?e, "could not fetch JWKs in background job");
236 0 : }
237 0 : });
238 0 : } else {
239 0 : tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
240 : }
241 9 : }
242 :
243 9 : Ok(cached)
244 10 : }
245 :
246 10 : async fn check_jwt<F: FetchAuthRules>(
247 10 : self: &Arc<Self>,
248 10 : ctx: &RequestMonitoring,
249 10 : jwt: &str,
250 10 : client: &reqwest::Client,
251 10 : endpoint: EndpointId,
252 10 : role_name: &RoleName,
253 10 : fetch: &F,
254 10 : ) -> Result<ComputeCredentialKeys, JwtError> {
255 : // JWT compact form is defined to be
256 : // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
257 : // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
258 :
259 10 : let (header_payload, signature) = jwt
260 10 : .rsplit_once('.')
261 10 : .ok_or(JwtEncodingError::InvalidCompactForm)?;
262 10 : let (header, payload) = header_payload
263 10 : .split_once('.')
264 10 : .ok_or(JwtEncodingError::InvalidCompactForm)?;
265 :
266 10 : let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?;
267 10 : let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
268 :
269 10 : let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)?;
270 :
271 10 : let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
272 :
273 10 : let mut guard = self
274 10 : .get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
275 4 : .await?;
276 :
277 : // get the key from the JWKs if possible. If not, wait for the keys to update.
278 9 : let (jwk, expected_audience) = loop {
279 10 : match guard.find_jwk_and_audience(kid, role_name) {
280 9 : Some(jwk) => break jwk,
281 1 : None if guard.last_retrieved.elapsed() > MIN_RENEW => {
282 0 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
283 :
284 0 : let permit = self.acquire_permit().await;
285 0 : guard = self
286 0 : .renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
287 0 : .await?;
288 : }
289 1 : _ => return Err(JwtError::JwkNotFound),
290 : }
291 : };
292 :
293 9 : if !jwk.is_supported(&header.algorithm) {
294 0 : return Err(JwtError::SignatureAlgorithmNotSupported);
295 9 : }
296 9 :
297 9 : match &jwk.key {
298 5 : jose_jwk::Key::Ec(key) => {
299 5 : verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
300 : }
301 4 : jose_jwk::Key::Rsa(key) => {
302 4 : verify_rsa_signature(header_payload.as_bytes(), &sig, key, &header.algorithm)?;
303 : }
304 0 : key => return Err(JwtError::UnsupportedKeyType(key.into())),
305 : };
306 :
307 8 : let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)?;
308 8 : let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
309 :
310 8 : tracing::debug!(?payload, "JWT signature valid with claims");
311 :
312 8 : if let Some(aud) = expected_audience {
313 0 : if payload.audience.0.iter().all(|s| s != aud) {
314 0 : return Err(JwtError::InvalidJwtTokenAudience);
315 0 : }
316 8 : }
317 :
318 8 : let now = SystemTime::now();
319 :
320 8 : if let Some(exp) = payload.expiration {
321 8 : if now >= exp + CLOCK_SKEW_LEEWAY {
322 0 : return Err(JwtError::JwtTokenHasExpired);
323 8 : }
324 0 : }
325 :
326 8 : if let Some(nbf) = payload.not_before {
327 0 : if nbf >= now + CLOCK_SKEW_LEEWAY {
328 0 : return Err(JwtError::JwtTokenNotYetReadyToUse);
329 0 : }
330 8 : }
331 :
332 8 : Ok(ComputeCredentialKeys::JwtPayload(payloadb))
333 10 : }
334 : }
335 :
336 : impl JwkCache {
337 0 : pub(crate) async fn check_jwt<F: FetchAuthRules>(
338 0 : &self,
339 0 : ctx: &RequestMonitoring,
340 0 : endpoint: EndpointId,
341 0 : role_name: &RoleName,
342 0 : fetch: &F,
343 0 : jwt: &str,
344 0 : ) -> Result<ComputeCredentialKeys, JwtError> {
345 0 : // try with just a read lock first
346 0 : let key = (endpoint.clone(), role_name.clone());
347 0 : let entry = self.map.get(&key).as_deref().map(Arc::clone);
348 0 : let entry = entry.unwrap_or_else(|| {
349 0 : // acquire a write lock after to insert.
350 0 : let entry = self.map.entry(key).or_default();
351 0 : Arc::clone(&*entry)
352 0 : });
353 0 :
354 0 : entry
355 0 : .check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
356 0 : .await
357 0 : }
358 : }
359 :
360 5 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> Result<(), JwtError> {
361 : use ecdsa::Signature;
362 : use signature::Verifier;
363 :
364 5 : match key.crv {
365 : jose_jwk::EcCurves::P256 => {
366 5 : let pk = p256::PublicKey::try_from(key).map_err(JwtError::InvalidP256Key)?;
367 5 : let key = p256::ecdsa::VerifyingKey::from(&pk);
368 5 : let sig = Signature::from_slice(sig)?;
369 5 : key.verify(data, &sig)?;
370 : }
371 0 : key => return Err(JwtError::UnsupportedEcKeyType(key)),
372 : }
373 :
374 4 : Ok(())
375 5 : }
376 :
377 4 : fn verify_rsa_signature(
378 4 : data: &[u8],
379 4 : sig: &[u8],
380 4 : key: &jose_jwk::Rsa,
381 4 : alg: &jose_jwa::Algorithm,
382 4 : ) -> Result<(), JwtError> {
383 : use jose_jwa::{Algorithm, Signing};
384 : use rsa::pkcs1v15::{Signature, VerifyingKey};
385 : use rsa::RsaPublicKey;
386 :
387 4 : let key = RsaPublicKey::try_from(key).map_err(JwtError::InvalidRsaKey)?;
388 :
389 4 : match alg {
390 : Algorithm::Signing(Signing::Rs256) => {
391 4 : let key = VerifyingKey::<sha2::Sha256>::new(key);
392 4 : let sig = Signature::try_from(sig)?;
393 4 : key.verify(data, &sig)?;
394 : }
395 0 : _ => return Err(JwtError::InvalidRsaSigningAlgorithm),
396 : };
397 :
398 4 : Ok(())
399 4 : }
400 :
401 : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
402 30 : #[derive(serde::Deserialize, serde::Serialize)]
403 : struct JwtHeader<'a> {
404 : /// must be a supported alg
405 : #[serde(rename = "alg")]
406 : algorithm: jose_jwa::Algorithm,
407 : /// key id, must be provided for our usecase
408 : #[serde(rename = "kid")]
409 : key_id: Option<&'a str>,
410 : }
411 :
412 : /// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
413 24 : #[derive(serde::Deserialize, Debug)]
414 : #[allow(dead_code)]
415 : struct JwtPayload<'a> {
416 : /// Audience - Recipient for which the JWT is intended
417 : #[serde(rename = "aud", default)]
418 : audience: OneOrMany,
419 : /// Expiration - Time after which the JWT expires
420 : #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
421 : expiration: Option<SystemTime>,
422 : /// Not before - Time after which the JWT expires
423 : #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
424 : not_before: Option<SystemTime>,
425 :
426 : // the following entries are only extracted for the sake of debug logging.
427 : /// Issuer of the JWT
428 : #[serde(rename = "iss")]
429 : issuer: Option<&'a str>,
430 : /// Subject of the JWT (the user)
431 : #[serde(rename = "sub")]
432 : subject: Option<&'a str>,
433 : /// Unique token identifier
434 : #[serde(rename = "jti")]
435 : jwt_id: Option<&'a str>,
436 : /// Unique session identifier
437 : #[serde(rename = "sid")]
438 : session_id: Option<&'a str>,
439 : }
440 :
441 : /// `OneOrMany` supports parsing either a single item or an array of items.
442 : ///
443 : /// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
444 : ///
445 : /// > The "aud" (audience) claim identifies the recipients that the JWT is
446 : /// > intended for. Each principal intended to process the JWT MUST
447 : /// > identify itself with a value in the audience claim. If the principal
448 : /// > processing the claim does not identify itself with a value in the
449 : /// > "aud" claim when this claim is present, then the JWT MUST be
450 : /// > rejected. In the general case, the "aud" value is **an array of case-
451 : /// > sensitive strings**, each containing a StringOrURI value. In the
452 : /// > special case when the JWT has one audience, the "aud" value MAY be a
453 : /// > **single case-sensitive string** containing a StringOrURI value. The
454 : /// > interpretation of audience values is generally application specific.
455 : /// > Use of this claim is OPTIONAL.
456 : #[derive(Default, Debug)]
457 : struct OneOrMany(Vec<String>);
458 :
459 : impl<'de> Deserialize<'de> for OneOrMany {
460 0 : fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
461 0 : where
462 0 : D: Deserializer<'de>,
463 0 : {
464 : struct OneOrManyVisitor;
465 : impl<'de> Visitor<'de> for OneOrManyVisitor {
466 : type Value = OneOrMany;
467 :
468 0 : fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
469 0 : formatter.write_str("a single string or an array of strings")
470 0 : }
471 :
472 0 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
473 0 : where
474 0 : E: serde::de::Error,
475 0 : {
476 0 : Ok(OneOrMany(vec![v.to_owned()]))
477 0 : }
478 :
479 0 : fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
480 0 : where
481 0 : A: serde::de::SeqAccess<'de>,
482 0 : {
483 0 : let mut v = vec![];
484 0 : while let Some(s) = seq.next_element()? {
485 0 : v.push(s);
486 0 : }
487 0 : Ok(OneOrMany(v))
488 0 : }
489 : }
490 0 : deserializer.deserialize_any(OneOrManyVisitor)
491 0 : }
492 : }
493 :
494 8 : fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
495 8 : let d = <Option<u64>>::deserialize(d)?;
496 8 : Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
497 8 : }
498 :
499 : struct JwkRenewalPermit<'a> {
500 : inner: Option<JwkRenewalPermitInner<'a>>,
501 : }
502 :
503 : enum JwkRenewalPermitInner<'a> {
504 : Owned(Arc<JwkCacheEntryLock>),
505 : Borrowed(&'a Arc<JwkCacheEntryLock>),
506 : }
507 :
508 : impl JwkRenewalPermit<'_> {
509 0 : fn into_owned(mut self) -> JwkRenewalPermit<'static> {
510 0 : JwkRenewalPermit {
511 0 : inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
512 0 : }
513 0 : }
514 :
515 1 : async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
516 1 : match from.lookup.acquire().await {
517 1 : Ok(permit) => {
518 1 : permit.forget();
519 1 : JwkRenewalPermit {
520 1 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
521 1 : }
522 : }
523 0 : Err(_) => panic!("semaphore should not be closed"),
524 : }
525 1 : }
526 :
527 0 : fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
528 0 : match from.lookup.try_acquire() {
529 0 : Ok(permit) => {
530 0 : permit.forget();
531 0 : Some(JwkRenewalPermit {
532 0 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
533 0 : })
534 : }
535 0 : Err(tokio::sync::TryAcquireError::NoPermits) => None,
536 0 : Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
537 : }
538 0 : }
539 : }
540 :
541 : impl JwkRenewalPermitInner<'_> {
542 0 : fn into_owned(self) -> JwkRenewalPermitInner<'static> {
543 0 : match self {
544 0 : JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
545 0 : JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
546 : }
547 0 : }
548 : }
549 :
550 : impl Drop for JwkRenewalPermit<'_> {
551 1 : fn drop(&mut self) {
552 1 : let entry = match &self.inner {
553 0 : None => return,
554 0 : Some(JwkRenewalPermitInner::Owned(p)) => p,
555 1 : Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
556 : };
557 1 : entry.lookup.add_permits(1);
558 1 : }
559 : }
560 :
561 2 : #[derive(Error, Debug)]
562 : #[non_exhaustive]
563 : pub(crate) enum JwtError {
564 : #[error("jwk not found")]
565 : JwkNotFound,
566 :
567 : #[error("missing key id")]
568 : MissingKeyId,
569 :
570 : #[error("Provided authentication token is not a valid JWT encoding")]
571 : JwtEncoding(#[from] JwtEncodingError),
572 :
573 : #[error("invalid JWT token audience")]
574 : InvalidJwtTokenAudience,
575 :
576 : #[error("JWT token has expired")]
577 : JwtTokenHasExpired,
578 :
579 : #[error("JWT token is not yet ready to use")]
580 : JwtTokenNotYetReadyToUse,
581 :
582 : #[error("invalid P256 key")]
583 : InvalidP256Key(jose_jwk::crypto::Error),
584 :
585 : #[error("invalid RSA key")]
586 : InvalidRsaKey(jose_jwk::crypto::Error),
587 :
588 : #[error("invalid RSA signing algorithm")]
589 : InvalidRsaSigningAlgorithm,
590 :
591 : #[error("unsupported EC key type {0:?}")]
592 : UnsupportedEcKeyType(jose_jwk::EcCurves),
593 :
594 : #[error("unsupported key type {0:?}")]
595 : UnsupportedKeyType(KeyType),
596 :
597 : #[error("signature algorithm not supported")]
598 : SignatureAlgorithmNotSupported,
599 :
600 : #[error("signature error: {0}")]
601 : Signature(#[from] signature::Error),
602 :
603 : #[error("failed to fetch auth rules: {0}")]
604 : FetchAuthRules(#[from] FetchAuthRulesError),
605 : }
606 :
607 : impl From<base64::DecodeError> for JwtError {
608 0 : fn from(err: base64::DecodeError) -> Self {
609 0 : JwtEncodingError::Base64Decode(err).into()
610 0 : }
611 : }
612 :
613 : impl From<serde_json::Error> for JwtError {
614 0 : fn from(err: serde_json::Error) -> Self {
615 0 : JwtEncodingError::SerdeJson(err).into()
616 0 : }
617 : }
618 :
619 0 : #[derive(Error, Debug)]
620 : #[non_exhaustive]
621 : pub enum JwtEncodingError {
622 : #[error(transparent)]
623 : Base64Decode(#[from] base64::DecodeError),
624 :
625 : #[error(transparent)]
626 : SerdeJson(#[from] serde_json::Error),
627 :
628 : #[error("invalid compact form")]
629 : InvalidCompactForm,
630 : }
631 :
632 : #[allow(dead_code, reason = "Debug use only")]
633 : #[derive(Debug)]
634 : pub(crate) enum KeyType {
635 : Ec(jose_jwk::EcCurves),
636 : Rsa,
637 : Oct,
638 : Okp(jose_jwk::OkpCurves),
639 : Unknown,
640 : }
641 :
642 : impl From<&jose_jwk::Key> for KeyType {
643 0 : fn from(key: &jose_jwk::Key) -> Self {
644 0 : match key {
645 0 : jose_jwk::Key::Ec(ec) => Self::Ec(ec.crv),
646 0 : jose_jwk::Key::Rsa(_rsa) => Self::Rsa,
647 0 : jose_jwk::Key::Oct(_oct) => Self::Oct,
648 0 : jose_jwk::Key::Okp(okp) => Self::Okp(okp.crv),
649 0 : _ => Self::Unknown,
650 : }
651 0 : }
652 : }
653 :
654 : #[cfg(test)]
655 : mod tests {
656 : use std::future::IntoFuture;
657 : use std::net::SocketAddr;
658 : use std::time::SystemTime;
659 :
660 : use base64::URL_SAFE_NO_PAD;
661 : use bytes::Bytes;
662 : use http::Response;
663 : use http_body_util::Full;
664 : use hyper::service::service_fn;
665 : use hyper_util::rt::TokioIo;
666 : use rand::rngs::OsRng;
667 : use rsa::pkcs8::DecodePrivateKey;
668 : use signature::Signer;
669 : use tokio::net::TcpListener;
670 :
671 : use super::*;
672 : use crate::RoleName;
673 :
674 2 : fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
675 2 : let sk = p256::SecretKey::random(&mut OsRng);
676 2 : let pk = sk.public_key().into();
677 2 : let jwk = jose_jwk::Jwk {
678 2 : key: jose_jwk::Key::Ec(pk),
679 2 : prm: jose_jwk::Parameters {
680 2 : kid: Some(kid),
681 2 : ..Default::default()
682 2 : },
683 2 : };
684 2 : (sk, jwk)
685 2 : }
686 :
687 2 : fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
688 2 : let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap();
689 2 : let pk = sk.to_public_key().into();
690 2 : let jwk = jose_jwk::Jwk {
691 2 : key: jose_jwk::Key::Rsa(pk),
692 2 : prm: jose_jwk::Parameters {
693 2 : kid: Some(kid),
694 2 : ..Default::default()
695 2 : },
696 2 : };
697 2 : (sk, jwk)
698 2 : }
699 :
700 5 : fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
701 5 : let header = JwtHeader {
702 5 : algorithm: jose_jwa::Algorithm::Signing(sig),
703 5 : key_id: Some(&kid),
704 5 : };
705 5 : let body = typed_json::json! {{
706 5 : "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
707 5 : }};
708 5 :
709 5 : let header =
710 5 : base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
711 5 : let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
712 5 :
713 5 : format!("{header}.{body}")
714 5 : }
715 :
716 3 : fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
717 : use p256::ecdsa::{Signature, SigningKey};
718 :
719 3 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
720 3 : let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
721 3 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
722 3 :
723 3 : format!("{payload}.{sig}")
724 3 : }
725 :
726 2 : fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
727 : use rsa::pkcs1v15::SigningKey;
728 : use rsa::signature::SignatureEncoding;
729 :
730 2 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
731 2 : let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
732 2 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
733 2 :
734 2 : format!("{payload}.{sig}")
735 2 : }
736 :
737 : // RSA key gen is slow....
738 : const RS1: &str = "-----BEGIN PRIVATE KEY-----
739 : MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y
740 : aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA
741 : SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA
742 : cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X
743 : mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s
744 : PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM
745 : nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9
746 : ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi
747 : 5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+
748 : KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz
749 : 1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP
750 : We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony
751 : UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu
752 : GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+
753 : HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU
754 : f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24
755 : DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY
756 : 9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF
757 : /4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL
758 : 3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE
759 : qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d
760 : ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW
761 : rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf
762 : JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI
763 : rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9
764 : 8LMFlz1QPqSZYN/A/kOcLBfa3A==
765 : -----END PRIVATE KEY-----
766 : ";
767 : const RS2: &str = "-----BEGIN PRIVATE KEY-----
768 : MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J
769 : HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J
770 : zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk
771 : qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO
772 : IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt
773 : ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy
774 : uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv
775 : K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG
776 : LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe
777 : QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT
778 : zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX
779 : 0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR
780 : +Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1
781 : P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL
782 : x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/
783 : FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg
784 : iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M
785 : 4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA
786 : iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj
787 : N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB
788 : uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U
789 : qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP
790 : WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB
791 : rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne
792 : X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
793 : 5JiconnI5aLek0QVPoFaVXFa
794 : -----END PRIVATE KEY-----
795 : ";
796 :
797 : #[tokio::test]
798 1 : async fn renew() {
799 1 : let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into());
800 1 : let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into());
801 1 : let (ec1, jwk3) = new_ec_jwk("3".into());
802 1 : let (ec2, jwk4) = new_ec_jwk("4".into());
803 1 :
804 1 : let foo_jwks = jose_jwk::JwkSet {
805 1 : keys: vec![jwk1, jwk3],
806 1 : };
807 1 : let bar_jwks = jose_jwk::JwkSet {
808 1 : keys: vec![jwk2, jwk4],
809 1 : };
810 1 :
811 2 : let service = service_fn(move |req| {
812 2 : let foo_jwks = foo_jwks.clone();
813 2 : let bar_jwks = bar_jwks.clone();
814 2 : async move {
815 2 : let jwks = match req.uri().path() {
816 2 : "/foo" => &foo_jwks,
817 1 : "/bar" => &bar_jwks,
818 1 : _ => {
819 1 : return Response::builder()
820 0 : .status(404)
821 0 : .body(Full::new(Bytes::new()));
822 1 : }
823 1 : };
824 2 : let body = serde_json::to_vec(jwks).unwrap();
825 2 : Response::builder()
826 2 : .status(200)
827 2 : .body(Full::new(Bytes::from(body)))
828 2 : }
829 2 : });
830 1 :
831 1 : let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
832 1 : let server = hyper::server::conn::http1::Builder::new();
833 1 : let addr = listener.local_addr().unwrap();
834 1 : tokio::spawn(async move {
835 1 : loop {
836 2 : let (s, _) = listener.accept().await.unwrap();
837 1 : let serve = server.serve_connection(TokioIo::new(s), service.clone());
838 1 : tokio::spawn(serve.into_future());
839 1 : }
840 1 : });
841 1 :
842 1 : let client = reqwest::Client::new();
843 1 :
844 1 : #[derive(Clone)]
845 1 : struct Fetch(SocketAddr, Vec<RoleNameInt>);
846 1 :
847 1 : impl FetchAuthRules for Fetch {
848 1 : async fn fetch_auth_rules(
849 1 : &self,
850 1 : _ctx: &RequestMonitoring,
851 1 : _endpoint: EndpointId,
852 1 : ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
853 1 : Ok(vec![
854 1 : AuthRule {
855 1 : id: "foo".to_owned(),
856 1 : jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
857 1 : audience: None,
858 1 : role_names: self.1.clone(),
859 1 : },
860 1 : AuthRule {
861 1 : id: "bar".to_owned(),
862 1 : jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
863 1 : audience: None,
864 1 : role_names: self.1.clone(),
865 1 : },
866 1 : ])
867 1 : }
868 1 : }
869 1 :
870 1 : let role_name1 = RoleName::from("anonymous");
871 1 : let role_name2 = RoleName::from("authenticated");
872 1 :
873 1 : let fetch = Fetch(
874 1 : addr,
875 1 : vec![
876 1 : RoleNameInt::from(&role_name1),
877 1 : RoleNameInt::from(&role_name2),
878 1 : ],
879 1 : );
880 1 :
881 1 : let endpoint = EndpointId::from("ep");
882 1 :
883 1 : let jwk_cache = Arc::new(JwkCacheEntryLock::default());
884 1 :
885 1 : let jwt1 = new_rsa_jwt("1".into(), rs1);
886 1 : let jwt2 = new_rsa_jwt("2".into(), rs2);
887 1 : let jwt3 = new_ec_jwt("3".into(), &ec1);
888 1 : let jwt4 = new_ec_jwt("4".into(), &ec2);
889 1 :
890 1 : // had the wrong kid, therefore will have the wrong ecdsa signature
891 1 : let bad_jwt = new_ec_jwt("3".into(), &ec2);
892 1 : // this role_name is not accepted
893 1 : let bad_role_name = RoleName::from("cloud_admin");
894 1 :
895 1 : let err = jwk_cache
896 1 : .check_jwt(
897 1 : &RequestMonitoring::test(),
898 1 : &bad_jwt,
899 1 : &client,
900 1 : endpoint.clone(),
901 1 : &role_name1,
902 1 : &fetch,
903 1 : )
904 4 : .await
905 1 : .unwrap_err();
906 1 : assert!(err.to_string().contains("signature error"));
907 1 :
908 1 : let err = jwk_cache
909 1 : .check_jwt(
910 1 : &RequestMonitoring::test(),
911 1 : &jwt1,
912 1 : &client,
913 1 : endpoint.clone(),
914 1 : &bad_role_name,
915 1 : &fetch,
916 1 : )
917 1 : .await
918 1 : .unwrap_err();
919 1 : assert!(err.to_string().contains("jwk not found"));
920 1 :
921 1 : let tokens = [jwt1, jwt2, jwt3, jwt4];
922 1 : let role_names = [role_name1, role_name2];
923 3 : for role in &role_names {
924 10 : for token in &tokens {
925 8 : jwk_cache
926 8 : .check_jwt(
927 8 : &RequestMonitoring::test(),
928 8 : token,
929 8 : &client,
930 8 : endpoint.clone(),
931 8 : role,
932 8 : &fetch,
933 8 : )
934 1 : .await
935 8 : .unwrap();
936 1 : }
937 1 : }
938 1 : }
939 : }
|