LCOV - code coverage report
Current view: top level - proxy/src/serverless - backend.rs (source / functions) Coverage Total Hit
Test: 792183ae0ef4f1f8b22e9ac7e8748740ab73f873.info Lines: 0.0 % 123 0
Test Date: 2024-06-26 01:04:33 Functions: 0.0 % 22 0

            Line data    Source code
       1              : use std::{sync::Arc, time::Duration};
       2              : 
       3              : use async_trait::async_trait;
       4              : use tracing::{field::display, info};
       5              : 
       6              : use crate::{
       7              :     auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
       8              :     compute,
       9              :     config::{AuthenticationConfig, ProxyConfig},
      10              :     console::{
      11              :         errors::{GetAuthInfoError, WakeComputeError},
      12              :         locks::ApiLocks,
      13              :         provider::ApiLockError,
      14              :         CachedNodeInfo,
      15              :     },
      16              :     context::RequestMonitoring,
      17              :     error::{ErrorKind, ReportableError, UserFacingError},
      18              :     intern::EndpointIdInt,
      19              :     proxy::{
      20              :         connect_compute::ConnectMechanism,
      21              :         retry::{CouldRetry, ShouldRetryWakeCompute},
      22              :     },
      23              :     rate_limiter::EndpointRateLimiter,
      24              :     Host,
      25              : };
      26              : 
      27              : use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
      28              : 
      29              : pub struct PoolingBackend {
      30              :     pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
      31              :     pub config: &'static ProxyConfig,
      32              :     pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      33              : }
      34              : 
      35              : impl PoolingBackend {
      36            0 :     pub async fn authenticate(
      37            0 :         &self,
      38            0 :         ctx: &mut RequestMonitoring,
      39            0 :         config: &AuthenticationConfig,
      40            0 :         conn_info: &ConnInfo,
      41            0 :     ) -> Result<ComputeCredentials, AuthError> {
      42            0 :         let user_info = conn_info.user_info.clone();
      43            0 :         let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
      44            0 :         let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
      45            0 :         if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
      46            0 :             return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
      47            0 :         }
      48            0 :         if !self
      49            0 :             .endpoint_rate_limiter
      50            0 :             .check(conn_info.user_info.endpoint.clone().into(), 1)
      51              :         {
      52            0 :             return Err(AuthError::too_many_connections());
      53            0 :         }
      54            0 :         let cached_secret = match maybe_secret {
      55            0 :             Some(secret) => secret,
      56            0 :             None => backend.get_role_secret(ctx).await?,
      57              :         };
      58              : 
      59            0 :         let secret = match cached_secret.value.clone() {
      60            0 :             Some(secret) => self.config.authentication_config.check_rate_limit(
      61            0 :                 ctx,
      62            0 :                 config,
      63            0 :                 secret,
      64            0 :                 &user_info.endpoint,
      65            0 :                 true,
      66            0 :             )?,
      67              :             None => {
      68              :                 // If we don't have an authentication secret, for the http flow we can just return an error.
      69            0 :                 info!("authentication info not found");
      70            0 :                 return Err(AuthError::auth_failed(&*user_info.user));
      71              :             }
      72              :         };
      73            0 :         let ep = EndpointIdInt::from(&conn_info.user_info.endpoint);
      74            0 :         let auth_outcome = crate::auth::validate_password_and_exchange(
      75            0 :             &config.thread_pool,
      76            0 :             ep,
      77            0 :             &conn_info.password,
      78            0 :             secret,
      79            0 :         )
      80            0 :         .await?;
      81            0 :         let res = match auth_outcome {
      82            0 :             crate::sasl::Outcome::Success(key) => {
      83            0 :                 info!("user successfully authenticated");
      84            0 :                 Ok(key)
      85              :             }
      86            0 :             crate::sasl::Outcome::Failure(reason) => {
      87            0 :                 info!("auth backend failed with an error: {reason}");
      88            0 :                 Err(AuthError::auth_failed(&*conn_info.user_info.user))
      89              :             }
      90              :         };
      91            0 :         res.map(|key| ComputeCredentials {
      92            0 :             info: user_info,
      93            0 :             keys: key,
      94            0 :         })
      95            0 :     }
      96              : 
      97              :     // Wake up the destination if needed. Code here is a bit involved because
      98              :     // we reuse the code from the usual proxy and we need to prepare few structures
      99              :     // that this code expects.
     100            0 :     #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
     101              :     pub async fn connect_to_compute(
     102              :         &self,
     103              :         ctx: &mut RequestMonitoring,
     104              :         conn_info: ConnInfo,
     105              :         keys: ComputeCredentials,
     106              :         force_new: bool,
     107              :     ) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
     108              :         let maybe_client = if !force_new {
     109              :             info!("pool: looking for an existing connection");
     110              :             self.pool.get(ctx, &conn_info)?
     111              :         } else {
     112              :             info!("pool: pool is disabled");
     113              :             None
     114              :         };
     115              : 
     116              :         if let Some(client) = maybe_client {
     117              :             return Ok(client);
     118              :         }
     119              :         let conn_id = uuid::Uuid::new_v4();
     120              :         tracing::Span::current().record("conn_id", display(conn_id));
     121              :         info!(%conn_id, "pool: opening a new connection '{conn_info}'");
     122            0 :         let backend = self.config.auth_backend.as_ref().map(|_| keys);
     123              :         crate::proxy::connect_compute::connect_to_compute(
     124              :             ctx,
     125              :             &TokioMechanism {
     126              :                 conn_id,
     127              :                 conn_info,
     128              :                 pool: self.pool.clone(),
     129              :                 locks: &self.config.connect_compute_locks,
     130              :             },
     131              :             &backend,
     132              :             false, // do not allow self signed compute for http flow
     133              :             self.config.wake_compute_retry_config,
     134              :             self.config.connect_to_compute_retry_config,
     135              :         )
     136              :         .await
     137              :     }
     138              : }
     139              : 
     140            0 : #[derive(Debug, thiserror::Error)]
     141              : pub enum HttpConnError {
     142              :     #[error("pooled connection closed at inconsistent state")]
     143              :     ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
     144              :     #[error("could not connection to compute")]
     145              :     ConnectionError(#[from] tokio_postgres::Error),
     146              : 
     147              :     #[error("could not get auth info")]
     148              :     GetAuthInfo(#[from] GetAuthInfoError),
     149              :     #[error("user not authenticated")]
     150              :     AuthError(#[from] AuthError),
     151              :     #[error("wake_compute returned error")]
     152              :     WakeCompute(#[from] WakeComputeError),
     153              :     #[error("error acquiring resource permit: {0}")]
     154              :     TooManyConnectionAttempts(#[from] ApiLockError),
     155              : }
     156              : 
     157              : impl ReportableError for HttpConnError {
     158            0 :     fn get_error_kind(&self) -> ErrorKind {
     159            0 :         match self {
     160            0 :             HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
     161            0 :             HttpConnError::ConnectionError(p) => p.get_error_kind(),
     162            0 :             HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
     163            0 :             HttpConnError::AuthError(a) => a.get_error_kind(),
     164            0 :             HttpConnError::WakeCompute(w) => w.get_error_kind(),
     165            0 :             HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
     166              :         }
     167            0 :     }
     168              : }
     169              : 
     170              : impl UserFacingError for HttpConnError {
     171            0 :     fn to_string_client(&self) -> String {
     172            0 :         match self {
     173            0 :             HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
     174            0 :             HttpConnError::ConnectionError(p) => p.to_string(),
     175            0 :             HttpConnError::GetAuthInfo(c) => c.to_string_client(),
     176            0 :             HttpConnError::AuthError(c) => c.to_string_client(),
     177            0 :             HttpConnError::WakeCompute(c) => c.to_string_client(),
     178              :             HttpConnError::TooManyConnectionAttempts(_) => {
     179            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
     180              :             }
     181              :         }
     182            0 :     }
     183              : }
     184              : 
     185              : impl CouldRetry for HttpConnError {
     186            0 :     fn could_retry(&self) -> bool {
     187            0 :         match self {
     188            0 :             HttpConnError::ConnectionError(e) => e.could_retry(),
     189            0 :             HttpConnError::ConnectionClosedAbruptly(_) => false,
     190            0 :             HttpConnError::GetAuthInfo(_) => false,
     191            0 :             HttpConnError::AuthError(_) => false,
     192            0 :             HttpConnError::WakeCompute(_) => false,
     193            0 :             HttpConnError::TooManyConnectionAttempts(_) => false,
     194              :         }
     195            0 :     }
     196              : }
     197              : impl ShouldRetryWakeCompute for HttpConnError {
     198            0 :     fn should_retry_wake_compute(&self) -> bool {
     199            0 :         match self {
     200            0 :             HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
     201              :             // we never checked cache validity
     202            0 :             HttpConnError::TooManyConnectionAttempts(_) => false,
     203            0 :             _ => true,
     204              :         }
     205            0 :     }
     206              : }
     207              : 
     208              : struct TokioMechanism {
     209              :     pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
     210              :     conn_info: ConnInfo,
     211              :     conn_id: uuid::Uuid,
     212              : 
     213              :     /// connect_to_compute concurrency lock
     214              :     locks: &'static ApiLocks<Host>,
     215              : }
     216              : 
     217              : #[async_trait]
     218              : impl ConnectMechanism for TokioMechanism {
     219              :     type Connection = Client<tokio_postgres::Client>;
     220              :     type ConnectError = HttpConnError;
     221              :     type Error = HttpConnError;
     222              : 
     223            0 :     async fn connect_once(
     224            0 :         &self,
     225            0 :         ctx: &mut RequestMonitoring,
     226            0 :         node_info: &CachedNodeInfo,
     227            0 :         timeout: Duration,
     228            0 :     ) -> Result<Self::Connection, Self::ConnectError> {
     229            0 :         let host = node_info.config.get_host()?;
     230            0 :         let permit = self.locks.get_permit(&host).await?;
     231            0 : 
     232            0 :         let mut config = (*node_info.config).clone();
     233            0 :         let config = config
     234            0 :             .user(&self.conn_info.user_info.user)
     235            0 :             .password(&*self.conn_info.password)
     236            0 :             .dbname(&self.conn_info.dbname)
     237            0 :             .connect_timeout(timeout);
     238            0 : 
     239            0 :         config
     240            0 :             .param("client_encoding", "UTF8")
     241            0 :             .expect("client encoding UTF8 is always valid");
     242            0 : 
     243            0 :         let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
     244            0 :         let res = config.connect(tokio_postgres::NoTls).await;
     245            0 :         drop(pause);
     246            0 :         let (client, connection) = permit.release_result(res)?;
     247            0 : 
     248            0 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     249            0 :         Ok(poll_client(
     250            0 :             self.pool.clone(),
     251            0 :             ctx,
     252            0 :             self.conn_info.clone(),
     253            0 :             client,
     254            0 :             connection,
     255            0 :             self.conn_id,
     256            0 :             node_info.aux.clone(),
     257            0 :         ))
     258            0 :     }
     259              : 
     260            0 :     fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
     261              : }
        

Generated by: LCOV version 2.1-beta