TLA Line data Source code
1 : //! Definition and parser for channel binding flag (a part of the `GS2` header).
2 :
3 : /// Channel binding flag (possibly with params).
4 CBC 53 : #[derive(Debug, PartialEq, Eq)]
5 : pub enum ChannelBinding<T> {
6 : /// Client doesn't support channel binding.
7 : NotSupportedClient,
8 : /// Client thinks server doesn't support channel binding.
9 : NotSupportedServer,
10 : /// Client wants to use this type of channel binding.
11 : Required(T),
12 : }
13 :
14 : impl<T> ChannelBinding<T> {
15 49 : pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
16 49 : use ChannelBinding::*;
17 49 : Ok(match self {
18 7 : NotSupportedClient => NotSupportedClient,
19 UBC 0 : NotSupportedServer => NotSupportedServer,
20 CBC 42 : Required(x) => Required(f(x)?),
21 : })
22 49 : }
23 : }
24 :
25 : impl<'a> ChannelBinding<&'a str> {
26 : // NB: FromStr doesn't work with lifetimes
27 53 : pub fn parse(input: &'a str) -> Option<Self> {
28 53 : use ChannelBinding::*;
29 53 : Some(match input {
30 53 : "n" => NotSupportedClient,
31 45 : "y" => NotSupportedServer,
32 43 : other => Required(other.strip_prefix("p=")?),
33 : })
34 53 : }
35 : }
36 :
37 : impl<T: std::fmt::Display> ChannelBinding<T> {
38 : /// Encode channel binding data as base64 for subsequent checks.
39 52 : pub fn encode<'a, E>(
40 52 : &self,
41 52 : get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
42 52 : ) -> Result<std::borrow::Cow<'static, str>, E> {
43 52 : use ChannelBinding::*;
44 52 : Ok(match self {
45 : NotSupportedClient => {
46 : // base64::encode("n,,")
47 8 : "biws".into()
48 : }
49 : NotSupportedServer => {
50 : // base64::encode("y,,")
51 1 : "eSws".into()
52 : }
53 43 : Required(mode) => {
54 43 : use std::io::Write;
55 43 : let mut cbind_input = vec![];
56 43 : write!(&mut cbind_input, "p={mode},,",).unwrap();
57 43 : cbind_input.extend_from_slice(get_cbind_data(mode)?);
58 43 : base64::encode(&cbind_input).into()
59 : }
60 : })
61 52 : }
62 : }
63 :
64 : #[cfg(test)]
65 : mod tests {
66 : use super::*;
67 :
68 1 : #[test]
69 1 : fn channel_binding_encode() -> anyhow::Result<()> {
70 1 : use ChannelBinding::*;
71 1 :
72 1 : let cases = [
73 1 : (NotSupportedClient, base64::encode("n,,")),
74 1 : (NotSupportedServer, base64::encode("y,,")),
75 1 : (Required("foo"), base64::encode("p=foo,,bar")),
76 1 : ];
77 :
78 4 : for (cb, input) in cases {
79 3 : assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
80 : }
81 :
82 1 : Ok(())
83 1 : }
84 : }
|