Line data Source code
1 : //! Definitions for SASL messages.
2 :
3 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
4 :
5 : use crate::parse::{split_at_const, split_cstr};
6 :
7 : /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
8 : #[derive(Debug)]
9 : pub(crate) struct FirstMessage<'a> {
10 : /// Authentication method, e.g. `"SCRAM-SHA-256"`.
11 : pub(crate) method: &'a str,
12 : /// Initial client message.
13 : pub(crate) message: &'a str,
14 : }
15 :
16 : impl<'a> FirstMessage<'a> {
17 : // NB: FromStr doesn't work with lifetimes
18 13 : pub(crate) fn parse(bytes: &'a [u8]) -> Option<Self> {
19 13 : let (method_cstr, tail) = split_cstr(bytes)?;
20 13 : let method = method_cstr.to_str().ok()?;
21 :
22 13 : let (len_bytes, bytes) = split_at_const(tail)?;
23 13 : let len = u32::from_be_bytes(*len_bytes) as usize;
24 13 : if len != bytes.len() {
25 0 : return None;
26 13 : }
27 :
28 13 : let message = std::str::from_utf8(bytes).ok()?;
29 13 : Some(Self { method, message })
30 13 : }
31 : }
32 :
33 : /// A single SASL message.
34 : /// This struct is deliberately decoupled from lower-level
35 : /// [`BeAuthenticationSaslMessage`].
36 : #[derive(Debug)]
37 : pub(super) enum ServerMessage<T> {
38 : /// We expect to see more steps.
39 : Continue(T),
40 : /// This is the final step.
41 : Final(T),
42 : }
43 :
44 : impl<'a> ServerMessage<&'a str> {
45 17 : pub(super) fn to_reply(&self) -> BeMessage<'a> {
46 17 : BeMessage::AuthenticationSasl(match self {
47 11 : ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()),
48 6 : ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()),
49 : })
50 17 : }
51 : }
52 :
53 : #[cfg(test)]
54 : mod tests {
55 : use super::*;
56 :
57 : #[test]
58 1 : fn parse_sasl_first_message() {
59 1 : let proto = "SCRAM-SHA-256";
60 1 : let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4";
61 1 : let sasl_len = (sasl.len() as u32).to_be_bytes();
62 1 : let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat();
63 1 :
64 1 : let password = FirstMessage::parse(&bytes).unwrap();
65 1 : assert_eq!(password.method, proto);
66 1 : assert_eq!(password.message, sasl);
67 1 : }
68 : }
|