Line data Source code
1 : //! Main authentication flow.
2 :
3 : use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
4 : use crate::{
5 : config::TlsServerEndPoint,
6 : console::AuthSecret,
7 : context::RequestMonitoring,
8 : intern::EndpointIdInt,
9 : sasl,
10 : scram::{self, threadpool::ThreadPool},
11 : stream::{PqStream, Stream},
12 : };
13 : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
14 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
15 : use std::{io, sync::Arc};
16 : use tokio::io::{AsyncRead, AsyncWrite};
17 : use tracing::info;
18 :
19 : /// Every authentication selector is supposed to implement this trait.
20 : pub trait AuthMethod {
21 : /// Any authentication selector should provide initial backend message
22 : /// containing auth method name and parameters, e.g. md5 salt.
23 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
24 : }
25 :
26 : /// Initial state of [`AuthFlow`].
27 : pub struct Begin;
28 :
29 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
30 : pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
31 :
32 : impl AuthMethod for Scram<'_> {
33 : #[inline(always)]
34 26 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
35 26 : if channel_binding {
36 24 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
37 : } else {
38 2 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
39 2 : scram::METHODS_WITHOUT_PLUS,
40 2 : ))
41 : }
42 26 : }
43 : }
44 :
45 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
46 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
47 : pub struct PasswordHack;
48 :
49 : impl AuthMethod for PasswordHack {
50 : #[inline(always)]
51 2 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
52 2 : Be::AuthenticationCleartextPassword
53 2 : }
54 : }
55 :
56 : /// Use clear-text password auth called `password` in docs
57 : /// <https://www.postgresql.org/docs/current/auth-password.html>
58 : pub struct CleartextPassword {
59 : pub pool: Arc<ThreadPool>,
60 : pub endpoint: EndpointIdInt,
61 : pub secret: AuthSecret,
62 : }
63 :
64 : impl AuthMethod for CleartextPassword {
65 : #[inline(always)]
66 2 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
67 2 : Be::AuthenticationCleartextPassword
68 2 : }
69 : }
70 :
71 : /// This wrapper for [`PqStream`] performs client authentication.
72 : #[must_use]
73 : pub struct AuthFlow<'a, S, State> {
74 : /// The underlying stream which implements libpq's protocol.
75 : stream: &'a mut PqStream<Stream<S>>,
76 : /// State might contain ancillary data (see [`Self::begin`]).
77 : state: State,
78 : tls_server_end_point: TlsServerEndPoint,
79 : }
80 :
81 : /// Initial state of the stream wrapper.
82 : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
83 : /// Create a new wrapper for client authentication.
84 30 : pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
85 30 : let tls_server_end_point = stream.get_ref().tls_server_end_point();
86 30 :
87 30 : Self {
88 30 : stream,
89 30 : state: Begin,
90 30 : tls_server_end_point,
91 30 : }
92 30 : }
93 :
94 : /// Move to the next step by sending auth method's name & params to client.
95 30 : pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
96 30 : self.stream
97 30 : .write_message(&method.first_message(self.tls_server_end_point.supported()))
98 0 : .await?;
99 :
100 30 : Ok(AuthFlow {
101 30 : stream: self.stream,
102 30 : state: method,
103 30 : tls_server_end_point: self.tls_server_end_point,
104 30 : })
105 30 : }
106 : }
107 :
108 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
109 : /// Perform user authentication. Raise an error in case authentication failed.
110 2 : pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
111 2 : let msg = self.stream.read_password_message().await?;
112 2 : let password = msg
113 2 : .strip_suffix(&[0])
114 2 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
115 :
116 2 : let payload = PasswordHackPayload::parse(password)
117 2 : // If we ended up here and the payload is malformed, it means that
118 2 : // the user neither enabled SNI nor resorted to any other method
119 2 : // for passing the project name we rely on. We should show them
120 2 : // the most helpful error message and point to the documentation.
121 2 : .ok_or(AuthErrorImpl::MissingEndpointName)?;
122 :
123 2 : Ok(payload)
124 2 : }
125 : }
126 :
127 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
128 : /// Perform user authentication. Raise an error in case authentication failed.
129 2 : pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
130 2 : let msg = self.stream.read_password_message().await?;
131 2 : let password = msg
132 2 : .strip_suffix(&[0])
133 2 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
134 :
135 2 : let outcome = validate_password_and_exchange(
136 2 : &self.state.pool,
137 2 : self.state.endpoint,
138 2 : password,
139 2 : self.state.secret,
140 2 : )
141 2 : .await?;
142 :
143 2 : if let sasl::Outcome::Success(_) = &outcome {
144 2 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
145 0 : }
146 :
147 2 : Ok(outcome)
148 2 : }
149 : }
150 :
151 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
152 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
153 : /// Perform user authentication. Raise an error in case authentication failed.
154 26 : pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
155 26 : let Scram(secret, ctx) = self.state;
156 26 :
157 26 : // pause the timer while we communicate with the client
158 26 : let _paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
159 :
160 : // Initial client message contains the chosen auth method's name.
161 26 : let msg = self.stream.read_password_message().await?;
162 24 : let sasl = sasl::FirstMessage::parse(&msg)
163 24 : .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
164 :
165 : // Currently, the only supported SASL method is SCRAM.
166 24 : if !scram::METHODS.contains(&sasl.method) {
167 0 : return Err(super::AuthError::bad_auth_method(sasl.method));
168 24 : }
169 24 :
170 24 : match sasl.method {
171 24 : SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
172 12 : SCRAM_SHA_256_PLUS => {
173 12 : ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
174 : }
175 0 : _ => {}
176 : }
177 24 : info!("client chooses {}", sasl.method);
178 :
179 24 : let outcome = sasl::SaslStream::new(self.stream, sasl.message)
180 24 : .authenticate(scram::Exchange::new(
181 24 : secret,
182 24 : rand::random,
183 24 : self.tls_server_end_point,
184 24 : ))
185 22 : .await?;
186 :
187 14 : if let sasl::Outcome::Success(_) = &outcome {
188 12 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
189 2 : }
190 :
191 14 : Ok(outcome)
192 26 : }
193 : }
194 :
195 4 : pub(crate) async fn validate_password_and_exchange(
196 4 : pool: &ThreadPool,
197 4 : endpoint: EndpointIdInt,
198 4 : password: &[u8],
199 4 : secret: AuthSecret,
200 4 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
201 4 : match secret {
202 : #[cfg(any(test, feature = "testing"))]
203 : AuthSecret::Md5(_) => {
204 : // test only
205 0 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
206 0 : password.to_owned(),
207 0 : )))
208 : }
209 : // perform scram authentication as both client and server to validate the keys
210 4 : AuthSecret::Scram(scram_secret) => {
211 4 : let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
212 :
213 4 : let client_key = match outcome {
214 4 : sasl::Outcome::Success(client_key) => client_key,
215 0 : sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
216 : };
217 :
218 4 : let keys = crate::compute::ScramKeys {
219 4 : client_key: client_key.as_bytes(),
220 4 : server_key: scram_secret.server_key.as_bytes(),
221 4 : };
222 4 :
223 4 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
224 4 : tokio_postgres::config::AuthKeys::ScramSha256(keys),
225 4 : )))
226 : }
227 : }
228 4 : }
|