LCOV - code coverage report
Current view: top level - proxy/src/scram - messages.rs (source / functions) Coverage Total Hit
Test: 553e39c2773e5840c720c90d86e56f89a4330d43.info Lines: 95.8 % 143 137
Test Date: 2025-06-13 20:01:21 Functions: 92.9 % 14 13

            Line data    Source code
       1              : //! Definitions for SCRAM messages.
       2              : 
       3              : use std::fmt;
       4              : use std::ops::Range;
       5              : 
       6              : use base64::Engine as _;
       7              : use base64::prelude::BASE64_STANDARD;
       8              : 
       9              : use super::base64_decode_array;
      10              : use super::key::{SCRAM_KEY_LEN, ScramKey};
      11              : use super::signature::SignatureBuilder;
      12              : use crate::sasl::ChannelBinding;
      13              : 
      14              : /// Faithfully taken from PostgreSQL.
      15              : pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18;
      16              : 
      17              : /// Although we ignore all extensions, we still have to validate the message.
      18           33 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
      19           33 :     for mut chars in parts.map(|s| s.chars()) {
      20            6 :         let attr = chars.next()?;
      21            6 :         if !attr.is_ascii_alphabetic() {
      22            1 :             return None;
      23            5 :         }
      24            5 :         let eq = chars.next()?;
      25            4 :         if eq != '=' {
      26            1 :             return None;
      27            3 :         }
      28              :     }
      29              : 
      30           30 :     Some(())
      31           33 : }
      32              : 
      33              : #[derive(Debug)]
      34              : pub(crate) struct ClientFirstMessage<'a> {
      35              :     /// `client-first-message-bare`.
      36              :     pub(crate) bare: &'a str,
      37              :     /// Channel binding mode.
      38              :     pub(crate) cbind_flag: ChannelBinding<&'a str>,
      39              :     /// Client nonce.
      40              :     pub(crate) nonce: &'a str,
      41              : }
      42              : 
      43              : impl<'a> ClientFirstMessage<'a> {
      44              :     // NB: FromStr doesn't work with lifetimes
      45           21 :     pub(crate) fn parse(input: &'a str) -> Option<Self> {
      46           21 :         let mut parts = input.split(',');
      47              : 
      48           21 :         let cbind_flag = ChannelBinding::parse(parts.next()?)?;
      49              : 
      50              :         // PG doesn't support authorization identity,
      51              :         // so we don't bother defining GS2 header type
      52           21 :         let authzid = parts.next()?;
      53           21 :         if !authzid.is_empty() {
      54            1 :             return None;
      55           20 :         }
      56           20 : 
      57           20 :         // Unfortunately, `parts.as_str()` is unstable
      58           20 :         let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
      59           20 :         let (_, bare) = input.split_at(pos);
      60              : 
      61              :         // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
      62           20 :         let username = parts.next()?.strip_prefix("n=")?;
      63              : 
      64              :         // https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
      65           20 :         if !username.is_empty() {
      66            1 :             tracing::warn!(username, "scram username provided, but is not expected");
      67              :             // TODO(conrad):
      68              :             // return None;
      69           19 :         }
      70              : 
      71           20 :         let nonce = parts.next()?.strip_prefix("r=")?;
      72              : 
      73              :         // Validate but ignore auth extensions
      74           20 :         validate_sasl_extensions(parts)?;
      75              : 
      76           17 :         Some(Self {
      77           17 :             bare,
      78           17 :             cbind_flag,
      79           17 :             nonce,
      80           17 :         })
      81           21 :     }
      82              : 
      83              :     /// Build a response to [`ClientFirstMessage`].
      84           12 :     pub(crate) fn build_server_first_message(
      85           12 :         &self,
      86           12 :         nonce: &[u8; SCRAM_RAW_NONCE_LEN],
      87           12 :         salt_base64: &str,
      88           12 :         iterations: u32,
      89           12 :     ) -> OwnedServerFirstMessage {
      90              :         use std::fmt::Write;
      91              : 
      92           12 :         let mut message = String::new();
      93           12 :         write!(&mut message, "r={}", self.nonce).unwrap();
      94           12 :         BASE64_STANDARD.encode_string(nonce, &mut message);
      95           12 :         let combined_nonce = 2..message.len();
      96           12 :         write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
      97           12 : 
      98           12 :         // This design guarantees that it's impossible to create a
      99           12 :         // server-first-message without receiving a client-first-message
     100           12 :         OwnedServerFirstMessage {
     101           12 :             message,
     102           12 :             nonce: combined_nonce,
     103           12 :         }
     104           12 :     }
     105              : }
     106              : 
     107              : #[derive(Debug)]
     108              : pub(crate) struct ClientFinalMessage<'a> {
     109              :     /// `client-final-message-without-proof`.
     110              :     pub(crate) without_proof: &'a str,
     111              :     /// Channel binding data (base64).
     112              :     pub(crate) channel_binding: &'a str,
     113              :     /// Combined client & server nonce.
     114              :     pub(crate) nonce: &'a str,
     115              :     /// Client auth proof.
     116              :     pub(crate) proof: [u8; SCRAM_KEY_LEN],
     117              : }
     118              : 
     119              : impl<'a> ClientFinalMessage<'a> {
     120              :     // NB: FromStr doesn't work with lifetimes
     121           13 :     pub(crate) fn parse(input: &'a str) -> Option<Self> {
     122           13 :         let (without_proof, proof) = input.rsplit_once(',')?;
     123              : 
     124           13 :         let mut parts = without_proof.split(',');
     125           13 :         let channel_binding = parts.next()?.strip_prefix("c=")?;
     126           13 :         let nonce = parts.next()?.strip_prefix("r=")?;
     127              : 
     128              :         // Validate but ignore auth extensions
     129           13 :         validate_sasl_extensions(parts)?;
     130              : 
     131           13 :         let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
     132              : 
     133           13 :         Some(Self {
     134           13 :             without_proof,
     135           13 :             channel_binding,
     136           13 :             nonce,
     137           13 :             proof,
     138           13 :         })
     139           13 :     }
     140              : 
     141              :     /// Build a response to [`ClientFinalMessage`].
     142            7 :     pub(crate) fn build_server_final_message(
     143            7 :         &self,
     144            7 :         signature_builder: SignatureBuilder<'_>,
     145            7 :         server_key: &ScramKey,
     146            7 :     ) -> String {
     147            7 :         let mut buf = String::from("v=");
     148            7 :         BASE64_STANDARD.encode_string(signature_builder.build(server_key), &mut buf);
     149            7 : 
     150            7 :         buf
     151            7 :     }
     152              : }
     153              : 
     154              : /// We need to keep a convenient representation of this
     155              : /// message for the next authentication step.
     156              : pub(crate) struct OwnedServerFirstMessage {
     157              :     /// Owned `server-first-message`.
     158              :     message: String,
     159              :     /// Slice into `message`.
     160              :     nonce: Range<usize>,
     161              : }
     162              : 
     163              : impl OwnedServerFirstMessage {
     164              :     /// Extract combined nonce from the message.
     165              :     #[inline(always)]
     166            8 :     pub(crate) fn nonce(&self) -> &str {
     167            8 :         &self.message[self.nonce.clone()]
     168            8 :     }
     169              : 
     170              :     /// Get reference to a text representation of the message.
     171              :     #[inline(always)]
     172           20 :     pub(crate) fn as_str(&self) -> &str {
     173           20 :         &self.message
     174           20 :     }
     175              : }
     176              : 
     177              : impl fmt::Debug for OwnedServerFirstMessage {
     178            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     179            0 :         f.debug_struct("ServerFirstMessage")
     180            0 :             .field("message", &self.as_str())
     181            0 :             .field("nonce", &self.nonce())
     182            0 :             .finish()
     183            0 :     }
     184              : }
     185              : 
     186              : #[cfg(test)]
     187              : mod tests {
     188              :     use super::*;
     189              : 
     190              :     #[test]
     191            1 :     fn parse_client_first_message() {
     192              :         use ChannelBinding::*;
     193              : 
     194              :         // (Almost) real strings captured during debug sessions
     195            1 :         let cases = [
     196            1 :             (NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
     197            1 :             (NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
     198            1 :             (
     199            1 :                 Required("tls-server-end-point"),
     200            1 :                 "p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
     201            1 :             ),
     202            1 :         ];
     203              : 
     204            4 :         for (cb, input) in cases {
     205            3 :             let msg = ClientFirstMessage::parse(input).unwrap();
     206            3 : 
     207            3 :             assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
     208            3 :             assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
     209            3 :             assert_eq!(msg.cbind_flag, cb);
     210              :         }
     211            1 :     }
     212              : 
     213              :     #[test]
     214            1 :     fn parse_client_first_message_with_invalid_gs2_authz() {
     215            1 :         assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none());
     216            1 :     }
     217              : 
     218              :     #[test]
     219            1 :     fn parse_client_first_message_with_extra_params() {
     220            1 :         let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
     221            1 :         assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
     222            1 :         assert_eq!(msg.nonce, "nonce");
     223            1 :         assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
     224            1 :     }
     225              : 
     226              :     #[test]
     227            1 :     fn parse_client_first_message_with_extra_params_invalid() {
     228            1 :         // must be of the form `<ascii letter>=<...>`
     229            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
     230            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
     231            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
     232            1 :     }
     233              : 
     234              :     #[test]
     235            1 :     fn parse_client_final_message() {
     236            1 :         let input = [
     237            1 :             "c=eSws",
     238            1 :             "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
     239            1 :             "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
     240            1 :         ]
     241            1 :         .join(",");
     242            1 : 
     243            1 :         let msg = ClientFinalMessage::parse(&input).unwrap();
     244            1 :         assert_eq!(
     245            1 :             msg.without_proof,
     246            1 :             "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     247            1 :         );
     248            1 :         assert_eq!(
     249            1 :             msg.nonce,
     250            1 :             "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     251            1 :         );
     252            1 :         assert_eq!(
     253            1 :             BASE64_STANDARD.encode(msg.proof),
     254            1 :             "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
     255            1 :         );
     256            1 :     }
     257              : }
        

Generated by: LCOV version 2.1-beta