TLA Line data Source code
1 : //! Implementation of the SCRAM authentication algorithm.
2 :
3 : use std::convert::Infallible;
4 :
5 : use postgres_protocol::authentication::sasl::ScramSha256;
6 :
7 : use super::messages::{
8 : ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
9 : };
10 : use super::secret::ServerSecret;
11 : use super::signature::SignatureBuilder;
12 : use crate::config;
13 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
14 :
15 : /// The only channel binding mode we currently support.
16 UBC 0 : #[derive(Debug)]
17 : struct TlsServerEndPoint;
18 :
19 : impl std::fmt::Display for TlsServerEndPoint {
20 CBC 42 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 42 : write!(f, "tls-server-end-point")
22 42 : }
23 : }
24 :
25 : impl std::str::FromStr for TlsServerEndPoint {
26 : type Err = sasl::Error;
27 :
28 42 : fn from_str(s: &str) -> Result<Self, Self::Err> {
29 42 : match s {
30 42 : "tls-server-end-point" => Ok(TlsServerEndPoint),
31 UBC 0 : _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
32 : }
33 CBC 42 : }
34 : }
35 :
36 : struct SaslSentInner {
37 : cbind_flag: ChannelBinding<TlsServerEndPoint>,
38 : client_first_message_bare: String,
39 : server_first_message: OwnedServerFirstMessage,
40 : }
41 :
42 : struct SaslInitial {
43 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
44 : }
45 :
46 : enum ExchangeState {
47 : /// Waiting for [`ClientFirstMessage`].
48 : Initial(SaslInitial),
49 : /// Waiting for [`ClientFinalMessage`].
50 : SaltSent(SaslSentInner),
51 : }
52 :
53 : /// Server's side of SCRAM auth algorithm.
54 : pub struct Exchange<'a> {
55 : state: ExchangeState,
56 : secret: &'a ServerSecret,
57 : tls_server_end_point: config::TlsServerEndPoint,
58 : }
59 :
60 : impl<'a> Exchange<'a> {
61 48 : pub fn new(
62 48 : secret: &'a ServerSecret,
63 48 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
64 48 : tls_server_end_point: config::TlsServerEndPoint,
65 48 : ) -> Self {
66 48 : Self {
67 48 : state: ExchangeState::Initial(SaslInitial { nonce }),
68 48 : secret,
69 48 : tls_server_end_point,
70 48 : }
71 48 : }
72 : }
73 :
74 2 : pub fn exchange(
75 2 : secret: &ServerSecret,
76 2 : mut client: ScramSha256,
77 2 : tls_server_end_point: config::TlsServerEndPoint,
78 2 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
79 2 : use sasl::Step::*;
80 2 :
81 2 : let init = SaslInitial {
82 2 : nonce: rand::random,
83 2 : };
84 :
85 2 : let client_first = std::str::from_utf8(client.message())
86 2 : .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
87 2 : let sent = match init.transition(secret, &tls_server_end_point, client_first)? {
88 2 : Continue(sent, server_first) => {
89 2 : client.update(server_first.as_bytes())?;
90 2 : sent
91 : }
92 : Success(x, _) => match x {},
93 UBC 0 : Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
94 : };
95 :
96 CBC 2 : let client_final = std::str::from_utf8(client.message())
97 2 : .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
98 2 : let keys = match sent.transition(secret, &tls_server_end_point, client_final)? {
99 2 : Success(keys, server_final) => {
100 2 : client.finish(server_final.as_bytes())?;
101 2 : keys
102 : }
103 : Continue(x, _) => match x {},
104 UBC 0 : Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
105 : };
106 :
107 CBC 2 : Ok(sasl::Outcome::Success(keys))
108 2 : }
109 :
110 : impl SaslInitial {
111 50 : fn transition(
112 50 : &self,
113 50 : secret: &ServerSecret,
114 50 : tls_server_end_point: &config::TlsServerEndPoint,
115 50 : input: &str,
116 50 : ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
117 50 : let client_first_message = ClientFirstMessage::parse(input)
118 50 : .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
119 :
120 : // If the flag is set to "y" and the server supports channel
121 : // binding, the server MUST fail authentication
122 50 : if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
123 1 : && tls_server_end_point.supported()
124 : {
125 1 : return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
126 49 : }
127 49 :
128 49 : let server_first_message = client_first_message.build_server_first_message(
129 49 : &(self.nonce)(),
130 49 : &secret.salt_base64,
131 49 : secret.iterations,
132 49 : );
133 49 : let msg = server_first_message.as_str().to_owned();
134 :
135 49 : let next = SaslSentInner {
136 49 : cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
137 49 : client_first_message_bare: client_first_message.bare.to_owned(),
138 49 : server_first_message,
139 49 : };
140 49 :
141 49 : Ok(sasl::Step::Continue(next, msg))
142 50 : }
143 : }
144 :
145 : impl SaslSentInner {
146 49 : fn transition(
147 49 : &self,
148 49 : secret: &ServerSecret,
149 49 : tls_server_end_point: &config::TlsServerEndPoint,
150 49 : input: &str,
151 49 : ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
152 49 : let Self {
153 49 : cbind_flag,
154 49 : client_first_message_bare,
155 49 : server_first_message,
156 49 : } = self;
157 :
158 49 : let client_final_message = ClientFinalMessage::parse(input)
159 49 : .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
160 :
161 49 : let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
162 42 : config::TlsServerEndPoint::Sha256(x) => Ok(x),
163 UBC 0 : config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
164 CBC 49 : })?;
165 :
166 : // This might've been caused by a MITM attack
167 49 : if client_final_message.channel_binding != channel_binding {
168 4 : return Err(SaslError::ChannelBindingFailed(
169 4 : "insecure connection: secure channel data mismatch",
170 4 : ));
171 45 : }
172 45 :
173 45 : if client_final_message.nonce != server_first_message.nonce() {
174 UBC 0 : return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
175 CBC 45 : }
176 45 :
177 45 : let signature_builder = SignatureBuilder {
178 45 : client_first_message_bare,
179 45 : server_first_message: server_first_message.as_str(),
180 45 : client_final_message_without_proof: client_final_message.without_proof,
181 45 : };
182 45 :
183 45 : let client_key = signature_builder
184 45 : .build(&secret.stored_key)
185 45 : .derive_client_key(&client_final_message.proof);
186 45 :
187 45 : // Auth fails either if keys don't match or it's pre-determined to fail.
188 45 : if client_key.sha256() != secret.stored_key || secret.doomed {
189 4 : return Ok(sasl::Step::Failure("password doesn't match"));
190 41 : }
191 41 :
192 41 : let msg =
193 41 : client_final_message.build_server_final_message(signature_builder, &secret.server_key);
194 41 :
195 41 : Ok(sasl::Step::Success(client_key, msg))
196 49 : }
197 : }
198 :
199 : impl sasl::Mechanism for Exchange<'_> {
200 : type Output = super::ScramKey;
201 :
202 95 : fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
203 95 : use {sasl::Step::*, ExchangeState::*};
204 95 : match &self.state {
205 48 : Initial(init) => {
206 48 : match init.transition(self.secret, &self.tls_server_end_point, input)? {
207 47 : Continue(sent, msg) => {
208 47 : self.state = SaltSent(sent);
209 47 : Ok(Continue(self, msg))
210 : }
211 : Success(x, _) => match x {},
212 UBC 0 : Failure(msg) => Ok(Failure(msg)),
213 : }
214 : }
215 CBC 47 : SaltSent(sent) => {
216 47 : match sent.transition(self.secret, &self.tls_server_end_point, input)? {
217 39 : Success(keys, msg) => Ok(Success(keys, msg)),
218 : Continue(x, _) => match x {},
219 4 : Failure(msg) => Ok(Failure(msg)),
220 : }
221 : }
222 : }
223 95 : }
224 : }
|