Line data Source code
1 : //! Definition and parser for channel binding flag (a part of the `GS2` header).
2 :
3 : /// Channel binding flag (possibly with params).
4 : #[derive(Debug, PartialEq, Eq)]
5 : pub(crate) enum ChannelBinding<T> {
6 : /// Client doesn't support channel binding.
7 : NotSupportedClient,
8 : /// Client thinks server doesn't support channel binding.
9 : NotSupportedServer,
10 : /// Client wants to use this type of channel binding.
11 : Required(T),
12 : }
13 :
14 : impl<T> ChannelBinding<T> {
15 12 : pub(crate) fn and_then<R, E>(
16 12 : self,
17 12 : f: impl FnOnce(T) -> Result<R, E>,
18 12 : ) -> Result<ChannelBinding<R>, E> {
19 12 : Ok(match self {
20 6 : Self::NotSupportedClient => ChannelBinding::NotSupportedClient,
21 0 : Self::NotSupportedServer => ChannelBinding::NotSupportedServer,
22 6 : Self::Required(x) => ChannelBinding::Required(f(x)?),
23 : })
24 12 : }
25 : }
26 :
27 : impl<'a> ChannelBinding<&'a str> {
28 : // NB: FromStr doesn't work with lifetimes
29 21 : pub(crate) fn parse(input: &'a str) -> Option<Self> {
30 21 : Some(match input {
31 21 : "n" => Self::NotSupportedClient,
32 9 : "y" => Self::NotSupportedServer,
33 7 : other => Self::Required(other.strip_prefix("p=")?),
34 : })
35 21 : }
36 : }
37 :
38 : impl<T: std::fmt::Display> ChannelBinding<T> {
39 : /// Encode channel binding data as base64 for subsequent checks.
40 15 : pub(crate) fn encode<'a, E>(
41 15 : &self,
42 15 : get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
43 15 : ) -> Result<std::borrow::Cow<'static, str>, E> {
44 15 : Ok(match self {
45 7 : Self::NotSupportedClient => {
46 7 : // base64::encode("n,,")
47 7 : "biws".into()
48 : }
49 1 : Self::NotSupportedServer => {
50 1 : // base64::encode("y,,")
51 1 : "eSws".into()
52 : }
53 7 : Self::Required(mode) => {
54 : use std::io::Write;
55 7 : let mut cbind_input = vec![];
56 7 : write!(&mut cbind_input, "p={mode},,",).unwrap();
57 7 : cbind_input.extend_from_slice(get_cbind_data(mode)?);
58 7 : base64::encode(&cbind_input).into()
59 : }
60 : })
61 15 : }
62 : }
63 :
64 : #[cfg(test)]
65 : mod tests {
66 : use super::*;
67 :
68 : #[test]
69 1 : fn channel_binding_encode() -> anyhow::Result<()> {
70 : use ChannelBinding::*;
71 :
72 1 : let cases = [
73 1 : (NotSupportedClient, base64::encode("n,,")),
74 1 : (NotSupportedServer, base64::encode("y,,")),
75 1 : (Required("foo"), base64::encode("p=foo,,bar")),
76 1 : ];
77 :
78 4 : for (cb, input) in cases {
79 3 : assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
80 : }
81 :
82 1 : Ok(())
83 1 : }
84 : }
|