LCOV - code coverage report
Current view: top level - libs/utils/src - auth.rs (source / functions) Coverage Total Hit
Test: 4be46b1c0003aa3bbac9ade362c676b419df4c20.info Lines: 40.3 % 119 48
Test Date: 2025-07-22 17:50:06 Functions: 19.2 % 26 5

            Line data    Source code
       1              : // For details about authentication see docs/authentication.md
       2              : 
       3              : use std::borrow::Cow;
       4              : use std::fmt::Display;
       5              : use std::fs;
       6              : use std::sync::Arc;
       7              : 
       8              : use anyhow::Result;
       9              : use arc_swap::ArcSwap;
      10              : use camino::Utf8Path;
      11              : use jsonwebtoken::{
      12              :     Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
      13              : };
      14              : use pem::Pem;
      15              : use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned};
      16              : use uuid::Uuid;
      17              : 
      18              : use crate::id::TenantId;
      19              : 
      20              : /// Algorithm to use. We require EdDSA.
      21              : const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
      22              : 
      23            0 : #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
      24              : #[serde(rename_all = "lowercase")]
      25              : pub enum Scope {
      26              :     /// Provides access to all data for a specific tenant (specified in `struct Claims` below)
      27              :     // TODO: join these two?
      28              :     Tenant,
      29              :     /// Provides access to all data for a specific tenant, but based on endpoint ID. This token scope
      30              :     /// is only used by compute to fetch the spec for a specific endpoint. The spec contains a Tenant-scoped
      31              :     /// token authorizing access to all data of a tenant, so the spec-fetch API requires a TenantEndpoint
      32              :     /// scope token to ensure that untrusted compute nodes can't fetch spec for arbitrary endpoints.
      33              :     TenantEndpoint,
      34              :     /// Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
      35              :     /// Should only be used e.g. for status check/tenant creation/list.
      36              :     PageServerApi,
      37              :     /// Provides blanket access to all data on the safekeeper plus safekeeper-wide APIs.
      38              :     /// Should only be used e.g. for status check.
      39              :     /// Currently also used for connection from any pageserver to any safekeeper.
      40              :     SafekeeperData,
      41              :     /// The scope used by pageservers in upcalls to storage controller and cloud control plane
      42              :     #[serde(rename = "generations_api")]
      43              :     GenerationsApi,
      44              :     /// Allows access to control plane managment API and all storage controller endpoints.
      45              :     Admin,
      46              : 
      47              :     /// Allows access to control plane & storage controller endpoints used in infrastructure automation (e.g. node registration)
      48              :     Infra,
      49              : 
      50              :     /// Allows access to storage controller APIs used by the scrubber, to interrogate the state
      51              :     /// of a tenant & post scrub results.
      52              :     Scrubber,
      53              : 
      54              :     /// This scope is used for communication with other storage controller instances.
      55              :     /// At the time of writing, this is only used for the step down request.
      56              :     #[serde(rename = "controller_peer")]
      57              :     ControllerPeer,
      58              : }
      59              : 
      60            0 : fn deserialize_empty_string_as_none_uuid<'de, D>(deserializer: D) -> Result<Option<Uuid>, D::Error>
      61            0 : where
      62            0 :     D: Deserializer<'de>,
      63              : {
      64            0 :     let opt = Option::<String>::deserialize(deserializer)?;
      65            0 :     match opt.as_deref() {
      66            0 :         Some("") => Ok(None),
      67            0 :         Some(s) => Uuid::parse_str(s)
      68            0 :             .map(Some)
      69            0 :             .map_err(serde::de::Error::custom),
      70            0 :         None => Ok(None),
      71              :     }
      72            0 : }
      73              : 
      74              : /// JWT payload. See docs/authentication.md for the format
      75              : #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
      76              : pub struct Claims {
      77              :     #[serde(default)]
      78              :     pub tenant_id: Option<TenantId>,
      79              :     #[serde(
      80              :         default,
      81              :         skip_serializing_if = "Option::is_none",
      82              :         // Neon control plane includes this field as empty in the claims.
      83              :         // Consider it None in those cases.
      84              :         deserialize_with = "deserialize_empty_string_as_none_uuid"
      85              :     )]
      86              :     pub endpoint_id: Option<Uuid>,
      87              :     pub scope: Scope,
      88              : }
      89              : 
      90              : impl Claims {
      91            0 :     pub fn new(tenant_id: Option<TenantId>, scope: Scope) -> Self {
      92            0 :         Self {
      93            0 :             tenant_id,
      94            0 :             scope,
      95            0 :             endpoint_id: None,
      96            0 :         }
      97            0 :     }
      98              : }
      99              : 
     100              : pub struct SwappableJwtAuth(ArcSwap<JwtAuth>);
     101              : 
     102              : impl SwappableJwtAuth {
     103            0 :     pub fn new(jwt_auth: JwtAuth) -> Self {
     104            0 :         SwappableJwtAuth(ArcSwap::new(Arc::new(jwt_auth)))
     105            0 :     }
     106            0 :     pub fn swap(&self, jwt_auth: JwtAuth) {
     107            0 :         self.0.swap(Arc::new(jwt_auth));
     108            0 :     }
     109            0 :     pub fn decode<D: DeserializeOwned>(
     110            0 :         &self,
     111            0 :         token: &str,
     112            0 :     ) -> std::result::Result<TokenData<D>, AuthError> {
     113            0 :         self.0.load().decode(token)
     114            0 :     }
     115              : }
     116              : 
     117              : impl std::fmt::Debug for SwappableJwtAuth {
     118            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     119            0 :         write!(f, "Swappable({:?})", self.0.load())
     120            0 :     }
     121              : }
     122              : 
     123              : #[derive(Clone, PartialEq, Eq, Hash, Debug)]
     124              : pub struct AuthError(pub Cow<'static, str>);
     125              : 
     126              : impl Display for AuthError {
     127            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     128            0 :         write!(f, "{}", self.0)
     129            0 :     }
     130              : }
     131              : 
     132              : pub struct JwtAuth {
     133              :     decoding_keys: Vec<DecodingKey>,
     134              :     validation: Validation,
     135              : }
     136              : 
     137              : impl JwtAuth {
     138            2 :     pub fn new(decoding_keys: Vec<DecodingKey>) -> Self {
     139            2 :         let mut validation = Validation::default();
     140            2 :         validation.algorithms = vec![STORAGE_TOKEN_ALGORITHM];
     141              :         // The default 'required_spec_claims' is 'exp'. But we don't want to require
     142              :         // expiration.
     143            2 :         validation.required_spec_claims = [].into();
     144            2 :         Self {
     145            2 :             decoding_keys,
     146            2 :             validation,
     147            2 :         }
     148            2 :     }
     149              : 
     150            0 :     pub fn from_key_path(key_path: &Utf8Path) -> Result<Self> {
     151            0 :         let metadata = key_path.metadata()?;
     152            0 :         let decoding_keys = if metadata.is_dir() {
     153            0 :             let mut keys = Vec::new();
     154            0 :             for entry in fs::read_dir(key_path)? {
     155            0 :                 let path = entry?.path();
     156            0 :                 if !path.is_file() {
     157              :                     // Ignore directories (don't recurse)
     158            0 :                     continue;
     159            0 :                 }
     160            0 :                 let public_key = fs::read(path)?;
     161            0 :                 keys.push(DecodingKey::from_ed_pem(&public_key)?);
     162              :             }
     163            0 :             keys
     164            0 :         } else if metadata.is_file() {
     165            0 :             let public_key = fs::read(key_path)?;
     166            0 :             vec![DecodingKey::from_ed_pem(&public_key)?]
     167              :         } else {
     168            0 :             anyhow::bail!("path is neither a directory or a file")
     169              :         };
     170            0 :         if decoding_keys.is_empty() {
     171            0 :             anyhow::bail!(
     172            0 :                 "Configured for JWT auth with zero decoding keys. All JWT gated requests would be rejected."
     173              :             );
     174            0 :         }
     175            0 :         Ok(Self::new(decoding_keys))
     176            0 :     }
     177              : 
     178            0 :     pub fn from_key(key: String) -> Result<Self> {
     179            0 :         Ok(Self::new(vec![DecodingKey::from_ed_pem(key.as_bytes())?]))
     180            0 :     }
     181              : 
     182              :     /// Attempt to decode the token with the internal decoding keys.
     183              :     ///
     184              :     /// The function tries the stored decoding keys in succession,
     185              :     /// and returns the first yielding a successful result.
     186              :     /// If there is no working decoding key, it returns the last error.
     187            2 :     pub fn decode<D: DeserializeOwned>(
     188            2 :         &self,
     189            2 :         token: &str,
     190            2 :     ) -> std::result::Result<TokenData<D>, AuthError> {
     191            2 :         let mut res = None;
     192            2 :         for decoding_key in &self.decoding_keys {
     193            2 :             res = Some(decode(token, decoding_key, &self.validation));
     194            2 :             if let Some(Ok(res)) = res {
     195            2 :                 return Ok(res);
     196            0 :             }
     197              :         }
     198            0 :         if let Some(res) = res {
     199            0 :             res.map_err(|e| AuthError(Cow::Owned(e.to_string())))
     200              :         } else {
     201            0 :             Err(AuthError(Cow::Borrowed("no JWT decoding keys configured")))
     202              :         }
     203            2 :     }
     204              : }
     205              : 
     206              : impl std::fmt::Debug for JwtAuth {
     207            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     208            0 :         f.debug_struct("JwtAuth")
     209            0 :             .field("validation", &self.validation)
     210            0 :             .finish()
     211            0 :     }
     212              : }
     213              : 
     214              : // this function is used only for testing purposes in CLI e g generate tokens during init
     215            1 : pub fn encode_from_key_file<S: Serialize>(claims: &S, pem: &Pem) -> Result<String> {
     216            1 :     let key = EncodingKey::from_ed_der(pem.contents());
     217            1 :     Ok(encode(&Header::new(STORAGE_TOKEN_ALGORITHM), claims, &key)?)
     218            1 : }
     219              : 
     220              : #[cfg(test)]
     221              : mod tests {
     222              :     use std::str::FromStr;
     223              : 
     224              :     use super::*;
     225              : 
     226              :     // Generated with:
     227              :     //
     228              :     // openssl genpkey -algorithm ed25519 -out ed25519-priv.pem
     229              :     // openssl pkey -in ed25519-priv.pem -pubout -out ed25519-pub.pem
     230              :     const TEST_PUB_KEY_ED25519: &str = r#"
     231              : -----BEGIN PUBLIC KEY-----
     232              : MCowBQYDK2VwAyEARYwaNBayR+eGI0iXB4s3QxE3Nl2g1iWbr6KtLWeVD/w=
     233              : -----END PUBLIC KEY-----
     234              : "#;
     235              : 
     236              :     const TEST_PRIV_KEY_ED25519: &str = r#"
     237              : -----BEGIN PRIVATE KEY-----
     238              : MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
     239              : -----END PRIVATE KEY-----
     240              : "#;
     241              : 
     242              :     #[test]
     243            1 :     fn test_decode() {
     244            1 :         let expected_claims = Claims {
     245            1 :             tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
     246            1 :             scope: Scope::Tenant,
     247            1 :             endpoint_id: None,
     248            1 :         };
     249              : 
     250              :         // A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
     251              :         //
     252              :         // ```
     253              :         // {
     254              :         //   "scope": "tenant",
     255              :         //   "tenant_id": "3d1f7595b468230304e0b73cecbcb081",
     256              :         //   "iss": "neon.controlplane",
     257              :         //   "iat": 1678442479
     258              :         // }
     259              :         // ```
     260              :         //
     261            1 :         let encoded_eddsa = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJpYXQiOjE2Nzg0NDI0Nzl9.rNheBnluMJNgXzSTTJoTNIGy4P_qe0JUHl_nVEGuDCTgHOThPVr552EnmKccrCKquPeW3c2YUk0Y9Oh4KyASAw";
     262              : 
     263              :         // Check it can be validated with the public key
     264            1 :         let auth = JwtAuth::new(vec![
     265            1 :             DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519.as_bytes()).unwrap(),
     266              :         ]);
     267            1 :         let claims_from_token: Claims = auth.decode(encoded_eddsa).unwrap().claims;
     268            1 :         assert_eq!(claims_from_token, expected_claims);
     269            1 :     }
     270              : 
     271              :     #[test]
     272            1 :     fn test_encode() {
     273            1 :         let claims = Claims {
     274            1 :             tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
     275            1 :             scope: Scope::Tenant,
     276            1 :             endpoint_id: None,
     277            1 :         };
     278              : 
     279            1 :         let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap();
     280            1 :         let encoded = encode_from_key_file(&claims, &pem).unwrap();
     281              : 
     282              :         // decode it back
     283            1 :         let auth = JwtAuth::new(vec![
     284            1 :             DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519.as_bytes()).unwrap(),
     285              :         ]);
     286            1 :         let decoded: TokenData<Claims> = auth.decode(&encoded).unwrap();
     287              : 
     288            1 :         assert_eq!(decoded.claims, claims);
     289            1 :     }
     290              : }
        

Generated by: LCOV version 2.1-beta