Line data Source code
1 : //! Definitions for SCRAM messages.
2 :
3 : use std::fmt;
4 : use std::ops::Range;
5 :
6 : use super::base64_decode_array;
7 : use super::key::{ScramKey, SCRAM_KEY_LEN};
8 : use super::signature::SignatureBuilder;
9 : use crate::sasl::ChannelBinding;
10 :
11 : /// Faithfully taken from PostgreSQL.
12 : pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18;
13 :
14 : /// Although we ignore all extensions, we still have to validate the message.
15 33 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
16 33 : for mut chars in parts.map(|s| s.chars()) {
17 6 : let attr = chars.next()?;
18 6 : if !attr.is_ascii_alphabetic() {
19 1 : return None;
20 5 : }
21 5 : let eq = chars.next()?;
22 4 : if eq != '=' {
23 1 : return None;
24 3 : }
25 : }
26 :
27 30 : Some(())
28 33 : }
29 :
30 : #[derive(Debug)]
31 : pub(crate) struct ClientFirstMessage<'a> {
32 : /// `client-first-message-bare`.
33 : pub(crate) bare: &'a str,
34 : /// Channel binding mode.
35 : pub(crate) cbind_flag: ChannelBinding<&'a str>,
36 : /// Client nonce.
37 : pub(crate) nonce: &'a str,
38 : }
39 :
40 : impl<'a> ClientFirstMessage<'a> {
41 : // NB: FromStr doesn't work with lifetimes
42 21 : pub(crate) fn parse(input: &'a str) -> Option<Self> {
43 21 : let mut parts = input.split(',');
44 :
45 21 : let cbind_flag = ChannelBinding::parse(parts.next()?)?;
46 :
47 : // PG doesn't support authorization identity,
48 : // so we don't bother defining GS2 header type
49 21 : let authzid = parts.next()?;
50 21 : if !authzid.is_empty() {
51 1 : return None;
52 20 : }
53 20 :
54 20 : // Unfortunately, `parts.as_str()` is unstable
55 20 : let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
56 20 : let (_, bare) = input.split_at(pos);
57 :
58 : // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
59 20 : let username = parts.next()?.strip_prefix("n=")?;
60 :
61 : // https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
62 20 : if !username.is_empty() {
63 1 : tracing::warn!(username, "scram username provided, but is not expected");
64 : // TODO(conrad):
65 : // return None;
66 19 : }
67 :
68 20 : let nonce = parts.next()?.strip_prefix("r=")?;
69 :
70 : // Validate but ignore auth extensions
71 20 : validate_sasl_extensions(parts)?;
72 :
73 17 : Some(Self {
74 17 : bare,
75 17 : cbind_flag,
76 17 : nonce,
77 17 : })
78 21 : }
79 :
80 : /// Build a response to [`ClientFirstMessage`].
81 12 : pub(crate) fn build_server_first_message(
82 12 : &self,
83 12 : nonce: &[u8; SCRAM_RAW_NONCE_LEN],
84 12 : salt_base64: &str,
85 12 : iterations: u32,
86 12 : ) -> OwnedServerFirstMessage {
87 : use std::fmt::Write;
88 :
89 12 : let mut message = String::new();
90 12 : write!(&mut message, "r={}", self.nonce).unwrap();
91 12 : base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
92 12 : let combined_nonce = 2..message.len();
93 12 : write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
94 12 :
95 12 : // This design guarantees that it's impossible to create a
96 12 : // server-first-message without receiving a client-first-message
97 12 : OwnedServerFirstMessage {
98 12 : message,
99 12 : nonce: combined_nonce,
100 12 : }
101 12 : }
102 : }
103 :
104 : #[derive(Debug)]
105 : pub(crate) struct ClientFinalMessage<'a> {
106 : /// `client-final-message-without-proof`.
107 : pub(crate) without_proof: &'a str,
108 : /// Channel binding data (base64).
109 : pub(crate) channel_binding: &'a str,
110 : /// Combined client & server nonce.
111 : pub(crate) nonce: &'a str,
112 : /// Client auth proof.
113 : pub(crate) proof: [u8; SCRAM_KEY_LEN],
114 : }
115 :
116 : impl<'a> ClientFinalMessage<'a> {
117 : // NB: FromStr doesn't work with lifetimes
118 13 : pub(crate) fn parse(input: &'a str) -> Option<Self> {
119 13 : let (without_proof, proof) = input.rsplit_once(',')?;
120 :
121 13 : let mut parts = without_proof.split(',');
122 13 : let channel_binding = parts.next()?.strip_prefix("c=")?;
123 13 : let nonce = parts.next()?.strip_prefix("r=")?;
124 :
125 : // Validate but ignore auth extensions
126 13 : validate_sasl_extensions(parts)?;
127 :
128 13 : let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
129 :
130 13 : Some(Self {
131 13 : without_proof,
132 13 : channel_binding,
133 13 : nonce,
134 13 : proof,
135 13 : })
136 13 : }
137 :
138 : /// Build a response to [`ClientFinalMessage`].
139 7 : pub(crate) fn build_server_final_message(
140 7 : &self,
141 7 : signature_builder: SignatureBuilder<'_>,
142 7 : server_key: &ScramKey,
143 7 : ) -> String {
144 7 : let mut buf = String::from("v=");
145 7 : base64::encode_config_buf(
146 7 : signature_builder.build(server_key),
147 7 : base64::STANDARD,
148 7 : &mut buf,
149 7 : );
150 7 :
151 7 : buf
152 7 : }
153 : }
154 :
155 : /// We need to keep a convenient representation of this
156 : /// message for the next authentication step.
157 : pub(crate) struct OwnedServerFirstMessage {
158 : /// Owned `server-first-message`.
159 : message: String,
160 : /// Slice into `message`.
161 : nonce: Range<usize>,
162 : }
163 :
164 : impl OwnedServerFirstMessage {
165 : /// Extract combined nonce from the message.
166 : #[inline(always)]
167 8 : pub(crate) fn nonce(&self) -> &str {
168 8 : &self.message[self.nonce.clone()]
169 8 : }
170 :
171 : /// Get reference to a text representation of the message.
172 : #[inline(always)]
173 20 : pub(crate) fn as_str(&self) -> &str {
174 20 : &self.message
175 20 : }
176 : }
177 :
178 : impl fmt::Debug for OwnedServerFirstMessage {
179 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 0 : f.debug_struct("ServerFirstMessage")
181 0 : .field("message", &self.as_str())
182 0 : .field("nonce", &self.nonce())
183 0 : .finish()
184 0 : }
185 : }
186 :
187 : #[cfg(test)]
188 : mod tests {
189 : use super::*;
190 :
191 : #[test]
192 1 : fn parse_client_first_message() {
193 : use ChannelBinding::*;
194 :
195 : // (Almost) real strings captured during debug sessions
196 1 : let cases = [
197 1 : (NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
198 1 : (NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
199 1 : (
200 1 : Required("tls-server-end-point"),
201 1 : "p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
202 1 : ),
203 1 : ];
204 :
205 4 : for (cb, input) in cases {
206 3 : let msg = ClientFirstMessage::parse(input).unwrap();
207 3 :
208 3 : assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
209 3 : assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
210 3 : assert_eq!(msg.cbind_flag, cb);
211 : }
212 1 : }
213 :
214 : #[test]
215 1 : fn parse_client_first_message_with_invalid_gs2_authz() {
216 1 : assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none());
217 1 : }
218 :
219 : #[test]
220 1 : fn parse_client_first_message_with_extra_params() {
221 1 : let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
222 1 : assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
223 1 : assert_eq!(msg.nonce, "nonce");
224 1 : assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
225 1 : }
226 :
227 : #[test]
228 1 : fn parse_client_first_message_with_extra_params_invalid() {
229 1 : // must be of the form `<ascii letter>=<...>`
230 1 : assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
231 1 : assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
232 1 : assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
233 1 : }
234 :
235 : #[test]
236 1 : fn parse_client_final_message() {
237 1 : let input = [
238 1 : "c=eSws",
239 1 : "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
240 1 : "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
241 1 : ]
242 1 : .join(",");
243 1 :
244 1 : let msg = ClientFinalMessage::parse(&input).unwrap();
245 1 : assert_eq!(
246 1 : msg.without_proof,
247 1 : "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
248 1 : );
249 1 : assert_eq!(
250 1 : msg.nonce,
251 1 : "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
252 1 : );
253 1 : assert_eq!(
254 1 : base64::encode(msg.proof),
255 1 : "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
256 1 : );
257 1 : }
258 : }
|