LCOV - code coverage report
Current view: top level - proxy/src/sasl - messages.rs (source / functions) Coverage Total Hit
Test: 20b6afc7b7f34578dcaab2b3acdaecfe91cd8bf1.info Lines: 96.2 % 26 25
Test Date: 2024-11-25 17:48:16 Functions: 100.0 % 3 3

            Line data    Source code
       1              : //! Definitions for SASL messages.
       2              : 
       3              : use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
       4              : 
       5              : use crate::parse::{split_at_const, split_cstr};
       6              : 
       7              : /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
       8              : #[derive(Debug)]
       9              : pub(crate) struct FirstMessage<'a> {
      10              :     /// Authentication method, e.g. `"SCRAM-SHA-256"`.
      11              :     pub(crate) method: &'a str,
      12              :     /// Initial client message.
      13              :     pub(crate) message: &'a str,
      14              : }
      15              : 
      16              : impl<'a> FirstMessage<'a> {
      17              :     // NB: FromStr doesn't work with lifetimes
      18           13 :     pub(crate) fn parse(bytes: &'a [u8]) -> Option<Self> {
      19           13 :         let (method_cstr, tail) = split_cstr(bytes)?;
      20           13 :         let method = method_cstr.to_str().ok()?;
      21              : 
      22           13 :         let (len_bytes, bytes) = split_at_const(tail)?;
      23           13 :         let len = u32::from_be_bytes(*len_bytes) as usize;
      24           13 :         if len != bytes.len() {
      25            0 :             return None;
      26           13 :         }
      27              : 
      28           13 :         let message = std::str::from_utf8(bytes).ok()?;
      29           13 :         Some(Self { method, message })
      30           13 :     }
      31              : }
      32              : 
      33              : /// A single SASL message.
      34              : /// This struct is deliberately decoupled from lower-level
      35              : /// [`BeAuthenticationSaslMessage`].
      36              : #[derive(Debug)]
      37              : pub(super) enum ServerMessage<T> {
      38              :     /// We expect to see more steps.
      39              :     Continue(T),
      40              :     /// This is the final step.
      41              :     Final(T),
      42              : }
      43              : 
      44              : impl<'a> ServerMessage<&'a str> {
      45           17 :     pub(super) fn to_reply(&self) -> BeMessage<'a> {
      46           17 :         BeMessage::AuthenticationSasl(match self {
      47           11 :             ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()),
      48            6 :             ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()),
      49              :         })
      50           17 :     }
      51              : }
      52              : 
      53              : #[cfg(test)]
      54              : mod tests {
      55              :     use super::*;
      56              : 
      57              :     #[test]
      58            1 :     fn parse_sasl_first_message() {
      59            1 :         let proto = "SCRAM-SHA-256";
      60            1 :         let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4";
      61            1 :         let sasl_len = (sasl.len() as u32).to_be_bytes();
      62            1 :         let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat();
      63            1 : 
      64            1 :         let password = FirstMessage::parse(&bytes).unwrap();
      65            1 :         assert_eq!(password.method, proto);
      66            1 :         assert_eq!(password.message, sasl);
      67            1 :     }
      68              : }
        

Generated by: LCOV version 2.1-beta