LCOV - code coverage report
Current view: top level - proxy/src/auth - flow.rs (source / functions) Coverage Total Hit
Test: fabb29a6339542ee130cd1d32b534fafdc0be240.info Lines: 93.1 % 116 108
Test Date: 2024-06-25 13:20:00 Functions: 39.7 % 58 23

            Line data    Source code
       1              : //! Main authentication flow.
       2              : 
       3              : use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
       4              : use crate::{
       5              :     config::TlsServerEndPoint,
       6              :     console::AuthSecret,
       7              :     context::RequestMonitoring,
       8              :     intern::EndpointIdInt,
       9              :     sasl,
      10              :     scram::{self, threadpool::ThreadPool},
      11              :     stream::{PqStream, Stream},
      12              : };
      13              : use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
      14              : use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
      15              : use std::{io, sync::Arc};
      16              : use tokio::io::{AsyncRead, AsyncWrite};
      17              : use tracing::info;
      18              : 
      19              : /// Every authentication selector is supposed to implement this trait.
      20              : pub trait AuthMethod {
      21              :     /// Any authentication selector should provide initial backend message
      22              :     /// containing auth method name and parameters, e.g. md5 salt.
      23              :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
      24              : }
      25              : 
      26              : /// Initial state of [`AuthFlow`].
      27              : pub struct Begin;
      28              : 
      29              : /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
      30              : pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
      31              : 
      32              : impl AuthMethod for Scram<'_> {
      33              :     #[inline(always)]
      34           26 :     fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
      35           26 :         if channel_binding {
      36           24 :             Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
      37              :         } else {
      38            2 :             Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
      39            2 :                 scram::METHODS_WITHOUT_PLUS,
      40            2 :             ))
      41              :         }
      42           26 :     }
      43              : }
      44              : 
      45              : /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
      46              : /// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
      47              : pub struct PasswordHack;
      48              : 
      49              : impl AuthMethod for PasswordHack {
      50              :     #[inline(always)]
      51            2 :     fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
      52            2 :         Be::AuthenticationCleartextPassword
      53            2 :     }
      54              : }
      55              : 
      56              : /// Use clear-text password auth called `password` in docs
      57              : /// <https://www.postgresql.org/docs/current/auth-password.html>
      58              : pub struct CleartextPassword {
      59              :     pub pool: Arc<ThreadPool>,
      60              :     pub endpoint: EndpointIdInt,
      61              :     pub secret: AuthSecret,
      62              : }
      63              : 
      64              : impl AuthMethod for CleartextPassword {
      65              :     #[inline(always)]
      66            2 :     fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
      67            2 :         Be::AuthenticationCleartextPassword
      68            2 :     }
      69              : }
      70              : 
      71              : /// This wrapper for [`PqStream`] performs client authentication.
      72              : #[must_use]
      73              : pub struct AuthFlow<'a, S, State> {
      74              :     /// The underlying stream which implements libpq's protocol.
      75              :     stream: &'a mut PqStream<Stream<S>>,
      76              :     /// State might contain ancillary data (see [`Self::begin`]).
      77              :     state: State,
      78              :     tls_server_end_point: TlsServerEndPoint,
      79              : }
      80              : 
      81              : /// Initial state of the stream wrapper.
      82              : impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
      83              :     /// Create a new wrapper for client authentication.
      84           30 :     pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
      85           30 :         let tls_server_end_point = stream.get_ref().tls_server_end_point();
      86           30 : 
      87           30 :         Self {
      88           30 :             stream,
      89           30 :             state: Begin,
      90           30 :             tls_server_end_point,
      91           30 :         }
      92           30 :     }
      93              : 
      94              :     /// Move to the next step by sending auth method's name & params to client.
      95           30 :     pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
      96           30 :         self.stream
      97           30 :             .write_message(&method.first_message(self.tls_server_end_point.supported()))
      98            0 :             .await?;
      99              : 
     100           30 :         Ok(AuthFlow {
     101           30 :             stream: self.stream,
     102           30 :             state: method,
     103           30 :             tls_server_end_point: self.tls_server_end_point,
     104           30 :         })
     105           30 :     }
     106              : }
     107              : 
     108              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
     109              :     /// Perform user authentication. Raise an error in case authentication failed.
     110            2 :     pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
     111            2 :         let msg = self.stream.read_password_message().await?;
     112            2 :         let password = msg
     113            2 :             .strip_suffix(&[0])
     114            2 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
     115              : 
     116            2 :         let payload = PasswordHackPayload::parse(password)
     117            2 :             // If we ended up here and the payload is malformed, it means that
     118            2 :             // the user neither enabled SNI nor resorted to any other method
     119            2 :             // for passing the project name we rely on. We should show them
     120            2 :             // the most helpful error message and point to the documentation.
     121            2 :             .ok_or(AuthErrorImpl::MissingEndpointName)?;
     122              : 
     123            2 :         Ok(payload)
     124            2 :     }
     125              : }
     126              : 
     127              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
     128              :     /// Perform user authentication. Raise an error in case authentication failed.
     129            2 :     pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     130            2 :         let msg = self.stream.read_password_message().await?;
     131            2 :         let password = msg
     132            2 :             .strip_suffix(&[0])
     133            2 :             .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
     134              : 
     135            2 :         let outcome = validate_password_and_exchange(
     136            2 :             &self.state.pool,
     137            2 :             self.state.endpoint,
     138            2 :             password,
     139            2 :             self.state.secret,
     140            2 :         )
     141            2 :         .await?;
     142              : 
     143            2 :         if let sasl::Outcome::Success(_) = &outcome {
     144            2 :             self.stream.write_message_noflush(&Be::AuthenticationOk)?;
     145            0 :         }
     146              : 
     147            2 :         Ok(outcome)
     148            2 :     }
     149              : }
     150              : 
     151              : /// Stream wrapper for handling [SCRAM](crate::scram) auth.
     152              : impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
     153              :     /// Perform user authentication. Raise an error in case authentication failed.
     154           26 :     pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
     155           26 :         let Scram(secret, ctx) = self.state;
     156           26 : 
     157           26 :         // pause the timer while we communicate with the client
     158           26 :         let _paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
     159              : 
     160              :         // Initial client message contains the chosen auth method's name.
     161           26 :         let msg = self.stream.read_password_message().await?;
     162           24 :         let sasl = sasl::FirstMessage::parse(&msg)
     163           24 :             .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
     164              : 
     165              :         // Currently, the only supported SASL method is SCRAM.
     166           24 :         if !scram::METHODS.contains(&sasl.method) {
     167            0 :             return Err(super::AuthError::bad_auth_method(sasl.method));
     168           24 :         }
     169           24 : 
     170           24 :         match sasl.method {
     171           24 :             SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
     172           12 :             SCRAM_SHA_256_PLUS => {
     173           12 :                 ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
     174              :             }
     175            0 :             _ => {}
     176              :         }
     177           24 :         info!("client chooses {}", sasl.method);
     178              : 
     179           24 :         let outcome = sasl::SaslStream::new(self.stream, sasl.message)
     180           24 :             .authenticate(scram::Exchange::new(
     181           24 :                 secret,
     182           24 :                 rand::random,
     183           24 :                 self.tls_server_end_point,
     184           24 :             ))
     185           22 :             .await?;
     186              : 
     187           14 :         if let sasl::Outcome::Success(_) = &outcome {
     188           12 :             self.stream.write_message_noflush(&Be::AuthenticationOk)?;
     189            2 :         }
     190              : 
     191           14 :         Ok(outcome)
     192           26 :     }
     193              : }
     194              : 
     195            4 : pub(crate) async fn validate_password_and_exchange(
     196            4 :     pool: &ThreadPool,
     197            4 :     endpoint: EndpointIdInt,
     198            4 :     password: &[u8],
     199            4 :     secret: AuthSecret,
     200            4 : ) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
     201            4 :     match secret {
     202              :         #[cfg(any(test, feature = "testing"))]
     203              :         AuthSecret::Md5(_) => {
     204              :             // test only
     205            0 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
     206            0 :                 password.to_owned(),
     207            0 :             )))
     208              :         }
     209              :         // perform scram authentication as both client and server to validate the keys
     210            4 :         AuthSecret::Scram(scram_secret) => {
     211            4 :             let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
     212              : 
     213            4 :             let client_key = match outcome {
     214            4 :                 sasl::Outcome::Success(client_key) => client_key,
     215            0 :                 sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
     216              :             };
     217              : 
     218            4 :             let keys = crate::compute::ScramKeys {
     219            4 :                 client_key: client_key.as_bytes(),
     220            4 :                 server_key: scram_secret.server_key.as_bytes(),
     221            4 :             };
     222            4 : 
     223            4 :             Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
     224            4 :                 tokio_postgres::config::AuthKeys::ScramSha256(keys),
     225            4 :             )))
     226              :         }
     227              :     }
     228            4 : }
        

Generated by: LCOV version 2.1-beta