LCOV - code coverage report
Current view: top level - proxy/src/serverless - conn_pool_lib.rs (source / functions) Coverage Total Hit
Test: 6df3fc19ec669bcfbbf9aba41d1338898d24eaa0.info Lines: 53.2 % 468 249
Test Date: 2025-03-12 18:28:53 Functions: 24.7 % 93 23

            Line data    Source code
       1              : use std::collections::HashMap;
       2              : use std::marker::PhantomData;
       3              : use std::ops::Deref;
       4              : use std::sync::atomic::{self, AtomicUsize};
       5              : use std::sync::{Arc, Weak};
       6              : use std::time::Duration;
       7              : 
       8              : use clashmap::ClashMap;
       9              : use parking_lot::RwLock;
      10              : use postgres_client::ReadyForQueryStatus;
      11              : use rand::Rng;
      12              : use smol_str::ToSmolStr;
      13              : use tracing::{Span, debug, info};
      14              : 
      15              : use super::backend::HttpConnError;
      16              : use super::conn_pool::ClientDataRemote;
      17              : use super::http_conn_pool::ClientDataHttp;
      18              : use super::local_conn_pool::ClientDataLocal;
      19              : use crate::auth::backend::ComputeUserInfo;
      20              : use crate::context::RequestContext;
      21              : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
      22              : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
      23              : use crate::protocol2::ConnectionInfoExtra;
      24              : use crate::types::{DbName, EndpointCacheKey, RoleName};
      25              : use crate::usage_metrics::{Ids, MetricCounter, TrafficDirection, USAGE_METRICS};
      26              : 
      27              : #[derive(Debug, Clone)]
      28              : pub(crate) struct ConnInfo {
      29              :     pub(crate) user_info: ComputeUserInfo,
      30              :     pub(crate) dbname: DbName,
      31              : }
      32              : 
      33              : impl ConnInfo {
      34              :     // hm, change to hasher to avoid cloning?
      35            3 :     pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
      36            3 :         (self.dbname.clone(), self.user_info.user.clone())
      37            3 :     }
      38              : 
      39            2 :     pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
      40            2 :         // We don't want to cache http connections for ephemeral endpoints.
      41            2 :         if self.user_info.options.is_ephemeral() {
      42            0 :             None
      43              :         } else {
      44            2 :             Some(self.user_info.endpoint_cache_key())
      45              :         }
      46            2 :     }
      47              : }
      48              : 
      49              : #[derive(Clone)]
      50              : pub(crate) enum ClientDataEnum {
      51              :     Remote(ClientDataRemote),
      52              :     Local(ClientDataLocal),
      53              :     Http(ClientDataHttp),
      54              : }
      55              : 
      56              : #[derive(Clone)]
      57              : pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
      58              :     pub(crate) inner: C,
      59              :     pub(crate) aux: MetricsAuxInfo,
      60              :     pub(crate) conn_id: uuid::Uuid,
      61              :     pub(crate) data: ClientDataEnum, // custom client data like session, key, jti
      62              : }
      63              : 
      64              : impl<C: ClientInnerExt> Drop for ClientInnerCommon<C> {
      65            7 :     fn drop(&mut self) {
      66            7 :         match &mut self.data {
      67            7 :             ClientDataEnum::Remote(remote_data) => {
      68            7 :                 remote_data.cancel();
      69            7 :             }
      70            0 :             ClientDataEnum::Local(local_data) => {
      71            0 :                 local_data.cancel();
      72            0 :             }
      73            0 :             ClientDataEnum::Http(_http_data) => (),
      74              :         }
      75            7 :     }
      76              : }
      77              : 
      78              : impl<C: ClientInnerExt> ClientInnerCommon<C> {
      79            6 :     pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
      80            6 :         self.conn_id
      81            6 :     }
      82              : 
      83            0 :     pub(crate) fn get_data(&mut self) -> &mut ClientDataEnum {
      84            0 :         &mut self.data
      85            0 :     }
      86              : }
      87              : 
      88              : pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
      89              :     pub(crate) conn: ClientInnerCommon<C>,
      90              :     pub(crate) _last_access: std::time::Instant,
      91              : }
      92              : 
      93              : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
      94              : // Number of open connections is limited by the `max_conns_per_endpoint`.
      95              : pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
      96              :     pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
      97              :     total_conns: usize,
      98              :     /// max # connections per endpoint
      99              :     max_conns: usize,
     100              :     _guard: HttpEndpointPoolsGuard<'static>,
     101              :     global_connections_count: Arc<AtomicUsize>,
     102              :     global_pool_size_max_conns: usize,
     103              :     pool_name: String,
     104              : }
     105              : 
     106              : impl<C: ClientInnerExt> EndpointConnPool<C> {
     107            0 :     pub(crate) fn new(
     108            0 :         hmap: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
     109            0 :         tconns: usize,
     110            0 :         max_conns_per_endpoint: usize,
     111            0 :         global_connections_count: Arc<AtomicUsize>,
     112            0 :         max_total_conns: usize,
     113            0 :         pname: String,
     114            0 :     ) -> Self {
     115            0 :         Self {
     116            0 :             pools: hmap,
     117            0 :             total_conns: tconns,
     118            0 :             max_conns: max_conns_per_endpoint,
     119            0 :             _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
     120            0 :             global_connections_count,
     121            0 :             global_pool_size_max_conns: max_total_conns,
     122            0 :             pool_name: pname,
     123            0 :         }
     124            0 :     }
     125              : 
     126            0 :     pub(crate) fn get_conn_entry(
     127            0 :         &mut self,
     128            0 :         db_user: (DbName, RoleName),
     129            0 :     ) -> Option<ConnPoolEntry<C>> {
     130            0 :         let Self {
     131            0 :             pools,
     132            0 :             total_conns,
     133            0 :             global_connections_count,
     134            0 :             ..
     135            0 :         } = self;
     136            0 :         pools.get_mut(&db_user).and_then(|pool_entries| {
     137            0 :             let (entry, removed) = pool_entries.get_conn_entry(total_conns);
     138            0 :             global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
     139            0 :             entry
     140            0 :         })
     141            0 :     }
     142              : 
     143            0 :     pub(crate) fn remove_client(
     144            0 :         &mut self,
     145            0 :         db_user: (DbName, RoleName),
     146            0 :         conn_id: uuid::Uuid,
     147            0 :     ) -> bool {
     148            0 :         let Self {
     149            0 :             pools,
     150            0 :             total_conns,
     151            0 :             global_connections_count,
     152            0 :             ..
     153            0 :         } = self;
     154            0 :         if let Some(pool) = pools.get_mut(&db_user) {
     155            0 :             let old_len = pool.get_conns().len();
     156            0 :             pool.get_conns()
     157            0 :                 .retain(|conn| conn.conn.get_conn_id() != conn_id);
     158            0 :             let new_len = pool.get_conns().len();
     159            0 :             let removed = old_len - new_len;
     160            0 :             if removed > 0 {
     161            0 :                 global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
     162            0 :                 Metrics::get()
     163            0 :                     .proxy
     164            0 :                     .http_pool_opened_connections
     165            0 :                     .get_metric()
     166            0 :                     .dec_by(removed as i64);
     167            0 :             }
     168            0 :             *total_conns -= removed;
     169            0 :             removed > 0
     170              :         } else {
     171            0 :             false
     172              :         }
     173            0 :     }
     174              : 
     175            6 :     pub(crate) fn get_name(&self) -> &str {
     176            6 :         &self.pool_name
     177            6 :     }
     178              : 
     179            0 :     pub(crate) fn get_pool(&self, db_user: (DbName, RoleName)) -> Option<&DbUserConnPool<C>> {
     180            0 :         self.pools.get(&db_user)
     181            0 :     }
     182              : 
     183            0 :     pub(crate) fn get_pool_mut(
     184            0 :         &mut self,
     185            0 :         db_user: (DbName, RoleName),
     186            0 :     ) -> Option<&mut DbUserConnPool<C>> {
     187            0 :         self.pools.get_mut(&db_user)
     188            0 :     }
     189              : 
     190            6 :     pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
     191            6 :         let conn_id = client.get_conn_id();
     192            6 :         let (max_conn, conn_count, pool_name) = {
     193            6 :             let pool = pool.read();
     194            6 :             (
     195            6 :                 pool.global_pool_size_max_conns,
     196            6 :                 pool.global_connections_count
     197            6 :                     .load(atomic::Ordering::Relaxed),
     198            6 :                 pool.get_name().to_string(),
     199            6 :             )
     200            6 :         };
     201            6 : 
     202            6 :         if client.inner.is_closed() {
     203            1 :             info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
     204            1 :             return;
     205            5 :         }
     206            5 : 
     207            5 :         if conn_count >= max_conn {
     208            1 :             info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
     209            1 :             return;
     210            4 :         }
     211            4 : 
     212            4 :         // return connection to the pool
     213            4 :         let mut returned = false;
     214            4 :         let mut per_db_size = 0;
     215            4 :         let total_conns = {
     216            4 :             let mut pool = pool.write();
     217            4 : 
     218            4 :             if pool.total_conns < pool.max_conns {
     219            3 :                 let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
     220            3 :                 pool_entries.get_conns().push(ConnPoolEntry {
     221            3 :                     conn: client,
     222            3 :                     _last_access: std::time::Instant::now(),
     223            3 :                 });
     224            3 : 
     225            3 :                 returned = true;
     226            3 :                 per_db_size = pool_entries.get_conns().len();
     227            3 : 
     228            3 :                 pool.total_conns += 1;
     229            3 :                 pool.global_connections_count
     230            3 :                     .fetch_add(1, atomic::Ordering::Relaxed);
     231            3 :                 Metrics::get()
     232            3 :                     .proxy
     233            3 :                     .http_pool_opened_connections
     234            3 :                     .get_metric()
     235            3 :                     .inc();
     236            3 :             }
     237              : 
     238            4 :             pool.total_conns
     239            4 :         };
     240            4 : 
     241            4 :         // do logging outside of the mutex
     242            4 :         if returned {
     243            3 :             debug!(%conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
     244              :         } else {
     245            1 :             info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
     246              :         }
     247            6 :     }
     248              : }
     249              : 
     250              : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
     251            2 :     fn drop(&mut self) {
     252            2 :         if self.total_conns > 0 {
     253            2 :             self.global_connections_count
     254            2 :                 .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
     255            2 :             Metrics::get()
     256            2 :                 .proxy
     257            2 :                 .http_pool_opened_connections
     258            2 :                 .get_metric()
     259            2 :                 .dec_by(self.total_conns as i64);
     260            2 :         }
     261            2 :     }
     262              : }
     263              : 
     264              : pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
     265              :     pub(crate) conns: Vec<ConnPoolEntry<C>>,
     266              :     pub(crate) initialized: Option<bool>, // a bit ugly, exists only for local pools
     267              : }
     268              : 
     269              : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
     270            2 :     fn default() -> Self {
     271            2 :         Self {
     272            2 :             conns: Vec::new(),
     273            2 :             initialized: None,
     274            2 :         }
     275            2 :     }
     276              : }
     277              : 
     278              : pub(crate) trait DbUserConn<C: ClientInnerExt>: Default {
     279              :     fn set_initialized(&mut self);
     280              :     fn is_initialized(&self) -> bool;
     281              :     fn clear_closed_clients(&mut self, conns: &mut usize) -> usize;
     282              :     fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize);
     283              :     fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>>;
     284              : }
     285              : 
     286              : impl<C: ClientInnerExt> DbUserConn<C> for DbUserConnPool<C> {
     287            0 :     fn set_initialized(&mut self) {
     288            0 :         self.initialized = Some(true);
     289            0 :     }
     290              : 
     291            0 :     fn is_initialized(&self) -> bool {
     292            0 :         self.initialized.unwrap_or(false)
     293            0 :     }
     294              : 
     295            1 :     fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
     296            1 :         let old_len = self.conns.len();
     297            1 : 
     298            2 :         self.conns.retain(|conn| !conn.conn.inner.is_closed());
     299            1 : 
     300            1 :         let new_len = self.conns.len();
     301            1 :         let removed = old_len - new_len;
     302            1 :         *conns -= removed;
     303            1 :         removed
     304            1 :     }
     305              : 
     306            0 :     fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize) {
     307            0 :         let mut removed = self.clear_closed_clients(conns);
     308            0 :         let conn = self.conns.pop();
     309            0 :         if conn.is_some() {
     310            0 :             *conns -= 1;
     311            0 :             removed += 1;
     312            0 :         }
     313              : 
     314            0 :         Metrics::get()
     315            0 :             .proxy
     316            0 :             .http_pool_opened_connections
     317            0 :             .get_metric()
     318            0 :             .dec_by(removed as i64);
     319            0 : 
     320            0 :         (conn, removed)
     321            0 :     }
     322              : 
     323            6 :     fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>> {
     324            6 :         &mut self.conns
     325            6 :     }
     326              : }
     327              : 
     328              : pub(crate) trait EndpointConnPoolExt<C: ClientInnerExt> {
     329              :     fn clear_closed(&mut self) -> usize;
     330              :     fn total_conns(&self) -> usize;
     331              : }
     332              : 
     333              : impl<C: ClientInnerExt> EndpointConnPoolExt<C> for EndpointConnPool<C> {
     334            1 :     fn clear_closed(&mut self) -> usize {
     335            1 :         let mut clients_removed: usize = 0;
     336            1 :         for db_pool in self.pools.values_mut() {
     337            1 :             clients_removed += db_pool.clear_closed_clients(&mut self.total_conns);
     338            1 :         }
     339            1 :         clients_removed
     340            1 :     }
     341              : 
     342            1 :     fn total_conns(&self) -> usize {
     343            1 :         self.total_conns
     344            1 :     }
     345              : }
     346              : 
     347              : pub(crate) struct GlobalConnPool<C, P>
     348              : where
     349              :     C: ClientInnerExt,
     350              :     P: EndpointConnPoolExt<C>,
     351              : {
     352              :     // endpoint -> per-endpoint connection pool
     353              :     //
     354              :     // That should be a fairly conteded map, so return reference to the per-endpoint
     355              :     // pool as early as possible and release the lock.
     356              :     pub(crate) global_pool: ClashMap<EndpointCacheKey, Arc<RwLock<P>>>,
     357              : 
     358              :     /// Number of endpoint-connection pools
     359              :     ///
     360              :     /// [`ClashMap::len`] iterates over all inner pools and acquires a read lock on each.
     361              :     /// That seems like far too much effort, so we're using a relaxed increment counter instead.
     362              :     /// It's only used for diagnostics.
     363              :     pub(crate) global_pool_size: AtomicUsize,
     364              : 
     365              :     /// Total number of connections in the pool
     366              :     pub(crate) global_connections_count: Arc<AtomicUsize>,
     367              : 
     368              :     pub(crate) config: &'static crate::config::HttpConfig,
     369              : 
     370              :     _marker: PhantomData<C>,
     371              : }
     372              : 
     373              : #[derive(Debug, Clone, Copy)]
     374              : pub struct GlobalConnPoolOptions {
     375              :     // Maximum number of connections per one endpoint.
     376              :     // Can mix different (dbname, username) connections.
     377              :     // When running out of free slots for a particular endpoint,
     378              :     // falls back to opening a new connection for each request.
     379              :     pub max_conns_per_endpoint: usize,
     380              : 
     381              :     pub gc_epoch: Duration,
     382              : 
     383              :     pub pool_shards: usize,
     384              : 
     385              :     pub idle_timeout: Duration,
     386              : 
     387              :     pub opt_in: bool,
     388              : 
     389              :     // Total number of connections in the pool.
     390              :     pub max_total_conns: usize,
     391              : }
     392              : 
     393              : impl<C, P> GlobalConnPool<C, P>
     394              : where
     395              :     C: ClientInnerExt,
     396              :     P: EndpointConnPoolExt<C>,
     397              : {
     398            1 :     pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
     399            1 :         let shards = config.pool_options.pool_shards;
     400            1 :         Arc::new(Self {
     401            1 :             global_pool: ClashMap::with_shard_amount(shards),
     402            1 :             global_pool_size: AtomicUsize::new(0),
     403            1 :             config,
     404            1 :             global_connections_count: Arc::new(AtomicUsize::new(0)),
     405            1 :             _marker: PhantomData,
     406            1 :         })
     407            1 :     }
     408              : 
     409              :     #[cfg(test)]
     410            9 :     pub(crate) fn get_global_connections_count(&self) -> usize {
     411            9 :         self.global_connections_count
     412            9 :             .load(atomic::Ordering::Relaxed)
     413            9 :     }
     414              : 
     415            0 :     pub(crate) fn get_idle_timeout(&self) -> Duration {
     416            0 :         self.config.pool_options.idle_timeout
     417            0 :     }
     418              : 
     419            0 :     pub(crate) fn shutdown(&self) {
     420            0 :         // drops all strong references to endpoint-pools
     421            0 :         self.global_pool.clear();
     422            0 :     }
     423              : 
     424            0 :     pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
     425            0 :         let epoch = self.config.pool_options.gc_epoch;
     426            0 :         let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
     427              :         loop {
     428            0 :             interval.tick().await;
     429              : 
     430            0 :             let shard = rng.gen_range(0..self.global_pool.shards().len());
     431            0 :             self.gc(shard);
     432              :         }
     433              :     }
     434              : 
     435            2 :     pub(crate) fn gc(&self, shard: usize) {
     436            2 :         debug!(shard, "pool: performing epoch reclamation");
     437              : 
     438              :         // acquire a random shard lock
     439            2 :         let mut shard = self.global_pool.shards()[shard].write();
     440            2 : 
     441            2 :         let timer = Metrics::get()
     442            2 :             .proxy
     443            2 :             .http_pool_reclaimation_lag_seconds
     444            2 :             .start_timer();
     445            2 :         let current_len = shard.len();
     446            2 :         let mut clients_removed = 0;
     447            2 :         shard.retain(|(endpoint, x)| {
     448              :             // if the current endpoint pool is unique (no other strong or weak references)
     449              :             // then it is currently not in use by any connections.
     450            2 :             if let Some(pool) = Arc::get_mut(x) {
     451            1 :                 let endpoints = pool.get_mut();
     452            1 :                 clients_removed = endpoints.clear_closed();
     453            1 : 
     454            1 :                 if endpoints.total_conns() == 0 {
     455            0 :                     info!("pool: discarding pool for endpoint {endpoint}");
     456            0 :                     return false;
     457            1 :                 }
     458            1 :             }
     459              : 
     460            2 :             true
     461            2 :         });
     462            2 : 
     463            2 :         let new_len = shard.len();
     464            2 :         drop(shard);
     465            2 :         timer.observe();
     466            2 : 
     467            2 :         // Do logging outside of the lock.
     468            2 :         if clients_removed > 0 {
     469            1 :             let size = self
     470            1 :                 .global_connections_count
     471            1 :                 .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
     472            1 :                 - clients_removed;
     473            1 :             Metrics::get()
     474            1 :                 .proxy
     475            1 :                 .http_pool_opened_connections
     476            1 :                 .get_metric()
     477            1 :                 .dec_by(clients_removed as i64);
     478            1 :             info!(
     479            0 :                 "pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"
     480              :             );
     481            1 :         }
     482            2 :         let removed = current_len - new_len;
     483            2 : 
     484            2 :         if removed > 0 {
     485            0 :             let global_pool_size = self
     486            0 :                 .global_pool_size
     487            0 :                 .fetch_sub(removed, atomic::Ordering::Relaxed)
     488            0 :                 - removed;
     489            0 :             info!("pool: performed global pool gc. size now {global_pool_size}");
     490            2 :         }
     491            2 :     }
     492              : }
     493              : 
     494              : impl<C: ClientInnerExt> GlobalConnPool<C, EndpointConnPool<C>> {
     495            0 :     pub(crate) fn get(
     496            0 :         self: &Arc<Self>,
     497            0 :         ctx: &RequestContext,
     498            0 :         conn_info: &ConnInfo,
     499            0 :     ) -> Result<Option<Client<C>>, HttpConnError> {
     500            0 :         let mut client: Option<ClientInnerCommon<C>> = None;
     501            0 :         let Some(endpoint) = conn_info.endpoint_cache_key() else {
     502            0 :             return Ok(None);
     503              :         };
     504              : 
     505            0 :         let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
     506            0 :         if let Some(entry) = endpoint_pool
     507            0 :             .write()
     508            0 :             .get_conn_entry(conn_info.db_and_user())
     509            0 :         {
     510            0 :             client = Some(entry.conn);
     511            0 :         }
     512            0 :         let endpoint_pool = Arc::downgrade(&endpoint_pool);
     513              : 
     514              :         // ok return cached connection if found and establish a new one otherwise
     515            0 :         if let Some(mut client) = client {
     516            0 :             if client.inner.is_closed() {
     517            0 :                 info!("pool: cached connection '{conn_info}' is closed, opening a new one");
     518            0 :                 return Ok(None);
     519            0 :             }
     520            0 :             tracing::Span::current()
     521            0 :                 .record("conn_id", tracing::field::display(client.get_conn_id()));
     522            0 :             tracing::Span::current().record(
     523            0 :                 "pid",
     524            0 :                 tracing::field::display(client.inner.get_process_id()),
     525            0 :             );
     526            0 :             debug!(
     527            0 :                 cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
     528            0 :                 "pool: reusing connection '{conn_info}'"
     529              :             );
     530              : 
     531            0 :             match client.get_data() {
     532            0 :                 ClientDataEnum::Local(data) => {
     533            0 :                     data.session().send(ctx.session_id())?;
     534              :                 }
     535              : 
     536            0 :                 ClientDataEnum::Remote(data) => {
     537            0 :                     data.session().send(ctx.session_id())?;
     538              :                 }
     539            0 :                 ClientDataEnum::Http(_) => (),
     540              :             }
     541              : 
     542            0 :             ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
     543            0 :             ctx.success();
     544            0 :             return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
     545            0 :         }
     546            0 :         Ok(None)
     547            0 :     }
     548              : 
     549            2 :     pub(crate) fn get_or_create_endpoint_pool(
     550            2 :         self: &Arc<Self>,
     551            2 :         endpoint: &EndpointCacheKey,
     552            2 :     ) -> Arc<RwLock<EndpointConnPool<C>>> {
     553              :         // fast path
     554            2 :         if let Some(pool) = self.global_pool.get(endpoint) {
     555            0 :             return pool.clone();
     556            2 :         }
     557            2 : 
     558            2 :         // slow path
     559            2 :         let new_pool = Arc::new(RwLock::new(EndpointConnPool {
     560            2 :             pools: HashMap::new(),
     561            2 :             total_conns: 0,
     562            2 :             max_conns: self.config.pool_options.max_conns_per_endpoint,
     563            2 :             _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
     564            2 :             global_connections_count: self.global_connections_count.clone(),
     565            2 :             global_pool_size_max_conns: self.config.pool_options.max_total_conns,
     566            2 :             pool_name: String::from("remote"),
     567            2 :         }));
     568            2 : 
     569            2 :         // find or create a pool for this endpoint
     570            2 :         let mut created = false;
     571            2 :         let pool = self
     572            2 :             .global_pool
     573            2 :             .entry(endpoint.clone())
     574            2 :             .or_insert_with(|| {
     575            2 :                 created = true;
     576            2 :                 new_pool
     577            2 :             })
     578            2 :             .clone();
     579            2 : 
     580            2 :         // log new global pool size
     581            2 :         if created {
     582            2 :             let global_pool_size = self
     583            2 :                 .global_pool_size
     584            2 :                 .fetch_add(1, atomic::Ordering::Relaxed)
     585            2 :                 + 1;
     586            2 :             info!(
     587            0 :                 "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
     588              :             );
     589            0 :         }
     590              : 
     591            2 :         pool
     592            2 :     }
     593              : }
     594              : pub(crate) struct Client<C: ClientInnerExt> {
     595              :     span: Span,
     596              :     inner: Option<ClientInnerCommon<C>>,
     597              :     conn_info: ConnInfo,
     598              :     pool: Weak<RwLock<EndpointConnPool<C>>>,
     599              : }
     600              : 
     601              : pub(crate) struct Discard<'a, C: ClientInnerExt> {
     602              :     conn_info: &'a ConnInfo,
     603              :     pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
     604              : }
     605              : 
     606              : impl<C: ClientInnerExt> Client<C> {
     607            7 :     pub(crate) fn new(
     608            7 :         inner: ClientInnerCommon<C>,
     609            7 :         conn_info: ConnInfo,
     610            7 :         pool: Weak<RwLock<EndpointConnPool<C>>>,
     611            7 :     ) -> Self {
     612            7 :         Self {
     613            7 :             inner: Some(inner),
     614            7 :             span: Span::current(),
     615            7 :             conn_info,
     616            7 :             pool,
     617            7 :         }
     618            7 :     }
     619              : 
     620            0 :     pub(crate) fn client_inner(&mut self) -> (&mut ClientInnerCommon<C>, Discard<'_, C>) {
     621            0 :         let Self {
     622            0 :             inner,
     623            0 :             pool,
     624            0 :             conn_info,
     625            0 :             span: _,
     626            0 :         } = self;
     627            0 :         let inner_m = inner.as_mut().expect("client inner should not be removed");
     628            0 :         (inner_m, Discard { conn_info, pool })
     629            0 :     }
     630              : 
     631            1 :     pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
     632            1 :         let Self {
     633            1 :             inner,
     634            1 :             pool,
     635            1 :             conn_info,
     636            1 :             span: _,
     637            1 :         } = self;
     638            1 :         let inner = inner.as_mut().expect("client inner should not be removed");
     639            1 :         (&mut inner.inner, Discard { conn_info, pool })
     640            1 :     }
     641              : 
     642            0 :     pub(crate) fn metrics(
     643            0 :         &self,
     644            0 :         direction: TrafficDirection,
     645            0 :         ctx: &RequestContext,
     646            0 :     ) -> Arc<MetricCounter> {
     647            0 :         let aux = &self
     648            0 :             .inner
     649            0 :             .as_ref()
     650            0 :             .expect("client inner should not be removed")
     651            0 :             .aux;
     652              : 
     653            0 :         let private_link_id = match ctx.extra() {
     654            0 :             None => None,
     655            0 :             Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
     656            0 :             Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
     657              :         };
     658              : 
     659            0 :         USAGE_METRICS.register(Ids {
     660            0 :             endpoint_id: aux.endpoint_id,
     661            0 :             branch_id: aux.branch_id,
     662            0 :             direction,
     663            0 :             private_link_id,
     664            0 :         })
     665            0 :     }
     666              : }
     667              : 
     668              : impl<C: ClientInnerExt> Drop for Client<C> {
     669            7 :     fn drop(&mut self) {
     670            7 :         let conn_info = self.conn_info.clone();
     671            7 :         let client = self
     672            7 :             .inner
     673            7 :             .take()
     674            7 :             .expect("client inner should not be removed");
     675            7 :         if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
     676            6 :             let _current_span = self.span.enter();
     677            6 :             // return connection to the pool
     678            6 :             EndpointConnPool::put(&conn_pool, &conn_info, client);
     679            6 :         }
     680            7 :     }
     681              : }
     682              : 
     683              : impl<C: ClientInnerExt> Deref for Client<C> {
     684              :     type Target = C;
     685              : 
     686            0 :     fn deref(&self) -> &Self::Target {
     687            0 :         &self
     688            0 :             .inner
     689            0 :             .as_ref()
     690            0 :             .expect("client inner should not be removed")
     691            0 :             .inner
     692            0 :     }
     693              : }
     694              : 
     695              : pub(crate) trait ClientInnerExt: Sync + Send + 'static {
     696              :     fn is_closed(&self) -> bool;
     697              :     fn get_process_id(&self) -> i32;
     698              : }
     699              : 
     700              : impl ClientInnerExt for postgres_client::Client {
     701            0 :     fn is_closed(&self) -> bool {
     702            0 :         self.is_closed()
     703            0 :     }
     704              : 
     705            0 :     fn get_process_id(&self) -> i32 {
     706            0 :         self.get_process_id()
     707            0 :     }
     708              : }
     709              : 
     710              : impl<C: ClientInnerExt> Discard<'_, C> {
     711            0 :     pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
     712            0 :         let conn_info = &self.conn_info;
     713            0 :         if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
     714            0 :             info!("pool: throwing away connection '{conn_info}' because connection is not idle");
     715            0 :         }
     716            0 :     }
     717            1 :     pub(crate) fn discard(&mut self) {
     718            1 :         let conn_info = &self.conn_info;
     719            1 :         if std::mem::take(self.pool).strong_count() > 0 {
     720            1 :             info!(
     721            0 :                 "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"
     722              :             );
     723            0 :         }
     724            1 :     }
     725              : }
        

Generated by: LCOV version 2.1-beta