Line data Source code
1 : //! Definition and parser for channel binding flag (a part of the `GS2` header).
2 :
3 : use base64::Engine as _;
4 : use base64::prelude::BASE64_STANDARD;
5 :
6 : /// Channel binding flag (possibly with params).
7 : #[derive(Debug, PartialEq, Eq)]
8 : pub(crate) enum ChannelBinding<T> {
9 : /// Client doesn't support channel binding.
10 : NotSupportedClient,
11 : /// Client thinks server doesn't support channel binding.
12 : NotSupportedServer,
13 : /// Client wants to use this type of channel binding.
14 : Required(T),
15 : }
16 :
17 : impl<T> ChannelBinding<T> {
18 12 : pub(crate) fn and_then<R, E>(
19 12 : self,
20 12 : f: impl FnOnce(T) -> Result<R, E>,
21 12 : ) -> Result<ChannelBinding<R>, E> {
22 12 : Ok(match self {
23 6 : Self::NotSupportedClient => ChannelBinding::NotSupportedClient,
24 0 : Self::NotSupportedServer => ChannelBinding::NotSupportedServer,
25 6 : Self::Required(x) => ChannelBinding::Required(f(x)?),
26 : })
27 12 : }
28 : }
29 :
30 : impl<'a> ChannelBinding<&'a str> {
31 : // NB: FromStr doesn't work with lifetimes
32 21 : pub(crate) fn parse(input: &'a str) -> Option<Self> {
33 21 : Some(match input {
34 21 : "n" => Self::NotSupportedClient,
35 9 : "y" => Self::NotSupportedServer,
36 7 : other => Self::Required(other.strip_prefix("p=")?),
37 : })
38 21 : }
39 : }
40 :
41 : impl<T: std::fmt::Display> ChannelBinding<T> {
42 : /// Encode channel binding data as base64 for subsequent checks.
43 15 : pub(crate) fn encode<'a, E>(
44 15 : &self,
45 15 : get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
46 15 : ) -> Result<std::borrow::Cow<'static, str>, E> {
47 15 : Ok(match self {
48 7 : Self::NotSupportedClient => {
49 : // base64::encode("n,,")
50 7 : "biws".into()
51 : }
52 1 : Self::NotSupportedServer => {
53 : // base64::encode("y,,")
54 1 : "eSws".into()
55 : }
56 7 : Self::Required(mode) => {
57 7 : let mut cbind_input = format!("p={mode},,",).into_bytes();
58 7 : cbind_input.extend_from_slice(get_cbind_data(mode)?);
59 7 : BASE64_STANDARD.encode(&cbind_input).into()
60 : }
61 : })
62 15 : }
63 : }
64 :
65 : #[cfg(test)]
66 : mod tests {
67 : use super::*;
68 :
69 : #[test]
70 1 : fn channel_binding_encode() -> anyhow::Result<()> {
71 : use ChannelBinding::*;
72 :
73 1 : let cases = [
74 1 : (NotSupportedClient, BASE64_STANDARD.encode("n,,")),
75 1 : (NotSupportedServer, BASE64_STANDARD.encode("y,,")),
76 1 : (Required("foo"), BASE64_STANDARD.encode("p=foo,,bar")),
77 1 : ];
78 :
79 4 : for (cb, input) in cases {
80 3 : assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
81 : }
82 :
83 1 : Ok(())
84 1 : }
85 : }
|