Line data Source code
1 : //! Main authentication flow.
2 :
3 : use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
4 : use crate::{
5 : config::TlsServerEndPoint,
6 : console::AuthSecret,
7 : context::RequestMonitoring,
8 : sasl, scram,
9 : stream::{PqStream, Stream},
10 : };
11 : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
12 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
13 : use std::io;
14 : use tokio::io::{AsyncRead, AsyncWrite};
15 : use tracing::info;
16 :
17 : /// Every authentication selector is supposed to implement this trait.
18 : pub trait AuthMethod {
19 : /// Any authentication selector should provide initial backend message
20 : /// containing auth method name and parameters, e.g. md5 salt.
21 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
22 : }
23 :
24 : /// Initial state of [`AuthFlow`].
25 : pub struct Begin;
26 :
27 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
28 : pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
29 :
30 : impl AuthMethod for Scram<'_> {
31 : #[inline(always)]
32 24 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
33 24 : if channel_binding {
34 24 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
35 : } else {
36 0 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
37 0 : scram::METHODS_WITHOUT_PLUS,
38 0 : ))
39 : }
40 24 : }
41 : }
42 :
43 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
44 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
45 : pub struct PasswordHack;
46 :
47 : impl AuthMethod for PasswordHack {
48 : #[inline(always)]
49 0 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
50 0 : Be::AuthenticationCleartextPassword
51 0 : }
52 : }
53 :
54 : /// Use clear-text password auth called `password` in docs
55 : /// <https://www.postgresql.org/docs/current/auth-password.html>
56 : pub struct CleartextPassword(pub AuthSecret);
57 :
58 : impl AuthMethod for CleartextPassword {
59 : #[inline(always)]
60 0 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
61 0 : Be::AuthenticationCleartextPassword
62 0 : }
63 : }
64 :
65 : /// This wrapper for [`PqStream`] performs client authentication.
66 : #[must_use]
67 : pub struct AuthFlow<'a, S, State> {
68 : /// The underlying stream which implements libpq's protocol.
69 : stream: &'a mut PqStream<Stream<S>>,
70 : /// State might contain ancillary data (see [`Self::begin`]).
71 : state: State,
72 : tls_server_end_point: TlsServerEndPoint,
73 : }
74 :
75 : /// Initial state of the stream wrapper.
76 : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
77 : /// Create a new wrapper for client authentication.
78 24 : pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
79 24 : let tls_server_end_point = stream.get_ref().tls_server_end_point();
80 24 :
81 24 : Self {
82 24 : stream,
83 24 : state: Begin,
84 24 : tls_server_end_point,
85 24 : }
86 24 : }
87 :
88 : /// Move to the next step by sending auth method's name & params to client.
89 24 : pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
90 24 : self.stream
91 24 : .write_message(&method.first_message(self.tls_server_end_point.supported()))
92 0 : .await?;
93 :
94 24 : Ok(AuthFlow {
95 24 : stream: self.stream,
96 24 : state: method,
97 24 : tls_server_end_point: self.tls_server_end_point,
98 24 : })
99 24 : }
100 : }
101 :
102 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
103 : /// Perform user authentication. Raise an error in case authentication failed.
104 0 : pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
105 0 : let msg = self.stream.read_password_message().await?;
106 0 : let password = msg
107 0 : .strip_suffix(&[0])
108 0 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
109 :
110 0 : let payload = PasswordHackPayload::parse(password)
111 0 : // If we ended up here and the payload is malformed, it means that
112 0 : // the user neither enabled SNI nor resorted to any other method
113 0 : // for passing the project name we rely on. We should show them
114 0 : // the most helpful error message and point to the documentation.
115 0 : .ok_or(AuthErrorImpl::MissingEndpointName)?;
116 :
117 0 : Ok(payload)
118 0 : }
119 : }
120 :
121 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
122 : /// Perform user authentication. Raise an error in case authentication failed.
123 0 : pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
124 0 : let msg = self.stream.read_password_message().await?;
125 0 : let password = msg
126 0 : .strip_suffix(&[0])
127 0 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
128 :
129 0 : let outcome = validate_password_and_exchange(password, self.state.0)?;
130 :
131 0 : if let sasl::Outcome::Success(_) = &outcome {
132 0 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
133 0 : }
134 :
135 0 : Ok(outcome)
136 0 : }
137 : }
138 :
139 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
140 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
141 : /// Perform user authentication. Raise an error in case authentication failed.
142 24 : pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
143 24 : let Scram(secret, ctx) = self.state;
144 24 :
145 24 : // pause the timer while we communicate with the client
146 24 : let _paused = ctx.latency_timer.pause();
147 :
148 : // Initial client message contains the chosen auth method's name.
149 24 : let msg = self.stream.read_password_message().await?;
150 22 : let sasl = sasl::FirstMessage::parse(&msg)
151 22 : .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
152 :
153 : // Currently, the only supported SASL method is SCRAM.
154 22 : if !scram::METHODS.contains(&sasl.method) {
155 0 : return Err(super::AuthError::bad_auth_method(sasl.method));
156 22 : }
157 22 :
158 22 : match sasl.method {
159 22 : SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
160 12 : SCRAM_SHA_256_PLUS => {
161 12 : ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
162 : }
163 0 : _ => {}
164 : }
165 0 : info!("client chooses {}", sasl.method);
166 :
167 22 : let outcome = sasl::SaslStream::new(self.stream, sasl.message)
168 22 : .authenticate(scram::Exchange::new(
169 22 : secret,
170 22 : rand::random,
171 22 : self.tls_server_end_point,
172 22 : ))
173 20 : .await?;
174 :
175 12 : if let sasl::Outcome::Success(_) = &outcome {
176 10 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
177 2 : }
178 :
179 12 : Ok(outcome)
180 24 : }
181 : }
182 :
183 0 : pub(crate) fn validate_password_and_exchange(
184 0 : password: &[u8],
185 0 : secret: AuthSecret,
186 0 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
187 0 : match secret {
188 : #[cfg(any(test, feature = "testing"))]
189 : AuthSecret::Md5(_) => {
190 : // test only
191 0 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
192 0 : password.to_owned(),
193 0 : )))
194 : }
195 : // perform scram authentication as both client and server to validate the keys
196 0 : AuthSecret::Scram(scram_secret) => {
197 0 : use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
198 0 : let sasl_client = ScramSha256::new(password, ChannelBinding::unsupported());
199 0 : let outcome = crate::scram::exchange(
200 0 : &scram_secret,
201 0 : sasl_client,
202 0 : crate::config::TlsServerEndPoint::Undefined,
203 0 : )?;
204 :
205 0 : let client_key = match outcome {
206 0 : sasl::Outcome::Success(client_key) => client_key,
207 0 : sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
208 : };
209 :
210 0 : let keys = crate::compute::ScramKeys {
211 0 : client_key: client_key.as_bytes(),
212 0 : server_key: scram_secret.server_key.as_bytes(),
213 0 : };
214 0 :
215 0 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
216 0 : tokio_postgres::config::AuthKeys::ScramSha256(keys),
217 0 : )))
218 : }
219 : }
220 0 : }
|