|             Line data    Source code 
       1              : use std::{sync::Arc, time::Duration};
       2              : 
       3              : use async_trait::async_trait;
       4              : use tracing::{field::display, info};
       5              : 
       6              : use crate::{
       7              :     auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
       8              :     compute,
       9              :     config::{AuthenticationConfig, ProxyConfig},
      10              :     console::{
      11              :         errors::{GetAuthInfoError, WakeComputeError},
      12              :         locks::ApiLocks,
      13              :         provider::ApiLockError,
      14              :         CachedNodeInfo,
      15              :     },
      16              :     context::RequestMonitoring,
      17              :     error::{ErrorKind, ReportableError, UserFacingError},
      18              :     intern::EndpointIdInt,
      19              :     proxy::{connect_compute::ConnectMechanism, retry::ShouldRetry},
      20              :     rate_limiter::EndpointRateLimiter,
      21              :     Host,
      22              : };
      23              : 
      24              : use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
      25              : 
      26              : pub struct PoolingBackend {
      27              :     pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
      28              :     pub config: &'static ProxyConfig,
      29              :     pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      30              : }
      31              : 
      32              : impl PoolingBackend {
      33            0 :     pub async fn authenticate(
      34            0 :         &self,
      35            0 :         ctx: &mut RequestMonitoring,
      36            0 :         config: &AuthenticationConfig,
      37            0 :         conn_info: &ConnInfo,
      38            0 :     ) -> Result<ComputeCredentials, AuthError> {
      39            0 :         let user_info = conn_info.user_info.clone();
      40            0 :         let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
      41            0 :         let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
      42            0 :         if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
      43            0 :             return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
      44            0 :         }
      45            0 :         if !self
      46            0 :             .endpoint_rate_limiter
      47            0 :             .check(conn_info.user_info.endpoint.clone().into(), 1)
      48              :         {
      49            0 :             return Err(AuthError::too_many_connections());
      50            0 :         }
      51            0 :         let cached_secret = match maybe_secret {
      52            0 :             Some(secret) => secret,
      53            0 :             None => backend.get_role_secret(ctx).await?,
      54              :         };
      55              : 
      56            0 :         let secret = match cached_secret.value.clone() {
      57            0 :             Some(secret) => self.config.authentication_config.check_rate_limit(
      58            0 :                 ctx,
      59            0 :                 config,
      60            0 :                 secret,
      61            0 :                 &user_info.endpoint,
      62            0 :                 true,
      63            0 :             )?,
      64              :             None => {
      65              :                 // If we don't have an authentication secret, for the http flow we can just return an error.
      66            0 :                 info!("authentication info not found");
      67            0 :                 return Err(AuthError::auth_failed(&*user_info.user));
      68              :             }
      69              :         };
      70            0 :         let ep = EndpointIdInt::from(&conn_info.user_info.endpoint);
      71            0 :         let auth_outcome = crate::auth::validate_password_and_exchange(
      72            0 :             &config.thread_pool,
      73            0 :             ep,
      74            0 :             &conn_info.password,
      75            0 :             secret,
      76            0 :         )
      77            0 :         .await?;
      78            0 :         let res = match auth_outcome {
      79            0 :             crate::sasl::Outcome::Success(key) => {
      80            0 :                 info!("user successfully authenticated");
      81            0 :                 Ok(key)
      82              :             }
      83            0 :             crate::sasl::Outcome::Failure(reason) => {
      84            0 :                 info!("auth backend failed with an error: {reason}");
      85            0 :                 Err(AuthError::auth_failed(&*conn_info.user_info.user))
      86              :             }
      87              :         };
      88            0 :         res.map(|key| ComputeCredentials {
      89            0 :             info: user_info,
      90            0 :             keys: key,
      91            0 :         })
      92            0 :     }
      93              : 
      94              :     // Wake up the destination if needed. Code here is a bit involved because
      95              :     // we reuse the code from the usual proxy and we need to prepare few structures
      96              :     // that this code expects.
      97            0 :     #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
      98              :     pub async fn connect_to_compute(
      99              :         &self,
     100              :         ctx: &mut RequestMonitoring,
     101              :         conn_info: ConnInfo,
     102              :         keys: ComputeCredentials,
     103              :         force_new: bool,
     104              :     ) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
     105              :         let maybe_client = if !force_new {
     106              :             info!("pool: looking for an existing connection");
     107              :             self.pool.get(ctx, &conn_info)?
     108              :         } else {
     109              :             info!("pool: pool is disabled");
     110              :             None
     111              :         };
     112              : 
     113              :         if let Some(client) = maybe_client {
     114              :             return Ok(client);
     115              :         }
     116              :         let conn_id = uuid::Uuid::new_v4();
     117              :         tracing::Span::current().record("conn_id", display(conn_id));
     118              :         info!(%conn_id, "pool: opening a new connection '{conn_info}'");
     119            0 :         let backend = self.config.auth_backend.as_ref().map(|_| keys);
     120              :         crate::proxy::connect_compute::connect_to_compute(
     121              :             ctx,
     122              :             &TokioMechanism {
     123              :                 conn_id,
     124              :                 conn_info,
     125              :                 pool: self.pool.clone(),
     126              :                 locks: &self.config.connect_compute_locks,
     127              :             },
     128              :             &backend,
     129              :             false, // do not allow self signed compute for http flow
     130              :             self.config.wake_compute_retry_config,
     131              :             self.config.connect_to_compute_retry_config,
     132              :         )
     133              :         .await
     134              :     }
     135              : }
     136              : 
     137            0 : #[derive(Debug, thiserror::Error)]
     138              : pub enum HttpConnError {
     139              :     #[error("pooled connection closed at inconsistent state")]
     140              :     ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
     141              :     #[error("could not connection to compute")]
     142              :     ConnectionError(#[from] tokio_postgres::Error),
     143              : 
     144              :     #[error("could not get auth info")]
     145              :     GetAuthInfo(#[from] GetAuthInfoError),
     146              :     #[error("user not authenticated")]
     147              :     AuthError(#[from] AuthError),
     148              :     #[error("wake_compute returned error")]
     149              :     WakeCompute(#[from] WakeComputeError),
     150              :     #[error("error acquiring resource permit: {0}")]
     151              :     TooManyConnectionAttempts(#[from] ApiLockError),
     152              : }
     153              : 
     154              : impl ReportableError for HttpConnError {
     155            0 :     fn get_error_kind(&self) -> ErrorKind {
     156            0 :         match self {
     157            0 :             HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
     158            0 :             HttpConnError::ConnectionError(p) => p.get_error_kind(),
     159            0 :             HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
     160            0 :             HttpConnError::AuthError(a) => a.get_error_kind(),
     161            0 :             HttpConnError::WakeCompute(w) => w.get_error_kind(),
     162            0 :             HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
     163              :         }
     164            0 :     }
     165              : }
     166              : 
     167              : impl UserFacingError for HttpConnError {
     168            0 :     fn to_string_client(&self) -> String {
     169            0 :         match self {
     170            0 :             HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
     171            0 :             HttpConnError::ConnectionError(p) => p.to_string(),
     172            0 :             HttpConnError::GetAuthInfo(c) => c.to_string_client(),
     173            0 :             HttpConnError::AuthError(c) => c.to_string_client(),
     174            0 :             HttpConnError::WakeCompute(c) => c.to_string_client(),
     175              :             HttpConnError::TooManyConnectionAttempts(_) => {
     176            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
     177              :             }
     178              :         }
     179            0 :     }
     180              : }
     181              : 
     182              : impl ShouldRetry for HttpConnError {
     183            0 :     fn could_retry(&self) -> bool {
     184            0 :         match self {
     185            0 :             HttpConnError::ConnectionError(e) => e.could_retry(),
     186            0 :             HttpConnError::ConnectionClosedAbruptly(_) => false,
     187            0 :             HttpConnError::GetAuthInfo(_) => false,
     188            0 :             HttpConnError::AuthError(_) => false,
     189            0 :             HttpConnError::WakeCompute(_) => false,
     190            0 :             HttpConnError::TooManyConnectionAttempts(_) => false,
     191              :         }
     192            0 :     }
     193            0 :     fn should_retry_database_address(&self) -> bool {
     194            0 :         match self {
     195            0 :             HttpConnError::ConnectionError(e) => e.should_retry_database_address(),
     196              :             // we never checked cache validity
     197            0 :             HttpConnError::TooManyConnectionAttempts(_) => false,
     198            0 :             _ => true,
     199              :         }
     200            0 :     }
     201              : }
     202              : 
     203              : struct TokioMechanism {
     204              :     pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
     205              :     conn_info: ConnInfo,
     206              :     conn_id: uuid::Uuid,
     207              : 
     208              :     /// connect_to_compute concurrency lock
     209              :     locks: &'static ApiLocks<Host>,
     210              : }
     211              : 
     212              : #[async_trait]
     213              : impl ConnectMechanism for TokioMechanism {
     214              :     type Connection = Client<tokio_postgres::Client>;
     215              :     type ConnectError = HttpConnError;
     216              :     type Error = HttpConnError;
     217              : 
     218            0 :     async fn connect_once(
     219            0 :         &self,
     220            0 :         ctx: &mut RequestMonitoring,
     221            0 :         node_info: &CachedNodeInfo,
     222            0 :         timeout: Duration,
     223            0 :     ) -> Result<Self::Connection, Self::ConnectError> {
     224            0 :         let host = node_info.config.get_host()?;
     225            0 :         let permit = self.locks.get_permit(&host).await?;
     226            0 : 
     227            0 :         let mut config = (*node_info.config).clone();
     228            0 :         let config = config
     229            0 :             .user(&self.conn_info.user_info.user)
     230            0 :             .password(&*self.conn_info.password)
     231            0 :             .dbname(&self.conn_info.dbname)
     232            0 :             .connect_timeout(timeout);
     233            0 : 
     234            0 :         let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
     235            0 :         let res = config.connect(tokio_postgres::NoTls).await;
     236            0 :         drop(pause);
     237            0 :         let (client, connection) = permit.release_result(res)?;
     238            0 : 
     239            0 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     240            0 :         Ok(poll_client(
     241            0 :             self.pool.clone(),
     242            0 :             ctx,
     243            0 :             self.conn_info.clone(),
     244            0 :             client,
     245            0 :             connection,
     246            0 :             self.conn_id,
     247            0 :             node_info.aux.clone(),
     248            0 :         ))
     249            0 :     }
     250              : 
     251            0 :     fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
     252              : }
         |