Line data Source code
1 : //! Definition and parser for channel binding flag (a part of the `GS2` header).
2 :
3 : /// Channel binding flag (possibly with params).
4 : #[derive(Debug, PartialEq, Eq)]
5 : pub(crate) enum ChannelBinding<T> {
6 : /// Client doesn't support channel binding.
7 : NotSupportedClient,
8 : /// Client thinks server doesn't support channel binding.
9 : NotSupportedServer,
10 : /// Client wants to use this type of channel binding.
11 : Required(T),
12 : }
13 :
14 : impl<T> ChannelBinding<T> {
15 72 : pub(crate) fn and_then<R, E>(
16 72 : self,
17 72 : f: impl FnOnce(T) -> Result<R, E>,
18 72 : ) -> Result<ChannelBinding<R>, E> {
19 72 : Ok(match self {
20 36 : Self::NotSupportedClient => ChannelBinding::NotSupportedClient,
21 0 : Self::NotSupportedServer => ChannelBinding::NotSupportedServer,
22 36 : Self::Required(x) => ChannelBinding::Required(f(x)?),
23 : })
24 72 : }
25 : }
26 :
27 : impl<'a> ChannelBinding<&'a str> {
28 : // NB: FromStr doesn't work with lifetimes
29 126 : pub(crate) fn parse(input: &'a str) -> Option<Self> {
30 126 : Some(match input {
31 126 : "n" => Self::NotSupportedClient,
32 54 : "y" => Self::NotSupportedServer,
33 42 : other => Self::Required(other.strip_prefix("p=")?),
34 : })
35 126 : }
36 : }
37 :
38 : impl<T: std::fmt::Display> ChannelBinding<T> {
39 : /// Encode channel binding data as base64 for subsequent checks.
40 90 : pub(crate) fn encode<'a, E>(
41 90 : &self,
42 90 : get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
43 90 : ) -> Result<std::borrow::Cow<'static, str>, E> {
44 90 : Ok(match self {
45 42 : Self::NotSupportedClient => {
46 42 : // base64::encode("n,,")
47 42 : "biws".into()
48 : }
49 6 : Self::NotSupportedServer => {
50 6 : // base64::encode("y,,")
51 6 : "eSws".into()
52 : }
53 42 : Self::Required(mode) => {
54 42 : use std::io::Write;
55 42 : let mut cbind_input = vec![];
56 42 : write!(&mut cbind_input, "p={mode},,",).unwrap();
57 42 : cbind_input.extend_from_slice(get_cbind_data(mode)?);
58 42 : base64::encode(&cbind_input).into()
59 : }
60 : })
61 90 : }
62 : }
63 :
64 : #[cfg(test)]
65 : mod tests {
66 : use super::*;
67 :
68 : #[test]
69 6 : fn channel_binding_encode() -> anyhow::Result<()> {
70 6 : use ChannelBinding::*;
71 6 :
72 6 : let cases = [
73 6 : (NotSupportedClient, base64::encode("n,,")),
74 6 : (NotSupportedServer, base64::encode("y,,")),
75 6 : (Required("foo"), base64::encode("p=foo,,bar")),
76 6 : ];
77 :
78 24 : for (cb, input) in cases {
79 18 : assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
80 : }
81 :
82 6 : Ok(())
83 6 : }
84 : }
|