LCOV - code coverage report
Current view: top level - proxy/src/scram - messages.rs (source / functions) Coverage Total Hit
Test: 90b23405d17e36048d3bb64e314067f397803f1b.info Lines: 95.9 % 147 141
Test Date: 2024-09-20 13:14:58 Functions: 92.9 % 14 13

            Line data    Source code
       1              : //! Definitions for SCRAM messages.
       2              : 
       3              : use super::base64_decode_array;
       4              : use super::key::{ScramKey, SCRAM_KEY_LEN};
       5              : use super::signature::SignatureBuilder;
       6              : use crate::sasl::ChannelBinding;
       7              : use std::fmt;
       8              : use std::ops::Range;
       9              : 
      10              : /// Faithfully taken from PostgreSQL.
      11              : pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18;
      12              : 
      13              : /// Although we ignore all extensions, we still have to validate the message.
      14           33 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
      15           33 :     for mut chars in parts.map(|s| s.chars()) {
      16            6 :         let attr = chars.next()?;
      17            6 :         if !attr.is_ascii_alphabetic() {
      18            1 :             return None;
      19            5 :         }
      20            5 :         let eq = chars.next()?;
      21            4 :         if eq != '=' {
      22            1 :             return None;
      23            3 :         }
      24              :     }
      25              : 
      26           30 :     Some(())
      27           33 : }
      28              : 
      29              : #[derive(Debug)]
      30              : pub(crate) struct ClientFirstMessage<'a> {
      31              :     /// `client-first-message-bare`.
      32              :     pub(crate) bare: &'a str,
      33              :     /// Channel binding mode.
      34              :     pub(crate) cbind_flag: ChannelBinding<&'a str>,
      35              :     /// Client nonce.
      36              :     pub(crate) nonce: &'a str,
      37              : }
      38              : 
      39              : impl<'a> ClientFirstMessage<'a> {
      40              :     // NB: FromStr doesn't work with lifetimes
      41           21 :     pub(crate) fn parse(input: &'a str) -> Option<Self> {
      42           21 :         let mut parts = input.split(',');
      43              : 
      44           21 :         let cbind_flag = ChannelBinding::parse(parts.next()?)?;
      45              : 
      46              :         // PG doesn't support authorization identity,
      47              :         // so we don't bother defining GS2 header type
      48           21 :         let authzid = parts.next()?;
      49           21 :         if !authzid.is_empty() {
      50            1 :             return None;
      51           20 :         }
      52           20 : 
      53           20 :         // Unfortunately, `parts.as_str()` is unstable
      54           20 :         let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
      55           20 :         let (_, bare) = input.split_at(pos);
      56              : 
      57              :         // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
      58           20 :         let username = parts.next()?.strip_prefix("n=")?;
      59              : 
      60              :         // https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
      61           20 :         if !username.is_empty() {
      62            1 :             tracing::warn!(username, "scram username provided, but is not expected");
      63              :             // TODO(conrad):
      64              :             // return None;
      65           19 :         }
      66              : 
      67           20 :         let nonce = parts.next()?.strip_prefix("r=")?;
      68              : 
      69              :         // Validate but ignore auth extensions
      70           20 :         validate_sasl_extensions(parts)?;
      71              : 
      72           17 :         Some(Self {
      73           17 :             bare,
      74           17 :             cbind_flag,
      75           17 :             nonce,
      76           17 :         })
      77           21 :     }
      78              : 
      79              :     /// Build a response to [`ClientFirstMessage`].
      80           12 :     pub(crate) fn build_server_first_message(
      81           12 :         &self,
      82           12 :         nonce: &[u8; SCRAM_RAW_NONCE_LEN],
      83           12 :         salt_base64: &str,
      84           12 :         iterations: u32,
      85           12 :     ) -> OwnedServerFirstMessage {
      86              :         use std::fmt::Write;
      87              : 
      88           12 :         let mut message = String::new();
      89           12 :         write!(&mut message, "r={}", self.nonce).unwrap();
      90           12 :         base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
      91           12 :         let combined_nonce = 2..message.len();
      92           12 :         write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
      93           12 : 
      94           12 :         // This design guarantees that it's impossible to create a
      95           12 :         // server-first-message without receiving a client-first-message
      96           12 :         OwnedServerFirstMessage {
      97           12 :             message,
      98           12 :             nonce: combined_nonce,
      99           12 :         }
     100           12 :     }
     101              : }
     102              : 
     103              : #[derive(Debug)]
     104              : pub(crate) struct ClientFinalMessage<'a> {
     105              :     /// `client-final-message-without-proof`.
     106              :     pub(crate) without_proof: &'a str,
     107              :     /// Channel binding data (base64).
     108              :     pub(crate) channel_binding: &'a str,
     109              :     /// Combined client & server nonce.
     110              :     pub(crate) nonce: &'a str,
     111              :     /// Client auth proof.
     112              :     pub(crate) proof: [u8; SCRAM_KEY_LEN],
     113              : }
     114              : 
     115              : impl<'a> ClientFinalMessage<'a> {
     116              :     // NB: FromStr doesn't work with lifetimes
     117           13 :     pub(crate) fn parse(input: &'a str) -> Option<Self> {
     118           13 :         let (without_proof, proof) = input.rsplit_once(',')?;
     119              : 
     120           13 :         let mut parts = without_proof.split(',');
     121           13 :         let channel_binding = parts.next()?.strip_prefix("c=")?;
     122           13 :         let nonce = parts.next()?.strip_prefix("r=")?;
     123              : 
     124              :         // Validate but ignore auth extensions
     125           13 :         validate_sasl_extensions(parts)?;
     126              : 
     127           13 :         let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
     128              : 
     129           13 :         Some(Self {
     130           13 :             without_proof,
     131           13 :             channel_binding,
     132           13 :             nonce,
     133           13 :             proof,
     134           13 :         })
     135           13 :     }
     136              : 
     137              :     /// Build a response to [`ClientFinalMessage`].
     138            7 :     pub(crate) fn build_server_final_message(
     139            7 :         &self,
     140            7 :         signature_builder: SignatureBuilder<'_>,
     141            7 :         server_key: &ScramKey,
     142            7 :     ) -> String {
     143            7 :         let mut buf = String::from("v=");
     144            7 :         base64::encode_config_buf(
     145            7 :             signature_builder.build(server_key),
     146            7 :             base64::STANDARD,
     147            7 :             &mut buf,
     148            7 :         );
     149            7 : 
     150            7 :         buf
     151            7 :     }
     152              : }
     153              : 
     154              : /// We need to keep a convenient representation of this
     155              : /// message for the next authentication step.
     156              : pub(crate) struct OwnedServerFirstMessage {
     157              :     /// Owned `server-first-message`.
     158              :     message: String,
     159              :     /// Slice into `message`.
     160              :     nonce: Range<usize>,
     161              : }
     162              : 
     163              : impl OwnedServerFirstMessage {
     164              :     /// Extract combined nonce from the message.
     165              :     #[inline(always)]
     166            8 :     pub(crate) fn nonce(&self) -> &str {
     167            8 :         &self.message[self.nonce.clone()]
     168            8 :     }
     169              : 
     170              :     /// Get reference to a text representation of the message.
     171              :     #[inline(always)]
     172           20 :     pub(crate) fn as_str(&self) -> &str {
     173           20 :         &self.message
     174           20 :     }
     175              : }
     176              : 
     177              : impl fmt::Debug for OwnedServerFirstMessage {
     178            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     179            0 :         f.debug_struct("ServerFirstMessage")
     180            0 :             .field("message", &self.as_str())
     181            0 :             .field("nonce", &self.nonce())
     182            0 :             .finish()
     183            0 :     }
     184              : }
     185              : 
     186              : #[cfg(test)]
     187              : mod tests {
     188              :     use super::*;
     189              : 
     190              :     #[test]
     191            1 :     fn parse_client_first_message() {
     192              :         use ChannelBinding::*;
     193              : 
     194              :         // (Almost) real strings captured during debug sessions
     195            1 :         let cases = [
     196            1 :             (NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
     197            1 :             (NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
     198            1 :             (
     199            1 :                 Required("tls-server-end-point"),
     200            1 :                 "p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
     201            1 :             ),
     202            1 :         ];
     203              : 
     204            4 :         for (cb, input) in cases {
     205            3 :             let msg = ClientFirstMessage::parse(input).unwrap();
     206            3 : 
     207            3 :             assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
     208            3 :             assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
     209            3 :             assert_eq!(msg.cbind_flag, cb);
     210              :         }
     211            1 :     }
     212              : 
     213              :     #[test]
     214            1 :     fn parse_client_first_message_with_invalid_gs2_authz() {
     215            1 :         assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none());
     216            1 :     }
     217              : 
     218              :     #[test]
     219            1 :     fn parse_client_first_message_with_extra_params() {
     220            1 :         let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
     221            1 :         assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
     222            1 :         assert_eq!(msg.nonce, "nonce");
     223            1 :         assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
     224            1 :     }
     225              : 
     226              :     #[test]
     227            1 :     fn parse_client_first_message_with_extra_params_invalid() {
     228            1 :         // must be of the form `<ascii letter>=<...>`
     229            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
     230            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
     231            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
     232            1 :     }
     233              : 
     234              :     #[test]
     235            1 :     fn parse_client_final_message() {
     236            1 :         let input = [
     237            1 :             "c=eSws",
     238            1 :             "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
     239            1 :             "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
     240            1 :         ]
     241            1 :         .join(",");
     242            1 : 
     243            1 :         let msg = ClientFinalMessage::parse(&input).unwrap();
     244            1 :         assert_eq!(
     245            1 :             msg.without_proof,
     246            1 :             "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     247            1 :         );
     248            1 :         assert_eq!(
     249            1 :             msg.nonce,
     250            1 :             "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     251            1 :         );
     252            1 :         assert_eq!(
     253            1 :             base64::encode(msg.proof),
     254            1 :             "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
     255            1 :         );
     256            1 :     }
     257              : }
        

Generated by: LCOV version 2.1-beta