LCOV - code coverage report
Current view: top level - proxy/src/sasl - channel_binding.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 97.4 % 39 38
Test Date: 2025-07-16 12:29:03 Functions: 100.0 % 6 6

            Line data    Source code
       1              : //! Definition and parser for channel binding flag (a part of the `GS2` header).
       2              : 
       3              : use base64::Engine as _;
       4              : use base64::prelude::BASE64_STANDARD;
       5              : 
       6              : /// Channel binding flag (possibly with params).
       7              : #[derive(Debug, PartialEq, Eq)]
       8              : pub(crate) enum ChannelBinding<T> {
       9              :     /// Client doesn't support channel binding.
      10              :     NotSupportedClient,
      11              :     /// Client thinks server doesn't support channel binding.
      12              :     NotSupportedServer,
      13              :     /// Client wants to use this type of channel binding.
      14              :     Required(T),
      15              : }
      16              : 
      17              : impl<T> ChannelBinding<T> {
      18           12 :     pub(crate) fn and_then<R, E>(
      19           12 :         self,
      20           12 :         f: impl FnOnce(T) -> Result<R, E>,
      21           12 :     ) -> Result<ChannelBinding<R>, E> {
      22           12 :         Ok(match self {
      23            6 :             Self::NotSupportedClient => ChannelBinding::NotSupportedClient,
      24            0 :             Self::NotSupportedServer => ChannelBinding::NotSupportedServer,
      25            6 :             Self::Required(x) => ChannelBinding::Required(f(x)?),
      26              :         })
      27           12 :     }
      28              : }
      29              : 
      30              : impl<'a> ChannelBinding<&'a str> {
      31              :     // NB: FromStr doesn't work with lifetimes
      32           21 :     pub(crate) fn parse(input: &'a str) -> Option<Self> {
      33           21 :         Some(match input {
      34           21 :             "n" => Self::NotSupportedClient,
      35            9 :             "y" => Self::NotSupportedServer,
      36            7 :             other => Self::Required(other.strip_prefix("p=")?),
      37              :         })
      38           21 :     }
      39              : }
      40              : 
      41              : impl<T: std::fmt::Display> ChannelBinding<T> {
      42              :     /// Encode channel binding data as base64 for subsequent checks.
      43           15 :     pub(crate) fn encode<'a, E>(
      44           15 :         &self,
      45           15 :         get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
      46           15 :     ) -> Result<std::borrow::Cow<'static, str>, E> {
      47           15 :         Ok(match self {
      48            7 :             Self::NotSupportedClient => {
      49              :                 // base64::encode("n,,")
      50            7 :                 "biws".into()
      51              :             }
      52            1 :             Self::NotSupportedServer => {
      53              :                 // base64::encode("y,,")
      54            1 :                 "eSws".into()
      55              :             }
      56            7 :             Self::Required(mode) => {
      57            7 :                 let mut cbind_input = format!("p={mode},,",).into_bytes();
      58            7 :                 cbind_input.extend_from_slice(get_cbind_data(mode)?);
      59            7 :                 BASE64_STANDARD.encode(&cbind_input).into()
      60              :             }
      61              :         })
      62           15 :     }
      63              : }
      64              : 
      65              : #[cfg(test)]
      66              : mod tests {
      67              :     use super::*;
      68              : 
      69              :     #[test]
      70            1 :     fn channel_binding_encode() -> anyhow::Result<()> {
      71              :         use ChannelBinding::*;
      72              : 
      73            1 :         let cases = [
      74            1 :             (NotSupportedClient, BASE64_STANDARD.encode("n,,")),
      75            1 :             (NotSupportedServer, BASE64_STANDARD.encode("y,,")),
      76            1 :             (Required("foo"), BASE64_STANDARD.encode("p=foo,,bar")),
      77            1 :         ];
      78              : 
      79            4 :         for (cb, input) in cases {
      80            3 :             assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
      81              :         }
      82              : 
      83            1 :         Ok(())
      84            1 :     }
      85              : }
        

Generated by: LCOV version 2.1-beta