Line data Source code
1 : //! Implementation of the SCRAM authentication algorithm.
2 :
3 : use std::convert::Infallible;
4 :
5 : use hmac::{Hmac, Mac};
6 : use sha2::Sha256;
7 :
8 : use super::messages::{
9 : ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
10 : };
11 : use super::pbkdf2::Pbkdf2;
12 : use super::secret::ServerSecret;
13 : use super::signature::SignatureBuilder;
14 : use super::threadpool::ThreadPool;
15 : use super::ScramKey;
16 : use crate::config;
17 : use crate::intern::EndpointIdInt;
18 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
19 :
20 : /// The only channel binding mode we currently support.
21 : #[derive(Debug)]
22 : struct TlsServerEndPoint;
23 :
24 : impl std::fmt::Display for TlsServerEndPoint {
25 12 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 12 : write!(f, "tls-server-end-point")
27 12 : }
28 : }
29 :
30 : impl std::str::FromStr for TlsServerEndPoint {
31 : type Err = sasl::Error;
32 :
33 12 : fn from_str(s: &str) -> Result<Self, Self::Err> {
34 12 : match s {
35 12 : "tls-server-end-point" => Ok(TlsServerEndPoint),
36 0 : _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
37 : }
38 12 : }
39 : }
40 :
41 : struct SaslSentInner {
42 : cbind_flag: ChannelBinding<TlsServerEndPoint>,
43 : client_first_message_bare: String,
44 : server_first_message: OwnedServerFirstMessage,
45 : }
46 :
47 : struct SaslInitial {
48 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
49 : }
50 :
51 : enum ExchangeState {
52 : /// Waiting for [`ClientFirstMessage`].
53 : Initial(SaslInitial),
54 : /// Waiting for [`ClientFinalMessage`].
55 : SaltSent(SaslSentInner),
56 : }
57 :
58 : /// Server's side of SCRAM auth algorithm.
59 : pub struct Exchange<'a> {
60 : state: ExchangeState,
61 : secret: &'a ServerSecret,
62 : tls_server_end_point: config::TlsServerEndPoint,
63 : }
64 :
65 : impl<'a> Exchange<'a> {
66 26 : pub fn new(
67 26 : secret: &'a ServerSecret,
68 26 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
69 26 : tls_server_end_point: config::TlsServerEndPoint,
70 26 : ) -> Self {
71 26 : Self {
72 26 : state: ExchangeState::Initial(SaslInitial { nonce }),
73 26 : secret,
74 26 : tls_server_end_point,
75 26 : }
76 26 : }
77 : }
78 :
79 : // copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
80 8 : async fn derive_client_key(
81 8 : pool: &ThreadPool,
82 8 : endpoint: EndpointIdInt,
83 8 : password: &[u8],
84 8 : salt: &[u8],
85 8 : iterations: u32,
86 8 : ) -> ScramKey {
87 8 : let salted_password = pool
88 8 : .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
89 8 : .await
90 8 : .expect("job should not be cancelled");
91 8 :
92 8 : let make_key = |name| {
93 8 : let key = Hmac::<Sha256>::new_from_slice(&salted_password)
94 8 : .expect("HMAC is able to accept all key sizes")
95 8 : .chain_update(name)
96 8 : .finalize();
97 8 :
98 8 : <[u8; 32]>::from(key.into_bytes())
99 8 : };
100 :
101 8 : make_key(b"Client Key").into()
102 8 : }
103 :
104 8 : pub async fn exchange(
105 8 : pool: &ThreadPool,
106 8 : endpoint: EndpointIdInt,
107 8 : secret: &ServerSecret,
108 8 : password: &[u8],
109 8 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
110 8 : let salt = base64::decode(&secret.salt_base64)?;
111 8 : let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
112 :
113 8 : if secret.is_password_invalid(&client_key).into() {
114 2 : Ok(sasl::Outcome::Failure("password doesn't match"))
115 : } else {
116 6 : Ok(sasl::Outcome::Success(client_key))
117 : }
118 8 : }
119 :
120 : impl SaslInitial {
121 26 : fn transition(
122 26 : &self,
123 26 : secret: &ServerSecret,
124 26 : tls_server_end_point: &config::TlsServerEndPoint,
125 26 : input: &str,
126 26 : ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
127 26 : let client_first_message = ClientFirstMessage::parse(input)
128 26 : .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
129 :
130 : // If the flag is set to "y" and the server supports channel
131 : // binding, the server MUST fail authentication
132 26 : if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
133 2 : && tls_server_end_point.supported()
134 : {
135 2 : return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
136 24 : }
137 24 :
138 24 : let server_first_message = client_first_message.build_server_first_message(
139 24 : &(self.nonce)(),
140 24 : &secret.salt_base64,
141 24 : secret.iterations,
142 24 : );
143 24 : let msg = server_first_message.as_str().to_owned();
144 :
145 24 : let next = SaslSentInner {
146 24 : cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
147 24 : client_first_message_bare: client_first_message.bare.to_owned(),
148 24 : server_first_message,
149 24 : };
150 24 :
151 24 : Ok(sasl::Step::Continue(next, msg))
152 26 : }
153 : }
154 :
155 : impl SaslSentInner {
156 24 : fn transition(
157 24 : &self,
158 24 : secret: &ServerSecret,
159 24 : tls_server_end_point: &config::TlsServerEndPoint,
160 24 : input: &str,
161 24 : ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
162 24 : let Self {
163 24 : cbind_flag,
164 24 : client_first_message_bare,
165 24 : server_first_message,
166 24 : } = self;
167 :
168 24 : let client_final_message = ClientFinalMessage::parse(input)
169 24 : .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
170 :
171 24 : let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
172 12 : config::TlsServerEndPoint::Sha256(x) => Ok(x),
173 0 : config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
174 24 : })?;
175 :
176 : // This might've been caused by a MITM attack
177 24 : if client_final_message.channel_binding != channel_binding {
178 8 : return Err(SaslError::ChannelBindingFailed(
179 8 : "insecure connection: secure channel data mismatch",
180 8 : ));
181 16 : }
182 16 :
183 16 : if client_final_message.nonce != server_first_message.nonce() {
184 0 : return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
185 16 : }
186 16 :
187 16 : let signature_builder = SignatureBuilder {
188 16 : client_first_message_bare,
189 16 : server_first_message: server_first_message.as_str(),
190 16 : client_final_message_without_proof: client_final_message.without_proof,
191 16 : };
192 16 :
193 16 : let client_key = signature_builder
194 16 : .build(&secret.stored_key)
195 16 : .derive_client_key(&client_final_message.proof);
196 16 :
197 16 : // Auth fails either if keys don't match or it's pre-determined to fail.
198 16 : if secret.is_password_invalid(&client_key).into() {
199 2 : return Ok(sasl::Step::Failure("password doesn't match"));
200 14 : }
201 14 :
202 14 : let msg =
203 14 : client_final_message.build_server_final_message(signature_builder, &secret.server_key);
204 14 :
205 14 : Ok(sasl::Step::Success(client_key, msg))
206 24 : }
207 : }
208 :
209 : impl sasl::Mechanism for Exchange<'_> {
210 : type Output = super::ScramKey;
211 :
212 50 : fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
213 50 : use {sasl::Step::*, ExchangeState::*};
214 50 : match &self.state {
215 26 : Initial(init) => {
216 26 : match init.transition(self.secret, &self.tls_server_end_point, input)? {
217 24 : Continue(sent, msg) => {
218 24 : self.state = SaltSent(sent);
219 24 : Ok(Continue(self, msg))
220 : }
221 : Success(x, _) => match x {},
222 0 : Failure(msg) => Ok(Failure(msg)),
223 : }
224 : }
225 24 : SaltSent(sent) => {
226 24 : match sent.transition(self.secret, &self.tls_server_end_point, input)? {
227 14 : Success(keys, msg) => Ok(Success(keys, msg)),
228 : Continue(x, _) => match x {},
229 2 : Failure(msg) => Ok(Failure(msg)),
230 : }
231 : }
232 : }
233 50 : }
234 : }
|