Line data Source code
1 : //! Definitions for SCRAM messages.
2 :
3 : use std::fmt;
4 : use std::ops::Range;
5 :
6 : use base64::Engine as _;
7 : use base64::prelude::BASE64_STANDARD;
8 :
9 : use super::base64_decode_array;
10 : use super::key::{SCRAM_KEY_LEN, ScramKey};
11 : use super::signature::SignatureBuilder;
12 : use crate::sasl::ChannelBinding;
13 :
14 : /// Faithfully taken from PostgreSQL.
15 : pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18;
16 :
17 : /// Although we ignore all extensions, we still have to validate the message.
18 33 : fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
19 33 : for mut chars in parts.map(|s| s.chars()) {
20 6 : let attr = chars.next()?;
21 6 : if !attr.is_ascii_alphabetic() {
22 1 : return None;
23 5 : }
24 5 : let eq = chars.next()?;
25 4 : if eq != '=' {
26 1 : return None;
27 3 : }
28 : }
29 :
30 30 : Some(())
31 33 : }
32 :
33 : #[derive(Debug)]
34 : pub(crate) struct ClientFirstMessage<'a> {
35 : /// `client-first-message-bare`.
36 : pub(crate) bare: &'a str,
37 : /// Channel binding mode.
38 : pub(crate) cbind_flag: ChannelBinding<&'a str>,
39 : /// Client nonce.
40 : pub(crate) nonce: &'a str,
41 : }
42 :
43 : impl<'a> ClientFirstMessage<'a> {
44 : // NB: FromStr doesn't work with lifetimes
45 21 : pub(crate) fn parse(input: &'a str) -> Option<Self> {
46 21 : let mut parts = input.split(',');
47 :
48 21 : let cbind_flag = ChannelBinding::parse(parts.next()?)?;
49 :
50 : // PG doesn't support authorization identity,
51 : // so we don't bother defining GS2 header type
52 21 : let authzid = parts.next()?;
53 21 : if !authzid.is_empty() {
54 1 : return None;
55 20 : }
56 :
57 : // Unfortunately, `parts.as_str()` is unstable
58 20 : let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
59 20 : let (_, bare) = input.split_at(pos);
60 :
61 : // In theory, these might be preceded by "reserved-mext" (i.e. "m=")
62 20 : let username = parts.next()?.strip_prefix("n=")?;
63 :
64 : // https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
65 20 : if !username.is_empty() {
66 1 : tracing::warn!(username, "scram username provided, but is not expected");
67 : // TODO(conrad):
68 : // return None;
69 19 : }
70 :
71 20 : let nonce = parts.next()?.strip_prefix("r=")?;
72 :
73 : // Validate but ignore auth extensions
74 20 : validate_sasl_extensions(parts)?;
75 :
76 17 : Some(Self {
77 17 : bare,
78 17 : cbind_flag,
79 17 : nonce,
80 17 : })
81 21 : }
82 :
83 : /// Build a response to [`ClientFirstMessage`].
84 12 : pub(crate) fn build_server_first_message(
85 12 : &self,
86 12 : nonce: &[u8; SCRAM_RAW_NONCE_LEN],
87 12 : salt_base64: &str,
88 12 : iterations: u32,
89 12 : ) -> OwnedServerFirstMessage {
90 12 : let mut message = String::with_capacity(128);
91 12 : message.push_str("r=");
92 :
93 : // write combined nonce
94 12 : let combined_nonce_start = message.len();
95 12 : message.push_str(self.nonce);
96 12 : BASE64_STANDARD.encode_string(nonce, &mut message);
97 12 : let combined_nonce = combined_nonce_start..message.len();
98 :
99 : // write salt and iterations
100 12 : message.push_str(",s=");
101 12 : message.push_str(salt_base64);
102 12 : message.push_str(",i=");
103 12 : message.push_str(itoa::Buffer::new().format(iterations));
104 :
105 : // This design guarantees that it's impossible to create a
106 : // server-first-message without receiving a client-first-message
107 12 : OwnedServerFirstMessage {
108 12 : message,
109 12 : nonce: combined_nonce,
110 12 : }
111 12 : }
112 : }
113 :
114 : #[derive(Debug)]
115 : pub(crate) struct ClientFinalMessage<'a> {
116 : /// `client-final-message-without-proof`.
117 : pub(crate) without_proof: &'a str,
118 : /// Channel binding data (base64).
119 : pub(crate) channel_binding: &'a str,
120 : /// Combined client & server nonce.
121 : pub(crate) nonce: &'a str,
122 : /// Client auth proof.
123 : pub(crate) proof: [u8; SCRAM_KEY_LEN],
124 : }
125 :
126 : impl<'a> ClientFinalMessage<'a> {
127 : // NB: FromStr doesn't work with lifetimes
128 13 : pub(crate) fn parse(input: &'a str) -> Option<Self> {
129 13 : let (without_proof, proof) = input.rsplit_once(',')?;
130 :
131 13 : let mut parts = without_proof.split(',');
132 13 : let channel_binding = parts.next()?.strip_prefix("c=")?;
133 13 : let nonce = parts.next()?.strip_prefix("r=")?;
134 :
135 : // Validate but ignore auth extensions
136 13 : validate_sasl_extensions(parts)?;
137 :
138 13 : let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
139 :
140 13 : Some(Self {
141 13 : without_proof,
142 13 : channel_binding,
143 13 : nonce,
144 13 : proof,
145 13 : })
146 13 : }
147 :
148 : /// Build a response to [`ClientFinalMessage`].
149 7 : pub(crate) fn build_server_final_message(
150 7 : &self,
151 7 : signature_builder: SignatureBuilder<'_>,
152 7 : server_key: &ScramKey,
153 7 : ) -> String {
154 7 : let mut buf = String::from("v=");
155 7 : BASE64_STANDARD.encode_string(signature_builder.build(server_key), &mut buf);
156 :
157 7 : buf
158 7 : }
159 : }
160 :
161 : /// We need to keep a convenient representation of this
162 : /// message for the next authentication step.
163 : pub(crate) struct OwnedServerFirstMessage {
164 : /// Owned `server-first-message`.
165 : message: String,
166 : /// Slice into `message`.
167 : nonce: Range<usize>,
168 : }
169 :
170 : impl OwnedServerFirstMessage {
171 : /// Extract combined nonce from the message.
172 : #[inline(always)]
173 8 : pub(crate) fn nonce(&self) -> &str {
174 8 : &self.message[self.nonce.clone()]
175 8 : }
176 :
177 : /// Get reference to a text representation of the message.
178 : #[inline(always)]
179 20 : pub(crate) fn as_str(&self) -> &str {
180 20 : &self.message
181 20 : }
182 : }
183 :
184 : impl fmt::Debug for OwnedServerFirstMessage {
185 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 0 : f.debug_struct("ServerFirstMessage")
187 0 : .field("message", &self.as_str())
188 0 : .field("nonce", &self.nonce())
189 0 : .finish()
190 0 : }
191 : }
192 :
193 : #[cfg(test)]
194 : mod tests {
195 : use super::*;
196 :
197 : #[test]
198 1 : fn parse_client_first_message() {
199 : use ChannelBinding::*;
200 :
201 : // (Almost) real strings captured during debug sessions
202 1 : let cases = [
203 1 : (NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
204 1 : (NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
205 1 : (
206 1 : Required("tls-server-end-point"),
207 1 : "p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
208 1 : ),
209 1 : ];
210 :
211 4 : for (cb, input) in cases {
212 3 : let msg = ClientFirstMessage::parse(input).unwrap();
213 :
214 3 : assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
215 3 : assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
216 3 : assert_eq!(msg.cbind_flag, cb);
217 : }
218 1 : }
219 :
220 : #[test]
221 1 : fn parse_client_first_message_with_invalid_gs2_authz() {
222 1 : assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none());
223 1 : }
224 :
225 : #[test]
226 1 : fn parse_client_first_message_with_extra_params() {
227 1 : let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
228 1 : assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
229 1 : assert_eq!(msg.nonce, "nonce");
230 1 : assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
231 1 : }
232 :
233 : #[test]
234 1 : fn parse_client_first_message_with_extra_params_invalid() {
235 : // must be of the form `<ascii letter>=<...>`
236 1 : assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
237 1 : assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
238 1 : assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
239 1 : }
240 :
241 : #[test]
242 1 : fn parse_client_final_message() {
243 1 : let input = [
244 1 : "c=eSws",
245 1 : "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
246 1 : "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
247 1 : ]
248 1 : .join(",");
249 :
250 1 : let msg = ClientFinalMessage::parse(&input).unwrap();
251 1 : assert_eq!(
252 : msg.without_proof,
253 : "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
254 : );
255 1 : assert_eq!(
256 : msg.nonce,
257 : "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
258 : );
259 1 : assert_eq!(
260 1 : BASE64_STANDARD.encode(msg.proof),
261 : "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
262 : );
263 1 : }
264 : }
|