LCOV - code coverage report
Current view: top level - proxy/src/sasl - stream.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 100.0 % 43 43
Test Date: 2025-02-20 13:11:02 Functions: 34.0 % 50 17

            Line data    Source code
       1              : //! Abstraction for the string-oriented SASL protocols.
       2              : 
       3              : use std::io;
       4              : 
       5              : use tokio::io::{AsyncRead, AsyncWrite};
       6              : use tracing::info;
       7              : 
       8              : use super::messages::ServerMessage;
       9              : use super::Mechanism;
      10              : use crate::stream::PqStream;
      11              : 
      12              : /// Abstracts away all peculiarities of the libpq's protocol.
      13              : pub(crate) struct SaslStream<'a, S> {
      14              :     /// The underlying stream.
      15              :     stream: &'a mut PqStream<S>,
      16              :     /// Current password message we received from client.
      17              :     current: bytes::Bytes,
      18              :     /// First SASL message produced by client.
      19              :     first: Option<&'a str>,
      20              : }
      21              : 
      22              : impl<'a, S> SaslStream<'a, S> {
      23           12 :     pub(crate) fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
      24           12 :         Self {
      25           12 :             stream,
      26           12 :             current: bytes::Bytes::new(),
      27           12 :             first: Some(first),
      28           12 :         }
      29           12 :     }
      30              : }
      31              : 
      32              : impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
      33              :     // Receive a new SASL message from the client.
      34           23 :     async fn recv(&mut self) -> io::Result<&str> {
      35           23 :         if let Some(first) = self.first.take() {
      36           12 :             return Ok(first);
      37           11 :         }
      38              : 
      39           11 :         self.current = self.stream.read_password_message().await?;
      40           11 :         let s = std::str::from_utf8(&self.current)
      41           11 :             .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
      42              : 
      43           11 :         Ok(s)
      44           23 :     }
      45              : }
      46              : 
      47              : impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
      48              :     // Send a SASL message to the client.
      49           11 :     async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
      50           11 :         self.stream.write_message(&msg.to_reply()).await?;
      51           11 :         Ok(())
      52           11 :     }
      53              : 
      54              :     // Queue a SASL message for the client.
      55            6 :     fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
      56            6 :         self.stream.write_message_noflush(&msg.to_reply())?;
      57            6 :         Ok(())
      58            6 :     }
      59              : }
      60              : 
      61              : /// SASL authentication outcome.
      62              : /// It's much easier to match on those two variants
      63              : /// than to peek into a noisy protocol error type.
      64              : #[must_use = "caller must explicitly check for success"]
      65              : pub(crate) enum Outcome<R> {
      66              :     /// Authentication succeeded and produced some value.
      67              :     Success(R),
      68              :     /// Authentication failed (reason attached).
      69              :     Failure(&'static str),
      70              : }
      71              : 
      72              : impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
      73              :     /// Perform SASL message exchange according to the underlying algorithm
      74              :     /// until user is either authenticated or denied access.
      75           12 :     pub(crate) async fn authenticate<M: Mechanism>(
      76           12 :         mut self,
      77           12 :         mut mechanism: M,
      78           12 :     ) -> super::Result<Outcome<M::Output>> {
      79              :         loop {
      80           23 :             let input = self.recv().await?;
      81           23 :             let step = mechanism.exchange(input).map_err(|error| {
      82            5 :                 info!(?error, "error during SASL exchange");
      83            5 :                 error
      84           23 :             })?;
      85              : 
      86              :             use super::Step;
      87           18 :             return Ok(match step {
      88           11 :                 Step::Continue(moved_mechanism, reply) => {
      89           11 :                     self.send(&ServerMessage::Continue(&reply)).await?;
      90           11 :                     mechanism = moved_mechanism;
      91           11 :                     continue;
      92              :                 }
      93            6 :                 Step::Success(result, reply) => {
      94            6 :                     self.send_noflush(&ServerMessage::Final(&reply))?;
      95            6 :                     Outcome::Success(result)
      96              :                 }
      97            1 :                 Step::Failure(reason) => Outcome::Failure(reason),
      98              :             });
      99              :         }
     100           12 :     }
     101              : }
        

Generated by: LCOV version 2.1-beta