Line data Source code
1 : //! Abstraction for the string-oriented SASL protocols.
2 :
3 : use super::{messages::ServerMessage, Mechanism};
4 : use crate::stream::PqStream;
5 : use std::io;
6 : use tokio::io::{AsyncRead, AsyncWrite};
7 : use tracing::info;
8 :
9 : /// Abstracts away all peculiarities of the libpq's protocol.
10 : pub struct SaslStream<'a, S> {
11 : /// The underlying stream.
12 : stream: &'a mut PqStream<S>,
13 : /// Current password message we received from client.
14 : current: bytes::Bytes,
15 : /// First SASL message produced by client.
16 : first: Option<&'a str>,
17 : }
18 :
19 : impl<'a, S> SaslStream<'a, S> {
20 24 : pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
21 24 : Self {
22 24 : stream,
23 24 : current: bytes::Bytes::new(),
24 24 : first: Some(first),
25 24 : }
26 24 : }
27 : }
28 :
29 : impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
30 : // Receive a new SASL message from the client.
31 46 : async fn recv(&mut self) -> io::Result<&str> {
32 46 : if let Some(first) = self.first.take() {
33 24 : return Ok(first);
34 22 : }
35 :
36 22 : self.current = self.stream.read_password_message().await?;
37 22 : let s = std::str::from_utf8(&self.current)
38 22 : .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
39 :
40 22 : Ok(s)
41 46 : }
42 : }
43 :
44 : impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
45 : // Send a SASL message to the client.
46 34 : async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
47 34 : self.stream.write_message(&msg.to_reply()).await?;
48 34 : Ok(())
49 34 : }
50 : }
51 :
52 : /// SASL authentication outcome.
53 : /// It's much easier to match on those two variants
54 : /// than to peek into a noisy protocol error type.
55 : #[must_use = "caller must explicitly check for success"]
56 : pub enum Outcome<R> {
57 : /// Authentication succeeded and produced some value.
58 : Success(R),
59 : /// Authentication failed (reason attached).
60 : Failure(&'static str),
61 : }
62 :
63 : impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
64 : /// Perform SASL message exchange according to the underlying algorithm
65 : /// until user is either authenticated or denied access.
66 24 : pub async fn authenticate<M: Mechanism>(
67 24 : mut self,
68 24 : mut mechanism: M,
69 24 : ) -> super::Result<Outcome<M::Output>> {
70 : loop {
71 46 : let input = self.recv().await?;
72 46 : let step = mechanism.exchange(input).map_err(|error| {
73 10 : info!(?error, "error during SASL exchange");
74 10 : error
75 46 : })?;
76 :
77 : use super::Step;
78 36 : return Ok(match step {
79 22 : Step::Continue(moved_mechanism, reply) => {
80 22 : self.send(&ServerMessage::Continue(&reply)).await?;
81 22 : mechanism = moved_mechanism;
82 22 : continue;
83 : }
84 12 : Step::Success(result, reply) => {
85 12 : self.send(&ServerMessage::Final(&reply)).await?;
86 12 : Outcome::Success(result)
87 : }
88 2 : Step::Failure(reason) => Outcome::Failure(reason),
89 : });
90 : }
91 24 : }
92 : }
|