Line data Source code
1 : //! SASL-based authentication support.
2 :
3 : use hmac::{Hmac, Mac};
4 : use rand::{self, Rng};
5 : use sha2::digest::FixedOutput;
6 : use sha2::{Digest, Sha256};
7 : use std::fmt::Write;
8 : use std::io;
9 : use std::iter;
10 : use std::mem;
11 : use std::str;
12 : use tokio::task::yield_now;
13 :
14 : const NONCE_LENGTH: usize = 24;
15 :
16 : /// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
17 : pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
18 : /// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
19 : pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
20 :
21 : // since postgres passwords are not required to exclude saslprep-prohibited
22 : // characters or even be valid UTF8, we run saslprep if possible and otherwise
23 : // return the raw password.
24 13 : fn normalize(pass: &[u8]) -> Vec<u8> {
25 13 : let pass = match str::from_utf8(pass) {
26 13 : Ok(pass) => pass,
27 0 : Err(_) => return pass.to_vec(),
28 : };
29 :
30 13 : match stringprep::saslprep(pass) {
31 13 : Ok(pass) => pass.into_owned().into_bytes(),
32 0 : Err(_) => pass.as_bytes().to_vec(),
33 : }
34 13 : }
35 :
36 29 : pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
37 29 : let mut hmac =
38 29 : Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
39 29 : hmac.update(salt);
40 29 : hmac.update(&[0, 0, 0, 1]);
41 29 : let mut prev = hmac.finalize().into_bytes();
42 29 :
43 29 : let mut hi = prev;
44 :
45 114660 : for i in 1..iterations {
46 114660 : let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
47 114660 : hmac.update(&prev);
48 114660 : prev = hmac.finalize().into_bytes();
49 :
50 3669120 : for (hi, prev) in hi.iter_mut().zip(prev) {
51 3669120 : *hi ^= prev;
52 3669120 : }
53 : // yield every ~250us
54 : // hopefully reduces tail latencies
55 114660 : if i % 1024 == 0 {
56 84 : yield_now().await
57 114576 : }
58 : }
59 :
60 29 : hi.into()
61 29 : }
62 :
63 : enum ChannelBindingInner {
64 : Unrequested,
65 : Unsupported,
66 : TlsServerEndPoint(Vec<u8>),
67 : }
68 :
69 : /// The channel binding configuration for a SCRAM authentication exchange.
70 : pub struct ChannelBinding(ChannelBindingInner);
71 :
72 : impl ChannelBinding {
73 : /// The server did not request channel binding.
74 2 : pub fn unrequested() -> ChannelBinding {
75 2 : ChannelBinding(ChannelBindingInner::Unrequested)
76 2 : }
77 :
78 : /// The server requested channel binding but the client is unable to provide it.
79 4 : pub fn unsupported() -> ChannelBinding {
80 4 : ChannelBinding(ChannelBindingInner::Unsupported)
81 4 : }
82 :
83 : /// The server requested channel binding and the client will use the `tls-server-end-point`
84 : /// method.
85 10 : pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
86 10 : ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
87 10 : }
88 :
89 25 : fn gs2_header(&self) -> &'static str {
90 25 : match self.0 {
91 1 : ChannelBindingInner::Unrequested => "y,,",
92 8 : ChannelBindingInner::Unsupported => "n,,",
93 16 : ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
94 : }
95 25 : }
96 :
97 12 : fn cbind_data(&self) -> &[u8] {
98 12 : match self.0 {
99 4 : ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
100 8 : ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
101 : }
102 12 : }
103 : }
104 :
105 : /// A pair of keys for the SCRAM-SHA-256 mechanism.
106 : /// See <https://datatracker.ietf.org/doc/html/rfc5802#section-3> for details.
107 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
108 : pub struct ScramKeys<const N: usize> {
109 : /// Used by server to authenticate client.
110 : pub client_key: [u8; N],
111 : /// Used by client to verify server's signature.
112 : pub server_key: [u8; N],
113 : }
114 :
115 : /// Password or keys which were derived from it.
116 : enum Credentials<const N: usize> {
117 : /// A regular password as a vector of bytes.
118 : Password(Vec<u8>),
119 : /// A precomputed pair of keys.
120 : Keys(ScramKeys<N>),
121 : }
122 :
123 : enum State {
124 : Update {
125 : nonce: String,
126 : password: Credentials<32>,
127 : channel_binding: ChannelBinding,
128 : },
129 : Finish {
130 : server_key: [u8; 32],
131 : auth_message: String,
132 : },
133 : Done,
134 : }
135 :
136 : /// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
137 : /// process.
138 : ///
139 : /// During the authentication process, if the backend sends an `AuthenticationSASL` message which
140 : /// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
141 : ///
142 : /// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
143 : /// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
144 : ///
145 : /// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
146 : /// passed to the `update()` method, after which the buffer returned by the `message()` method
147 : /// should be sent to the backend in a `SASLResponse` message.
148 : ///
149 : /// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
150 : /// to the `finish()` method, after which the authentication process is complete.
151 : pub struct ScramSha256 {
152 : message: String,
153 : state: State,
154 : }
155 :
156 12 : fn nonce() -> String {
157 12 : // rand 0.5's ThreadRng is cryptographically secure
158 12 : let mut rng = rand::thread_rng();
159 12 : (0..NONCE_LENGTH)
160 288 : .map(|_| {
161 288 : let mut v = rng.gen_range(0x21u8..0x7e);
162 288 : if v == 0x2c {
163 0 : v = 0x7e
164 288 : }
165 288 : v as char
166 288 : })
167 12 : .collect()
168 12 : }
169 :
170 : impl ScramSha256 {
171 : /// Constructs a new instance which will use the provided password for authentication.
172 12 : pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
173 12 : let password = Credentials::Password(normalize(password));
174 12 : ScramSha256::new_inner(password, channel_binding, nonce())
175 12 : }
176 :
177 : /// Constructs a new instance which will use the provided key pair for authentication.
178 0 : pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
179 0 : let password = Credentials::Keys(keys);
180 0 : ScramSha256::new_inner(password, channel_binding, nonce())
181 0 : }
182 :
183 13 : fn new_inner(
184 13 : password: Credentials<32>,
185 13 : channel_binding: ChannelBinding,
186 13 : nonce: String,
187 13 : ) -> ScramSha256 {
188 13 : ScramSha256 {
189 13 : message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
190 13 : state: State::Update {
191 13 : nonce,
192 13 : password,
193 13 : channel_binding,
194 13 : },
195 13 : }
196 13 : }
197 :
198 : /// Returns the message which should be sent to the backend in an `SASLResponse` message.
199 25 : pub fn message(&self) -> &[u8] {
200 25 : if let State::Done = self.state {
201 0 : panic!("invalid SCRAM state");
202 25 : }
203 25 : self.message.as_bytes()
204 25 : }
205 :
206 : /// Updates the state machine with the response from the backend.
207 : ///
208 : /// This should be called when an `AuthenticationSASLContinue` message is received.
209 12 : pub async fn update(&mut self, message: &[u8]) -> io::Result<()> {
210 12 : let (client_nonce, password, channel_binding) =
211 12 : match mem::replace(&mut self.state, State::Done) {
212 : State::Update {
213 12 : nonce,
214 12 : password,
215 12 : channel_binding,
216 12 : } => (nonce, password, channel_binding),
217 0 : _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
218 : };
219 :
220 12 : let message =
221 12 : str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
222 :
223 12 : let parsed = Parser::new(message).server_first_message()?;
224 :
225 12 : if !parsed.nonce.starts_with(&client_nonce) {
226 0 : return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
227 12 : }
228 :
229 12 : let (client_key, server_key) = match password {
230 12 : Credentials::Password(password) => {
231 12 : let salt = match base64::decode(parsed.salt) {
232 12 : Ok(salt) => salt,
233 0 : Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
234 : };
235 :
236 12 : let salted_password = hi(&password, &salt, parsed.iteration_count).await;
237 :
238 24 : let make_key = |name| {
239 24 : let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
240 24 : .expect("HMAC is able to accept all key sizes");
241 24 : hmac.update(name);
242 24 :
243 24 : let mut key = [0u8; 32];
244 24 : key.copy_from_slice(hmac.finalize().into_bytes().as_slice());
245 24 : key
246 24 : };
247 :
248 12 : (make_key(b"Client Key"), make_key(b"Server Key"))
249 : }
250 0 : Credentials::Keys(keys) => (keys.client_key, keys.server_key),
251 : };
252 :
253 12 : let mut hash = Sha256::default();
254 12 : hash.update(client_key);
255 12 : let stored_key = hash.finalize_fixed();
256 12 :
257 12 : let mut cbind_input = vec![];
258 12 : cbind_input.extend(channel_binding.gs2_header().as_bytes());
259 12 : cbind_input.extend(channel_binding.cbind_data());
260 12 : let cbind_input = base64::encode(&cbind_input);
261 12 :
262 12 : self.message.clear();
263 12 : write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
264 12 :
265 12 : let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
266 12 :
267 12 : let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
268 12 : .expect("HMAC is able to accept all key sizes");
269 12 : hmac.update(auth_message.as_bytes());
270 12 : let client_signature = hmac.finalize().into_bytes();
271 12 :
272 12 : let mut client_proof = client_key;
273 384 : for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
274 384 : *proof ^= signature;
275 384 : }
276 :
277 12 : write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
278 12 :
279 12 : self.state = State::Finish {
280 12 : server_key,
281 12 : auth_message,
282 12 : };
283 12 : Ok(())
284 12 : }
285 :
286 : /// Finalizes the authentication process.
287 : ///
288 : /// This should be called when the backend sends an `AuthenticationSASLFinal` message.
289 : /// Authentication has only succeeded if this method returns `Ok(())`.
290 7 : pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
291 7 : let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) {
292 : State::Finish {
293 7 : server_key,
294 7 : auth_message,
295 7 : } => (server_key, auth_message),
296 0 : _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
297 : };
298 :
299 7 : let message =
300 7 : str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
301 :
302 7 : let parsed = Parser::new(message).server_final_message()?;
303 :
304 7 : let verifier = match parsed {
305 0 : ServerFinalMessage::Error(e) => {
306 0 : return Err(io::Error::new(
307 0 : io::ErrorKind::Other,
308 0 : format!("SCRAM error: {}", e),
309 0 : ));
310 : }
311 7 : ServerFinalMessage::Verifier(verifier) => verifier,
312 : };
313 :
314 7 : let verifier = match base64::decode(verifier) {
315 7 : Ok(verifier) => verifier,
316 0 : Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
317 : };
318 :
319 7 : let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
320 7 : .expect("HMAC is able to accept all key sizes");
321 7 : hmac.update(auth_message.as_bytes());
322 7 : hmac.verify_slice(&verifier)
323 7 : .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
324 7 : }
325 : }
326 :
327 : struct Parser<'a> {
328 : s: &'a str,
329 : it: iter::Peekable<str::CharIndices<'a>>,
330 : }
331 :
332 : impl<'a> Parser<'a> {
333 20 : fn new(s: &'a str) -> Parser<'a> {
334 20 : Parser {
335 20 : s,
336 20 : it: s.char_indices().peekable(),
337 20 : }
338 20 : }
339 :
340 118 : fn eat(&mut self, target: char) -> io::Result<()> {
341 118 : match self.it.next() {
342 118 : Some((_, c)) if c == target => Ok(()),
343 0 : Some((i, c)) => {
344 0 : let m = format!(
345 0 : "unexpected character at byte {}: expected `{}` but got `{}",
346 0 : i, target, c
347 0 : );
348 0 : Err(io::Error::new(io::ErrorKind::InvalidInput, m))
349 : }
350 0 : None => Err(io::Error::new(
351 0 : io::ErrorKind::UnexpectedEof,
352 0 : "unexpected EOF",
353 0 : )),
354 : }
355 118 : }
356 :
357 46 : fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
358 46 : where
359 46 : F: Fn(char) -> bool,
360 46 : {
361 46 : let start = match self.it.peek() {
362 46 : Some(&(i, _)) => i,
363 0 : None => return Ok(""),
364 : };
365 :
366 : loop {
367 1337 : match self.it.peek() {
368 1317 : Some(&(_, c)) if f(c) => {
369 1291 : self.it.next();
370 1291 : }
371 26 : Some(&(i, _)) => return Ok(&self.s[start..i]),
372 20 : None => return Ok(&self.s[start..]),
373 : }
374 : }
375 46 : }
376 :
377 13 : fn printable(&mut self) -> io::Result<&'a str> {
378 631 : self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
379 13 : }
380 :
381 13 : fn nonce(&mut self) -> io::Result<&'a str> {
382 13 : self.eat('r')?;
383 13 : self.eat('=')?;
384 13 : self.printable()
385 13 : }
386 :
387 20 : fn base64(&mut self) -> io::Result<&'a str> {
388 637 : self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
389 20 : }
390 :
391 13 : fn salt(&mut self) -> io::Result<&'a str> {
392 13 : self.eat('s')?;
393 13 : self.eat('=')?;
394 13 : self.base64()
395 13 : }
396 :
397 13 : fn posit_number(&mut self) -> io::Result<u32> {
398 49 : let n = self.take_while(|c| c.is_ascii_digit())?;
399 13 : n.parse()
400 13 : .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
401 13 : }
402 :
403 13 : fn iteration_count(&mut self) -> io::Result<u32> {
404 13 : self.eat('i')?;
405 13 : self.eat('=')?;
406 13 : self.posit_number()
407 13 : }
408 :
409 20 : fn eof(&mut self) -> io::Result<()> {
410 20 : match self.it.peek() {
411 0 : Some(&(i, _)) => Err(io::Error::new(
412 0 : io::ErrorKind::InvalidInput,
413 0 : format!("unexpected trailing data at byte {}", i),
414 0 : )),
415 20 : None => Ok(()),
416 : }
417 20 : }
418 :
419 13 : fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
420 13 : let nonce = self.nonce()?;
421 13 : self.eat(',')?;
422 13 : let salt = self.salt()?;
423 13 : self.eat(',')?;
424 13 : let iteration_count = self.iteration_count()?;
425 13 : self.eof()?;
426 :
427 13 : Ok(ServerFirstMessage {
428 13 : nonce,
429 13 : salt,
430 13 : iteration_count,
431 13 : })
432 13 : }
433 :
434 0 : fn value(&mut self) -> io::Result<&'a str> {
435 0 : self.take_while(|c| matches!(c, '\0' | '=' | ','))
436 0 : }
437 :
438 7 : fn server_error(&mut self) -> io::Result<Option<&'a str>> {
439 7 : match self.it.peek() {
440 0 : Some(&(_, 'e')) => {}
441 7 : _ => return Ok(None),
442 : }
443 :
444 0 : self.eat('e')?;
445 0 : self.eat('=')?;
446 0 : self.value().map(Some)
447 7 : }
448 :
449 7 : fn verifier(&mut self) -> io::Result<&'a str> {
450 7 : self.eat('v')?;
451 7 : self.eat('=')?;
452 7 : self.base64()
453 7 : }
454 :
455 7 : fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
456 7 : let message = match self.server_error()? {
457 0 : Some(error) => ServerFinalMessage::Error(error),
458 7 : None => ServerFinalMessage::Verifier(self.verifier()?),
459 : };
460 7 : self.eof()?;
461 7 : Ok(message)
462 7 : }
463 : }
464 :
465 : struct ServerFirstMessage<'a> {
466 : nonce: &'a str,
467 : salt: &'a str,
468 : iteration_count: u32,
469 : }
470 :
471 : enum ServerFinalMessage<'a> {
472 : Error(&'a str),
473 : Verifier(&'a str),
474 : }
475 :
476 : #[cfg(test)]
477 : mod test {
478 : use super::*;
479 :
480 : #[test]
481 1 : fn parse_server_first_message() {
482 1 : let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
483 1 : let message = Parser::new(message).server_first_message().unwrap();
484 1 : assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
485 1 : assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
486 1 : assert_eq!(message.iteration_count, 4096);
487 1 : }
488 :
489 : // recorded auth exchange from psql
490 : #[tokio::test]
491 1 : async fn exchange() {
492 1 : let password = "foobar";
493 1 : let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
494 1 :
495 1 : let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
496 1 : let server_first =
497 1 : "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
498 1 : =4096";
499 1 : let client_final =
500 1 : "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
501 1 : 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
502 1 : let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
503 1 :
504 1 : let mut scram = ScramSha256::new_inner(
505 1 : Credentials::Password(normalize(password.as_bytes())),
506 1 : ChannelBinding::unsupported(),
507 1 : nonce.to_string(),
508 1 : );
509 1 : assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
510 1 :
511 1 : scram.update(server_first.as_bytes()).await.unwrap();
512 1 : assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
513 1 :
514 1 : scram.finish(server_final.as_bytes()).unwrap();
515 1 : }
516 : }
|