Line data Source code
1 : //! Definition and parser for channel binding flag (a part of the `GS2` header).
2 :
3 : /// Channel binding flag (possibly with params).
4 69 : #[derive(Debug, PartialEq, Eq)]
5 : pub enum ChannelBinding<T> {
6 : /// Client doesn't support channel binding.
7 : NotSupportedClient,
8 : /// Client thinks server doesn't support channel binding.
9 : NotSupportedServer,
10 : /// Client wants to use this type of channel binding.
11 : Required(T),
12 : }
13 :
14 : impl<T> ChannelBinding<T> {
15 61 : pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
16 61 : use ChannelBinding::*;
17 61 : Ok(match self {
18 12 : NotSupportedClient => NotSupportedClient,
19 0 : NotSupportedServer => NotSupportedServer,
20 49 : Required(x) => Required(f(x)?),
21 : })
22 61 : }
23 : }
24 :
25 : impl<'a> ChannelBinding<&'a str> {
26 : // NB: FromStr doesn't work with lifetimes
27 69 : pub fn parse(input: &'a str) -> Option<Self> {
28 69 : use ChannelBinding::*;
29 69 : Some(match input {
30 69 : "n" => NotSupportedClient,
31 55 : "y" => NotSupportedServer,
32 51 : other => Required(other.strip_prefix("p=")?),
33 : })
34 69 : }
35 : }
36 :
37 : impl<T: std::fmt::Display> ChannelBinding<T> {
38 : /// Encode channel binding data as base64 for subsequent checks.
39 67 : pub fn encode<'a, E>(
40 67 : &self,
41 67 : get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
42 67 : ) -> Result<std::borrow::Cow<'static, str>, E> {
43 67 : use ChannelBinding::*;
44 67 : Ok(match self {
45 : NotSupportedClient => {
46 : // base64::encode("n,,")
47 14 : "biws".into()
48 : }
49 : NotSupportedServer => {
50 : // base64::encode("y,,")
51 2 : "eSws".into()
52 : }
53 51 : Required(mode) => {
54 51 : use std::io::Write;
55 51 : let mut cbind_input = vec![];
56 51 : write!(&mut cbind_input, "p={mode},,",).unwrap();
57 51 : cbind_input.extend_from_slice(get_cbind_data(mode)?);
58 51 : base64::encode(&cbind_input).into()
59 : }
60 : })
61 67 : }
62 : }
63 :
64 : #[cfg(test)]
65 : mod tests {
66 : use super::*;
67 :
68 2 : #[test]
69 2 : fn channel_binding_encode() -> anyhow::Result<()> {
70 2 : use ChannelBinding::*;
71 2 :
72 2 : let cases = [
73 2 : (NotSupportedClient, base64::encode("n,,")),
74 2 : (NotSupportedServer, base64::encode("y,,")),
75 2 : (Required("foo"), base64::encode("p=foo,,bar")),
76 2 : ];
77 :
78 8 : for (cb, input) in cases {
79 6 : assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
80 : }
81 :
82 2 : Ok(())
83 2 : }
84 : }
|