Line data Source code
1 : //! Abstraction for the string-oriented SASL protocols.
2 :
3 : use std::io;
4 :
5 : use tokio::io::{AsyncRead, AsyncWrite};
6 : use tracing::info;
7 :
8 : use super::messages::ServerMessage;
9 : use super::Mechanism;
10 : use crate::stream::PqStream;
11 :
12 : /// Abstracts away all peculiarities of the libpq's protocol.
13 : pub(crate) struct SaslStream<'a, S> {
14 : /// The underlying stream.
15 : stream: &'a mut PqStream<S>,
16 : /// Current password message we received from client.
17 : current: bytes::Bytes,
18 : /// First SASL message produced by client.
19 : first: Option<&'a str>,
20 : }
21 :
22 : impl<'a, S> SaslStream<'a, S> {
23 12 : pub(crate) fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
24 12 : Self {
25 12 : stream,
26 12 : current: bytes::Bytes::new(),
27 12 : first: Some(first),
28 12 : }
29 12 : }
30 : }
31 :
32 : impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
33 : // Receive a new SASL message from the client.
34 23 : async fn recv(&mut self) -> io::Result<&str> {
35 23 : if let Some(first) = self.first.take() {
36 12 : return Ok(first);
37 11 : }
38 :
39 11 : self.current = self.stream.read_password_message().await?;
40 11 : let s = std::str::from_utf8(&self.current)
41 11 : .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
42 :
43 11 : Ok(s)
44 23 : }
45 : }
46 :
47 : impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
48 : // Send a SASL message to the client.
49 17 : async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
50 17 : self.stream.write_message(&msg.to_reply()).await?;
51 17 : Ok(())
52 17 : }
53 : }
54 :
55 : /// SASL authentication outcome.
56 : /// It's much easier to match on those two variants
57 : /// than to peek into a noisy protocol error type.
58 : #[must_use = "caller must explicitly check for success"]
59 : pub(crate) enum Outcome<R> {
60 : /// Authentication succeeded and produced some value.
61 : Success(R),
62 : /// Authentication failed (reason attached).
63 : Failure(&'static str),
64 : }
65 :
66 : impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
67 : /// Perform SASL message exchange according to the underlying algorithm
68 : /// until user is either authenticated or denied access.
69 12 : pub(crate) async fn authenticate<M: Mechanism>(
70 12 : mut self,
71 12 : mut mechanism: M,
72 12 : ) -> super::Result<Outcome<M::Output>> {
73 : loop {
74 23 : let input = self.recv().await?;
75 23 : let step = mechanism.exchange(input).map_err(|error| {
76 5 : info!(?error, "error during SASL exchange");
77 5 : error
78 23 : })?;
79 :
80 : use super::Step;
81 18 : return Ok(match step {
82 11 : Step::Continue(moved_mechanism, reply) => {
83 11 : self.send(&ServerMessage::Continue(&reply)).await?;
84 11 : mechanism = moved_mechanism;
85 11 : continue;
86 : }
87 6 : Step::Success(result, reply) => {
88 6 : self.send(&ServerMessage::Final(&reply)).await?;
89 6 : Outcome::Success(result)
90 : }
91 1 : Step::Failure(reason) => Outcome::Failure(reason),
92 : });
93 : }
94 12 : }
95 : }
|