LCOV - code coverage report
Current view: top level - proxy/src/auth - flow.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 78.1 % 105 82
Test Date: 2024-02-07 07:37:29 Functions: 37.5 % 48 18

            Line data    Source code
       1              : //! Main authentication flow.
       2              : 
       3              : use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
       4              : use crate::{
       5              :     config::TlsServerEndPoint,
       6              :     console::AuthSecret,
       7              :     sasl, scram,
       8              :     stream::{PqStream, Stream},
       9              : };
      10              : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
      11              : use std::io;
      12              : use tokio::io::{AsyncRead, AsyncWrite};
      13              : use tracing::info;
      14              : 
      15              : /// Every authentication selector is supposed to implement this trait.
      16              : pub trait AuthMethod {
      17              :     /// Any authentication selector should provide initial backend message
      18              :     /// containing auth method name and parameters, e.g. md5 salt.
      19              :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
      20              : }
      21              : 
      22              : /// Initial state of [`AuthFlow`].
      23              : pub struct Begin;
      24              : 
      25              : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
      26              : pub struct Scram<'a>(pub &'a scram::ServerSecret);
      27              : 
      28              : impl AuthMethod for Scram<'_> {
      29              :     #[inline(always)]
      30           61 :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
      31           61 :         if channel_binding {
      32           61 :             Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
      33              :         } else {
      34            0 :             Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
      35            0 :                 scram::METHODS_WITHOUT_PLUS,
      36            0 :             ))
      37              :         }
      38           61 :     }
      39              : }
      40              : 
      41              : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
      42              : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
      43              : pub struct PasswordHack;
      44              : 
      45              : impl AuthMethod for PasswordHack {
      46              :     #[inline(always)]
      47            3 :     fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
      48            3 :         Be::AuthenticationCleartextPassword
      49            3 :     }
      50              : }
      51              : 
      52              : /// Use clear-text password auth called `password` in docs
      53              : /// <https://www.postgresql.org/docs/current/auth-password.html>
      54              : pub struct CleartextPassword(pub AuthSecret);
      55              : 
      56              : impl AuthMethod for CleartextPassword {
      57              :     #[inline(always)]
      58            0 :     fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
      59            0 :         Be::AuthenticationCleartextPassword
      60            0 :     }
      61              : }
      62              : 
      63              : /// This wrapper for [`PqStream`] performs client authentication.
      64              : #[must_use]
      65              : pub struct AuthFlow<'a, S, State> {
      66              :     /// The underlying stream which implements libpq's protocol.
      67              :     stream: &'a mut PqStream<Stream<S>>,
      68              :     /// State might contain ancillary data (see [`Self::begin`]).
      69              :     state: State,
      70              :     tls_server_end_point: TlsServerEndPoint,
      71              : }
      72              : 
      73              : /// Initial state of the stream wrapper.
      74              : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
      75              :     /// Create a new wrapper for client authentication.
      76           64 :     pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
      77           64 :         let tls_server_end_point = stream.get_ref().tls_server_end_point();
      78           64 : 
      79           64 :         Self {
      80           64 :             stream,
      81           64 :             state: Begin,
      82           64 :             tls_server_end_point,
      83           64 :         }
      84           64 :     }
      85              : 
      86              :     /// Move to the next step by sending auth method's name & params to client.
      87           64 :     pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
      88           64 :         self.stream
      89           64 :             .write_message(&method.first_message(self.tls_server_end_point.supported()))
      90            0 :             .await?;
      91              : 
      92           64 :         Ok(AuthFlow {
      93           64 :             stream: self.stream,
      94           64 :             state: method,
      95           64 :             tls_server_end_point: self.tls_server_end_point,
      96           64 :         })
      97           64 :     }
      98              : }
      99              : 
     100              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
     101              :     /// Perform user authentication. Raise an error in case authentication failed.
     102            3 :     pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
     103            3 :         let msg = self.stream.read_password_message().await?;
     104            3 :         let password = msg
     105            3 :             .strip_suffix(&[0])
     106            3 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
     107              : 
     108            3 :         let payload = PasswordHackPayload::parse(password)
     109            3 :             // If we ended up here and the payload is malformed, it means that
     110            3 :             // the user neither enabled SNI nor resorted to any other method
     111            3 :             // for passing the project name we rely on. We should show them
     112            3 :             // the most helpful error message and point to the documentation.
     113            3 :             .ok_or(AuthErrorImpl::MissingEndpointName)?;
     114              : 
     115            2 :         Ok(payload)
     116            3 :     }
     117              : }
     118              : 
     119              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
     120              :     /// Perform user authentication. Raise an error in case authentication failed.
     121            0 :     pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     122            0 :         let msg = self.stream.read_password_message().await?;
     123            0 :         let password = msg
     124            0 :             .strip_suffix(&[0])
     125            0 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
     126              : 
     127            0 :         let outcome = validate_password_and_exchange(password, self.state.0)?;
     128              : 
     129            0 :         if let sasl::Outcome::Success(_) = &outcome {
     130            0 :             self.stream.write_message_noflush(&Be::AuthenticationOk)?;
     131            0 :         }
     132              : 
     133            0 :         Ok(outcome)
     134            0 :     }
     135              : }
     136              : 
     137              : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
     138              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
     139              :     /// Perform user authentication. Raise an error in case authentication failed.
     140           61 :     pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
     141              :         // Initial client message contains the chosen auth method's name.
     142           61 :         let msg = self.stream.read_password_message().await?;
     143           59 :         let sasl = sasl::FirstMessage::parse(&msg)
     144           59 :             .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
     145              : 
     146              :         // Currently, the only supported SASL method is SCRAM.
     147           59 :         if !scram::METHODS.contains(&sasl.method) {
     148            0 :             return Err(super::AuthError::bad_auth_method(sasl.method));
     149           59 :         }
     150              : 
     151           37 :         info!("client chooses {}", sasl.method);
     152              : 
     153           59 :         let secret = self.state.0;
     154           59 :         let outcome = sasl::SaslStream::new(self.stream, sasl.message)
     155           59 :             .authenticate(scram::Exchange::new(
     156           59 :                 secret,
     157           59 :                 rand::random,
     158           59 :                 self.tls_server_end_point,
     159           59 :             ))
     160           57 :             .await?;
     161              : 
     162           49 :         if let sasl::Outcome::Success(_) = &outcome {
     163           44 :             self.stream.write_message_noflush(&Be::AuthenticationOk)?;
     164            5 :         }
     165              : 
     166           49 :         Ok(outcome)
     167           61 :     }
     168              : }
     169              : 
     170            2 : pub(super) fn validate_password_and_exchange(
     171            2 :     password: &[u8],
     172            2 :     secret: AuthSecret,
     173            2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     174            2 :     match secret {
     175              :         #[cfg(any(test, feature = "testing"))]
     176              :         AuthSecret::Md5(_) => {
     177              :             // test only
     178            0 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
     179            0 :                 password.to_owned(),
     180            0 :             )))
     181              :         }
     182              :         // perform scram authentication as both client and server to validate the keys
     183            2 :         AuthSecret::Scram(scram_secret) => {
     184            2 :             use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
     185            2 :             let sasl_client = ScramSha256::new(password, ChannelBinding::unsupported());
     186            2 :             let outcome = crate::scram::exchange(
     187            2 :                 &scram_secret,
     188            2 :                 sasl_client,
     189            2 :                 crate::config::TlsServerEndPoint::Undefined,
     190            2 :             )?;
     191              : 
     192            2 :             let client_key = match outcome {
     193            2 :                 sasl::Outcome::Success(client_key) => client_key,
     194            0 :                 sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
     195              :             };
     196              : 
     197            2 :             let keys = crate::compute::ScramKeys {
     198            2 :                 client_key: client_key.as_bytes(),
     199            2 :                 server_key: scram_secret.server_key.as_bytes(),
     200            2 :             };
     201            2 : 
     202            2 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
     203            2 :                 tokio_postgres::config::AuthKeys::ScramSha256(keys),
     204            2 :             )))
     205              :         }
     206              :     }
     207            2 : }
        

Generated by: LCOV version 2.1-beta