LCOV - code coverage report
Current view: top level - proxy/src/serverless - conn_pool.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 94.0 % 563 529
Test Date: 2024-02-12 20:26:03 Functions: 55.1 % 136 75

            Line data    Source code
       1              : use dashmap::DashMap;
       2              : use futures::{future::poll_fn, Future};
       3              : use metrics::IntCounterPairGuard;
       4              : use parking_lot::RwLock;
       5              : use rand::Rng;
       6              : use smallvec::SmallVec;
       7              : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
       8              : use std::{
       9              :     fmt,
      10              :     task::{ready, Poll},
      11              : };
      12              : use std::{
      13              :     ops::Deref,
      14              :     sync::atomic::{self, AtomicUsize},
      15              : };
      16              : use tokio::time::Instant;
      17              : use tokio_postgres::tls::NoTlsStream;
      18              : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
      19              : 
      20              : use crate::console::messages::MetricsAuxInfo;
      21              : use crate::metrics::{ENDPOINT_POOLS, GC_LATENCY, NUM_OPEN_CLIENTS_IN_HTTP_POOL};
      22              : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
      23              : use crate::{
      24              :     auth::backend::ComputeUserInfo, context::RequestMonitoring, metrics::NUM_DB_CONNECTIONS_GAUGE,
      25              :     DbName, EndpointCacheKey, RoleName,
      26              : };
      27              : 
      28              : use tracing::{debug, error, warn, Span};
      29              : use tracing::{info, info_span, Instrument};
      30              : 
      31              : use super::backend::HttpConnError;
      32              : 
      33          114 : #[derive(Debug, Clone)]
      34              : pub struct ConnInfo {
      35              :     pub user_info: ComputeUserInfo,
      36              :     pub dbname: DbName,
      37              :     pub password: SmallVec<[u8; 16]>,
      38              : }
      39              : 
      40              : impl ConnInfo {
      41              :     // hm, change to hasher to avoid cloning?
      42          110 :     pub fn db_and_user(&self) -> (DbName, RoleName) {
      43          110 :         (self.dbname.clone(), self.user_info.user.clone())
      44          110 :     }
      45              : 
      46           68 :     pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
      47           68 :         self.user_info.endpoint_cache_key()
      48           68 :     }
      49              : }
      50              : 
      51              : impl fmt::Display for ConnInfo {
      52              :     // use custom display to avoid logging password
      53          256 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      54          256 :         write!(
      55          256 :             f,
      56          256 :             "{}@{}/{}?{}",
      57          256 :             self.user_info.user,
      58          256 :             self.user_info.endpoint,
      59          256 :             self.dbname,
      60          256 :             self.user_info.options.get_cache_key("")
      61          256 :         )
      62          256 :     }
      63              : }
      64              : 
      65              : struct ConnPoolEntry<C: ClientInnerExt> {
      66              :     conn: ClientInner<C>,
      67              :     _last_access: std::time::Instant,
      68              : }
      69              : 
      70              : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
      71              : // Number of open connections is limited by the `max_conns_per_endpoint`.
      72              : pub struct EndpointConnPool<C: ClientInnerExt> {
      73              :     pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
      74              :     total_conns: usize,
      75              :     max_conns: usize,
      76              :     _guard: IntCounterPairGuard,
      77              :     global_connections_count: Arc<AtomicUsize>,
      78              :     global_pool_size_max_conns: usize,
      79              : }
      80              : 
      81              : impl<C: ClientInnerExt> EndpointConnPool<C> {
      82           24 :     fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
      83           24 :         let Self {
      84           24 :             pools,
      85           24 :             total_conns,
      86           24 :             global_connections_count,
      87           24 :             ..
      88           24 :         } = self;
      89           24 :         pools.get_mut(&db_user).and_then(|pool_entries| {
      90           14 :             pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
      91           24 :         })
      92           24 :     }
      93              : 
      94            3 :     fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
      95            3 :         let Self {
      96            3 :             pools,
      97            3 :             total_conns,
      98            3 :             global_connections_count,
      99            3 :             ..
     100            3 :         } = self;
     101            3 :         if let Some(pool) = pools.get_mut(&db_user) {
     102            1 :             let old_len = pool.conns.len();
     103            1 :             pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
     104            1 :             let new_len = pool.conns.len();
     105            1 :             let removed = old_len - new_len;
     106            1 :             if removed > 0 {
     107            0 :                 global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
     108            0 :                 NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64);
     109            1 :             }
     110            1 :             *total_conns -= removed;
     111            1 :             removed > 0
     112              :         } else {
     113            2 :             false
     114              :         }
     115            3 :     }
     116              : 
     117           52 :     fn put(
     118           52 :         pool: &RwLock<Self>,
     119           52 :         conn_info: &ConnInfo,
     120           52 :         client: ClientInner<C>,
     121           52 :     ) -> anyhow::Result<()> {
     122           52 :         let conn_id = client.conn_id;
     123           52 : 
     124           52 :         if client.is_closed() {
     125            2 :             info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
     126            2 :             return Ok(());
     127           50 :         }
     128           50 :         let global_max_conn = pool.read().global_pool_size_max_conns;
     129           50 :         if pool
     130           50 :             .read()
     131           50 :             .global_connections_count
     132           50 :             .load(atomic::Ordering::Relaxed)
     133           50 :             >= global_max_conn
     134              :         {
     135            2 :             info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
     136            2 :             return Ok(());
     137           48 :         }
     138           48 : 
     139           48 :         // return connection to the pool
     140           48 :         let mut returned = false;
     141           48 :         let mut per_db_size = 0;
     142           48 :         let total_conns = {
     143           48 :             let mut pool = pool.write();
     144           48 : 
     145           48 :             if pool.total_conns < pool.max_conns {
     146           46 :                 let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
     147           46 :                 pool_entries.conns.push(ConnPoolEntry {
     148           46 :                     conn: client,
     149           46 :                     _last_access: std::time::Instant::now(),
     150           46 :                 });
     151           46 : 
     152           46 :                 returned = true;
     153           46 :                 per_db_size = pool_entries.conns.len();
     154           46 : 
     155           46 :                 pool.total_conns += 1;
     156           46 :                 pool.global_connections_count
     157           46 :                     .fetch_add(1, atomic::Ordering::Relaxed);
     158           46 :                 NUM_OPEN_CLIENTS_IN_HTTP_POOL.inc();
     159           46 :             }
     160              : 
     161           48 :             pool.total_conns
     162           48 :         };
     163           48 : 
     164           48 :         // do logging outside of the mutex
     165           48 :         if returned {
     166           46 :             info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
     167              :         } else {
     168            2 :             info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
     169              :         }
     170              : 
     171           48 :         Ok(())
     172           52 :     }
     173              : }
     174              : 
     175              : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
     176           17 :     fn drop(&mut self) {
     177           17 :         if self.total_conns > 0 {
     178           16 :             self.global_connections_count
     179           16 :                 .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
     180           16 :             NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(self.total_conns as i64);
     181           16 :         }
     182           17 :     }
     183              : }
     184              : 
     185              : pub struct DbUserConnPool<C: ClientInnerExt> {
     186              :     conns: Vec<ConnPoolEntry<C>>,
     187              : }
     188              : 
     189              : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
     190           16 :     fn default() -> Self {
     191           16 :         Self { conns: Vec::new() }
     192           16 :     }
     193              : }
     194              : 
     195              : impl<C: ClientInnerExt> DbUserConnPool<C> {
     196           16 :     fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
     197           16 :         let old_len = self.conns.len();
     198           16 : 
     199           16 :         self.conns.retain(|conn| !conn.conn.is_closed());
     200           16 : 
     201           16 :         let new_len = self.conns.len();
     202           16 :         let removed = old_len - new_len;
     203           16 :         *conns -= removed;
     204           16 :         removed
     205           16 :     }
     206              : 
     207           14 :     fn get_conn_entry(
     208           14 :         &mut self,
     209           14 :         conns: &mut usize,
     210           14 :         global_connections_count: Arc<AtomicUsize>,
     211           14 :     ) -> Option<ConnPoolEntry<C>> {
     212           14 :         let mut removed = self.clear_closed_clients(conns);
     213           14 :         let conn = self.conns.pop();
     214           14 :         if conn.is_some() {
     215            4 :             *conns -= 1;
     216            4 :             removed += 1;
     217           10 :         }
     218           14 :         global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
     219           14 :         NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64);
     220           14 :         conn
     221           14 :     }
     222              : }
     223              : 
     224              : pub struct GlobalConnPool<C: ClientInnerExt> {
     225              :     // endpoint -> per-endpoint connection pool
     226              :     //
     227              :     // That should be a fairly conteded map, so return reference to the per-endpoint
     228              :     // pool as early as possible and release the lock.
     229              :     global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
     230              : 
     231              :     /// Number of endpoint-connection pools
     232              :     ///
     233              :     /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
     234              :     /// That seems like far too much effort, so we're using a relaxed increment counter instead.
     235              :     /// It's only used for diagnostics.
     236              :     global_pool_size: AtomicUsize,
     237              : 
     238              :     /// Total number of connections in the pool
     239              :     global_connections_count: Arc<AtomicUsize>,
     240              : 
     241              :     config: &'static crate::config::HttpConfig,
     242              : }
     243              : 
     244            0 : #[derive(Debug, Clone, Copy)]
     245              : pub struct GlobalConnPoolOptions {
     246              :     // Maximum number of connections per one endpoint.
     247              :     // Can mix different (dbname, username) connections.
     248              :     // When running out of free slots for a particular endpoint,
     249              :     // falls back to opening a new connection for each request.
     250              :     pub max_conns_per_endpoint: usize,
     251              : 
     252              :     pub gc_epoch: Duration,
     253              : 
     254              :     pub pool_shards: usize,
     255              : 
     256              :     pub idle_timeout: Duration,
     257              : 
     258              :     pub opt_in: bool,
     259              : 
     260              :     // Total number of connections in the pool.
     261              :     pub max_total_conns: usize,
     262              : }
     263              : 
     264              : impl<C: ClientInnerExt> GlobalConnPool<C> {
     265           27 :     pub fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
     266           27 :         let shards = config.pool_options.pool_shards;
     267           27 :         Arc::new(Self {
     268           27 :             global_pool: DashMap::with_shard_amount(shards),
     269           27 :             global_pool_size: AtomicUsize::new(0),
     270           27 :             config,
     271           27 :             global_connections_count: Arc::new(AtomicUsize::new(0)),
     272           27 :         })
     273           27 :     }
     274              : 
     275              :     #[cfg(test)]
     276           18 :     pub fn get_global_connections_count(&self) -> usize {
     277           18 :         self.global_connections_count
     278           18 :             .load(atomic::Ordering::Relaxed)
     279           18 :     }
     280              : 
     281           40 :     pub fn get_idle_timeout(&self) -> Duration {
     282           40 :         self.config.pool_options.idle_timeout
     283           40 :     }
     284              : 
     285           25 :     pub fn shutdown(&self) {
     286           25 :         // drops all strong references to endpoint-pools
     287           25 :         self.global_pool.clear();
     288           25 :     }
     289              : 
     290           25 :     pub async fn gc_worker(&self, mut rng: impl Rng) {
     291           25 :         let epoch = self.config.pool_options.gc_epoch;
     292           25 :         let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
     293              :         loop {
     294           58 :             interval.tick().await;
     295              : 
     296           33 :             let shard = rng.gen_range(0..self.global_pool.shards().len());
     297           33 :             self.gc(shard);
     298              :         }
     299              :     }
     300              : 
     301           37 :     fn gc(&self, shard: usize) {
     302           37 :         debug!(shard, "pool: performing epoch reclamation");
     303              : 
     304              :         // acquire a random shard lock
     305           37 :         let mut shard = self.global_pool.shards()[shard].write();
     306           37 : 
     307           37 :         let timer = GC_LATENCY.start_timer();
     308           37 :         let current_len = shard.len();
     309           37 :         let mut clients_removed = 0;
     310           37 :         shard.retain(|endpoint, x| {
     311              :             // if the current endpoint pool is unique (no other strong or weak references)
     312              :             // then it is currently not in use by any connections.
     313            4 :             if let Some(pool) = Arc::get_mut(x.get_mut()) {
     314              :                 let EndpointConnPool {
     315            2 :                     pools, total_conns, ..
     316            2 :                 } = pool.get_mut();
     317            2 : 
     318            2 :                 // ensure that closed clients are removed
     319            2 :                 pools.iter_mut().for_each(|(_, db_pool)| {
     320            2 :                     clients_removed += db_pool.clear_closed_clients(total_conns);
     321            2 :                 });
     322            2 : 
     323            2 :                 // we only remove this pool if it has no active connections
     324            2 :                 if *total_conns == 0 {
     325            0 :                     info!("pool: discarding pool for endpoint {endpoint}");
     326            0 :                     return false;
     327            2 :                 }
     328            2 :             }
     329              : 
     330            4 :             true
     331           37 :         });
     332           37 : 
     333           37 :         let new_len = shard.len();
     334           37 :         drop(shard);
     335           37 :         timer.observe_duration();
     336           37 : 
     337           37 :         // Do logging outside of the lock.
     338           37 :         if clients_removed > 0 {
     339            2 :             let size = self
     340            2 :                 .global_connections_count
     341            2 :                 .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
     342            2 :                 - clients_removed;
     343            2 :             NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(clients_removed as i64);
     344            2 :             info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
     345           35 :         }
     346           37 :         let removed = current_len - new_len;
     347           37 : 
     348           37 :         if removed > 0 {
     349            0 :             let global_pool_size = self
     350            0 :                 .global_pool_size
     351            0 :                 .fetch_sub(removed, atomic::Ordering::Relaxed)
     352            0 :                 - removed;
     353            0 :             info!("pool: performed global pool gc. size now {global_pool_size}");
     354           37 :         }
     355           37 :     }
     356              : 
     357           24 :     pub async fn get(
     358           24 :         self: &Arc<Self>,
     359           24 :         ctx: &mut RequestMonitoring,
     360           24 :         conn_info: &ConnInfo,
     361           24 :     ) -> Result<Option<Client<C>>, HttpConnError> {
     362           24 :         let mut client: Option<ClientInner<C>> = None;
     363           24 : 
     364           24 :         let endpoint_pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
     365           24 :         if let Some(entry) = endpoint_pool
     366           24 :             .write()
     367           24 :             .get_conn_entry(conn_info.db_and_user())
     368              :         {
     369            4 :             client = Some(entry.conn)
     370           20 :         }
     371           24 :         let endpoint_pool = Arc::downgrade(&endpoint_pool);
     372              : 
     373              :         // ok return cached connection if found and establish a new one otherwise
     374           24 :         if let Some(client) = client {
     375            4 :             if client.is_closed() {
     376            0 :                 info!("pool: cached connection '{conn_info}' is closed, opening a new one");
     377            0 :                 return Ok(None);
     378              :             } else {
     379            4 :                 tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
     380            4 :                 tracing::Span::current().record(
     381            4 :                     "pid",
     382            4 :                     &tracing::field::display(client.inner.get_process_id()),
     383            4 :                 );
     384            4 :                 info!("pool: reusing connection '{conn_info}'");
     385            4 :                 client.session.send(ctx.session_id)?;
     386            4 :                 ctx.latency_timer.pool_hit();
     387            4 :                 ctx.latency_timer.success();
     388            4 :                 return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
     389              :             }
     390           20 :         }
     391           20 :         Ok(None)
     392           24 :     }
     393              : 
     394           68 :     fn get_or_create_endpoint_pool(
     395           68 :         self: &Arc<Self>,
     396           68 :         endpoint: &EndpointCacheKey,
     397           68 :     ) -> Arc<RwLock<EndpointConnPool<C>>> {
     398              :         // fast path
     399           68 :         if let Some(pool) = self.global_pool.get(endpoint) {
     400           51 :             return pool.clone();
     401           17 :         }
     402           17 : 
     403           17 :         // slow path
     404           17 :         let new_pool = Arc::new(RwLock::new(EndpointConnPool {
     405           17 :             pools: HashMap::new(),
     406           17 :             total_conns: 0,
     407           17 :             max_conns: self.config.pool_options.max_conns_per_endpoint,
     408           17 :             _guard: ENDPOINT_POOLS.guard(),
     409           17 :             global_connections_count: self.global_connections_count.clone(),
     410           17 :             global_pool_size_max_conns: self.config.pool_options.max_total_conns,
     411           17 :         }));
     412           17 : 
     413           17 :         // find or create a pool for this endpoint
     414           17 :         let mut created = false;
     415           17 :         let pool = self
     416           17 :             .global_pool
     417           17 :             .entry(endpoint.clone())
     418           17 :             .or_insert_with(|| {
     419           17 :                 created = true;
     420           17 :                 new_pool
     421           17 :             })
     422           17 :             .clone();
     423           17 : 
     424           17 :         // log new global pool size
     425           17 :         if created {
     426           17 :             let global_pool_size = self
     427           17 :                 .global_pool_size
     428           17 :                 .fetch_add(1, atomic::Ordering::Relaxed)
     429           17 :                 + 1;
     430           17 :             info!(
     431           13 :                 "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
     432           13 :             );
     433            0 :         }
     434              : 
     435           17 :         pool
     436           68 :     }
     437              : }
     438              : 
     439           40 : pub fn poll_client<C: ClientInnerExt>(
     440           40 :     global_pool: Arc<GlobalConnPool<C>>,
     441           40 :     ctx: &mut RequestMonitoring,
     442           40 :     conn_info: ConnInfo,
     443           40 :     client: C,
     444           40 :     mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
     445           40 :     conn_id: uuid::Uuid,
     446           40 :     aux: MetricsAuxInfo,
     447           40 : ) -> Client<C> {
     448           40 :     let conn_gauge = NUM_DB_CONNECTIONS_GAUGE
     449           40 :         .with_label_values(&[ctx.protocol])
     450           40 :         .guard();
     451           40 :     let mut session_id = ctx.session_id;
     452           40 :     let (tx, mut rx) = tokio::sync::watch::channel(session_id);
     453              : 
     454           40 :     let span = info_span!(parent: None, "connection", %conn_id);
     455           40 :     span.in_scope(|| {
     456           40 :         info!(%conn_info, %session_id, "new connection");
     457           40 :     });
     458           40 :     let pool =
     459           40 :         Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()));
     460           40 :     let pool_clone = pool.clone();
     461           40 : 
     462           40 :     let db_user = conn_info.db_and_user();
     463           40 :     let idle = global_pool.get_idle_timeout();
     464           40 :     tokio::spawn(
     465           40 :     async move {
     466           40 :         let _conn_gauge = conn_gauge;
     467           40 :         let mut idle_timeout = pin!(tokio::time::sleep(idle));
     468         7044 :         poll_fn(move |cx| {
     469         7044 :             if matches!(rx.has_changed(), Ok(true)) {
     470            4 :                 session_id = *rx.borrow_and_update();
     471            4 :                 info!(%session_id, "changed session");
     472            4 :                 idle_timeout.as_mut().reset(Instant::now() + idle);
     473         7040 :             }
     474              : 
     475              :             // 5 minute idle connection timeout
     476         7044 :             if idle_timeout.as_mut().poll(cx).is_ready() {
     477            0 :                 idle_timeout.as_mut().reset(Instant::now() + idle);
     478            0 :                 info!("connection idle");
     479            0 :                 if let Some(pool) = pool.clone().upgrade() {
     480              :                     // remove client from pool - should close the connection if it's idle.
     481              :                     // does nothing if the client is currently checked-out and in-use
     482            0 :                     if pool.write().remove_client(db_user.clone(), conn_id) {
     483            0 :                         info!("idle connection removed");
     484            0 :                     }
     485            0 :                 }
     486         7044 :             }
     487              : 
     488              :             loop {
     489         7044 :                 let message = ready!(connection.poll_message(cx));
     490              : 
     491            0 :                 match message {
     492            0 :                     Some(Ok(AsyncMessage::Notice(notice))) => {
     493            0 :                         info!(%session_id, "notice: {}", notice);
     494              :                     }
     495            0 :                     Some(Ok(AsyncMessage::Notification(notif))) => {
     496            0 :                         warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
     497              :                     }
     498              :                     Some(Ok(_)) => {
     499            0 :                         warn!(%session_id, "unknown message");
     500              :                     }
     501            0 :                     Some(Err(e)) => {
     502            0 :                         error!(%session_id, "connection error: {}", e);
     503            0 :                         break
     504              :                     }
     505              :                     None => {
     506           39 :                         info!("connection closed");
     507           39 :                         break
     508              :                     }
     509              :                 }
     510              :             }
     511              : 
     512              :             // remove from connection pool
     513           39 :             if let Some(pool) = pool.clone().upgrade() {
     514            3 :                 if pool.write().remove_client(db_user.clone(), conn_id) {
     515            0 :                     info!("closed connection removed");
     516            3 :                 }
     517           36 :             }
     518              : 
     519           39 :             Poll::Ready(())
     520         7044 :         }).await;
     521              : 
     522           40 :     }
     523           40 :     .instrument(span));
     524           40 :     let inner = ClientInner {
     525           40 :         inner: client,
     526           40 :         session: tx,
     527           40 :         aux,
     528           40 :         conn_id,
     529           40 :     };
     530           40 :     Client::new(inner, conn_info, pool_clone)
     531           40 : }
     532              : 
     533              : struct ClientInner<C: ClientInnerExt> {
     534              :     inner: C,
     535              :     session: tokio::sync::watch::Sender<uuid::Uuid>,
     536              :     aux: MetricsAuxInfo,
     537              :     conn_id: uuid::Uuid,
     538              : }
     539              : 
     540              : pub trait ClientInnerExt: Sync + Send + 'static {
     541              :     fn is_closed(&self) -> bool;
     542              :     fn get_process_id(&self) -> i32;
     543              : }
     544              : 
     545              : impl ClientInnerExt for tokio_postgres::Client {
     546           48 :     fn is_closed(&self) -> bool {
     547           48 :         self.is_closed()
     548           48 :     }
     549            4 :     fn get_process_id(&self) -> i32 {
     550            4 :         self.get_process_id()
     551            4 :     }
     552              : }
     553              : 
     554              : impl<C: ClientInnerExt> ClientInner<C> {
     555           64 :     pub fn is_closed(&self) -> bool {
     556           64 :         self.inner.is_closed()
     557           64 :     }
     558              : }
     559              : 
     560              : impl<C: ClientInnerExt> Client<C> {
     561           42 :     pub fn metrics(&self) -> Arc<MetricCounter> {
     562           42 :         let aux = &self.inner.as_ref().unwrap().aux;
     563           42 :         USAGE_METRICS.register(Ids {
     564           42 :             endpoint_id: aux.endpoint_id.clone(),
     565           42 :             branch_id: aux.branch_id.clone(),
     566           42 :         })
     567           42 :     }
     568              : }
     569              : 
     570              : pub struct Client<C: ClientInnerExt> {
     571              :     span: Span,
     572              :     inner: Option<ClientInner<C>>,
     573              :     conn_info: ConnInfo,
     574              :     pool: Weak<RwLock<EndpointConnPool<C>>>,
     575              : }
     576              : 
     577              : pub struct Discard<'a, C: ClientInnerExt> {
     578              :     conn_info: &'a ConnInfo,
     579              :     pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
     580              : }
     581              : 
     582              : impl<C: ClientInnerExt> Client<C> {
     583           58 :     pub(self) fn new(
     584           58 :         inner: ClientInner<C>,
     585           58 :         conn_info: ConnInfo,
     586           58 :         pool: Weak<RwLock<EndpointConnPool<C>>>,
     587           58 :     ) -> Self {
     588           58 :         Self {
     589           58 :             inner: Some(inner),
     590           58 :             span: Span::current(),
     591           58 :             conn_info,
     592           58 :             pool,
     593           58 :         }
     594           58 :     }
     595           46 :     pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
     596           46 :         let Self {
     597           46 :             inner,
     598           46 :             pool,
     599           46 :             conn_info,
     600           46 :             span: _,
     601           46 :         } = self;
     602           46 :         let inner = inner.as_mut().expect("client inner should not be removed");
     603           46 :         (&mut inner.inner, Discard { pool, conn_info })
     604           46 :     }
     605              : 
     606           39 :     pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
     607           39 :         self.inner().1.check_idle(status)
     608           39 :     }
     609            4 :     pub fn discard(&mut self) {
     610            4 :         self.inner().1.discard()
     611            4 :     }
     612              : }
     613              : 
     614              : impl<C: ClientInnerExt> Discard<'_, C> {
     615           42 :     pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
     616           42 :         let conn_info = &self.conn_info;
     617           42 :         if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
     618            2 :             info!("pool: throwing away connection '{conn_info}' because connection is not idle")
     619           40 :         }
     620           42 :     }
     621            4 :     pub fn discard(&mut self) {
     622            4 :         let conn_info = &self.conn_info;
     623            4 :         if std::mem::take(self.pool).strong_count() > 0 {
     624            4 :             info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
     625            0 :         }
     626            4 :     }
     627              : }
     628              : 
     629              : impl<C: ClientInnerExt> Deref for Client<C> {
     630              :     type Target = C;
     631              : 
     632           41 :     fn deref(&self) -> &Self::Target {
     633           41 :         &self
     634           41 :             .inner
     635           41 :             .as_ref()
     636           41 :             .expect("client inner should not be removed")
     637           41 :             .inner
     638           41 :     }
     639              : }
     640              : 
     641              : impl<C: ClientInnerExt> Client<C> {
     642           58 :     fn do_drop(&mut self) -> Option<impl FnOnce()> {
     643           58 :         let conn_info = self.conn_info.clone();
     644           58 :         let client = self
     645           58 :             .inner
     646           58 :             .take()
     647           58 :             .expect("client inner should not be removed");
     648           58 :         if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
     649           52 :             let current_span = self.span.clone();
     650           52 :             // return connection to the pool
     651           52 :             return Some(move || {
     652           52 :                 let _span = current_span.enter();
     653           52 :                 let _ = EndpointConnPool::put(&conn_pool, &conn_info, client);
     654           52 :             });
     655            6 :         }
     656            6 :         None
     657           58 :     }
     658              : }
     659              : 
     660              : impl<C: ClientInnerExt> Drop for Client<C> {
     661           46 :     fn drop(&mut self) {
     662           46 :         if let Some(drop) = self.do_drop() {
     663           40 :             tokio::task::spawn_blocking(drop);
     664           40 :         }
     665           46 :     }
     666              : }
     667              : 
     668              : #[cfg(test)]
     669              : mod tests {
     670              :     use env_logger;
     671              :     use std::{mem, sync::atomic::AtomicBool};
     672              : 
     673              :     use super::*;
     674              : 
     675              :     struct MockClient(Arc<AtomicBool>);
     676              :     impl MockClient {
     677           12 :         fn new(is_closed: bool) -> Self {
     678           12 :             MockClient(Arc::new(is_closed.into()))
     679           12 :         }
     680              :     }
     681              :     impl ClientInnerExt for MockClient {
     682           16 :         fn is_closed(&self) -> bool {
     683           16 :             self.0.load(atomic::Ordering::Relaxed)
     684           16 :         }
     685            0 :         fn get_process_id(&self) -> i32 {
     686            0 :             0
     687            0 :         }
     688              :     }
     689              : 
     690           10 :     fn create_inner() -> ClientInner<MockClient> {
     691           10 :         create_inner_with(MockClient::new(false))
     692           10 :     }
     693              : 
     694           14 :     fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
     695           14 :         ClientInner {
     696           14 :             inner: client,
     697           14 :             session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
     698           14 :             aux: Default::default(),
     699           14 :             conn_id: uuid::Uuid::new_v4(),
     700           14 :         }
     701           14 :     }
     702              : 
     703            2 :     #[tokio::test]
     704            2 :     async fn test_pool() {
     705            2 :         let _ = env_logger::try_init();
     706            2 :         let config = Box::leak(Box::new(crate::config::HttpConfig {
     707            2 :             pool_options: GlobalConnPoolOptions {
     708            2 :                 max_conns_per_endpoint: 2,
     709            2 :                 gc_epoch: Duration::from_secs(1),
     710            2 :                 pool_shards: 2,
     711            2 :                 idle_timeout: Duration::from_secs(1),
     712            2 :                 opt_in: false,
     713            2 :                 max_total_conns: 3,
     714            2 :             },
     715            2 :             request_timeout: Duration::from_secs(1),
     716            2 :         }));
     717            2 :         let pool = GlobalConnPool::new(config);
     718            2 :         let conn_info = ConnInfo {
     719            2 :             user_info: ComputeUserInfo {
     720            2 :                 user: "user".into(),
     721            2 :                 endpoint: "endpoint".into(),
     722            2 :                 options: Default::default(),
     723            2 :             },
     724            2 :             dbname: "dbname".into(),
     725            2 :             password: "password".as_bytes().into(),
     726            2 :         };
     727            2 :         let ep_pool =
     728            2 :             Arc::downgrade(&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()));
     729            2 :         {
     730            2 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     731            2 :             assert_eq!(0, pool.get_global_connections_count());
     732            2 :             client.discard();
     733            2 :             // Discard should not add the connection from the pool.
     734            2 :             assert_eq!(0, pool.get_global_connections_count());
     735            2 :         }
     736            2 :         {
     737            2 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     738            2 :             client.do_drop().unwrap()();
     739            2 :             mem::forget(client); // drop the client
     740            2 :             assert_eq!(1, pool.get_global_connections_count());
     741            2 :         }
     742            2 :         {
     743            2 :             let mut closed_client = Client::new(
     744            2 :                 create_inner_with(MockClient::new(true)),
     745            2 :                 conn_info.clone(),
     746            2 :                 ep_pool.clone(),
     747            2 :             );
     748            2 :             closed_client.do_drop().unwrap()();
     749            2 :             mem::forget(closed_client); // drop the client
     750            2 :                                         // The closed client shouldn't be added to the pool.
     751            2 :             assert_eq!(1, pool.get_global_connections_count());
     752            2 :         }
     753            2 :         let is_closed: Arc<AtomicBool> = Arc::new(false.into());
     754            2 :         {
     755            2 :             let mut client = Client::new(
     756            2 :                 create_inner_with(MockClient(is_closed.clone())),
     757            2 :                 conn_info.clone(),
     758            2 :                 ep_pool.clone(),
     759            2 :             );
     760            2 :             client.do_drop().unwrap()();
     761            2 :             mem::forget(client); // drop the client
     762            2 : 
     763            2 :             // The client should be added to the pool.
     764            2 :             assert_eq!(2, pool.get_global_connections_count());
     765            2 :         }
     766            2 :         {
     767            2 :             let mut client = Client::new(create_inner(), conn_info, ep_pool);
     768            2 :             client.do_drop().unwrap()();
     769            2 :             mem::forget(client); // drop the client
     770            2 : 
     771            2 :             // The client shouldn't be added to the pool. Because the ep-pool is full.
     772            2 :             assert_eq!(2, pool.get_global_connections_count());
     773            2 :         }
     774            2 : 
     775            2 :         let conn_info = ConnInfo {
     776            2 :             user_info: ComputeUserInfo {
     777            2 :                 user: "user".into(),
     778            2 :                 endpoint: "endpoint-2".into(),
     779            2 :                 options: Default::default(),
     780            2 :             },
     781            2 :             dbname: "dbname".into(),
     782            2 :             password: "password".as_bytes().into(),
     783            2 :         };
     784            2 :         let ep_pool =
     785            2 :             Arc::downgrade(&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()));
     786            2 :         {
     787            2 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     788            2 :             client.do_drop().unwrap()();
     789            2 :             mem::forget(client); // drop the client
     790            2 :             assert_eq!(3, pool.get_global_connections_count());
     791            2 :         }
     792            2 :         {
     793            2 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     794            2 :             client.do_drop().unwrap()();
     795            2 :             mem::forget(client); // drop the client
     796            2 : 
     797            2 :             // The client shouldn't be added to the pool. Because the global pool is full.
     798            2 :             assert_eq!(3, pool.get_global_connections_count());
     799            2 :         }
     800            2 : 
     801            2 :         is_closed.store(true, atomic::Ordering::Relaxed);
     802            2 :         // Do gc for all shards.
     803            2 :         pool.gc(0);
     804            2 :         pool.gc(1);
     805            2 :         // Closed client should be removed from the pool.
     806            2 :         assert_eq!(2, pool.get_global_connections_count());
     807            2 :     }
     808              : }
        

Generated by: LCOV version 2.1-beta