TLA Line data Source code
1 : //! Definition and parser for channel binding flag (a part of the `GS2` header).
2 :
3 : /// Channel binding flag (possibly with params).
4 CBC 3 : #[derive(Debug, PartialEq, Eq)]
5 : pub enum ChannelBinding<T> {
6 : /// Client doesn't support channel binding.
7 : NotSupportedClient,
8 : /// Client thinks server doesn't support channel binding.
9 : NotSupportedServer,
10 : /// Client wants to use this type of channel binding.
11 : Required(T),
12 : }
13 :
14 : impl<T> ChannelBinding<T> {
15 32 : pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
16 32 : use ChannelBinding::*;
17 32 : Ok(match self {
18 1 : NotSupportedClient => NotSupportedClient,
19 31 : NotSupportedServer => NotSupportedServer,
20 UBC 0 : Required(x) => Required(f(x)?),
21 : })
22 CBC 32 : }
23 : }
24 :
25 : impl<'a> ChannelBinding<&'a str> {
26 : // NB: FromStr doesn't work with lifetimes
27 35 : pub fn parse(input: &'a str) -> Option<Self> {
28 35 : use ChannelBinding::*;
29 35 : Some(match input {
30 35 : "n" => NotSupportedClient,
31 33 : "y" => NotSupportedServer,
32 1 : other => Required(other.strip_prefix("p=")?),
33 : })
34 35 : }
35 : }
36 :
37 : impl<T: std::fmt::Display> ChannelBinding<T> {
38 : /// Encode channel binding data as base64 for subsequent checks.
39 35 : pub fn encode<E>(
40 35 : &self,
41 35 : get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
42 35 : ) -> Result<std::borrow::Cow<'static, str>, E> {
43 35 : use ChannelBinding::*;
44 35 : Ok(match self {
45 : NotSupportedClient => {
46 : // base64::encode("n,,")
47 2 : "biws".into()
48 : }
49 : NotSupportedServer => {
50 : // base64::encode("y,,")
51 32 : "eSws".into()
52 : }
53 1 : Required(mode) => {
54 1 : let msg = format!(
55 1 : "p={mode},,{data}",
56 1 : mode = mode,
57 1 : data = get_cbind_data(mode)?
58 : );
59 1 : base64::encode(msg).into()
60 : }
61 : })
62 35 : }
63 : }
64 :
65 : #[cfg(test)]
66 : mod tests {
67 : use super::*;
68 :
69 1 : #[test]
70 1 : fn channel_binding_encode() -> anyhow::Result<()> {
71 1 : use ChannelBinding::*;
72 1 :
73 1 : let cases = [
74 1 : (NotSupportedClient, base64::encode("n,,")),
75 1 : (NotSupportedServer, base64::encode("y,,")),
76 1 : (Required("foo"), base64::encode("p=foo,,bar")),
77 1 : ];
78 :
79 4 : for (cb, input) in cases {
80 3 : assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
81 : }
82 :
83 1 : Ok(())
84 1 : }
85 : }
|