LCOV - code coverage report
Current view: top level - proxy/src/sasl - stream.rs (source / functions) Coverage Total Hit
Test: 465a86b0c1fda0069b3e0f6c1c126e6b635a1f72.info Lines: 100.0 % 39 39
Test Date: 2024-06-25 15:47:26 Functions: 33.3 % 45 15

            Line data    Source code
       1              : //! Abstraction for the string-oriented SASL protocols.
       2              : 
       3              : use super::{messages::ServerMessage, Mechanism};
       4              : use crate::stream::PqStream;
       5              : use std::io;
       6              : use tokio::io::{AsyncRead, AsyncWrite};
       7              : use tracing::info;
       8              : 
       9              : /// Abstracts away all peculiarities of the libpq's protocol.
      10              : pub struct SaslStream<'a, S> {
      11              :     /// The underlying stream.
      12              :     stream: &'a mut PqStream<S>,
      13              :     /// Current password message we received from client.
      14              :     current: bytes::Bytes,
      15              :     /// First SASL message produced by client.
      16              :     first: Option<&'a str>,
      17              : }
      18              : 
      19              : impl<'a, S> SaslStream<'a, S> {
      20           24 :     pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
      21           24 :         Self {
      22           24 :             stream,
      23           24 :             current: bytes::Bytes::new(),
      24           24 :             first: Some(first),
      25           24 :         }
      26           24 :     }
      27              : }
      28              : 
      29              : impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
      30              :     // Receive a new SASL message from the client.
      31           46 :     async fn recv(&mut self) -> io::Result<&str> {
      32           46 :         if let Some(first) = self.first.take() {
      33           24 :             return Ok(first);
      34           22 :         }
      35              : 
      36           22 :         self.current = self.stream.read_password_message().await?;
      37           22 :         let s = std::str::from_utf8(&self.current)
      38           22 :             .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
      39              : 
      40           22 :         Ok(s)
      41           46 :     }
      42              : }
      43              : 
      44              : impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
      45              :     // Send a SASL message to the client.
      46           34 :     async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
      47           34 :         self.stream.write_message(&msg.to_reply()).await?;
      48           34 :         Ok(())
      49           34 :     }
      50              : }
      51              : 
      52              : /// SASL authentication outcome.
      53              : /// It's much easier to match on those two variants
      54              : /// than to peek into a noisy protocol error type.
      55              : #[must_use = "caller must explicitly check for success"]
      56              : pub enum Outcome<R> {
      57              :     /// Authentication succeeded and produced some value.
      58              :     Success(R),
      59              :     /// Authentication failed (reason attached).
      60              :     Failure(&'static str),
      61              : }
      62              : 
      63              : impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
      64              :     /// Perform SASL message exchange according to the underlying algorithm
      65              :     /// until user is either authenticated or denied access.
      66           24 :     pub async fn authenticate<M: Mechanism>(
      67           24 :         mut self,
      68           24 :         mut mechanism: M,
      69           24 :     ) -> super::Result<Outcome<M::Output>> {
      70              :         loop {
      71           46 :             let input = self.recv().await?;
      72           46 :             let step = mechanism.exchange(input).map_err(|error| {
      73           10 :                 info!(?error, "error during SASL exchange");
      74           10 :                 error
      75           46 :             })?;
      76              : 
      77              :             use super::Step;
      78           36 :             return Ok(match step {
      79           22 :                 Step::Continue(moved_mechanism, reply) => {
      80           22 :                     self.send(&ServerMessage::Continue(&reply)).await?;
      81           22 :                     mechanism = moved_mechanism;
      82           22 :                     continue;
      83              :                 }
      84           12 :                 Step::Success(result, reply) => {
      85           12 :                     self.send(&ServerMessage::Final(&reply)).await?;
      86           12 :                     Outcome::Success(result)
      87              :                 }
      88            2 :                 Step::Failure(reason) => Outcome::Failure(reason),
      89              :             });
      90              :         }
      91           24 :     }
      92              : }
        

Generated by: LCOV version 2.1-beta