Line data Source code
1 : //! SASL-based authentication support.
2 :
3 : use std::fmt::Write;
4 : use std::{io, iter, mem, str};
5 :
6 : use hmac::{Hmac, Mac};
7 : use rand::{self, Rng};
8 : use sha2::digest::FixedOutput;
9 : use sha2::{Digest, Sha256};
10 : use tokio::task::yield_now;
11 :
12 : const NONCE_LENGTH: usize = 24;
13 :
14 : /// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
15 : pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
16 : /// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
17 : pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
18 :
19 : // since postgres passwords are not required to exclude saslprep-prohibited
20 : // characters or even be valid UTF8, we run saslprep if possible and otherwise
21 : // return the raw password.
22 13 : fn normalize(pass: &[u8]) -> Vec<u8> {
23 13 : let pass = match str::from_utf8(pass) {
24 13 : Ok(pass) => pass,
25 0 : Err(_) => return pass.to_vec(),
26 : };
27 :
28 13 : match stringprep::saslprep(pass) {
29 13 : Ok(pass) => pass.into_owned().into_bytes(),
30 0 : Err(_) => pass.as_bytes().to_vec(),
31 : }
32 13 : }
33 :
34 29 : pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
35 29 : let mut hmac =
36 29 : Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
37 29 : hmac.update(salt);
38 29 : hmac.update(&[0, 0, 0, 1]);
39 29 : let mut prev = hmac.finalize().into_bytes();
40 29 :
41 29 : let mut hi = prev;
42 :
43 114660 : for i in 1..iterations {
44 114660 : let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
45 114660 : hmac.update(&prev);
46 114660 : prev = hmac.finalize().into_bytes();
47 :
48 3669120 : for (hi, prev) in hi.iter_mut().zip(prev) {
49 3669120 : *hi ^= prev;
50 3669120 : }
51 : // yield every ~250us
52 : // hopefully reduces tail latencies
53 114660 : if i % 1024 == 0 {
54 84 : yield_now().await
55 8184 : }
56 : }
57 :
58 29 : hi.into()
59 29 : }
60 :
61 : enum ChannelBindingInner {
62 : Unrequested,
63 : Unsupported,
64 : TlsServerEndPoint(Vec<u8>),
65 : }
66 :
67 : /// The channel binding configuration for a SCRAM authentication exchange.
68 : pub struct ChannelBinding(ChannelBindingInner);
69 :
70 : impl ChannelBinding {
71 : /// The server did not request channel binding.
72 2 : pub fn unrequested() -> ChannelBinding {
73 2 : ChannelBinding(ChannelBindingInner::Unrequested)
74 2 : }
75 :
76 : /// The server requested channel binding but the client is unable to provide it.
77 4 : pub fn unsupported() -> ChannelBinding {
78 4 : ChannelBinding(ChannelBindingInner::Unsupported)
79 4 : }
80 :
81 : /// The server requested channel binding and the client will use the `tls-server-end-point`
82 : /// method.
83 10 : pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
84 10 : ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
85 10 : }
86 :
87 25 : fn gs2_header(&self) -> &'static str {
88 25 : match self.0 {
89 1 : ChannelBindingInner::Unrequested => "y,,",
90 8 : ChannelBindingInner::Unsupported => "n,,",
91 16 : ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
92 : }
93 25 : }
94 :
95 12 : fn cbind_data(&self) -> &[u8] {
96 12 : match self.0 {
97 4 : ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
98 8 : ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
99 : }
100 12 : }
101 : }
102 :
103 : /// A pair of keys for the SCRAM-SHA-256 mechanism.
104 : /// See <https://datatracker.ietf.org/doc/html/rfc5802#section-3> for details.
105 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
106 : pub struct ScramKeys<const N: usize> {
107 : /// Used by server to authenticate client.
108 : pub client_key: [u8; N],
109 : /// Used by client to verify server's signature.
110 : pub server_key: [u8; N],
111 : }
112 :
113 : /// Password or keys which were derived from it.
114 : enum Credentials<const N: usize> {
115 : /// A regular password as a vector of bytes.
116 : Password(Vec<u8>),
117 : /// A precomputed pair of keys.
118 : Keys(ScramKeys<N>),
119 : }
120 :
121 : enum State {
122 : Update {
123 : nonce: String,
124 : password: Credentials<32>,
125 : channel_binding: ChannelBinding,
126 : },
127 : Finish {
128 : server_key: [u8; 32],
129 : auth_message: String,
130 : },
131 : Done,
132 : }
133 :
134 : /// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
135 : /// process.
136 : ///
137 : /// During the authentication process, if the backend sends an `AuthenticationSASL` message which
138 : /// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
139 : ///
140 : /// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
141 : /// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
142 : ///
143 : /// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
144 : /// passed to the `update()` method, after which the buffer returned by the `message()` method
145 : /// should be sent to the backend in a `SASLResponse` message.
146 : ///
147 : /// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
148 : /// to the `finish()` method, after which the authentication process is complete.
149 : pub struct ScramSha256 {
150 : message: String,
151 : state: State,
152 : }
153 :
154 12 : fn nonce() -> String {
155 12 : // rand 0.5's ThreadRng is cryptographically secure
156 12 : let mut rng = rand::thread_rng();
157 12 : (0..NONCE_LENGTH)
158 288 : .map(|_| {
159 288 : let mut v = rng.gen_range(0x21u8..0x7e);
160 288 : if v == 0x2c {
161 6 : v = 0x7e
162 282 : }
163 288 : v as char
164 288 : })
165 12 : .collect()
166 12 : }
167 :
168 : impl ScramSha256 {
169 : /// Constructs a new instance which will use the provided password for authentication.
170 12 : pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
171 12 : let password = Credentials::Password(normalize(password));
172 12 : ScramSha256::new_inner(password, channel_binding, nonce())
173 12 : }
174 :
175 : /// Constructs a new instance which will use the provided key pair for authentication.
176 0 : pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
177 0 : let password = Credentials::Keys(keys);
178 0 : ScramSha256::new_inner(password, channel_binding, nonce())
179 0 : }
180 :
181 13 : fn new_inner(
182 13 : password: Credentials<32>,
183 13 : channel_binding: ChannelBinding,
184 13 : nonce: String,
185 13 : ) -> ScramSha256 {
186 13 : ScramSha256 {
187 13 : message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
188 13 : state: State::Update {
189 13 : nonce,
190 13 : password,
191 13 : channel_binding,
192 13 : },
193 13 : }
194 13 : }
195 :
196 : /// Returns the message which should be sent to the backend in an `SASLResponse` message.
197 25 : pub fn message(&self) -> &[u8] {
198 25 : if let State::Done = self.state {
199 0 : panic!("invalid SCRAM state");
200 25 : }
201 25 : self.message.as_bytes()
202 25 : }
203 :
204 : /// Updates the state machine with the response from the backend.
205 : ///
206 : /// This should be called when an `AuthenticationSASLContinue` message is received.
207 12 : pub async fn update(&mut self, message: &[u8]) -> io::Result<()> {
208 12 : let (client_nonce, password, channel_binding) =
209 12 : match mem::replace(&mut self.state, State::Done) {
210 : State::Update {
211 12 : nonce,
212 12 : password,
213 12 : channel_binding,
214 12 : } => (nonce, password, channel_binding),
215 0 : _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
216 : };
217 :
218 12 : let message =
219 12 : str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
220 :
221 12 : let parsed = Parser::new(message).server_first_message()?;
222 :
223 12 : if !parsed.nonce.starts_with(&client_nonce) {
224 0 : return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
225 1 : }
226 :
227 12 : let (client_key, server_key) = match password {
228 12 : Credentials::Password(password) => {
229 12 : let salt = match base64::decode(parsed.salt) {
230 12 : Ok(salt) => salt,
231 0 : Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
232 : };
233 :
234 12 : let salted_password = hi(&password, &salt, parsed.iteration_count).await;
235 :
236 24 : let make_key = |name| {
237 24 : let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
238 24 : .expect("HMAC is able to accept all key sizes");
239 24 : hmac.update(name);
240 24 :
241 24 : let mut key = [0u8; 32];
242 24 : key.copy_from_slice(hmac.finalize().into_bytes().as_slice());
243 24 : key
244 24 : };
245 :
246 12 : (make_key(b"Client Key"), make_key(b"Server Key"))
247 : }
248 0 : Credentials::Keys(keys) => (keys.client_key, keys.server_key),
249 : };
250 :
251 12 : let mut hash = Sha256::default();
252 12 : hash.update(client_key);
253 12 : let stored_key = hash.finalize_fixed();
254 12 :
255 12 : let mut cbind_input = vec![];
256 12 : cbind_input.extend(channel_binding.gs2_header().as_bytes());
257 12 : cbind_input.extend(channel_binding.cbind_data());
258 12 : let cbind_input = base64::encode(&cbind_input);
259 12 :
260 12 : self.message.clear();
261 12 : write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
262 12 :
263 12 : let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
264 12 :
265 12 : let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
266 12 : .expect("HMAC is able to accept all key sizes");
267 12 : hmac.update(auth_message.as_bytes());
268 12 : let client_signature = hmac.finalize().into_bytes();
269 12 :
270 12 : let mut client_proof = client_key;
271 384 : for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
272 384 : *proof ^= signature;
273 384 : }
274 :
275 12 : write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
276 12 :
277 12 : self.state = State::Finish {
278 12 : server_key,
279 12 : auth_message,
280 12 : };
281 12 : Ok(())
282 1 : }
283 :
284 : /// Finalizes the authentication process.
285 : ///
286 : /// This should be called when the backend sends an `AuthenticationSASLFinal` message.
287 : /// Authentication has only succeeded if this method returns `Ok(())`.
288 7 : pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
289 7 : let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) {
290 : State::Finish {
291 7 : server_key,
292 7 : auth_message,
293 7 : } => (server_key, auth_message),
294 0 : _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
295 : };
296 :
297 7 : let message =
298 7 : str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
299 :
300 7 : let parsed = Parser::new(message).server_final_message()?;
301 :
302 7 : let verifier = match parsed {
303 0 : ServerFinalMessage::Error(e) => {
304 0 : return Err(io::Error::new(
305 0 : io::ErrorKind::Other,
306 0 : format!("SCRAM error: {}", e),
307 0 : ));
308 : }
309 7 : ServerFinalMessage::Verifier(verifier) => verifier,
310 : };
311 :
312 7 : let verifier = match base64::decode(verifier) {
313 7 : Ok(verifier) => verifier,
314 0 : Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
315 : };
316 :
317 7 : let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
318 7 : .expect("HMAC is able to accept all key sizes");
319 7 : hmac.update(auth_message.as_bytes());
320 7 : hmac.verify_slice(&verifier)
321 7 : .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
322 7 : }
323 : }
324 :
325 : struct Parser<'a> {
326 : s: &'a str,
327 : it: iter::Peekable<str::CharIndices<'a>>,
328 : }
329 :
330 : impl<'a> Parser<'a> {
331 20 : fn new(s: &'a str) -> Parser<'a> {
332 20 : Parser {
333 20 : s,
334 20 : it: s.char_indices().peekable(),
335 20 : }
336 20 : }
337 :
338 118 : fn eat(&mut self, target: char) -> io::Result<()> {
339 118 : match self.it.next() {
340 118 : Some((_, c)) if c == target => Ok(()),
341 0 : Some((i, c)) => {
342 0 : let m = format!(
343 0 : "unexpected character at byte {}: expected `{}` but got `{}",
344 0 : i, target, c
345 0 : );
346 0 : Err(io::Error::new(io::ErrorKind::InvalidInput, m))
347 : }
348 0 : None => Err(io::Error::new(
349 0 : io::ErrorKind::UnexpectedEof,
350 0 : "unexpected EOF",
351 0 : )),
352 : }
353 118 : }
354 :
355 46 : fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
356 46 : where
357 46 : F: Fn(char) -> bool,
358 46 : {
359 46 : let start = match self.it.peek() {
360 46 : Some(&(i, _)) => i,
361 0 : None => return Ok(""),
362 : };
363 :
364 : loop {
365 1337 : match self.it.peek() {
366 1317 : Some(&(_, c)) if f(c) => {
367 1291 : self.it.next();
368 1291 : }
369 26 : Some(&(i, _)) => return Ok(&self.s[start..i]),
370 20 : None => return Ok(&self.s[start..]),
371 : }
372 : }
373 46 : }
374 :
375 13 : fn printable(&mut self) -> io::Result<&'a str> {
376 631 : self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
377 13 : }
378 :
379 13 : fn nonce(&mut self) -> io::Result<&'a str> {
380 13 : self.eat('r')?;
381 13 : self.eat('=')?;
382 13 : self.printable()
383 13 : }
384 :
385 20 : fn base64(&mut self) -> io::Result<&'a str> {
386 637 : self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
387 20 : }
388 :
389 13 : fn salt(&mut self) -> io::Result<&'a str> {
390 13 : self.eat('s')?;
391 13 : self.eat('=')?;
392 13 : self.base64()
393 13 : }
394 :
395 13 : fn posit_number(&mut self) -> io::Result<u32> {
396 49 : let n = self.take_while(|c| c.is_ascii_digit())?;
397 13 : n.parse()
398 13 : .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
399 13 : }
400 :
401 13 : fn iteration_count(&mut self) -> io::Result<u32> {
402 13 : self.eat('i')?;
403 13 : self.eat('=')?;
404 13 : self.posit_number()
405 13 : }
406 :
407 20 : fn eof(&mut self) -> io::Result<()> {
408 20 : match self.it.peek() {
409 0 : Some(&(i, _)) => Err(io::Error::new(
410 0 : io::ErrorKind::InvalidInput,
411 0 : format!("unexpected trailing data at byte {}", i),
412 0 : )),
413 20 : None => Ok(()),
414 : }
415 20 : }
416 :
417 13 : fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
418 13 : let nonce = self.nonce()?;
419 13 : self.eat(',')?;
420 13 : let salt = self.salt()?;
421 13 : self.eat(',')?;
422 13 : let iteration_count = self.iteration_count()?;
423 13 : self.eof()?;
424 :
425 13 : Ok(ServerFirstMessage {
426 13 : nonce,
427 13 : salt,
428 13 : iteration_count,
429 13 : })
430 13 : }
431 :
432 0 : fn value(&mut self) -> io::Result<&'a str> {
433 0 : self.take_while(|c| matches!(c, '\0' | '=' | ','))
434 0 : }
435 :
436 7 : fn server_error(&mut self) -> io::Result<Option<&'a str>> {
437 7 : match self.it.peek() {
438 0 : Some(&(_, 'e')) => {}
439 7 : _ => return Ok(None),
440 : }
441 :
442 0 : self.eat('e')?;
443 0 : self.eat('=')?;
444 0 : self.value().map(Some)
445 7 : }
446 :
447 7 : fn verifier(&mut self) -> io::Result<&'a str> {
448 7 : self.eat('v')?;
449 7 : self.eat('=')?;
450 7 : self.base64()
451 7 : }
452 :
453 7 : fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
454 7 : let message = match self.server_error()? {
455 0 : Some(error) => ServerFinalMessage::Error(error),
456 7 : None => ServerFinalMessage::Verifier(self.verifier()?),
457 : };
458 7 : self.eof()?;
459 7 : Ok(message)
460 7 : }
461 : }
462 :
463 : struct ServerFirstMessage<'a> {
464 : nonce: &'a str,
465 : salt: &'a str,
466 : iteration_count: u32,
467 : }
468 :
469 : enum ServerFinalMessage<'a> {
470 : Error(&'a str),
471 : Verifier(&'a str),
472 : }
473 :
474 : #[cfg(test)]
475 : mod test {
476 : use super::*;
477 :
478 : #[test]
479 1 : fn parse_server_first_message() {
480 1 : let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
481 1 : let message = Parser::new(message).server_first_message().unwrap();
482 1 : assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
483 1 : assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
484 1 : assert_eq!(message.iteration_count, 4096);
485 1 : }
486 :
487 : // recorded auth exchange from psql
488 : #[tokio::test]
489 1 : async fn exchange() {
490 1 : let password = "foobar";
491 1 : let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
492 1 :
493 1 : let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
494 1 : let server_first = "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
495 1 : =4096";
496 1 : let client_final = "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
497 1 : 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
498 1 : let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
499 1 :
500 1 : let mut scram = ScramSha256::new_inner(
501 1 : Credentials::Password(normalize(password.as_bytes())),
502 1 : ChannelBinding::unsupported(),
503 1 : nonce.to_string(),
504 1 : );
505 1 : assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
506 1 :
507 1 : scram.update(server_first.as_bytes()).await.unwrap();
508 1 : assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
509 1 :
510 1 : scram.finish(server_final.as_bytes()).unwrap();
511 1 : }
512 : }
|