LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - classic.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 78.1 % 64 50
Test Date: 2023-09-06 10:18:01 Functions: 8.6 % 70 6

            Line data    Source code
       1              : use std::ops::ControlFlow;
       2              : 
       3              : use super::AuthSuccess;
       4              : use crate::{
       5              :     auth::{self, AuthFlow, ClientCredentials},
       6              :     compute,
       7              :     console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
       8              :     proxy::{handle_try_wake, retry_after},
       9              :     sasl, scram,
      10              :     stream::PqStream,
      11              : };
      12              : use tokio::io::{AsyncRead, AsyncWrite};
      13              : use tracing::{error, info, warn};
      14              : 
      15           25 : pub(super) async fn authenticate(
      16           25 :     api: &impl console::Api,
      17           25 :     extra: &ConsoleReqExtra<'_>,
      18           25 :     creds: &ClientCredentials<'_>,
      19           25 :     client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
      20           25 : ) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
      21           25 :     info!("fetching user's authentication info");
      22          116 :     let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
      23            1 :         // If we don't have an authentication secret, we mock one to
      24            1 :         // prevent malicious probing (possible due to missing protocol steps).
      25            1 :         // This mocked secret will never lead to successful authentication.
      26            1 :         info!("authentication info not found, mocking it");
      27            1 :         AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
      28           25 :     });
      29           25 : 
      30           25 :     let flow = AuthFlow::new(client);
      31           25 :     let scram_keys = match info {
      32              :         AuthInfo::Md5(_) => {
      33            0 :             info!("auth endpoint chooses MD5");
      34            0 :             return Err(auth::AuthError::bad_auth_method("MD5"));
      35              :         }
      36           25 :         AuthInfo::Scram(secret) => {
      37           25 :             info!("auth endpoint chooses SCRAM");
      38           25 :             let scram = auth::Scram(&secret);
      39              : 
      40           25 :             let auth_flow = flow.begin(scram).await.map_err(|error| {
      41            0 :                 warn!(?error, "error sending scram acknowledgement");
      42            0 :                 error
      43           25 :             })?;
      44              : 
      45           50 :             let auth_outcome = auth_flow.authenticate().await.map_err(|error| {
      46            0 :                 warn!(?error, "error processing scram messages");
      47            0 :                 error
      48           25 :             })?;
      49              : 
      50           25 :             let client_key = match auth_outcome {
      51           22 :                 sasl::Outcome::Success(key) => key,
      52            3 :                 sasl::Outcome::Failure(reason) => {
      53            3 :                     info!("auth backend failed with an error: {reason}");
      54            3 :                     return Err(auth::AuthError::auth_failed(creds.user));
      55              :                 }
      56              :             };
      57              : 
      58           22 :             Some(compute::ScramKeys {
      59           22 :                 client_key: client_key.as_bytes(),
      60           22 :                 server_key: secret.server_key.as_bytes(),
      61           22 :             })
      62           22 :         }
      63           22 :     };
      64           22 : 
      65           22 :     let mut num_retries = 0;
      66           22 :     let mut node = loop {
      67           22 :         let wake_res = api.wake_compute(extra, creds).await;
      68           22 :         match handle_try_wake(wake_res, num_retries) {
      69            0 :             Err(e) => {
      70            0 :                 error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
      71            0 :                 return Err(e.into());
      72              :             }
      73            0 :             Ok(ControlFlow::Continue(e)) => {
      74            0 :                 warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
      75              :             }
      76           22 :             Ok(ControlFlow::Break(n)) => break n,
      77              :         }
      78              : 
      79            0 :         let wait_duration = retry_after(num_retries);
      80            0 :         num_retries += 1;
      81            0 :         tokio::time::sleep(wait_duration).await;
      82              :     };
      83           22 :     if let Some(keys) = scram_keys {
      84           22 :         use tokio_postgres::config::AuthKeys;
      85           22 :         node.config.auth_keys(AuthKeys::ScramSha256(keys));
      86           22 :     }
      87              : 
      88           22 :     Ok(AuthSuccess {
      89           22 :         reported_auth_ok: false,
      90           22 :         value: node,
      91           22 :     })
      92           25 : }
        

Generated by: LCOV version 2.1-beta