|             Line data    Source code 
       1              : //! Definitions for SASL messages.
       2              : 
       3              : use crate::parse::split_cstr;
       4              : 
       5              : /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
       6              : #[derive(Debug)]
       7              : pub(crate) struct FirstMessage<'a> {
       8              :     /// Authentication method, e.g. `"SCRAM-SHA-256"`.
       9              :     pub(crate) method: &'a str,
      10              :     /// Initial client message.
      11              :     pub(crate) message: &'a str,
      12              : }
      13              : 
      14              : impl<'a> FirstMessage<'a> {
      15              :     // NB: FromStr doesn't work with lifetimes
      16           13 :     pub(crate) fn parse(bytes: &'a [u8]) -> Option<Self> {
      17           13 :         let (method_cstr, tail) = split_cstr(bytes)?;
      18           13 :         let method = method_cstr.to_str().ok()?;
      19              : 
      20           13 :         let (len_bytes, bytes) = tail.split_first_chunk()?;
      21           13 :         let len = u32::from_be_bytes(*len_bytes) as usize;
      22           13 :         if len != bytes.len() {
      23            0 :             return None;
      24           13 :         }
      25              : 
      26           13 :         let message = std::str::from_utf8(bytes).ok()?;
      27           13 :         Some(Self { method, message })
      28           13 :     }
      29              : }
      30              : 
      31              : #[cfg(test)]
      32              : mod tests {
      33              :     use super::*;
      34              : 
      35              :     #[test]
      36            1 :     fn parse_sasl_first_message() {
      37            1 :         let proto = "SCRAM-SHA-256";
      38            1 :         let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4";
      39            1 :         let sasl_len = (sasl.len() as u32).to_be_bytes();
      40            1 :         let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat();
      41              : 
      42            1 :         let password = FirstMessage::parse(&bytes).unwrap();
      43            1 :         assert_eq!(password.method, proto);
      44            1 :         assert_eq!(password.message, sasl);
      45            1 :     }
      46              : }
         |