LCOV - code coverage report
Current view: top level - proxy/src/scram - messages.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 95.4 % 131 125
Test Date: 2025-07-16 12:29:03 Functions: 92.9 % 14 13

            Line data    Source code
       1              : //! Definitions for SCRAM messages.
       2              : 
       3              : use std::fmt;
       4              : use std::ops::Range;
       5              : 
       6              : use base64::Engine as _;
       7              : use base64::prelude::BASE64_STANDARD;
       8              : 
       9              : use super::base64_decode_array;
      10              : use super::key::{SCRAM_KEY_LEN, ScramKey};
      11              : use super::signature::SignatureBuilder;
      12              : use crate::sasl::ChannelBinding;
      13              : 
      14              : /// Faithfully taken from PostgreSQL.
      15              : pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18;
      16              : 
      17              : /// Although we ignore all extensions, we still have to validate the message.
      18           33 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
      19           33 :     for mut chars in parts.map(|s| s.chars()) {
      20            6 :         let attr = chars.next()?;
      21            6 :         if !attr.is_ascii_alphabetic() {
      22            1 :             return None;
      23            5 :         }
      24            5 :         let eq = chars.next()?;
      25            4 :         if eq != '=' {
      26            1 :             return None;
      27            3 :         }
      28              :     }
      29              : 
      30           30 :     Some(())
      31           33 : }
      32              : 
      33              : #[derive(Debug)]
      34              : pub(crate) struct ClientFirstMessage<'a> {
      35              :     /// `client-first-message-bare`.
      36              :     pub(crate) bare: &'a str,
      37              :     /// Channel binding mode.
      38              :     pub(crate) cbind_flag: ChannelBinding<&'a str>,
      39              :     /// Client nonce.
      40              :     pub(crate) nonce: &'a str,
      41              : }
      42              : 
      43              : impl<'a> ClientFirstMessage<'a> {
      44              :     // NB: FromStr doesn't work with lifetimes
      45           21 :     pub(crate) fn parse(input: &'a str) -> Option<Self> {
      46           21 :         let mut parts = input.split(',');
      47              : 
      48           21 :         let cbind_flag = ChannelBinding::parse(parts.next()?)?;
      49              : 
      50              :         // PG doesn't support authorization identity,
      51              :         // so we don't bother defining GS2 header type
      52           21 :         let authzid = parts.next()?;
      53           21 :         if !authzid.is_empty() {
      54            1 :             return None;
      55           20 :         }
      56              : 
      57              :         // Unfortunately, `parts.as_str()` is unstable
      58           20 :         let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
      59           20 :         let (_, bare) = input.split_at(pos);
      60              : 
      61              :         // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
      62           20 :         let username = parts.next()?.strip_prefix("n=")?;
      63              : 
      64              :         // https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
      65           20 :         if !username.is_empty() {
      66            1 :             tracing::warn!(username, "scram username provided, but is not expected");
      67              :             // TODO(conrad):
      68              :             // return None;
      69           19 :         }
      70              : 
      71           20 :         let nonce = parts.next()?.strip_prefix("r=")?;
      72              : 
      73              :         // Validate but ignore auth extensions
      74           20 :         validate_sasl_extensions(parts)?;
      75              : 
      76           17 :         Some(Self {
      77           17 :             bare,
      78           17 :             cbind_flag,
      79           17 :             nonce,
      80           17 :         })
      81           21 :     }
      82              : 
      83              :     /// Build a response to [`ClientFirstMessage`].
      84           12 :     pub(crate) fn build_server_first_message(
      85           12 :         &self,
      86           12 :         nonce: &[u8; SCRAM_RAW_NONCE_LEN],
      87           12 :         salt_base64: &str,
      88           12 :         iterations: u32,
      89           12 :     ) -> OwnedServerFirstMessage {
      90           12 :         let mut message = String::with_capacity(128);
      91           12 :         message.push_str("r=");
      92              : 
      93              :         // write combined nonce
      94           12 :         let combined_nonce_start = message.len();
      95           12 :         message.push_str(self.nonce);
      96           12 :         BASE64_STANDARD.encode_string(nonce, &mut message);
      97           12 :         let combined_nonce = combined_nonce_start..message.len();
      98              : 
      99              :         // write salt and iterations
     100           12 :         message.push_str(",s=");
     101           12 :         message.push_str(salt_base64);
     102           12 :         message.push_str(",i=");
     103           12 :         message.push_str(itoa::Buffer::new().format(iterations));
     104              : 
     105              :         // This design guarantees that it's impossible to create a
     106              :         // server-first-message without receiving a client-first-message
     107           12 :         OwnedServerFirstMessage {
     108           12 :             message,
     109           12 :             nonce: combined_nonce,
     110           12 :         }
     111           12 :     }
     112              : }
     113              : 
     114              : #[derive(Debug)]
     115              : pub(crate) struct ClientFinalMessage<'a> {
     116              :     /// `client-final-message-without-proof`.
     117              :     pub(crate) without_proof: &'a str,
     118              :     /// Channel binding data (base64).
     119              :     pub(crate) channel_binding: &'a str,
     120              :     /// Combined client & server nonce.
     121              :     pub(crate) nonce: &'a str,
     122              :     /// Client auth proof.
     123              :     pub(crate) proof: [u8; SCRAM_KEY_LEN],
     124              : }
     125              : 
     126              : impl<'a> ClientFinalMessage<'a> {
     127              :     // NB: FromStr doesn't work with lifetimes
     128           13 :     pub(crate) fn parse(input: &'a str) -> Option<Self> {
     129           13 :         let (without_proof, proof) = input.rsplit_once(',')?;
     130              : 
     131           13 :         let mut parts = without_proof.split(',');
     132           13 :         let channel_binding = parts.next()?.strip_prefix("c=")?;
     133           13 :         let nonce = parts.next()?.strip_prefix("r=")?;
     134              : 
     135              :         // Validate but ignore auth extensions
     136           13 :         validate_sasl_extensions(parts)?;
     137              : 
     138           13 :         let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
     139              : 
     140           13 :         Some(Self {
     141           13 :             without_proof,
     142           13 :             channel_binding,
     143           13 :             nonce,
     144           13 :             proof,
     145           13 :         })
     146           13 :     }
     147              : 
     148              :     /// Build a response to [`ClientFinalMessage`].
     149            7 :     pub(crate) fn build_server_final_message(
     150            7 :         &self,
     151            7 :         signature_builder: SignatureBuilder<'_>,
     152            7 :         server_key: &ScramKey,
     153            7 :     ) -> String {
     154            7 :         let mut buf = String::from("v=");
     155            7 :         BASE64_STANDARD.encode_string(signature_builder.build(server_key), &mut buf);
     156              : 
     157            7 :         buf
     158            7 :     }
     159              : }
     160              : 
     161              : /// We need to keep a convenient representation of this
     162              : /// message for the next authentication step.
     163              : pub(crate) struct OwnedServerFirstMessage {
     164              :     /// Owned `server-first-message`.
     165              :     message: String,
     166              :     /// Slice into `message`.
     167              :     nonce: Range<usize>,
     168              : }
     169              : 
     170              : impl OwnedServerFirstMessage {
     171              :     /// Extract combined nonce from the message.
     172              :     #[inline(always)]
     173            8 :     pub(crate) fn nonce(&self) -> &str {
     174            8 :         &self.message[self.nonce.clone()]
     175            8 :     }
     176              : 
     177              :     /// Get reference to a text representation of the message.
     178              :     #[inline(always)]
     179           20 :     pub(crate) fn as_str(&self) -> &str {
     180           20 :         &self.message
     181           20 :     }
     182              : }
     183              : 
     184              : impl fmt::Debug for OwnedServerFirstMessage {
     185            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     186            0 :         f.debug_struct("ServerFirstMessage")
     187            0 :             .field("message", &self.as_str())
     188            0 :             .field("nonce", &self.nonce())
     189            0 :             .finish()
     190            0 :     }
     191              : }
     192              : 
     193              : #[cfg(test)]
     194              : mod tests {
     195              :     use super::*;
     196              : 
     197              :     #[test]
     198            1 :     fn parse_client_first_message() {
     199              :         use ChannelBinding::*;
     200              : 
     201              :         // (Almost) real strings captured during debug sessions
     202            1 :         let cases = [
     203            1 :             (NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
     204            1 :             (NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
     205            1 :             (
     206            1 :                 Required("tls-server-end-point"),
     207            1 :                 "p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
     208            1 :             ),
     209            1 :         ];
     210              : 
     211            4 :         for (cb, input) in cases {
     212            3 :             let msg = ClientFirstMessage::parse(input).unwrap();
     213              : 
     214            3 :             assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
     215            3 :             assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
     216            3 :             assert_eq!(msg.cbind_flag, cb);
     217              :         }
     218            1 :     }
     219              : 
     220              :     #[test]
     221            1 :     fn parse_client_first_message_with_invalid_gs2_authz() {
     222            1 :         assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none());
     223            1 :     }
     224              : 
     225              :     #[test]
     226            1 :     fn parse_client_first_message_with_extra_params() {
     227            1 :         let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
     228            1 :         assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
     229            1 :         assert_eq!(msg.nonce, "nonce");
     230            1 :         assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
     231            1 :     }
     232              : 
     233              :     #[test]
     234            1 :     fn parse_client_first_message_with_extra_params_invalid() {
     235              :         // must be of the form `<ascii letter>=<...>`
     236            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
     237            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
     238            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
     239            1 :     }
     240              : 
     241              :     #[test]
     242            1 :     fn parse_client_final_message() {
     243            1 :         let input = [
     244            1 :             "c=eSws",
     245            1 :             "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
     246            1 :             "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
     247            1 :         ]
     248            1 :         .join(",");
     249              : 
     250            1 :         let msg = ClientFinalMessage::parse(&input).unwrap();
     251            1 :         assert_eq!(
     252              :             msg.without_proof,
     253              :             "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     254              :         );
     255            1 :         assert_eq!(
     256              :             msg.nonce,
     257              :             "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     258              :         );
     259            1 :         assert_eq!(
     260            1 :             BASE64_STANDARD.encode(msg.proof),
     261              :             "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
     262              :         );
     263            1 :     }
     264              : }
        

Generated by: LCOV version 2.1-beta