Line data Source code
1 : //! Main authentication flow.
2 :
3 : use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
4 : use crate::{
5 : config::TlsServerEndPoint,
6 : console::AuthSecret,
7 : context::RequestMonitoring,
8 : intern::EndpointIdInt,
9 : sasl,
10 : scram::{self, threadpool::ThreadPool},
11 : stream::{PqStream, Stream},
12 : };
13 : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
14 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
15 : use std::{io, sync::Arc};
16 : use tokio::io::{AsyncRead, AsyncWrite};
17 : use tracing::info;
18 :
19 : /// Every authentication selector is supposed to implement this trait.
20 : pub(crate) trait AuthMethod {
21 : /// Any authentication selector should provide initial backend message
22 : /// containing auth method name and parameters, e.g. md5 salt.
23 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
24 : }
25 :
26 : /// Initial state of [`AuthFlow`].
27 : pub(crate) struct Begin;
28 :
29 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
30 : pub(crate) struct Scram<'a>(
31 : pub(crate) &'a scram::ServerSecret,
32 : pub(crate) &'a RequestMonitoring,
33 : );
34 :
35 : impl AuthMethod for Scram<'_> {
36 : #[inline(always)]
37 78 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
38 78 : if channel_binding {
39 72 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
40 : } else {
41 6 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
42 6 : scram::METHODS_WITHOUT_PLUS,
43 6 : ))
44 : }
45 78 : }
46 : }
47 :
48 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
49 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
50 : pub(crate) struct PasswordHack;
51 :
52 : impl AuthMethod for PasswordHack {
53 : #[inline(always)]
54 6 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
55 6 : Be::AuthenticationCleartextPassword
56 6 : }
57 : }
58 :
59 : /// Use clear-text password auth called `password` in docs
60 : /// <https://www.postgresql.org/docs/current/auth-password.html>
61 : pub(crate) struct CleartextPassword {
62 : pub(crate) pool: Arc<ThreadPool>,
63 : pub(crate) endpoint: EndpointIdInt,
64 : pub(crate) secret: AuthSecret,
65 : }
66 :
67 : impl AuthMethod for CleartextPassword {
68 : #[inline(always)]
69 6 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
70 6 : Be::AuthenticationCleartextPassword
71 6 : }
72 : }
73 :
74 : /// This wrapper for [`PqStream`] performs client authentication.
75 : #[must_use]
76 : pub(crate) struct AuthFlow<'a, S, State> {
77 : /// The underlying stream which implements libpq's protocol.
78 : stream: &'a mut PqStream<Stream<S>>,
79 : /// State might contain ancillary data (see [`Self::begin`]).
80 : state: State,
81 : tls_server_end_point: TlsServerEndPoint,
82 : }
83 :
84 : /// Initial state of the stream wrapper.
85 : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
86 : /// Create a new wrapper for client authentication.
87 90 : pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
88 90 : let tls_server_end_point = stream.get_ref().tls_server_end_point();
89 90 :
90 90 : Self {
91 90 : stream,
92 90 : state: Begin,
93 90 : tls_server_end_point,
94 90 : }
95 90 : }
96 :
97 : /// Move to the next step by sending auth method's name & params to client.
98 90 : pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
99 90 : self.stream
100 90 : .write_message(&method.first_message(self.tls_server_end_point.supported()))
101 0 : .await?;
102 :
103 90 : Ok(AuthFlow {
104 90 : stream: self.stream,
105 90 : state: method,
106 90 : tls_server_end_point: self.tls_server_end_point,
107 90 : })
108 90 : }
109 : }
110 :
111 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
112 : /// Perform user authentication. Raise an error in case authentication failed.
113 6 : pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
114 6 : let msg = self.stream.read_password_message().await?;
115 6 : let password = msg
116 6 : .strip_suffix(&[0])
117 6 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
118 :
119 6 : let payload = PasswordHackPayload::parse(password)
120 6 : // If we ended up here and the payload is malformed, it means that
121 6 : // the user neither enabled SNI nor resorted to any other method
122 6 : // for passing the project name we rely on. We should show them
123 6 : // the most helpful error message and point to the documentation.
124 6 : .ok_or(AuthErrorImpl::MissingEndpointName)?;
125 :
126 6 : Ok(payload)
127 6 : }
128 : }
129 :
130 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
131 : /// Perform user authentication. Raise an error in case authentication failed.
132 6 : pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
133 6 : let msg = self.stream.read_password_message().await?;
134 6 : let password = msg
135 6 : .strip_suffix(&[0])
136 6 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
137 :
138 6 : let outcome = validate_password_and_exchange(
139 6 : &self.state.pool,
140 6 : self.state.endpoint,
141 6 : password,
142 6 : self.state.secret,
143 6 : )
144 6 : .await?;
145 :
146 6 : if let sasl::Outcome::Success(_) = &outcome {
147 6 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
148 0 : }
149 :
150 6 : Ok(outcome)
151 6 : }
152 : }
153 :
154 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
155 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
156 : /// Perform user authentication. Raise an error in case authentication failed.
157 78 : pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
158 78 : let Scram(secret, ctx) = self.state;
159 78 :
160 78 : // pause the timer while we communicate with the client
161 78 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
162 :
163 : // Initial client message contains the chosen auth method's name.
164 78 : let msg = self.stream.read_password_message().await?;
165 72 : let sasl = sasl::FirstMessage::parse(&msg)
166 72 : .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
167 :
168 : // Currently, the only supported SASL method is SCRAM.
169 72 : if !scram::METHODS.contains(&sasl.method) {
170 0 : return Err(super::AuthError::bad_auth_method(sasl.method));
171 72 : }
172 72 :
173 72 : match sasl.method {
174 72 : SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
175 36 : SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
176 0 : _ => {}
177 : }
178 72 : info!("client chooses {}", sasl.method);
179 :
180 72 : let outcome = sasl::SaslStream::new(self.stream, sasl.message)
181 72 : .authenticate(scram::Exchange::new(
182 72 : secret,
183 72 : rand::random,
184 72 : self.tls_server_end_point,
185 72 : ))
186 66 : .await?;
187 :
188 42 : if let sasl::Outcome::Success(_) = &outcome {
189 36 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
190 6 : }
191 :
192 42 : Ok(outcome)
193 78 : }
194 : }
195 :
196 12 : pub(crate) async fn validate_password_and_exchange(
197 12 : pool: &ThreadPool,
198 12 : endpoint: EndpointIdInt,
199 12 : password: &[u8],
200 12 : secret: AuthSecret,
201 12 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
202 12 : match secret {
203 : #[cfg(any(test, feature = "testing"))]
204 : AuthSecret::Md5(_) => {
205 : // test only
206 0 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
207 0 : password.to_owned(),
208 0 : )))
209 : }
210 : // perform scram authentication as both client and server to validate the keys
211 12 : AuthSecret::Scram(scram_secret) => {
212 12 : let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
213 :
214 12 : let client_key = match outcome {
215 12 : sasl::Outcome::Success(client_key) => client_key,
216 0 : sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
217 : };
218 :
219 12 : let keys = crate::compute::ScramKeys {
220 12 : client_key: client_key.as_bytes(),
221 12 : server_key: scram_secret.server_key.as_bytes(),
222 12 : };
223 12 :
224 12 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
225 12 : tokio_postgres::config::AuthKeys::ScramSha256(keys),
226 12 : )))
227 : }
228 : }
229 12 : }
|