LCOV - code coverage report
Current view: top level - proxy/src/sasl - channel_binding.rs (source / functions) Coverage Total Hit
Test: 7eb96e224e685167ad85f58f858387d8cf253f63.info Lines: 97.6 % 42 41
Test Date: 2024-09-23 21:23:07 Functions: 100.0 % 6 6

            Line data    Source code
       1              : //! Definition and parser for channel binding flag (a part of the `GS2` header).
       2              : 
       3              : /// Channel binding flag (possibly with params).
       4              : #[derive(Debug, PartialEq, Eq)]
       5              : pub(crate) enum ChannelBinding<T> {
       6              :     /// Client doesn't support channel binding.
       7              :     NotSupportedClient,
       8              :     /// Client thinks server doesn't support channel binding.
       9              :     NotSupportedServer,
      10              :     /// Client wants to use this type of channel binding.
      11              :     Required(T),
      12              : }
      13              : 
      14              : impl<T> ChannelBinding<T> {
      15           12 :     pub(crate) fn and_then<R, E>(
      16           12 :         self,
      17           12 :         f: impl FnOnce(T) -> Result<R, E>,
      18           12 :     ) -> Result<ChannelBinding<R>, E> {
      19           12 :         Ok(match self {
      20            6 :             Self::NotSupportedClient => ChannelBinding::NotSupportedClient,
      21            0 :             Self::NotSupportedServer => ChannelBinding::NotSupportedServer,
      22            6 :             Self::Required(x) => ChannelBinding::Required(f(x)?),
      23              :         })
      24           12 :     }
      25              : }
      26              : 
      27              : impl<'a> ChannelBinding<&'a str> {
      28              :     // NB: FromStr doesn't work with lifetimes
      29           21 :     pub(crate) fn parse(input: &'a str) -> Option<Self> {
      30           21 :         Some(match input {
      31           21 :             "n" => Self::NotSupportedClient,
      32            9 :             "y" => Self::NotSupportedServer,
      33            7 :             other => Self::Required(other.strip_prefix("p=")?),
      34              :         })
      35           21 :     }
      36              : }
      37              : 
      38              : impl<T: std::fmt::Display> ChannelBinding<T> {
      39              :     /// Encode channel binding data as base64 for subsequent checks.
      40           15 :     pub(crate) fn encode<'a, E>(
      41           15 :         &self,
      42           15 :         get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
      43           15 :     ) -> Result<std::borrow::Cow<'static, str>, E> {
      44           15 :         Ok(match self {
      45            7 :             Self::NotSupportedClient => {
      46            7 :                 // base64::encode("n,,")
      47            7 :                 "biws".into()
      48              :             }
      49            1 :             Self::NotSupportedServer => {
      50            1 :                 // base64::encode("y,,")
      51            1 :                 "eSws".into()
      52              :             }
      53            7 :             Self::Required(mode) => {
      54              :                 use std::io::Write;
      55            7 :                 let mut cbind_input = vec![];
      56            7 :                 write!(&mut cbind_input, "p={mode},,",).unwrap();
      57            7 :                 cbind_input.extend_from_slice(get_cbind_data(mode)?);
      58            7 :                 base64::encode(&cbind_input).into()
      59              :             }
      60              :         })
      61           15 :     }
      62              : }
      63              : 
      64              : #[cfg(test)]
      65              : mod tests {
      66              :     use super::*;
      67              : 
      68              :     #[test]
      69            1 :     fn channel_binding_encode() -> anyhow::Result<()> {
      70              :         use ChannelBinding::*;
      71              : 
      72            1 :         let cases = [
      73            1 :             (NotSupportedClient, base64::encode("n,,")),
      74            1 :             (NotSupportedServer, base64::encode("y,,")),
      75            1 :             (Required("foo"), base64::encode("p=foo,,bar")),
      76            1 :         ];
      77              : 
      78            4 :         for (cb, input) in cases {
      79            3 :             assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
      80              :         }
      81              : 
      82            1 :         Ok(())
      83            1 :     }
      84              : }
        

Generated by: LCOV version 2.1-beta