TLA Line data Source code
1 : //! Main authentication flow.
2 :
3 : use super::{AuthErrorImpl, PasswordHackPayload};
4 : use crate::{sasl, scram, stream::PqStream};
5 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
6 : use std::io;
7 : use tokio::io::{AsyncRead, AsyncWrite};
8 :
9 : /// Every authentication selector is supposed to implement this trait.
10 : pub trait AuthMethod {
11 : /// Any authentication selector should provide initial backend message
12 : /// containing auth method name and parameters, e.g. md5 salt.
13 : fn first_message(&self) -> BeMessage<'_>;
14 : }
15 :
16 : /// Initial state of [`AuthFlow`].
17 : pub struct Begin;
18 :
19 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
20 : pub struct Scram<'a>(pub &'a scram::ServerSecret);
21 :
22 : impl AuthMethod for Scram<'_> {
23 : #[inline(always)]
24 CBC 31 : fn first_message(&self) -> BeMessage<'_> {
25 31 : Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
26 31 : }
27 : }
28 :
29 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
30 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
31 : pub struct PasswordHack;
32 :
33 : impl AuthMethod for PasswordHack {
34 : #[inline(always)]
35 3 : fn first_message(&self) -> BeMessage<'_> {
36 3 : Be::AuthenticationCleartextPassword
37 3 : }
38 : }
39 :
40 : /// Use clear-text password auth called `password` in docs
41 : /// <https://www.postgresql.org/docs/current/auth-password.html>
42 : pub struct CleartextPassword;
43 :
44 : impl AuthMethod for CleartextPassword {
45 : #[inline(always)]
46 UBC 0 : fn first_message(&self) -> BeMessage<'_> {
47 0 : Be::AuthenticationCleartextPassword
48 0 : }
49 : }
50 :
51 : /// This wrapper for [`PqStream`] performs client authentication.
52 : #[must_use]
53 : pub struct AuthFlow<'a, Stream, State> {
54 : /// The underlying stream which implements libpq's protocol.
55 : stream: &'a mut PqStream<Stream>,
56 : /// State might contain ancillary data (see [`Self::begin`]).
57 : state: State,
58 : }
59 :
60 : /// Initial state of the stream wrapper.
61 : impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
62 : /// Create a new wrapper for client authentication.
63 CBC 34 : pub fn new(stream: &'a mut PqStream<S>) -> Self {
64 34 : Self {
65 34 : stream,
66 34 : state: Begin,
67 34 : }
68 34 : }
69 :
70 : /// Move to the next step by sending auth method's name & params to client.
71 34 : pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
72 34 : self.stream.write_message(&method.first_message()).await?;
73 :
74 34 : Ok(AuthFlow {
75 34 : stream: self.stream,
76 34 : state: method,
77 34 : })
78 34 : }
79 : }
80 :
81 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
82 : /// Perform user authentication. Raise an error in case authentication failed.
83 3 : pub async fn authenticate(self) -> super::Result<PasswordHackPayload> {
84 3 : let msg = self.stream.read_password_message().await?;
85 3 : let password = msg
86 3 : .strip_suffix(&[0])
87 3 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
88 :
89 3 : let payload = PasswordHackPayload::parse(password)
90 3 : // If we ended up here and the payload is malformed, it means that
91 3 : // the user neither enabled SNI nor resorted to any other method
92 3 : // for passing the project name we rely on. We should show them
93 3 : // the most helpful error message and point to the documentation.
94 3 : .ok_or(AuthErrorImpl::MissingEndpointName)?;
95 :
96 2 : Ok(payload)
97 3 : }
98 : }
99 :
100 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
101 : /// Perform user authentication. Raise an error in case authentication failed.
102 UBC 0 : pub async fn authenticate(self) -> super::Result<Vec<u8>> {
103 0 : let msg = self.stream.read_password_message().await?;
104 0 : let password = msg
105 0 : .strip_suffix(&[0])
106 0 : .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
107 :
108 0 : Ok(password.to_vec())
109 0 : }
110 : }
111 :
112 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
113 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
114 : /// Perform user authentication. Raise an error in case authentication failed.
115 CBC 31 : pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
116 : // Initial client message contains the chosen auth method's name.
117 31 : let msg = self.stream.read_password_message().await?;
118 31 : let sasl = sasl::FirstMessage::parse(&msg)
119 31 : .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
120 :
121 : // Currently, the only supported SASL method is SCRAM.
122 31 : if !scram::METHODS.contains(&sasl.method) {
123 UBC 0 : return Err(super::AuthError::bad_auth_method(sasl.method));
124 CBC 31 : }
125 31 :
126 31 : let secret = self.state.0;
127 31 : let outcome = sasl::SaslStream::new(self.stream, sasl.message)
128 31 : .authenticate(scram::Exchange::new(secret, rand::random, None))
129 31 : .await?;
130 :
131 31 : Ok(outcome)
132 31 : }
133 : }
|