LCOV - code coverage report
Current view: top level - proxy/src/auth - flow.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 80.4 % 56 45
Test Date: 2023-09-06 10:18:01 Functions: 37.2 % 43 16

            Line data    Source code
       1              : //! Main authentication flow.
       2              : 
       3              : use super::{AuthErrorImpl, PasswordHackPayload};
       4              : use crate::{sasl, scram, stream::PqStream};
       5              : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
       6              : use std::io;
       7              : use tokio::io::{AsyncRead, AsyncWrite};
       8              : 
       9              : /// Every authentication selector is supposed to implement this trait.
      10              : pub trait AuthMethod {
      11              :     /// Any authentication selector should provide initial backend message
      12              :     /// containing auth method name and parameters, e.g. md5 salt.
      13              :     fn first_message(&self) -> BeMessage<'_>;
      14              : }
      15              : 
      16              : /// Initial state of [`AuthFlow`].
      17              : pub struct Begin;
      18              : 
      19              : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
      20              : pub struct Scram<'a>(pub &'a scram::ServerSecret);
      21              : 
      22              : impl AuthMethod for Scram<'_> {
      23              :     #[inline(always)]
      24           29 :     fn first_message(&self) -> BeMessage<'_> {
      25           29 :         Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
      26           29 :     }
      27              : }
      28              : 
      29              : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
      30              : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
      31              : pub struct PasswordHack;
      32              : 
      33              : impl AuthMethod for PasswordHack {
      34              :     #[inline(always)]
      35            3 :     fn first_message(&self) -> BeMessage<'_> {
      36            3 :         Be::AuthenticationCleartextPassword
      37            3 :     }
      38              : }
      39              : 
      40              : /// Use clear-text password auth called `password` in docs
      41              : /// <https://www.postgresql.org/docs/current/auth-password.html>
      42              : pub struct CleartextPassword;
      43              : 
      44              : impl AuthMethod for CleartextPassword {
      45              :     #[inline(always)]
      46            0 :     fn first_message(&self) -> BeMessage<'_> {
      47            0 :         Be::AuthenticationCleartextPassword
      48            0 :     }
      49              : }
      50              : 
      51              : /// This wrapper for [`PqStream`] performs client authentication.
      52              : #[must_use]
      53              : pub struct AuthFlow<'a, Stream, State> {
      54              :     /// The underlying stream which implements libpq's protocol.
      55              :     stream: &'a mut PqStream<Stream>,
      56              :     /// State might contain ancillary data (see [`Self::begin`]).
      57              :     state: State,
      58              : }
      59              : 
      60              : /// Initial state of the stream wrapper.
      61              : impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
      62              :     /// Create a new wrapper for client authentication.
      63           32 :     pub fn new(stream: &'a mut PqStream<S>) -> Self {
      64           32 :         Self {
      65           32 :             stream,
      66           32 :             state: Begin,
      67           32 :         }
      68           32 :     }
      69              : 
      70              :     /// Move to the next step by sending auth method's name & params to client.
      71           32 :     pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
      72           32 :         self.stream.write_message(&method.first_message()).await?;
      73              : 
      74           32 :         Ok(AuthFlow {
      75           32 :             stream: self.stream,
      76           32 :             state: method,
      77           32 :         })
      78           32 :     }
      79              : }
      80              : 
      81              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
      82              :     /// Perform user authentication. Raise an error in case authentication failed.
      83            3 :     pub async fn authenticate(self) -> super::Result<PasswordHackPayload> {
      84            3 :         let msg = self.stream.read_password_message().await?;
      85            3 :         let password = msg
      86            3 :             .strip_suffix(&[0])
      87            3 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
      88              : 
      89            3 :         let payload = PasswordHackPayload::parse(password)
      90            3 :             // If we ended up here and the payload is malformed, it means that
      91            3 :             // the user neither enabled SNI nor resorted to any other method
      92            3 :             // for passing the project name we rely on. We should show them
      93            3 :             // the most helpful error message and point to the documentation.
      94            3 :             .ok_or(AuthErrorImpl::MissingEndpointName)?;
      95              : 
      96            2 :         Ok(payload)
      97            3 :     }
      98              : }
      99              : 
     100              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
     101              :     /// Perform user authentication. Raise an error in case authentication failed.
     102            0 :     pub async fn authenticate(self) -> super::Result<Vec<u8>> {
     103            0 :         let msg = self.stream.read_password_message().await?;
     104            0 :         let password = msg
     105            0 :             .strip_suffix(&[0])
     106            0 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
     107              : 
     108            0 :         Ok(password.to_vec())
     109            0 :     }
     110              : }
     111              : 
     112              : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
     113              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
     114              :     /// Perform user authentication. Raise an error in case authentication failed.
     115           29 :     pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
     116              :         // Initial client message contains the chosen auth method's name.
     117           29 :         let msg = self.stream.read_password_message().await?;
     118           29 :         let sasl = sasl::FirstMessage::parse(&msg)
     119           29 :             .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
     120              : 
     121              :         // Currently, the only supported SASL method is SCRAM.
     122           29 :         if !scram::METHODS.contains(&sasl.method) {
     123            0 :             return Err(super::AuthError::bad_auth_method(sasl.method));
     124           29 :         }
     125           29 : 
     126           29 :         let secret = self.state.0;
     127           29 :         let outcome = sasl::SaslStream::new(self.stream, sasl.message)
     128           29 :             .authenticate(scram::Exchange::new(secret, rand::random, None))
     129           29 :             .await?;
     130              : 
     131           29 :         Ok(outcome)
     132           29 :     }
     133              : }
        

Generated by: LCOV version 2.1-beta