LCOV - differential code coverage report
Current view: top level - proxy/src/scram - messages.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 87.9 % 140 123 17 123
Current Date: 2024-01-09 02:06:09 Functions: 73.3 % 15 11 4 11
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Definitions for SCRAM messages.
       2                 : 
       3                 : use super::base64_decode_array;
       4                 : use super::key::{ScramKey, SCRAM_KEY_LEN};
       5                 : use super::signature::SignatureBuilder;
       6                 : use crate::sasl::ChannelBinding;
       7                 : use std::fmt;
       8                 : use std::ops::Range;
       9                 : 
      10                 : /// Faithfully taken from PostgreSQL.
      11                 : pub const SCRAM_RAW_NONCE_LEN: usize = 18;
      12                 : 
      13                 : /// Although we ignore all extensions, we still have to validate the message.
      14 CBC         103 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
      15             103 :     for mut chars in parts.map(|s| s.chars()) {
      16 UBC           0 :         let attr = chars.next()?;
      17               0 :         if !attr.is_ascii_alphabetic() {
      18               0 :             return None;
      19               0 :         }
      20               0 :         let eq = chars.next()?;
      21               0 :         if eq != '=' {
      22               0 :             return None;
      23               0 :         }
      24                 :     }
      25                 : 
      26 CBC         103 :     Some(())
      27             103 : }
      28                 : 
      29 UBC           0 : #[derive(Debug)]
      30                 : pub struct ClientFirstMessage<'a> {
      31                 :     /// `client-first-message-bare`.
      32                 :     pub bare: &'a str,
      33                 :     /// Channel binding mode.
      34                 :     pub cbind_flag: ChannelBinding<&'a str>,
      35                 :     /// (Client username)[<https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf/src/backend/libpq/auth-scram.c#L13>].
      36                 :     pub username: &'a str,
      37                 :     /// Client nonce.
      38                 :     pub nonce: &'a str,
      39                 : }
      40                 : 
      41                 : impl<'a> ClientFirstMessage<'a> {
      42                 :     // NB: FromStr doesn't work with lifetimes
      43 CBC          53 :     pub fn parse(input: &'a str) -> Option<Self> {
      44              53 :         let mut parts = input.split(',');
      45                 : 
      46              53 :         let cbind_flag = ChannelBinding::parse(parts.next()?)?;
      47                 : 
      48                 :         // PG doesn't support authorization identity,
      49                 :         // so we don't bother defining GS2 header type
      50              53 :         let authzid = parts.next()?;
      51              53 :         if !authzid.is_empty() {
      52 UBC           0 :             return None;
      53 CBC          53 :         }
      54              53 : 
      55              53 :         // Unfortunately, `parts.as_str()` is unstable
      56              53 :         let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
      57              53 :         let (_, bare) = input.split_at(pos);
      58                 : 
      59                 :         // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
      60              53 :         let username = parts.next()?.strip_prefix("n=")?;
      61              53 :         let nonce = parts.next()?.strip_prefix("r=")?;
      62                 : 
      63                 :         // Validate but ignore auth extensions
      64              53 :         validate_sasl_extensions(parts)?;
      65                 : 
      66              53 :         Some(Self {
      67              53 :             bare,
      68              53 :             cbind_flag,
      69              53 :             username,
      70              53 :             nonce,
      71              53 :         })
      72              53 :     }
      73                 : 
      74                 :     /// Build a response to [`ClientFirstMessage`].
      75              49 :     pub fn build_server_first_message(
      76              49 :         &self,
      77              49 :         nonce: &[u8; SCRAM_RAW_NONCE_LEN],
      78              49 :         salt_base64: &str,
      79              49 :         iterations: u32,
      80              49 :     ) -> OwnedServerFirstMessage {
      81              49 :         use std::fmt::Write;
      82              49 : 
      83              49 :         let mut message = String::new();
      84              49 :         write!(&mut message, "r={}", self.nonce).unwrap();
      85              49 :         base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
      86              49 :         let combined_nonce = 2..message.len();
      87              49 :         write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
      88              49 : 
      89              49 :         // This design guarantees that it's impossible to create a
      90              49 :         // server-first-message without receiving a client-first-message
      91              49 :         OwnedServerFirstMessage {
      92              49 :             message,
      93              49 :             nonce: combined_nonce,
      94              49 :         }
      95              49 :     }
      96                 : }
      97                 : 
      98 UBC           0 : #[derive(Debug)]
      99                 : pub struct ClientFinalMessage<'a> {
     100                 :     /// `client-final-message-without-proof`.
     101                 :     pub without_proof: &'a str,
     102                 :     /// Channel binding data (base64).
     103                 :     pub channel_binding: &'a str,
     104                 :     /// Combined client & server nonce.
     105                 :     pub nonce: &'a str,
     106                 :     /// Client auth proof.
     107                 :     pub proof: [u8; SCRAM_KEY_LEN],
     108                 : }
     109                 : 
     110                 : impl<'a> ClientFinalMessage<'a> {
     111                 :     // NB: FromStr doesn't work with lifetimes
     112 CBC          50 :     pub fn parse(input: &'a str) -> Option<Self> {
     113              50 :         let (without_proof, proof) = input.rsplit_once(',')?;
     114                 : 
     115              50 :         let mut parts = without_proof.split(',');
     116              50 :         let channel_binding = parts.next()?.strip_prefix("c=")?;
     117              50 :         let nonce = parts.next()?.strip_prefix("r=")?;
     118                 : 
     119                 :         // Validate but ignore auth extensions
     120              50 :         validate_sasl_extensions(parts)?;
     121                 : 
     122              50 :         let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
     123                 : 
     124              50 :         Some(Self {
     125              50 :             without_proof,
     126              50 :             channel_binding,
     127              50 :             nonce,
     128              50 :             proof,
     129              50 :         })
     130              50 :     }
     131                 : 
     132                 :     /// Build a response to [`ClientFinalMessage`].
     133              41 :     pub fn build_server_final_message(
     134              41 :         &self,
     135              41 :         signature_builder: SignatureBuilder,
     136              41 :         server_key: &ScramKey,
     137              41 :     ) -> String {
     138              41 :         let mut buf = String::from("v=");
     139              41 :         base64::encode_config_buf(
     140              41 :             signature_builder.build(server_key),
     141              41 :             base64::STANDARD,
     142              41 :             &mut buf,
     143              41 :         );
     144              41 : 
     145              41 :         buf
     146              41 :     }
     147                 : }
     148                 : 
     149                 : /// We need to keep a convenient representation of this
     150                 : /// message for the next authentication step.
     151                 : pub struct OwnedServerFirstMessage {
     152                 :     /// Owned `server-first-message`.
     153                 :     message: String,
     154                 :     /// Slice into `message`.
     155                 :     nonce: Range<usize>,
     156                 : }
     157                 : 
     158                 : impl OwnedServerFirstMessage {
     159                 :     /// Extract combined nonce from the message.
     160                 :     #[inline(always)]
     161              45 :     pub fn nonce(&self) -> &str {
     162              45 :         &self.message[self.nonce.clone()]
     163              45 :     }
     164                 : 
     165                 :     /// Get reference to a text representation of the message.
     166                 :     #[inline(always)]
     167              94 :     pub fn as_str(&self) -> &str {
     168              94 :         &self.message
     169              94 :     }
     170                 : }
     171                 : 
     172                 : impl fmt::Debug for OwnedServerFirstMessage {
     173 UBC           0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     174               0 :         f.debug_struct("ServerFirstMessage")
     175               0 :             .field("message", &self.as_str())
     176               0 :             .field("nonce", &self.nonce())
     177               0 :             .finish()
     178               0 :     }
     179                 : }
     180                 : 
     181                 : #[cfg(test)]
     182                 : mod tests {
     183                 :     use super::*;
     184                 : 
     185 CBC           1 :     #[test]
     186               1 :     fn parse_client_first_message() {
     187               1 :         use ChannelBinding::*;
     188               1 : 
     189               1 :         // (Almost) real strings captured during debug sessions
     190               1 :         let cases = [
     191               1 :             (NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
     192               1 :             (NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
     193               1 :             (
     194               1 :                 Required("tls-server-end-point"),
     195               1 :                 "p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju",
     196               1 :             ),
     197               1 :         ];
     198                 : 
     199               4 :         for (cb, input) in cases {
     200               3 :             let msg = ClientFirstMessage::parse(input).unwrap();
     201               3 : 
     202               3 :             assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju");
     203               3 :             assert_eq!(msg.username, "pepe");
     204               3 :             assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
     205               3 :             assert_eq!(msg.cbind_flag, cb);
     206                 :         }
     207               1 :     }
     208                 : 
     209               1 :     #[test]
     210               1 :     fn parse_client_final_message() {
     211               1 :         let input = [
     212               1 :             "c=eSws",
     213               1 :             "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
     214               1 :             "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
     215               1 :         ]
     216               1 :         .join(",");
     217               1 : 
     218               1 :         let msg = ClientFinalMessage::parse(&input).unwrap();
     219               1 :         assert_eq!(
     220               1 :             msg.without_proof,
     221               1 :             "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     222               1 :         );
     223               1 :         assert_eq!(
     224               1 :             msg.nonce,
     225               1 :             "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     226               1 :         );
     227               1 :         assert_eq!(
     228               1 :             base64::encode(msg.proof),
     229               1 :             "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
     230               1 :         );
     231               1 :     }
     232                 : }
        

Generated by: LCOV version 2.1-beta