Line data Source code
1 : //! Implementation of the SCRAM authentication algorithm.
2 :
3 : use std::convert::Infallible;
4 :
5 : use hmac::{Hmac, Mac};
6 : use sha2::Sha256;
7 : use tokio::task::yield_now;
8 :
9 : use super::messages::{
10 : ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
11 : };
12 : use super::secret::ServerSecret;
13 : use super::signature::SignatureBuilder;
14 : use super::ScramKey;
15 : use crate::config;
16 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
17 :
18 : /// The only channel binding mode we currently support.
19 : #[derive(Debug)]
20 : struct TlsServerEndPoint;
21 :
22 : impl std::fmt::Display for TlsServerEndPoint {
23 12 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 12 : write!(f, "tls-server-end-point")
25 12 : }
26 : }
27 :
28 : impl std::str::FromStr for TlsServerEndPoint {
29 : type Err = sasl::Error;
30 :
31 12 : fn from_str(s: &str) -> Result<Self, Self::Err> {
32 12 : match s {
33 12 : "tls-server-end-point" => Ok(TlsServerEndPoint),
34 0 : _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
35 : }
36 12 : }
37 : }
38 :
39 : struct SaslSentInner {
40 : cbind_flag: ChannelBinding<TlsServerEndPoint>,
41 : client_first_message_bare: String,
42 : server_first_message: OwnedServerFirstMessage,
43 : }
44 :
45 : struct SaslInitial {
46 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
47 : }
48 :
49 : enum ExchangeState {
50 : /// Waiting for [`ClientFirstMessage`].
51 : Initial(SaslInitial),
52 : /// Waiting for [`ClientFinalMessage`].
53 : SaltSent(SaslSentInner),
54 : }
55 :
56 : /// Server's side of SCRAM auth algorithm.
57 : pub struct Exchange<'a> {
58 : state: ExchangeState,
59 : secret: &'a ServerSecret,
60 : tls_server_end_point: config::TlsServerEndPoint,
61 : }
62 :
63 : impl<'a> Exchange<'a> {
64 26 : pub fn new(
65 26 : secret: &'a ServerSecret,
66 26 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
67 26 : tls_server_end_point: config::TlsServerEndPoint,
68 26 : ) -> Self {
69 26 : Self {
70 26 : state: ExchangeState::Initial(SaslInitial { nonce }),
71 26 : secret,
72 26 : tls_server_end_point,
73 26 : }
74 26 : }
75 : }
76 :
77 : // copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
78 8 : async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
79 8 : let hmac = Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
80 8 : let mut prev = hmac
81 8 : .clone()
82 8 : .chain_update(salt)
83 8 : .chain_update(1u32.to_be_bytes())
84 8 : .finalize()
85 8 : .into_bytes();
86 8 :
87 8 : let mut hi = prev;
88 :
89 32760 : for i in 1..iterations {
90 32760 : prev = hmac.clone().chain_update(prev).finalize().into_bytes();
91 :
92 1048320 : for (hi, prev) in hi.iter_mut().zip(prev) {
93 1048320 : *hi ^= prev;
94 1048320 : }
95 : // yield every ~250us
96 : // hopefully reduces tail latencies
97 32760 : if i % 1024 == 0 {
98 24 : yield_now().await
99 32736 : }
100 : }
101 :
102 8 : hi.into()
103 8 : }
104 :
105 : // copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
106 8 : async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey {
107 24 : let salted_password = pbkdf2(password, salt, iterations).await;
108 :
109 8 : let make_key = |name| {
110 8 : let key = Hmac::<Sha256>::new_from_slice(&salted_password)
111 8 : .expect("HMAC is able to accept all key sizes")
112 8 : .chain_update(name)
113 8 : .finalize();
114 8 :
115 8 : <[u8; 32]>::from(key.into_bytes())
116 8 : };
117 :
118 8 : make_key(b"Client Key").into()
119 8 : }
120 :
121 8 : pub async fn exchange(
122 8 : secret: &ServerSecret,
123 8 : password: &[u8],
124 8 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
125 8 : let salt = base64::decode(&secret.salt_base64)?;
126 24 : let client_key = derive_client_key(password, &salt, secret.iterations).await;
127 :
128 8 : if secret.is_password_invalid(&client_key).into() {
129 2 : Ok(sasl::Outcome::Failure("password doesn't match"))
130 : } else {
131 6 : Ok(sasl::Outcome::Success(client_key))
132 : }
133 8 : }
134 :
135 : impl SaslInitial {
136 26 : fn transition(
137 26 : &self,
138 26 : secret: &ServerSecret,
139 26 : tls_server_end_point: &config::TlsServerEndPoint,
140 26 : input: &str,
141 26 : ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
142 26 : let client_first_message = ClientFirstMessage::parse(input)
143 26 : .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
144 :
145 : // If the flag is set to "y" and the server supports channel
146 : // binding, the server MUST fail authentication
147 26 : if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
148 2 : && tls_server_end_point.supported()
149 : {
150 2 : return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
151 24 : }
152 24 :
153 24 : let server_first_message = client_first_message.build_server_first_message(
154 24 : &(self.nonce)(),
155 24 : &secret.salt_base64,
156 24 : secret.iterations,
157 24 : );
158 24 : let msg = server_first_message.as_str().to_owned();
159 :
160 24 : let next = SaslSentInner {
161 24 : cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
162 24 : client_first_message_bare: client_first_message.bare.to_owned(),
163 24 : server_first_message,
164 24 : };
165 24 :
166 24 : Ok(sasl::Step::Continue(next, msg))
167 26 : }
168 : }
169 :
170 : impl SaslSentInner {
171 24 : fn transition(
172 24 : &self,
173 24 : secret: &ServerSecret,
174 24 : tls_server_end_point: &config::TlsServerEndPoint,
175 24 : input: &str,
176 24 : ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
177 24 : let Self {
178 24 : cbind_flag,
179 24 : client_first_message_bare,
180 24 : server_first_message,
181 24 : } = self;
182 :
183 24 : let client_final_message = ClientFinalMessage::parse(input)
184 24 : .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
185 :
186 24 : let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
187 12 : config::TlsServerEndPoint::Sha256(x) => Ok(x),
188 0 : config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
189 24 : })?;
190 :
191 : // This might've been caused by a MITM attack
192 24 : if client_final_message.channel_binding != channel_binding {
193 8 : return Err(SaslError::ChannelBindingFailed(
194 8 : "insecure connection: secure channel data mismatch",
195 8 : ));
196 16 : }
197 16 :
198 16 : if client_final_message.nonce != server_first_message.nonce() {
199 0 : return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
200 16 : }
201 16 :
202 16 : let signature_builder = SignatureBuilder {
203 16 : client_first_message_bare,
204 16 : server_first_message: server_first_message.as_str(),
205 16 : client_final_message_without_proof: client_final_message.without_proof,
206 16 : };
207 16 :
208 16 : let client_key = signature_builder
209 16 : .build(&secret.stored_key)
210 16 : .derive_client_key(&client_final_message.proof);
211 16 :
212 16 : // Auth fails either if keys don't match or it's pre-determined to fail.
213 16 : if secret.is_password_invalid(&client_key).into() {
214 2 : return Ok(sasl::Step::Failure("password doesn't match"));
215 14 : }
216 14 :
217 14 : let msg =
218 14 : client_final_message.build_server_final_message(signature_builder, &secret.server_key);
219 14 :
220 14 : Ok(sasl::Step::Success(client_key, msg))
221 24 : }
222 : }
223 :
224 : impl sasl::Mechanism for Exchange<'_> {
225 : type Output = super::ScramKey;
226 :
227 50 : fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
228 50 : use {sasl::Step::*, ExchangeState::*};
229 50 : match &self.state {
230 26 : Initial(init) => {
231 26 : match init.transition(self.secret, &self.tls_server_end_point, input)? {
232 24 : Continue(sent, msg) => {
233 24 : self.state = SaltSent(sent);
234 24 : Ok(Continue(self, msg))
235 : }
236 : Success(x, _) => match x {},
237 0 : Failure(msg) => Ok(Failure(msg)),
238 : }
239 : }
240 24 : SaltSent(sent) => {
241 24 : match sent.transition(self.secret, &self.tls_server_end_point, input)? {
242 14 : Success(keys, msg) => Ok(Success(keys, msg)),
243 : Continue(x, _) => match x {},
244 2 : Failure(msg) => Ok(Failure(msg)),
245 : }
246 : }
247 : }
248 50 : }
249 : }
|