Line data Source code
1 : use std::ops::ControlFlow;
2 :
3 : use super::AuthSuccess;
4 : use crate::{
5 : auth::{self, AuthFlow, ClientCredentials},
6 : compute,
7 : console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
8 : proxy::{handle_try_wake, retry_after},
9 : sasl, scram,
10 : stream::PqStream,
11 : };
12 : use tokio::io::{AsyncRead, AsyncWrite};
13 : use tracing::{error, info, warn};
14 :
15 25 : pub(super) async fn authenticate(
16 25 : api: &impl console::Api,
17 25 : extra: &ConsoleReqExtra<'_>,
18 25 : creds: &ClientCredentials<'_>,
19 25 : client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
20 25 : ) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
21 25 : info!("fetching user's authentication info");
22 116 : let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
23 1 : // If we don't have an authentication secret, we mock one to
24 1 : // prevent malicious probing (possible due to missing protocol steps).
25 1 : // This mocked secret will never lead to successful authentication.
26 1 : info!("authentication info not found, mocking it");
27 1 : AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
28 25 : });
29 25 :
30 25 : let flow = AuthFlow::new(client);
31 25 : let scram_keys = match info {
32 : AuthInfo::Md5(_) => {
33 0 : info!("auth endpoint chooses MD5");
34 0 : return Err(auth::AuthError::bad_auth_method("MD5"));
35 : }
36 25 : AuthInfo::Scram(secret) => {
37 25 : info!("auth endpoint chooses SCRAM");
38 25 : let scram = auth::Scram(&secret);
39 :
40 25 : let auth_flow = flow.begin(scram).await.map_err(|error| {
41 0 : warn!(?error, "error sending scram acknowledgement");
42 0 : error
43 25 : })?;
44 :
45 50 : let auth_outcome = auth_flow.authenticate().await.map_err(|error| {
46 0 : warn!(?error, "error processing scram messages");
47 0 : error
48 25 : })?;
49 :
50 25 : let client_key = match auth_outcome {
51 22 : sasl::Outcome::Success(key) => key,
52 3 : sasl::Outcome::Failure(reason) => {
53 3 : info!("auth backend failed with an error: {reason}");
54 3 : return Err(auth::AuthError::auth_failed(creds.user));
55 : }
56 : };
57 :
58 22 : Some(compute::ScramKeys {
59 22 : client_key: client_key.as_bytes(),
60 22 : server_key: secret.server_key.as_bytes(),
61 22 : })
62 22 : }
63 22 : };
64 22 :
65 22 : let mut num_retries = 0;
66 22 : let mut node = loop {
67 22 : let wake_res = api.wake_compute(extra, creds).await;
68 22 : match handle_try_wake(wake_res, num_retries) {
69 0 : Err(e) => {
70 0 : error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
71 0 : return Err(e.into());
72 : }
73 0 : Ok(ControlFlow::Continue(e)) => {
74 0 : warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
75 : }
76 22 : Ok(ControlFlow::Break(n)) => break n,
77 : }
78 :
79 0 : let wait_duration = retry_after(num_retries);
80 0 : num_retries += 1;
81 0 : tokio::time::sleep(wait_duration).await;
82 : };
83 22 : if let Some(keys) = scram_keys {
84 22 : use tokio_postgres::config::AuthKeys;
85 22 : node.config.auth_keys(AuthKeys::ScramSha256(keys));
86 22 : }
87 :
88 22 : Ok(AuthSuccess {
89 22 : reported_auth_ok: false,
90 22 : value: node,
91 22 : })
92 25 : }
|