LCOV - code coverage report
Current view: top level - proxy/src/serverless - backend.rs (source / functions) Coverage Total Hit
Test: aca8877be6ceba750c1be359ed71bc1799d52b30.info Lines: 98.4 % 61 60
Test Date: 2024-02-14 18:05:35 Functions: 75.0 % 24 18

            Line data    Source code
       1              : use std::{sync::Arc, time::Duration};
       2              : 
       3              : use async_trait::async_trait;
       4              : use tracing::{field::display, info};
       5              : 
       6              : use crate::{
       7              :     auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
       8              :     compute,
       9              :     config::ProxyConfig,
      10              :     console::{
      11              :         errors::{GetAuthInfoError, WakeComputeError},
      12              :         CachedNodeInfo,
      13              :     },
      14              :     context::RequestMonitoring,
      15              :     proxy::connect_compute::ConnectMechanism,
      16              : };
      17              : 
      18              : use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
      19              : 
      20              : pub struct PoolingBackend {
      21              :     pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
      22              :     pub config: &'static ProxyConfig,
      23              : }
      24              : 
      25              : impl PoolingBackend {
      26           47 :     pub async fn authenticate(
      27           47 :         &self,
      28           47 :         ctx: &mut RequestMonitoring,
      29           47 :         conn_info: &ConnInfo,
      30           47 :     ) -> Result<ComputeCredentials, AuthError> {
      31           47 :         let user_info = conn_info.user_info.clone();
      32           47 :         let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
      33          372 :         let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
      34           47 :         if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
      35            1 :             return Err(AuthError::ip_address_not_allowed());
      36           46 :         }
      37           46 :         let cached_secret = match maybe_secret {
      38            0 :             Some(secret) => secret,
      39          320 :             None => backend.get_role_secret(ctx).await?,
      40              :         };
      41              : 
      42           46 :         let secret = match cached_secret.value.clone() {
      43           45 :             Some(secret) => secret,
      44              :             None => {
      45              :                 // If we don't have an authentication secret, for the http flow we can just return an error.
      46            1 :                 info!("authentication info not found");
      47            1 :                 return Err(AuthError::auth_failed(&*user_info.user));
      48              :             }
      49              :         };
      50           45 :         let auth_outcome =
      51           45 :             crate::auth::validate_password_and_exchange(&conn_info.password, secret)?;
      52           45 :         let res = match auth_outcome {
      53           44 :             crate::sasl::Outcome::Success(key) => Ok(key),
      54            1 :             crate::sasl::Outcome::Failure(reason) => {
      55            1 :                 info!("auth backend failed with an error: {reason}");
      56            1 :                 Err(AuthError::auth_failed(&*conn_info.user_info.user))
      57              :             }
      58              :         };
      59           45 :         res.map(|key| ComputeCredentials {
      60           44 :             info: user_info,
      61           44 :             keys: key,
      62           45 :         })
      63           47 :     }
      64              : 
      65              :     // Wake up the destination if needed. Code here is a bit involved because
      66              :     // we reuse the code from the usual proxy and we need to prepare few structures
      67              :     // that this code expects.
      68           88 :     #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
      69              :     pub async fn connect_to_compute(
      70              :         &self,
      71              :         ctx: &mut RequestMonitoring,
      72              :         conn_info: ConnInfo,
      73              :         keys: ComputeCredentials,
      74              :         force_new: bool,
      75              :     ) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
      76              :         let maybe_client = if !force_new {
      77           24 :             info!("pool: looking for an existing connection");
      78              :             self.pool.get(ctx, &conn_info).await?
      79              :         } else {
      80           20 :             info!("pool: pool is disabled");
      81              :             None
      82              :         };
      83              : 
      84              :         if let Some(client) = maybe_client {
      85              :             return Ok(client);
      86              :         }
      87              :         let conn_id = uuid::Uuid::new_v4();
      88              :         tracing::Span::current().record("conn_id", display(conn_id));
      89           40 :         info!(%conn_id, "pool: opening a new connection '{conn_info}'");
      90           40 :         let backend = self.config.auth_backend.as_ref().map(|_| keys);
      91              :         crate::proxy::connect_compute::connect_to_compute(
      92              :             ctx,
      93              :             &TokioMechanism {
      94              :                 conn_id,
      95              :                 conn_info,
      96              :                 pool: self.pool.clone(),
      97              :             },
      98              :             &backend,
      99              :             false, // do not allow self signed compute for http flow
     100              :         )
     101              :         .await
     102              :     }
     103              : }
     104              : 
     105           18 : #[derive(Debug, thiserror::Error)]
     106              : pub enum HttpConnError {
     107              :     #[error("pooled connection closed at inconsistent state")]
     108              :     ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
     109              :     #[error("could not connection to compute")]
     110              :     ConnectionError(#[from] tokio_postgres::Error),
     111              : 
     112              :     #[error("could not get auth info")]
     113              :     GetAuthInfo(#[from] GetAuthInfoError),
     114              :     #[error("user not authenticated")]
     115              :     AuthError(#[from] AuthError),
     116              :     #[error("wake_compute returned error")]
     117              :     WakeCompute(#[from] WakeComputeError),
     118              : }
     119              : 
     120              : struct TokioMechanism {
     121              :     pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
     122              :     conn_info: ConnInfo,
     123              :     conn_id: uuid::Uuid,
     124              : }
     125              : 
     126              : #[async_trait]
     127              : impl ConnectMechanism for TokioMechanism {
     128              :     type Connection = Client<tokio_postgres::Client>;
     129              :     type ConnectError = tokio_postgres::Error;
     130              :     type Error = HttpConnError;
     131              : 
     132           40 :     async fn connect_once(
     133           40 :         &self,
     134           40 :         ctx: &mut RequestMonitoring,
     135           40 :         node_info: &CachedNodeInfo,
     136           40 :         timeout: Duration,
     137           40 :     ) -> Result<Self::Connection, Self::ConnectError> {
     138           40 :         let mut config = (*node_info.config).clone();
     139           40 :         let config = config
     140           40 :             .user(&self.conn_info.user_info.user)
     141           40 :             .password(&*self.conn_info.password)
     142           40 :             .dbname(&self.conn_info.dbname)
     143           40 :             .connect_timeout(timeout);
     144              : 
     145          121 :         let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
     146              : 
     147           40 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     148           40 :         Ok(poll_client(
     149           40 :             self.pool.clone(),
     150           40 :             ctx,
     151           40 :             self.conn_info.clone(),
     152           40 :             client,
     153           40 :             connection,
     154           40 :             self.conn_id,
     155           40 :             node_info.aux.clone(),
     156           40 :         ))
     157          120 :     }
     158              : 
     159           40 :     fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
     160              : }
        

Generated by: LCOV version 2.1-beta