Line data Source code
1 : //! Main authentication flow.
2 :
3 : use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
4 : use crate::{
5 : config::TlsServerEndPoint,
6 : console::AuthSecret,
7 : sasl, scram,
8 : stream::{PqStream, Stream},
9 : };
10 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
11 : use std::io;
12 : use tokio::io::{AsyncRead, AsyncWrite};
13 : use tracing::info;
14 :
15 : /// Every authentication selector is supposed to implement this trait.
16 : pub trait AuthMethod {
17 : /// Any authentication selector should provide initial backend message
18 : /// containing auth method name and parameters, e.g. md5 salt.
19 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
20 : }
21 :
22 : /// Initial state of [`AuthFlow`].
23 : pub struct Begin;
24 :
25 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
26 : pub struct Scram<'a>(pub &'a scram::ServerSecret);
27 :
28 : impl AuthMethod for Scram<'_> {
29 : #[inline(always)]
30 61 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
31 61 : if channel_binding {
32 61 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
33 : } else {
34 0 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
35 0 : scram::METHODS_WITHOUT_PLUS,
36 0 : ))
37 : }
38 61 : }
39 : }
40 :
41 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
42 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
43 : pub struct PasswordHack;
44 :
45 : impl AuthMethod for PasswordHack {
46 : #[inline(always)]
47 3 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
48 3 : Be::AuthenticationCleartextPassword
49 3 : }
50 : }
51 :
52 : /// Use clear-text password auth called `password` in docs
53 : /// <https://www.postgresql.org/docs/current/auth-password.html>
54 : pub struct CleartextPassword(pub AuthSecret);
55 :
56 : impl AuthMethod for CleartextPassword {
57 : #[inline(always)]
58 0 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
59 0 : Be::AuthenticationCleartextPassword
60 0 : }
61 : }
62 :
63 : /// This wrapper for [`PqStream`] performs client authentication.
64 : #[must_use]
65 : pub struct AuthFlow<'a, S, State> {
66 : /// The underlying stream which implements libpq's protocol.
67 : stream: &'a mut PqStream<Stream<S>>,
68 : /// State might contain ancillary data (see [`Self::begin`]).
69 : state: State,
70 : tls_server_end_point: TlsServerEndPoint,
71 : }
72 :
73 : /// Initial state of the stream wrapper.
74 : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
75 : /// Create a new wrapper for client authentication.
76 64 : pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
77 64 : let tls_server_end_point = stream.get_ref().tls_server_end_point();
78 64 :
79 64 : Self {
80 64 : stream,
81 64 : state: Begin,
82 64 : tls_server_end_point,
83 64 : }
84 64 : }
85 :
86 : /// Move to the next step by sending auth method's name & params to client.
87 64 : pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
88 64 : self.stream
89 64 : .write_message(&method.first_message(self.tls_server_end_point.supported()))
90 0 : .await?;
91 :
92 64 : Ok(AuthFlow {
93 64 : stream: self.stream,
94 64 : state: method,
95 64 : tls_server_end_point: self.tls_server_end_point,
96 64 : })
97 64 : }
98 : }
99 :
100 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
101 : /// Perform user authentication. Raise an error in case authentication failed.
102 3 : pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
103 3 : let msg = self.stream.read_password_message().await?;
104 3 : let password = msg
105 3 : .strip_suffix(&[0])
106 3 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
107 :
108 3 : let payload = PasswordHackPayload::parse(password)
109 3 : // If we ended up here and the payload is malformed, it means that
110 3 : // the user neither enabled SNI nor resorted to any other method
111 3 : // for passing the project name we rely on. We should show them
112 3 : // the most helpful error message and point to the documentation.
113 3 : .ok_or(AuthErrorImpl::MissingEndpointName)?;
114 :
115 2 : Ok(payload)
116 3 : }
117 : }
118 :
119 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
120 : /// Perform user authentication. Raise an error in case authentication failed.
121 0 : pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
122 0 : let msg = self.stream.read_password_message().await?;
123 0 : let password = msg
124 0 : .strip_suffix(&[0])
125 0 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
126 :
127 0 : let outcome = validate_password_and_exchange(password, self.state.0)?;
128 :
129 0 : if let sasl::Outcome::Success(_) = &outcome {
130 0 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
131 0 : }
132 :
133 0 : Ok(outcome)
134 0 : }
135 : }
136 :
137 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
138 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
139 : /// Perform user authentication. Raise an error in case authentication failed.
140 61 : pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
141 : // Initial client message contains the chosen auth method's name.
142 61 : let msg = self.stream.read_password_message().await?;
143 59 : let sasl = sasl::FirstMessage::parse(&msg)
144 59 : .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
145 :
146 : // Currently, the only supported SASL method is SCRAM.
147 59 : if !scram::METHODS.contains(&sasl.method) {
148 0 : return Err(super::AuthError::bad_auth_method(sasl.method));
149 59 : }
150 :
151 37 : info!("client chooses {}", sasl.method);
152 :
153 59 : let secret = self.state.0;
154 59 : let outcome = sasl::SaslStream::new(self.stream, sasl.message)
155 59 : .authenticate(scram::Exchange::new(
156 59 : secret,
157 59 : rand::random,
158 59 : self.tls_server_end_point,
159 59 : ))
160 57 : .await?;
161 :
162 49 : if let sasl::Outcome::Success(_) = &outcome {
163 44 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
164 5 : }
165 :
166 49 : Ok(outcome)
167 61 : }
168 : }
169 :
170 2 : pub(super) fn validate_password_and_exchange(
171 2 : password: &[u8],
172 2 : secret: AuthSecret,
173 2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
174 2 : match secret {
175 : #[cfg(any(test, feature = "testing"))]
176 : AuthSecret::Md5(_) => {
177 : // test only
178 0 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
179 0 : password.to_owned(),
180 0 : )))
181 : }
182 : // perform scram authentication as both client and server to validate the keys
183 2 : AuthSecret::Scram(scram_secret) => {
184 2 : use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
185 2 : let sasl_client = ScramSha256::new(password, ChannelBinding::unsupported());
186 2 : let outcome = crate::scram::exchange(
187 2 : &scram_secret,
188 2 : sasl_client,
189 2 : crate::config::TlsServerEndPoint::Undefined,
190 2 : )?;
191 :
192 2 : let client_key = match outcome {
193 2 : sasl::Outcome::Success(client_key) => client_key,
194 0 : sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
195 : };
196 :
197 2 : let keys = crate::compute::ScramKeys {
198 2 : client_key: client_key.as_bytes(),
199 2 : server_key: scram_secret.server_key.as_bytes(),
200 2 : };
201 2 :
202 2 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
203 2 : tokio_postgres::config::AuthKeys::ScramSha256(keys),
204 2 : )))
205 : }
206 : }
207 2 : }
|