Line data Source code
1 : //! Definitions for SCRAM messages.
2 :
3 : use super::base64_decode_array;
4 : use super::key::{ScramKey, SCRAM_KEY_LEN};
5 : use super::signature::SignatureBuilder;
6 : use crate::sasl::ChannelBinding;
7 : use std::fmt;
8 : use std::ops::Range;
9 :
10 : /// Faithfully taken from PostgreSQL.
11 : pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18;
12 :
13 : /// Although we ignore all extensions, we still have to validate the message.
14 198 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
15 198 : for mut chars in parts.map(|s| s.chars()) {
16 36 : let attr = chars.next()?;
17 36 : if !attr.is_ascii_alphabetic() {
18 6 : return None;
19 30 : }
20 30 : let eq = chars.next()?;
21 24 : if eq != '=' {
22 6 : return None;
23 18 : }
24 : }
25 :
26 180 : Some(())
27 198 : }
28 :
29 : #[derive(Debug)]
30 : pub(crate) struct ClientFirstMessage<'a> {
31 : /// `client-first-message-bare`.
32 : pub(crate) bare: &'a str,
33 : /// Channel binding mode.
34 : pub(crate) cbind_flag: ChannelBinding<&'a str>,
35 : /// Client nonce.
36 : pub(crate) nonce: &'a str,
37 : }
38 :
39 : impl<'a> ClientFirstMessage<'a> {
40 : // NB: FromStr doesn't work with lifetimes
41 126 : pub(crate) fn parse(input: &'a str) -> Option<Self> {
42 126 : let mut parts = input.split(',');
43 :
44 126 : let cbind_flag = ChannelBinding::parse(parts.next()?)?;
45 :
46 : // PG doesn't support authorization identity,
47 : // so we don't bother defining GS2 header type
48 126 : let authzid = parts.next()?;
49 126 : if !authzid.is_empty() {
50 6 : return None;
51 120 : }
52 120 :
53 120 : // Unfortunately, `parts.as_str()` is unstable
54 120 : let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
55 120 : let (_, bare) = input.split_at(pos);
56 :
57 : // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
58 120 : let username = parts.next()?.strip_prefix("n=")?;
59 :
60 : // https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
61 120 : if !username.is_empty() {
62 6 : tracing::warn!(username, "scram username provided, but is not expected");
63 : // TODO(conrad):
64 : // return None;
65 114 : }
66 :
67 120 : let nonce = parts.next()?.strip_prefix("r=")?;
68 :
69 : // Validate but ignore auth extensions
70 120 : validate_sasl_extensions(parts)?;
71 :
72 102 : Some(Self {
73 102 : bare,
74 102 : cbind_flag,
75 102 : nonce,
76 102 : })
77 126 : }
78 :
79 : /// Build a response to [`ClientFirstMessage`].
80 72 : pub(crate) fn build_server_first_message(
81 72 : &self,
82 72 : nonce: &[u8; SCRAM_RAW_NONCE_LEN],
83 72 : salt_base64: &str,
84 72 : iterations: u32,
85 72 : ) -> OwnedServerFirstMessage {
86 72 : use std::fmt::Write;
87 72 :
88 72 : let mut message = String::new();
89 72 : write!(&mut message, "r={}", self.nonce).unwrap();
90 72 : base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
91 72 : let combined_nonce = 2..message.len();
92 72 : write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
93 72 :
94 72 : // This design guarantees that it's impossible to create a
95 72 : // server-first-message without receiving a client-first-message
96 72 : OwnedServerFirstMessage {
97 72 : message,
98 72 : nonce: combined_nonce,
99 72 : }
100 72 : }
101 : }
102 :
103 : #[derive(Debug)]
104 : pub(crate) struct ClientFinalMessage<'a> {
105 : /// `client-final-message-without-proof`.
106 : pub(crate) without_proof: &'a str,
107 : /// Channel binding data (base64).
108 : pub(crate) channel_binding: &'a str,
109 : /// Combined client & server nonce.
110 : pub(crate) nonce: &'a str,
111 : /// Client auth proof.
112 : pub(crate) proof: [u8; SCRAM_KEY_LEN],
113 : }
114 :
115 : impl<'a> ClientFinalMessage<'a> {
116 : // NB: FromStr doesn't work with lifetimes
117 78 : pub(crate) fn parse(input: &'a str) -> Option<Self> {
118 78 : let (without_proof, proof) = input.rsplit_once(',')?;
119 :
120 78 : let mut parts = without_proof.split(',');
121 78 : let channel_binding = parts.next()?.strip_prefix("c=")?;
122 78 : let nonce = parts.next()?.strip_prefix("r=")?;
123 :
124 : // Validate but ignore auth extensions
125 78 : validate_sasl_extensions(parts)?;
126 :
127 78 : let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
128 :
129 78 : Some(Self {
130 78 : without_proof,
131 78 : channel_binding,
132 78 : nonce,
133 78 : proof,
134 78 : })
135 78 : }
136 :
137 : /// Build a response to [`ClientFinalMessage`].
138 42 : pub(crate) fn build_server_final_message(
139 42 : &self,
140 42 : signature_builder: SignatureBuilder<'_>,
141 42 : server_key: &ScramKey,
142 42 : ) -> String {
143 42 : let mut buf = String::from("v=");
144 42 : base64::encode_config_buf(
145 42 : signature_builder.build(server_key),
146 42 : base64::STANDARD,
147 42 : &mut buf,
148 42 : );
149 42 :
150 42 : buf
151 42 : }
152 : }
153 :
154 : /// We need to keep a convenient representation of this
155 : /// message for the next authentication step.
156 : pub(crate) struct OwnedServerFirstMessage {
157 : /// Owned `server-first-message`.
158 : message: String,
159 : /// Slice into `message`.
160 : nonce: Range<usize>,
161 : }
162 :
163 : impl OwnedServerFirstMessage {
164 : /// Extract combined nonce from the message.
165 : #[inline(always)]
166 48 : pub(crate) fn nonce(&self) -> &str {
167 48 : &self.message[self.nonce.clone()]
168 48 : }
169 :
170 : /// Get reference to a text representation of the message.
171 : #[inline(always)]
172 120 : pub(crate) fn as_str(&self) -> &str {
173 120 : &self.message
174 120 : }
175 : }
176 :
177 : impl fmt::Debug for OwnedServerFirstMessage {
178 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 0 : f.debug_struct("ServerFirstMessage")
180 0 : .field("message", &self.as_str())
181 0 : .field("nonce", &self.nonce())
182 0 : .finish()
183 0 : }
184 : }
185 :
186 : #[cfg(test)]
187 : mod tests {
188 : use super::*;
189 :
190 : #[test]
191 6 : fn parse_client_first_message() {
192 6 : use ChannelBinding::*;
193 6 :
194 6 : // (Almost) real strings captured during debug sessions
195 6 : let cases = [
196 6 : (NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
197 6 : (NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
198 6 : (
199 6 : Required("tls-server-end-point"),
200 6 : "p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
201 6 : ),
202 6 : ];
203 :
204 24 : for (cb, input) in cases {
205 18 : let msg = ClientFirstMessage::parse(input).unwrap();
206 18 :
207 18 : assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
208 18 : assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
209 18 : assert_eq!(msg.cbind_flag, cb);
210 : }
211 6 : }
212 :
213 : #[test]
214 6 : fn parse_client_first_message_with_invalid_gs2_authz() {
215 6 : assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none());
216 6 : }
217 :
218 : #[test]
219 6 : fn parse_client_first_message_with_extra_params() {
220 6 : let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
221 6 : assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
222 6 : assert_eq!(msg.nonce, "nonce");
223 6 : assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
224 6 : }
225 :
226 : #[test]
227 6 : fn parse_client_first_message_with_extra_params_invalid() {
228 6 : // must be of the form `<ascii letter>=<...>`
229 6 : assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
230 6 : assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
231 6 : assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
232 6 : }
233 :
234 : #[test]
235 6 : fn parse_client_final_message() {
236 6 : let input = [
237 6 : "c=eSws",
238 6 : "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
239 6 : "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
240 6 : ]
241 6 : .join(",");
242 6 :
243 6 : let msg = ClientFinalMessage::parse(&input).unwrap();
244 6 : assert_eq!(
245 6 : msg.without_proof,
246 6 : "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
247 6 : );
248 6 : assert_eq!(
249 6 : msg.nonce,
250 6 : "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
251 6 : );
252 6 : assert_eq!(
253 6 : base64::encode(msg.proof),
254 6 : "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
255 6 : );
256 6 : }
257 : }
|