Line data Source code
1 : //! Main authentication flow.
2 :
3 : use std::io;
4 : use std::sync::Arc;
5 :
6 : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
7 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
8 : use tokio::io::{AsyncRead, AsyncWrite};
9 : use tracing::info;
10 :
11 : use super::backend::ComputeCredentialKeys;
12 : use super::{AuthError, PasswordHackPayload};
13 : use crate::config::TlsServerEndPoint;
14 : use crate::context::RequestContext;
15 : use crate::control_plane::AuthSecret;
16 : use crate::intern::EndpointIdInt;
17 : use crate::sasl;
18 : use crate::scram::threadpool::ThreadPool;
19 : use crate::scram::{self};
20 : use crate::stream::{PqStream, Stream};
21 :
22 : /// Every authentication selector is supposed to implement this trait.
23 : pub(crate) trait AuthMethod {
24 : /// Any authentication selector should provide initial backend message
25 : /// containing auth method name and parameters, e.g. md5 salt.
26 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
27 : }
28 :
29 : /// Initial state of [`AuthFlow`].
30 : pub(crate) struct Begin;
31 :
32 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
33 : pub(crate) struct Scram<'a>(
34 : pub(crate) &'a scram::ServerSecret,
35 : pub(crate) &'a RequestContext,
36 : );
37 :
38 : impl AuthMethod for Scram<'_> {
39 : #[inline(always)]
40 13 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
41 13 : if channel_binding {
42 12 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
43 : } else {
44 1 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
45 1 : scram::METHODS_WITHOUT_PLUS,
46 1 : ))
47 : }
48 13 : }
49 : }
50 :
51 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
52 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
53 : pub(crate) struct PasswordHack;
54 :
55 : impl AuthMethod for PasswordHack {
56 : #[inline(always)]
57 1 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
58 1 : Be::AuthenticationCleartextPassword
59 1 : }
60 : }
61 :
62 : /// Use clear-text password auth called `password` in docs
63 : /// <https://www.postgresql.org/docs/current/auth-password.html>
64 : pub(crate) struct CleartextPassword {
65 : pub(crate) pool: Arc<ThreadPool>,
66 : pub(crate) endpoint: EndpointIdInt,
67 : pub(crate) secret: AuthSecret,
68 : }
69 :
70 : impl AuthMethod for CleartextPassword {
71 : #[inline(always)]
72 1 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
73 1 : Be::AuthenticationCleartextPassword
74 1 : }
75 : }
76 :
77 : /// This wrapper for [`PqStream`] performs client authentication.
78 : #[must_use]
79 : pub(crate) struct AuthFlow<'a, S, State> {
80 : /// The underlying stream which implements libpq's protocol.
81 : stream: &'a mut PqStream<Stream<S>>,
82 : /// State might contain ancillary data (see [`Self::begin`]).
83 : state: State,
84 : tls_server_end_point: TlsServerEndPoint,
85 : }
86 :
87 : /// Initial state of the stream wrapper.
88 : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
89 : /// Create a new wrapper for client authentication.
90 15 : pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
91 15 : let tls_server_end_point = stream.get_ref().tls_server_end_point();
92 15 :
93 15 : Self {
94 15 : stream,
95 15 : state: Begin,
96 15 : tls_server_end_point,
97 15 : }
98 15 : }
99 :
100 : /// Move to the next step by sending auth method's name & params to client.
101 15 : pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
102 15 : self.stream
103 15 : .write_message(&method.first_message(self.tls_server_end_point.supported()))
104 0 : .await?;
105 :
106 15 : Ok(AuthFlow {
107 15 : stream: self.stream,
108 15 : state: method,
109 15 : tls_server_end_point: self.tls_server_end_point,
110 15 : })
111 15 : }
112 : }
113 :
114 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
115 : /// Perform user authentication. Raise an error in case authentication failed.
116 1 : pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
117 1 : let msg = self.stream.read_password_message().await?;
118 1 : let password = msg
119 1 : .strip_suffix(&[0])
120 1 : .ok_or(AuthError::MalformedPassword("missing terminator"))?;
121 :
122 1 : let payload = PasswordHackPayload::parse(password)
123 1 : // If we ended up here and the payload is malformed, it means that
124 1 : // the user neither enabled SNI nor resorted to any other method
125 1 : // for passing the project name we rely on. We should show them
126 1 : // the most helpful error message and point to the documentation.
127 1 : .ok_or(AuthError::MissingEndpointName)?;
128 :
129 1 : Ok(payload)
130 1 : }
131 : }
132 :
133 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
134 : /// Perform user authentication. Raise an error in case authentication failed.
135 1 : pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
136 1 : let msg = self.stream.read_password_message().await?;
137 1 : let password = msg
138 1 : .strip_suffix(&[0])
139 1 : .ok_or(AuthError::MalformedPassword("missing terminator"))?;
140 :
141 1 : let outcome = validate_password_and_exchange(
142 1 : &self.state.pool,
143 1 : self.state.endpoint,
144 1 : password,
145 1 : self.state.secret,
146 1 : )
147 1 : .await?;
148 :
149 1 : if let sasl::Outcome::Success(_) = &outcome {
150 1 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
151 0 : }
152 :
153 1 : Ok(outcome)
154 1 : }
155 : }
156 :
157 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
158 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
159 : /// Perform user authentication. Raise an error in case authentication failed.
160 13 : pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
161 13 : let Scram(secret, ctx) = self.state;
162 13 :
163 13 : // pause the timer while we communicate with the client
164 13 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
165 :
166 : // Initial client message contains the chosen auth method's name.
167 13 : let msg = self.stream.read_password_message().await?;
168 12 : let sasl = sasl::FirstMessage::parse(&msg)
169 12 : .ok_or(AuthError::MalformedPassword("bad sasl message"))?;
170 :
171 : // Currently, the only supported SASL method is SCRAM.
172 12 : if !scram::METHODS.contains(&sasl.method) {
173 0 : return Err(super::AuthError::bad_auth_method(sasl.method));
174 12 : }
175 12 :
176 12 : match sasl.method {
177 12 : SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
178 6 : SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
179 0 : _ => {}
180 : }
181 :
182 : // TODO: make this a metric instead
183 12 : info!("client chooses {}", sasl.method);
184 :
185 12 : let outcome = sasl::SaslStream::new(self.stream, sasl.message)
186 12 : .authenticate(scram::Exchange::new(
187 12 : secret,
188 12 : rand::random,
189 12 : self.tls_server_end_point,
190 12 : ))
191 11 : .await?;
192 :
193 7 : if let sasl::Outcome::Success(_) = &outcome {
194 6 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
195 1 : }
196 :
197 7 : Ok(outcome)
198 13 : }
199 : }
200 :
201 2 : pub(crate) async fn validate_password_and_exchange(
202 2 : pool: &ThreadPool,
203 2 : endpoint: EndpointIdInt,
204 2 : password: &[u8],
205 2 : secret: AuthSecret,
206 2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
207 2 : match secret {
208 : #[cfg(any(test, feature = "testing"))]
209 : AuthSecret::Md5(_) => {
210 : // test only
211 0 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
212 0 : password.to_owned(),
213 0 : )))
214 : }
215 : // perform scram authentication as both client and server to validate the keys
216 2 : AuthSecret::Scram(scram_secret) => {
217 2 : let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
218 :
219 2 : let client_key = match outcome {
220 2 : sasl::Outcome::Success(client_key) => client_key,
221 0 : sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
222 : };
223 :
224 2 : let keys = crate::compute::ScramKeys {
225 2 : client_key: client_key.as_bytes(),
226 2 : server_key: scram_secret.server_key.as_bytes(),
227 2 : };
228 2 :
229 2 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
230 2 : tokio_postgres::config::AuthKeys::ScramSha256(keys),
231 2 : )))
232 : }
233 : }
234 2 : }
|