|             Line data    Source code 
       1              : use super::{ComputeCredentials, ComputeUserInfo};
       2              : use crate::{
       3              :     auth::{self, backend::ComputeCredentialKeys, AuthFlow},
       4              :     compute,
       5              :     config::AuthenticationConfig,
       6              :     console::AuthSecret,
       7              :     context::RequestMonitoring,
       8              :     sasl,
       9              :     stream::{PqStream, Stream},
      10              : };
      11              : use tokio::io::{AsyncRead, AsyncWrite};
      12              : use tracing::{info, warn};
      13              : 
      14            2 : pub(super) async fn authenticate(
      15            2 :     ctx: &mut RequestMonitoring,
      16            2 :     creds: ComputeUserInfo,
      17            2 :     client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
      18            2 :     config: &'static AuthenticationConfig,
      19            2 :     secret: AuthSecret,
      20            2 : ) -> auth::Result<ComputeCredentials> {
      21            2 :     let flow = AuthFlow::new(client);
      22            2 :     let scram_keys = match secret {
      23              :         #[cfg(any(test, feature = "testing"))]
      24              :         AuthSecret::Md5(_) => {
      25            0 :             info!("auth endpoint chooses MD5");
      26            0 :             return Err(auth::AuthError::bad_auth_method("MD5"));
      27              :         }
      28            2 :         AuthSecret::Scram(secret) => {
      29            2 :             info!("auth endpoint chooses SCRAM");
      30            2 :             let scram = auth::Scram(&secret, &mut *ctx);
      31              : 
      32            2 :             let auth_outcome = tokio::time::timeout(
      33            2 :                 config.scram_protocol_timeout,
      34            2 :                 async {
      35            2 : 
      36            2 :                     flow.begin(scram).await.map_err(|error| {
      37            0 :                         warn!(?error, "error sending scram acknowledgement");
      38            0 :                         error
      39            4 :                     })?.authenticate().await.map_err(|error| {
      40            0 :                         warn!(?error, "error processing scram messages");
      41            0 :                         error
      42            2 :                     })
      43            2 :                 }
      44            2 :             )
      45            4 :             .await
      46            2 :             .map_err(|e| {
      47            0 :                 warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
      48            0 :                 auth::AuthError::user_timeout(e)
      49            2 :             })??;
      50              : 
      51            2 :             let client_key = match auth_outcome {
      52            2 :                 sasl::Outcome::Success(key) => key,
      53            0 :                 sasl::Outcome::Failure(reason) => {
      54            0 :                     info!("auth backend failed with an error: {reason}");
      55            0 :                     return Err(auth::AuthError::auth_failed(&*creds.user));
      56              :                 }
      57              :             };
      58              : 
      59            2 :             compute::ScramKeys {
      60            2 :                 client_key: client_key.as_bytes(),
      61            2 :                 server_key: secret.server_key.as_bytes(),
      62            2 :             }
      63            2 :         }
      64            2 :     };
      65            2 : 
      66            2 :     Ok(ComputeCredentials {
      67            2 :         info: creds,
      68            2 :         keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
      69            2 :             scram_keys,
      70            2 :         )),
      71            2 :     })
      72            2 : }
         |