Line data Source code
1 : //! Implementation of the SCRAM authentication algorithm.
2 :
3 : use std::convert::Infallible;
4 :
5 : use hmac::{Hmac, Mac};
6 : use sha2::Sha256;
7 :
8 : use super::messages::{
9 : ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
10 : };
11 : use super::pbkdf2::Pbkdf2;
12 : use super::secret::ServerSecret;
13 : use super::signature::SignatureBuilder;
14 : use super::threadpool::ThreadPool;
15 : use super::ScramKey;
16 : use crate::intern::EndpointIdInt;
17 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
18 :
19 : /// The only channel binding mode we currently support.
20 : #[derive(Debug)]
21 : struct TlsServerEndPoint;
22 :
23 : impl std::fmt::Display for TlsServerEndPoint {
24 6 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 6 : write!(f, "tls-server-end-point")
26 6 : }
27 : }
28 :
29 : impl std::str::FromStr for TlsServerEndPoint {
30 : type Err = sasl::Error;
31 :
32 6 : fn from_str(s: &str) -> Result<Self, Self::Err> {
33 6 : match s {
34 6 : "tls-server-end-point" => Ok(TlsServerEndPoint),
35 0 : _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
36 : }
37 6 : }
38 : }
39 :
40 : struct SaslSentInner {
41 : cbind_flag: ChannelBinding<TlsServerEndPoint>,
42 : client_first_message_bare: String,
43 : server_first_message: OwnedServerFirstMessage,
44 : }
45 :
46 : struct SaslInitial {
47 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
48 : }
49 :
50 : enum ExchangeState {
51 : /// Waiting for [`ClientFirstMessage`].
52 : Initial(SaslInitial),
53 : /// Waiting for [`ClientFinalMessage`].
54 : SaltSent(SaslSentInner),
55 : }
56 :
57 : /// Server's side of SCRAM auth algorithm.
58 : pub(crate) struct Exchange<'a> {
59 : state: ExchangeState,
60 : secret: &'a ServerSecret,
61 : tls_server_end_point: crate::tls::TlsServerEndPoint,
62 : }
63 :
64 : impl<'a> Exchange<'a> {
65 13 : pub(crate) fn new(
66 13 : secret: &'a ServerSecret,
67 13 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
68 13 : tls_server_end_point: crate::tls::TlsServerEndPoint,
69 13 : ) -> Self {
70 13 : Self {
71 13 : state: ExchangeState::Initial(SaslInitial { nonce }),
72 13 : secret,
73 13 : tls_server_end_point,
74 13 : }
75 13 : }
76 : }
77 :
78 : // copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
79 4 : async fn derive_client_key(
80 4 : pool: &ThreadPool,
81 4 : endpoint: EndpointIdInt,
82 4 : password: &[u8],
83 4 : salt: &[u8],
84 4 : iterations: u32,
85 4 : ) -> ScramKey {
86 4 : let salted_password = pool
87 4 : .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
88 4 : .await;
89 :
90 4 : let make_key = |name| {
91 4 : let key = Hmac::<Sha256>::new_from_slice(&salted_password)
92 4 : .expect("HMAC is able to accept all key sizes")
93 4 : .chain_update(name)
94 4 : .finalize();
95 4 :
96 4 : <[u8; 32]>::from(key.into_bytes())
97 4 : };
98 :
99 4 : make_key(b"Client Key").into()
100 4 : }
101 :
102 4 : pub(crate) async fn exchange(
103 4 : pool: &ThreadPool,
104 4 : endpoint: EndpointIdInt,
105 4 : secret: &ServerSecret,
106 4 : password: &[u8],
107 4 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
108 4 : let salt = base64::decode(&secret.salt_base64)?;
109 4 : let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
110 :
111 4 : if secret.is_password_invalid(&client_key).into() {
112 1 : Ok(sasl::Outcome::Failure("password doesn't match"))
113 : } else {
114 3 : Ok(sasl::Outcome::Success(client_key))
115 : }
116 4 : }
117 :
118 : impl SaslInitial {
119 13 : fn transition(
120 13 : &self,
121 13 : secret: &ServerSecret,
122 13 : tls_server_end_point: &crate::tls::TlsServerEndPoint,
123 13 : input: &str,
124 13 : ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
125 13 : let client_first_message = ClientFirstMessage::parse(input)
126 13 : .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
127 :
128 : // If the flag is set to "y" and the server supports channel
129 : // binding, the server MUST fail authentication
130 13 : if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
131 1 : && tls_server_end_point.supported()
132 : {
133 1 : return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
134 12 : }
135 12 :
136 12 : let server_first_message = client_first_message.build_server_first_message(
137 12 : &(self.nonce)(),
138 12 : &secret.salt_base64,
139 12 : secret.iterations,
140 12 : );
141 12 : let msg = server_first_message.as_str().to_owned();
142 :
143 12 : let next = SaslSentInner {
144 12 : cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
145 12 : client_first_message_bare: client_first_message.bare.to_owned(),
146 12 : server_first_message,
147 12 : };
148 12 :
149 12 : Ok(sasl::Step::Continue(next, msg))
150 13 : }
151 : }
152 :
153 : impl SaslSentInner {
154 12 : fn transition(
155 12 : &self,
156 12 : secret: &ServerSecret,
157 12 : tls_server_end_point: &crate::tls::TlsServerEndPoint,
158 12 : input: &str,
159 12 : ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
160 12 : let Self {
161 12 : cbind_flag,
162 12 : client_first_message_bare,
163 12 : server_first_message,
164 12 : } = self;
165 :
166 12 : let client_final_message = ClientFinalMessage::parse(input)
167 12 : .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
168 :
169 12 : let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
170 6 : crate::tls::TlsServerEndPoint::Sha256(x) => Ok(x),
171 0 : crate::tls::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
172 12 : })?;
173 :
174 : // This might've been caused by a MITM attack
175 12 : if client_final_message.channel_binding != channel_binding {
176 4 : return Err(SaslError::ChannelBindingFailed(
177 4 : "insecure connection: secure channel data mismatch",
178 4 : ));
179 8 : }
180 8 :
181 8 : if client_final_message.nonce != server_first_message.nonce() {
182 0 : return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
183 8 : }
184 8 :
185 8 : let signature_builder = SignatureBuilder {
186 8 : client_first_message_bare,
187 8 : server_first_message: server_first_message.as_str(),
188 8 : client_final_message_without_proof: client_final_message.without_proof,
189 8 : };
190 8 :
191 8 : let client_key = signature_builder
192 8 : .build(&secret.stored_key)
193 8 : .derive_client_key(&client_final_message.proof);
194 8 :
195 8 : // Auth fails either if keys don't match or it's pre-determined to fail.
196 8 : if secret.is_password_invalid(&client_key).into() {
197 1 : return Ok(sasl::Step::Failure("password doesn't match"));
198 7 : }
199 7 :
200 7 : let msg =
201 7 : client_final_message.build_server_final_message(signature_builder, &secret.server_key);
202 7 :
203 7 : Ok(sasl::Step::Success(client_key, msg))
204 12 : }
205 : }
206 :
207 : impl sasl::Mechanism for Exchange<'_> {
208 : type Output = super::ScramKey;
209 :
210 25 : fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
211 : use sasl::Step;
212 : use ExchangeState;
213 25 : match &self.state {
214 13 : ExchangeState::Initial(init) => {
215 13 : match init.transition(self.secret, &self.tls_server_end_point, input)? {
216 12 : Step::Continue(sent, msg) => {
217 12 : self.state = ExchangeState::SaltSent(sent);
218 12 : Ok(Step::Continue(self, msg))
219 : }
220 0 : Step::Failure(msg) => Ok(Step::Failure(msg)),
221 : }
222 : }
223 12 : ExchangeState::SaltSent(sent) => {
224 12 : match sent.transition(self.secret, &self.tls_server_end_point, input)? {
225 7 : Step::Success(keys, msg) => Ok(Step::Success(keys, msg)),
226 1 : Step::Failure(msg) => Ok(Step::Failure(msg)),
227 : }
228 : }
229 : }
230 25 : }
231 : }
|