Line data Source code
1 : //! Salted Challenge Response Authentication Mechanism.
2 : //!
3 : //! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
4 : //!
5 : //! Reference implementation:
6 : //! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
7 : //! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
8 :
9 : mod countmin;
10 : mod exchange;
11 : mod key;
12 : mod messages;
13 : mod pbkdf2;
14 : mod secret;
15 : mod signature;
16 : pub mod threadpool;
17 :
18 : pub(crate) use exchange::{exchange, Exchange};
19 : use hmac::{Hmac, Mac};
20 : pub(crate) use key::ScramKey;
21 : pub(crate) use secret::ServerSecret;
22 : use sha2::{Digest, Sha256};
23 :
24 : const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
25 : const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
26 :
27 : /// A list of supported SCRAM methods.
28 : pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
29 : pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
30 :
31 : /// Decode base64 into array without any heap allocations
32 49 : fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
33 49 : let mut bytes = [0u8; N];
34 :
35 49 : let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
36 49 : if size != N {
37 0 : return None;
38 49 : }
39 49 :
40 49 : Some(bytes)
41 49 : }
42 :
43 : /// This function essentially is `Hmac(sha256, key, input)`.
44 : /// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
45 15 : fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
46 15 : let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
47 75 : parts.into_iter().for_each(|s| mac.update(s));
48 15 :
49 15 : mac.finalize().into_bytes().into()
50 15 : }
51 :
52 12 : fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
53 12 : let mut hasher = Sha256::new();
54 12 : parts.into_iter().for_each(|s| hasher.update(s));
55 12 :
56 12 : hasher.finalize().into()
57 12 : }
58 :
59 : #[cfg(test)]
60 : mod tests {
61 : use super::threadpool::ThreadPool;
62 : use super::{Exchange, ServerSecret};
63 : use crate::intern::EndpointIdInt;
64 : use crate::sasl::{Mechanism, Step};
65 : use crate::types::EndpointId;
66 :
67 : #[test]
68 1 : fn snapshot() {
69 1 : let iterations = 4096;
70 1 : let salt = "QSXCR+Q6sek8bf92";
71 1 : let stored_key = "FO+9jBb3MUukt6jJnzjPZOWc5ow/Pu6JtPyju0aqaE8=";
72 1 : let server_key = "qxJ1SbmSAi5EcS0J5Ck/cKAm/+Ixa+Kwp63f4OHDgzo=";
73 1 : let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",);
74 1 : let secret = ServerSecret::parse(&secret).unwrap();
75 :
76 : const NONCE: [u8; 18] = [
77 : 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
78 : ];
79 1 : let mut exchange = Exchange::new(
80 1 : &secret,
81 1 : || NONCE,
82 1 : crate::config::TlsServerEndPoint::Undefined,
83 1 : );
84 1 :
85 1 : let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
86 1 : let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
87 1 : let server_first =
88 1 : "r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,s=QSXCR+Q6sek8bf92,i=4096";
89 1 : let server_final = "v=qtUDIofVnIhM7tKn93EQUUt5vgMOldcDVu1HC+OH0o0=";
90 :
91 1 : exchange = match exchange.exchange(client_first).unwrap() {
92 1 : Step::Continue(exchange, message) => {
93 1 : assert_eq!(message, server_first);
94 1 : exchange
95 : }
96 0 : Step::Success(_, _) => panic!("expected continue, got success"),
97 0 : Step::Failure(f) => panic!("{f}"),
98 : };
99 :
100 1 : let key = match exchange.exchange(client_final).unwrap() {
101 1 : Step::Success(key, message) => {
102 1 : assert_eq!(message, server_final);
103 1 : key
104 : }
105 0 : Step::Continue(_, _) => panic!("expected success, got continue"),
106 0 : Step::Failure(f) => panic!("{f}"),
107 : };
108 :
109 1 : assert_eq!(
110 1 : key.as_bytes(),
111 1 : [
112 1 : 74, 103, 1, 132, 12, 31, 200, 48, 28, 54, 82, 232, 207, 12, 138, 189, 40, 32, 134,
113 1 : 27, 125, 170, 232, 35, 171, 167, 166, 41, 70, 228, 182, 112,
114 1 : ]
115 1 : );
116 1 : }
117 :
118 2 : async fn run_round_trip_test(server_password: &str, client_password: &str) {
119 2 : let pool = ThreadPool::new(1);
120 2 :
121 2 : let ep = EndpointId::from("foo");
122 2 : let ep = EndpointIdInt::from(ep);
123 :
124 6 : let scram_secret = ServerSecret::build(server_password).await.unwrap();
125 2 : let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
126 2 : .await
127 2 : .unwrap();
128 2 :
129 2 : match outcome {
130 1 : crate::sasl::Outcome::Success(_) => {}
131 1 : crate::sasl::Outcome::Failure(r) => panic!("{r}"),
132 : }
133 1 : }
134 :
135 : #[tokio::test]
136 1 : async fn round_trip() {
137 4 : run_round_trip_test("pencil", "pencil").await;
138 1 : }
139 :
140 : #[tokio::test]
141 : #[should_panic(expected = "password doesn't match")]
142 1 : async fn failure() {
143 4 : run_round_trip_test("pencil", "eraser").await;
144 1 : }
145 : }
|