LCOV - code coverage report
Current view: top level - proxy/src/scram - pbkdf2.rs (source / functions) Coverage Total Hit
Test: 960803fca14b2e843c565dddf575f7017d250bc3.info Lines: 100.0 % 56 56
Test Date: 2024-06-22 23:41:44 Functions: 100.0 % 4 4

            Line data    Source code
       1              : use hmac::{
       2              :     digest::{consts::U32, generic_array::GenericArray},
       3              :     Hmac, Mac,
       4              : };
       5              : use sha2::Sha256;
       6              : 
       7              : pub struct Pbkdf2 {
       8              :     hmac: Hmac<Sha256>,
       9              :     prev: GenericArray<u8, U32>,
      10              :     hi: GenericArray<u8, U32>,
      11              :     iterations: u32,
      12              : }
      13              : 
      14              : // inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
      15              : impl Pbkdf2 {
      16           12 :     pub fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
      17           12 :         let hmac =
      18           12 :             Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
      19           12 : 
      20           12 :         let prev = hmac
      21           12 :             .clone()
      22           12 :             .chain_update(salt)
      23           12 :             .chain_update(1u32.to_be_bytes())
      24           12 :             .finalize()
      25           12 :             .into_bytes();
      26           12 : 
      27           12 :         Self {
      28           12 :             hmac,
      29           12 :             // one consumed for the hash above
      30           12 :             iterations: iterations - 1,
      31           12 :             hi: prev,
      32           12 :             prev,
      33           12 :         }
      34           12 :     }
      35              : 
      36           11 :     pub fn cost(&self) -> u32 {
      37           11 :         (self.iterations).clamp(0, 4096)
      38           11 :     }
      39              : 
      40          304 :     pub fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
      41          304 :         let Self {
      42          304 :             hmac,
      43          304 :             prev,
      44          304 :             hi,
      45          304 :             iterations,
      46          304 :         } = self;
      47          304 : 
      48          304 :         // only do 4096 iterations per turn before sharing the thread for fairness
      49          304 :         let n = (*iterations).clamp(0, 4096);
      50          304 :         for _ in 0..n {
      51      1240948 :             *prev = hmac.clone().chain_update(*prev).finalize().into_bytes();
      52              : 
      53     39710336 :             for (hi, prev) in hi.iter_mut().zip(*prev) {
      54     39710336 :                 *hi ^= prev;
      55     39710336 :             }
      56              :         }
      57              : 
      58          304 :         *iterations -= n;
      59          304 :         if *iterations == 0 {
      60           12 :             std::task::Poll::Ready((*hi).into())
      61              :         } else {
      62          292 :             std::task::Poll::Pending
      63              :         }
      64          304 :     }
      65              : }
      66              : 
      67              : #[cfg(test)]
      68              : mod tests {
      69              :     use super::Pbkdf2;
      70              :     use pbkdf2::pbkdf2_hmac_array;
      71              :     use sha2::Sha256;
      72              : 
      73              :     #[test]
      74            2 :     fn works() {
      75            2 :         let salt = b"sodium chloride";
      76            2 :         let pass = b"Ne0n_!5_50_C007";
      77            2 : 
      78            2 :         let mut job = Pbkdf2::start(pass, salt, 600000);
      79            2 :         let hash = loop {
      80          294 :             let std::task::Poll::Ready(hash) = job.turn() else {
      81          292 :                 continue;
      82              :             };
      83            2 :             break hash;
      84            2 :         };
      85            2 : 
      86            2 :         let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 600000);
      87            2 :         assert_eq!(hash, expected)
      88            2 :     }
      89              : }
        

Generated by: LCOV version 2.1-beta