Line data Source code
1 : //! Main authentication flow.
2 :
3 : use std::sync::Arc;
4 :
5 : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
6 : use tokio::io::{AsyncRead, AsyncWrite};
7 : use tracing::info;
8 :
9 : use super::backend::ComputeCredentialKeys;
10 : use super::{AuthError, PasswordHackPayload};
11 : use crate::context::RequestContext;
12 : use crate::control_plane::AuthSecret;
13 : use crate::intern::EndpointIdInt;
14 : use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
15 : use crate::sasl;
16 : use crate::scram::threadpool::ThreadPool;
17 : use crate::scram::{self};
18 : use crate::stream::{PqStream, Stream};
19 : use crate::tls::TlsServerEndPoint;
20 :
21 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
22 : pub(crate) struct Scram<'a>(
23 : pub(crate) &'a scram::ServerSecret,
24 : pub(crate) &'a RequestContext,
25 : );
26 :
27 : impl Scram<'_> {
28 : #[inline(always)]
29 13 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
30 13 : if channel_binding {
31 12 : BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
32 : } else {
33 1 : BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
34 1 : scram::METHODS_WITHOUT_PLUS,
35 1 : ))
36 : }
37 13 : }
38 : }
39 :
40 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
41 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
42 : pub(crate) struct PasswordHack;
43 :
44 : /// Use clear-text password auth called `password` in docs
45 : /// <https://www.postgresql.org/docs/current/auth-password.html>
46 : pub(crate) struct CleartextPassword {
47 : pub(crate) pool: Arc<ThreadPool>,
48 : pub(crate) endpoint: EndpointIdInt,
49 : pub(crate) secret: AuthSecret,
50 : }
51 :
52 : /// This wrapper for [`PqStream`] performs client authentication.
53 : #[must_use]
54 : pub(crate) struct AuthFlow<'a, S, State> {
55 : /// The underlying stream which implements libpq's protocol.
56 : stream: &'a mut PqStream<Stream<S>>,
57 : /// State might contain ancillary data.
58 : state: State,
59 : tls_server_end_point: TlsServerEndPoint,
60 : }
61 :
62 : /// Initial state of the stream wrapper.
63 : impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
64 : /// Create a new wrapper for client authentication.
65 15 : pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
66 15 : let tls_server_end_point = stream.get_ref().tls_server_end_point();
67 :
68 15 : Self {
69 15 : stream,
70 15 : state: method,
71 15 : tls_server_end_point,
72 15 : }
73 15 : }
74 : }
75 :
76 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
77 : /// Perform user authentication. Raise an error in case authentication failed.
78 1 : pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
79 1 : self.stream
80 1 : .write_message(BeMessage::AuthenticationCleartextPassword);
81 1 : self.stream.flush().await?;
82 :
83 1 : let msg = self.stream.read_password_message().await?;
84 1 : let password = msg
85 1 : .strip_suffix(&[0])
86 1 : .ok_or(AuthError::MalformedPassword("missing terminator"))?;
87 :
88 1 : let payload = PasswordHackPayload::parse(password)
89 : // If we ended up here and the payload is malformed, it means that
90 : // the user neither enabled SNI nor resorted to any other method
91 : // for passing the project name we rely on. We should show them
92 : // the most helpful error message and point to the documentation.
93 1 : .ok_or(AuthError::MissingEndpointName)?;
94 :
95 1 : Ok(payload)
96 1 : }
97 : }
98 :
99 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
100 : /// Perform user authentication. Raise an error in case authentication failed.
101 1 : pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
102 1 : self.stream
103 1 : .write_message(BeMessage::AuthenticationCleartextPassword);
104 1 : self.stream.flush().await?;
105 :
106 1 : let msg = self.stream.read_password_message().await?;
107 1 : let password = msg
108 1 : .strip_suffix(&[0])
109 1 : .ok_or(AuthError::MalformedPassword("missing terminator"))?;
110 :
111 1 : let outcome = validate_password_and_exchange(
112 1 : &self.state.pool,
113 1 : self.state.endpoint,
114 1 : password,
115 1 : self.state.secret,
116 1 : )
117 1 : .await?;
118 :
119 1 : if let sasl::Outcome::Success(_) = &outcome {
120 1 : self.stream.write_message(BeMessage::AuthenticationOk);
121 1 : }
122 :
123 1 : Ok(outcome)
124 1 : }
125 : }
126 :
127 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
128 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
129 : /// Perform user authentication. Raise an error in case authentication failed.
130 13 : pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
131 13 : let Scram(secret, ctx) = self.state;
132 13 : let channel_binding = self.tls_server_end_point;
133 :
134 : // send sasl message.
135 : {
136 : // pause the timer while we communicate with the client
137 13 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
138 :
139 13 : let sasl = self.state.first_message(channel_binding.supported());
140 13 : self.stream.write_message(sasl);
141 13 : self.stream.flush().await?;
142 : }
143 :
144 : // complete sasl handshake.
145 13 : sasl::authenticate(ctx, self.stream, |method| {
146 : // Currently, the only supported SASL method is SCRAM.
147 12 : match method {
148 12 : SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
149 6 : SCRAM_SHA_256_PLUS => {
150 6 : ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus);
151 6 : }
152 0 : method => return Err(sasl::Error::BadAuthMethod(method.into())),
153 : }
154 :
155 : // TODO: make this a metric instead
156 12 : info!("client chooses {}", method);
157 :
158 12 : Ok(scram::Exchange::new(secret, rand::random, channel_binding))
159 12 : })
160 13 : .await
161 13 : .map_err(AuthError::Sasl)
162 13 : }
163 : }
164 :
165 2 : pub(crate) async fn validate_password_and_exchange(
166 2 : pool: &ThreadPool,
167 2 : endpoint: EndpointIdInt,
168 2 : password: &[u8],
169 2 : secret: AuthSecret,
170 2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
171 2 : match secret {
172 : // perform scram authentication as both client and server to validate the keys
173 2 : AuthSecret::Scram(scram_secret) => {
174 2 : let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
175 :
176 2 : let client_key = match outcome {
177 2 : sasl::Outcome::Success(client_key) => client_key,
178 0 : sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
179 : };
180 :
181 2 : let keys = crate::compute::ScramKeys {
182 2 : client_key: client_key.as_bytes(),
183 2 : server_key: scram_secret.server_key.as_bytes(),
184 2 : };
185 :
186 2 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
187 2 : postgres_client::config::AuthKeys::ScramSha256(keys),
188 2 : )))
189 : }
190 : }
191 2 : }
|