Line data Source code
1 : //! Implementation of the SCRAM authentication algorithm.
2 :
3 : use std::convert::Infallible;
4 :
5 : use postgres_protocol::authentication::sasl::ScramSha256;
6 :
7 : use super::messages::{
8 : ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
9 : };
10 : use super::secret::ServerSecret;
11 : use super::signature::SignatureBuilder;
12 : use crate::config;
13 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
14 :
15 : /// The only channel binding mode we currently support.
16 0 : #[derive(Debug)]
17 : struct TlsServerEndPoint;
18 :
19 : impl std::fmt::Display for TlsServerEndPoint {
20 51 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 51 : write!(f, "tls-server-end-point")
22 51 : }
23 : }
24 :
25 : impl std::str::FromStr for TlsServerEndPoint {
26 : type Err = sasl::Error;
27 :
28 51 : fn from_str(s: &str) -> Result<Self, Self::Err> {
29 51 : match s {
30 51 : "tls-server-end-point" => Ok(TlsServerEndPoint),
31 0 : _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
32 : }
33 51 : }
34 : }
35 :
36 : struct SaslSentInner {
37 : cbind_flag: ChannelBinding<TlsServerEndPoint>,
38 : client_first_message_bare: String,
39 : server_first_message: OwnedServerFirstMessage,
40 : }
41 :
42 : struct SaslInitial {
43 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
44 : }
45 :
46 : enum ExchangeState {
47 : /// Waiting for [`ClientFirstMessage`].
48 : Initial(SaslInitial),
49 : /// Waiting for [`ClientFinalMessage`].
50 : SaltSent(SaslSentInner),
51 : }
52 :
53 : /// Server's side of SCRAM auth algorithm.
54 : pub struct Exchange<'a> {
55 : state: ExchangeState,
56 : secret: &'a ServerSecret,
57 : tls_server_end_point: config::TlsServerEndPoint,
58 : }
59 :
60 : impl<'a> Exchange<'a> {
61 63 : pub fn new(
62 63 : secret: &'a ServerSecret,
63 63 : nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
64 63 : tls_server_end_point: config::TlsServerEndPoint,
65 63 : ) -> Self {
66 63 : Self {
67 63 : state: ExchangeState::Initial(SaslInitial { nonce }),
68 63 : secret,
69 63 : tls_server_end_point,
70 63 : }
71 63 : }
72 : }
73 :
74 47 : pub fn exchange(
75 47 : secret: &ServerSecret,
76 47 : mut client: ScramSha256,
77 47 : tls_server_end_point: config::TlsServerEndPoint,
78 47 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
79 47 : use sasl::Step::*;
80 47 :
81 47 : let init = SaslInitial {
82 47 : nonce: rand::random,
83 47 : };
84 :
85 47 : let client_first = std::str::from_utf8(client.message())
86 47 : .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
87 47 : let sent = match init.transition(secret, &tls_server_end_point, client_first)? {
88 47 : Continue(sent, server_first) => {
89 47 : client.update(server_first.as_bytes())?;
90 47 : sent
91 : }
92 : Success(x, _) => match x {},
93 0 : Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
94 : };
95 :
96 47 : let client_final = std::str::from_utf8(client.message())
97 47 : .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
98 47 : let keys = match sent.transition(secret, &tls_server_end_point, client_final)? {
99 46 : Success(keys, server_final) => {
100 46 : client.finish(server_final.as_bytes())?;
101 46 : keys
102 : }
103 : Continue(x, _) => match x {},
104 1 : Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
105 : };
106 :
107 46 : Ok(sasl::Outcome::Success(keys))
108 47 : }
109 :
110 : impl SaslInitial {
111 110 : fn transition(
112 110 : &self,
113 110 : secret: &ServerSecret,
114 110 : tls_server_end_point: &config::TlsServerEndPoint,
115 110 : input: &str,
116 110 : ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
117 110 : let client_first_message = ClientFirstMessage::parse(input)
118 110 : .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
119 :
120 : // If the flag is set to "y" and the server supports channel
121 : // binding, the server MUST fail authentication
122 110 : if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
123 2 : && tls_server_end_point.supported()
124 : {
125 2 : return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
126 108 : }
127 108 :
128 108 : let server_first_message = client_first_message.build_server_first_message(
129 108 : &(self.nonce)(),
130 108 : &secret.salt_base64,
131 108 : secret.iterations,
132 108 : );
133 108 : let msg = server_first_message.as_str().to_owned();
134 :
135 108 : let next = SaslSentInner {
136 108 : cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
137 108 : client_first_message_bare: client_first_message.bare.to_owned(),
138 108 : server_first_message,
139 108 : };
140 108 :
141 108 : Ok(sasl::Step::Continue(next, msg))
142 110 : }
143 : }
144 :
145 : impl SaslSentInner {
146 108 : fn transition(
147 108 : &self,
148 108 : secret: &ServerSecret,
149 108 : tls_server_end_point: &config::TlsServerEndPoint,
150 108 : input: &str,
151 108 : ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
152 108 : let Self {
153 108 : cbind_flag,
154 108 : client_first_message_bare,
155 108 : server_first_message,
156 108 : } = self;
157 :
158 108 : let client_final_message = ClientFinalMessage::parse(input)
159 108 : .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
160 :
161 108 : let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
162 51 : config::TlsServerEndPoint::Sha256(x) => Ok(x),
163 0 : config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
164 108 : })?;
165 :
166 : // This might've been caused by a MITM attack
167 108 : if client_final_message.channel_binding != channel_binding {
168 8 : return Err(SaslError::ChannelBindingFailed(
169 8 : "insecure connection: secure channel data mismatch",
170 8 : ));
171 100 : }
172 100 :
173 100 : if client_final_message.nonce != server_first_message.nonce() {
174 0 : return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
175 100 : }
176 100 :
177 100 : let signature_builder = SignatureBuilder {
178 100 : client_first_message_bare,
179 100 : server_first_message: server_first_message.as_str(),
180 100 : client_final_message_without_proof: client_final_message.without_proof,
181 100 : };
182 100 :
183 100 : let client_key = signature_builder
184 100 : .build(&secret.stored_key)
185 100 : .derive_client_key(&client_final_message.proof);
186 100 :
187 100 : // Auth fails either if keys don't match or it's pre-determined to fail.
188 100 : if client_key.sha256() != secret.stored_key || secret.doomed {
189 6 : return Ok(sasl::Step::Failure("password doesn't match"));
190 94 : }
191 94 :
192 94 : let msg =
193 94 : client_final_message.build_server_final_message(signature_builder, &secret.server_key);
194 94 :
195 94 : Ok(sasl::Step::Success(client_key, msg))
196 108 : }
197 : }
198 :
199 : impl sasl::Mechanism for Exchange<'_> {
200 : type Output = super::ScramKey;
201 :
202 124 : fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
203 124 : use {sasl::Step::*, ExchangeState::*};
204 124 : match &self.state {
205 63 : Initial(init) => {
206 63 : match init.transition(self.secret, &self.tls_server_end_point, input)? {
207 61 : Continue(sent, msg) => {
208 61 : self.state = SaltSent(sent);
209 61 : Ok(Continue(self, msg))
210 : }
211 : Success(x, _) => match x {},
212 0 : Failure(msg) => Ok(Failure(msg)),
213 : }
214 : }
215 61 : SaltSent(sent) => {
216 61 : match sent.transition(self.secret, &self.tls_server_end_point, input)? {
217 48 : Success(keys, msg) => Ok(Success(keys, msg)),
218 : Continue(x, _) => match x {},
219 5 : Failure(msg) => Ok(Failure(msg)),
220 : }
221 : }
222 : }
223 124 : }
224 : }
|