Line data Source code
1 : //! Definition and parser for channel binding flag (a part of the `GS2` header).
2 :
3 : /// Channel binding flag (possibly with params).
4 : #[derive(Debug, PartialEq, Eq)]
5 : pub enum ChannelBinding<T> {
6 : /// Client doesn't support channel binding.
7 : NotSupportedClient,
8 : /// Client thinks server doesn't support channel binding.
9 : NotSupportedServer,
10 : /// Client wants to use this type of channel binding.
11 : Required(T),
12 : }
13 :
14 : impl<T> ChannelBinding<T> {
15 24 : pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
16 24 : Ok(match self {
17 12 : Self::NotSupportedClient => ChannelBinding::NotSupportedClient,
18 0 : Self::NotSupportedServer => ChannelBinding::NotSupportedServer,
19 12 : Self::Required(x) => ChannelBinding::Required(f(x)?),
20 : })
21 24 : }
22 : }
23 :
24 : impl<'a> ChannelBinding<&'a str> {
25 : // NB: FromStr doesn't work with lifetimes
26 42 : pub fn parse(input: &'a str) -> Option<Self> {
27 42 : Some(match input {
28 42 : "n" => Self::NotSupportedClient,
29 18 : "y" => Self::NotSupportedServer,
30 14 : other => Self::Required(other.strip_prefix("p=")?),
31 : })
32 42 : }
33 : }
34 :
35 : impl<T: std::fmt::Display> ChannelBinding<T> {
36 : /// Encode channel binding data as base64 for subsequent checks.
37 30 : pub fn encode<'a, E>(
38 30 : &self,
39 30 : get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
40 30 : ) -> Result<std::borrow::Cow<'static, str>, E> {
41 30 : Ok(match self {
42 14 : Self::NotSupportedClient => {
43 14 : // base64::encode("n,,")
44 14 : "biws".into()
45 : }
46 2 : Self::NotSupportedServer => {
47 2 : // base64::encode("y,,")
48 2 : "eSws".into()
49 : }
50 14 : Self::Required(mode) => {
51 14 : use std::io::Write;
52 14 : let mut cbind_input = vec![];
53 14 : write!(&mut cbind_input, "p={mode},,",).unwrap();
54 14 : cbind_input.extend_from_slice(get_cbind_data(mode)?);
55 14 : base64::encode(&cbind_input).into()
56 : }
57 : })
58 30 : }
59 : }
60 :
61 : #[cfg(test)]
62 : mod tests {
63 : use super::*;
64 :
65 : #[test]
66 2 : fn channel_binding_encode() -> anyhow::Result<()> {
67 2 : use ChannelBinding::*;
68 2 :
69 2 : let cases = [
70 2 : (NotSupportedClient, base64::encode("n,,")),
71 2 : (NotSupportedServer, base64::encode("y,,")),
72 2 : (Required("foo"), base64::encode("p=foo,,bar")),
73 2 : ];
74 :
75 8 : for (cb, input) in cases {
76 6 : assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
77 : }
78 :
79 2 : Ok(())
80 2 : }
81 : }
|