Line data Source code
1 : //! Main authentication flow.
2 :
3 : use std::sync::Arc;
4 :
5 : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
6 : use tokio::io::{AsyncRead, AsyncWrite};
7 : use tracing::info;
8 :
9 : use super::backend::ComputeCredentialKeys;
10 : use super::{AuthError, PasswordHackPayload};
11 : use crate::context::RequestContext;
12 : use crate::control_plane::AuthSecret;
13 : use crate::intern::{EndpointIdInt, RoleNameInt};
14 : use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
15 : use crate::sasl;
16 : use crate::scram::threadpool::ThreadPool;
17 : use crate::scram::{self};
18 : use crate::stream::{PqStream, Stream};
19 : use crate::tls::TlsServerEndPoint;
20 :
21 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
22 : pub(crate) struct Scram<'a>(
23 : pub(crate) &'a scram::ServerSecret,
24 : pub(crate) &'a RequestContext,
25 : );
26 :
27 : impl Scram<'_> {
28 : #[inline(always)]
29 13 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
30 13 : if channel_binding {
31 12 : BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
32 : } else {
33 1 : BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
34 1 : scram::METHODS_WITHOUT_PLUS,
35 1 : ))
36 : }
37 13 : }
38 : }
39 :
40 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
41 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
42 : pub(crate) struct PasswordHack;
43 :
44 : /// Use clear-text password auth called `password` in docs
45 : /// <https://www.postgresql.org/docs/current/auth-password.html>
46 : pub(crate) struct CleartextPassword {
47 : pub(crate) pool: Arc<ThreadPool>,
48 : pub(crate) endpoint: EndpointIdInt,
49 : pub(crate) role: RoleNameInt,
50 : pub(crate) secret: AuthSecret,
51 : }
52 :
53 : /// This wrapper for [`PqStream`] performs client authentication.
54 : #[must_use]
55 : pub(crate) struct AuthFlow<'a, S, State> {
56 : /// The underlying stream which implements libpq's protocol.
57 : stream: &'a mut PqStream<Stream<S>>,
58 : /// State might contain ancillary data.
59 : state: State,
60 : tls_server_end_point: TlsServerEndPoint,
61 : }
62 :
63 : /// Initial state of the stream wrapper.
64 : impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
65 : /// Create a new wrapper for client authentication.
66 15 : pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
67 15 : let tls_server_end_point = stream.get_ref().tls_server_end_point();
68 :
69 15 : Self {
70 15 : stream,
71 15 : state: method,
72 15 : tls_server_end_point,
73 15 : }
74 15 : }
75 : }
76 :
77 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
78 : /// Perform user authentication. Raise an error in case authentication failed.
79 1 : pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
80 1 : self.stream
81 1 : .write_message(BeMessage::AuthenticationCleartextPassword);
82 1 : self.stream.flush().await?;
83 :
84 1 : let msg = self.stream.read_password_message().await?;
85 1 : let password = msg
86 1 : .strip_suffix(&[0])
87 1 : .ok_or(AuthError::MalformedPassword("missing terminator"))?;
88 :
89 1 : let payload = PasswordHackPayload::parse(password)
90 : // If we ended up here and the payload is malformed, it means that
91 : // the user neither enabled SNI nor resorted to any other method
92 : // for passing the project name we rely on. We should show them
93 : // the most helpful error message and point to the documentation.
94 1 : .ok_or(AuthError::MissingEndpointName)?;
95 :
96 1 : Ok(payload)
97 1 : }
98 : }
99 :
100 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
101 : /// Perform user authentication. Raise an error in case authentication failed.
102 1 : pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
103 1 : self.stream
104 1 : .write_message(BeMessage::AuthenticationCleartextPassword);
105 1 : self.stream.flush().await?;
106 :
107 1 : let msg = self.stream.read_password_message().await?;
108 1 : let password = msg
109 1 : .strip_suffix(&[0])
110 1 : .ok_or(AuthError::MalformedPassword("missing terminator"))?;
111 :
112 1 : let outcome = validate_password_and_exchange(
113 1 : &self.state.pool,
114 1 : self.state.endpoint,
115 1 : self.state.role,
116 1 : password,
117 1 : self.state.secret,
118 1 : )
119 1 : .await?;
120 :
121 1 : if let sasl::Outcome::Success(_) = &outcome {
122 1 : self.stream.write_message(BeMessage::AuthenticationOk);
123 1 : }
124 :
125 1 : Ok(outcome)
126 1 : }
127 : }
128 :
129 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
130 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
131 : /// Perform user authentication. Raise an error in case authentication failed.
132 13 : pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
133 13 : let Scram(secret, ctx) = self.state;
134 13 : let channel_binding = self.tls_server_end_point;
135 :
136 : // send sasl message.
137 : {
138 : // pause the timer while we communicate with the client
139 13 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
140 :
141 13 : let sasl = self.state.first_message(channel_binding.supported());
142 13 : self.stream.write_message(sasl);
143 13 : self.stream.flush().await?;
144 : }
145 :
146 : // complete sasl handshake.
147 13 : sasl::authenticate(ctx, self.stream, |method| {
148 : // Currently, the only supported SASL method is SCRAM.
149 12 : match method {
150 12 : SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
151 6 : SCRAM_SHA_256_PLUS => {
152 6 : ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus);
153 6 : }
154 0 : method => return Err(sasl::Error::BadAuthMethod(method.into())),
155 : }
156 :
157 : // TODO: make this a metric instead
158 12 : info!("client chooses {}", method);
159 :
160 12 : Ok(scram::Exchange::new(secret, rand::random, channel_binding))
161 12 : })
162 13 : .await
163 13 : .map_err(AuthError::Sasl)
164 13 : }
165 : }
166 :
167 2 : pub(crate) async fn validate_password_and_exchange(
168 2 : pool: &ThreadPool,
169 2 : endpoint: EndpointIdInt,
170 2 : role: RoleNameInt,
171 2 : password: &[u8],
172 2 : secret: AuthSecret,
173 2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
174 2 : match secret {
175 : // perform scram authentication as both client and server to validate the keys
176 2 : AuthSecret::Scram(scram_secret) => {
177 2 : let outcome =
178 2 : crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?;
179 :
180 2 : let client_key = match outcome {
181 2 : sasl::Outcome::Success(client_key) => client_key,
182 0 : sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
183 : };
184 :
185 2 : let keys = crate::compute::ScramKeys {
186 2 : client_key: client_key.as_bytes(),
187 2 : server_key: scram_secret.server_key.as_bytes(),
188 2 : };
189 :
190 2 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
191 2 : postgres_client::config::AuthKeys::ScramSha256(keys),
192 2 : )))
193 : }
194 : }
195 2 : }
|