LCOV - differential code coverage report
Current view: top level - proxy/src/scram - messages.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 87.9 % 140 123 17 123
Current Date: 2023-10-19 02:04:12 Functions: 73.3 % 15 11 4 11
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : //! Definitions for SCRAM messages.
       2                 : 
       3                 : use super::base64_decode_array;
       4                 : use super::key::{ScramKey, SCRAM_KEY_LEN};
       5                 : use super::signature::SignatureBuilder;
       6                 : use crate::sasl::ChannelBinding;
       7                 : use std::fmt;
       8                 : use std::ops::Range;
       9                 : 
      10                 : /// Faithfully taken from PostgreSQL.
      11                 : pub const SCRAM_RAW_NONCE_LEN: usize = 18;
      12                 : 
      13                 : /// Although we ignore all extensions, we still have to validate the message.
      14 CBC          68 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
      15              68 :     for mut chars in parts.map(|s| s.chars()) {
      16 UBC           0 :         let attr = chars.next()?;
      17               0 :         if !attr.is_ascii_alphabetic() {
      18               0 :             return None;
      19               0 :         }
      20               0 :         let eq = chars.next()?;
      21               0 :         if eq != '=' {
      22               0 :             return None;
      23               0 :         }
      24                 :     }
      25                 : 
      26 CBC          68 :     Some(())
      27              68 : }
      28                 : 
      29 UBC           0 : #[derive(Debug)]
      30                 : pub struct ClientFirstMessage<'a> {
      31                 :     /// `client-first-message-bare`.
      32                 :     pub bare: &'a str,
      33                 :     /// Channel binding mode.
      34                 :     pub cbind_flag: ChannelBinding<&'a str>,
      35                 :     /// (Client username)[<https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf/src/backend/libpq/auth-scram.c#L13>].
      36                 :     pub username: &'a str,
      37                 :     /// Client nonce.
      38                 :     pub nonce: &'a str,
      39                 : }
      40                 : 
      41                 : impl<'a> ClientFirstMessage<'a> {
      42                 :     // NB: FromStr doesn't work with lifetimes
      43 CBC          35 :     pub fn parse(input: &'a str) -> Option<Self> {
      44              35 :         let mut parts = input.split(',');
      45                 : 
      46              35 :         let cbind_flag = ChannelBinding::parse(parts.next()?)?;
      47                 : 
      48                 :         // PG doesn't support authorization identity,
      49                 :         // so we don't bother defining GS2 header type
      50              35 :         let authzid = parts.next()?;
      51              35 :         if !authzid.is_empty() {
      52 UBC           0 :             return None;
      53 CBC          35 :         }
      54              35 : 
      55              35 :         // Unfortunately, `parts.as_str()` is unstable
      56              35 :         let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
      57              35 :         let (_, bare) = input.split_at(pos);
      58                 : 
      59                 :         // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
      60              35 :         let username = parts.next()?.strip_prefix("n=")?;
      61              35 :         let nonce = parts.next()?.strip_prefix("r=")?;
      62                 : 
      63                 :         // Validate but ignore auth extensions
      64              35 :         validate_sasl_extensions(parts)?;
      65                 : 
      66              35 :         Some(Self {
      67              35 :             bare,
      68              35 :             cbind_flag,
      69              35 :             username,
      70              35 :             nonce,
      71              35 :         })
      72              35 :     }
      73                 : 
      74                 :     /// Build a response to [`ClientFirstMessage`].
      75              32 :     pub fn build_server_first_message(
      76              32 :         &self,
      77              32 :         nonce: &[u8; SCRAM_RAW_NONCE_LEN],
      78              32 :         salt_base64: &str,
      79              32 :         iterations: u32,
      80              32 :     ) -> OwnedServerFirstMessage {
      81              32 :         use std::fmt::Write;
      82              32 : 
      83              32 :         let mut message = String::new();
      84              32 :         write!(&mut message, "r={}", self.nonce).unwrap();
      85              32 :         base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
      86              32 :         let combined_nonce = 2..message.len();
      87              32 :         write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
      88              32 : 
      89              32 :         // This design guarantees that it's impossible to create a
      90              32 :         // server-first-message without receiving a client-first-message
      91              32 :         OwnedServerFirstMessage {
      92              32 :             message,
      93              32 :             nonce: combined_nonce,
      94              32 :         }
      95              32 :     }
      96                 : }
      97                 : 
      98 UBC           0 : #[derive(Debug)]
      99                 : pub struct ClientFinalMessage<'a> {
     100                 :     /// `client-final-message-without-proof`.
     101                 :     pub without_proof: &'a str,
     102                 :     /// Channel binding data (base64).
     103                 :     pub channel_binding: &'a str,
     104                 :     /// Combined client & server nonce.
     105                 :     pub nonce: &'a str,
     106                 :     /// Client auth proof.
     107                 :     pub proof: [u8; SCRAM_KEY_LEN],
     108                 : }
     109                 : 
     110                 : impl<'a> ClientFinalMessage<'a> {
     111                 :     // NB: FromStr doesn't work with lifetimes
     112 CBC          33 :     pub fn parse(input: &'a str) -> Option<Self> {
     113              33 :         let (without_proof, proof) = input.rsplit_once(',')?;
     114                 : 
     115              33 :         let mut parts = without_proof.split(',');
     116              33 :         let channel_binding = parts.next()?.strip_prefix("c=")?;
     117              33 :         let nonce = parts.next()?.strip_prefix("r=")?;
     118                 : 
     119                 :         // Validate but ignore auth extensions
     120              33 :         validate_sasl_extensions(parts)?;
     121                 : 
     122              33 :         let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
     123                 : 
     124              33 :         Some(Self {
     125              33 :             without_proof,
     126              33 :             channel_binding,
     127              33 :             nonce,
     128              33 :             proof,
     129              33 :         })
     130              33 :     }
     131                 : 
     132                 :     /// Build a response to [`ClientFinalMessage`].
     133              28 :     pub fn build_server_final_message(
     134              28 :         &self,
     135              28 :         signature_builder: SignatureBuilder,
     136              28 :         server_key: &ScramKey,
     137              28 :     ) -> String {
     138              28 :         let mut buf = String::from("v=");
     139              28 :         base64::encode_config_buf(
     140              28 :             signature_builder.build(server_key),
     141              28 :             base64::STANDARD,
     142              28 :             &mut buf,
     143              28 :         );
     144              28 : 
     145              28 :         buf
     146              28 :     }
     147                 : }
     148                 : 
     149                 : /// We need to keep a convenient representation of this
     150                 : /// message for the next authentication step.
     151                 : pub struct OwnedServerFirstMessage {
     152                 :     /// Owned `server-first-message`.
     153                 :     message: String,
     154                 :     /// Slice into `message`.
     155                 :     nonce: Range<usize>,
     156                 : }
     157                 : 
     158                 : impl OwnedServerFirstMessage {
     159                 :     /// Extract combined nonce from the message.
     160                 :     #[inline(always)]
     161              32 :     pub fn nonce(&self) -> &str {
     162              32 :         &self.message[self.nonce.clone()]
     163              32 :     }
     164                 : 
     165                 :     /// Get reference to a text representation of the message.
     166                 :     #[inline(always)]
     167              64 :     pub fn as_str(&self) -> &str {
     168              64 :         &self.message
     169              64 :     }
     170                 : }
     171                 : 
     172                 : impl fmt::Debug for OwnedServerFirstMessage {
     173 UBC           0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     174               0 :         f.debug_struct("ServerFirstMessage")
     175               0 :             .field("message", &self.as_str())
     176               0 :             .field("nonce", &self.nonce())
     177               0 :             .finish()
     178               0 :     }
     179                 : }
     180                 : 
     181                 : #[cfg(test)]
     182                 : mod tests {
     183                 :     use super::*;
     184                 : 
     185 CBC           1 :     #[test]
     186               1 :     fn parse_client_first_message() {
     187               1 :         use ChannelBinding::*;
     188               1 : 
     189               1 :         // (Almost) real strings captured during debug sessions
     190               1 :         let cases = [
     191               1 :             (NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
     192               1 :             (NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
     193               1 :             (
     194               1 :                 Required("tls-server-end-point"),
     195               1 :                 "p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju",
     196               1 :             ),
     197               1 :         ];
     198                 : 
     199               4 :         for (cb, input) in cases {
     200               3 :             let msg = ClientFirstMessage::parse(input).unwrap();
     201               3 : 
     202               3 :             assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju");
     203               3 :             assert_eq!(msg.username, "pepe");
     204               3 :             assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
     205               3 :             assert_eq!(msg.cbind_flag, cb);
     206                 :         }
     207               1 :     }
     208                 : 
     209               1 :     #[test]
     210               1 :     fn parse_client_final_message() {
     211               1 :         let input = [
     212               1 :             "c=eSws",
     213               1 :             "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
     214               1 :             "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
     215               1 :         ]
     216               1 :         .join(",");
     217               1 : 
     218               1 :         let msg = ClientFinalMessage::parse(&input).unwrap();
     219               1 :         assert_eq!(
     220               1 :             msg.without_proof,
     221               1 :             "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     222               1 :         );
     223               1 :         assert_eq!(
     224               1 :             msg.nonce,
     225               1 :             "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
     226               1 :         );
     227               1 :         assert_eq!(
     228               1 :             base64::encode(msg.proof),
     229               1 :             "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
     230               1 :         );
     231               1 :     }
     232                 : }
        

Generated by: LCOV version 2.1-beta