LCOV - code coverage report
Current view: top level - proxy/src/sasl - stream.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 100.0 % 40 40
Test Date: 2025-07-16 12:29:03 Functions: 22.2 % 9 2

            Line data    Source code
       1              : //! Abstraction for the string-oriented SASL protocols.
       2              : 
       3              : use std::io;
       4              : 
       5              : use tokio::io::{AsyncRead, AsyncWrite};
       6              : 
       7              : use super::{Mechanism, Step};
       8              : use crate::context::RequestContext;
       9              : use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
      10              : use crate::stream::PqStream;
      11              : 
      12              : /// SASL authentication outcome.
      13              : /// It's much easier to match on those two variants
      14              : /// than to peek into a noisy protocol error type.
      15              : #[must_use = "caller must explicitly check for success"]
      16              : pub(crate) enum Outcome<R> {
      17              :     /// Authentication succeeded and produced some value.
      18              :     Success(R),
      19              :     /// Authentication failed (reason attached).
      20              :     Failure(&'static str),
      21              : }
      22              : 
      23           13 : pub async fn authenticate<S, F, M>(
      24           13 :     ctx: &RequestContext,
      25           13 :     stream: &mut PqStream<S>,
      26           13 :     mechanism: F,
      27           13 : ) -> super::Result<Outcome<M::Output>>
      28           13 : where
      29           13 :     S: AsyncRead + AsyncWrite + Unpin,
      30           13 :     F: FnOnce(&str) -> super::Result<M>,
      31           13 :     M: Mechanism,
      32           13 : {
      33           12 :     let (mut mechanism, mut input) = {
      34              :         // pause the timer while we communicate with the client
      35           13 :         let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
      36              : 
      37              :         // Initial client message contains the chosen auth method's name.
      38           13 :         let msg = stream.read_password_message().await?;
      39              : 
      40           12 :         let sasl = super::FirstMessage::parse(msg)
      41           12 :             .ok_or(super::Error::BadClientMessage("bad sasl message"))?;
      42              : 
      43           12 :         (mechanism(sasl.method)?, sasl.message)
      44              :     };
      45              : 
      46              :     loop {
      47           23 :         match mechanism.exchange(input) {
      48           11 :             Ok(Step::Continue(moved_mechanism, reply)) => {
      49           11 :                 mechanism = moved_mechanism;
      50           11 : 
      51           11 :                 // write reply
      52           11 :                 let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
      53           11 :                 stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
      54           11 :                 drop(reply);
      55           11 :             }
      56            6 :             Ok(Step::Success(result, reply)) => {
      57              :                 // write reply
      58            6 :                 let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
      59            6 :                 stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
      60            6 :                 stream.write_message(BeMessage::AuthenticationOk);
      61              : 
      62              :                 // exit with success
      63            6 :                 break Ok(Outcome::Success(result));
      64              :             }
      65              :             // exit with failure
      66            1 :             Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)),
      67            5 :             Err(error) => {
      68            5 :                 tracing::info!(?error, "error during SASL exchange");
      69            5 :                 return Err(error);
      70              :             }
      71              :         }
      72              : 
      73              :         // pause the timer while we communicate with the client
      74           11 :         let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
      75              : 
      76              :         // get next input
      77           11 :         stream.flush().await?;
      78           11 :         let msg = stream.read_password_message().await?;
      79           11 :         input = std::str::from_utf8(msg)
      80           11 :             .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
      81              :     }
      82           13 : }
        

Generated by: LCOV version 2.1-beta