LCOV - code coverage report
Current view: top level - proxy/src/scram - messages.rs (source / functions) Coverage Total Hit
Test: 465a86b0c1fda0069b3e0f6c1c126e6b635a1f72.info Lines: 96.1 % 152 146
Test Date: 2024-06-25 15:47:26 Functions: 92.9 % 14 13

            Line data    Source code
       1              : //! Definitions for SCRAM messages.
       2              : 
       3              : use super::base64_decode_array;
       4              : use super::key::{ScramKey, SCRAM_KEY_LEN};
       5              : use super::signature::SignatureBuilder;
       6              : use crate::sasl::ChannelBinding;
       7              : use std::fmt;
       8              : use std::ops::Range;
       9              : 
      10              : /// Faithfully taken from PostgreSQL.
      11              : pub const SCRAM_RAW_NONCE_LEN: usize = 18;
      12              : 
      13              : /// Although we ignore all extensions, we still have to validate the message.
      14           66 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
      15           66 :     for mut chars in parts.map(|s| s.chars()) {
      16           12 :         let attr = chars.next()?;
      17           12 :         if !attr.is_ascii_alphabetic() {
      18            2 :             return None;
      19           10 :         }
      20           10 :         let eq = chars.next()?;
      21            8 :         if eq != '=' {
      22            2 :             return None;
      23            6 :         }
      24              :     }
      25              : 
      26           60 :     Some(())
      27           66 : }
      28              : 
      29              : #[derive(Debug)]
      30              : pub struct ClientFirstMessage<'a> {
      31              :     /// `client-first-message-bare`.
      32              :     pub bare: &'a str,
      33              :     /// Channel binding mode.
      34              :     pub cbind_flag: ChannelBinding<&'a str>,
      35              :     /// Client nonce.
      36              :     pub nonce: &'a str,
      37              : }
      38              : 
      39              : impl<'a> ClientFirstMessage<'a> {
      40              :     // NB: FromStr doesn't work with lifetimes
      41           42 :     pub fn parse(input: &'a str) -> Option<Self> {
      42           42 :         let mut parts = input.split(',');
      43              : 
      44           42 :         let cbind_flag = ChannelBinding::parse(parts.next()?)?;
      45              : 
      46              :         // PG doesn't support authorization identity,
      47              :         // so we don't bother defining GS2 header type
      48           42 :         let authzid = parts.next()?;
      49           42 :         if !authzid.is_empty() {
      50            2 :             return None;
      51           40 :         }
      52           40 : 
      53           40 :         // Unfortunately, `parts.as_str()` is unstable
      54           40 :         let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
      55           40 :         let (_, bare) = input.split_at(pos);
      56              : 
      57              :         // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
      58           40 :         let username = parts.next()?.strip_prefix("n=")?;
      59              : 
      60              :         // https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
      61           40 :         if !username.is_empty() {
      62            2 :             tracing::warn!(username, "scram username provided, but is not expected")
      63              :             // TODO(conrad):
      64              :             // return None;
      65           38 :         }
      66              : 
      67           40 :         let nonce = parts.next()?.strip_prefix("r=")?;
      68              : 
      69              :         // Validate but ignore auth extensions
      70           40 :         validate_sasl_extensions(parts)?;
      71              : 
      72           34 :         Some(Self {
      73           34 :             bare,
      74           34 :             cbind_flag,
      75           34 :             nonce,
      76           34 :         })
      77           42 :     }
      78              : 
      79              :     /// Build a response to [`ClientFirstMessage`].
      80           24 :     pub fn build_server_first_message(
      81           24 :         &self,
      82           24 :         nonce: &[u8; SCRAM_RAW_NONCE_LEN],
      83           24 :         salt_base64: &str,
      84           24 :         iterations: u32,
      85           24 :     ) -> OwnedServerFirstMessage {
      86           24 :         use std::fmt::Write;
      87           24 : 
      88           24 :         let mut message = String::new();
      89           24 :         write!(&mut message, "r={}", self.nonce).unwrap();
      90           24 :         base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
      91           24 :         let combined_nonce = 2..message.len();
      92           24 :         write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
      93           24 : 
      94           24 :         // This design guarantees that it's impossible to create a
      95           24 :         // server-first-message without receiving a client-first-message
      96           24 :         OwnedServerFirstMessage {
      97           24 :             message,
      98           24 :             nonce: combined_nonce,
      99           24 :         }
     100           24 :     }
     101              : }
     102              : 
     103              : #[derive(Debug)]
     104              : pub struct ClientFinalMessage<'a> {
     105              :     /// `client-final-message-without-proof`.
     106              :     pub without_proof: &'a str,
     107              :     /// Channel binding data (base64).
     108              :     pub channel_binding: &'a str,
     109              :     /// Combined client & server nonce.
     110              :     pub nonce: &'a str,
     111              :     /// Client auth proof.
     112              :     pub proof: [u8; SCRAM_KEY_LEN],
     113              : }
     114              : 
     115              : impl<'a> ClientFinalMessage<'a> {
     116              :     // NB: FromStr doesn't work with lifetimes
     117           26 :     pub fn parse(input: &'a str) -> Option<Self> {
     118           26 :         let (without_proof, proof) = input.rsplit_once(',')?;
     119              : 
     120           26 :         let mut parts = without_proof.split(',');
     121           26 :         let channel_binding = parts.next()?.strip_prefix("c=")?;
     122           26 :         let nonce = parts.next()?.strip_prefix("r=")?;
     123              : 
     124              :         // Validate but ignore auth extensions
     125           26 :         validate_sasl_extensions(parts)?;
     126              : 
     127           26 :         let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
     128              : 
     129           26 :         Some(Self {
     130           26 :             without_proof,
     131           26 :             channel_binding,
     132           26 :             nonce,
     133           26 :             proof,
     134           26 :         })
     135           26 :     }
     136              : 
     137              :     /// Build a response to [`ClientFinalMessage`].
     138           14 :     pub fn build_server_final_message(
     139           14 :         &self,
     140           14 :         signature_builder: SignatureBuilder,
     141           14 :         server_key: &ScramKey,
     142           14 :     ) -> String {
     143           14 :         let mut buf = String::from("v=");
     144           14 :         base64::encode_config_buf(
     145           14 :             signature_builder.build(server_key),
     146           14 :             base64::STANDARD,
     147           14 :             &mut buf,
     148           14 :         );
     149           14 : 
     150           14 :         buf
     151           14 :     }
     152              : }
     153              : 
     154              : /// We need to keep a convenient representation of this
     155              : /// message for the next authentication step.
     156              : pub struct OwnedServerFirstMessage {
     157              :     /// Owned `server-first-message`.
     158              :     message: String,
     159              :     /// Slice into `message`.
     160              :     nonce: Range<usize>,
     161              : }
     162              : 
     163              : impl OwnedServerFirstMessage {
     164              :     /// Extract combined nonce from the message.
     165              :     #[inline(always)]
     166           16 :     pub fn nonce(&self) -> &str {
     167           16 :         &self.message[self.nonce.clone()]
     168           16 :     }
     169              : 
     170              :     /// Get reference to a text representation of the message.
     171              :     #[inline(always)]
     172           40 :     pub fn as_str(&self) -> &str {
     173           40 :         &self.message
     174           40 :     }
     175              : }
     176              : 
     177              : impl fmt::Debug for OwnedServerFirstMessage {
     178            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     179            0 :         f.debug_struct("ServerFirstMessage")
     180            0 :             .field("message", &self.as_str())
     181            0 :             .field("nonce", &self.nonce())
     182            0 :             .finish()
     183            0 :     }
     184              : }
     185              : 
     186              : #[cfg(test)]
     187              : mod tests {
     188              :     use super::*;
     189              : 
     190              :     #[test]
     191            2 :     fn parse_client_first_message() {
     192            2 :         use ChannelBinding::*;
     193            2 : 
     194            2 :         // (Almost) real strings captured during debug sessions
     195            2 :         let cases = [
     196            2 :             (NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
     197            2 :             (NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
     198            2 :             (
     199            2 :                 Required("tls-server-end-point"),
     200            2 :                 "p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
     201            2 :             ),
     202            2 :         ];
     203              : 
     204            8 :         for (cb, input) in cases {
     205            6 :             let msg = ClientFirstMessage::parse(input).unwrap();
     206            6 : 
     207            6 :             assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
     208            6 :             assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
     209            6 :             assert_eq!(msg.cbind_flag, cb);
     210              :         }
     211            2 :     }
     212              : 
     213              :     #[test]
     214            2 :     fn parse_client_first_message_with_invalid_gs2_authz() {
     215            2 :         assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none())
     216            2 :     }
     217              : 
     218              :     #[test]
     219            2 :     fn parse_client_first_message_with_extra_params() {
     220            2 :         let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
     221            2 :         assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
     222            2 :         assert_eq!(msg.nonce, "nonce");
     223            2 :         assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
     224            2 :     }
     225              : 
     226              :     #[test]
     227            2 :     fn parse_client_first_message_with_extra_params_invalid() {
     228            2 :         // must be of the form `<ascii letter>=<...>`
     229            2 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
     230            2 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
     231            2 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
     232            2 :     }
     233              : 
     234              :     #[test]
     235            2 :     fn parse_client_final_message() {
     236            2 :         let input = [
     237            2 :             "c=eSws",
     238            2 :             "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
     239            2 :             "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
     240            2 :         ]
     241            2 :         .join(",");
     242            2 : 
     243            2 :         let msg = ClientFinalMessage::parse(&input).unwrap();
     244            2 :         assert_eq!(
     245            2 :             msg.without_proof,
     246            2 :             "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     247            2 :         );
     248            2 :         assert_eq!(
     249            2 :             msg.nonce,
     250            2 :             "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     251            2 :         );
     252            2 :         assert_eq!(
     253            2 :             base64::encode(msg.proof),
     254            2 :             "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
     255            2 :         );
     256            2 :     }
     257              : }
        

Generated by: LCOV version 2.1-beta