LCOV - differential code coverage report
Current view: top level - proxy/src/sasl - stream.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 94.9 % 39 37 2 37
Current Date: 2023-10-19 02:04:12 Functions: 35.0 % 40 14 26 14
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : //! Abstraction for the string-oriented SASL protocols.
       2                 : 
       3                 : use super::{messages::ServerMessage, Mechanism};
       4                 : use crate::stream::PqStream;
       5                 : use std::io;
       6                 : use tokio::io::{AsyncRead, AsyncWrite};
       7                 : use tracing::info;
       8                 : 
       9                 : /// Abstracts away all peculiarities of the libpq's protocol.
      10                 : pub struct SaslStream<'a, S> {
      11                 :     /// The underlying stream.
      12                 :     stream: &'a mut PqStream<S>,
      13                 :     /// Current password message we received from client.
      14                 :     current: bytes::Bytes,
      15                 :     /// First SASL message produced by client.
      16                 :     first: Option<&'a str>,
      17                 : }
      18                 : 
      19                 : impl<'a, S> SaslStream<'a, S> {
      20 CBC          31 :     pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
      21              31 :         Self {
      22              31 :             stream,
      23              31 :             current: bytes::Bytes::new(),
      24              31 :             first: Some(first),
      25              31 :         }
      26              31 :     }
      27                 : }
      28                 : 
      29                 : impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
      30                 :     // Receive a new SASL message from the client.
      31              62 :     async fn recv(&mut self) -> io::Result<&str> {
      32              62 :         if let Some(first) = self.first.take() {
      33              31 :             return Ok(first);
      34              31 :         }
      35                 : 
      36              31 :         self.current = self.stream.read_password_message().await?;
      37              31 :         let s = std::str::from_utf8(&self.current)
      38              31 :             .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
      39                 : 
      40              31 :         Ok(s)
      41              62 :     }
      42                 : }
      43                 : 
      44                 : impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
      45                 :     // Send a SASL message to the client.
      46              58 :     async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
      47              58 :         self.stream.write_message(&msg.to_reply()).await?;
      48              58 :         Ok(())
      49              58 :     }
      50                 : }
      51                 : 
      52                 : /// SASL authentication outcome.
      53                 : /// It's much easier to match on those two variants
      54                 : /// than to peek into a noisy protocol error type.
      55                 : #[must_use = "caller must explicitly check for success"]
      56                 : pub enum Outcome<R> {
      57                 :     /// Authentication succeeded and produced some value.
      58                 :     Success(R),
      59                 :     /// Authentication failed (reason attached).
      60                 :     Failure(&'static str),
      61                 : }
      62                 : 
      63                 : impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
      64                 :     /// Perform SASL message exchange according to the underlying algorithm
      65                 :     /// until user is either authenticated or denied access.
      66              31 :     pub async fn authenticate<M: Mechanism>(
      67              31 :         mut self,
      68              31 :         mut mechanism: M,
      69              31 :     ) -> super::Result<Outcome<M::Output>> {
      70                 :         loop {
      71              62 :             let input = self.recv().await?;
      72              62 :             let step = mechanism.exchange(input).map_err(|error| {
      73 UBC           0 :                 info!(?error, "error during SASL exchange");
      74               0 :                 error
      75 CBC          62 :             })?;
      76                 : 
      77                 :             use super::Step;
      78              62 :             return Ok(match step {
      79              31 :                 Step::Continue(moved_mechanism, reply) => {
      80              31 :                     self.send(&ServerMessage::Continue(&reply)).await?;
      81              31 :                     mechanism = moved_mechanism;
      82              31 :                     continue;
      83                 :                 }
      84              27 :                 Step::Success(result, reply) => {
      85              27 :                     self.send(&ServerMessage::Final(&reply)).await?;
      86              27 :                     Outcome::Success(result)
      87                 :                 }
      88               4 :                 Step::Failure(reason) => Outcome::Failure(reason),
      89                 :             });
      90                 :         }
      91              31 :     }
      92                 : }
        

Generated by: LCOV version 2.1-beta