Line data Source code
1 : //! For postgres password authentication, we need to perform a PBKDF2 using
2 : //! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key.
3 :
4 : use hmac::Mac as _;
5 : use hmac::digest::consts::U32;
6 : use hmac::digest::generic_array::GenericArray;
7 : use zeroize::Zeroize as _;
8 :
9 : use crate::metrics::Metrics;
10 :
11 : /// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake.
12 : pub type Prf = hmac::Hmac<sha2::Sha256>;
13 : pub(crate) type Block = GenericArray<u8, U32>;
14 :
15 : pub(crate) struct Pbkdf2 {
16 : hmac: Prf,
17 : /// U{r-1} for whatever iteration r we are currently on.
18 : prev: Block,
19 : /// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on.
20 : hi: Block,
21 : /// number of iterations left
22 : iterations: u32,
23 : }
24 :
25 : impl Drop for Pbkdf2 {
26 16 : fn drop(&mut self) {
27 16 : self.prev.zeroize();
28 16 : self.hi.zeroize();
29 16 : }
30 : }
31 :
32 : // inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
33 : impl Pbkdf2 {
34 16 : pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self {
35 : // key the HMAC and derive the first block in-place
36 16 : let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes");
37 :
38 : // U1 = PRF(Password, Salt + INT_32_BE(i))
39 : // i = 1 since we only need 1 block of output.
40 16 : hmac.update(salt);
41 16 : hmac.update(&1u32.to_be_bytes());
42 16 : let init_block = hmac.finalize_reset().into_bytes();
43 :
44 : // Prf::new_from_slice will run 2 sha256 rounds.
45 : // Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
46 16 : Metrics::get().proxy.sha_rounds.inc_by(4);
47 :
48 16 : Self {
49 16 : hmac,
50 16 : // one iteration spent above
51 16 : iterations: iterations - 1,
52 16 : hi: init_block,
53 16 : prev: init_block,
54 16 : }
55 16 : }
56 :
57 16 : pub(crate) fn cost(&self) -> u32 {
58 16 : (self.iterations).clamp(0, 4096)
59 16 : }
60 :
61 : /// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn`
62 : /// function that only executes a fixed number of iterations before continuing.
63 : ///
64 : /// Task must be rescheuled if this returns [`std::task::Poll::Pending`].
65 30 : pub(crate) fn turn(&mut self) -> std::task::Poll<Block> {
66 : let Self {
67 30 : hmac,
68 30 : prev,
69 30 : hi,
70 30 : iterations,
71 30 : } = self;
72 :
73 : // only do up to 4096 iterations per turn for fairness
74 30 : let n = (*iterations).clamp(0, 4096);
75 88784 : for _ in 0..n {
76 88784 : let next = single_round(hmac, prev);
77 88784 : xor_assign(hi, &next);
78 88784 : *prev = next;
79 88784 : }
80 :
81 : // Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
82 30 : Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64);
83 :
84 30 : *iterations -= n;
85 30 : if *iterations == 0 {
86 16 : std::task::Poll::Ready(*hi)
87 : } else {
88 14 : std::task::Poll::Pending
89 : }
90 30 : }
91 : }
92 :
93 : #[inline(always)]
94 88790 : pub fn xor_assign(x: &mut Block, y: &Block) {
95 2841280 : for (x, &y) in std::iter::zip(x, y) {
96 2841280 : *x ^= y;
97 2841280 : }
98 88790 : }
99 :
100 : #[inline(always)]
101 88784 : fn single_round(prf: &mut Prf, ui: &Block) -> Block {
102 : // Ui = PRF(Password, Ui-1)
103 88784 : prf.update(ui);
104 88784 : prf.finalize_reset().into_bytes()
105 88784 : }
106 :
107 : #[cfg(test)]
108 : mod tests {
109 : use pbkdf2::pbkdf2_hmac_array;
110 : use sha2::Sha256;
111 :
112 : use super::Pbkdf2;
113 :
114 : #[test]
115 1 : fn works() {
116 1 : let salt = b"sodium chloride";
117 1 : let pass = b"Ne0n_!5_50_C007";
118 :
119 1 : let mut job = Pbkdf2::start(pass, salt, 60000);
120 1 : let hash: [u8; 32] = loop {
121 15 : let std::task::Poll::Ready(hash) = job.turn() else {
122 14 : continue;
123 : };
124 1 : break hash.into();
125 : };
126 :
127 1 : let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);
128 1 : assert_eq!(hash, expected);
129 1 : }
130 : }
|