LCOV - code coverage report
Current view: top level - proxy/src/scram - messages.rs (source / functions) Coverage Total Hit
Test: 691a4c28fe7169edd60b367c52d448a0a6605f1f.info Lines: 96.1 % 152 146
Test Date: 2024-05-10 13:18:37 Functions: 92.9 % 14 13

            Line data    Source code
       1              : //! Definitions for SCRAM messages.
       2              : 
       3              : use super::base64_decode_array;
       4              : use super::key::{ScramKey, SCRAM_KEY_LEN};
       5              : use super::signature::SignatureBuilder;
       6              : use crate::sasl::ChannelBinding;
       7              : use std::fmt;
       8              : use std::ops::Range;
       9              : 
      10              : /// Faithfully taken from PostgreSQL.
      11              : pub const SCRAM_RAW_NONCE_LEN: usize = 18;
      12              : 
      13              : /// Although we ignore all extensions, we still have to validate the message.
      14           66 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
      15           66 :     for mut chars in parts.map(|s| s.chars()) {
      16           12 :         let attr = chars.next()?;
      17           12 :         if !attr.is_ascii_alphabetic() {
      18            2 :             return None;
      19           10 :         }
      20           10 :         let eq = chars.next()?;
      21            8 :         if eq != '=' {
      22            2 :             return None;
      23            6 :         }
      24              :     }
      25              : 
      26           60 :     Some(())
      27           66 : }
      28              : 
      29              : #[derive(Debug)]
      30              : pub struct ClientFirstMessage<'a> {
      31              :     /// `client-first-message-bare`.
      32              :     pub bare: &'a str,
      33              :     /// Channel binding mode.
      34              :     pub cbind_flag: ChannelBinding<&'a str>,
      35              :     /// (Client username)[<https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf/src/backend/libpq/auth-scram.c#L13>].
      36              :     pub username: &'a str,
      37              :     /// Client nonce.
      38              :     pub nonce: &'a str,
      39              : }
      40              : 
      41              : impl<'a> ClientFirstMessage<'a> {
      42              :     // NB: FromStr doesn't work with lifetimes
      43           42 :     pub fn parse(input: &'a str) -> Option<Self> {
      44           42 :         let mut parts = input.split(',');
      45              : 
      46           42 :         let cbind_flag = ChannelBinding::parse(parts.next()?)?;
      47              : 
      48              :         // PG doesn't support authorization identity,
      49              :         // so we don't bother defining GS2 header type
      50           42 :         let authzid = parts.next()?;
      51           42 :         if !authzid.is_empty() {
      52            2 :             return None;
      53           40 :         }
      54           40 : 
      55           40 :         // Unfortunately, `parts.as_str()` is unstable
      56           40 :         let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
      57           40 :         let (_, bare) = input.split_at(pos);
      58              : 
      59              :         // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
      60           40 :         let username = parts.next()?.strip_prefix("n=")?;
      61           40 :         let nonce = parts.next()?.strip_prefix("r=")?;
      62              : 
      63              :         // Validate but ignore auth extensions
      64           40 :         validate_sasl_extensions(parts)?;
      65              : 
      66           34 :         Some(Self {
      67           34 :             bare,
      68           34 :             cbind_flag,
      69           34 :             username,
      70           34 :             nonce,
      71           34 :         })
      72           42 :     }
      73              : 
      74              :     /// Build a response to [`ClientFirstMessage`].
      75           24 :     pub fn build_server_first_message(
      76           24 :         &self,
      77           24 :         nonce: &[u8; SCRAM_RAW_NONCE_LEN],
      78           24 :         salt_base64: &str,
      79           24 :         iterations: u32,
      80           24 :     ) -> OwnedServerFirstMessage {
      81           24 :         use std::fmt::Write;
      82           24 : 
      83           24 :         let mut message = String::new();
      84           24 :         write!(&mut message, "r={}", self.nonce).unwrap();
      85           24 :         base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
      86           24 :         let combined_nonce = 2..message.len();
      87           24 :         write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
      88           24 : 
      89           24 :         // This design guarantees that it's impossible to create a
      90           24 :         // server-first-message without receiving a client-first-message
      91           24 :         OwnedServerFirstMessage {
      92           24 :             message,
      93           24 :             nonce: combined_nonce,
      94           24 :         }
      95           24 :     }
      96              : }
      97              : 
      98              : #[derive(Debug)]
      99              : pub struct ClientFinalMessage<'a> {
     100              :     /// `client-final-message-without-proof`.
     101              :     pub without_proof: &'a str,
     102              :     /// Channel binding data (base64).
     103              :     pub channel_binding: &'a str,
     104              :     /// Combined client & server nonce.
     105              :     pub nonce: &'a str,
     106              :     /// Client auth proof.
     107              :     pub proof: [u8; SCRAM_KEY_LEN],
     108              : }
     109              : 
     110              : impl<'a> ClientFinalMessage<'a> {
     111              :     // NB: FromStr doesn't work with lifetimes
     112           26 :     pub fn parse(input: &'a str) -> Option<Self> {
     113           26 :         let (without_proof, proof) = input.rsplit_once(',')?;
     114              : 
     115           26 :         let mut parts = without_proof.split(',');
     116           26 :         let channel_binding = parts.next()?.strip_prefix("c=")?;
     117           26 :         let nonce = parts.next()?.strip_prefix("r=")?;
     118              : 
     119              :         // Validate but ignore auth extensions
     120           26 :         validate_sasl_extensions(parts)?;
     121              : 
     122           26 :         let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
     123              : 
     124           26 :         Some(Self {
     125           26 :             without_proof,
     126           26 :             channel_binding,
     127           26 :             nonce,
     128           26 :             proof,
     129           26 :         })
     130           26 :     }
     131              : 
     132              :     /// Build a response to [`ClientFinalMessage`].
     133           14 :     pub fn build_server_final_message(
     134           14 :         &self,
     135           14 :         signature_builder: SignatureBuilder,
     136           14 :         server_key: &ScramKey,
     137           14 :     ) -> String {
     138           14 :         let mut buf = String::from("v=");
     139           14 :         base64::encode_config_buf(
     140           14 :             signature_builder.build(server_key),
     141           14 :             base64::STANDARD,
     142           14 :             &mut buf,
     143           14 :         );
     144           14 : 
     145           14 :         buf
     146           14 :     }
     147              : }
     148              : 
     149              : /// We need to keep a convenient representation of this
     150              : /// message for the next authentication step.
     151              : pub struct OwnedServerFirstMessage {
     152              :     /// Owned `server-first-message`.
     153              :     message: String,
     154              :     /// Slice into `message`.
     155              :     nonce: Range<usize>,
     156              : }
     157              : 
     158              : impl OwnedServerFirstMessage {
     159              :     /// Extract combined nonce from the message.
     160              :     #[inline(always)]
     161           16 :     pub fn nonce(&self) -> &str {
     162           16 :         &self.message[self.nonce.clone()]
     163           16 :     }
     164              : 
     165              :     /// Get reference to a text representation of the message.
     166              :     #[inline(always)]
     167           40 :     pub fn as_str(&self) -> &str {
     168           40 :         &self.message
     169           40 :     }
     170              : }
     171              : 
     172              : impl fmt::Debug for OwnedServerFirstMessage {
     173            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     174            0 :         f.debug_struct("ServerFirstMessage")
     175            0 :             .field("message", &self.as_str())
     176            0 :             .field("nonce", &self.nonce())
     177            0 :             .finish()
     178            0 :     }
     179              : }
     180              : 
     181              : #[cfg(test)]
     182              : mod tests {
     183              :     use super::*;
     184              : 
     185              :     #[test]
     186            2 :     fn parse_client_first_message() {
     187            2 :         use ChannelBinding::*;
     188            2 : 
     189            2 :         // (Almost) real strings captured during debug sessions
     190            2 :         let cases = [
     191            2 :             (NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
     192            2 :             (NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
     193            2 :             (
     194            2 :                 Required("tls-server-end-point"),
     195            2 :                 "p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju",
     196            2 :             ),
     197            2 :         ];
     198              : 
     199            8 :         for (cb, input) in cases {
     200            6 :             let msg = ClientFirstMessage::parse(input).unwrap();
     201            6 : 
     202            6 :             assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju");
     203            6 :             assert_eq!(msg.username, "pepe");
     204            6 :             assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
     205            6 :             assert_eq!(msg.cbind_flag, cb);
     206              :         }
     207            2 :     }
     208              : 
     209              :     #[test]
     210            2 :     fn parse_client_first_message_with_invalid_gs2_authz() {
     211            2 :         assert!(ClientFirstMessage::parse("n,authzid,n=user,r=nonce").is_none())
     212            2 :     }
     213              : 
     214              :     #[test]
     215            2 :     fn parse_client_first_message_with_extra_params() {
     216            2 :         let msg = ClientFirstMessage::parse("n,,n=user,r=nonce,a=foo,b=bar,c=baz").unwrap();
     217            2 :         assert_eq!(msg.bare, "n=user,r=nonce,a=foo,b=bar,c=baz");
     218            2 :         assert_eq!(msg.username, "user");
     219            2 :         assert_eq!(msg.nonce, "nonce");
     220            2 :         assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
     221            2 :     }
     222              : 
     223              :     #[test]
     224            2 :     fn parse_client_first_message_with_extra_params_invalid() {
     225            2 :         // must be of the form `<ascii letter>=<...>`
     226            2 :         assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,abc=foo").is_none());
     227            2 :         assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,1=foo").is_none());
     228            2 :         assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,a").is_none());
     229            2 :     }
     230              : 
     231              :     #[test]
     232            2 :     fn parse_client_final_message() {
     233            2 :         let input = [
     234            2 :             "c=eSws",
     235            2 :             "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
     236            2 :             "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
     237            2 :         ]
     238            2 :         .join(",");
     239            2 : 
     240            2 :         let msg = ClientFinalMessage::parse(&input).unwrap();
     241            2 :         assert_eq!(
     242            2 :             msg.without_proof,
     243            2 :             "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     244            2 :         );
     245            2 :         assert_eq!(
     246            2 :             msg.nonce,
     247            2 :             "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     248            2 :         );
     249            2 :         assert_eq!(
     250            2 :             base64::encode(msg.proof),
     251            2 :             "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
     252            2 :         );
     253            2 :     }
     254              : }
        

Generated by: LCOV version 2.1-beta