Line data Source code
1 : //! Implementation of the SCRAM authentication algorithm.
2 :
3 : use std::convert::Infallible;
4 :
5 : use hmac::{Hmac, Mac};
6 : use sha2::Sha256;
7 :
8 : use super::messages::{
9 : ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
10 : };
11 : use super::pbkdf2::Pbkdf2;
12 : use super::secret::ServerSecret;
13 : use super::signature::SignatureBuilder;
14 : use super::threadpool::ThreadPool;
15 : use super::ScramKey;
16 : use crate::config;
17 : use crate::intern::EndpointIdInt;
18 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
19 :
20 : /// The only channel binding mode we currently support.
21 : #[derive(Debug)]
22 : struct TlsServerEndPoint;
23 :
24 : impl std::fmt::Display for TlsServerEndPoint {
25 6 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 6 : write!(f, "tls-server-end-point")
27 6 : }
28 : }
29 :
30 : impl std::str::FromStr for TlsServerEndPoint {
31 : type Err = sasl::Error;
32 :
33 6 : fn from_str(s: &str) -> Result<Self, Self::Err> {
34 6 : match s {
35 6 : "tls-server-end-point" => Ok(TlsServerEndPoint),
36 0 : _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
37 : }
38 6 : }
39 : }
40 :
41 : struct SaslSentInner {
42 : cbind_flag: ChannelBinding<TlsServerEndPoint>,
43 : client_first_message_bare: String,
44 : server_first_message: OwnedServerFirstMessage,
45 : }
46 :
47 : struct SaslInitial {
48 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
49 : }
50 :
51 : enum ExchangeState {
52 : /// Waiting for [`ClientFirstMessage`].
53 : Initial(SaslInitial),
54 : /// Waiting for [`ClientFinalMessage`].
55 : SaltSent(SaslSentInner),
56 : }
57 :
58 : /// Server's side of SCRAM auth algorithm.
59 : pub(crate) struct Exchange<'a> {
60 : state: ExchangeState,
61 : secret: &'a ServerSecret,
62 : tls_server_end_point: config::TlsServerEndPoint,
63 : }
64 :
65 : impl<'a> Exchange<'a> {
66 13 : pub(crate) fn new(
67 13 : secret: &'a ServerSecret,
68 13 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
69 13 : tls_server_end_point: config::TlsServerEndPoint,
70 13 : ) -> Self {
71 13 : Self {
72 13 : state: ExchangeState::Initial(SaslInitial { nonce }),
73 13 : secret,
74 13 : tls_server_end_point,
75 13 : }
76 13 : }
77 : }
78 :
79 : // copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
80 4 : async fn derive_client_key(
81 4 : pool: &ThreadPool,
82 4 : endpoint: EndpointIdInt,
83 4 : password: &[u8],
84 4 : salt: &[u8],
85 4 : iterations: u32,
86 4 : ) -> ScramKey {
87 4 : let salted_password = pool
88 4 : .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
89 4 : .await;
90 :
91 4 : let make_key = |name| {
92 4 : let key = Hmac::<Sha256>::new_from_slice(&salted_password)
93 4 : .expect("HMAC is able to accept all key sizes")
94 4 : .chain_update(name)
95 4 : .finalize();
96 4 :
97 4 : <[u8; 32]>::from(key.into_bytes())
98 4 : };
99 :
100 4 : make_key(b"Client Key").into()
101 4 : }
102 :
103 4 : pub(crate) async fn exchange(
104 4 : pool: &ThreadPool,
105 4 : endpoint: EndpointIdInt,
106 4 : secret: &ServerSecret,
107 4 : password: &[u8],
108 4 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
109 4 : let salt = base64::decode(&secret.salt_base64)?;
110 4 : let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
111 :
112 4 : if secret.is_password_invalid(&client_key).into() {
113 1 : Ok(sasl::Outcome::Failure("password doesn't match"))
114 : } else {
115 3 : Ok(sasl::Outcome::Success(client_key))
116 : }
117 4 : }
118 :
119 : impl SaslInitial {
120 13 : fn transition(
121 13 : &self,
122 13 : secret: &ServerSecret,
123 13 : tls_server_end_point: &config::TlsServerEndPoint,
124 13 : input: &str,
125 13 : ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
126 13 : let client_first_message = ClientFirstMessage::parse(input)
127 13 : .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
128 :
129 : // If the flag is set to "y" and the server supports channel
130 : // binding, the server MUST fail authentication
131 13 : if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
132 1 : && tls_server_end_point.supported()
133 : {
134 1 : return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
135 12 : }
136 12 :
137 12 : let server_first_message = client_first_message.build_server_first_message(
138 12 : &(self.nonce)(),
139 12 : &secret.salt_base64,
140 12 : secret.iterations,
141 12 : );
142 12 : let msg = server_first_message.as_str().to_owned();
143 :
144 12 : let next = SaslSentInner {
145 12 : cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
146 12 : client_first_message_bare: client_first_message.bare.to_owned(),
147 12 : server_first_message,
148 12 : };
149 12 :
150 12 : Ok(sasl::Step::Continue(next, msg))
151 13 : }
152 : }
153 :
154 : impl SaslSentInner {
155 12 : fn transition(
156 12 : &self,
157 12 : secret: &ServerSecret,
158 12 : tls_server_end_point: &config::TlsServerEndPoint,
159 12 : input: &str,
160 12 : ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
161 12 : let Self {
162 12 : cbind_flag,
163 12 : client_first_message_bare,
164 12 : server_first_message,
165 12 : } = self;
166 :
167 12 : let client_final_message = ClientFinalMessage::parse(input)
168 12 : .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
169 :
170 12 : let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
171 6 : config::TlsServerEndPoint::Sha256(x) => Ok(x),
172 0 : config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
173 12 : })?;
174 :
175 : // This might've been caused by a MITM attack
176 12 : if client_final_message.channel_binding != channel_binding {
177 4 : return Err(SaslError::ChannelBindingFailed(
178 4 : "insecure connection: secure channel data mismatch",
179 4 : ));
180 8 : }
181 8 :
182 8 : if client_final_message.nonce != server_first_message.nonce() {
183 0 : return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
184 8 : }
185 8 :
186 8 : let signature_builder = SignatureBuilder {
187 8 : client_first_message_bare,
188 8 : server_first_message: server_first_message.as_str(),
189 8 : client_final_message_without_proof: client_final_message.without_proof,
190 8 : };
191 8 :
192 8 : let client_key = signature_builder
193 8 : .build(&secret.stored_key)
194 8 : .derive_client_key(&client_final_message.proof);
195 8 :
196 8 : // Auth fails either if keys don't match or it's pre-determined to fail.
197 8 : if secret.is_password_invalid(&client_key).into() {
198 1 : return Ok(sasl::Step::Failure("password doesn't match"));
199 7 : }
200 7 :
201 7 : let msg =
202 7 : client_final_message.build_server_final_message(signature_builder, &secret.server_key);
203 7 :
204 7 : Ok(sasl::Step::Success(client_key, msg))
205 12 : }
206 : }
207 :
208 : impl sasl::Mechanism for Exchange<'_> {
209 : type Output = super::ScramKey;
210 :
211 25 : fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
212 : use sasl::Step;
213 : use ExchangeState;
214 25 : match &self.state {
215 13 : ExchangeState::Initial(init) => {
216 13 : match init.transition(self.secret, &self.tls_server_end_point, input)? {
217 12 : Step::Continue(sent, msg) => {
218 12 : self.state = ExchangeState::SaltSent(sent);
219 12 : Ok(Step::Continue(self, msg))
220 : }
221 0 : Step::Failure(msg) => Ok(Step::Failure(msg)),
222 : }
223 : }
224 12 : ExchangeState::SaltSent(sent) => {
225 12 : match sent.transition(self.secret, &self.tls_server_end_point, input)? {
226 7 : Step::Success(keys, msg) => Ok(Step::Success(keys, msg)),
227 1 : Step::Failure(msg) => Ok(Step::Failure(msg)),
228 : }
229 : }
230 : }
231 25 : }
232 : }
|