LCOV - code coverage report
Current view: top level - proxy/src/scram - pbkdf2.rs (source / functions) Coverage Total Hit
Test: 5713ff31fc16472ab3f92425989ca6addc3dcf9c.info Lines: 100.0 % 59 59
Test Date: 2025-07-30 16:18:19 Functions: 100.0 % 7 7

            Line data    Source code
       1              : //! For postgres password authentication, we need to perform a PBKDF2 using
       2              : //! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key.
       3              : 
       4              : use hmac::Mac as _;
       5              : use hmac::digest::consts::U32;
       6              : use hmac::digest::generic_array::GenericArray;
       7              : use zeroize::Zeroize as _;
       8              : 
       9              : use crate::metrics::Metrics;
      10              : 
      11              : /// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake.
      12              : pub type Prf = hmac::Hmac<sha2::Sha256>;
      13              : pub(crate) type Block = GenericArray<u8, U32>;
      14              : 
      15              : pub(crate) struct Pbkdf2 {
      16              :     hmac: Prf,
      17              :     /// U{r-1} for whatever iteration r we are currently on.
      18              :     prev: Block,
      19              :     /// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on.
      20              :     hi: Block,
      21              :     /// number of iterations left
      22              :     iterations: u32,
      23              : }
      24              : 
      25              : impl Drop for Pbkdf2 {
      26           16 :     fn drop(&mut self) {
      27           16 :         self.prev.zeroize();
      28           16 :         self.hi.zeroize();
      29           16 :     }
      30              : }
      31              : 
      32              : // inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
      33              : impl Pbkdf2 {
      34           16 :     pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self {
      35              :         // key the HMAC and derive the first block in-place
      36           16 :         let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes");
      37              : 
      38              :         // U1 = PRF(Password, Salt + INT_32_BE(i))
      39              :         // i = 1 since we only need 1 block of output.
      40           16 :         hmac.update(salt);
      41           16 :         hmac.update(&1u32.to_be_bytes());
      42           16 :         let init_block = hmac.finalize_reset().into_bytes();
      43              : 
      44              :         // Prf::new_from_slice will run 2 sha256 rounds.
      45              :         // Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
      46           16 :         Metrics::get().proxy.sha_rounds.inc_by(4);
      47              : 
      48           16 :         Self {
      49           16 :             hmac,
      50           16 :             // one iteration spent above
      51           16 :             iterations: iterations - 1,
      52           16 :             hi: init_block,
      53           16 :             prev: init_block,
      54           16 :         }
      55           16 :     }
      56              : 
      57           15 :     pub(crate) fn cost(&self) -> u32 {
      58           15 :         (self.iterations).clamp(0, 4096)
      59           15 :     }
      60              : 
      61              :     /// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn`
      62              :     /// function that only executes a fixed number of iterations before continuing.
      63              :     ///
      64              :     /// Task must be rescheuled if this returns [`std::task::Poll::Pending`].
      65           30 :     pub(crate) fn turn(&mut self) -> std::task::Poll<Block> {
      66              :         let Self {
      67           30 :             hmac,
      68           30 :             prev,
      69           30 :             hi,
      70           30 :             iterations,
      71           30 :         } = self;
      72              : 
      73              :         // only do up to 4096 iterations per turn for fairness
      74           30 :         let n = (*iterations).clamp(0, 4096);
      75        88784 :         for _ in 0..n {
      76        88784 :             let next = single_round(hmac, prev);
      77        88784 :             xor_assign(hi, &next);
      78        88784 :             *prev = next;
      79        88784 :         }
      80              : 
      81              :         // Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
      82           30 :         Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64);
      83              : 
      84           30 :         *iterations -= n;
      85           30 :         if *iterations == 0 {
      86           16 :             std::task::Poll::Ready(*hi)
      87              :         } else {
      88           14 :             std::task::Poll::Pending
      89              :         }
      90           30 :     }
      91              : }
      92              : 
      93              : #[inline(always)]
      94        88790 : pub fn xor_assign(x: &mut Block, y: &Block) {
      95      2841280 :     for (x, &y) in std::iter::zip(x, y) {
      96      2841280 :         *x ^= y;
      97      2841280 :     }
      98        88790 : }
      99              : 
     100              : #[inline(always)]
     101        88784 : fn single_round(prf: &mut Prf, ui: &Block) -> Block {
     102              :     // Ui = PRF(Password, Ui-1)
     103        88784 :     prf.update(ui);
     104        88784 :     prf.finalize_reset().into_bytes()
     105        88784 : }
     106              : 
     107              : #[cfg(test)]
     108              : mod tests {
     109              :     use pbkdf2::pbkdf2_hmac_array;
     110              :     use sha2::Sha256;
     111              : 
     112              :     use super::Pbkdf2;
     113              : 
     114              :     #[test]
     115            1 :     fn works() {
     116            1 :         let salt = b"sodium chloride";
     117            1 :         let pass = b"Ne0n_!5_50_C007";
     118              : 
     119            1 :         let mut job = Pbkdf2::start(pass, salt, 60000);
     120            1 :         let hash: [u8; 32] = loop {
     121           15 :             let std::task::Poll::Ready(hash) = job.turn() else {
     122           14 :                 continue;
     123              :             };
     124            1 :             break hash.into();
     125              :         };
     126              : 
     127            1 :         let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);
     128            1 :         assert_eq!(hash, expected);
     129            1 :     }
     130              : }
        

Generated by: LCOV version 2.1-beta