LCOV - code coverage report
Current view: top level - libs/proxy/postgres-protocol2/src/authentication - sasl.rs (source / functions) Coverage Total Hit
Test: 1b0a6a0c05cee5a7de360813c8034804e105ce1c.info Lines: 86.8 % 310 269
Test Date: 2025-03-12 00:01:28 Functions: 83.0 % 47 39

            Line data    Source code
       1              : //! SASL-based authentication support.
       2              : 
       3              : use std::fmt::Write;
       4              : use std::{io, iter, mem, str};
       5              : 
       6              : use hmac::{Hmac, Mac};
       7              : use rand::{self, Rng};
       8              : use sha2::digest::FixedOutput;
       9              : use sha2::{Digest, Sha256};
      10              : use tokio::task::yield_now;
      11              : 
      12              : const NONCE_LENGTH: usize = 24;
      13              : 
      14              : /// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
      15              : pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
      16              : /// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
      17              : pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
      18              : 
      19              : // since postgres passwords are not required to exclude saslprep-prohibited
      20              : // characters or even be valid UTF8, we run saslprep if possible and otherwise
      21              : // return the raw password.
      22           13 : fn normalize(pass: &[u8]) -> Vec<u8> {
      23           13 :     let pass = match str::from_utf8(pass) {
      24           13 :         Ok(pass) => pass,
      25            0 :         Err(_) => return pass.to_vec(),
      26              :     };
      27              : 
      28           13 :     match stringprep::saslprep(pass) {
      29           13 :         Ok(pass) => pass.into_owned().into_bytes(),
      30            0 :         Err(_) => pass.as_bytes().to_vec(),
      31              :     }
      32           13 : }
      33              : 
      34           29 : pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
      35           29 :     let mut hmac =
      36           29 :         Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
      37           29 :     hmac.update(salt);
      38           29 :     hmac.update(&[0, 0, 0, 1]);
      39           29 :     let mut prev = hmac.finalize().into_bytes();
      40           29 : 
      41           29 :     let mut hi = prev;
      42              : 
      43       114660 :     for i in 1..iterations {
      44       114660 :         let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
      45       114660 :         hmac.update(&prev);
      46       114660 :         prev = hmac.finalize().into_bytes();
      47              : 
      48      3669120 :         for (hi, prev) in hi.iter_mut().zip(prev) {
      49      3669120 :             *hi ^= prev;
      50      3669120 :         }
      51              :         // yield every ~250us
      52              :         // hopefully reduces tail latencies
      53       114660 :         if i % 1024 == 0 {
      54           84 :             yield_now().await
      55         8184 :         }
      56              :     }
      57              : 
      58           29 :     hi.into()
      59           29 : }
      60              : 
      61              : enum ChannelBindingInner {
      62              :     Unrequested,
      63              :     Unsupported,
      64              :     TlsServerEndPoint(Vec<u8>),
      65              : }
      66              : 
      67              : /// The channel binding configuration for a SCRAM authentication exchange.
      68              : pub struct ChannelBinding(ChannelBindingInner);
      69              : 
      70              : impl ChannelBinding {
      71              :     /// The server did not request channel binding.
      72            2 :     pub fn unrequested() -> ChannelBinding {
      73            2 :         ChannelBinding(ChannelBindingInner::Unrequested)
      74            2 :     }
      75              : 
      76              :     /// The server requested channel binding but the client is unable to provide it.
      77            4 :     pub fn unsupported() -> ChannelBinding {
      78            4 :         ChannelBinding(ChannelBindingInner::Unsupported)
      79            4 :     }
      80              : 
      81              :     /// The server requested channel binding and the client will use the `tls-server-end-point`
      82              :     /// method.
      83           10 :     pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
      84           10 :         ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
      85           10 :     }
      86              : 
      87           25 :     fn gs2_header(&self) -> &'static str {
      88           25 :         match self.0 {
      89            1 :             ChannelBindingInner::Unrequested => "y,,",
      90            8 :             ChannelBindingInner::Unsupported => "n,,",
      91           16 :             ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
      92              :         }
      93           25 :     }
      94              : 
      95           12 :     fn cbind_data(&self) -> &[u8] {
      96           12 :         match self.0 {
      97            4 :             ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
      98            8 :             ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
      99              :         }
     100           12 :     }
     101              : }
     102              : 
     103              : /// A pair of keys for the SCRAM-SHA-256 mechanism.
     104              : /// See <https://datatracker.ietf.org/doc/html/rfc5802#section-3> for details.
     105              : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
     106              : pub struct ScramKeys<const N: usize> {
     107              :     /// Used by server to authenticate client.
     108              :     pub client_key: [u8; N],
     109              :     /// Used by client to verify server's signature.
     110              :     pub server_key: [u8; N],
     111              : }
     112              : 
     113              : /// Password or keys which were derived from it.
     114              : enum Credentials<const N: usize> {
     115              :     /// A regular password as a vector of bytes.
     116              :     Password(Vec<u8>),
     117              :     /// A precomputed pair of keys.
     118              :     Keys(ScramKeys<N>),
     119              : }
     120              : 
     121              : enum State {
     122              :     Update {
     123              :         nonce: String,
     124              :         password: Credentials<32>,
     125              :         channel_binding: ChannelBinding,
     126              :     },
     127              :     Finish {
     128              :         server_key: [u8; 32],
     129              :         auth_message: String,
     130              :     },
     131              :     Done,
     132              : }
     133              : 
     134              : /// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
     135              : /// process.
     136              : ///
     137              : /// During the authentication process, if the backend sends an `AuthenticationSASL` message which
     138              : /// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
     139              : ///
     140              : /// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
     141              : /// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
     142              : ///
     143              : /// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
     144              : /// passed to the `update()` method, after which the buffer returned by the `message()` method
     145              : /// should be sent to the backend in a `SASLResponse` message.
     146              : ///
     147              : /// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
     148              : /// to the `finish()` method, after which the authentication process is complete.
     149              : pub struct ScramSha256 {
     150              :     message: String,
     151              :     state: State,
     152              : }
     153              : 
     154           12 : fn nonce() -> String {
     155           12 :     // rand 0.5's ThreadRng is cryptographically secure
     156           12 :     let mut rng = rand::thread_rng();
     157           12 :     (0..NONCE_LENGTH)
     158          288 :         .map(|_| {
     159          288 :             let mut v = rng.gen_range(0x21u8..0x7e);
     160          288 :             if v == 0x2c {
     161            6 :                 v = 0x7e
     162          282 :             }
     163          288 :             v as char
     164          288 :         })
     165           12 :         .collect()
     166           12 : }
     167              : 
     168              : impl ScramSha256 {
     169              :     /// Constructs a new instance which will use the provided password for authentication.
     170           12 :     pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
     171           12 :         let password = Credentials::Password(normalize(password));
     172           12 :         ScramSha256::new_inner(password, channel_binding, nonce())
     173           12 :     }
     174              : 
     175              :     /// Constructs a new instance which will use the provided key pair for authentication.
     176            0 :     pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
     177            0 :         let password = Credentials::Keys(keys);
     178            0 :         ScramSha256::new_inner(password, channel_binding, nonce())
     179            0 :     }
     180              : 
     181           13 :     fn new_inner(
     182           13 :         password: Credentials<32>,
     183           13 :         channel_binding: ChannelBinding,
     184           13 :         nonce: String,
     185           13 :     ) -> ScramSha256 {
     186           13 :         ScramSha256 {
     187           13 :             message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
     188           13 :             state: State::Update {
     189           13 :                 nonce,
     190           13 :                 password,
     191           13 :                 channel_binding,
     192           13 :             },
     193           13 :         }
     194           13 :     }
     195              : 
     196              :     /// Returns the message which should be sent to the backend in an `SASLResponse` message.
     197           25 :     pub fn message(&self) -> &[u8] {
     198           25 :         if let State::Done = self.state {
     199            0 :             panic!("invalid SCRAM state");
     200           25 :         }
     201           25 :         self.message.as_bytes()
     202           25 :     }
     203              : 
     204              :     /// Updates the state machine with the response from the backend.
     205              :     ///
     206              :     /// This should be called when an `AuthenticationSASLContinue` message is received.
     207           12 :     pub async fn update(&mut self, message: &[u8]) -> io::Result<()> {
     208           12 :         let (client_nonce, password, channel_binding) =
     209           12 :             match mem::replace(&mut self.state, State::Done) {
     210              :                 State::Update {
     211           12 :                     nonce,
     212           12 :                     password,
     213           12 :                     channel_binding,
     214           12 :                 } => (nonce, password, channel_binding),
     215            0 :                 _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
     216              :             };
     217              : 
     218           12 :         let message =
     219           12 :             str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
     220              : 
     221           12 :         let parsed = Parser::new(message).server_first_message()?;
     222              : 
     223           12 :         if !parsed.nonce.starts_with(&client_nonce) {
     224            0 :             return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
     225            1 :         }
     226              : 
     227           12 :         let (client_key, server_key) = match password {
     228           12 :             Credentials::Password(password) => {
     229           12 :                 let salt = match base64::decode(parsed.salt) {
     230           12 :                     Ok(salt) => salt,
     231            0 :                     Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
     232              :                 };
     233              : 
     234           12 :                 let salted_password = hi(&password, &salt, parsed.iteration_count).await;
     235              : 
     236           24 :                 let make_key = |name| {
     237           24 :                     let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
     238           24 :                         .expect("HMAC is able to accept all key sizes");
     239           24 :                     hmac.update(name);
     240           24 : 
     241           24 :                     let mut key = [0u8; 32];
     242           24 :                     key.copy_from_slice(hmac.finalize().into_bytes().as_slice());
     243           24 :                     key
     244           24 :                 };
     245              : 
     246           12 :                 (make_key(b"Client Key"), make_key(b"Server Key"))
     247              :             }
     248            0 :             Credentials::Keys(keys) => (keys.client_key, keys.server_key),
     249              :         };
     250              : 
     251           12 :         let mut hash = Sha256::default();
     252           12 :         hash.update(client_key);
     253           12 :         let stored_key = hash.finalize_fixed();
     254           12 : 
     255           12 :         let mut cbind_input = vec![];
     256           12 :         cbind_input.extend(channel_binding.gs2_header().as_bytes());
     257           12 :         cbind_input.extend(channel_binding.cbind_data());
     258           12 :         let cbind_input = base64::encode(&cbind_input);
     259           12 : 
     260           12 :         self.message.clear();
     261           12 :         write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
     262           12 : 
     263           12 :         let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
     264           12 : 
     265           12 :         let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
     266           12 :             .expect("HMAC is able to accept all key sizes");
     267           12 :         hmac.update(auth_message.as_bytes());
     268           12 :         let client_signature = hmac.finalize().into_bytes();
     269           12 : 
     270           12 :         let mut client_proof = client_key;
     271          384 :         for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
     272          384 :             *proof ^= signature;
     273          384 :         }
     274              : 
     275           12 :         write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
     276           12 : 
     277           12 :         self.state = State::Finish {
     278           12 :             server_key,
     279           12 :             auth_message,
     280           12 :         };
     281           12 :         Ok(())
     282            1 :     }
     283              : 
     284              :     /// Finalizes the authentication process.
     285              :     ///
     286              :     /// This should be called when the backend sends an `AuthenticationSASLFinal` message.
     287              :     /// Authentication has only succeeded if this method returns `Ok(())`.
     288            7 :     pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
     289            7 :         let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) {
     290              :             State::Finish {
     291            7 :                 server_key,
     292            7 :                 auth_message,
     293            7 :             } => (server_key, auth_message),
     294            0 :             _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
     295              :         };
     296              : 
     297            7 :         let message =
     298            7 :             str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
     299              : 
     300            7 :         let parsed = Parser::new(message).server_final_message()?;
     301              : 
     302            7 :         let verifier = match parsed {
     303            0 :             ServerFinalMessage::Error(e) => {
     304            0 :                 return Err(io::Error::new(
     305            0 :                     io::ErrorKind::Other,
     306            0 :                     format!("SCRAM error: {}", e),
     307            0 :                 ));
     308              :             }
     309            7 :             ServerFinalMessage::Verifier(verifier) => verifier,
     310              :         };
     311              : 
     312            7 :         let verifier = match base64::decode(verifier) {
     313            7 :             Ok(verifier) => verifier,
     314            0 :             Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
     315              :         };
     316              : 
     317            7 :         let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
     318            7 :             .expect("HMAC is able to accept all key sizes");
     319            7 :         hmac.update(auth_message.as_bytes());
     320            7 :         hmac.verify_slice(&verifier)
     321            7 :             .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
     322            7 :     }
     323              : }
     324              : 
     325              : struct Parser<'a> {
     326              :     s: &'a str,
     327              :     it: iter::Peekable<str::CharIndices<'a>>,
     328              : }
     329              : 
     330              : impl<'a> Parser<'a> {
     331           20 :     fn new(s: &'a str) -> Parser<'a> {
     332           20 :         Parser {
     333           20 :             s,
     334           20 :             it: s.char_indices().peekable(),
     335           20 :         }
     336           20 :     }
     337              : 
     338          118 :     fn eat(&mut self, target: char) -> io::Result<()> {
     339          118 :         match self.it.next() {
     340          118 :             Some((_, c)) if c == target => Ok(()),
     341            0 :             Some((i, c)) => {
     342            0 :                 let m = format!(
     343            0 :                     "unexpected character at byte {}: expected `{}` but got `{}",
     344            0 :                     i, target, c
     345            0 :                 );
     346            0 :                 Err(io::Error::new(io::ErrorKind::InvalidInput, m))
     347              :             }
     348            0 :             None => Err(io::Error::new(
     349            0 :                 io::ErrorKind::UnexpectedEof,
     350            0 :                 "unexpected EOF",
     351            0 :             )),
     352              :         }
     353          118 :     }
     354              : 
     355           46 :     fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
     356           46 :     where
     357           46 :         F: Fn(char) -> bool,
     358           46 :     {
     359           46 :         let start = match self.it.peek() {
     360           46 :             Some(&(i, _)) => i,
     361            0 :             None => return Ok(""),
     362              :         };
     363              : 
     364              :         loop {
     365         1337 :             match self.it.peek() {
     366         1317 :                 Some(&(_, c)) if f(c) => {
     367         1291 :                     self.it.next();
     368         1291 :                 }
     369           26 :                 Some(&(i, _)) => return Ok(&self.s[start..i]),
     370           20 :                 None => return Ok(&self.s[start..]),
     371              :             }
     372              :         }
     373           46 :     }
     374              : 
     375           13 :     fn printable(&mut self) -> io::Result<&'a str> {
     376          631 :         self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
     377           13 :     }
     378              : 
     379           13 :     fn nonce(&mut self) -> io::Result<&'a str> {
     380           13 :         self.eat('r')?;
     381           13 :         self.eat('=')?;
     382           13 :         self.printable()
     383           13 :     }
     384              : 
     385           20 :     fn base64(&mut self) -> io::Result<&'a str> {
     386          637 :         self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
     387           20 :     }
     388              : 
     389           13 :     fn salt(&mut self) -> io::Result<&'a str> {
     390           13 :         self.eat('s')?;
     391           13 :         self.eat('=')?;
     392           13 :         self.base64()
     393           13 :     }
     394              : 
     395           13 :     fn posit_number(&mut self) -> io::Result<u32> {
     396           49 :         let n = self.take_while(|c| c.is_ascii_digit())?;
     397           13 :         n.parse()
     398           13 :             .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
     399           13 :     }
     400              : 
     401           13 :     fn iteration_count(&mut self) -> io::Result<u32> {
     402           13 :         self.eat('i')?;
     403           13 :         self.eat('=')?;
     404           13 :         self.posit_number()
     405           13 :     }
     406              : 
     407           20 :     fn eof(&mut self) -> io::Result<()> {
     408           20 :         match self.it.peek() {
     409            0 :             Some(&(i, _)) => Err(io::Error::new(
     410            0 :                 io::ErrorKind::InvalidInput,
     411            0 :                 format!("unexpected trailing data at byte {}", i),
     412            0 :             )),
     413           20 :             None => Ok(()),
     414              :         }
     415           20 :     }
     416              : 
     417           13 :     fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
     418           13 :         let nonce = self.nonce()?;
     419           13 :         self.eat(',')?;
     420           13 :         let salt = self.salt()?;
     421           13 :         self.eat(',')?;
     422           13 :         let iteration_count = self.iteration_count()?;
     423           13 :         self.eof()?;
     424              : 
     425           13 :         Ok(ServerFirstMessage {
     426           13 :             nonce,
     427           13 :             salt,
     428           13 :             iteration_count,
     429           13 :         })
     430           13 :     }
     431              : 
     432            0 :     fn value(&mut self) -> io::Result<&'a str> {
     433            0 :         self.take_while(|c| matches!(c, '\0' | '=' | ','))
     434            0 :     }
     435              : 
     436            7 :     fn server_error(&mut self) -> io::Result<Option<&'a str>> {
     437            7 :         match self.it.peek() {
     438            0 :             Some(&(_, 'e')) => {}
     439            7 :             _ => return Ok(None),
     440              :         }
     441              : 
     442            0 :         self.eat('e')?;
     443            0 :         self.eat('=')?;
     444            0 :         self.value().map(Some)
     445            7 :     }
     446              : 
     447            7 :     fn verifier(&mut self) -> io::Result<&'a str> {
     448            7 :         self.eat('v')?;
     449            7 :         self.eat('=')?;
     450            7 :         self.base64()
     451            7 :     }
     452              : 
     453            7 :     fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
     454            7 :         let message = match self.server_error()? {
     455            0 :             Some(error) => ServerFinalMessage::Error(error),
     456            7 :             None => ServerFinalMessage::Verifier(self.verifier()?),
     457              :         };
     458            7 :         self.eof()?;
     459            7 :         Ok(message)
     460            7 :     }
     461              : }
     462              : 
     463              : struct ServerFirstMessage<'a> {
     464              :     nonce: &'a str,
     465              :     salt: &'a str,
     466              :     iteration_count: u32,
     467              : }
     468              : 
     469              : enum ServerFinalMessage<'a> {
     470              :     Error(&'a str),
     471              :     Verifier(&'a str),
     472              : }
     473              : 
     474              : #[cfg(test)]
     475              : mod test {
     476              :     use super::*;
     477              : 
     478              :     #[test]
     479            1 :     fn parse_server_first_message() {
     480            1 :         let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
     481            1 :         let message = Parser::new(message).server_first_message().unwrap();
     482            1 :         assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
     483            1 :         assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
     484            1 :         assert_eq!(message.iteration_count, 4096);
     485            1 :     }
     486              : 
     487              :     // recorded auth exchange from psql
     488              :     #[tokio::test]
     489            1 :     async fn exchange() {
     490            1 :         let password = "foobar";
     491            1 :         let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
     492            1 : 
     493            1 :         let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
     494            1 :         let server_first = "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
     495            1 :              =4096";
     496            1 :         let client_final = "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
     497            1 :              1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
     498            1 :         let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
     499            1 : 
     500            1 :         let mut scram = ScramSha256::new_inner(
     501            1 :             Credentials::Password(normalize(password.as_bytes())),
     502            1 :             ChannelBinding::unsupported(),
     503            1 :             nonce.to_string(),
     504            1 :         );
     505            1 :         assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
     506            1 : 
     507            1 :         scram.update(server_first.as_bytes()).await.unwrap();
     508            1 :         assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
     509            1 : 
     510            1 :         scram.finish(server_final.as_bytes()).unwrap();
     511            1 :     }
     512              : }
        

Generated by: LCOV version 2.1-beta