Line data Source code
1 : use std::{future::Future, sync::Arc, time::Duration};
2 :
3 : use anyhow::{bail, ensure, Context};
4 : use arc_swap::ArcSwapOption;
5 : use dashmap::DashMap;
6 : use jose_jwk::crypto::KeyInfo;
7 : use signature::Verifier;
8 : use tokio::time::Instant;
9 :
10 : use crate::{http::parse_json_body_with_limit, intern::EndpointIdInt};
11 :
12 : // TODO(conrad): make these configurable.
13 : const MIN_RENEW: Duration = Duration::from_secs(30);
14 : const AUTO_RENEW: Duration = Duration::from_secs(300);
15 : const MAX_RENEW: Duration = Duration::from_secs(3600);
16 : const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
17 :
18 : /// How to get the JWT auth rules
19 : pub trait FetchAuthRules: Clone + Send + Sync + 'static {
20 : fn fetch_auth_rules(&self) -> impl Future<Output = anyhow::Result<AuthRules>> + Send;
21 : }
22 :
23 : #[derive(Clone)]
24 : struct FetchAuthRulesFromCplane {
25 : #[allow(dead_code)]
26 : endpoint: EndpointIdInt,
27 : }
28 :
29 : impl FetchAuthRules for FetchAuthRulesFromCplane {
30 0 : async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
31 0 : Err(anyhow::anyhow!("not yet implemented"))
32 0 : }
33 : }
34 :
35 : pub struct AuthRules {
36 : jwks_urls: Vec<url::Url>,
37 : }
38 :
39 : #[derive(Default)]
40 : pub struct JwkCache {
41 : client: reqwest::Client,
42 :
43 : map: DashMap<EndpointIdInt, Arc<JwkCacheEntryLock>>,
44 : }
45 :
46 : pub struct JwkCacheEntryLock {
47 : cached: ArcSwapOption<JwkCacheEntry>,
48 : lookup: tokio::sync::Semaphore,
49 : }
50 :
51 : impl Default for JwkCacheEntryLock {
52 2 : fn default() -> Self {
53 2 : JwkCacheEntryLock {
54 2 : cached: ArcSwapOption::empty(),
55 2 : lookup: tokio::sync::Semaphore::new(1),
56 2 : }
57 2 : }
58 : }
59 :
60 : pub struct JwkCacheEntry {
61 : /// Should refetch at least every hour to verify when old keys have been removed.
62 : /// Should refetch when new key IDs are seen only every 5 minutes or so
63 : last_retrieved: Instant,
64 :
65 : /// cplane will return multiple JWKs urls that we need to scrape.
66 : key_sets: ahash::HashMap<url::Url, jose_jwk::JwkSet>,
67 : }
68 :
69 : impl JwkCacheEntryLock {
70 2 : async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
71 2 : JwkRenewalPermit::acquire_permit(self).await
72 2 : }
73 :
74 0 : fn try_acquire_permit<'a>(self: &'a Arc<Self>) -> Option<JwkRenewalPermit<'a>> {
75 0 : JwkRenewalPermit::try_acquire_permit(self)
76 0 : }
77 :
78 2 : async fn renew_jwks<F: FetchAuthRules>(
79 2 : &self,
80 2 : _permit: JwkRenewalPermit<'_>,
81 2 : client: &reqwest::Client,
82 2 : auth_rules: &F,
83 2 : ) -> anyhow::Result<Arc<JwkCacheEntry>> {
84 2 : // double check that no one beat us to updating the cache.
85 2 : let now = Instant::now();
86 2 : let guard = self.cached.load_full();
87 2 : if let Some(cached) = guard {
88 0 : let last_update = now.duration_since(cached.last_retrieved);
89 0 : if last_update < Duration::from_secs(300) {
90 0 : return Ok(cached);
91 0 : }
92 2 : }
93 :
94 2 : let rules = auth_rules.fetch_auth_rules().await?;
95 2 : let mut key_sets = ahash::HashMap::with_capacity_and_hasher(
96 2 : rules.jwks_urls.len(),
97 2 : ahash::RandomState::new(),
98 2 : );
99 : // TODO(conrad): run concurrently
100 : // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
101 6 : for url in rules.jwks_urls {
102 4 : let req = client.get(url.clone());
103 4 : // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
104 8 : match req.send().await.and_then(|r| r.error_for_status()) {
105 : // todo: should we re-insert JWKs if we want to keep this JWKs URL?
106 : // I expect these failures would be quite sparse.
107 0 : Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"),
108 4 : Ok(r) => {
109 4 : let resp: http::Response<reqwest::Body> = r.into();
110 4 : match parse_json_body_with_limit::<jose_jwk::JwkSet>(
111 4 : resp.into_body(),
112 4 : MAX_JWK_BODY_SIZE,
113 4 : )
114 0 : .await
115 : {
116 0 : Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"),
117 4 : Ok(jwks) => {
118 4 : key_sets.insert(url, jwks);
119 4 : }
120 : }
121 : }
122 : }
123 : }
124 :
125 2 : let entry = Arc::new(JwkCacheEntry {
126 2 : last_retrieved: now,
127 2 : key_sets,
128 2 : });
129 2 : self.cached.swap(Some(Arc::clone(&entry)));
130 2 :
131 2 : Ok(entry)
132 2 : }
133 :
134 8 : async fn get_or_update_jwk_cache<F: FetchAuthRules>(
135 8 : self: &Arc<Self>,
136 8 : client: &reqwest::Client,
137 8 : fetch: &F,
138 8 : ) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
139 8 : let now = Instant::now();
140 8 : let guard = self.cached.load_full();
141 :
142 : // if we have no cached JWKs, try and get some
143 8 : let Some(cached) = guard else {
144 2 : let permit = self.acquire_permit().await;
145 8 : return self.renew_jwks(permit, client, fetch).await;
146 : };
147 :
148 6 : let last_update = now.duration_since(cached.last_retrieved);
149 6 :
150 6 : // check if the cached JWKs need updating.
151 6 : if last_update > MAX_RENEW {
152 0 : let permit = self.acquire_permit().await;
153 :
154 : // it's been too long since we checked the keys. wait for them to update.
155 0 : return self.renew_jwks(permit, client, fetch).await;
156 6 : }
157 6 :
158 6 : // every 5 minutes we should spawn a job to eagerly update the token.
159 6 : if last_update > AUTO_RENEW {
160 0 : if let Some(permit) = self.try_acquire_permit() {
161 0 : tracing::debug!("JWKs should be renewed. Renewal permit acquired");
162 0 : let permit = permit.into_owned();
163 0 : let entry = self.clone();
164 0 : let client = client.clone();
165 0 : let fetch = fetch.clone();
166 0 : tokio::spawn(async move {
167 0 : if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await {
168 0 : tracing::warn!(error=?e, "could not fetch JWKs in background job");
169 0 : }
170 0 : });
171 0 : } else {
172 0 : tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
173 : }
174 6 : }
175 :
176 6 : Ok(cached)
177 8 : }
178 :
179 8 : async fn check_jwt<F: FetchAuthRules>(
180 8 : self: &Arc<Self>,
181 8 : jwt: String,
182 8 : client: &reqwest::Client,
183 8 : fetch: &F,
184 8 : ) -> Result<(), anyhow::Error> {
185 : // JWT compact form is defined to be
186 : // <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
187 : // where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
188 :
189 8 : let (header_payload, signature) = jwt
190 8 : .rsplit_once(".")
191 8 : .context("Provided authentication token is not a valid JWT encoding")?;
192 8 : let (header, _payload) = header_payload
193 8 : .split_once(".")
194 8 : .context("Provided authentication token is not a valid JWT encoding")?;
195 :
196 8 : let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)
197 8 : .context("Provided authentication token is not a valid JWT encoding")?;
198 8 : let header = serde_json::from_slice::<JWTHeader<'_>>(&header)
199 8 : .context("Provided authentication token is not a valid JWT encoding")?;
200 :
201 8 : let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
202 8 : .context("Provided authentication token is not a valid JWT encoding")?;
203 :
204 8 : ensure!(header.typ == "JWT");
205 8 : let kid = header.kid.context("missing key id")?;
206 :
207 8 : let mut guard = self.get_or_update_jwk_cache(client, fetch).await?;
208 :
209 : // get the key from the JWKs if possible. If not, wait for the keys to update.
210 8 : let jwk = loop {
211 8 : let jwk = guard
212 8 : .key_sets
213 8 : .values()
214 12 : .flat_map(|jwks| &jwks.keys)
215 20 : .find(|jwk| jwk.prm.kid.as_deref() == Some(kid));
216 :
217 0 : match jwk {
218 8 : Some(jwk) => break jwk,
219 0 : None if guard.last_retrieved.elapsed() > MIN_RENEW => {
220 0 : let permit = self.acquire_permit().await;
221 0 : guard = self.renew_jwks(permit, client, fetch).await?;
222 : }
223 : _ => {
224 0 : bail!("jwk not found");
225 : }
226 : }
227 : };
228 :
229 8 : ensure!(
230 8 : jwk.is_supported(&header.alg),
231 0 : "signature algorithm not supported"
232 : );
233 :
234 8 : match &jwk.key {
235 4 : jose_jwk::Key::Ec(key) => {
236 4 : verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
237 : }
238 4 : jose_jwk::Key::Rsa(key) => {
239 4 : verify_rsa_signature(header_payload.as_bytes(), &sig, key, &jwk.prm.alg)?;
240 : }
241 0 : key => bail!("unsupported key type {key:?}"),
242 : };
243 :
244 : // TODO(conrad): verify iss, exp, nbf, etc...
245 :
246 8 : Ok(())
247 8 : }
248 : }
249 :
250 : impl JwkCache {
251 0 : pub async fn check_jwt(
252 0 : &self,
253 0 : endpoint: EndpointIdInt,
254 0 : jwt: String,
255 0 : ) -> Result<(), anyhow::Error> {
256 0 : // try with just a read lock first
257 0 : let entry = self.map.get(&endpoint).as_deref().map(Arc::clone);
258 0 : let entry = match entry {
259 0 : Some(entry) => entry,
260 : None => {
261 : // acquire a write lock after to insert.
262 0 : let entry = self.map.entry(endpoint).or_default();
263 0 : Arc::clone(&*entry)
264 : }
265 : };
266 :
267 0 : let fetch = FetchAuthRulesFromCplane { endpoint };
268 0 : entry.check_jwt(jwt, &self.client, &fetch).await
269 0 : }
270 : }
271 :
272 4 : fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> anyhow::Result<()> {
273 4 : use ecdsa::Signature;
274 4 : use signature::Verifier;
275 4 :
276 4 : match key.crv {
277 : jose_jwk::EcCurves::P256 => {
278 4 : let pk =
279 4 : p256::PublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid P256 key"))?;
280 4 : let key = p256::ecdsa::VerifyingKey::from(&pk);
281 4 : let sig = Signature::from_slice(sig)?;
282 4 : key.verify(data, &sig)?;
283 : }
284 0 : key => bail!("unsupported ec key type {key:?}"),
285 : }
286 :
287 4 : Ok(())
288 4 : }
289 :
290 4 : fn verify_rsa_signature(
291 4 : data: &[u8],
292 4 : sig: &[u8],
293 4 : key: &jose_jwk::Rsa,
294 4 : alg: &Option<jose_jwa::Algorithm>,
295 4 : ) -> anyhow::Result<()> {
296 : use jose_jwa::{Algorithm, Signing};
297 : use rsa::{
298 : pkcs1v15::{Signature, VerifyingKey},
299 : RsaPublicKey,
300 : };
301 :
302 4 : let key = RsaPublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid RSA key"))?;
303 :
304 4 : match alg {
305 : Some(Algorithm::Signing(Signing::Rs256)) => {
306 4 : let key = VerifyingKey::<sha2::Sha256>::new(key);
307 4 : let sig = Signature::try_from(sig)?;
308 4 : key.verify(data, &sig)?;
309 : }
310 0 : _ => bail!("invalid RSA signing algorithm"),
311 : };
312 :
313 4 : Ok(())
314 4 : }
315 :
316 : /// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
317 32 : #[derive(serde::Deserialize, serde::Serialize)]
318 : struct JWTHeader<'a> {
319 : /// must be "JWT"
320 : typ: &'a str,
321 : /// must be a supported alg
322 : alg: jose_jwa::Algorithm,
323 : /// key id, must be provided for our usecase
324 : kid: Option<&'a str>,
325 : }
326 :
327 : struct JwkRenewalPermit<'a> {
328 : inner: Option<JwkRenewalPermitInner<'a>>,
329 : }
330 :
331 : enum JwkRenewalPermitInner<'a> {
332 : Owned(Arc<JwkCacheEntryLock>),
333 : Borrowed(&'a Arc<JwkCacheEntryLock>),
334 : }
335 :
336 : impl JwkRenewalPermit<'_> {
337 0 : fn into_owned(mut self) -> JwkRenewalPermit<'static> {
338 0 : JwkRenewalPermit {
339 0 : inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
340 0 : }
341 0 : }
342 :
343 2 : async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
344 2 : match from.lookup.acquire().await {
345 2 : Ok(permit) => {
346 2 : permit.forget();
347 2 : JwkRenewalPermit {
348 2 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
349 2 : }
350 : }
351 0 : Err(_) => panic!("semaphore should not be closed"),
352 : }
353 2 : }
354 :
355 0 : fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
356 0 : match from.lookup.try_acquire() {
357 0 : Ok(permit) => {
358 0 : permit.forget();
359 0 : Some(JwkRenewalPermit {
360 0 : inner: Some(JwkRenewalPermitInner::Borrowed(from)),
361 0 : })
362 : }
363 0 : Err(tokio::sync::TryAcquireError::NoPermits) => None,
364 0 : Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
365 : }
366 0 : }
367 : }
368 :
369 : impl JwkRenewalPermitInner<'_> {
370 0 : fn into_owned(self) -> JwkRenewalPermitInner<'static> {
371 0 : match self {
372 0 : JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
373 0 : JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
374 : }
375 0 : }
376 : }
377 :
378 : impl Drop for JwkRenewalPermit<'_> {
379 2 : fn drop(&mut self) {
380 2 : let entry = match &self.inner {
381 0 : None => return,
382 0 : Some(JwkRenewalPermitInner::Owned(p)) => p,
383 2 : Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
384 : };
385 2 : entry.lookup.add_permits(1);
386 2 : }
387 : }
388 :
389 : #[cfg(test)]
390 : mod tests {
391 : use super::*;
392 :
393 : use std::{future::IntoFuture, net::SocketAddr, time::SystemTime};
394 :
395 : use base64::URL_SAFE_NO_PAD;
396 : use bytes::Bytes;
397 : use http::Response;
398 : use http_body_util::Full;
399 : use hyper1::service::service_fn;
400 : use hyper_util::rt::TokioIo;
401 : use rand::rngs::OsRng;
402 : use signature::Signer;
403 : use tokio::net::TcpListener;
404 :
405 4 : fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
406 4 : let sk = p256::SecretKey::random(&mut OsRng);
407 4 : let pk = sk.public_key().into();
408 4 : let jwk = jose_jwk::Jwk {
409 4 : key: jose_jwk::Key::Ec(pk),
410 4 : prm: jose_jwk::Parameters {
411 4 : kid: Some(kid),
412 4 : alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
413 4 : ..Default::default()
414 4 : },
415 4 : };
416 4 : (sk, jwk)
417 4 : }
418 :
419 4 : fn new_rsa_jwk(kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
420 4 : let sk = rsa::RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
421 4 : let pk = sk.to_public_key().into();
422 4 : let jwk = jose_jwk::Jwk {
423 4 : key: jose_jwk::Key::Rsa(pk),
424 4 : prm: jose_jwk::Parameters {
425 4 : kid: Some(kid),
426 4 : alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
427 4 : ..Default::default()
428 4 : },
429 4 : };
430 4 : (sk, jwk)
431 4 : }
432 :
433 8 : fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
434 8 : let header = JWTHeader {
435 8 : typ: "JWT",
436 8 : alg: jose_jwa::Algorithm::Signing(sig),
437 8 : kid: Some(&kid),
438 8 : };
439 8 : let body = typed_json::json! {{
440 8 : "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
441 8 : }};
442 8 :
443 8 : let header =
444 8 : base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
445 8 : let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
446 8 :
447 8 : format!("{header}.{body}")
448 8 : }
449 :
450 4 : fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String {
451 4 : use p256::ecdsa::{Signature, SigningKey};
452 4 :
453 4 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
454 4 : let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
455 4 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
456 4 :
457 4 : format!("{payload}.{sig}")
458 4 : }
459 :
460 4 : fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
461 4 : use rsa::pkcs1v15::SigningKey;
462 4 : use rsa::signature::SignatureEncoding;
463 4 :
464 4 : let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
465 4 : let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
466 4 : let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
467 4 :
468 4 : format!("{payload}.{sig}")
469 4 : }
470 :
471 : #[tokio::test]
472 2 : async fn renew() {
473 2 : let (rs1, jwk1) = new_rsa_jwk("1".into());
474 2 : let (rs2, jwk2) = new_rsa_jwk("2".into());
475 2 : let (ec1, jwk3) = new_ec_jwk("3".into());
476 2 : let (ec2, jwk4) = new_ec_jwk("4".into());
477 2 :
478 2 : let jwt1 = new_rsa_jwt("1".into(), rs1);
479 2 : let jwt2 = new_rsa_jwt("2".into(), rs2);
480 2 : let jwt3 = new_ec_jwt("3".into(), ec1);
481 2 : let jwt4 = new_ec_jwt("4".into(), ec2);
482 2 :
483 2 : let foo_jwks = jose_jwk::JwkSet {
484 2 : keys: vec![jwk1, jwk3],
485 2 : };
486 2 : let bar_jwks = jose_jwk::JwkSet {
487 2 : keys: vec![jwk2, jwk4],
488 2 : };
489 2 :
490 4 : let service = service_fn(move |req| {
491 4 : let foo_jwks = foo_jwks.clone();
492 4 : let bar_jwks = bar_jwks.clone();
493 4 : async move {
494 4 : let jwks = match req.uri().path() {
495 4 : "/foo" => &foo_jwks,
496 2 : "/bar" => &bar_jwks,
497 2 : _ => {
498 2 : return Response::builder()
499 0 : .status(404)
500 0 : .body(Full::new(Bytes::new()));
501 2 : }
502 2 : };
503 4 : let body = serde_json::to_vec(jwks).unwrap();
504 4 : Response::builder()
505 4 : .status(200)
506 4 : .body(Full::new(Bytes::from(body)))
507 4 : }
508 4 : });
509 2 :
510 2 : let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
511 2 : let server = hyper1::server::conn::http1::Builder::new();
512 2 : let addr = listener.local_addr().unwrap();
513 2 : tokio::spawn(async move {
514 2 : loop {
515 4 : let (s, _) = listener.accept().await.unwrap();
516 2 : let serve = server.serve_connection(TokioIo::new(s), service.clone());
517 2 : tokio::spawn(serve.into_future());
518 2 : }
519 2 : });
520 2 :
521 2 : let client = reqwest::Client::new();
522 2 :
523 2 : #[derive(Clone)]
524 2 : struct Fetch(SocketAddr);
525 2 :
526 2 : impl FetchAuthRules for Fetch {
527 2 : async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
528 2 : Ok(AuthRules {
529 2 : jwks_urls: vec![
530 2 : format!("http://{}/foo", self.0).parse().unwrap(),
531 2 : format!("http://{}/bar", self.0).parse().unwrap(),
532 2 : ],
533 2 : })
534 2 : }
535 2 : }
536 2 :
537 2 : let jwk_cache = Arc::new(JwkCacheEntryLock::default());
538 2 :
539 2 : jwk_cache
540 2 : .check_jwt(jwt1, &client, &Fetch(addr))
541 8 : .await
542 2 : .unwrap();
543 2 : jwk_cache
544 2 : .check_jwt(jwt2, &client, &Fetch(addr))
545 2 : .await
546 2 : .unwrap();
547 2 : jwk_cache
548 2 : .check_jwt(jwt3, &client, &Fetch(addr))
549 2 : .await
550 2 : .unwrap();
551 2 : jwk_cache
552 2 : .check_jwt(jwt4, &client, &Fetch(addr))
553 2 : .await
554 2 : .unwrap();
555 2 : }
556 : }
|