Line data Source code
1 : //! Implementation of the SCRAM authentication algorithm.
2 :
3 : use std::convert::Infallible;
4 :
5 : use postgres_protocol::authentication::sasl::ScramSha256;
6 :
7 : use super::messages::{
8 : ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
9 : };
10 : use super::secret::ServerSecret;
11 : use super::signature::SignatureBuilder;
12 : use crate::config;
13 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
14 :
15 : /// The only channel binding mode we currently support.
16 0 : #[derive(Debug)]
17 : struct TlsServerEndPoint;
18 :
19 : impl std::fmt::Display for TlsServerEndPoint {
20 49 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 49 : write!(f, "tls-server-end-point")
22 49 : }
23 : }
24 :
25 : impl std::str::FromStr for TlsServerEndPoint {
26 : type Err = sasl::Error;
27 :
28 49 : fn from_str(s: &str) -> Result<Self, Self::Err> {
29 49 : match s {
30 49 : "tls-server-end-point" => Ok(TlsServerEndPoint),
31 0 : _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
32 : }
33 49 : }
34 : }
35 :
36 : struct SaslSentInner {
37 : cbind_flag: ChannelBinding<TlsServerEndPoint>,
38 : client_first_message_bare: String,
39 : server_first_message: OwnedServerFirstMessage,
40 : }
41 :
42 : struct SaslInitial {
43 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
44 : }
45 :
46 : enum ExchangeState {
47 : /// Waiting for [`ClientFirstMessage`].
48 : Initial(SaslInitial),
49 : /// Waiting for [`ClientFinalMessage`].
50 : SaltSent(SaslSentInner),
51 : }
52 :
53 : /// Server's side of SCRAM auth algorithm.
54 : pub struct Exchange<'a> {
55 : state: ExchangeState,
56 : secret: &'a ServerSecret,
57 : tls_server_end_point: config::TlsServerEndPoint,
58 : }
59 :
60 : impl<'a> Exchange<'a> {
61 61 : pub fn new(
62 61 : secret: &'a ServerSecret,
63 61 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
64 61 : tls_server_end_point: config::TlsServerEndPoint,
65 61 : ) -> Self {
66 61 : Self {
67 61 : state: ExchangeState::Initial(SaslInitial { nonce }),
68 61 : secret,
69 61 : tls_server_end_point,
70 61 : }
71 61 : }
72 : }
73 :
74 2 : pub fn exchange(
75 2 : secret: &ServerSecret,
76 2 : mut client: ScramSha256,
77 2 : tls_server_end_point: config::TlsServerEndPoint,
78 2 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
79 2 : use sasl::Step::*;
80 2 :
81 2 : let init = SaslInitial {
82 2 : nonce: rand::random,
83 2 : };
84 :
85 2 : let client_first = std::str::from_utf8(client.message())
86 2 : .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
87 2 : let sent = match init.transition(secret, &tls_server_end_point, client_first)? {
88 2 : Continue(sent, server_first) => {
89 2 : client.update(server_first.as_bytes())?;
90 2 : sent
91 : }
92 : Success(x, _) => match x {},
93 0 : Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
94 : };
95 :
96 2 : let client_final = std::str::from_utf8(client.message())
97 2 : .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
98 2 : let keys = match sent.transition(secret, &tls_server_end_point, client_final)? {
99 2 : Success(keys, server_final) => {
100 2 : client.finish(server_final.as_bytes())?;
101 2 : keys
102 : }
103 : Continue(x, _) => match x {},
104 0 : Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
105 : };
106 :
107 2 : Ok(sasl::Outcome::Success(keys))
108 2 : }
109 :
110 : impl SaslInitial {
111 63 : fn transition(
112 63 : &self,
113 63 : secret: &ServerSecret,
114 63 : tls_server_end_point: &config::TlsServerEndPoint,
115 63 : input: &str,
116 63 : ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
117 63 : let client_first_message = ClientFirstMessage::parse(input)
118 63 : .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
119 :
120 : // If the flag is set to "y" and the server supports channel
121 : // binding, the server MUST fail authentication
122 63 : if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
123 2 : && tls_server_end_point.supported()
124 : {
125 2 : return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
126 61 : }
127 61 :
128 61 : let server_first_message = client_first_message.build_server_first_message(
129 61 : &(self.nonce)(),
130 61 : &secret.salt_base64,
131 61 : secret.iterations,
132 61 : );
133 61 : let msg = server_first_message.as_str().to_owned();
134 :
135 61 : let next = SaslSentInner {
136 61 : cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
137 61 : client_first_message_bare: client_first_message.bare.to_owned(),
138 61 : server_first_message,
139 61 : };
140 61 :
141 61 : Ok(sasl::Step::Continue(next, msg))
142 63 : }
143 : }
144 :
145 : impl SaslSentInner {
146 61 : fn transition(
147 61 : &self,
148 61 : secret: &ServerSecret,
149 61 : tls_server_end_point: &config::TlsServerEndPoint,
150 61 : input: &str,
151 61 : ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
152 61 : let Self {
153 61 : cbind_flag,
154 61 : client_first_message_bare,
155 61 : server_first_message,
156 61 : } = self;
157 :
158 61 : let client_final_message = ClientFinalMessage::parse(input)
159 61 : .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
160 :
161 61 : let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
162 49 : config::TlsServerEndPoint::Sha256(x) => Ok(x),
163 0 : config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
164 61 : })?;
165 :
166 : // This might've been caused by a MITM attack
167 61 : if client_final_message.channel_binding != channel_binding {
168 8 : return Err(SaslError::ChannelBindingFailed(
169 8 : "insecure connection: secure channel data mismatch",
170 8 : ));
171 53 : }
172 53 :
173 53 : if client_final_message.nonce != server_first_message.nonce() {
174 0 : return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
175 53 : }
176 53 :
177 53 : let signature_builder = SignatureBuilder {
178 53 : client_first_message_bare,
179 53 : server_first_message: server_first_message.as_str(),
180 53 : client_final_message_without_proof: client_final_message.without_proof,
181 53 : };
182 53 :
183 53 : let client_key = signature_builder
184 53 : .build(&secret.stored_key)
185 53 : .derive_client_key(&client_final_message.proof);
186 53 :
187 53 : // Auth fails either if keys don't match or it's pre-determined to fail.
188 53 : if client_key.sha256() != secret.stored_key || secret.doomed {
189 5 : return Ok(sasl::Step::Failure("password doesn't match"));
190 48 : }
191 48 :
192 48 : let msg =
193 48 : client_final_message.build_server_final_message(signature_builder, &secret.server_key);
194 48 :
195 48 : Ok(sasl::Step::Success(client_key, msg))
196 61 : }
197 : }
198 :
199 : impl sasl::Mechanism for Exchange<'_> {
200 : type Output = super::ScramKey;
201 :
202 120 : fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
203 120 : use {sasl::Step::*, ExchangeState::*};
204 120 : match &self.state {
205 61 : Initial(init) => {
206 61 : match init.transition(self.secret, &self.tls_server_end_point, input)? {
207 59 : Continue(sent, msg) => {
208 59 : self.state = SaltSent(sent);
209 59 : Ok(Continue(self, msg))
210 : }
211 : Success(x, _) => match x {},
212 0 : Failure(msg) => Ok(Failure(msg)),
213 : }
214 : }
215 59 : SaltSent(sent) => {
216 59 : match sent.transition(self.secret, &self.tls_server_end_point, input)? {
217 46 : Success(keys, msg) => Ok(Success(keys, msg)),
218 : Continue(x, _) => match x {},
219 5 : Failure(msg) => Ok(Failure(msg)),
220 : }
221 : }
222 : }
223 120 : }
224 : }
|