LCOV - differential code coverage report
Current view: top level - proxy/src/sasl - channel_binding.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 97.7 % 43 42 1 42
Current Date: 2024-01-09 02:06:09 Functions: 87.5 % 8 7 1 7
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Definition and parser for channel binding flag (a part of the `GS2` header).
       2                 : 
       3                 : /// Channel binding flag (possibly with params).
       4 CBC          53 : #[derive(Debug, PartialEq, Eq)]
       5                 : pub enum ChannelBinding<T> {
       6                 :     /// Client doesn't support channel binding.
       7                 :     NotSupportedClient,
       8                 :     /// Client thinks server doesn't support channel binding.
       9                 :     NotSupportedServer,
      10                 :     /// Client wants to use this type of channel binding.
      11                 :     Required(T),
      12                 : }
      13                 : 
      14                 : impl<T> ChannelBinding<T> {
      15              49 :     pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
      16              49 :         use ChannelBinding::*;
      17              49 :         Ok(match self {
      18               7 :             NotSupportedClient => NotSupportedClient,
      19 UBC           0 :             NotSupportedServer => NotSupportedServer,
      20 CBC          42 :             Required(x) => Required(f(x)?),
      21                 :         })
      22              49 :     }
      23                 : }
      24                 : 
      25                 : impl<'a> ChannelBinding<&'a str> {
      26                 :     // NB: FromStr doesn't work with lifetimes
      27              53 :     pub fn parse(input: &'a str) -> Option<Self> {
      28              53 :         use ChannelBinding::*;
      29              53 :         Some(match input {
      30              53 :             "n" => NotSupportedClient,
      31              45 :             "y" => NotSupportedServer,
      32              43 :             other => Required(other.strip_prefix("p=")?),
      33                 :         })
      34              53 :     }
      35                 : }
      36                 : 
      37                 : impl<T: std::fmt::Display> ChannelBinding<T> {
      38                 :     /// Encode channel binding data as base64 for subsequent checks.
      39              52 :     pub fn encode<'a, E>(
      40              52 :         &self,
      41              52 :         get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
      42              52 :     ) -> Result<std::borrow::Cow<'static, str>, E> {
      43              52 :         use ChannelBinding::*;
      44              52 :         Ok(match self {
      45                 :             NotSupportedClient => {
      46                 :                 // base64::encode("n,,")
      47               8 :                 "biws".into()
      48                 :             }
      49                 :             NotSupportedServer => {
      50                 :                 // base64::encode("y,,")
      51               1 :                 "eSws".into()
      52                 :             }
      53              43 :             Required(mode) => {
      54              43 :                 use std::io::Write;
      55              43 :                 let mut cbind_input = vec![];
      56              43 :                 write!(&mut cbind_input, "p={mode},,",).unwrap();
      57              43 :                 cbind_input.extend_from_slice(get_cbind_data(mode)?);
      58              43 :                 base64::encode(&cbind_input).into()
      59                 :             }
      60                 :         })
      61              52 :     }
      62                 : }
      63                 : 
      64                 : #[cfg(test)]
      65                 : mod tests {
      66                 :     use super::*;
      67                 : 
      68               1 :     #[test]
      69               1 :     fn channel_binding_encode() -> anyhow::Result<()> {
      70               1 :         use ChannelBinding::*;
      71               1 : 
      72               1 :         let cases = [
      73               1 :             (NotSupportedClient, base64::encode("n,,")),
      74               1 :             (NotSupportedServer, base64::encode("y,,")),
      75               1 :             (Required("foo"), base64::encode("p=foo,,bar")),
      76               1 :         ];
      77                 : 
      78               4 :         for (cb, input) in cases {
      79               3 :             assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
      80                 :         }
      81                 : 
      82               1 :         Ok(())
      83               1 :     }
      84                 : }
        

Generated by: LCOV version 2.1-beta