LCOV - code coverage report
Current view: top level - proxy/src/sasl - channel_binding.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 97.7 % 43 42
Test Date: 2023-09-06 10:18:01 Functions: 77.8 % 9 7

            Line data    Source code
       1              : //! Definition and parser for channel binding flag (a part of the `GS2` header).
       2              : 
       3              : /// Channel binding flag (possibly with params).
       4            3 : #[derive(Debug, PartialEq, Eq)]
       5              : pub enum ChannelBinding<T> {
       6              :     /// Client doesn't support channel binding.
       7              :     NotSupportedClient,
       8              :     /// Client thinks server doesn't support channel binding.
       9              :     NotSupportedServer,
      10              :     /// Client wants to use this type of channel binding.
      11              :     Required(T),
      12              : }
      13              : 
      14              : impl<T> ChannelBinding<T> {
      15           30 :     pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
      16           30 :         use ChannelBinding::*;
      17           30 :         Ok(match self {
      18            1 :             NotSupportedClient => NotSupportedClient,
      19           29 :             NotSupportedServer => NotSupportedServer,
      20            0 :             Required(x) => Required(f(x)?),
      21              :         })
      22           30 :     }
      23              : }
      24              : 
      25              : impl<'a> ChannelBinding<&'a str> {
      26              :     // NB: FromStr doesn't work with lifetimes
      27           33 :     pub fn parse(input: &'a str) -> Option<Self> {
      28           33 :         use ChannelBinding::*;
      29           33 :         Some(match input {
      30           33 :             "n" => NotSupportedClient,
      31           31 :             "y" => NotSupportedServer,
      32            1 :             other => Required(other.strip_prefix("p=")?),
      33              :         })
      34           33 :     }
      35              : }
      36              : 
      37              : impl<T: std::fmt::Display> ChannelBinding<T> {
      38              :     /// Encode channel binding data as base64 for subsequent checks.
      39           33 :     pub fn encode<E>(
      40           33 :         &self,
      41           33 :         get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
      42           33 :     ) -> Result<std::borrow::Cow<'static, str>, E> {
      43           33 :         use ChannelBinding::*;
      44           33 :         Ok(match self {
      45              :             NotSupportedClient => {
      46              :                 // base64::encode("n,,")
      47            2 :                 "biws".into()
      48              :             }
      49              :             NotSupportedServer => {
      50              :                 // base64::encode("y,,")
      51           30 :                 "eSws".into()
      52              :             }
      53            1 :             Required(mode) => {
      54            1 :                 let msg = format!(
      55            1 :                     "p={mode},,{data}",
      56            1 :                     mode = mode,
      57            1 :                     data = get_cbind_data(mode)?
      58              :                 );
      59            1 :                 base64::encode(msg).into()
      60              :             }
      61              :         })
      62           33 :     }
      63              : }
      64              : 
      65              : #[cfg(test)]
      66              : mod tests {
      67              :     use super::*;
      68              : 
      69            1 :     #[test]
      70            1 :     fn channel_binding_encode() -> anyhow::Result<()> {
      71            1 :         use ChannelBinding::*;
      72            1 : 
      73            1 :         let cases = [
      74            1 :             (NotSupportedClient, base64::encode("n,,")),
      75            1 :             (NotSupportedServer, base64::encode("y,,")),
      76            1 :             (Required("foo"), base64::encode("p=foo,,bar")),
      77            1 :         ];
      78              : 
      79            4 :         for (cb, input) in cases {
      80            3 :             assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
      81              :         }
      82              : 
      83            1 :         Ok(())
      84            1 :     }
      85              : }
        

Generated by: LCOV version 2.1-beta