Line data Source code
1 : //! Abstraction for the string-oriented SASL protocols.
2 :
3 : use std::io;
4 :
5 : use tokio::io::{AsyncRead, AsyncWrite};
6 :
7 : use super::{Mechanism, Step};
8 : use crate::context::RequestContext;
9 : use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
10 : use crate::stream::PqStream;
11 :
12 : /// SASL authentication outcome.
13 : /// It's much easier to match on those two variants
14 : /// than to peek into a noisy protocol error type.
15 : #[must_use = "caller must explicitly check for success"]
16 : pub(crate) enum Outcome<R> {
17 : /// Authentication succeeded and produced some value.
18 : Success(R),
19 : /// Authentication failed (reason attached).
20 : Failure(&'static str),
21 : }
22 :
23 13 : pub async fn authenticate<S, F, M>(
24 13 : ctx: &RequestContext,
25 13 : stream: &mut PqStream<S>,
26 13 : mechanism: F,
27 13 : ) -> super::Result<Outcome<M::Output>>
28 13 : where
29 13 : S: AsyncRead + AsyncWrite + Unpin,
30 13 : F: FnOnce(&str) -> super::Result<M>,
31 13 : M: Mechanism,
32 13 : {
33 12 : let (mut mechanism, mut input) = {
34 : // pause the timer while we communicate with the client
35 13 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
36 :
37 : // Initial client message contains the chosen auth method's name.
38 13 : let msg = stream.read_password_message().await?;
39 :
40 12 : let sasl = super::FirstMessage::parse(msg)
41 12 : .ok_or(super::Error::BadClientMessage("bad sasl message"))?;
42 :
43 12 : (mechanism(sasl.method)?, sasl.message)
44 : };
45 :
46 : loop {
47 23 : match mechanism.exchange(input) {
48 11 : Ok(Step::Continue(moved_mechanism, reply)) => {
49 11 : mechanism = moved_mechanism;
50 11 :
51 11 : // write reply
52 11 : let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
53 11 : stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
54 11 : drop(reply);
55 11 : }
56 6 : Ok(Step::Success(result, reply)) => {
57 : // write reply
58 6 : let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
59 6 : stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
60 6 : stream.write_message(BeMessage::AuthenticationOk);
61 :
62 : // exit with success
63 6 : break Ok(Outcome::Success(result));
64 : }
65 : // exit with failure
66 1 : Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)),
67 5 : Err(error) => {
68 5 : tracing::info!(?error, "error during SASL exchange");
69 5 : return Err(error);
70 : }
71 : }
72 :
73 : // pause the timer while we communicate with the client
74 11 : let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
75 :
76 : // get next input
77 11 : stream.flush().await?;
78 11 : let msg = stream.read_password_message().await?;
79 11 : input = std::str::from_utf8(msg)
80 11 : .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
81 : }
82 13 : }
|