LCOV - differential code coverage report
Current view: top level - proxy/src/sasl - stream.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 100.0 % 39 39 39
Current Date: 2024-01-09 02:06:09 Functions: 37.5 % 40 15 25 15
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Abstraction for the string-oriented SASL protocols.
       2                 : 
       3                 : use super::{messages::ServerMessage, Mechanism};
       4                 : use crate::stream::PqStream;
       5                 : use std::io;
       6                 : use tokio::io::{AsyncRead, AsyncWrite};
       7                 : use tracing::info;
       8                 : 
       9                 : /// Abstracts away all peculiarities of the libpq's protocol.
      10                 : pub struct SaslStream<'a, S> {
      11                 :     /// The underlying stream.
      12                 :     stream: &'a mut PqStream<S>,
      13                 :     /// Current password message we received from client.
      14                 :     current: bytes::Bytes,
      15                 :     /// First SASL message produced by client.
      16                 :     first: Option<&'a str>,
      17                 : }
      18                 : 
      19                 : impl<'a, S> SaslStream<'a, S> {
      20 CBC          47 :     pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
      21              47 :         Self {
      22              47 :             stream,
      23              47 :             current: bytes::Bytes::new(),
      24              47 :             first: Some(first),
      25              47 :         }
      26              47 :     }
      27                 : }
      28                 : 
      29                 : impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
      30                 :     // Receive a new SASL message from the client.
      31              93 :     async fn recv(&mut self) -> io::Result<&str> {
      32              93 :         if let Some(first) = self.first.take() {
      33              47 :             return Ok(first);
      34              46 :         }
      35                 : 
      36              46 :         self.current = self.stream.read_password_message().await?;
      37              46 :         let s = std::str::from_utf8(&self.current)
      38              46 :             .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
      39                 : 
      40              46 :         Ok(s)
      41              93 :     }
      42                 : }
      43                 : 
      44                 : impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
      45                 :     // Send a SASL message to the client.
      46              84 :     async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
      47              84 :         self.stream.write_message(&msg.to_reply()).await?;
      48              84 :         Ok(())
      49              84 :     }
      50                 : }
      51                 : 
      52                 : /// SASL authentication outcome.
      53                 : /// It's much easier to match on those two variants
      54                 : /// than to peek into a noisy protocol error type.
      55                 : #[must_use = "caller must explicitly check for success"]
      56                 : pub enum Outcome<R> {
      57                 :     /// Authentication succeeded and produced some value.
      58                 :     Success(R),
      59                 :     /// Authentication failed (reason attached).
      60                 :     Failure(&'static str),
      61                 : }
      62                 : 
      63                 : impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
      64                 :     /// Perform SASL message exchange according to the underlying algorithm
      65                 :     /// until user is either authenticated or denied access.
      66              47 :     pub async fn authenticate<M: Mechanism>(
      67              47 :         mut self,
      68              47 :         mut mechanism: M,
      69              47 :     ) -> super::Result<Outcome<M::Output>> {
      70                 :         loop {
      71              93 :             let input = self.recv().await?;
      72              93 :             let step = mechanism.exchange(input).map_err(|error| {
      73               5 :                 info!(?error, "error during SASL exchange");
      74               5 :                 error
      75              93 :             })?;
      76                 : 
      77                 :             use super::Step;
      78              88 :             return Ok(match step {
      79              46 :                 Step::Continue(moved_mechanism, reply) => {
      80              46 :                     self.send(&ServerMessage::Continue(&reply)).await?;
      81              46 :                     mechanism = moved_mechanism;
      82              46 :                     continue;
      83                 :                 }
      84              38 :                 Step::Success(result, reply) => {
      85              38 :                     self.send(&ServerMessage::Final(&reply)).await?;
      86              38 :                     Outcome::Success(result)
      87                 :                 }
      88               4 :                 Step::Failure(reason) => Outcome::Failure(reason),
      89                 :             });
      90                 :         }
      91              47 :     }
      92                 : }
        

Generated by: LCOV version 2.1-beta