TLA Line data Source code
1 : use std::ops::ControlFlow;
2 :
3 : use super::AuthSuccess;
4 : use crate::{
5 : auth::{self, AuthFlow, ClientCredentials},
6 : compute,
7 : console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
8 : proxy::{handle_try_wake, retry_after},
9 : sasl, scram,
10 : stream::PqStream,
11 : };
12 : use tokio::io::{AsyncRead, AsyncWrite};
13 : use tracing::{error, info, warn};
14 :
15 CBC 27 : pub(super) async fn authenticate(
16 27 : api: &impl console::Api,
17 27 : extra: &ConsoleReqExtra<'_>,
18 27 : creds: &ClientCredentials<'_>,
19 27 : client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
20 27 : ) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
21 27 : info!("fetching user's authentication info");
22 135 : let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
23 1 : // If we don't have an authentication secret, we mock one to
24 1 : // prevent malicious probing (possible due to missing protocol steps).
25 1 : // This mocked secret will never lead to successful authentication.
26 1 : info!("authentication info not found, mocking it");
27 1 : AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
28 27 : });
29 27 :
30 27 : let flow = AuthFlow::new(client);
31 27 : let scram_keys = match info {
32 : AuthInfo::Md5(_) => {
33 UBC 0 : info!("auth endpoint chooses MD5");
34 0 : return Err(auth::AuthError::bad_auth_method("MD5"));
35 : }
36 CBC 27 : AuthInfo::Scram(secret) => {
37 27 : info!("auth endpoint chooses SCRAM");
38 27 : let scram = auth::Scram(&secret);
39 :
40 27 : let auth_flow = flow.begin(scram).await.map_err(|error| {
41 UBC 0 : warn!(?error, "error sending scram acknowledgement");
42 0 : error
43 CBC 27 : })?;
44 :
45 54 : let auth_outcome = auth_flow.authenticate().await.map_err(|error| {
46 UBC 0 : warn!(?error, "error processing scram messages");
47 0 : error
48 CBC 27 : })?;
49 :
50 27 : let client_key = match auth_outcome {
51 24 : sasl::Outcome::Success(key) => key,
52 3 : sasl::Outcome::Failure(reason) => {
53 3 : info!("auth backend failed with an error: {reason}");
54 3 : return Err(auth::AuthError::auth_failed(creds.user));
55 : }
56 : };
57 :
58 24 : Some(compute::ScramKeys {
59 24 : client_key: client_key.as_bytes(),
60 24 : server_key: secret.server_key.as_bytes(),
61 24 : })
62 24 : }
63 24 : };
64 24 :
65 24 : let mut num_retries = 0;
66 24 : let mut node = loop {
67 24 : let wake_res = api.wake_compute(extra, creds).await;
68 24 : match handle_try_wake(wake_res, num_retries) {
69 UBC 0 : Err(e) => {
70 0 : error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
71 0 : return Err(e.into());
72 : }
73 0 : Ok(ControlFlow::Continue(e)) => {
74 0 : warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
75 : }
76 CBC 24 : Ok(ControlFlow::Break(n)) => break n,
77 : }
78 :
79 UBC 0 : let wait_duration = retry_after(num_retries);
80 0 : num_retries += 1;
81 0 : tokio::time::sleep(wait_duration).await;
82 : };
83 CBC 24 : if let Some(keys) = scram_keys {
84 24 : use tokio_postgres::config::AuthKeys;
85 24 : node.config.auth_keys(AuthKeys::ScramSha256(keys));
86 24 : }
87 :
88 24 : Ok(AuthSuccess {
89 24 : reported_auth_ok: false,
90 24 : value: node,
91 24 : })
92 27 : }
|