Line data Source code
1 : //! Implementation of the SCRAM authentication algorithm.
2 :
3 : use std::convert::Infallible;
4 :
5 : use base64::Engine as _;
6 : use base64::prelude::BASE64_STANDARD;
7 : use tracing::{debug, trace};
8 :
9 : use super::messages::{
10 : ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
11 : };
12 : use super::pbkdf2::Pbkdf2;
13 : use super::secret::ServerSecret;
14 : use super::signature::SignatureBuilder;
15 : use super::threadpool::ThreadPool;
16 : use super::{ScramKey, pbkdf2};
17 : use crate::intern::{EndpointIdInt, RoleNameInt};
18 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
19 : use crate::scram::cache::Pbkdf2CacheEntry;
20 :
21 : /// The only channel binding mode we currently support.
22 : #[derive(Debug)]
23 : struct TlsServerEndPoint;
24 :
25 : impl std::fmt::Display for TlsServerEndPoint {
26 6 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 6 : write!(f, "tls-server-end-point")
28 6 : }
29 : }
30 :
31 : impl std::str::FromStr for TlsServerEndPoint {
32 : type Err = sasl::Error;
33 :
34 6 : fn from_str(s: &str) -> Result<Self, Self::Err> {
35 6 : match s {
36 6 : "tls-server-end-point" => Ok(TlsServerEndPoint),
37 0 : _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
38 : }
39 6 : }
40 : }
41 :
42 : struct SaslSentInner {
43 : cbind_flag: ChannelBinding<TlsServerEndPoint>,
44 : client_first_message_bare: String,
45 : server_first_message: OwnedServerFirstMessage,
46 : }
47 :
48 : struct SaslInitial {
49 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
50 : }
51 :
52 : enum ExchangeState {
53 : /// Waiting for [`ClientFirstMessage`].
54 : Initial(SaslInitial),
55 : /// Waiting for [`ClientFinalMessage`].
56 : SaltSent(SaslSentInner),
57 : }
58 :
59 : /// Server's side of SCRAM auth algorithm.
60 : pub(crate) struct Exchange<'a> {
61 : state: ExchangeState,
62 : secret: &'a ServerSecret,
63 : tls_server_end_point: crate::tls::TlsServerEndPoint,
64 : }
65 :
66 : impl<'a> Exchange<'a> {
67 13 : pub(crate) fn new(
68 13 : secret: &'a ServerSecret,
69 13 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
70 13 : tls_server_end_point: crate::tls::TlsServerEndPoint,
71 13 : ) -> Self {
72 13 : Self {
73 13 : state: ExchangeState::Initial(SaslInitial { nonce }),
74 13 : secret,
75 13 : tls_server_end_point,
76 13 : }
77 13 : }
78 : }
79 :
80 14 : async fn derive_client_key(
81 14 : pool: &ThreadPool,
82 14 : endpoint: EndpointIdInt,
83 14 : password: &[u8],
84 14 : salt: &[u8],
85 14 : iterations: u32,
86 14 : ) -> pbkdf2::Block {
87 14 : pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
88 14 : .await
89 14 : }
90 :
91 : /// For cleartext flow, we need to derive the client key to
92 : /// 1. authenticate the client.
93 : /// 2. authenticate with compute.
94 8 : pub(crate) async fn exchange(
95 8 : pool: &ThreadPool,
96 8 : endpoint: EndpointIdInt,
97 8 : role: RoleNameInt,
98 8 : secret: &ServerSecret,
99 8 : password: &[u8],
100 8 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
101 8 : if secret.iterations > CACHED_ROUNDS {
102 8 : exchange_with_cache(pool, endpoint, role, secret, password).await
103 : } else {
104 0 : let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
105 0 : let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
106 0 : Ok(validate_pbkdf2(secret, &hash))
107 : }
108 8 : }
109 :
110 : /// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only,
111 : /// which is not enough by itself to perform an offline brute force.
112 8 : async fn exchange_with_cache(
113 8 : pool: &ThreadPool,
114 8 : endpoint: EndpointIdInt,
115 8 : role: RoleNameInt,
116 8 : secret: &ServerSecret,
117 8 : password: &[u8],
118 8 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
119 8 : let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
120 :
121 8 : debug_assert!(
122 8 : secret.iterations > CACHED_ROUNDS,
123 0 : "we should not cache password data if there isn't enough rounds needed"
124 : );
125 :
126 : // compute the prefix of the pbkdf2 output.
127 8 : let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await;
128 :
129 8 : if let Some(entry) = pool.cache.get_entry(endpoint, role) {
130 : // hot path: let's check the threadpool cache
131 2 : if secret.cached_at == entry.cached_from {
132 : // cache is valid. compute the full hash by adding the prefix to the suffix.
133 2 : let mut hash = prefix;
134 2 : pbkdf2::xor_assign(&mut hash, &entry.suffix);
135 2 : let outcome = validate_pbkdf2(secret, &hash);
136 :
137 2 : if matches!(outcome, sasl::Outcome::Success(_)) {
138 1 : trace!("password validated from cache");
139 1 : }
140 :
141 2 : return Ok(outcome);
142 0 : }
143 :
144 : // cached key is no longer valid.
145 0 : debug!("invalidating cached password");
146 0 : entry.invalidate();
147 6 : }
148 :
149 : // slow path: full password hash.
150 6 : let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
151 6 : let outcome = validate_pbkdf2(secret, &hash);
152 :
153 6 : let client_key = match outcome {
154 4 : sasl::Outcome::Success(client_key) => client_key,
155 2 : sasl::Outcome::Failure(_) => return Ok(outcome),
156 : };
157 :
158 4 : trace!("storing cached password");
159 :
160 : // time to cache, compute the suffix by subtracting the prefix from the hash.
161 4 : let mut suffix = hash;
162 4 : pbkdf2::xor_assign(&mut suffix, &prefix);
163 :
164 4 : pool.cache.insert(
165 4 : endpoint,
166 4 : role,
167 4 : Pbkdf2CacheEntry {
168 4 : cached_from: secret.cached_at,
169 4 : suffix,
170 4 : },
171 : );
172 :
173 4 : Ok(sasl::Outcome::Success(client_key))
174 8 : }
175 :
176 8 : fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome<ScramKey> {
177 8 : let client_key = super::ScramKey::client_key(&(*hash).into());
178 8 : if secret.is_password_invalid(&client_key).into() {
179 3 : sasl::Outcome::Failure("password doesn't match")
180 : } else {
181 5 : sasl::Outcome::Success(client_key)
182 : }
183 8 : }
184 :
185 : const CACHED_ROUNDS: u32 = 16;
186 :
187 : impl SaslInitial {
188 13 : fn transition(
189 13 : &self,
190 13 : secret: &ServerSecret,
191 13 : tls_server_end_point: &crate::tls::TlsServerEndPoint,
192 13 : input: &str,
193 13 : ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
194 13 : let client_first_message = ClientFirstMessage::parse(input)
195 13 : .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
196 :
197 : // If the flag is set to "y" and the server supports channel
198 : // binding, the server MUST fail authentication
199 13 : if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
200 1 : && tls_server_end_point.supported()
201 : {
202 1 : return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
203 12 : }
204 :
205 12 : let server_first_message = client_first_message.build_server_first_message(
206 12 : &(self.nonce)(),
207 12 : &secret.salt_base64,
208 12 : secret.iterations,
209 : );
210 12 : let msg = server_first_message.as_str().to_owned();
211 :
212 12 : let next = SaslSentInner {
213 12 : cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
214 12 : client_first_message_bare: client_first_message.bare.to_owned(),
215 12 : server_first_message,
216 : };
217 :
218 12 : Ok(sasl::Step::Continue(next, msg))
219 13 : }
220 : }
221 :
222 : impl SaslSentInner {
223 12 : fn transition(
224 12 : &self,
225 12 : secret: &ServerSecret,
226 12 : tls_server_end_point: &crate::tls::TlsServerEndPoint,
227 12 : input: &str,
228 12 : ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
229 : let Self {
230 12 : cbind_flag,
231 12 : client_first_message_bare,
232 12 : server_first_message,
233 12 : } = self;
234 :
235 12 : let client_final_message = ClientFinalMessage::parse(input)
236 12 : .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
237 :
238 12 : let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
239 6 : crate::tls::TlsServerEndPoint::Sha256(x) => Ok(x),
240 0 : crate::tls::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
241 6 : })?;
242 :
243 : // This might've been caused by a MITM attack
244 12 : if client_final_message.channel_binding != channel_binding {
245 4 : return Err(SaslError::ChannelBindingFailed(
246 4 : "insecure connection: secure channel data mismatch",
247 4 : ));
248 8 : }
249 :
250 8 : if client_final_message.nonce != server_first_message.nonce() {
251 0 : return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
252 8 : }
253 :
254 8 : let signature_builder = SignatureBuilder {
255 8 : client_first_message_bare,
256 8 : server_first_message: server_first_message.as_str(),
257 8 : client_final_message_without_proof: client_final_message.without_proof,
258 8 : };
259 :
260 8 : let client_key = signature_builder
261 8 : .build(&secret.stored_key)
262 8 : .derive_client_key(&client_final_message.proof);
263 :
264 : // Auth fails either if keys don't match or it's pre-determined to fail.
265 8 : if secret.is_password_invalid(&client_key).into() {
266 1 : return Ok(sasl::Step::Failure("password doesn't match"));
267 7 : }
268 :
269 7 : let msg =
270 7 : client_final_message.build_server_final_message(signature_builder, &secret.server_key);
271 :
272 7 : Ok(sasl::Step::Success(client_key, msg))
273 12 : }
274 : }
275 :
276 : impl sasl::Mechanism for Exchange<'_> {
277 : type Output = super::ScramKey;
278 :
279 25 : fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
280 : use ExchangeState;
281 : use sasl::Step;
282 25 : match &self.state {
283 13 : ExchangeState::Initial(init) => {
284 13 : match init.transition(self.secret, &self.tls_server_end_point, input)? {
285 12 : Step::Continue(sent, msg) => {
286 12 : self.state = ExchangeState::SaltSent(sent);
287 12 : Ok(Step::Continue(self, msg))
288 : }
289 0 : Step::Failure(msg) => Ok(Step::Failure(msg)),
290 : }
291 : }
292 12 : ExchangeState::SaltSent(sent) => {
293 12 : match sent.transition(self.secret, &self.tls_server_end_point, input)? {
294 7 : Step::Success(keys, msg) => Ok(Step::Success(keys, msg)),
295 1 : Step::Failure(msg) => Ok(Step::Failure(msg)),
296 : }
297 : }
298 : }
299 25 : }
300 : }
|