Line data Source code
1 : //! Abstraction for the string-oriented SASL protocols.
2 :
3 : use super::{messages::ServerMessage, Mechanism};
4 : use crate::stream::PqStream;
5 : use std::io;
6 : use tokio::io::{AsyncRead, AsyncWrite};
7 : use tracing::info;
8 :
9 : /// Abstracts away all peculiarities of the libpq's protocol.
10 : pub(crate) struct SaslStream<'a, S> {
11 : /// The underlying stream.
12 : stream: &'a mut PqStream<S>,
13 : /// Current password message we received from client.
14 : current: bytes::Bytes,
15 : /// First SASL message produced by client.
16 : first: Option<&'a str>,
17 : }
18 :
19 : impl<'a, S> SaslStream<'a, S> {
20 12 : pub(crate) fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
21 12 : Self {
22 12 : stream,
23 12 : current: bytes::Bytes::new(),
24 12 : first: Some(first),
25 12 : }
26 12 : }
27 : }
28 :
29 : impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
30 : // Receive a new SASL message from the client.
31 23 : async fn recv(&mut self) -> io::Result<&str> {
32 23 : if let Some(first) = self.first.take() {
33 12 : return Ok(first);
34 11 : }
35 :
36 11 : self.current = self.stream.read_password_message().await?;
37 11 : let s = std::str::from_utf8(&self.current)
38 11 : .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
39 :
40 11 : Ok(s)
41 23 : }
42 : }
43 :
44 : impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
45 : // Send a SASL message to the client.
46 17 : async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
47 17 : self.stream.write_message(&msg.to_reply()).await?;
48 17 : Ok(())
49 17 : }
50 : }
51 :
52 : /// SASL authentication outcome.
53 : /// It's much easier to match on those two variants
54 : /// than to peek into a noisy protocol error type.
55 : #[must_use = "caller must explicitly check for success"]
56 : pub(crate) enum Outcome<R> {
57 : /// Authentication succeeded and produced some value.
58 : Success(R),
59 : /// Authentication failed (reason attached).
60 : Failure(&'static str),
61 : }
62 :
63 : impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
64 : /// Perform SASL message exchange according to the underlying algorithm
65 : /// until user is either authenticated or denied access.
66 12 : pub(crate) async fn authenticate<M: Mechanism>(
67 12 : mut self,
68 12 : mut mechanism: M,
69 12 : ) -> super::Result<Outcome<M::Output>> {
70 : loop {
71 23 : let input = self.recv().await?;
72 23 : let step = mechanism.exchange(input).map_err(|error| {
73 5 : info!(?error, "error during SASL exchange");
74 5 : error
75 23 : })?;
76 :
77 : use super::Step;
78 18 : return Ok(match step {
79 11 : Step::Continue(moved_mechanism, reply) => {
80 11 : self.send(&ServerMessage::Continue(&reply)).await?;
81 11 : mechanism = moved_mechanism;
82 11 : continue;
83 : }
84 6 : Step::Success(result, reply) => {
85 6 : self.send(&ServerMessage::Final(&reply)).await?;
86 6 : Outcome::Success(result)
87 : }
88 1 : Step::Failure(reason) => Outcome::Failure(reason),
89 : });
90 : }
91 12 : }
92 : }
|