LCOV - code coverage report
Current view: top level - proxy/src/serverless - backend.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 98.2 % 57 56
Test Date: 2024-02-12 20:26:03 Functions: 73.9 % 23 17

            Line data    Source code
       1              : use std::{sync::Arc, time::Duration};
       2              : 
       3              : use async_trait::async_trait;
       4              : use tracing::{field::display, info};
       5              : 
       6              : use crate::{
       7              :     auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError},
       8              :     compute,
       9              :     config::ProxyConfig,
      10              :     console::{
      11              :         errors::{GetAuthInfoError, WakeComputeError},
      12              :         CachedNodeInfo,
      13              :     },
      14              :     context::RequestMonitoring,
      15              :     proxy::connect_compute::ConnectMechanism,
      16              : };
      17              : 
      18              : use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
      19              : 
      20              : pub struct PoolingBackend {
      21              :     pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
      22              :     pub config: &'static ProxyConfig,
      23              : }
      24              : 
      25              : impl PoolingBackend {
      26           47 :     pub async fn authenticate(
      27           47 :         &self,
      28           47 :         ctx: &mut RequestMonitoring,
      29           47 :         conn_info: &ConnInfo,
      30           47 :     ) -> Result<ComputeCredentialKeys, AuthError> {
      31           47 :         let user_info = conn_info.user_info.clone();
      32           47 :         let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
      33          366 :         let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
      34           47 :         if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
      35            1 :             return Err(AuthError::ip_address_not_allowed());
      36           46 :         }
      37           46 :         let cached_secret = match maybe_secret {
      38            0 :             Some(secret) => secret,
      39          317 :             None => backend.get_role_secret(ctx).await?,
      40              :         };
      41              : 
      42           46 :         let secret = match cached_secret.value.clone() {
      43           45 :             Some(secret) => secret,
      44              :             None => {
      45              :                 // If we don't have an authentication secret, for the http flow we can just return an error.
      46            1 :                 info!("authentication info not found");
      47            1 :                 return Err(AuthError::auth_failed(&*user_info.user));
      48              :             }
      49              :         };
      50           45 :         let auth_outcome =
      51           45 :             crate::auth::validate_password_and_exchange(&conn_info.password, secret)?;
      52           45 :         match auth_outcome {
      53           44 :             crate::sasl::Outcome::Success(key) => Ok(key),
      54            1 :             crate::sasl::Outcome::Failure(reason) => {
      55            1 :                 info!("auth backend failed with an error: {reason}");
      56            1 :                 Err(AuthError::auth_failed(&*conn_info.user_info.user))
      57              :             }
      58              :         }
      59           47 :     }
      60              : 
      61              :     // Wake up the destination if needed. Code here is a bit involved because
      62              :     // we reuse the code from the usual proxy and we need to prepare few structures
      63              :     // that this code expects.
      64           88 :     #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
      65              :     pub async fn connect_to_compute(
      66              :         &self,
      67              :         ctx: &mut RequestMonitoring,
      68              :         conn_info: ConnInfo,
      69              :         keys: ComputeCredentialKeys,
      70              :         force_new: bool,
      71              :     ) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
      72              :         let maybe_client = if !force_new {
      73           24 :             info!("pool: looking for an existing connection");
      74              :             self.pool.get(ctx, &conn_info).await?
      75              :         } else {
      76           20 :             info!("pool: pool is disabled");
      77              :             None
      78              :         };
      79              : 
      80              :         if let Some(client) = maybe_client {
      81              :             return Ok(client);
      82              :         }
      83              :         let conn_id = uuid::Uuid::new_v4();
      84              :         tracing::Span::current().record("conn_id", display(conn_id));
      85           40 :         info!("pool: opening a new connection '{conn_info}'");
      86              :         let backend = self
      87              :             .config
      88              :             .auth_backend
      89              :             .as_ref()
      90           40 :             .map(|_| conn_info.user_info.clone());
      91              : 
      92              :         let mut node_info = backend
      93              :             .wake_compute(ctx)
      94              :             .await?
      95              :             .ok_or(HttpConnError::NoComputeInfo)?;
      96              : 
      97              :         match keys {
      98              :             #[cfg(any(test, feature = "testing"))]
      99              :             ComputeCredentialKeys::Password(password) => node_info.config.password(password),
     100              :             ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
     101              :         };
     102              : 
     103              :         ctx.set_project(node_info.aux.clone());
     104              : 
     105              :         crate::proxy::connect_compute::connect_to_compute(
     106              :             ctx,
     107              :             &TokioMechanism {
     108              :                 conn_id,
     109              :                 conn_info,
     110              :                 pool: self.pool.clone(),
     111              :             },
     112              :             node_info,
     113              :             &backend,
     114              :         )
     115              :         .await
     116              :     }
     117              : }
     118              : 
     119           18 : #[derive(Debug, thiserror::Error)]
     120              : pub enum HttpConnError {
     121              :     #[error("pooled connection closed at inconsistent state")]
     122              :     ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
     123              :     #[error("could not connection to compute")]
     124              :     ConnectionError(#[from] tokio_postgres::Error),
     125              : 
     126              :     #[error("could not get auth info")]
     127              :     GetAuthInfo(#[from] GetAuthInfoError),
     128              :     #[error("user not authenticated")]
     129              :     AuthError(#[from] AuthError),
     130              :     #[error("wake_compute returned error")]
     131              :     WakeCompute(#[from] WakeComputeError),
     132              :     #[error("wake_compute returned nothing")]
     133              :     NoComputeInfo,
     134              : }
     135              : 
     136              : struct TokioMechanism {
     137              :     pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
     138              :     conn_info: ConnInfo,
     139              :     conn_id: uuid::Uuid,
     140              : }
     141              : 
     142              : #[async_trait]
     143              : impl ConnectMechanism for TokioMechanism {
     144              :     type Connection = Client<tokio_postgres::Client>;
     145              :     type ConnectError = tokio_postgres::Error;
     146              :     type Error = HttpConnError;
     147              : 
     148           40 :     async fn connect_once(
     149           40 :         &self,
     150           40 :         ctx: &mut RequestMonitoring,
     151           40 :         node_info: &CachedNodeInfo,
     152           40 :         timeout: Duration,
     153           40 :     ) -> Result<Self::Connection, Self::ConnectError> {
     154           40 :         let mut config = (*node_info.config).clone();
     155           40 :         let config = config
     156           40 :             .user(&self.conn_info.user_info.user)
     157           40 :             .password(&*self.conn_info.password)
     158           40 :             .dbname(&self.conn_info.dbname)
     159           40 :             .connect_timeout(timeout);
     160              : 
     161          128 :         let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
     162              : 
     163           40 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     164           40 :         Ok(poll_client(
     165           40 :             self.pool.clone(),
     166           40 :             ctx,
     167           40 :             self.conn_info.clone(),
     168           40 :             client,
     169           40 :             connection,
     170           40 :             self.conn_id,
     171           40 :             node_info.aux.clone(),
     172           40 :         ))
     173          120 :     }
     174              : 
     175           40 :     fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
     176              : }
        

Generated by: LCOV version 2.1-beta