LCOV - differential code coverage report
Current view: top level - libs/utils/src - auth.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 89.2 % 120 107 13 107
Current Date: 2024-01-09 02:06:09 Functions: 59.2 % 49 29 20 29
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : // For details about authentication see docs/authentication.md
       2                 : 
       3                 : use arc_swap::ArcSwap;
       4                 : use serde;
       5                 : use std::{borrow::Cow, fmt::Display, fs, sync::Arc};
       6                 : 
       7                 : use anyhow::Result;
       8                 : use camino::Utf8Path;
       9                 : use jsonwebtoken::{
      10                 :     decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
      11                 : };
      12                 : use serde::{Deserialize, Serialize};
      13                 : 
      14                 : use crate::{http::error::ApiError, id::TenantId};
      15                 : 
      16                 : /// Algorithm to use. We require EdDSA.
      17                 : const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
      18                 : 
      19 CBC         686 : #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
      20                 : #[serde(rename_all = "lowercase")]
      21                 : pub enum Scope {
      22                 :     // Provides access to all data for a specific tenant (specified in `struct Claims` below)
      23                 :     // TODO: join these two?
      24                 :     Tenant,
      25                 :     // Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
      26                 :     // Should only be used e.g. for status check/tenant creation/list.
      27                 :     PageServerApi,
      28                 :     // Provides blanket access to all data on the safekeeper plus safekeeper-wide APIs.
      29                 :     // Should only be used e.g. for status check.
      30                 :     // Currently also used for connection from any pageserver to any safekeeper.
      31                 :     SafekeeperData,
      32                 : }
      33                 : 
      34                 : /// JWT payload. See docs/authentication.md for the format
      35            1660 : #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
      36                 : pub struct Claims {
      37                 :     #[serde(default)]
      38                 :     pub tenant_id: Option<TenantId>,
      39                 :     pub scope: Scope,
      40                 : }
      41                 : 
      42                 : impl Claims {
      43             163 :     pub fn new(tenant_id: Option<TenantId>, scope: Scope) -> Self {
      44             163 :         Self { tenant_id, scope }
      45             163 :     }
      46                 : }
      47                 : 
      48                 : pub struct SwappableJwtAuth(ArcSwap<JwtAuth>);
      49                 : 
      50                 : impl SwappableJwtAuth {
      51              31 :     pub fn new(jwt_auth: JwtAuth) -> Self {
      52              31 :         SwappableJwtAuth(ArcSwap::new(Arc::new(jwt_auth)))
      53              31 :     }
      54               6 :     pub fn swap(&self, jwt_auth: JwtAuth) {
      55               6 :         self.0.swap(Arc::new(jwt_auth));
      56               6 :     }
      57             216 :     pub fn decode(&self, token: &str) -> std::result::Result<TokenData<Claims>, AuthError> {
      58             216 :         self.0.load().decode(token)
      59             216 :     }
      60                 : }
      61                 : 
      62                 : impl std::fmt::Debug for SwappableJwtAuth {
      63 UBC           0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      64               0 :         write!(f, "Swappable({:?})", self.0.load())
      65               0 :     }
      66                 : }
      67                 : 
      68               0 : #[derive(Clone, PartialEq, Eq, Hash, Debug)]
      69                 : pub struct AuthError(pub Cow<'static, str>);
      70                 : 
      71                 : impl Display for AuthError {
      72 CBC           3 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      73               3 :         write!(f, "{}", self.0)
      74               3 :     }
      75                 : }
      76                 : 
      77                 : impl From<AuthError> for ApiError {
      78               3 :     fn from(_value: AuthError) -> Self {
      79               3 :         // Don't pass on the value of the AuthError as a precautionary measure.
      80               3 :         // Being intentionally vague in public error communication hurts debugability
      81               3 :         // but it is more secure.
      82               3 :         ApiError::Forbidden("JWT authentication error".to_string())
      83               3 :     }
      84                 : }
      85                 : 
      86                 : pub struct JwtAuth {
      87                 :     decoding_keys: Vec<DecodingKey>,
      88                 :     validation: Validation,
      89                 : }
      90                 : 
      91                 : impl JwtAuth {
      92              82 :     pub fn new(decoding_keys: Vec<DecodingKey>) -> Self {
      93              82 :         let mut validation = Validation::default();
      94              82 :         validation.algorithms = vec![STORAGE_TOKEN_ALGORITHM];
      95              82 :         // The default 'required_spec_claims' is 'exp'. But we don't want to require
      96              82 :         // expiration.
      97              82 :         validation.required_spec_claims = [].into();
      98              82 :         Self {
      99              82 :             decoding_keys,
     100              82 :             validation,
     101              82 :         }
     102              82 :     }
     103                 : 
     104              80 :     pub fn from_key_path(key_path: &Utf8Path) -> Result<Self> {
     105              80 :         let metadata = key_path.metadata()?;
     106              80 :         let decoding_keys = if metadata.is_dir() {
     107               2 :             let mut keys = Vec::new();
     108               3 :             for entry in fs::read_dir(key_path)? {
     109               3 :                 let path = entry?.path();
     110               3 :                 if !path.is_file() {
     111                 :                     // Ignore directories (don't recurse)
     112 UBC           0 :                     continue;
     113 CBC           3 :                 }
     114               3 :                 let public_key = fs::read(path)?;
     115               3 :                 keys.push(DecodingKey::from_ed_pem(&public_key)?);
     116                 :             }
     117               2 :             keys
     118              78 :         } else if metadata.is_file() {
     119              78 :             let public_key = fs::read(key_path)?;
     120              78 :             vec![DecodingKey::from_ed_pem(&public_key)?]
     121                 :         } else {
     122 UBC           0 :             anyhow::bail!("path is neither a directory or a file")
     123                 :         };
     124 CBC          80 :         if decoding_keys.is_empty() {
     125 UBC           0 :             anyhow::bail!("Configured for JWT auth with zero decoding keys. All JWT gated requests would be rejected.");
     126 CBC          80 :         }
     127              80 :         Ok(Self::new(decoding_keys))
     128              80 :     }
     129                 : 
     130                 :     /// Attempt to decode the token with the internal decoding keys.
     131                 :     ///
     132                 :     /// The function tries the stored decoding keys in succession,
     133                 :     /// and returns the first yielding a successful result.
     134                 :     /// If there is no working decoding key, it returns the last error.
     135             346 :     pub fn decode(&self, token: &str) -> std::result::Result<TokenData<Claims>, AuthError> {
     136             346 :         let mut res = None;
     137             351 :         for decoding_key in &self.decoding_keys {
     138             348 :             res = Some(decode(token, decoding_key, &self.validation));
     139             348 :             if let Some(Ok(res)) = res {
     140             343 :                 return Ok(res);
     141               5 :             }
     142                 :         }
     143               3 :         if let Some(res) = res {
     144               3 :             res.map_err(|e| AuthError(Cow::Owned(e.to_string())))
     145                 :         } else {
     146 UBC           0 :             Err(AuthError(Cow::Borrowed("no JWT decoding keys configured")))
     147                 :         }
     148 CBC         346 :     }
     149                 : }
     150                 : 
     151                 : impl std::fmt::Debug for JwtAuth {
     152 UBC           0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     153               0 :         f.debug_struct("JwtAuth")
     154               0 :             .field("validation", &self.validation)
     155               0 :             .finish()
     156               0 :     }
     157                 : }
     158                 : 
     159                 : // this function is used only for testing purposes in CLI e g generate tokens during init
     160 CBC         164 : pub fn encode_from_key_file(claims: &Claims, key_data: &[u8]) -> Result<String> {
     161             164 :     let key = EncodingKey::from_ed_pem(key_data)?;
     162             164 :     Ok(encode(&Header::new(STORAGE_TOKEN_ALGORITHM), claims, &key)?)
     163             164 : }
     164                 : 
     165                 : #[cfg(test)]
     166                 : mod tests {
     167                 :     use super::*;
     168                 :     use std::str::FromStr;
     169                 : 
     170                 :     // Generated with:
     171                 :     //
     172                 :     // openssl genpkey -algorithm ed25519 -out ed25519-priv.pem
     173                 :     // openssl pkey -in ed25519-priv.pem -pubout -out ed25519-pub.pem
     174                 :     const TEST_PUB_KEY_ED25519: &[u8] = br#"
     175                 : -----BEGIN PUBLIC KEY-----
     176                 : MCowBQYDK2VwAyEARYwaNBayR+eGI0iXB4s3QxE3Nl2g1iWbr6KtLWeVD/w=
     177                 : -----END PUBLIC KEY-----
     178                 : "#;
     179                 : 
     180                 :     const TEST_PRIV_KEY_ED25519: &[u8] = br#"
     181                 : -----BEGIN PRIVATE KEY-----
     182                 : MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
     183                 : -----END PRIVATE KEY-----
     184                 : "#;
     185                 : 
     186               1 :     #[test]
     187               1 :     fn test_decode() {
     188               1 :         let expected_claims = Claims {
     189               1 :             tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
     190               1 :             scope: Scope::Tenant,
     191               1 :         };
     192               1 : 
     193               1 :         // A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
     194               1 :         //
     195               1 :         // ```
     196               1 :         // {
     197               1 :         //   "scope": "tenant",
     198               1 :         //   "tenant_id": "3d1f7595b468230304e0b73cecbcb081",
     199               1 :         //   "iss": "neon.controlplane",
     200               1 :         //   "exp": 1709200879,
     201               1 :         //   "iat": 1678442479
     202               1 :         // }
     203               1 :         // ```
     204               1 :         //
     205               1 :         let encoded_eddsa = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.U3eA8j-uU-JnhzeO3EDHRuXLwkAUFCPxtGHEgw6p7Ccc3YRbFs2tmCdbD9PZEXP-XsxSeBQi1FY0YPcT3NXADw";
     206               1 : 
     207               1 :         // Check it can be validated with the public key
     208               1 :         let auth = JwtAuth::new(vec![DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519).unwrap()]);
     209               1 :         let claims_from_token = auth.decode(encoded_eddsa).unwrap().claims;
     210               1 :         assert_eq!(claims_from_token, expected_claims);
     211               1 :     }
     212                 : 
     213               1 :     #[test]
     214               1 :     fn test_encode() {
     215               1 :         let claims = Claims {
     216               1 :             tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
     217               1 :             scope: Scope::Tenant,
     218               1 :         };
     219               1 : 
     220               1 :         let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_ED25519).unwrap();
     221               1 : 
     222               1 :         // decode it back
     223               1 :         let auth = JwtAuth::new(vec![DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519).unwrap()]);
     224               1 :         let decoded = auth.decode(&encoded).unwrap();
     225               1 : 
     226               1 :         assert_eq!(decoded.claims, claims);
     227               1 :     }
     228                 : }
        

Generated by: LCOV version 2.1-beta