LCOV - differential code coverage report
Current view: top level - proxy/src/auth - flow.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 80.4 % 56 45 11 45
Current Date: 2023-10-19 02:04:12 Functions: 37.2 % 43 16 27 16
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : //! Main authentication flow.
       2                 : 
       3                 : use super::{AuthErrorImpl, PasswordHackPayload};
       4                 : use crate::{sasl, scram, stream::PqStream};
       5                 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
       6                 : use std::io;
       7                 : use tokio::io::{AsyncRead, AsyncWrite};
       8                 : 
       9                 : /// Every authentication selector is supposed to implement this trait.
      10                 : pub trait AuthMethod {
      11                 :     /// Any authentication selector should provide initial backend message
      12                 :     /// containing auth method name and parameters, e.g. md5 salt.
      13                 :     fn first_message(&self) -> BeMessage<'_>;
      14                 : }
      15                 : 
      16                 : /// Initial state of [`AuthFlow`].
      17                 : pub struct Begin;
      18                 : 
      19                 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
      20                 : pub struct Scram<'a>(pub &'a scram::ServerSecret);
      21                 : 
      22                 : impl AuthMethod for Scram<'_> {
      23                 :     #[inline(always)]
      24 CBC          31 :     fn first_message(&self) -> BeMessage<'_> {
      25              31 :         Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
      26              31 :     }
      27                 : }
      28                 : 
      29                 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
      30                 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
      31                 : pub struct PasswordHack;
      32                 : 
      33                 : impl AuthMethod for PasswordHack {
      34                 :     #[inline(always)]
      35               3 :     fn first_message(&self) -> BeMessage<'_> {
      36               3 :         Be::AuthenticationCleartextPassword
      37               3 :     }
      38                 : }
      39                 : 
      40                 : /// Use clear-text password auth called `password` in docs
      41                 : /// <https://www.postgresql.org/docs/current/auth-password.html>
      42                 : pub struct CleartextPassword;
      43                 : 
      44                 : impl AuthMethod for CleartextPassword {
      45                 :     #[inline(always)]
      46 UBC           0 :     fn first_message(&self) -> BeMessage<'_> {
      47               0 :         Be::AuthenticationCleartextPassword
      48               0 :     }
      49                 : }
      50                 : 
      51                 : /// This wrapper for [`PqStream`] performs client authentication.
      52                 : #[must_use]
      53                 : pub struct AuthFlow<'a, Stream, State> {
      54                 :     /// The underlying stream which implements libpq's protocol.
      55                 :     stream: &'a mut PqStream<Stream>,
      56                 :     /// State might contain ancillary data (see [`Self::begin`]).
      57                 :     state: State,
      58                 : }
      59                 : 
      60                 : /// Initial state of the stream wrapper.
      61                 : impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
      62                 :     /// Create a new wrapper for client authentication.
      63 CBC          34 :     pub fn new(stream: &'a mut PqStream<S>) -> Self {
      64              34 :         Self {
      65              34 :             stream,
      66              34 :             state: Begin,
      67              34 :         }
      68              34 :     }
      69                 : 
      70                 :     /// Move to the next step by sending auth method's name & params to client.
      71              34 :     pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
      72              34 :         self.stream.write_message(&method.first_message()).await?;
      73                 : 
      74              34 :         Ok(AuthFlow {
      75              34 :             stream: self.stream,
      76              34 :             state: method,
      77              34 :         })
      78              34 :     }
      79                 : }
      80                 : 
      81                 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
      82                 :     /// Perform user authentication. Raise an error in case authentication failed.
      83               3 :     pub async fn authenticate(self) -> super::Result<PasswordHackPayload> {
      84               3 :         let msg = self.stream.read_password_message().await?;
      85               3 :         let password = msg
      86               3 :             .strip_suffix(&[0])
      87               3 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
      88                 : 
      89               3 :         let payload = PasswordHackPayload::parse(password)
      90               3 :             // If we ended up here and the payload is malformed, it means that
      91               3 :             // the user neither enabled SNI nor resorted to any other method
      92               3 :             // for passing the project name we rely on. We should show them
      93               3 :             // the most helpful error message and point to the documentation.
      94               3 :             .ok_or(AuthErrorImpl::MissingEndpointName)?;
      95                 : 
      96               2 :         Ok(payload)
      97               3 :     }
      98                 : }
      99                 : 
     100                 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
     101                 :     /// Perform user authentication. Raise an error in case authentication failed.
     102 UBC           0 :     pub async fn authenticate(self) -> super::Result<Vec<u8>> {
     103               0 :         let msg = self.stream.read_password_message().await?;
     104               0 :         let password = msg
     105               0 :             .strip_suffix(&[0])
     106               0 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
     107                 : 
     108               0 :         Ok(password.to_vec())
     109               0 :     }
     110                 : }
     111                 : 
     112                 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
     113                 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
     114                 :     /// Perform user authentication. Raise an error in case authentication failed.
     115 CBC          31 :     pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
     116                 :         // Initial client message contains the chosen auth method's name.
     117              31 :         let msg = self.stream.read_password_message().await?;
     118              31 :         let sasl = sasl::FirstMessage::parse(&msg)
     119              31 :             .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
     120                 : 
     121                 :         // Currently, the only supported SASL method is SCRAM.
     122              31 :         if !scram::METHODS.contains(&sasl.method) {
     123 UBC           0 :             return Err(super::AuthError::bad_auth_method(sasl.method));
     124 CBC          31 :         }
     125              31 : 
     126              31 :         let secret = self.state.0;
     127              31 :         let outcome = sasl::SaslStream::new(self.stream, sasl.message)
     128              31 :             .authenticate(scram::Exchange::new(secret, rand::random, None))
     129              31 :             .await?;
     130                 : 
     131              31 :         Ok(outcome)
     132              31 :     }
     133                 : }
        

Generated by: LCOV version 2.1-beta