Line data Source code
1 : // For details about authentication see docs/authentication.md
2 :
3 : use std::fmt::Display;
4 : use std::fs;
5 : use std::sync::Arc;
6 : use std::{borrow::Cow, io, path::Path};
7 :
8 : use anyhow::Result;
9 : use arc_swap::ArcSwap;
10 : use camino::Utf8Path;
11 : use jsonwebtoken::{
12 : Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
13 : };
14 : use oid_registry::OID_PKCS1_RSAENCRYPTION;
15 : use pem::Pem;
16 : use rustls_pki_types::CertificateDer;
17 : use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned};
18 : use uuid::Uuid;
19 :
20 : use crate::id::TenantId;
21 :
22 : /// Signature algorithms to use. We allow EdDSA and RSA/SHA-256.
23 : const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
24 : const HADRON_STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::RS256;
25 :
26 0 : #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
27 : #[serde(rename_all = "lowercase")]
28 : pub enum Scope {
29 : /// Provides access to all data for a specific tenant (specified in `struct Claims` below)
30 : // TODO: join these two?
31 : Tenant,
32 : /// Provides access to all data for a specific tenant, but based on endpoint ID. This token scope
33 : /// is only used by compute to fetch the spec for a specific endpoint. The spec contains a Tenant-scoped
34 : /// token authorizing access to all data of a tenant, so the spec-fetch API requires a TenantEndpoint
35 : /// scope token to ensure that untrusted compute nodes can't fetch spec for arbitrary endpoints.
36 : TenantEndpoint,
37 : /// Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
38 : /// Should only be used e.g. for status check/tenant creation/list.
39 : PageServerApi,
40 : /// Provides blanket access to all data on the safekeeper plus safekeeper-wide APIs.
41 : /// Should only be used e.g. for status check.
42 : /// Currently also used for connection from any pageserver to any safekeeper.
43 : SafekeeperData,
44 : /// The scope used by pageservers in upcalls to storage controller and cloud control plane
45 : #[serde(rename = "generations_api")]
46 : GenerationsApi,
47 : /// Allows access to control plane managment API and all storage controller endpoints.
48 : Admin,
49 :
50 : /// Allows access to control plane & storage controller endpoints used in infrastructure automation (e.g. node registration)
51 : Infra,
52 :
53 : /// Allows access to storage controller APIs used by the scrubber, to interrogate the state
54 : /// of a tenant & post scrub results.
55 : Scrubber,
56 :
57 : /// This scope is used for communication with other storage controller instances.
58 : /// At the time of writing, this is only used for the step down request.
59 : #[serde(rename = "controller_peer")]
60 : ControllerPeer,
61 : }
62 :
63 0 : fn deserialize_empty_string_as_none_uuid<'de, D>(deserializer: D) -> Result<Option<Uuid>, D::Error>
64 0 : where
65 0 : D: Deserializer<'de>,
66 : {
67 0 : let opt = Option::<String>::deserialize(deserializer)?;
68 0 : match opt.as_deref() {
69 0 : Some("") => Ok(None),
70 0 : Some(s) => Uuid::parse_str(s)
71 0 : .map(Some)
72 0 : .map_err(serde::de::Error::custom),
73 0 : None => Ok(None),
74 : }
75 0 : }
76 :
77 : /// JWT payload. See docs/authentication.md for the format
78 : #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
79 : pub struct Claims {
80 : #[serde(default)]
81 : pub tenant_id: Option<TenantId>,
82 : #[serde(
83 : default,
84 : skip_serializing_if = "Option::is_none",
85 : // Neon control plane includes this field as empty in the claims.
86 : // Consider it None in those cases.
87 : deserialize_with = "deserialize_empty_string_as_none_uuid"
88 : )]
89 : pub endpoint_id: Option<Uuid>,
90 : pub scope: Scope,
91 : }
92 :
93 : impl Claims {
94 0 : pub fn new(tenant_id: Option<TenantId>, scope: Scope) -> Self {
95 0 : Self {
96 0 : tenant_id,
97 0 : scope,
98 0 : endpoint_id: None,
99 0 : }
100 0 : }
101 :
102 0 : pub fn new_for_endpoint(endpoint_id: Uuid) -> Self {
103 0 : Self {
104 0 : tenant_id: None,
105 0 : endpoint_id: Some(endpoint_id),
106 0 : scope: Scope::TenantEndpoint,
107 0 : }
108 0 : }
109 : }
110 :
111 : pub struct SwappableJwtAuth(ArcSwap<JwtAuth>);
112 :
113 : impl SwappableJwtAuth {
114 0 : pub fn new(jwt_auth: JwtAuth) -> Self {
115 0 : SwappableJwtAuth(ArcSwap::new(Arc::new(jwt_auth)))
116 0 : }
117 0 : pub fn swap(&self, jwt_auth: JwtAuth) {
118 0 : self.0.swap(Arc::new(jwt_auth));
119 0 : }
120 0 : pub fn decode<D: DeserializeOwned>(
121 0 : &self,
122 0 : token: &str,
123 0 : ) -> std::result::Result<TokenData<D>, AuthError> {
124 0 : self.0.load().decode(token)
125 0 : }
126 : }
127 :
128 : impl std::fmt::Debug for SwappableJwtAuth {
129 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 0 : write!(f, "Swappable({:?})", self.0.load())
131 0 : }
132 : }
133 :
134 : #[derive(Clone, PartialEq, Eq, Hash, Debug)]
135 : pub struct AuthError(pub Cow<'static, str>);
136 :
137 : impl Display for AuthError {
138 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 0 : write!(f, "{}", self.0)
140 0 : }
141 : }
142 :
143 : pub struct JwtAuth {
144 : decoding_keys: Vec<DecodingKey>,
145 : validation: Validation,
146 : }
147 :
148 : impl JwtAuth {
149 37 : pub fn new(decoding_keys: Vec<DecodingKey>) -> Self {
150 37 : let mut validation = Validation::default();
151 37 : validation.algorithms = vec![STORAGE_TOKEN_ALGORITHM];
152 : // The default 'required_spec_claims' is 'exp'. But we don't want to require
153 : // expiration.
154 37 : validation.required_spec_claims = [].into();
155 37 : Self {
156 37 : decoding_keys,
157 37 : validation,
158 37 : }
159 37 : }
160 :
161 0 : pub fn from_key_path(key_path: &Utf8Path) -> Result<Self> {
162 0 : let metadata = key_path.metadata()?;
163 0 : let decoding_keys = if metadata.is_dir() {
164 0 : let mut keys = Vec::new();
165 0 : for entry in fs::read_dir(key_path)? {
166 0 : let path = entry?.path();
167 0 : if !path.is_file() {
168 : // Ignore directories (don't recurse)
169 0 : continue;
170 0 : }
171 0 : let public_key = fs::read(path)?;
172 0 : keys.push(DecodingKey::from_ed_pem(&public_key)?);
173 : }
174 0 : keys
175 0 : } else if metadata.is_file() {
176 0 : let public_key = fs::read(key_path)?;
177 0 : vec![DecodingKey::from_ed_pem(&public_key)?]
178 : } else {
179 0 : anyhow::bail!("path is neither a directory or a file")
180 : };
181 0 : if decoding_keys.is_empty() {
182 0 : anyhow::bail!(
183 0 : "Configured for JWT auth with zero decoding keys. All JWT gated requests would be rejected."
184 : );
185 0 : }
186 0 : Ok(Self::new(decoding_keys))
187 0 : }
188 :
189 : // Helper function to parse a X509 certificate file and extract the RSA public keys from it as `DecodingKey`s.
190 : // - `ceritificate_file_path`: the path to the certificate file. It must be a file, not a directory or anything else.
191 : // Returns the successfully extracted decoding keys. Non-RSA keys and non-X509-parsable certificates are skipped.
192 : // Multuple keys may be returned because a single file can contain multiple certificates.
193 5 : fn extract_rsa_decoding_keys_from_certificate<P: AsRef<Path>>(
194 5 : certificate_file_path: P,
195 5 : ) -> Result<Vec<DecodingKey>> {
196 5 : let certs: io::Result<Vec<CertificateDer<'static>>> = rustls_pemfile::certs(
197 5 : &mut io::BufReader::new(fs::File::open(certificate_file_path)?),
198 : )
199 5 : .collect();
200 :
201 5 : Ok(certs?
202 5 : .iter()
203 5 : .filter_map(
204 3 : |cert| match x509_parser::parse_x509_certificate(cert) {
205 3 : Ok((_, cert)) => {
206 3 : let public_key = cert.public_key();
207 : // Note that we are just extracting the public key from the certificate, not the signature.
208 : // So the algorithm is just the asymmetric crypto such as RSA, no hashes of or anything like
209 : // that.
210 3 : if *public_key.algorithm.oid() == OID_PKCS1_RSAENCRYPTION {
211 3 : Some(DecodingKey::from_rsa_der(&public_key.subject_public_key.data))
212 : } else {
213 0 : tracing::warn!(
214 0 : "Unsupported public key algorithm: {:?} found in certificate. Skipping.",
215 : public_key.algorithm
216 : );
217 0 : None
218 : }
219 : }
220 0 : Err(e) => {
221 0 : tracing::warn!("Error parsing certificate: {}. Skipping.", e);
222 0 : None
223 : }
224 3 : },
225 : )
226 5 : .collect())
227 5 : }
228 :
229 : /// Create a `JwtAuth` that can decode tokens using RSA public keys in X509 certificates from the given path.
230 : /// - `cert_path`: the path to a directory or a file containing X509 certificates. If it is a directory, all files
231 : /// under the first level of the directory will be inspected for certificates.
232 : /// Returns the `JwtAuth` with the decoding keys extracted from the certificates, or error.
233 : /// Used by Hadron.
234 2 : pub fn from_cert_path(cert_path: &Utf8Path) -> Result<Self> {
235 2 : tracing::info!(
236 0 : "Loading public keys in certificates from path: {}",
237 : cert_path
238 : );
239 :
240 2 : let mut decoding_keys = Vec::new();
241 :
242 2 : let metadata = cert_path.metadata()?;
243 2 : if metadata.is_dir() {
244 4 : for entry in fs::read_dir(cert_path)? {
245 4 : let path = entry?.path();
246 4 : if !path.is_file() {
247 : // Ignore directories (don't recurse)
248 0 : continue;
249 4 : }
250 4 : decoding_keys.extend(
251 4 : Self::extract_rsa_decoding_keys_from_certificate(path).unwrap_or_default(),
252 : );
253 : }
254 1 : } else if metadata.is_file() {
255 1 : decoding_keys.extend(
256 1 : Self::extract_rsa_decoding_keys_from_certificate(cert_path).unwrap_or_default(),
257 1 : );
258 1 : } else {
259 0 : anyhow::bail!("{cert_path} is neither a directory or a file")
260 : }
261 2 : if decoding_keys.is_empty() {
262 0 : anyhow::bail!(
263 0 : "Configured for JWT auth with zero decoding keys. All JWT gated requests would be rejected."
264 : );
265 2 : }
266 :
267 : // Note that we need to create a `JwtAuth` with a different `validation` from the default one created by `new()` in this case
268 : // because the `jsonwebtoken` crate requires that all algorithms in `validation.algorithms` belong to the same algorithm family
269 : // (all RSA or all EdDSA).
270 2 : let mut validation = Validation::default();
271 2 : validation.algorithms = vec![HADRON_STORAGE_TOKEN_ALGORITHM];
272 2 : validation.required_spec_claims = [].into();
273 2 : Ok(Self {
274 2 : validation,
275 2 : decoding_keys,
276 2 : })
277 2 : }
278 :
279 0 : pub fn from_key(key: String) -> Result<Self> {
280 0 : Ok(Self::new(vec![DecodingKey::from_ed_pem(key.as_bytes())?]))
281 0 : }
282 :
283 : /// Attempt to decode the token with the internal decoding keys.
284 : ///
285 : /// The function tries the stored decoding keys in succession,
286 : /// and returns the first yielding a successful result.
287 : /// If there is no working decoding key, it returns the last error.
288 123 : pub fn decode<D: DeserializeOwned>(
289 123 : &self,
290 123 : token: &str,
291 123 : ) -> std::result::Result<TokenData<D>, AuthError> {
292 123 : let mut res = None;
293 125 : for decoding_key in &self.decoding_keys {
294 124 : res = Some(decode(token, decoding_key, &self.validation));
295 124 : if let Some(Ok(res)) = res {
296 122 : return Ok(res);
297 2 : }
298 : }
299 1 : if let Some(res) = res {
300 1 : res.map_err(|e| AuthError(Cow::Owned(e.to_string())))
301 : } else {
302 0 : Err(AuthError(Cow::Borrowed("no JWT decoding keys configured")))
303 : }
304 6 : }
305 : }
306 :
307 : impl std::fmt::Debug for JwtAuth {
308 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 0 : f.debug_struct("JwtAuth")
310 0 : .field("validation", &self.validation)
311 0 : .finish()
312 0 : }
313 : }
314 :
315 : // this function is used only for testing purposes in CLI e g generate tokens during init
316 1 : pub fn encode_from_key_file<S: Serialize>(claims: &S, pem: &Pem) -> Result<String> {
317 1 : let key = EncodingKey::from_ed_der(pem.contents());
318 1 : Ok(encode(&Header::new(STORAGE_TOKEN_ALGORITHM), claims, &key)?)
319 1 : }
320 :
321 : /// Encode (i.e., sign) a Hadron auth token with the given claims and RSA private key. This is used
322 : /// by HCC to sign tokens when deploying compute or returning the compute spec. The resulting token
323 : /// is used by the compute node to authenticate with HCC and PS/SK.
324 2 : pub fn encode_hadron_token<S: Serialize>(claims: &S, key_data: &[u8]) -> Result<String> {
325 2 : let key = EncodingKey::from_rsa_pem(key_data)?;
326 2 : encode_hadron_token_with_encoding_key(claims, &key)
327 2 : }
328 :
329 2 : pub fn encode_hadron_token_with_encoding_key<S: Serialize>(
330 2 : claims: &S,
331 2 : encoding_key: &EncodingKey,
332 2 : ) -> Result<String> {
333 2 : Ok(encode(
334 2 : &Header::new(HADRON_STORAGE_TOKEN_ALGORITHM),
335 2 : claims,
336 2 : encoding_key,
337 0 : )?)
338 2 : }
339 :
340 : #[cfg(test)]
341 : mod tests {
342 : use io::Write;
343 : use std::str::FromStr;
344 :
345 : use super::*;
346 :
347 : // Generated with:
348 : //
349 : // openssl genpkey -algorithm ed25519 -out ed25519-priv.pem
350 : // openssl pkey -in ed25519-priv.pem -pubout -out ed25519-pub.pem
351 : const TEST_PUB_KEY_ED25519: &str = r#"
352 : -----BEGIN PUBLIC KEY-----
353 : MCowBQYDK2VwAyEARYwaNBayR+eGI0iXB4s3QxE3Nl2g1iWbr6KtLWeVD/w=
354 : -----END PUBLIC KEY-----
355 : "#;
356 :
357 : const TEST_PRIV_KEY_ED25519: &str = r#"
358 : -----BEGIN PRIVATE KEY-----
359 : MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
360 : -----END PRIVATE KEY-----
361 : "#;
362 :
363 : #[test]
364 1 : fn test_decode() {
365 1 : let expected_claims = Claims {
366 1 : tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
367 1 : endpoint_id: None,
368 1 : scope: Scope::Tenant,
369 1 : };
370 :
371 : // A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
372 : //
373 : // ```
374 : // {
375 : // "scope": "tenant",
376 : // "tenant_id": "3d1f7595b468230304e0b73cecbcb081",
377 : // "iss": "neon.controlplane",
378 : // "iat": 1678442479
379 : // }
380 : // ```
381 : //
382 1 : let encoded_eddsa = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJpYXQiOjE2Nzg0NDI0Nzl9.rNheBnluMJNgXzSTTJoTNIGy4P_qe0JUHl_nVEGuDCTgHOThPVr552EnmKccrCKquPeW3c2YUk0Y9Oh4KyASAw";
383 :
384 : // Check it can be validated with the public key
385 1 : let auth = JwtAuth::new(vec![
386 1 : DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519.as_bytes()).unwrap(),
387 : ]);
388 1 : let claims_from_token: Claims = auth.decode(encoded_eddsa).unwrap().claims;
389 1 : assert_eq!(claims_from_token, expected_claims);
390 1 : }
391 :
392 : #[test]
393 1 : fn test_encode() {
394 1 : let claims = Claims {
395 1 : tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
396 1 : endpoint_id: None,
397 1 : scope: Scope::Tenant,
398 1 : };
399 :
400 1 : let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap();
401 1 : let encoded = encode_from_key_file(&claims, &pem).unwrap();
402 :
403 : // decode it back
404 1 : let auth = JwtAuth::new(vec![
405 1 : DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519.as_bytes()).unwrap(),
406 : ]);
407 1 : let decoded: TokenData<Claims> = auth.decode(&encoded).unwrap();
408 :
409 1 : assert_eq!(decoded.claims, claims);
410 1 : }
411 :
412 : #[test]
413 1 : fn test_decode_with_key_from_certificate() {
414 : // Tests that we can sign (encode) a token with a RSA private key and verify (decode) it with the
415 : // corresponding public key extracted from a certificate.
416 :
417 : // Generate two RSA key pairs and create self-signed certificates with it.
418 1 : let key_pair_1 = rcgen::KeyPair::generate_for(&rcgen::PKCS_RSA_SHA256).unwrap();
419 1 : let key_pair_2 = rcgen::KeyPair::generate_for(&rcgen::PKCS_RSA_SHA256).unwrap();
420 1 : let mut params = rcgen::CertificateParams::default();
421 1 : params
422 1 : .distinguished_name
423 1 : .push(rcgen::DnType::CommonName, "eng-brickstore@databricks.com");
424 1 : let cert_1 = params.clone().self_signed(&key_pair_1).unwrap();
425 1 : let cert_2 = params.self_signed(&key_pair_2).unwrap();
426 :
427 : // Write the certificates and keys to a temporary dir.
428 1 : let dir = camino_tempfile::tempdir().unwrap();
429 1 : {
430 1 : fs::File::create(dir.path().join("cert_1.pem"))
431 1 : .unwrap()
432 1 : .write_all(cert_1.pem().as_bytes())
433 1 : .unwrap();
434 1 : fs::File::create(dir.path().join("key_1.pem"))
435 1 : .unwrap()
436 1 : .write_all(key_pair_1.serialize_pem().as_bytes())
437 1 : .unwrap();
438 1 : fs::File::create(dir.path().join("cert_2.pem"))
439 1 : .unwrap()
440 1 : .write_all(cert_2.pem().as_bytes())
441 1 : .unwrap();
442 1 : fs::File::create(dir.path().join("key_2.pem"))
443 1 : .unwrap()
444 1 : .write_all(key_pair_2.serialize_pem().as_bytes())
445 1 : .unwrap();
446 1 : }
447 : // Instantiate a `JwtAuth` with the certificate path. The resulting `JwtAuth` should extract the RSA public
448 : // keys out of the X509 certificates and use them as the decoding keys. Since we specified a directory, both
449 : // X509 certificates will be loaded, but the private key files are skipped.
450 1 : let auth = JwtAuth::from_cert_path(dir.path()).unwrap();
451 1 : assert_eq!(auth.decoding_keys.len(), 2);
452 :
453 : // Also create a `JwtAuth`, specifying a single certificate file for it to get the decoding key from.
454 1 : let auth_cert_1 = JwtAuth::from_cert_path(&dir.path().join("cert_1.pem")).unwrap();
455 1 : assert_eq!(auth_cert_1.decoding_keys.len(), 1);
456 :
457 : // Encode tokens with some claims.
458 1 : let claims = Claims {
459 1 : tenant_id: Some(TenantId::generate()),
460 1 : endpoint_id: None,
461 1 : scope: Scope::Tenant,
462 1 : };
463 1 : let encoded_1 =
464 1 : encode_hadron_token(&claims, key_pair_1.serialize_pem().as_bytes()).unwrap();
465 1 : let encoded_2 =
466 1 : encode_hadron_token(&claims, key_pair_2.serialize_pem().as_bytes()).unwrap();
467 :
468 : // Verify that we can decode the token with matching decoding keys (decoding also verifies the signature).
469 1 : assert_eq!(auth.decode::<Claims>(&encoded_1).unwrap().claims, claims);
470 1 : assert_eq!(auth.decode::<Claims>(&encoded_2).unwrap().claims, claims);
471 1 : assert_eq!(
472 1 : auth_cert_1.decode::<Claims>(&encoded_1).unwrap().claims,
473 : claims
474 : );
475 :
476 : // Verify that the token cannot be decoded with a mismatched decode key.
477 1 : assert!(auth_cert_1.decode::<Claims>(&encoded_2).is_err());
478 1 : }
479 : }
|