LCOV - code coverage report
Current view: top level - proxy/src/scram - pbkdf2.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 100.0 % 47 47
Test Date: 2025-07-16 12:29:03 Functions: 100.0 % 4 4

            Line data    Source code
       1              : use hmac::digest::consts::U32;
       2              : use hmac::digest::generic_array::GenericArray;
       3              : use hmac::{Hmac, Mac};
       4              : use sha2::Sha256;
       5              : 
       6              : pub(crate) struct Pbkdf2 {
       7              :     hmac: Hmac<Sha256>,
       8              :     prev: GenericArray<u8, U32>,
       9              :     hi: GenericArray<u8, U32>,
      10              :     iterations: u32,
      11              : }
      12              : 
      13              : // inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
      14              : impl Pbkdf2 {
      15            6 :     pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
      16              :         // key the HMAC and derive the first block in-place
      17            6 :         let mut hmac =
      18            6 :             Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
      19            6 :         hmac.update(salt);
      20            6 :         hmac.update(&1u32.to_be_bytes());
      21            6 :         let init_block = hmac.finalize_reset().into_bytes();
      22              : 
      23            6 :         Self {
      24            6 :             hmac,
      25            6 :             // one iteration spent above
      26            6 :             iterations: iterations - 1,
      27            6 :             hi: init_block,
      28            6 :             prev: init_block,
      29            6 :         }
      30            6 :     }
      31              : 
      32            6 :     pub(crate) fn cost(&self) -> u32 {
      33            6 :         (self.iterations).clamp(0, 4096)
      34            6 :     }
      35              : 
      36           20 :     pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
      37              :         let Self {
      38           20 :             hmac,
      39           20 :             prev,
      40           20 :             hi,
      41           20 :             iterations,
      42           20 :         } = self;
      43              : 
      44              :         // only do up to 4096 iterations per turn for fairness
      45           20 :         let n = (*iterations).clamp(0, 4096);
      46           20 :         for _ in 0..n {
      47        80474 :             hmac.update(prev);
      48        80474 :             let block = hmac.finalize_reset().into_bytes();
      49              : 
      50      2575168 :             for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) {
      51      2575168 :                 *hi_byte ^= b;
      52      2575168 :             }
      53              : 
      54        80474 :             *prev = block;
      55              :         }
      56              : 
      57           20 :         *iterations -= n;
      58           20 :         if *iterations == 0 {
      59            6 :             std::task::Poll::Ready((*hi).into())
      60              :         } else {
      61           14 :             std::task::Poll::Pending
      62              :         }
      63           20 :     }
      64              : }
      65              : 
      66              : #[cfg(test)]
      67              : mod tests {
      68              :     use pbkdf2::pbkdf2_hmac_array;
      69              :     use sha2::Sha256;
      70              : 
      71              :     use super::Pbkdf2;
      72              : 
      73              :     #[test]
      74            1 :     fn works() {
      75            1 :         let salt = b"sodium chloride";
      76            1 :         let pass = b"Ne0n_!5_50_C007";
      77              : 
      78            1 :         let mut job = Pbkdf2::start(pass, salt, 60000);
      79            1 :         let hash = loop {
      80           15 :             let std::task::Poll::Ready(hash) = job.turn() else {
      81           14 :                 continue;
      82              :             };
      83            1 :             break hash;
      84              :         };
      85              : 
      86            1 :         let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);
      87            1 :         assert_eq!(hash, expected);
      88            1 :     }
      89              : }
        

Generated by: LCOV version 2.1-beta