LCOV - code coverage report
Current view: top level - proxy/src/auth - flow.rs (source / functions) Coverage Total Hit
Test: a43a77853355b937a79c57b07a8f05607cf29e6c.info Lines: 93.0 % 115 107
Test Date: 2024-09-19 12:04:32 Functions: 39.7 % 58 23

            Line data    Source code
       1              : //! Main authentication flow.
       2              : 
       3              : use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
       4              : use crate::{
       5              :     config::TlsServerEndPoint,
       6              :     console::AuthSecret,
       7              :     context::RequestMonitoring,
       8              :     intern::EndpointIdInt,
       9              :     sasl,
      10              :     scram::{self, threadpool::ThreadPool},
      11              :     stream::{PqStream, Stream},
      12              : };
      13              : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
      14              : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
      15              : use std::{io, sync::Arc};
      16              : use tokio::io::{AsyncRead, AsyncWrite};
      17              : use tracing::info;
      18              : 
      19              : /// Every authentication selector is supposed to implement this trait.
      20              : pub(crate) trait AuthMethod {
      21              :     /// Any authentication selector should provide initial backend message
      22              :     /// containing auth method name and parameters, e.g. md5 salt.
      23              :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
      24              : }
      25              : 
      26              : /// Initial state of [`AuthFlow`].
      27              : pub(crate) struct Begin;
      28              : 
      29              : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
      30              : pub(crate) struct Scram<'a>(
      31              :     pub(crate) &'a scram::ServerSecret,
      32              :     pub(crate) &'a RequestMonitoring,
      33              : );
      34              : 
      35              : impl AuthMethod for Scram<'_> {
      36              :     #[inline(always)]
      37           13 :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
      38           13 :         if channel_binding {
      39           12 :             Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
      40              :         } else {
      41            1 :             Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
      42            1 :                 scram::METHODS_WITHOUT_PLUS,
      43            1 :             ))
      44              :         }
      45           13 :     }
      46              : }
      47              : 
      48              : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
      49              : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
      50              : pub(crate) struct PasswordHack;
      51              : 
      52              : impl AuthMethod for PasswordHack {
      53              :     #[inline(always)]
      54            1 :     fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
      55            1 :         Be::AuthenticationCleartextPassword
      56            1 :     }
      57              : }
      58              : 
      59              : /// Use clear-text password auth called `password` in docs
      60              : /// <https://www.postgresql.org/docs/current/auth-password.html>
      61              : pub(crate) struct CleartextPassword {
      62              :     pub(crate) pool: Arc<ThreadPool>,
      63              :     pub(crate) endpoint: EndpointIdInt,
      64              :     pub(crate) secret: AuthSecret,
      65              : }
      66              : 
      67              : impl AuthMethod for CleartextPassword {
      68              :     #[inline(always)]
      69            1 :     fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
      70            1 :         Be::AuthenticationCleartextPassword
      71            1 :     }
      72              : }
      73              : 
      74              : /// This wrapper for [`PqStream`] performs client authentication.
      75              : #[must_use]
      76              : pub(crate) struct AuthFlow<'a, S, State> {
      77              :     /// The underlying stream which implements libpq's protocol.
      78              :     stream: &'a mut PqStream<Stream<S>>,
      79              :     /// State might contain ancillary data (see [`Self::begin`]).
      80              :     state: State,
      81              :     tls_server_end_point: TlsServerEndPoint,
      82              : }
      83              : 
      84              : /// Initial state of the stream wrapper.
      85              : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
      86              :     /// Create a new wrapper for client authentication.
      87           15 :     pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
      88           15 :         let tls_server_end_point = stream.get_ref().tls_server_end_point();
      89           15 : 
      90           15 :         Self {
      91           15 :             stream,
      92           15 :             state: Begin,
      93           15 :             tls_server_end_point,
      94           15 :         }
      95           15 :     }
      96              : 
      97              :     /// Move to the next step by sending auth method's name & params to client.
      98           15 :     pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
      99           15 :         self.stream
     100           15 :             .write_message(&method.first_message(self.tls_server_end_point.supported()))
     101            0 :             .await?;
     102              : 
     103           15 :         Ok(AuthFlow {
     104           15 :             stream: self.stream,
     105           15 :             state: method,
     106           15 :             tls_server_end_point: self.tls_server_end_point,
     107           15 :         })
     108           15 :     }
     109              : }
     110              : 
     111              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
     112              :     /// Perform user authentication. Raise an error in case authentication failed.
     113            1 :     pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
     114            1 :         let msg = self.stream.read_password_message().await?;
     115            1 :         let password = msg
     116            1 :             .strip_suffix(&[0])
     117            1 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
     118              : 
     119            1 :         let payload = PasswordHackPayload::parse(password)
     120            1 :             // If we ended up here and the payload is malformed, it means that
     121            1 :             // the user neither enabled SNI nor resorted to any other method
     122            1 :             // for passing the project name we rely on. We should show them
     123            1 :             // the most helpful error message and point to the documentation.
     124            1 :             .ok_or(AuthErrorImpl::MissingEndpointName)?;
     125              : 
     126            1 :         Ok(payload)
     127            1 :     }
     128              : }
     129              : 
     130              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
     131              :     /// Perform user authentication. Raise an error in case authentication failed.
     132            1 :     pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     133            1 :         let msg = self.stream.read_password_message().await?;
     134            1 :         let password = msg
     135            1 :             .strip_suffix(&[0])
     136            1 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
     137              : 
     138            1 :         let outcome = validate_password_and_exchange(
     139            1 :             &self.state.pool,
     140            1 :             self.state.endpoint,
     141            1 :             password,
     142            1 :             self.state.secret,
     143            1 :         )
     144            1 :         .await?;
     145              : 
     146            1 :         if let sasl::Outcome::Success(_) = &outcome {
     147            1 :             self.stream.write_message_noflush(&Be::AuthenticationOk)?;
     148            0 :         }
     149              : 
     150            1 :         Ok(outcome)
     151            1 :     }
     152              : }
     153              : 
     154              : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
     155              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
     156              :     /// Perform user authentication. Raise an error in case authentication failed.
     157           13 :     pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
     158           13 :         let Scram(secret, ctx) = self.state;
     159           13 : 
     160           13 :         // pause the timer while we communicate with the client
     161           13 :         let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
     162              : 
     163              :         // Initial client message contains the chosen auth method's name.
     164           13 :         let msg = self.stream.read_password_message().await?;
     165           12 :         let sasl = sasl::FirstMessage::parse(&msg)
     166           12 :             .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
     167              : 
     168              :         // Currently, the only supported SASL method is SCRAM.
     169           12 :         if !scram::METHODS.contains(&sasl.method) {
     170            0 :             return Err(super::AuthError::bad_auth_method(sasl.method));
     171           12 :         }
     172           12 : 
     173           12 :         match sasl.method {
     174           12 :             SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
     175            6 :             SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
     176            0 :             _ => {}
     177              :         }
     178           12 :         info!("client chooses {}", sasl.method);
     179              : 
     180           12 :         let outcome = sasl::SaslStream::new(self.stream, sasl.message)
     181           12 :             .authenticate(scram::Exchange::new(
     182           12 :                 secret,
     183           12 :                 rand::random,
     184           12 :                 self.tls_server_end_point,
     185           12 :             ))
     186           11 :             .await?;
     187              : 
     188            7 :         if let sasl::Outcome::Success(_) = &outcome {
     189            6 :             self.stream.write_message_noflush(&Be::AuthenticationOk)?;
     190            1 :         }
     191              : 
     192            7 :         Ok(outcome)
     193           13 :     }
     194              : }
     195              : 
     196            2 : pub(crate) async fn validate_password_and_exchange(
     197            2 :     pool: &ThreadPool,
     198            2 :     endpoint: EndpointIdInt,
     199            2 :     password: &[u8],
     200            2 :     secret: AuthSecret,
     201            2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     202            2 :     match secret {
     203              :         #[cfg(any(test, feature = "testing"))]
     204              :         AuthSecret::Md5(_) => {
     205              :             // test only
     206            0 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
     207            0 :                 password.to_owned(),
     208            0 :             )))
     209              :         }
     210              :         // perform scram authentication as both client and server to validate the keys
     211            2 :         AuthSecret::Scram(scram_secret) => {
     212            2 :             let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
     213              : 
     214            2 :             let client_key = match outcome {
     215            2 :                 sasl::Outcome::Success(client_key) => client_key,
     216            0 :                 sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
     217              :             };
     218              : 
     219            2 :             let keys = crate::compute::ScramKeys {
     220            2 :                 client_key: client_key.as_bytes(),
     221            2 :                 server_key: scram_secret.server_key.as_bytes(),
     222            2 :             };
     223            2 : 
     224            2 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
     225            2 :                 tokio_postgres::config::AuthKeys::ScramSha256(keys),
     226            2 :             )))
     227              :         }
     228              :     }
     229            2 : }
        

Generated by: LCOV version 2.1-beta