Line data Source code
1 : //! Implementation of the SCRAM authentication algorithm.
2 :
3 : use std::convert::Infallible;
4 :
5 : use postgres_protocol::authentication::sasl::ScramSha256;
6 :
7 : use super::messages::{
8 : ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
9 : };
10 : use super::secret::ServerSecret;
11 : use super::signature::SignatureBuilder;
12 : use crate::config;
13 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
14 :
15 : /// The only channel binding mode we currently support.
16 0 : #[derive(Debug)]
17 : struct TlsServerEndPoint;
18 :
19 : impl std::fmt::Display for TlsServerEndPoint {
20 12 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 12 : write!(f, "tls-server-end-point")
22 12 : }
23 : }
24 :
25 : impl std::str::FromStr for TlsServerEndPoint {
26 : type Err = sasl::Error;
27 :
28 12 : fn from_str(s: &str) -> Result<Self, Self::Err> {
29 12 : match s {
30 12 : "tls-server-end-point" => Ok(TlsServerEndPoint),
31 0 : _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
32 : }
33 12 : }
34 : }
35 :
36 : struct SaslSentInner {
37 : cbind_flag: ChannelBinding<TlsServerEndPoint>,
38 : client_first_message_bare: String,
39 : server_first_message: OwnedServerFirstMessage,
40 : }
41 :
42 : struct SaslInitial {
43 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
44 : }
45 :
46 : enum ExchangeState {
47 : /// Waiting for [`ClientFirstMessage`].
48 : Initial(SaslInitial),
49 : /// Waiting for [`ClientFinalMessage`].
50 : SaltSent(SaslSentInner),
51 : }
52 :
53 : /// Server's side of SCRAM auth algorithm.
54 : pub struct Exchange<'a> {
55 : state: ExchangeState,
56 : secret: &'a ServerSecret,
57 : tls_server_end_point: config::TlsServerEndPoint,
58 : }
59 :
60 : impl<'a> Exchange<'a> {
61 24 : pub fn new(
62 24 : secret: &'a ServerSecret,
63 24 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
64 24 : tls_server_end_point: config::TlsServerEndPoint,
65 24 : ) -> Self {
66 24 : Self {
67 24 : state: ExchangeState::Initial(SaslInitial { nonce }),
68 24 : secret,
69 24 : tls_server_end_point,
70 24 : }
71 24 : }
72 : }
73 :
74 4 : pub fn exchange(
75 4 : secret: &ServerSecret,
76 4 : mut client: ScramSha256,
77 4 : tls_server_end_point: config::TlsServerEndPoint,
78 4 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
79 4 : use sasl::Step::*;
80 4 :
81 4 : let init = SaslInitial {
82 4 : nonce: rand::random,
83 4 : };
84 :
85 4 : let client_first = std::str::from_utf8(client.message())
86 4 : .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
87 4 : let sent = match init.transition(secret, &tls_server_end_point, client_first)? {
88 4 : Continue(sent, server_first) => {
89 4 : client.update(server_first.as_bytes())?;
90 4 : sent
91 : }
92 : Success(x, _) => match x {},
93 0 : Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
94 : };
95 :
96 4 : let client_final = std::str::from_utf8(client.message())
97 4 : .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
98 4 : let keys = match sent.transition(secret, &tls_server_end_point, client_final)? {
99 2 : Success(keys, server_final) => {
100 2 : client.finish(server_final.as_bytes())?;
101 2 : keys
102 : }
103 : Continue(x, _) => match x {},
104 2 : Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
105 : };
106 :
107 2 : Ok(sasl::Outcome::Success(keys))
108 4 : }
109 :
110 : impl SaslInitial {
111 28 : fn transition(
112 28 : &self,
113 28 : secret: &ServerSecret,
114 28 : tls_server_end_point: &config::TlsServerEndPoint,
115 28 : input: &str,
116 28 : ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
117 28 : let client_first_message = ClientFirstMessage::parse(input)
118 28 : .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
119 :
120 : // If the flag is set to "y" and the server supports channel
121 : // binding, the server MUST fail authentication
122 28 : if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
123 2 : && tls_server_end_point.supported()
124 : {
125 2 : return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
126 26 : }
127 26 :
128 26 : let server_first_message = client_first_message.build_server_first_message(
129 26 : &(self.nonce)(),
130 26 : &secret.salt_base64,
131 26 : secret.iterations,
132 26 : );
133 26 : let msg = server_first_message.as_str().to_owned();
134 :
135 26 : let next = SaslSentInner {
136 26 : cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
137 26 : client_first_message_bare: client_first_message.bare.to_owned(),
138 26 : server_first_message,
139 26 : };
140 26 :
141 26 : Ok(sasl::Step::Continue(next, msg))
142 28 : }
143 : }
144 :
145 : impl SaslSentInner {
146 26 : fn transition(
147 26 : &self,
148 26 : secret: &ServerSecret,
149 26 : tls_server_end_point: &config::TlsServerEndPoint,
150 26 : input: &str,
151 26 : ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
152 26 : let Self {
153 26 : cbind_flag,
154 26 : client_first_message_bare,
155 26 : server_first_message,
156 26 : } = self;
157 :
158 26 : let client_final_message = ClientFinalMessage::parse(input)
159 26 : .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
160 :
161 26 : let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
162 12 : config::TlsServerEndPoint::Sha256(x) => Ok(x),
163 0 : config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
164 26 : })?;
165 :
166 : // This might've been caused by a MITM attack
167 26 : if client_final_message.channel_binding != channel_binding {
168 8 : return Err(SaslError::ChannelBindingFailed(
169 8 : "insecure connection: secure channel data mismatch",
170 8 : ));
171 18 : }
172 18 :
173 18 : if client_final_message.nonce != server_first_message.nonce() {
174 0 : return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
175 18 : }
176 18 :
177 18 : let signature_builder = SignatureBuilder {
178 18 : client_first_message_bare,
179 18 : server_first_message: server_first_message.as_str(),
180 18 : client_final_message_without_proof: client_final_message.without_proof,
181 18 : };
182 18 :
183 18 : let client_key = signature_builder
184 18 : .build(&secret.stored_key)
185 18 : .derive_client_key(&client_final_message.proof);
186 18 :
187 18 : // Auth fails either if keys don't match or it's pre-determined to fail.
188 18 : if client_key.sha256() != secret.stored_key || secret.doomed {
189 4 : return Ok(sasl::Step::Failure("password doesn't match"));
190 14 : }
191 14 :
192 14 : let msg =
193 14 : client_final_message.build_server_final_message(signature_builder, &secret.server_key);
194 14 :
195 14 : Ok(sasl::Step::Success(client_key, msg))
196 26 : }
197 : }
198 :
199 : impl sasl::Mechanism for Exchange<'_> {
200 : type Output = super::ScramKey;
201 :
202 46 : fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
203 46 : use {sasl::Step::*, ExchangeState::*};
204 46 : match &self.state {
205 24 : Initial(init) => {
206 24 : match init.transition(self.secret, &self.tls_server_end_point, input)? {
207 22 : Continue(sent, msg) => {
208 22 : self.state = SaltSent(sent);
209 22 : Ok(Continue(self, msg))
210 : }
211 : Success(x, _) => match x {},
212 0 : Failure(msg) => Ok(Failure(msg)),
213 : }
214 : }
215 22 : SaltSent(sent) => {
216 22 : match sent.transition(self.secret, &self.tls_server_end_point, input)? {
217 12 : Success(keys, msg) => Ok(Success(keys, msg)),
218 : Continue(x, _) => match x {},
219 2 : Failure(msg) => Ok(Failure(msg)),
220 : }
221 : }
222 : }
223 46 : }
224 : }
|