Line data Source code
1 : //! Implementation of the SCRAM authentication algorithm.
2 :
3 : use std::convert::Infallible;
4 :
5 : use hmac::{Hmac, Mac};
6 : use sha2::Sha256;
7 :
8 : use super::messages::{
9 : ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
10 : };
11 : use super::pbkdf2::Pbkdf2;
12 : use super::secret::ServerSecret;
13 : use super::signature::SignatureBuilder;
14 : use super::threadpool::ThreadPool;
15 : use super::ScramKey;
16 : use crate::config;
17 : use crate::intern::EndpointIdInt;
18 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
19 :
20 : /// The only channel binding mode we currently support.
21 : #[derive(Debug)]
22 : struct TlsServerEndPoint;
23 :
24 : impl std::fmt::Display for TlsServerEndPoint {
25 36 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 36 : write!(f, "tls-server-end-point")
27 36 : }
28 : }
29 :
30 : impl std::str::FromStr for TlsServerEndPoint {
31 : type Err = sasl::Error;
32 :
33 36 : fn from_str(s: &str) -> Result<Self, Self::Err> {
34 36 : match s {
35 36 : "tls-server-end-point" => Ok(TlsServerEndPoint),
36 0 : _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
37 : }
38 36 : }
39 : }
40 :
41 : struct SaslSentInner {
42 : cbind_flag: ChannelBinding<TlsServerEndPoint>,
43 : client_first_message_bare: String,
44 : server_first_message: OwnedServerFirstMessage,
45 : }
46 :
47 : struct SaslInitial {
48 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
49 : }
50 :
51 : enum ExchangeState {
52 : /// Waiting for [`ClientFirstMessage`].
53 : Initial(SaslInitial),
54 : /// Waiting for [`ClientFinalMessage`].
55 : SaltSent(SaslSentInner),
56 : }
57 :
58 : /// Server's side of SCRAM auth algorithm.
59 : pub(crate) struct Exchange<'a> {
60 : state: ExchangeState,
61 : secret: &'a ServerSecret,
62 : tls_server_end_point: config::TlsServerEndPoint,
63 : }
64 :
65 : impl<'a> Exchange<'a> {
66 78 : pub(crate) fn new(
67 78 : secret: &'a ServerSecret,
68 78 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
69 78 : tls_server_end_point: config::TlsServerEndPoint,
70 78 : ) -> Self {
71 78 : Self {
72 78 : state: ExchangeState::Initial(SaslInitial { nonce }),
73 78 : secret,
74 78 : tls_server_end_point,
75 78 : }
76 78 : }
77 : }
78 :
79 : // copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
80 24 : async fn derive_client_key(
81 24 : pool: &ThreadPool,
82 24 : endpoint: EndpointIdInt,
83 24 : password: &[u8],
84 24 : salt: &[u8],
85 24 : iterations: u32,
86 24 : ) -> ScramKey {
87 24 : let salted_password = pool
88 24 : .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
89 24 : .await
90 24 : .expect("job should not be cancelled");
91 24 :
92 24 : let make_key = |name| {
93 24 : let key = Hmac::<Sha256>::new_from_slice(&salted_password)
94 24 : .expect("HMAC is able to accept all key sizes")
95 24 : .chain_update(name)
96 24 : .finalize();
97 24 :
98 24 : <[u8; 32]>::from(key.into_bytes())
99 24 : };
100 :
101 24 : make_key(b"Client Key").into()
102 24 : }
103 :
104 24 : pub(crate) async fn exchange(
105 24 : pool: &ThreadPool,
106 24 : endpoint: EndpointIdInt,
107 24 : secret: &ServerSecret,
108 24 : password: &[u8],
109 24 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
110 24 : let salt = base64::decode(&secret.salt_base64)?;
111 24 : let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
112 :
113 24 : if secret.is_password_invalid(&client_key).into() {
114 6 : Ok(sasl::Outcome::Failure("password doesn't match"))
115 : } else {
116 18 : Ok(sasl::Outcome::Success(client_key))
117 : }
118 24 : }
119 :
120 : impl SaslInitial {
121 78 : fn transition(
122 78 : &self,
123 78 : secret: &ServerSecret,
124 78 : tls_server_end_point: &config::TlsServerEndPoint,
125 78 : input: &str,
126 78 : ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
127 78 : let client_first_message = ClientFirstMessage::parse(input)
128 78 : .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
129 :
130 : // If the flag is set to "y" and the server supports channel
131 : // binding, the server MUST fail authentication
132 78 : if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
133 6 : && tls_server_end_point.supported()
134 : {
135 6 : return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
136 72 : }
137 72 :
138 72 : let server_first_message = client_first_message.build_server_first_message(
139 72 : &(self.nonce)(),
140 72 : &secret.salt_base64,
141 72 : secret.iterations,
142 72 : );
143 72 : let msg = server_first_message.as_str().to_owned();
144 :
145 72 : let next = SaslSentInner {
146 72 : cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
147 72 : client_first_message_bare: client_first_message.bare.to_owned(),
148 72 : server_first_message,
149 72 : };
150 72 :
151 72 : Ok(sasl::Step::Continue(next, msg))
152 78 : }
153 : }
154 :
155 : impl SaslSentInner {
156 72 : fn transition(
157 72 : &self,
158 72 : secret: &ServerSecret,
159 72 : tls_server_end_point: &config::TlsServerEndPoint,
160 72 : input: &str,
161 72 : ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
162 72 : let Self {
163 72 : cbind_flag,
164 72 : client_first_message_bare,
165 72 : server_first_message,
166 72 : } = self;
167 :
168 72 : let client_final_message = ClientFinalMessage::parse(input)
169 72 : .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
170 :
171 72 : let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
172 36 : config::TlsServerEndPoint::Sha256(x) => Ok(x),
173 0 : config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
174 72 : })?;
175 :
176 : // This might've been caused by a MITM attack
177 72 : if client_final_message.channel_binding != channel_binding {
178 24 : return Err(SaslError::ChannelBindingFailed(
179 24 : "insecure connection: secure channel data mismatch",
180 24 : ));
181 48 : }
182 48 :
183 48 : if client_final_message.nonce != server_first_message.nonce() {
184 0 : return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
185 48 : }
186 48 :
187 48 : let signature_builder = SignatureBuilder {
188 48 : client_first_message_bare,
189 48 : server_first_message: server_first_message.as_str(),
190 48 : client_final_message_without_proof: client_final_message.without_proof,
191 48 : };
192 48 :
193 48 : let client_key = signature_builder
194 48 : .build(&secret.stored_key)
195 48 : .derive_client_key(&client_final_message.proof);
196 48 :
197 48 : // Auth fails either if keys don't match or it's pre-determined to fail.
198 48 : if secret.is_password_invalid(&client_key).into() {
199 6 : return Ok(sasl::Step::Failure("password doesn't match"));
200 42 : }
201 42 :
202 42 : let msg =
203 42 : client_final_message.build_server_final_message(signature_builder, &secret.server_key);
204 42 :
205 42 : Ok(sasl::Step::Success(client_key, msg))
206 72 : }
207 : }
208 :
209 : impl sasl::Mechanism for Exchange<'_> {
210 : type Output = super::ScramKey;
211 :
212 150 : fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
213 150 : use {sasl::Step, ExchangeState};
214 150 : match &self.state {
215 78 : ExchangeState::Initial(init) => {
216 78 : match init.transition(self.secret, &self.tls_server_end_point, input)? {
217 72 : Step::Continue(sent, msg) => {
218 72 : self.state = ExchangeState::SaltSent(sent);
219 72 : Ok(Step::Continue(self, msg))
220 : }
221 : Step::Success(x, _) => match x {},
222 0 : Step::Failure(msg) => Ok(Step::Failure(msg)),
223 : }
224 : }
225 72 : ExchangeState::SaltSent(sent) => {
226 72 : match sent.transition(self.secret, &self.tls_server_end_point, input)? {
227 42 : Step::Success(keys, msg) => Ok(Step::Success(keys, msg)),
228 : Step::Continue(x, _) => match x {},
229 6 : Step::Failure(msg) => Ok(Step::Failure(msg)),
230 : }
231 : }
232 : }
233 150 : }
234 : }
|