LCOV - code coverage report
Current view: top level - libs/proxy/postgres-protocol2/src/authentication - sasl.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 86.5 % 312 270
Test Date: 2025-02-20 13:11:02 Functions: 83.0 % 47 39

            Line data    Source code
       1              : //! SASL-based authentication support.
       2              : 
       3              : use hmac::{Hmac, Mac};
       4              : use rand::{self, Rng};
       5              : use sha2::digest::FixedOutput;
       6              : use sha2::{Digest, Sha256};
       7              : use std::fmt::Write;
       8              : use std::io;
       9              : use std::iter;
      10              : use std::mem;
      11              : use std::str;
      12              : use tokio::task::yield_now;
      13              : 
      14              : const NONCE_LENGTH: usize = 24;
      15              : 
      16              : /// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
      17              : pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
      18              : /// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
      19              : pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
      20              : 
      21              : // since postgres passwords are not required to exclude saslprep-prohibited
      22              : // characters or even be valid UTF8, we run saslprep if possible and otherwise
      23              : // return the raw password.
      24           13 : fn normalize(pass: &[u8]) -> Vec<u8> {
      25           13 :     let pass = match str::from_utf8(pass) {
      26           13 :         Ok(pass) => pass,
      27            0 :         Err(_) => return pass.to_vec(),
      28              :     };
      29              : 
      30           13 :     match stringprep::saslprep(pass) {
      31           13 :         Ok(pass) => pass.into_owned().into_bytes(),
      32            0 :         Err(_) => pass.as_bytes().to_vec(),
      33              :     }
      34           13 : }
      35              : 
      36           29 : pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
      37           29 :     let mut hmac =
      38           29 :         Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
      39           29 :     hmac.update(salt);
      40           29 :     hmac.update(&[0, 0, 0, 1]);
      41           29 :     let mut prev = hmac.finalize().into_bytes();
      42           29 : 
      43           29 :     let mut hi = prev;
      44              : 
      45       114660 :     for i in 1..iterations {
      46       114660 :         let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
      47       114660 :         hmac.update(&prev);
      48       114660 :         prev = hmac.finalize().into_bytes();
      49              : 
      50      3669120 :         for (hi, prev) in hi.iter_mut().zip(prev) {
      51      3669120 :             *hi ^= prev;
      52      3669120 :         }
      53              :         // yield every ~250us
      54              :         // hopefully reduces tail latencies
      55       114660 :         if i % 1024 == 0 {
      56           84 :             yield_now().await
      57       114576 :         }
      58              :     }
      59              : 
      60           29 :     hi.into()
      61           29 : }
      62              : 
      63              : enum ChannelBindingInner {
      64              :     Unrequested,
      65              :     Unsupported,
      66              :     TlsServerEndPoint(Vec<u8>),
      67              : }
      68              : 
      69              : /// The channel binding configuration for a SCRAM authentication exchange.
      70              : pub struct ChannelBinding(ChannelBindingInner);
      71              : 
      72              : impl ChannelBinding {
      73              :     /// The server did not request channel binding.
      74            2 :     pub fn unrequested() -> ChannelBinding {
      75            2 :         ChannelBinding(ChannelBindingInner::Unrequested)
      76            2 :     }
      77              : 
      78              :     /// The server requested channel binding but the client is unable to provide it.
      79            4 :     pub fn unsupported() -> ChannelBinding {
      80            4 :         ChannelBinding(ChannelBindingInner::Unsupported)
      81            4 :     }
      82              : 
      83              :     /// The server requested channel binding and the client will use the `tls-server-end-point`
      84              :     /// method.
      85           10 :     pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
      86           10 :         ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
      87           10 :     }
      88              : 
      89           25 :     fn gs2_header(&self) -> &'static str {
      90           25 :         match self.0 {
      91            1 :             ChannelBindingInner::Unrequested => "y,,",
      92            8 :             ChannelBindingInner::Unsupported => "n,,",
      93           16 :             ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
      94              :         }
      95           25 :     }
      96              : 
      97           12 :     fn cbind_data(&self) -> &[u8] {
      98           12 :         match self.0 {
      99            4 :             ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
     100            8 :             ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
     101              :         }
     102           12 :     }
     103              : }
     104              : 
     105              : /// A pair of keys for the SCRAM-SHA-256 mechanism.
     106              : /// See <https://datatracker.ietf.org/doc/html/rfc5802#section-3> for details.
     107              : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
     108              : pub struct ScramKeys<const N: usize> {
     109              :     /// Used by server to authenticate client.
     110              :     pub client_key: [u8; N],
     111              :     /// Used by client to verify server's signature.
     112              :     pub server_key: [u8; N],
     113              : }
     114              : 
     115              : /// Password or keys which were derived from it.
     116              : enum Credentials<const N: usize> {
     117              :     /// A regular password as a vector of bytes.
     118              :     Password(Vec<u8>),
     119              :     /// A precomputed pair of keys.
     120              :     Keys(ScramKeys<N>),
     121              : }
     122              : 
     123              : enum State {
     124              :     Update {
     125              :         nonce: String,
     126              :         password: Credentials<32>,
     127              :         channel_binding: ChannelBinding,
     128              :     },
     129              :     Finish {
     130              :         server_key: [u8; 32],
     131              :         auth_message: String,
     132              :     },
     133              :     Done,
     134              : }
     135              : 
     136              : /// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
     137              : /// process.
     138              : ///
     139              : /// During the authentication process, if the backend sends an `AuthenticationSASL` message which
     140              : /// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
     141              : ///
     142              : /// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
     143              : /// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
     144              : ///
     145              : /// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
     146              : /// passed to the `update()` method, after which the buffer returned by the `message()` method
     147              : /// should be sent to the backend in a `SASLResponse` message.
     148              : ///
     149              : /// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
     150              : /// to the `finish()` method, after which the authentication process is complete.
     151              : pub struct ScramSha256 {
     152              :     message: String,
     153              :     state: State,
     154              : }
     155              : 
     156           12 : fn nonce() -> String {
     157           12 :     // rand 0.5's ThreadRng is cryptographically secure
     158           12 :     let mut rng = rand::thread_rng();
     159           12 :     (0..NONCE_LENGTH)
     160          288 :         .map(|_| {
     161          288 :             let mut v = rng.gen_range(0x21u8..0x7e);
     162          288 :             if v == 0x2c {
     163            0 :                 v = 0x7e
     164          288 :             }
     165          288 :             v as char
     166          288 :         })
     167           12 :         .collect()
     168           12 : }
     169              : 
     170              : impl ScramSha256 {
     171              :     /// Constructs a new instance which will use the provided password for authentication.
     172           12 :     pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
     173           12 :         let password = Credentials::Password(normalize(password));
     174           12 :         ScramSha256::new_inner(password, channel_binding, nonce())
     175           12 :     }
     176              : 
     177              :     /// Constructs a new instance which will use the provided key pair for authentication.
     178            0 :     pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
     179            0 :         let password = Credentials::Keys(keys);
     180            0 :         ScramSha256::new_inner(password, channel_binding, nonce())
     181            0 :     }
     182              : 
     183           13 :     fn new_inner(
     184           13 :         password: Credentials<32>,
     185           13 :         channel_binding: ChannelBinding,
     186           13 :         nonce: String,
     187           13 :     ) -> ScramSha256 {
     188           13 :         ScramSha256 {
     189           13 :             message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
     190           13 :             state: State::Update {
     191           13 :                 nonce,
     192           13 :                 password,
     193           13 :                 channel_binding,
     194           13 :             },
     195           13 :         }
     196           13 :     }
     197              : 
     198              :     /// Returns the message which should be sent to the backend in an `SASLResponse` message.
     199           25 :     pub fn message(&self) -> &[u8] {
     200           25 :         if let State::Done = self.state {
     201            0 :             panic!("invalid SCRAM state");
     202           25 :         }
     203           25 :         self.message.as_bytes()
     204           25 :     }
     205              : 
     206              :     /// Updates the state machine with the response from the backend.
     207              :     ///
     208              :     /// This should be called when an `AuthenticationSASLContinue` message is received.
     209           12 :     pub async fn update(&mut self, message: &[u8]) -> io::Result<()> {
     210           12 :         let (client_nonce, password, channel_binding) =
     211           12 :             match mem::replace(&mut self.state, State::Done) {
     212              :                 State::Update {
     213           12 :                     nonce,
     214           12 :                     password,
     215           12 :                     channel_binding,
     216           12 :                 } => (nonce, password, channel_binding),
     217            0 :                 _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
     218              :             };
     219              : 
     220           12 :         let message =
     221           12 :             str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
     222              : 
     223           12 :         let parsed = Parser::new(message).server_first_message()?;
     224              : 
     225           12 :         if !parsed.nonce.starts_with(&client_nonce) {
     226            0 :             return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
     227           12 :         }
     228              : 
     229           12 :         let (client_key, server_key) = match password {
     230           12 :             Credentials::Password(password) => {
     231           12 :                 let salt = match base64::decode(parsed.salt) {
     232           12 :                     Ok(salt) => salt,
     233            0 :                     Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
     234              :                 };
     235              : 
     236           12 :                 let salted_password = hi(&password, &salt, parsed.iteration_count).await;
     237              : 
     238           24 :                 let make_key = |name| {
     239           24 :                     let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
     240           24 :                         .expect("HMAC is able to accept all key sizes");
     241           24 :                     hmac.update(name);
     242           24 : 
     243           24 :                     let mut key = [0u8; 32];
     244           24 :                     key.copy_from_slice(hmac.finalize().into_bytes().as_slice());
     245           24 :                     key
     246           24 :                 };
     247              : 
     248           12 :                 (make_key(b"Client Key"), make_key(b"Server Key"))
     249              :             }
     250            0 :             Credentials::Keys(keys) => (keys.client_key, keys.server_key),
     251              :         };
     252              : 
     253           12 :         let mut hash = Sha256::default();
     254           12 :         hash.update(client_key);
     255           12 :         let stored_key = hash.finalize_fixed();
     256           12 : 
     257           12 :         let mut cbind_input = vec![];
     258           12 :         cbind_input.extend(channel_binding.gs2_header().as_bytes());
     259           12 :         cbind_input.extend(channel_binding.cbind_data());
     260           12 :         let cbind_input = base64::encode(&cbind_input);
     261           12 : 
     262           12 :         self.message.clear();
     263           12 :         write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
     264           12 : 
     265           12 :         let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
     266           12 : 
     267           12 :         let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
     268           12 :             .expect("HMAC is able to accept all key sizes");
     269           12 :         hmac.update(auth_message.as_bytes());
     270           12 :         let client_signature = hmac.finalize().into_bytes();
     271           12 : 
     272           12 :         let mut client_proof = client_key;
     273          384 :         for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
     274          384 :             *proof ^= signature;
     275          384 :         }
     276              : 
     277           12 :         write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
     278           12 : 
     279           12 :         self.state = State::Finish {
     280           12 :             server_key,
     281           12 :             auth_message,
     282           12 :         };
     283           12 :         Ok(())
     284           12 :     }
     285              : 
     286              :     /// Finalizes the authentication process.
     287              :     ///
     288              :     /// This should be called when the backend sends an `AuthenticationSASLFinal` message.
     289              :     /// Authentication has only succeeded if this method returns `Ok(())`.
     290            7 :     pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
     291            7 :         let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) {
     292              :             State::Finish {
     293            7 :                 server_key,
     294            7 :                 auth_message,
     295            7 :             } => (server_key, auth_message),
     296            0 :             _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
     297              :         };
     298              : 
     299            7 :         let message =
     300            7 :             str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
     301              : 
     302            7 :         let parsed = Parser::new(message).server_final_message()?;
     303              : 
     304            7 :         let verifier = match parsed {
     305            0 :             ServerFinalMessage::Error(e) => {
     306            0 :                 return Err(io::Error::new(
     307            0 :                     io::ErrorKind::Other,
     308            0 :                     format!("SCRAM error: {}", e),
     309            0 :                 ));
     310              :             }
     311            7 :             ServerFinalMessage::Verifier(verifier) => verifier,
     312              :         };
     313              : 
     314            7 :         let verifier = match base64::decode(verifier) {
     315            7 :             Ok(verifier) => verifier,
     316            0 :             Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
     317              :         };
     318              : 
     319            7 :         let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
     320            7 :             .expect("HMAC is able to accept all key sizes");
     321            7 :         hmac.update(auth_message.as_bytes());
     322            7 :         hmac.verify_slice(&verifier)
     323            7 :             .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
     324            7 :     }
     325              : }
     326              : 
     327              : struct Parser<'a> {
     328              :     s: &'a str,
     329              :     it: iter::Peekable<str::CharIndices<'a>>,
     330              : }
     331              : 
     332              : impl<'a> Parser<'a> {
     333           20 :     fn new(s: &'a str) -> Parser<'a> {
     334           20 :         Parser {
     335           20 :             s,
     336           20 :             it: s.char_indices().peekable(),
     337           20 :         }
     338           20 :     }
     339              : 
     340          118 :     fn eat(&mut self, target: char) -> io::Result<()> {
     341          118 :         match self.it.next() {
     342          118 :             Some((_, c)) if c == target => Ok(()),
     343            0 :             Some((i, c)) => {
     344            0 :                 let m = format!(
     345            0 :                     "unexpected character at byte {}: expected `{}` but got `{}",
     346            0 :                     i, target, c
     347            0 :                 );
     348            0 :                 Err(io::Error::new(io::ErrorKind::InvalidInput, m))
     349              :             }
     350            0 :             None => Err(io::Error::new(
     351            0 :                 io::ErrorKind::UnexpectedEof,
     352            0 :                 "unexpected EOF",
     353            0 :             )),
     354              :         }
     355          118 :     }
     356              : 
     357           46 :     fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
     358           46 :     where
     359           46 :         F: Fn(char) -> bool,
     360           46 :     {
     361           46 :         let start = match self.it.peek() {
     362           46 :             Some(&(i, _)) => i,
     363            0 :             None => return Ok(""),
     364              :         };
     365              : 
     366              :         loop {
     367         1337 :             match self.it.peek() {
     368         1317 :                 Some(&(_, c)) if f(c) => {
     369         1291 :                     self.it.next();
     370         1291 :                 }
     371           26 :                 Some(&(i, _)) => return Ok(&self.s[start..i]),
     372           20 :                 None => return Ok(&self.s[start..]),
     373              :             }
     374              :         }
     375           46 :     }
     376              : 
     377           13 :     fn printable(&mut self) -> io::Result<&'a str> {
     378          631 :         self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
     379           13 :     }
     380              : 
     381           13 :     fn nonce(&mut self) -> io::Result<&'a str> {
     382           13 :         self.eat('r')?;
     383           13 :         self.eat('=')?;
     384           13 :         self.printable()
     385           13 :     }
     386              : 
     387           20 :     fn base64(&mut self) -> io::Result<&'a str> {
     388          637 :         self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
     389           20 :     }
     390              : 
     391           13 :     fn salt(&mut self) -> io::Result<&'a str> {
     392           13 :         self.eat('s')?;
     393           13 :         self.eat('=')?;
     394           13 :         self.base64()
     395           13 :     }
     396              : 
     397           13 :     fn posit_number(&mut self) -> io::Result<u32> {
     398           49 :         let n = self.take_while(|c| c.is_ascii_digit())?;
     399           13 :         n.parse()
     400           13 :             .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
     401           13 :     }
     402              : 
     403           13 :     fn iteration_count(&mut self) -> io::Result<u32> {
     404           13 :         self.eat('i')?;
     405           13 :         self.eat('=')?;
     406           13 :         self.posit_number()
     407           13 :     }
     408              : 
     409           20 :     fn eof(&mut self) -> io::Result<()> {
     410           20 :         match self.it.peek() {
     411            0 :             Some(&(i, _)) => Err(io::Error::new(
     412            0 :                 io::ErrorKind::InvalidInput,
     413            0 :                 format!("unexpected trailing data at byte {}", i),
     414            0 :             )),
     415           20 :             None => Ok(()),
     416              :         }
     417           20 :     }
     418              : 
     419           13 :     fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
     420           13 :         let nonce = self.nonce()?;
     421           13 :         self.eat(',')?;
     422           13 :         let salt = self.salt()?;
     423           13 :         self.eat(',')?;
     424           13 :         let iteration_count = self.iteration_count()?;
     425           13 :         self.eof()?;
     426              : 
     427           13 :         Ok(ServerFirstMessage {
     428           13 :             nonce,
     429           13 :             salt,
     430           13 :             iteration_count,
     431           13 :         })
     432           13 :     }
     433              : 
     434            0 :     fn value(&mut self) -> io::Result<&'a str> {
     435            0 :         self.take_while(|c| matches!(c, '\0' | '=' | ','))
     436            0 :     }
     437              : 
     438            7 :     fn server_error(&mut self) -> io::Result<Option<&'a str>> {
     439            7 :         match self.it.peek() {
     440            0 :             Some(&(_, 'e')) => {}
     441            7 :             _ => return Ok(None),
     442              :         }
     443              : 
     444            0 :         self.eat('e')?;
     445            0 :         self.eat('=')?;
     446            0 :         self.value().map(Some)
     447            7 :     }
     448              : 
     449            7 :     fn verifier(&mut self) -> io::Result<&'a str> {
     450            7 :         self.eat('v')?;
     451            7 :         self.eat('=')?;
     452            7 :         self.base64()
     453            7 :     }
     454              : 
     455            7 :     fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
     456            7 :         let message = match self.server_error()? {
     457            0 :             Some(error) => ServerFinalMessage::Error(error),
     458            7 :             None => ServerFinalMessage::Verifier(self.verifier()?),
     459              :         };
     460            7 :         self.eof()?;
     461            7 :         Ok(message)
     462            7 :     }
     463              : }
     464              : 
     465              : struct ServerFirstMessage<'a> {
     466              :     nonce: &'a str,
     467              :     salt: &'a str,
     468              :     iteration_count: u32,
     469              : }
     470              : 
     471              : enum ServerFinalMessage<'a> {
     472              :     Error(&'a str),
     473              :     Verifier(&'a str),
     474              : }
     475              : 
     476              : #[cfg(test)]
     477              : mod test {
     478              :     use super::*;
     479              : 
     480              :     #[test]
     481            1 :     fn parse_server_first_message() {
     482            1 :         let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
     483            1 :         let message = Parser::new(message).server_first_message().unwrap();
     484            1 :         assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
     485            1 :         assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
     486            1 :         assert_eq!(message.iteration_count, 4096);
     487            1 :     }
     488              : 
     489              :     // recorded auth exchange from psql
     490              :     #[tokio::test]
     491            1 :     async fn exchange() {
     492            1 :         let password = "foobar";
     493            1 :         let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
     494            1 : 
     495            1 :         let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
     496            1 :         let server_first =
     497            1 :             "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
     498            1 :              =4096";
     499            1 :         let client_final =
     500            1 :             "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
     501            1 :              1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
     502            1 :         let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
     503            1 : 
     504            1 :         let mut scram = ScramSha256::new_inner(
     505            1 :             Credentials::Password(normalize(password.as_bytes())),
     506            1 :             ChannelBinding::unsupported(),
     507            1 :             nonce.to_string(),
     508            1 :         );
     509            1 :         assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
     510            1 : 
     511            1 :         scram.update(server_first.as_bytes()).await.unwrap();
     512            1 :         assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
     513            1 : 
     514            1 :         scram.finish(server_final.as_bytes()).unwrap();
     515            1 :     }
     516              : }
        

Generated by: LCOV version 2.1-beta