LCOV - code coverage report
Current view: top level - proxy/src/auth - flow.rs (source / functions) Coverage Total Hit
Test: 20b6afc7b7f34578dcaab2b3acdaecfe91cd8bf1.info Lines: 93.0 % 115 107
Test Date: 2024-11-25 17:48:16 Functions: 39.7 % 58 23

            Line data    Source code
       1              : //! Main authentication flow.
       2              : 
       3              : use std::io;
       4              : use std::sync::Arc;
       5              : 
       6              : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
       7              : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
       8              : use tokio::io::{AsyncRead, AsyncWrite};
       9              : use tracing::info;
      10              : 
      11              : use super::backend::ComputeCredentialKeys;
      12              : use super::{AuthError, PasswordHackPayload};
      13              : use crate::config::TlsServerEndPoint;
      14              : use crate::context::RequestContext;
      15              : use crate::control_plane::AuthSecret;
      16              : use crate::intern::EndpointIdInt;
      17              : use crate::sasl;
      18              : use crate::scram::threadpool::ThreadPool;
      19              : use crate::scram::{self};
      20              : use crate::stream::{PqStream, Stream};
      21              : 
      22              : /// Every authentication selector is supposed to implement this trait.
      23              : pub(crate) trait AuthMethod {
      24              :     /// Any authentication selector should provide initial backend message
      25              :     /// containing auth method name and parameters, e.g. md5 salt.
      26              :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
      27              : }
      28              : 
      29              : /// Initial state of [`AuthFlow`].
      30              : pub(crate) struct Begin;
      31              : 
      32              : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
      33              : pub(crate) struct Scram<'a>(
      34              :     pub(crate) &'a scram::ServerSecret,
      35              :     pub(crate) &'a RequestContext,
      36              : );
      37              : 
      38              : impl AuthMethod for Scram<'_> {
      39              :     #[inline(always)]
      40           13 :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
      41           13 :         if channel_binding {
      42           12 :             Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
      43              :         } else {
      44            1 :             Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
      45            1 :                 scram::METHODS_WITHOUT_PLUS,
      46            1 :             ))
      47              :         }
      48           13 :     }
      49              : }
      50              : 
      51              : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
      52              : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
      53              : pub(crate) struct PasswordHack;
      54              : 
      55              : impl AuthMethod for PasswordHack {
      56              :     #[inline(always)]
      57            1 :     fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
      58            1 :         Be::AuthenticationCleartextPassword
      59            1 :     }
      60              : }
      61              : 
      62              : /// Use clear-text password auth called `password` in docs
      63              : /// <https://www.postgresql.org/docs/current/auth-password.html>
      64              : pub(crate) struct CleartextPassword {
      65              :     pub(crate) pool: Arc<ThreadPool>,
      66              :     pub(crate) endpoint: EndpointIdInt,
      67              :     pub(crate) secret: AuthSecret,
      68              : }
      69              : 
      70              : impl AuthMethod for CleartextPassword {
      71              :     #[inline(always)]
      72            1 :     fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
      73            1 :         Be::AuthenticationCleartextPassword
      74            1 :     }
      75              : }
      76              : 
      77              : /// This wrapper for [`PqStream`] performs client authentication.
      78              : #[must_use]
      79              : pub(crate) struct AuthFlow<'a, S, State> {
      80              :     /// The underlying stream which implements libpq's protocol.
      81              :     stream: &'a mut PqStream<Stream<S>>,
      82              :     /// State might contain ancillary data (see [`Self::begin`]).
      83              :     state: State,
      84              :     tls_server_end_point: TlsServerEndPoint,
      85              : }
      86              : 
      87              : /// Initial state of the stream wrapper.
      88              : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
      89              :     /// Create a new wrapper for client authentication.
      90           15 :     pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
      91           15 :         let tls_server_end_point = stream.get_ref().tls_server_end_point();
      92           15 : 
      93           15 :         Self {
      94           15 :             stream,
      95           15 :             state: Begin,
      96           15 :             tls_server_end_point,
      97           15 :         }
      98           15 :     }
      99              : 
     100              :     /// Move to the next step by sending auth method's name & params to client.
     101           15 :     pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
     102           15 :         self.stream
     103           15 :             .write_message(&method.first_message(self.tls_server_end_point.supported()))
     104            0 :             .await?;
     105              : 
     106           15 :         Ok(AuthFlow {
     107           15 :             stream: self.stream,
     108           15 :             state: method,
     109           15 :             tls_server_end_point: self.tls_server_end_point,
     110           15 :         })
     111           15 :     }
     112              : }
     113              : 
     114              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
     115              :     /// Perform user authentication. Raise an error in case authentication failed.
     116            1 :     pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
     117            1 :         let msg = self.stream.read_password_message().await?;
     118            1 :         let password = msg
     119            1 :             .strip_suffix(&[0])
     120            1 :             .ok_or(AuthError::MalformedPassword("missing terminator"))?;
     121              : 
     122            1 :         let payload = PasswordHackPayload::parse(password)
     123            1 :             // If we ended up here and the payload is malformed, it means that
     124            1 :             // the user neither enabled SNI nor resorted to any other method
     125            1 :             // for passing the project name we rely on. We should show them
     126            1 :             // the most helpful error message and point to the documentation.
     127            1 :             .ok_or(AuthError::MissingEndpointName)?;
     128              : 
     129            1 :         Ok(payload)
     130            1 :     }
     131              : }
     132              : 
     133              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
     134              :     /// Perform user authentication. Raise an error in case authentication failed.
     135            1 :     pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     136            1 :         let msg = self.stream.read_password_message().await?;
     137            1 :         let password = msg
     138            1 :             .strip_suffix(&[0])
     139            1 :             .ok_or(AuthError::MalformedPassword("missing terminator"))?;
     140              : 
     141            1 :         let outcome = validate_password_and_exchange(
     142            1 :             &self.state.pool,
     143            1 :             self.state.endpoint,
     144            1 :             password,
     145            1 :             self.state.secret,
     146            1 :         )
     147            1 :         .await?;
     148              : 
     149            1 :         if let sasl::Outcome::Success(_) = &outcome {
     150            1 :             self.stream.write_message_noflush(&Be::AuthenticationOk)?;
     151            0 :         }
     152              : 
     153            1 :         Ok(outcome)
     154            1 :     }
     155              : }
     156              : 
     157              : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
     158              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
     159              :     /// Perform user authentication. Raise an error in case authentication failed.
     160           13 :     pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
     161           13 :         let Scram(secret, ctx) = self.state;
     162           13 : 
     163           13 :         // pause the timer while we communicate with the client
     164           13 :         let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
     165              : 
     166              :         // Initial client message contains the chosen auth method's name.
     167           13 :         let msg = self.stream.read_password_message().await?;
     168           12 :         let sasl = sasl::FirstMessage::parse(&msg)
     169           12 :             .ok_or(AuthError::MalformedPassword("bad sasl message"))?;
     170              : 
     171              :         // Currently, the only supported SASL method is SCRAM.
     172           12 :         if !scram::METHODS.contains(&sasl.method) {
     173            0 :             return Err(super::AuthError::bad_auth_method(sasl.method));
     174           12 :         }
     175           12 : 
     176           12 :         match sasl.method {
     177           12 :             SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
     178            6 :             SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
     179            0 :             _ => {}
     180              :         }
     181              : 
     182              :         // TODO: make this a metric instead
     183           12 :         info!("client chooses {}", sasl.method);
     184              : 
     185           12 :         let outcome = sasl::SaslStream::new(self.stream, sasl.message)
     186           12 :             .authenticate(scram::Exchange::new(
     187           12 :                 secret,
     188           12 :                 rand::random,
     189           12 :                 self.tls_server_end_point,
     190           12 :             ))
     191           11 :             .await?;
     192              : 
     193            7 :         if let sasl::Outcome::Success(_) = &outcome {
     194            6 :             self.stream.write_message_noflush(&Be::AuthenticationOk)?;
     195            1 :         }
     196              : 
     197            7 :         Ok(outcome)
     198           13 :     }
     199              : }
     200              : 
     201            2 : pub(crate) async fn validate_password_and_exchange(
     202            2 :     pool: &ThreadPool,
     203            2 :     endpoint: EndpointIdInt,
     204            2 :     password: &[u8],
     205            2 :     secret: AuthSecret,
     206            2 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     207            2 :     match secret {
     208              :         #[cfg(any(test, feature = "testing"))]
     209              :         AuthSecret::Md5(_) => {
     210              :             // test only
     211            0 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
     212            0 :                 password.to_owned(),
     213            0 :             )))
     214              :         }
     215              :         // perform scram authentication as both client and server to validate the keys
     216            2 :         AuthSecret::Scram(scram_secret) => {
     217            2 :             let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
     218              : 
     219            2 :             let client_key = match outcome {
     220            2 :                 sasl::Outcome::Success(client_key) => client_key,
     221            0 :                 sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
     222              :             };
     223              : 
     224            2 :             let keys = crate::compute::ScramKeys {
     225            2 :                 client_key: client_key.as_bytes(),
     226            2 :                 server_key: scram_secret.server_key.as_bytes(),
     227            2 :             };
     228            2 : 
     229            2 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
     230            2 :                 tokio_postgres::config::AuthKeys::ScramSha256(keys),
     231            2 :             )))
     232              :         }
     233              :     }
     234            2 : }
        

Generated by: LCOV version 2.1-beta