LCOV - code coverage report
Current view: top level - proxy/src - scram.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 92.6 % 68 63
Test Date: 2023-09-06 10:18:01 Functions: 100.0 % 13 13

            Line data    Source code
       1              : //! Salted Challenge Response Authentication Mechanism.
       2              : //!
       3              : //! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
       4              : //!
       5              : //! Reference implementation:
       6              : //! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
       7              : //! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
       8              : 
       9              : mod exchange;
      10              : mod key;
      11              : mod messages;
      12              : mod secret;
      13              : mod signature;
      14              : 
      15              : #[cfg(any(test, doc))]
      16              : mod password;
      17              : 
      18              : pub use exchange::Exchange;
      19              : pub use key::ScramKey;
      20              : pub use secret::ServerSecret;
      21              : pub use secret::*;
      22              : 
      23              : use hmac::{Hmac, Mac};
      24              : use sha2::{Digest, Sha256};
      25              : 
      26              : // TODO: add SCRAM-SHA-256-PLUS
      27              : /// A list of supported SCRAM methods.
      28              : pub const METHODS: &[&str] = &["SCRAM-SHA-256"];
      29              : 
      30              : /// Decode base64 into array without any heap allocations
      31           81 : fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
      32           81 :     let mut bytes = [0u8; N];
      33              : 
      34           81 :     let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
      35           81 :     if size != N {
      36            0 :         return None;
      37           81 :     }
      38           81 : 
      39           81 :     Some(bytes)
      40           81 : }
      41              : 
      42              : /// This function essentially is `Hmac(sha256, key, input)`.
      43              : /// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
      44         4162 : fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
      45         4162 :     let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
      46         4387 :     parts.into_iter().for_each(|s| mac.update(s));
      47         4162 : 
      48         4162 :     mac.finalize().into_bytes().into()
      49         4162 : }
      50              : 
      51           37 : fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
      52           37 :     let mut hasher = Sha256::new();
      53           39 :     parts.into_iter().for_each(|s| hasher.update(s));
      54           37 : 
      55           37 :     hasher.finalize().into()
      56           37 : }
      57              : 
      58              : #[cfg(test)]
      59              : mod tests {
      60              :     use crate::sasl::{Mechanism, Step};
      61              : 
      62              :     use super::{password::SaltedPassword, Exchange, ServerSecret};
      63              : 
      64            1 :     #[test]
      65            1 :     fn happy_path() {
      66            1 :         let iterations = 4096;
      67            1 :         let salt_base64 = "QSXCR+Q6sek8bf92";
      68            1 :         let pw = SaltedPassword::new(
      69            1 :             b"pencil",
      70            1 :             base64::decode(salt_base64).unwrap().as_slice(),
      71            1 :             iterations,
      72            1 :         );
      73            1 : 
      74            1 :         let secret = ServerSecret {
      75            1 :             iterations,
      76            1 :             salt_base64: salt_base64.to_owned(),
      77            1 :             stored_key: pw.client_key().sha256(),
      78            1 :             server_key: pw.server_key(),
      79            1 :             doomed: false,
      80            1 :         };
      81            1 :         const NONCE: [u8; 18] = [
      82            1 :             1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
      83            1 :         ];
      84            1 :         let mut exchange = Exchange::new(&secret, || NONCE, None);
      85            1 : 
      86            1 :         let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
      87            1 :         let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
      88            1 :         let server_first =
      89            1 :             "r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,s=QSXCR+Q6sek8bf92,i=4096";
      90            1 :         let server_final = "v=qtUDIofVnIhM7tKn93EQUUt5vgMOldcDVu1HC+OH0o0=";
      91              : 
      92            1 :         exchange = match exchange.exchange(client_first).unwrap() {
      93            1 :             Step::Continue(exchange, message) => {
      94            1 :                 assert_eq!(message, server_first);
      95            1 :                 exchange
      96              :             }
      97            0 :             Step::Success(_, _) => panic!("expected continue, got success"),
      98            0 :             Step::Failure(f) => panic!("{f}"),
      99              :         };
     100              : 
     101            1 :         let key = match exchange.exchange(client_final).unwrap() {
     102            1 :             Step::Success(key, message) => {
     103            1 :                 assert_eq!(message, server_final);
     104            1 :                 key
     105              :             }
     106            0 :             Step::Continue(_, _) => panic!("expected success, got continue"),
     107            0 :             Step::Failure(f) => panic!("{f}"),
     108              :         };
     109              : 
     110            1 :         assert_eq!(
     111            1 :             key.as_bytes(),
     112            1 :             [
     113            1 :                 74, 103, 1, 132, 12, 31, 200, 48, 28, 54, 82, 232, 207, 12, 138, 189, 40, 32, 134,
     114            1 :                 27, 125, 170, 232, 35, 171, 167, 166, 41, 70, 228, 182, 112,
     115            1 :             ]
     116            1 :         );
     117            1 :     }
     118              : }
        

Generated by: LCOV version 2.1-beta