LCOV - differential code coverage report
Current view: top level - proxy/src/auth - flow.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 78.1 % 105 82 23 82
Current Date: 2024-01-09 02:06:09 Functions: 37.5 % 48 18 30 18
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Main authentication flow.
       2                 : 
       3                 : use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
       4                 : use crate::{
       5                 :     config::TlsServerEndPoint,
       6                 :     console::AuthSecret,
       7                 :     sasl, scram,
       8                 :     stream::{PqStream, Stream},
       9                 : };
      10                 : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
      11                 : use std::io;
      12                 : use tokio::io::{AsyncRead, AsyncWrite};
      13                 : use tracing::info;
      14                 : 
      15                 : /// Every authentication selector is supposed to implement this trait.
      16                 : pub trait AuthMethod {
      17                 :     /// Any authentication selector should provide initial backend message
      18                 :     /// containing auth method name and parameters, e.g. md5 salt.
      19                 :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
      20                 : }
      21                 : 
      22                 : /// Initial state of [`AuthFlow`].
      23                 : pub struct Begin;
      24                 : 
      25                 : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
      26                 : pub struct Scram<'a>(pub &'a scram::ServerSecret);
      27                 : 
      28                 : impl AuthMethod for Scram<'_> {
      29                 :     #[inline(always)]
      30 CBC          48 :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
      31              48 :         if channel_binding {
      32              48 :             Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
      33                 :         } else {
      34 UBC           0 :             Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
      35               0 :                 scram::METHODS_WITHOUT_PLUS,
      36               0 :             ))
      37                 :         }
      38 CBC          48 :     }
      39                 : }
      40                 : 
      41                 : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
      42                 : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
      43                 : pub struct PasswordHack;
      44                 : 
      45                 : impl AuthMethod for PasswordHack {
      46                 :     #[inline(always)]
      47               3 :     fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
      48               3 :         Be::AuthenticationCleartextPassword
      49               3 :     }
      50                 : }
      51                 : 
      52                 : /// Use clear-text password auth called `password` in docs
      53                 : /// <https://www.postgresql.org/docs/current/auth-password.html>
      54                 : pub struct CleartextPassword(pub AuthSecret);
      55                 : 
      56                 : impl AuthMethod for CleartextPassword {
      57                 :     #[inline(always)]
      58 UBC           0 :     fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
      59               0 :         Be::AuthenticationCleartextPassword
      60               0 :     }
      61                 : }
      62                 : 
      63                 : /// This wrapper for [`PqStream`] performs client authentication.
      64                 : #[must_use]
      65                 : pub struct AuthFlow<'a, S, State> {
      66                 :     /// The underlying stream which implements libpq's protocol.
      67                 :     stream: &'a mut PqStream<Stream<S>>,
      68                 :     /// State might contain ancillary data (see [`Self::begin`]).
      69                 :     state: State,
      70                 :     tls_server_end_point: TlsServerEndPoint,
      71                 : }
      72                 : 
      73                 : /// Initial state of the stream wrapper.
      74                 : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
      75                 :     /// Create a new wrapper for client authentication.
      76 CBC          51 :     pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
      77              51 :         let tls_server_end_point = stream.get_ref().tls_server_end_point();
      78              51 : 
      79              51 :         Self {
      80              51 :             stream,
      81              51 :             state: Begin,
      82              51 :             tls_server_end_point,
      83              51 :         }
      84              51 :     }
      85                 : 
      86                 :     /// Move to the next step by sending auth method's name & params to client.
      87              51 :     pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
      88              51 :         self.stream
      89              51 :             .write_message(&method.first_message(self.tls_server_end_point.supported()))
      90 UBC           0 :             .await?;
      91                 : 
      92 CBC          51 :         Ok(AuthFlow {
      93              51 :             stream: self.stream,
      94              51 :             state: method,
      95              51 :             tls_server_end_point: self.tls_server_end_point,
      96              51 :         })
      97              51 :     }
      98                 : }
      99                 : 
     100                 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
     101                 :     /// Perform user authentication. Raise an error in case authentication failed.
     102               3 :     pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
     103               3 :         let msg = self.stream.read_password_message().await?;
     104               3 :         let password = msg
     105               3 :             .strip_suffix(&[0])
     106               3 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
     107                 : 
     108               3 :         let payload = PasswordHackPayload::parse(password)
     109               3 :             // If we ended up here and the payload is malformed, it means that
     110               3 :             // the user neither enabled SNI nor resorted to any other method
     111               3 :             // for passing the project name we rely on. We should show them
     112               3 :             // the most helpful error message and point to the documentation.
     113               3 :             .ok_or(AuthErrorImpl::MissingEndpointName)?;
     114                 : 
     115               2 :         Ok(payload)
     116               3 :     }
     117                 : }
     118                 : 
     119                 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
     120                 :     /// Perform user authentication. Raise an error in case authentication failed.
     121 UBC           0 :     pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     122               0 :         let msg = self.stream.read_password_message().await?;
     123               0 :         let password = msg
     124               0 :             .strip_suffix(&[0])
     125               0 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
     126                 : 
     127               0 :         let outcome = validate_password_and_exchange(password, self.state.0)?;
     128                 : 
     129               0 :         if let sasl::Outcome::Success(_) = &outcome {
     130               0 :             self.stream.write_message_noflush(&Be::AuthenticationOk)?;
     131               0 :         }
     132                 : 
     133               0 :         Ok(outcome)
     134               0 :     }
     135                 : }
     136                 : 
     137                 : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
     138                 : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
     139                 :     /// Perform user authentication. Raise an error in case authentication failed.
     140 CBC          48 :     pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
     141                 :         // Initial client message contains the chosen auth method's name.
     142              48 :         let msg = self.stream.read_password_message().await?;
     143              47 :         let sasl = sasl::FirstMessage::parse(&msg)
     144              47 :             .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
     145                 : 
     146                 :         // Currently, the only supported SASL method is SCRAM.
     147              47 :         if !scram::METHODS.contains(&sasl.method) {
     148 UBC           0 :             return Err(super::AuthError::bad_auth_method(sasl.method));
     149 CBC          47 :         }
     150                 : 
     151              36 :         info!("client chooses {}", sasl.method);
     152                 : 
     153              47 :         let secret = self.state.0;
     154              47 :         let outcome = sasl::SaslStream::new(self.stream, sasl.message)
     155              47 :             .authenticate(scram::Exchange::new(
     156              47 :                 secret,
     157              47 :                 rand::random,
     158              47 :                 self.tls_server_end_point,
     159              47 :             ))
     160              42 :             .await?;
     161                 : 
     162              42 :         if let sasl::Outcome::Success(_) = &outcome {
     163              38 :             self.stream.write_message_noflush(&Be::AuthenticationOk)?;
     164               4 :         }
     165                 : 
     166              42 :         Ok(outcome)
     167              48 :     }
     168                 : }
     169                 : 
     170               2 : pub(super) fn validate_password_and_exchange(
     171               2 :     password: &[u8],
     172               2 :     secret: AuthSecret,
     173               2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     174               2 :     match secret {
     175                 :         #[cfg(feature = "testing")]
     176                 :         AuthSecret::Md5(_) => {
     177                 :             // test only
     178 UBC           0 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
     179               0 :                 password.to_owned(),
     180               0 :             )))
     181                 :         }
     182                 :         // perform scram authentication as both client and server to validate the keys
     183 CBC           2 :         AuthSecret::Scram(scram_secret) => {
     184               2 :             use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
     185               2 :             let sasl_client = ScramSha256::new(password, ChannelBinding::unsupported());
     186               2 :             let outcome = crate::scram::exchange(
     187               2 :                 &scram_secret,
     188               2 :                 sasl_client,
     189               2 :                 crate::config::TlsServerEndPoint::Undefined,
     190               2 :             )?;
     191                 : 
     192               2 :             let client_key = match outcome {
     193               2 :                 sasl::Outcome::Success(client_key) => client_key,
     194 UBC           0 :                 sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
     195                 :             };
     196                 : 
     197 CBC           2 :             let keys = crate::compute::ScramKeys {
     198               2 :                 client_key: client_key.as_bytes(),
     199               2 :                 server_key: scram_secret.server_key.as_bytes(),
     200               2 :             };
     201               2 : 
     202               2 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
     203               2 :                 tokio_postgres::config::AuthKeys::ScramSha256(keys),
     204               2 :             )))
     205                 :         }
     206                 :     }
     207               2 : }
        

Generated by: LCOV version 2.1-beta