LCOV - code coverage report
Current view: top level - proxy/src/scram - messages.rs (source / functions) Coverage Total Hit
Test: b4ae4c4857f9ef3e144e982a35ee23bc84c71983.info Lines: 95.9 % 147 141
Test Date: 2024-10-22 22:13:45 Functions: 92.9 % 14 13

            Line data    Source code
       1              : //! Definitions for SCRAM messages.
       2              : 
       3              : use std::fmt;
       4              : use std::ops::Range;
       5              : 
       6              : use super::base64_decode_array;
       7              : use super::key::{ScramKey, SCRAM_KEY_LEN};
       8              : use super::signature::SignatureBuilder;
       9              : use crate::sasl::ChannelBinding;
      10              : 
      11              : /// Faithfully taken from PostgreSQL.
      12              : pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18;
      13              : 
      14              : /// Although we ignore all extensions, we still have to validate the message.
      15           33 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
      16           33 :     for mut chars in parts.map(|s| s.chars()) {
      17            6 :         let attr = chars.next()?;
      18            6 :         if !attr.is_ascii_alphabetic() {
      19            1 :             return None;
      20            5 :         }
      21            5 :         let eq = chars.next()?;
      22            4 :         if eq != '=' {
      23            1 :             return None;
      24            3 :         }
      25              :     }
      26              : 
      27           30 :     Some(())
      28           33 : }
      29              : 
      30              : #[derive(Debug)]
      31              : pub(crate) struct ClientFirstMessage<'a> {
      32              :     /// `client-first-message-bare`.
      33              :     pub(crate) bare: &'a str,
      34              :     /// Channel binding mode.
      35              :     pub(crate) cbind_flag: ChannelBinding<&'a str>,
      36              :     /// Client nonce.
      37              :     pub(crate) nonce: &'a str,
      38              : }
      39              : 
      40              : impl<'a> ClientFirstMessage<'a> {
      41              :     // NB: FromStr doesn't work with lifetimes
      42           21 :     pub(crate) fn parse(input: &'a str) -> Option<Self> {
      43           21 :         let mut parts = input.split(',');
      44              : 
      45           21 :         let cbind_flag = ChannelBinding::parse(parts.next()?)?;
      46              : 
      47              :         // PG doesn't support authorization identity,
      48              :         // so we don't bother defining GS2 header type
      49           21 :         let authzid = parts.next()?;
      50           21 :         if !authzid.is_empty() {
      51            1 :             return None;
      52           20 :         }
      53           20 : 
      54           20 :         // Unfortunately, `parts.as_str()` is unstable
      55           20 :         let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
      56           20 :         let (_, bare) = input.split_at(pos);
      57              : 
      58              :         // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
      59           20 :         let username = parts.next()?.strip_prefix("n=")?;
      60              : 
      61              :         // https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
      62           20 :         if !username.is_empty() {
      63            1 :             tracing::warn!(username, "scram username provided, but is not expected");
      64              :             // TODO(conrad):
      65              :             // return None;
      66           19 :         }
      67              : 
      68           20 :         let nonce = parts.next()?.strip_prefix("r=")?;
      69              : 
      70              :         // Validate but ignore auth extensions
      71           20 :         validate_sasl_extensions(parts)?;
      72              : 
      73           17 :         Some(Self {
      74           17 :             bare,
      75           17 :             cbind_flag,
      76           17 :             nonce,
      77           17 :         })
      78           21 :     }
      79              : 
      80              :     /// Build a response to [`ClientFirstMessage`].
      81           12 :     pub(crate) fn build_server_first_message(
      82           12 :         &self,
      83           12 :         nonce: &[u8; SCRAM_RAW_NONCE_LEN],
      84           12 :         salt_base64: &str,
      85           12 :         iterations: u32,
      86           12 :     ) -> OwnedServerFirstMessage {
      87              :         use std::fmt::Write;
      88              : 
      89           12 :         let mut message = String::new();
      90           12 :         write!(&mut message, "r={}", self.nonce).unwrap();
      91           12 :         base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
      92           12 :         let combined_nonce = 2..message.len();
      93           12 :         write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
      94           12 : 
      95           12 :         // This design guarantees that it's impossible to create a
      96           12 :         // server-first-message without receiving a client-first-message
      97           12 :         OwnedServerFirstMessage {
      98           12 :             message,
      99           12 :             nonce: combined_nonce,
     100           12 :         }
     101           12 :     }
     102              : }
     103              : 
     104              : #[derive(Debug)]
     105              : pub(crate) struct ClientFinalMessage<'a> {
     106              :     /// `client-final-message-without-proof`.
     107              :     pub(crate) without_proof: &'a str,
     108              :     /// Channel binding data (base64).
     109              :     pub(crate) channel_binding: &'a str,
     110              :     /// Combined client & server nonce.
     111              :     pub(crate) nonce: &'a str,
     112              :     /// Client auth proof.
     113              :     pub(crate) proof: [u8; SCRAM_KEY_LEN],
     114              : }
     115              : 
     116              : impl<'a> ClientFinalMessage<'a> {
     117              :     // NB: FromStr doesn't work with lifetimes
     118           13 :     pub(crate) fn parse(input: &'a str) -> Option<Self> {
     119           13 :         let (without_proof, proof) = input.rsplit_once(',')?;
     120              : 
     121           13 :         let mut parts = without_proof.split(',');
     122           13 :         let channel_binding = parts.next()?.strip_prefix("c=")?;
     123           13 :         let nonce = parts.next()?.strip_prefix("r=")?;
     124              : 
     125              :         // Validate but ignore auth extensions
     126           13 :         validate_sasl_extensions(parts)?;
     127              : 
     128           13 :         let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
     129              : 
     130           13 :         Some(Self {
     131           13 :             without_proof,
     132           13 :             channel_binding,
     133           13 :             nonce,
     134           13 :             proof,
     135           13 :         })
     136           13 :     }
     137              : 
     138              :     /// Build a response to [`ClientFinalMessage`].
     139            7 :     pub(crate) fn build_server_final_message(
     140            7 :         &self,
     141            7 :         signature_builder: SignatureBuilder<'_>,
     142            7 :         server_key: &ScramKey,
     143            7 :     ) -> String {
     144            7 :         let mut buf = String::from("v=");
     145            7 :         base64::encode_config_buf(
     146            7 :             signature_builder.build(server_key),
     147            7 :             base64::STANDARD,
     148            7 :             &mut buf,
     149            7 :         );
     150            7 : 
     151            7 :         buf
     152            7 :     }
     153              : }
     154              : 
     155              : /// We need to keep a convenient representation of this
     156              : /// message for the next authentication step.
     157              : pub(crate) struct OwnedServerFirstMessage {
     158              :     /// Owned `server-first-message`.
     159              :     message: String,
     160              :     /// Slice into `message`.
     161              :     nonce: Range<usize>,
     162              : }
     163              : 
     164              : impl OwnedServerFirstMessage {
     165              :     /// Extract combined nonce from the message.
     166              :     #[inline(always)]
     167            8 :     pub(crate) fn nonce(&self) -> &str {
     168            8 :         &self.message[self.nonce.clone()]
     169            8 :     }
     170              : 
     171              :     /// Get reference to a text representation of the message.
     172              :     #[inline(always)]
     173           20 :     pub(crate) fn as_str(&self) -> &str {
     174           20 :         &self.message
     175           20 :     }
     176              : }
     177              : 
     178              : impl fmt::Debug for OwnedServerFirstMessage {
     179            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     180            0 :         f.debug_struct("ServerFirstMessage")
     181            0 :             .field("message", &self.as_str())
     182            0 :             .field("nonce", &self.nonce())
     183            0 :             .finish()
     184            0 :     }
     185              : }
     186              : 
     187              : #[cfg(test)]
     188              : mod tests {
     189              :     use super::*;
     190              : 
     191              :     #[test]
     192            1 :     fn parse_client_first_message() {
     193              :         use ChannelBinding::*;
     194              : 
     195              :         // (Almost) real strings captured during debug sessions
     196            1 :         let cases = [
     197            1 :             (NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
     198            1 :             (NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
     199            1 :             (
     200            1 :                 Required("tls-server-end-point"),
     201            1 :                 "p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
     202            1 :             ),
     203            1 :         ];
     204              : 
     205            4 :         for (cb, input) in cases {
     206            3 :             let msg = ClientFirstMessage::parse(input).unwrap();
     207            3 : 
     208            3 :             assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
     209            3 :             assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
     210            3 :             assert_eq!(msg.cbind_flag, cb);
     211              :         }
     212            1 :     }
     213              : 
     214              :     #[test]
     215            1 :     fn parse_client_first_message_with_invalid_gs2_authz() {
     216            1 :         assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none());
     217            1 :     }
     218              : 
     219              :     #[test]
     220            1 :     fn parse_client_first_message_with_extra_params() {
     221            1 :         let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
     222            1 :         assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
     223            1 :         assert_eq!(msg.nonce, "nonce");
     224            1 :         assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
     225            1 :     }
     226              : 
     227              :     #[test]
     228            1 :     fn parse_client_first_message_with_extra_params_invalid() {
     229            1 :         // must be of the form `<ascii letter>=<...>`
     230            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
     231            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
     232            1 :         assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
     233            1 :     }
     234              : 
     235              :     #[test]
     236            1 :     fn parse_client_final_message() {
     237            1 :         let input = [
     238            1 :             "c=eSws",
     239            1 :             "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
     240            1 :             "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
     241            1 :         ]
     242            1 :         .join(",");
     243            1 : 
     244            1 :         let msg = ClientFinalMessage::parse(&input).unwrap();
     245            1 :         assert_eq!(
     246            1 :             msg.without_proof,
     247            1 :             "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     248            1 :         );
     249            1 :         assert_eq!(
     250            1 :             msg.nonce,
     251            1 :             "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     252            1 :         );
     253            1 :         assert_eq!(
     254            1 :             base64::encode(msg.proof),
     255            1 :             "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
     256            1 :         );
     257            1 :     }
     258              : }
        

Generated by: LCOV version 2.1-beta