Line data Source code
1 : //! Main authentication flow.
2 :
3 : use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
4 : use crate::{
5 : config::TlsServerEndPoint,
6 : console::AuthSecret,
7 : context::RequestMonitoring,
8 : sasl, scram,
9 : stream::{PqStream, Stream},
10 : };
11 : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
12 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
13 : use std::io;
14 : use tokio::io::{AsyncRead, AsyncWrite};
15 : use tracing::info;
16 :
17 : /// Every authentication selector is supposed to implement this trait.
18 : pub trait AuthMethod {
19 : /// Any authentication selector should provide initial backend message
20 : /// containing auth method name and parameters, e.g. md5 salt.
21 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
22 : }
23 :
24 : /// Initial state of [`AuthFlow`].
25 : pub struct Begin;
26 :
27 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
28 : pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
29 :
30 : impl AuthMethod for Scram<'_> {
31 : #[inline(always)]
32 63 : fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
33 63 : if channel_binding {
34 63 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
35 : } else {
36 0 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
37 0 : scram::METHODS_WITHOUT_PLUS,
38 0 : ))
39 : }
40 63 : }
41 : }
42 :
43 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
44 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
45 : pub struct PasswordHack;
46 :
47 : impl AuthMethod for PasswordHack {
48 : #[inline(always)]
49 3 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
50 3 : Be::AuthenticationCleartextPassword
51 3 : }
52 : }
53 :
54 : /// Use clear-text password auth called `password` in docs
55 : /// <https://www.postgresql.org/docs/current/auth-password.html>
56 : pub struct CleartextPassword(pub AuthSecret);
57 :
58 : impl AuthMethod for CleartextPassword {
59 : #[inline(always)]
60 0 : fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
61 0 : Be::AuthenticationCleartextPassword
62 0 : }
63 : }
64 :
65 : /// This wrapper for [`PqStream`] performs client authentication.
66 : #[must_use]
67 : pub struct AuthFlow<'a, S, State> {
68 : /// The underlying stream which implements libpq's protocol.
69 : stream: &'a mut PqStream<Stream<S>>,
70 : /// State might contain ancillary data (see [`Self::begin`]).
71 : state: State,
72 : tls_server_end_point: TlsServerEndPoint,
73 : }
74 :
75 : /// Initial state of the stream wrapper.
76 : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
77 : /// Create a new wrapper for client authentication.
78 66 : pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
79 66 : let tls_server_end_point = stream.get_ref().tls_server_end_point();
80 66 :
81 66 : Self {
82 66 : stream,
83 66 : state: Begin,
84 66 : tls_server_end_point,
85 66 : }
86 66 : }
87 :
88 : /// Move to the next step by sending auth method's name & params to client.
89 66 : pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
90 66 : self.stream
91 66 : .write_message(&method.first_message(self.tls_server_end_point.supported()))
92 0 : .await?;
93 :
94 66 : Ok(AuthFlow {
95 66 : stream: self.stream,
96 66 : state: method,
97 66 : tls_server_end_point: self.tls_server_end_point,
98 66 : })
99 66 : }
100 : }
101 :
102 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
103 : /// Perform user authentication. Raise an error in case authentication failed.
104 3 : pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
105 3 : let msg = self.stream.read_password_message().await?;
106 3 : let password = msg
107 3 : .strip_suffix(&[0])
108 3 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
109 :
110 3 : let payload = PasswordHackPayload::parse(password)
111 3 : // If we ended up here and the payload is malformed, it means that
112 3 : // the user neither enabled SNI nor resorted to any other method
113 3 : // for passing the project name we rely on. We should show them
114 3 : // the most helpful error message and point to the documentation.
115 3 : .ok_or(AuthErrorImpl::MissingEndpointName)?;
116 :
117 2 : Ok(payload)
118 3 : }
119 : }
120 :
121 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
122 : /// Perform user authentication. Raise an error in case authentication failed.
123 0 : pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
124 0 : let msg = self.stream.read_password_message().await?;
125 0 : let password = msg
126 0 : .strip_suffix(&[0])
127 0 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
128 :
129 0 : let outcome = validate_password_and_exchange(password, self.state.0)?;
130 :
131 0 : if let sasl::Outcome::Success(_) = &outcome {
132 0 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
133 0 : }
134 :
135 0 : Ok(outcome)
136 0 : }
137 : }
138 :
139 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
140 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
141 : /// Perform user authentication. Raise an error in case authentication failed.
142 63 : pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
143 63 : let Scram(secret, ctx) = self.state;
144 63 :
145 63 : // pause the timer while we communicate with the client
146 63 : let _paused = ctx.latency_timer.pause();
147 :
148 : // Initial client message contains the chosen auth method's name.
149 63 : let msg = self.stream.read_password_message().await?;
150 61 : let sasl = sasl::FirstMessage::parse(&msg)
151 61 : .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
152 :
153 : // Currently, the only supported SASL method is SCRAM.
154 61 : if !scram::METHODS.contains(&sasl.method) {
155 0 : return Err(super::AuthError::bad_auth_method(sasl.method));
156 61 : }
157 61 :
158 61 : match sasl.method {
159 61 : SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
160 51 : SCRAM_SHA_256_PLUS => {
161 51 : ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
162 : }
163 0 : _ => {}
164 : }
165 39 : info!("client chooses {}", sasl.method);
166 :
167 61 : let outcome = sasl::SaslStream::new(self.stream, sasl.message)
168 61 : .authenticate(scram::Exchange::new(
169 61 : secret,
170 61 : rand::random,
171 61 : self.tls_server_end_point,
172 61 : ))
173 59 : .await?;
174 :
175 51 : if let sasl::Outcome::Success(_) = &outcome {
176 46 : self.stream.write_message_noflush(&Be::AuthenticationOk)?;
177 5 : }
178 :
179 51 : Ok(outcome)
180 63 : }
181 : }
182 :
183 47 : pub(crate) fn validate_password_and_exchange(
184 47 : password: &[u8],
185 47 : secret: AuthSecret,
186 47 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
187 47 : match secret {
188 : #[cfg(any(test, feature = "testing"))]
189 : AuthSecret::Md5(_) => {
190 : // test only
191 0 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
192 0 : password.to_owned(),
193 0 : )))
194 : }
195 : // perform scram authentication as both client and server to validate the keys
196 47 : AuthSecret::Scram(scram_secret) => {
197 47 : use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
198 47 : let sasl_client = ScramSha256::new(password, ChannelBinding::unsupported());
199 47 : let outcome = crate::scram::exchange(
200 47 : &scram_secret,
201 47 : sasl_client,
202 47 : crate::config::TlsServerEndPoint::Undefined,
203 47 : )?;
204 :
205 47 : let client_key = match outcome {
206 46 : sasl::Outcome::Success(client_key) => client_key,
207 1 : sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
208 : };
209 :
210 46 : let keys = crate::compute::ScramKeys {
211 46 : client_key: client_key.as_bytes(),
212 46 : server_key: scram_secret.server_key.as_bytes(),
213 46 : };
214 46 :
215 46 : Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
216 46 : tokio_postgres::config::AuthKeys::ScramSha256(keys),
217 46 : )))
218 : }
219 : }
220 47 : }
|