LCOV - code coverage report
Current view: top level - proxy/src/auth - flow.rs (source / functions) Coverage Total Hit
Test: 1d5975439f3c9882b18414799141ebf9a3922c58.info Lines: 97.8 % 90 88
Test Date: 2025-07-31 15:59:03 Functions: 39.4 % 33 13

            Line data    Source code
       1              : //! Main authentication flow.
       2              : 
       3              : use std::sync::Arc;
       4              : 
       5              : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
       6              : use tokio::io::{AsyncRead, AsyncWrite};
       7              : use tracing::info;
       8              : 
       9              : use super::backend::ComputeCredentialKeys;
      10              : use super::{AuthError, PasswordHackPayload};
      11              : use crate::context::RequestContext;
      12              : use crate::control_plane::AuthSecret;
      13              : use crate::intern::{EndpointIdInt, RoleNameInt};
      14              : use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
      15              : use crate::sasl;
      16              : use crate::scram::threadpool::ThreadPool;
      17              : use crate::scram::{self};
      18              : use crate::stream::{PqStream, Stream};
      19              : use crate::tls::TlsServerEndPoint;
      20              : 
      21              : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
      22              : pub(crate) struct Scram<'a>(
      23              :     pub(crate) &'a scram::ServerSecret,
      24              :     pub(crate) &'a RequestContext,
      25              : );
      26              : 
      27              : impl Scram<'_> {
      28              :     #[inline(always)]
      29           13 :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
      30           13 :         if channel_binding {
      31           12 :             BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
      32              :         } else {
      33            1 :             BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
      34            1 :                 scram::METHODS_WITHOUT_PLUS,
      35            1 :             ))
      36              :         }
      37           13 :     }
      38              : }
      39              : 
      40              : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
      41              : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
      42              : pub(crate) struct PasswordHack;
      43              : 
      44              : /// Use clear-text password auth called `password` in docs
      45              : /// <https://www.postgresql.org/docs/current/auth-password.html>
      46              : pub(crate) struct CleartextPassword {
      47              :     pub(crate) pool: Arc<ThreadPool>,
      48              :     pub(crate) endpoint: EndpointIdInt,
      49              :     pub(crate) role: RoleNameInt,
      50              :     pub(crate) secret: AuthSecret,
      51              : }
      52              : 
      53              : /// This wrapper for [`PqStream`] performs client authentication.
      54              : #[must_use]
      55              : pub(crate) struct AuthFlow<'a, S, State> {
      56              :     /// The underlying stream which implements libpq's protocol.
      57              :     stream: &'a mut PqStream<Stream<S>>,
      58              :     /// State might contain ancillary data.
      59              :     state: State,
      60              :     tls_server_end_point: TlsServerEndPoint,
      61              : }
      62              : 
      63              : /// Initial state of the stream wrapper.
      64              : impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
      65              :     /// Create a new wrapper for client authentication.
      66           15 :     pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
      67           15 :         let tls_server_end_point = stream.get_ref().tls_server_end_point();
      68              : 
      69           15 :         Self {
      70           15 :             stream,
      71           15 :             state: method,
      72           15 :             tls_server_end_point,
      73           15 :         }
      74           15 :     }
      75              : }
      76              : 
      77              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
      78              :     /// Perform user authentication. Raise an error in case authentication failed.
      79            1 :     pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
      80            1 :         self.stream
      81            1 :             .write_message(BeMessage::AuthenticationCleartextPassword);
      82            1 :         self.stream.flush().await?;
      83              : 
      84            1 :         let msg = self.stream.read_password_message().await?;
      85            1 :         let password = msg
      86            1 :             .strip_suffix(&[0])
      87            1 :             .ok_or(AuthError::MalformedPassword("missing terminator"))?;
      88              : 
      89            1 :         let payload = PasswordHackPayload::parse(password)
      90              :             // If we ended up here and the payload is malformed, it means that
      91              :             // the user neither enabled SNI nor resorted to any other method
      92              :             // for passing the project name we rely on. We should show them
      93              :             // the most helpful error message and point to the documentation.
      94            1 :             .ok_or(AuthError::MissingEndpointName)?;
      95              : 
      96            1 :         Ok(payload)
      97            1 :     }
      98              : }
      99              : 
     100              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
     101              :     /// Perform user authentication. Raise an error in case authentication failed.
     102            1 :     pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     103            1 :         self.stream
     104            1 :             .write_message(BeMessage::AuthenticationCleartextPassword);
     105            1 :         self.stream.flush().await?;
     106              : 
     107            1 :         let msg = self.stream.read_password_message().await?;
     108            1 :         let password = msg
     109            1 :             .strip_suffix(&[0])
     110            1 :             .ok_or(AuthError::MalformedPassword("missing terminator"))?;
     111              : 
     112            1 :         let outcome = validate_password_and_exchange(
     113            1 :             &self.state.pool,
     114            1 :             self.state.endpoint,
     115            1 :             self.state.role,
     116            1 :             password,
     117            1 :             self.state.secret,
     118            1 :         )
     119            1 :         .await?;
     120              : 
     121            1 :         if let sasl::Outcome::Success(_) = &outcome {
     122            1 :             self.stream.write_message(BeMessage::AuthenticationOk);
     123            1 :         }
     124              : 
     125            1 :         Ok(outcome)
     126            1 :     }
     127              : }
     128              : 
     129              : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
     130              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
     131              :     /// Perform user authentication. Raise an error in case authentication failed.
     132           13 :     pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
     133           13 :         let Scram(secret, ctx) = self.state;
     134           13 :         let channel_binding = self.tls_server_end_point;
     135              : 
     136              :         // send sasl message.
     137              :         {
     138              :             // pause the timer while we communicate with the client
     139           13 :             let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
     140              : 
     141           13 :             let sasl = self.state.first_message(channel_binding.supported());
     142           13 :             self.stream.write_message(sasl);
     143           13 :             self.stream.flush().await?;
     144              :         }
     145              : 
     146              :         // complete sasl handshake.
     147           13 :         sasl::authenticate(ctx, self.stream, |method| {
     148              :             // Currently, the only supported SASL method is SCRAM.
     149           12 :             match method {
     150           12 :                 SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
     151            6 :                 SCRAM_SHA_256_PLUS => {
     152            6 :                     ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus);
     153            6 :                 }
     154            0 :                 method => return Err(sasl::Error::BadAuthMethod(method.into())),
     155              :             }
     156              : 
     157              :             // TODO: make this a metric instead
     158           12 :             info!("client chooses {}", method);
     159              : 
     160           12 :             Ok(scram::Exchange::new(secret, rand::random, channel_binding))
     161           12 :         })
     162           13 :         .await
     163           13 :         .map_err(AuthError::Sasl)
     164           13 :     }
     165              : }
     166              : 
     167            2 : pub(crate) async fn validate_password_and_exchange(
     168            2 :     pool: &ThreadPool,
     169            2 :     endpoint: EndpointIdInt,
     170            2 :     role: RoleNameInt,
     171            2 :     password: &[u8],
     172            2 :     secret: AuthSecret,
     173            2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     174            2 :     match secret {
     175              :         // perform scram authentication as both client and server to validate the keys
     176            2 :         AuthSecret::Scram(scram_secret) => {
     177            2 :             let outcome =
     178            2 :                 crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?;
     179              : 
     180            2 :             let client_key = match outcome {
     181            2 :                 sasl::Outcome::Success(client_key) => client_key,
     182            0 :                 sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
     183              :             };
     184              : 
     185            2 :             let keys = crate::compute::ScramKeys {
     186            2 :                 client_key: client_key.as_bytes(),
     187            2 :                 server_key: scram_secret.server_key.as_bytes(),
     188            2 :             };
     189              : 
     190            2 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
     191            2 :                 postgres_client::config::AuthKeys::ScramSha256(keys),
     192            2 :             )))
     193              :         }
     194              :     }
     195            2 : }
        

Generated by: LCOV version 2.1-beta