Line data Source code
1 : //! Abstraction for the string-oriented SASL protocols.
2 :
3 : use std::io;
4 :
5 : use tokio::io::{AsyncRead, AsyncWrite};
6 : use tracing::info;
7 :
8 : use super::messages::ServerMessage;
9 : use super::Mechanism;
10 : use crate::stream::PqStream;
11 :
12 : /// Abstracts away all peculiarities of the libpq's protocol.
13 : pub(crate) struct SaslStream<'a, S> {
14 : /// The underlying stream.
15 : stream: &'a mut PqStream<S>,
16 : /// Current password message we received from client.
17 : current: bytes::Bytes,
18 : /// First SASL message produced by client.
19 : first: Option<&'a str>,
20 : }
21 :
22 : impl<'a, S> SaslStream<'a, S> {
23 12 : pub(crate) fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
24 12 : Self {
25 12 : stream,
26 12 : current: bytes::Bytes::new(),
27 12 : first: Some(first),
28 12 : }
29 12 : }
30 : }
31 :
32 : impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
33 : // Receive a new SASL message from the client.
34 23 : async fn recv(&mut self) -> io::Result<&str> {
35 23 : if let Some(first) = self.first.take() {
36 12 : return Ok(first);
37 11 : }
38 :
39 11 : self.current = self.stream.read_password_message().await?;
40 11 : let s = std::str::from_utf8(&self.current)
41 11 : .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
42 :
43 11 : Ok(s)
44 23 : }
45 : }
46 :
47 : impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
48 : // Send a SASL message to the client.
49 11 : async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
50 11 : self.stream.write_message(&msg.to_reply()).await?;
51 11 : Ok(())
52 11 : }
53 :
54 : // Queue a SASL message for the client.
55 6 : fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
56 6 : self.stream.write_message_noflush(&msg.to_reply())?;
57 6 : Ok(())
58 6 : }
59 : }
60 :
61 : /// SASL authentication outcome.
62 : /// It's much easier to match on those two variants
63 : /// than to peek into a noisy protocol error type.
64 : #[must_use = "caller must explicitly check for success"]
65 : pub(crate) enum Outcome<R> {
66 : /// Authentication succeeded and produced some value.
67 : Success(R),
68 : /// Authentication failed (reason attached).
69 : Failure(&'static str),
70 : }
71 :
72 : impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
73 : /// Perform SASL message exchange according to the underlying algorithm
74 : /// until user is either authenticated or denied access.
75 12 : pub(crate) async fn authenticate<M: Mechanism>(
76 12 : mut self,
77 12 : mut mechanism: M,
78 12 : ) -> super::Result<Outcome<M::Output>> {
79 : loop {
80 23 : let input = self.recv().await?;
81 23 : let step = mechanism.exchange(input).map_err(|error| {
82 5 : info!(?error, "error during SASL exchange");
83 5 : error
84 23 : })?;
85 :
86 : use super::Step;
87 18 : return Ok(match step {
88 11 : Step::Continue(moved_mechanism, reply) => {
89 11 : self.send(&ServerMessage::Continue(&reply)).await?;
90 11 : mechanism = moved_mechanism;
91 11 : continue;
92 : }
93 6 : Step::Success(result, reply) => {
94 6 : self.send_noflush(&ServerMessage::Final(&reply))?;
95 6 : Outcome::Success(result)
96 : }
97 1 : Step::Failure(reason) => Outcome::Failure(reason),
98 : });
99 : }
100 12 : }
101 : }
|