LCOV - code coverage report
Current view: top level - proxy/src/scram - mod.rs (source / functions) Coverage Total Hit
Test: 1d5975439f3c9882b18414799141ebf9a3922c58.info Lines: 93.6 % 78 73
Test Date: 2025-07-31 15:59:03 Functions: 100.0 % 12 12

            Line data    Source code
       1              : //! Salted Challenge Response Authentication Mechanism.
       2              : //!
       3              : //! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
       4              : //!
       5              : //! Reference implementation:
       6              : //! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
       7              : //! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
       8              : 
       9              : mod cache;
      10              : mod countmin;
      11              : mod exchange;
      12              : mod key;
      13              : mod messages;
      14              : mod pbkdf2;
      15              : mod secret;
      16              : mod signature;
      17              : pub mod threadpool;
      18              : 
      19              : use base64::Engine as _;
      20              : use base64::prelude::BASE64_STANDARD;
      21              : pub(crate) use exchange::{Exchange, exchange};
      22              : pub(crate) use key::ScramKey;
      23              : pub(crate) use secret::ServerSecret;
      24              : 
      25              : const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
      26              : const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
      27              : 
      28              : /// A list of supported SCRAM methods.
      29              : pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
      30              : pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
      31              : 
      32              : /// Decode base64 into array without any heap allocations
      33           51 : fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
      34           51 :     let mut bytes = [0u8; N];
      35              : 
      36           51 :     let size = BASE64_STANDARD.decode_slice(input, &mut bytes).ok()?;
      37           51 :     if size != N {
      38            0 :         return None;
      39           51 :     }
      40              : 
      41           51 :     Some(bytes)
      42           51 : }
      43              : 
      44              : #[cfg(test)]
      45              : mod tests {
      46              :     use super::threadpool::ThreadPool;
      47              :     use super::{Exchange, ServerSecret};
      48              :     use crate::intern::{EndpointIdInt, RoleNameInt};
      49              :     use crate::sasl::{Mechanism, Step};
      50              :     use crate::types::{EndpointId, RoleName};
      51              : 
      52              :     #[test]
      53            1 :     fn snapshot() {
      54            1 :         let iterations = 4096;
      55            1 :         let salt = "QSXCR+Q6sek8bf92";
      56            1 :         let stored_key = "FO+9jBb3MUukt6jJnzjPZOWc5ow/Pu6JtPyju0aqaE8=";
      57            1 :         let server_key = "qxJ1SbmSAi5EcS0J5Ck/cKAm/+Ixa+Kwp63f4OHDgzo=";
      58            1 :         let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",);
      59            1 :         let secret = ServerSecret::parse(&secret).unwrap();
      60              : 
      61              :         const NONCE: [u8; 18] = [
      62              :             1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
      63              :         ];
      64            1 :         let mut exchange =
      65            1 :             Exchange::new(&secret, || NONCE, crate::tls::TlsServerEndPoint::Undefined);
      66              : 
      67            1 :         let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
      68            1 :         let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
      69            1 :         let server_first =
      70            1 :             "r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,s=QSXCR+Q6sek8bf92,i=4096";
      71            1 :         let server_final = "v=qtUDIofVnIhM7tKn93EQUUt5vgMOldcDVu1HC+OH0o0=";
      72              : 
      73            1 :         exchange = match exchange.exchange(client_first).unwrap() {
      74            1 :             Step::Continue(exchange, message) => {
      75            1 :                 assert_eq!(message, server_first);
      76            1 :                 exchange
      77              :             }
      78            0 :             Step::Success(_, _) => panic!("expected continue, got success"),
      79            0 :             Step::Failure(f) => panic!("{f}"),
      80              :         };
      81              : 
      82            1 :         let key = match exchange.exchange(client_final).unwrap() {
      83            1 :             Step::Success(key, message) => {
      84            1 :                 assert_eq!(message, server_final);
      85            1 :                 key
      86              :             }
      87            0 :             Step::Continue(_, _) => panic!("expected success, got continue"),
      88            0 :             Step::Failure(f) => panic!("{f}"),
      89              :         };
      90              : 
      91            1 :         assert_eq!(
      92            1 :             key.as_bytes(),
      93              :             [
      94              :                 74, 103, 1, 132, 12, 31, 200, 48, 28, 54, 82, 232, 207, 12, 138, 189, 40, 32, 134,
      95              :                 27, 125, 170, 232, 35, 171, 167, 166, 41, 70, 228, 182, 112,
      96              :             ]
      97              :         );
      98            1 :     }
      99              : 
     100            6 :     async fn check(
     101            6 :         pool: &ThreadPool,
     102            6 :         scram_secret: &ServerSecret,
     103            6 :         password: &[u8],
     104            6 :     ) -> Result<(), &'static str> {
     105            6 :         let ep = EndpointId::from("foo");
     106            6 :         let ep = EndpointIdInt::from(ep);
     107            6 :         let role = RoleName::from("user");
     108            6 :         let role = RoleNameInt::from(&role);
     109              : 
     110            6 :         let outcome = super::exchange(pool, ep, role, scram_secret, password)
     111            6 :             .await
     112            6 :             .unwrap();
     113              : 
     114            6 :         match outcome {
     115            3 :             crate::sasl::Outcome::Success(_) => Ok(()),
     116            3 :             crate::sasl::Outcome::Failure(r) => Err(r),
     117              :         }
     118            6 :     }
     119              : 
     120            2 :     async fn run_round_trip_test(server_password: &str, client_password: &str) {
     121            2 :         let pool = ThreadPool::new(1);
     122            2 :         let scram_secret = ServerSecret::build(server_password).await.unwrap();
     123            2 :         check(&pool, &scram_secret, client_password.as_bytes())
     124            2 :             .await
     125            2 :             .unwrap();
     126            2 :     }
     127              : 
     128              :     #[tokio::test]
     129            1 :     async fn round_trip() {
     130            1 :         run_round_trip_test("pencil", "pencil").await;
     131            1 :     }
     132              : 
     133              :     #[tokio::test]
     134              :     #[should_panic(expected = "password doesn't match")]
     135            1 :     async fn failure() {
     136            1 :         run_round_trip_test("pencil", "eraser").await;
     137            1 :     }
     138              : 
     139              :     #[tokio::test]
     140              :     #[tracing_test::traced_test]
     141            1 :     async fn password_cache() {
     142            1 :         let pool = ThreadPool::new(1);
     143            1 :         let scram_secret = ServerSecret::build("password").await.unwrap();
     144              : 
     145              :         // wrong passwords are not added to cache
     146            1 :         check(&pool, &scram_secret, b"wrong").await.unwrap_err();
     147            1 :         assert!(!logs_contain("storing cached password"));
     148              : 
     149              :         // correct passwords get cached
     150            1 :         check(&pool, &scram_secret, b"password").await.unwrap();
     151            1 :         assert!(logs_contain("storing cached password"));
     152              : 
     153              :         // wrong passwords do not match the cache
     154            1 :         check(&pool, &scram_secret, b"wrong").await.unwrap_err();
     155            1 :         assert!(!logs_contain("password validated from cache"));
     156              : 
     157              :         // correct passwords match the cache
     158            1 :         check(&pool, &scram_secret, b"password").await.unwrap();
     159            1 :         assert!(logs_contain("password validated from cache"));
     160            1 :     }
     161              : }
        

Generated by: LCOV version 2.1-beta