Line data Source code
1 : //! Salted Challenge Response Authentication Mechanism.
2 : //!
3 : //! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
4 : //!
5 : //! Reference implementation:
6 : //! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
7 : //! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
8 :
9 : mod cache;
10 : mod countmin;
11 : mod exchange;
12 : mod key;
13 : mod messages;
14 : mod pbkdf2;
15 : mod secret;
16 : mod signature;
17 : pub mod threadpool;
18 :
19 : use base64::Engine as _;
20 : use base64::prelude::BASE64_STANDARD;
21 : pub(crate) use exchange::{Exchange, exchange};
22 : pub(crate) use key::ScramKey;
23 : pub(crate) use secret::ServerSecret;
24 :
25 : const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
26 : const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
27 :
28 : /// A list of supported SCRAM methods.
29 : pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
30 : pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
31 :
32 : /// Decode base64 into array without any heap allocations
33 51 : fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
34 51 : let mut bytes = [0u8; N];
35 :
36 51 : let size = BASE64_STANDARD.decode_slice(input, &mut bytes).ok()?;
37 51 : if size != N {
38 0 : return None;
39 51 : }
40 :
41 51 : Some(bytes)
42 51 : }
43 :
44 : #[cfg(test)]
45 : mod tests {
46 : use super::threadpool::ThreadPool;
47 : use super::{Exchange, ServerSecret};
48 : use crate::intern::{EndpointIdInt, RoleNameInt};
49 : use crate::sasl::{Mechanism, Step};
50 : use crate::types::{EndpointId, RoleName};
51 :
52 : #[test]
53 1 : fn snapshot() {
54 1 : let iterations = 4096;
55 1 : let salt = "QSXCR+Q6sek8bf92";
56 1 : let stored_key = "FO+9jBb3MUukt6jJnzjPZOWc5ow/Pu6JtPyju0aqaE8=";
57 1 : let server_key = "qxJ1SbmSAi5EcS0J5Ck/cKAm/+Ixa+Kwp63f4OHDgzo=";
58 1 : let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",);
59 1 : let secret = ServerSecret::parse(&secret).unwrap();
60 :
61 : const NONCE: [u8; 18] = [
62 : 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
63 : ];
64 1 : let mut exchange =
65 1 : Exchange::new(&secret, || NONCE, crate::tls::TlsServerEndPoint::Undefined);
66 :
67 1 : let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
68 1 : let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
69 1 : let server_first =
70 1 : "r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,s=QSXCR+Q6sek8bf92,i=4096";
71 1 : let server_final = "v=qtUDIofVnIhM7tKn93EQUUt5vgMOldcDVu1HC+OH0o0=";
72 :
73 1 : exchange = match exchange.exchange(client_first).unwrap() {
74 1 : Step::Continue(exchange, message) => {
75 1 : assert_eq!(message, server_first);
76 1 : exchange
77 : }
78 0 : Step::Success(_, _) => panic!("expected continue, got success"),
79 0 : Step::Failure(f) => panic!("{f}"),
80 : };
81 :
82 1 : let key = match exchange.exchange(client_final).unwrap() {
83 1 : Step::Success(key, message) => {
84 1 : assert_eq!(message, server_final);
85 1 : key
86 : }
87 0 : Step::Continue(_, _) => panic!("expected success, got continue"),
88 0 : Step::Failure(f) => panic!("{f}"),
89 : };
90 :
91 1 : assert_eq!(
92 1 : key.as_bytes(),
93 : [
94 : 74, 103, 1, 132, 12, 31, 200, 48, 28, 54, 82, 232, 207, 12, 138, 189, 40, 32, 134,
95 : 27, 125, 170, 232, 35, 171, 167, 166, 41, 70, 228, 182, 112,
96 : ]
97 : );
98 1 : }
99 :
100 6 : async fn check(
101 6 : pool: &ThreadPool,
102 6 : scram_secret: &ServerSecret,
103 6 : password: &[u8],
104 6 : ) -> Result<(), &'static str> {
105 6 : let ep = EndpointId::from("foo");
106 6 : let ep = EndpointIdInt::from(ep);
107 6 : let role = RoleName::from("user");
108 6 : let role = RoleNameInt::from(&role);
109 :
110 6 : let outcome = super::exchange(pool, ep, role, scram_secret, password)
111 6 : .await
112 6 : .unwrap();
113 :
114 6 : match outcome {
115 3 : crate::sasl::Outcome::Success(_) => Ok(()),
116 3 : crate::sasl::Outcome::Failure(r) => Err(r),
117 : }
118 6 : }
119 :
120 2 : async fn run_round_trip_test(server_password: &str, client_password: &str) {
121 2 : let pool = ThreadPool::new(1);
122 2 : let scram_secret = ServerSecret::build(server_password).await.unwrap();
123 2 : check(&pool, &scram_secret, client_password.as_bytes())
124 2 : .await
125 2 : .unwrap();
126 2 : }
127 :
128 : #[tokio::test]
129 1 : async fn round_trip() {
130 1 : run_round_trip_test("pencil", "pencil").await;
131 1 : }
132 :
133 : #[tokio::test]
134 : #[should_panic(expected = "password doesn't match")]
135 1 : async fn failure() {
136 1 : run_round_trip_test("pencil", "eraser").await;
137 1 : }
138 :
139 : #[tokio::test]
140 : #[tracing_test::traced_test]
141 1 : async fn password_cache() {
142 1 : let pool = ThreadPool::new(1);
143 1 : let scram_secret = ServerSecret::build("password").await.unwrap();
144 :
145 : // wrong passwords are not added to cache
146 1 : check(&pool, &scram_secret, b"wrong").await.unwrap_err();
147 1 : assert!(!logs_contain("storing cached password"));
148 :
149 : // correct passwords get cached
150 1 : check(&pool, &scram_secret, b"password").await.unwrap();
151 1 : assert!(logs_contain("storing cached password"));
152 :
153 : // wrong passwords do not match the cache
154 1 : check(&pool, &scram_secret, b"wrong").await.unwrap_err();
155 1 : assert!(!logs_contain("password validated from cache"));
156 :
157 : // correct passwords match the cache
158 1 : check(&pool, &scram_secret, b"password").await.unwrap();
159 1 : assert!(logs_contain("password validated from cache"));
160 1 : }
161 : }
|