Line data Source code
1 : //! Main authentication flow.
2 :
3 : use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
4 : use crate::{
5 : config::TlsServerEndPoint,
6 : console::AuthSecret,
7 : context::RequestMonitoring,
8 : intern::EndpointIdInt,
9 : sasl,
10 : scram::{self, threadpool::ThreadPool},
11 : stream::{PqStream, Stream},
12 : };
13 : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
14 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
15 : use std::{io, sync::Arc};
16 : use tokio::io::{AsyncRead, AsyncWrite};
17 : use tracing::info;
18 :
19 : /// Every authentication selector is supposed to implement this trait.
20 : pub(crate) trait AuthMethod {
21 : /// Any authentication selector should provide initial backend message
22 : /// containing auth method name and parameters, e.g. md5 salt.
23 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
24 : }
25 :
26 : /// Initial state of [`AuthFlow`].
27 : pub(crate) struct Begin;
28 :
29 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
30 : pub(crate) struct Scram<'a>(
31 : pub(crate) &'a scram::ServerSecret,
32 : pub(crate) &'a RequestMonitoring,
33 : );
34 :
35 : impl AuthMethod for Scram<'_> {
36 : #[inline(always)]
37 13 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
38 13 : if channel_binding {
39 12 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
40 : } else {
41 1 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
42 1 : scram::METHODS_WITHOUT_PLUS,
43 1 : ))
44 : }
45 13 : }
46 : }
47 :
48 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
49 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
50 : pub(crate) struct PasswordHack;
51 :
52 : impl AuthMethod for PasswordHack {
53 : #[inline(always)]
54 1 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
55 1 : Be::AuthenticationCleartextPassword
56 1 : }
57 : }
58 :
59 : /// Use clear-text password auth called `password` in docs
60 : /// <https://www.postgresql.org/docs/current/auth-password.html>
61 : pub(crate) struct CleartextPassword {
62 : pub(crate) pool: Arc<ThreadPool>,
63 : pub(crate) endpoint: EndpointIdInt,
64 : pub(crate) secret: AuthSecret,
65 : }
66 :
67 : impl AuthMethod for CleartextPassword {
68 : #[inline(always)]
69 1 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
70 1 : Be::AuthenticationCleartextPassword
71 1 : }
72 : }
73 :
74 : /// This wrapper for [`PqStream`] performs client authentication.
75 : #[must_use]
76 : pub(crate) struct AuthFlow<'a, S, State> {
77 : /// The underlying stream which implements libpq's protocol.
78 : stream: &'a mut PqStream<Stream<S>>,
79 : /// State might contain ancillary data (see [`Self::begin`]).
80 : state: State,
81 : tls_server_end_point: TlsServerEndPoint,
82 : }
83 :
84 : /// Initial state of the stream wrapper.
85 : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
86 : /// Create a new wrapper for client authentication.
87 15 : pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
88 15 : let tls_server_end_point = stream.get_ref().tls_server_end_point();
89 15 :
90 15 : Self {
91 15 : stream,
92 15 : state: Begin,
93 15 : tls_server_end_point,
94 15 : }
95 15 : }
96 :
97 : /// Move to the next step by sending auth method's name & params to client.
98 15 : pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
99 15 : self.stream
100 15 : .write_message(&method.first_message(self.tls_server_end_point.supported()))
101 0 : .await?;
102 :
103 15 : Ok(AuthFlow {
104 15 : stream: self.stream,
105 15 : state: method,
106 15 : tls_server_end_point: self.tls_server_end_point,
107 15 : })
108 15 : }
109 : }
110 :
111 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
112 : /// Perform user authentication. Raise an error in case authentication failed.
113 1 : pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
114 1 : let msg = self.stream.read_password_message().await?;
115 1 : let password = msg
116 1 : .strip_suffix(&[0])
117 1 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
118 :
119 1 : let payload = PasswordHackPayload::parse(password)
120 1 : // If we ended up here and the payload is malformed, it means that
121 1 : // the user neither enabled SNI nor resorted to any other method
122 1 : // for passing the project name we rely on. We should show them
123 1 : // the most helpful error message and point to the documentation.
124 1 : .ok_or(AuthErrorImpl::MissingEndpointName)?;
125 :
126 1 : Ok(payload)
127 1 : }
128 : }
129 :
130 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
131 : /// Perform user authentication. Raise an error in case authentication failed.
132 1 : pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
133 1 : let msg = self.stream.read_password_message().await?;
134 1 : let password = msg
135 1 : .strip_suffix(&[0])
136 1 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
137 :
138 1 : let outcome = validate_password_and_exchange(
139 1 : &self.state.pool,
140 1 : self.state.endpoint,
141 1 : password,
142 1 : self.state.secret,
143 1 : )
144 1 : .await?;
145 :
146 1 : if let sasl::Outcome::Success(_) = &outcome {
147 1 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
148 0 : }
149 :
150 1 : Ok(outcome)
151 1 : }
152 : }
153 :
154 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
155 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
156 : /// Perform user authentication. Raise an error in case authentication failed.
157 13 : pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
158 13 : let Scram(secret, ctx) = self.state;
159 13 :
160 13 : // pause the timer while we communicate with the client
161 13 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
162 :
163 : // Initial client message contains the chosen auth method's name.
164 13 : let msg = self.stream.read_password_message().await?;
165 12 : let sasl = sasl::FirstMessage::parse(&msg)
166 12 : .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
167 :
168 : // Currently, the only supported SASL method is SCRAM.
169 12 : if !scram::METHODS.contains(&sasl.method) {
170 0 : return Err(super::AuthError::bad_auth_method(sasl.method));
171 12 : }
172 12 :
173 12 : match sasl.method {
174 12 : SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
175 6 : SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
176 0 : _ => {}
177 : }
178 12 : info!("client chooses {}", sasl.method);
179 :
180 12 : let outcome = sasl::SaslStream::new(self.stream, sasl.message)
181 12 : .authenticate(scram::Exchange::new(
182 12 : secret,
183 12 : rand::random,
184 12 : self.tls_server_end_point,
185 12 : ))
186 11 : .await?;
187 :
188 7 : if let sasl::Outcome::Success(_) = &outcome {
189 6 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
190 1 : }
191 :
192 7 : Ok(outcome)
193 13 : }
194 : }
195 :
196 2 : pub(crate) async fn validate_password_and_exchange(
197 2 : pool: &ThreadPool,
198 2 : endpoint: EndpointIdInt,
199 2 : password: &[u8],
200 2 : secret: AuthSecret,
201 2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
202 2 : match secret {
203 : #[cfg(any(test, feature = "testing"))]
204 : AuthSecret::Md5(_) => {
205 : // test only
206 0 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
207 0 : password.to_owned(),
208 0 : )))
209 : }
210 : // perform scram authentication as both client and server to validate the keys
211 2 : AuthSecret::Scram(scram_secret) => {
212 2 : let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
213 :
214 2 : let client_key = match outcome {
215 2 : sasl::Outcome::Success(client_key) => client_key,
216 0 : sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
217 : };
218 :
219 2 : let keys = crate::compute::ScramKeys {
220 2 : client_key: client_key.as_bytes(),
221 2 : server_key: scram_secret.server_key.as_bytes(),
222 2 : };
223 2 :
224 2 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
225 2 : tokio_postgres::config::AuthKeys::ScramSha256(keys),
226 2 : )))
227 : }
228 : }
229 2 : }
|