LCOV - code coverage report
Current view: top level - proxy/src/auth - flow.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 97.7 % 87 85
Test Date: 2025-07-16 12:29:03 Functions: 39.4 % 33 13

            Line data    Source code
       1              : //! Main authentication flow.
       2              : 
       3              : use std::sync::Arc;
       4              : 
       5              : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
       6              : use tokio::io::{AsyncRead, AsyncWrite};
       7              : use tracing::info;
       8              : 
       9              : use super::backend::ComputeCredentialKeys;
      10              : use super::{AuthError, PasswordHackPayload};
      11              : use crate::context::RequestContext;
      12              : use crate::control_plane::AuthSecret;
      13              : use crate::intern::EndpointIdInt;
      14              : use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
      15              : use crate::sasl;
      16              : use crate::scram::threadpool::ThreadPool;
      17              : use crate::scram::{self};
      18              : use crate::stream::{PqStream, Stream};
      19              : use crate::tls::TlsServerEndPoint;
      20              : 
      21              : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
      22              : pub(crate) struct Scram<'a>(
      23              :     pub(crate) &'a scram::ServerSecret,
      24              :     pub(crate) &'a RequestContext,
      25              : );
      26              : 
      27              : impl Scram<'_> {
      28              :     #[inline(always)]
      29           13 :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
      30           13 :         if channel_binding {
      31           12 :             BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
      32              :         } else {
      33            1 :             BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
      34            1 :                 scram::METHODS_WITHOUT_PLUS,
      35            1 :             ))
      36              :         }
      37           13 :     }
      38              : }
      39              : 
      40              : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
      41              : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
      42              : pub(crate) struct PasswordHack;
      43              : 
      44              : /// Use clear-text password auth called `password` in docs
      45              : /// <https://www.postgresql.org/docs/current/auth-password.html>
      46              : pub(crate) struct CleartextPassword {
      47              :     pub(crate) pool: Arc<ThreadPool>,
      48              :     pub(crate) endpoint: EndpointIdInt,
      49              :     pub(crate) secret: AuthSecret,
      50              : }
      51              : 
      52              : /// This wrapper for [`PqStream`] performs client authentication.
      53              : #[must_use]
      54              : pub(crate) struct AuthFlow<'a, S, State> {
      55              :     /// The underlying stream which implements libpq's protocol.
      56              :     stream: &'a mut PqStream<Stream<S>>,
      57              :     /// State might contain ancillary data.
      58              :     state: State,
      59              :     tls_server_end_point: TlsServerEndPoint,
      60              : }
      61              : 
      62              : /// Initial state of the stream wrapper.
      63              : impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
      64              :     /// Create a new wrapper for client authentication.
      65           15 :     pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
      66           15 :         let tls_server_end_point = stream.get_ref().tls_server_end_point();
      67              : 
      68           15 :         Self {
      69           15 :             stream,
      70           15 :             state: method,
      71           15 :             tls_server_end_point,
      72           15 :         }
      73           15 :     }
      74              : }
      75              : 
      76              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
      77              :     /// Perform user authentication. Raise an error in case authentication failed.
      78            1 :     pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
      79            1 :         self.stream
      80            1 :             .write_message(BeMessage::AuthenticationCleartextPassword);
      81            1 :         self.stream.flush().await?;
      82              : 
      83            1 :         let msg = self.stream.read_password_message().await?;
      84            1 :         let password = msg
      85            1 :             .strip_suffix(&[0])
      86            1 :             .ok_or(AuthError::MalformedPassword("missing terminator"))?;
      87              : 
      88            1 :         let payload = PasswordHackPayload::parse(password)
      89              :             // If we ended up here and the payload is malformed, it means that
      90              :             // the user neither enabled SNI nor resorted to any other method
      91              :             // for passing the project name we rely on. We should show them
      92              :             // the most helpful error message and point to the documentation.
      93            1 :             .ok_or(AuthError::MissingEndpointName)?;
      94              : 
      95            1 :         Ok(payload)
      96            1 :     }
      97              : }
      98              : 
      99              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
     100              :     /// Perform user authentication. Raise an error in case authentication failed.
     101            1 :     pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     102            1 :         self.stream
     103            1 :             .write_message(BeMessage::AuthenticationCleartextPassword);
     104            1 :         self.stream.flush().await?;
     105              : 
     106            1 :         let msg = self.stream.read_password_message().await?;
     107            1 :         let password = msg
     108            1 :             .strip_suffix(&[0])
     109            1 :             .ok_or(AuthError::MalformedPassword("missing terminator"))?;
     110              : 
     111            1 :         let outcome = validate_password_and_exchange(
     112            1 :             &self.state.pool,
     113            1 :             self.state.endpoint,
     114            1 :             password,
     115            1 :             self.state.secret,
     116            1 :         )
     117            1 :         .await?;
     118              : 
     119            1 :         if let sasl::Outcome::Success(_) = &outcome {
     120            1 :             self.stream.write_message(BeMessage::AuthenticationOk);
     121            1 :         }
     122              : 
     123            1 :         Ok(outcome)
     124            1 :     }
     125              : }
     126              : 
     127              : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
     128              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
     129              :     /// Perform user authentication. Raise an error in case authentication failed.
     130           13 :     pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
     131           13 :         let Scram(secret, ctx) = self.state;
     132           13 :         let channel_binding = self.tls_server_end_point;
     133              : 
     134              :         // send sasl message.
     135              :         {
     136              :             // pause the timer while we communicate with the client
     137           13 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
     138              : 
     139           13 :             let sasl = self.state.first_message(channel_binding.supported());
     140           13 :             self.stream.write_message(sasl);
     141           13 :             self.stream.flush().await?;
     142              :         }
     143              : 
     144              :         // complete sasl handshake.
     145           13 :         sasl::authenticate(ctx, self.stream, |method| {
     146              :             // Currently, the only supported SASL method is SCRAM.
     147           12 :             match method {
     148           12 :                 SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
     149            6 :                 SCRAM_SHA_256_PLUS => {
     150            6 :                     ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus);
     151            6 :                 }
     152            0 :                 method => return Err(sasl::Error::BadAuthMethod(method.into())),
     153              :             }
     154              : 
     155              :             // TODO: make this a metric instead
     156           12 :             info!("client chooses {}", method);
     157              : 
     158           12 :             Ok(scram::Exchange::new(secret, rand::random, channel_binding))
     159           12 :         })
     160           13 :         .await
     161           13 :         .map_err(AuthError::Sasl)
     162           13 :     }
     163              : }
     164              : 
     165            2 : pub(crate) async fn validate_password_and_exchange(
     166            2 :     pool: &ThreadPool,
     167            2 :     endpoint: EndpointIdInt,
     168            2 :     password: &[u8],
     169            2 :     secret: AuthSecret,
     170            2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     171            2 :     match secret {
     172              :         // perform scram authentication as both client and server to validate the keys
     173            2 :         AuthSecret::Scram(scram_secret) => {
     174            2 :             let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
     175              : 
     176            2 :             let client_key = match outcome {
     177            2 :                 sasl::Outcome::Success(client_key) => client_key,
     178            0 :                 sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
     179              :             };
     180              : 
     181            2 :             let keys = crate::compute::ScramKeys {
     182            2 :                 client_key: client_key.as_bytes(),
     183            2 :                 server_key: scram_secret.server_key.as_bytes(),
     184            2 :             };
     185              : 
     186            2 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
     187            2 :                 postgres_client::config::AuthKeys::ScramSha256(keys),
     188            2 :             )))
     189              :         }
     190              :     }
     191            2 : }
        

Generated by: LCOV version 2.1-beta