LCOV - code coverage report
Current view: top level - proxy/src/serverless - conn_pool.rs (source / functions) Coverage Total Hit
Test: 7eb96e224e685167ad85f58f858387d8cf253f63.info Lines: 60.1 % 602 362
Test Date: 2024-09-23 21:23:07 Functions: 32.9 % 82 27

            Line data    Source code
       1              : use dashmap::DashMap;
       2              : use futures::{future::poll_fn, Future};
       3              : use parking_lot::RwLock;
       4              : use rand::Rng;
       5              : use smallvec::SmallVec;
       6              : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
       7              : use std::{
       8              :     fmt,
       9              :     task::{ready, Poll},
      10              : };
      11              : use std::{
      12              :     ops::Deref,
      13              :     sync::atomic::{self, AtomicUsize},
      14              : };
      15              : use tokio::time::Instant;
      16              : use tokio_postgres::tls::NoTlsStream;
      17              : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
      18              : use tokio_util::sync::CancellationToken;
      19              : 
      20              : use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
      21              : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
      22              : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
      23              : use crate::{
      24              :     auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
      25              : };
      26              : 
      27              : use tracing::{debug, error, warn, Span};
      28              : use tracing::{info, info_span, Instrument};
      29              : 
      30              : use super::backend::HttpConnError;
      31              : 
      32              : #[derive(Debug, Clone)]
      33              : pub(crate) struct ConnInfoWithAuth {
      34              :     pub(crate) conn_info: ConnInfo,
      35              :     pub(crate) auth: AuthData,
      36              : }
      37              : 
      38              : #[derive(Debug, Clone)]
      39              : pub(crate) struct ConnInfo {
      40              :     pub(crate) user_info: ComputeUserInfo,
      41              :     pub(crate) dbname: DbName,
      42              : }
      43              : 
      44              : #[derive(Debug, Clone)]
      45              : pub(crate) enum AuthData {
      46              :     Password(SmallVec<[u8; 16]>),
      47              :     Jwt(String),
      48              : }
      49              : 
      50              : impl ConnInfo {
      51              :     // hm, change to hasher to avoid cloning?
      52            3 :     pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
      53            3 :         (self.dbname.clone(), self.user_info.user.clone())
      54            3 :     }
      55              : 
      56            2 :     pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
      57            2 :         // We don't want to cache http connections for ephemeral endpoints.
      58            2 :         if self.user_info.options.is_ephemeral() {
      59            0 :             None
      60              :         } else {
      61            2 :             Some(self.user_info.endpoint_cache_key())
      62              :         }
      63            2 :     }
      64              : }
      65              : 
      66              : impl fmt::Display for ConnInfo {
      67              :     // use custom display to avoid logging password
      68            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      69            0 :         write!(
      70            0 :             f,
      71            0 :             "{}@{}/{}?{}",
      72            0 :             self.user_info.user,
      73            0 :             self.user_info.endpoint,
      74            0 :             self.dbname,
      75            0 :             self.user_info.options.get_cache_key("")
      76            0 :         )
      77            0 :     }
      78              : }
      79              : 
      80              : struct ConnPoolEntry<C: ClientInnerExt> {
      81              :     conn: ClientInner<C>,
      82              :     _last_access: std::time::Instant,
      83              : }
      84              : 
      85              : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
      86              : // Number of open connections is limited by the `max_conns_per_endpoint`.
      87              : pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
      88              :     pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
      89              :     total_conns: usize,
      90              :     max_conns: usize,
      91              :     _guard: HttpEndpointPoolsGuard<'static>,
      92              :     global_connections_count: Arc<AtomicUsize>,
      93              :     global_pool_size_max_conns: usize,
      94              : }
      95              : 
      96              : impl<C: ClientInnerExt> EndpointConnPool<C> {
      97            0 :     fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
      98            0 :         let Self {
      99            0 :             pools,
     100            0 :             total_conns,
     101            0 :             global_connections_count,
     102            0 :             ..
     103            0 :         } = self;
     104            0 :         pools.get_mut(&db_user).and_then(|pool_entries| {
     105            0 :             pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
     106            0 :         })
     107            0 :     }
     108              : 
     109            0 :     fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
     110            0 :         let Self {
     111            0 :             pools,
     112            0 :             total_conns,
     113            0 :             global_connections_count,
     114            0 :             ..
     115            0 :         } = self;
     116            0 :         if let Some(pool) = pools.get_mut(&db_user) {
     117            0 :             let old_len = pool.conns.len();
     118            0 :             pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
     119            0 :             let new_len = pool.conns.len();
     120            0 :             let removed = old_len - new_len;
     121            0 :             if removed > 0 {
     122            0 :                 global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
     123            0 :                 Metrics::get()
     124            0 :                     .proxy
     125            0 :                     .http_pool_opened_connections
     126            0 :                     .get_metric()
     127            0 :                     .dec_by(removed as i64);
     128            0 :             }
     129            0 :             *total_conns -= removed;
     130            0 :             removed > 0
     131              :         } else {
     132            0 :             false
     133              :         }
     134            0 :     }
     135              : 
     136            6 :     fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
     137            6 :         let conn_id = client.conn_id;
     138            6 : 
     139            6 :         if client.is_closed() {
     140            1 :             info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
     141            1 :             return;
     142            5 :         }
     143            5 :         let global_max_conn = pool.read().global_pool_size_max_conns;
     144            5 :         if pool
     145            5 :             .read()
     146            5 :             .global_connections_count
     147            5 :             .load(atomic::Ordering::Relaxed)
     148            5 :             >= global_max_conn
     149              :         {
     150            1 :             info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
     151            1 :             return;
     152            4 :         }
     153            4 : 
     154            4 :         // return connection to the pool
     155            4 :         let mut returned = false;
     156            4 :         let mut per_db_size = 0;
     157            4 :         let total_conns = {
     158            4 :             let mut pool = pool.write();
     159            4 : 
     160            4 :             if pool.total_conns < pool.max_conns {
     161            3 :                 let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
     162            3 :                 pool_entries.conns.push(ConnPoolEntry {
     163            3 :                     conn: client,
     164            3 :                     _last_access: std::time::Instant::now(),
     165            3 :                 });
     166            3 : 
     167            3 :                 returned = true;
     168            3 :                 per_db_size = pool_entries.conns.len();
     169            3 : 
     170            3 :                 pool.total_conns += 1;
     171            3 :                 pool.global_connections_count
     172            3 :                     .fetch_add(1, atomic::Ordering::Relaxed);
     173            3 :                 Metrics::get()
     174            3 :                     .proxy
     175            3 :                     .http_pool_opened_connections
     176            3 :                     .get_metric()
     177            3 :                     .inc();
     178            3 :             }
     179              : 
     180            4 :             pool.total_conns
     181            4 :         };
     182            4 : 
     183            4 :         // do logging outside of the mutex
     184            4 :         if returned {
     185            3 :             info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
     186              :         } else {
     187            1 :             info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
     188              :         }
     189            6 :     }
     190              : }
     191              : 
     192              : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
     193            2 :     fn drop(&mut self) {
     194            2 :         if self.total_conns > 0 {
     195            2 :             self.global_connections_count
     196            2 :                 .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
     197            2 :             Metrics::get()
     198            2 :                 .proxy
     199            2 :                 .http_pool_opened_connections
     200            2 :                 .get_metric()
     201            2 :                 .dec_by(self.total_conns as i64);
     202            2 :         }
     203            2 :     }
     204              : }
     205              : 
     206              : pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
     207              :     conns: Vec<ConnPoolEntry<C>>,
     208              : }
     209              : 
     210              : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
     211            2 :     fn default() -> Self {
     212            2 :         Self { conns: Vec::new() }
     213            2 :     }
     214              : }
     215              : 
     216              : impl<C: ClientInnerExt> DbUserConnPool<C> {
     217            1 :     fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
     218            1 :         let old_len = self.conns.len();
     219            1 : 
     220            2 :         self.conns.retain(|conn| !conn.conn.is_closed());
     221            1 : 
     222            1 :         let new_len = self.conns.len();
     223            1 :         let removed = old_len - new_len;
     224            1 :         *conns -= removed;
     225            1 :         removed
     226            1 :     }
     227              : 
     228            0 :     fn get_conn_entry(
     229            0 :         &mut self,
     230            0 :         conns: &mut usize,
     231            0 :         global_connections_count: Arc<AtomicUsize>,
     232            0 :     ) -> Option<ConnPoolEntry<C>> {
     233            0 :         let mut removed = self.clear_closed_clients(conns);
     234            0 :         let conn = self.conns.pop();
     235            0 :         if conn.is_some() {
     236            0 :             *conns -= 1;
     237            0 :             removed += 1;
     238            0 :         }
     239            0 :         global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
     240            0 :         Metrics::get()
     241            0 :             .proxy
     242            0 :             .http_pool_opened_connections
     243            0 :             .get_metric()
     244            0 :             .dec_by(removed as i64);
     245            0 :         conn
     246            0 :     }
     247              : }
     248              : 
     249              : pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
     250              :     // endpoint -> per-endpoint connection pool
     251              :     //
     252              :     // That should be a fairly conteded map, so return reference to the per-endpoint
     253              :     // pool as early as possible and release the lock.
     254              :     global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
     255              : 
     256              :     /// Number of endpoint-connection pools
     257              :     ///
     258              :     /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
     259              :     /// That seems like far too much effort, so we're using a relaxed increment counter instead.
     260              :     /// It's only used for diagnostics.
     261              :     global_pool_size: AtomicUsize,
     262              : 
     263              :     /// Total number of connections in the pool
     264              :     global_connections_count: Arc<AtomicUsize>,
     265              : 
     266              :     config: &'static crate::config::HttpConfig,
     267              : }
     268              : 
     269              : #[derive(Debug, Clone, Copy)]
     270              : pub struct GlobalConnPoolOptions {
     271              :     // Maximum number of connections per one endpoint.
     272              :     // Can mix different (dbname, username) connections.
     273              :     // When running out of free slots for a particular endpoint,
     274              :     // falls back to opening a new connection for each request.
     275              :     pub max_conns_per_endpoint: usize,
     276              : 
     277              :     pub gc_epoch: Duration,
     278              : 
     279              :     pub pool_shards: usize,
     280              : 
     281              :     pub idle_timeout: Duration,
     282              : 
     283              :     pub opt_in: bool,
     284              : 
     285              :     // Total number of connections in the pool.
     286              :     pub max_total_conns: usize,
     287              : }
     288              : 
     289              : impl<C: ClientInnerExt> GlobalConnPool<C> {
     290            1 :     pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
     291            1 :         let shards = config.pool_options.pool_shards;
     292            1 :         Arc::new(Self {
     293            1 :             global_pool: DashMap::with_shard_amount(shards),
     294            1 :             global_pool_size: AtomicUsize::new(0),
     295            1 :             config,
     296            1 :             global_connections_count: Arc::new(AtomicUsize::new(0)),
     297            1 :         })
     298            1 :     }
     299              : 
     300              :     #[cfg(test)]
     301            9 :     pub(crate) fn get_global_connections_count(&self) -> usize {
     302            9 :         self.global_connections_count
     303            9 :             .load(atomic::Ordering::Relaxed)
     304            9 :     }
     305              : 
     306            0 :     pub(crate) fn get_idle_timeout(&self) -> Duration {
     307            0 :         self.config.pool_options.idle_timeout
     308            0 :     }
     309              : 
     310            0 :     pub(crate) fn shutdown(&self) {
     311            0 :         // drops all strong references to endpoint-pools
     312            0 :         self.global_pool.clear();
     313            0 :     }
     314              : 
     315            0 :     pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
     316            0 :         let epoch = self.config.pool_options.gc_epoch;
     317            0 :         let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
     318              :         loop {
     319            0 :             interval.tick().await;
     320              : 
     321            0 :             let shard = rng.gen_range(0..self.global_pool.shards().len());
     322            0 :             self.gc(shard);
     323              :         }
     324              :     }
     325              : 
     326            2 :     fn gc(&self, shard: usize) {
     327            2 :         debug!(shard, "pool: performing epoch reclamation");
     328              : 
     329              :         // acquire a random shard lock
     330            2 :         let mut shard = self.global_pool.shards()[shard].write();
     331            2 : 
     332            2 :         let timer = Metrics::get()
     333            2 :             .proxy
     334            2 :             .http_pool_reclaimation_lag_seconds
     335            2 :             .start_timer();
     336            2 :         let current_len = shard.len();
     337            2 :         let mut clients_removed = 0;
     338            2 :         shard.retain(|endpoint, x| {
     339              :             // if the current endpoint pool is unique (no other strong or weak references)
     340              :             // then it is currently not in use by any connections.
     341            2 :             if let Some(pool) = Arc::get_mut(x.get_mut()) {
     342              :                 let EndpointConnPool {
     343            1 :                     pools, total_conns, ..
     344            1 :                 } = pool.get_mut();
     345              : 
     346              :                 // ensure that closed clients are removed
     347            1 :                 for db_pool in pools.values_mut() {
     348            1 :                     clients_removed += db_pool.clear_closed_clients(total_conns);
     349            1 :                 }
     350              : 
     351              :                 // we only remove this pool if it has no active connections
     352            1 :                 if *total_conns == 0 {
     353            0 :                     info!("pool: discarding pool for endpoint {endpoint}");
     354            0 :                     return false;
     355            1 :                 }
     356            1 :             }
     357              : 
     358            2 :             true
     359            2 :         });
     360            2 : 
     361            2 :         let new_len = shard.len();
     362            2 :         drop(shard);
     363            2 :         timer.observe();
     364            2 : 
     365            2 :         // Do logging outside of the lock.
     366            2 :         if clients_removed > 0 {
     367            1 :             let size = self
     368            1 :                 .global_connections_count
     369            1 :                 .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
     370            1 :                 - clients_removed;
     371            1 :             Metrics::get()
     372            1 :                 .proxy
     373            1 :                 .http_pool_opened_connections
     374            1 :                 .get_metric()
     375            1 :                 .dec_by(clients_removed as i64);
     376            1 :             info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
     377            1 :         }
     378            2 :         let removed = current_len - new_len;
     379            2 : 
     380            2 :         if removed > 0 {
     381            0 :             let global_pool_size = self
     382            0 :                 .global_pool_size
     383            0 :                 .fetch_sub(removed, atomic::Ordering::Relaxed)
     384            0 :                 - removed;
     385            0 :             info!("pool: performed global pool gc. size now {global_pool_size}");
     386            2 :         }
     387            2 :     }
     388              : 
     389            0 :     pub(crate) fn get(
     390            0 :         self: &Arc<Self>,
     391            0 :         ctx: &RequestMonitoring,
     392            0 :         conn_info: &ConnInfo,
     393            0 :     ) -> Result<Option<Client<C>>, HttpConnError> {
     394            0 :         let mut client: Option<ClientInner<C>> = None;
     395            0 :         let Some(endpoint) = conn_info.endpoint_cache_key() else {
     396            0 :             return Ok(None);
     397              :         };
     398              : 
     399            0 :         let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
     400            0 :         if let Some(entry) = endpoint_pool
     401            0 :             .write()
     402            0 :             .get_conn_entry(conn_info.db_and_user())
     403            0 :         {
     404            0 :             client = Some(entry.conn);
     405            0 :         }
     406            0 :         let endpoint_pool = Arc::downgrade(&endpoint_pool);
     407              : 
     408              :         // ok return cached connection if found and establish a new one otherwise
     409            0 :         if let Some(client) = client {
     410            0 :             if client.is_closed() {
     411            0 :                 info!("pool: cached connection '{conn_info}' is closed, opening a new one");
     412            0 :                 return Ok(None);
     413            0 :             }
     414            0 :             tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
     415            0 :             tracing::Span::current().record(
     416            0 :                 "pid",
     417            0 :                 tracing::field::display(client.inner.get_process_id()),
     418            0 :             );
     419            0 :             info!(
     420            0 :                 cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
     421            0 :                 "pool: reusing connection '{conn_info}'"
     422              :             );
     423            0 :             client.session.send(ctx.session_id())?;
     424            0 :             ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
     425            0 :             ctx.success();
     426            0 :             return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
     427            0 :         }
     428            0 :         Ok(None)
     429            0 :     }
     430              : 
     431            2 :     fn get_or_create_endpoint_pool(
     432            2 :         self: &Arc<Self>,
     433            2 :         endpoint: &EndpointCacheKey,
     434            2 :     ) -> Arc<RwLock<EndpointConnPool<C>>> {
     435              :         // fast path
     436            2 :         if let Some(pool) = self.global_pool.get(endpoint) {
     437            0 :             return pool.clone();
     438            2 :         }
     439            2 : 
     440            2 :         // slow path
     441            2 :         let new_pool = Arc::new(RwLock::new(EndpointConnPool {
     442            2 :             pools: HashMap::new(),
     443            2 :             total_conns: 0,
     444            2 :             max_conns: self.config.pool_options.max_conns_per_endpoint,
     445            2 :             _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
     446            2 :             global_connections_count: self.global_connections_count.clone(),
     447            2 :             global_pool_size_max_conns: self.config.pool_options.max_total_conns,
     448            2 :         }));
     449            2 : 
     450            2 :         // find or create a pool for this endpoint
     451            2 :         let mut created = false;
     452            2 :         let pool = self
     453            2 :             .global_pool
     454            2 :             .entry(endpoint.clone())
     455            2 :             .or_insert_with(|| {
     456            2 :                 created = true;
     457            2 :                 new_pool
     458            2 :             })
     459            2 :             .clone();
     460            2 : 
     461            2 :         // log new global pool size
     462            2 :         if created {
     463            2 :             let global_pool_size = self
     464            2 :                 .global_pool_size
     465            2 :                 .fetch_add(1, atomic::Ordering::Relaxed)
     466            2 :                 + 1;
     467            2 :             info!(
     468            0 :                 "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
     469              :             );
     470            0 :         }
     471              : 
     472            2 :         pool
     473            2 :     }
     474              : }
     475              : 
     476            0 : pub(crate) fn poll_client<C: ClientInnerExt>(
     477            0 :     global_pool: Arc<GlobalConnPool<C>>,
     478            0 :     ctx: &RequestMonitoring,
     479            0 :     conn_info: ConnInfo,
     480            0 :     client: C,
     481            0 :     mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
     482            0 :     conn_id: uuid::Uuid,
     483            0 :     aux: MetricsAuxInfo,
     484            0 : ) -> Client<C> {
     485            0 :     let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
     486            0 :     let mut session_id = ctx.session_id();
     487            0 :     let (tx, mut rx) = tokio::sync::watch::channel(session_id);
     488              : 
     489            0 :     let span = info_span!(parent: None, "connection", %conn_id);
     490            0 :     let cold_start_info = ctx.cold_start_info();
     491            0 :     span.in_scope(|| {
     492            0 :         info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
     493            0 :     });
     494            0 :     let pool = match conn_info.endpoint_cache_key() {
     495            0 :         Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
     496            0 :         None => Weak::new(),
     497              :     };
     498            0 :     let pool_clone = pool.clone();
     499            0 : 
     500            0 :     let db_user = conn_info.db_and_user();
     501            0 :     let idle = global_pool.get_idle_timeout();
     502            0 :     let cancel = CancellationToken::new();
     503            0 :     let cancelled = cancel.clone().cancelled_owned();
     504            0 : 
     505            0 :     tokio::spawn(
     506            0 :     async move {
     507            0 :         let _conn_gauge = conn_gauge;
     508            0 :         let mut idle_timeout = pin!(tokio::time::sleep(idle));
     509            0 :         let mut cancelled = pin!(cancelled);
     510            0 : 
     511            0 :         poll_fn(move |cx| {
     512            0 :             if cancelled.as_mut().poll(cx).is_ready() {
     513            0 :                 info!("connection dropped");
     514            0 :                 return Poll::Ready(())
     515            0 :             }
     516            0 : 
     517            0 :             match rx.has_changed() {
     518              :                 Ok(true) => {
     519            0 :                     session_id = *rx.borrow_and_update();
     520            0 :                     info!(%session_id, "changed session");
     521            0 :                     idle_timeout.as_mut().reset(Instant::now() + idle);
     522              :                 }
     523              :                 Err(_) => {
     524            0 :                     info!("connection dropped");
     525            0 :                     return Poll::Ready(())
     526              :                 }
     527            0 :                 _ => {}
     528              :             }
     529              : 
     530              :             // 5 minute idle connection timeout
     531            0 :             if idle_timeout.as_mut().poll(cx).is_ready() {
     532            0 :                 idle_timeout.as_mut().reset(Instant::now() + idle);
     533            0 :                 info!("connection idle");
     534            0 :                 if let Some(pool) = pool.clone().upgrade() {
     535              :                     // remove client from pool - should close the connection if it's idle.
     536              :                     // does nothing if the client is currently checked-out and in-use
     537            0 :                     if pool.write().remove_client(db_user.clone(), conn_id) {
     538            0 :                         info!("idle connection removed");
     539            0 :                     }
     540            0 :                 }
     541            0 :             }
     542              : 
     543              :             loop {
     544            0 :                 let message = ready!(connection.poll_message(cx));
     545              : 
     546            0 :                 match message {
     547            0 :                     Some(Ok(AsyncMessage::Notice(notice))) => {
     548            0 :                         info!(%session_id, "notice: {}", notice);
     549              :                     }
     550            0 :                     Some(Ok(AsyncMessage::Notification(notif))) => {
     551            0 :                         warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
     552              :                     }
     553              :                     Some(Ok(_)) => {
     554            0 :                         warn!(%session_id, "unknown message");
     555              :                     }
     556            0 :                     Some(Err(e)) => {
     557            0 :                         error!(%session_id, "connection error: {}", e);
     558            0 :                         break
     559              :                     }
     560              :                     None => {
     561            0 :                         info!("connection closed");
     562            0 :                         break
     563              :                     }
     564              :                 }
     565              :             }
     566              : 
     567              :             // remove from connection pool
     568            0 :             if let Some(pool) = pool.clone().upgrade() {
     569            0 :                 if pool.write().remove_client(db_user.clone(), conn_id) {
     570            0 :                     info!("closed connection removed");
     571            0 :                 }
     572            0 :             }
     573              : 
     574            0 :             Poll::Ready(())
     575            0 :         }).await;
     576              : 
     577            0 :     }
     578            0 :     .instrument(span));
     579            0 :     let inner = ClientInner {
     580            0 :         inner: client,
     581            0 :         session: tx,
     582            0 :         cancel,
     583            0 :         aux,
     584            0 :         conn_id,
     585            0 :     };
     586            0 :     Client::new(inner, conn_info, pool_clone)
     587            0 : }
     588              : 
     589              : struct ClientInner<C: ClientInnerExt> {
     590              :     inner: C,
     591              :     session: tokio::sync::watch::Sender<uuid::Uuid>,
     592              :     cancel: CancellationToken,
     593              :     aux: MetricsAuxInfo,
     594              :     conn_id: uuid::Uuid,
     595              : }
     596              : 
     597              : impl<C: ClientInnerExt> Drop for ClientInner<C> {
     598            7 :     fn drop(&mut self) {
     599            7 :         // on client drop, tell the conn to shut down
     600            7 :         self.cancel.cancel();
     601            7 :     }
     602              : }
     603              : 
     604              : pub(crate) trait ClientInnerExt: Sync + Send + 'static {
     605              :     fn is_closed(&self) -> bool;
     606              :     fn get_process_id(&self) -> i32;
     607              : }
     608              : 
     609              : impl ClientInnerExt for tokio_postgres::Client {
     610            0 :     fn is_closed(&self) -> bool {
     611            0 :         self.is_closed()
     612            0 :     }
     613            0 :     fn get_process_id(&self) -> i32 {
     614            0 :         self.get_process_id()
     615            0 :     }
     616              : }
     617              : 
     618              : impl<C: ClientInnerExt> ClientInner<C> {
     619            8 :     pub(crate) fn is_closed(&self) -> bool {
     620            8 :         self.inner.is_closed()
     621            8 :     }
     622              : }
     623              : 
     624              : impl<C: ClientInnerExt> Client<C> {
     625            0 :     pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
     626            0 :         let aux = &self.inner.as_ref().unwrap().aux;
     627            0 :         USAGE_METRICS.register(Ids {
     628            0 :             endpoint_id: aux.endpoint_id,
     629            0 :             branch_id: aux.branch_id,
     630            0 :         })
     631            0 :     }
     632              : }
     633              : 
     634              : pub(crate) struct Client<C: ClientInnerExt> {
     635              :     span: Span,
     636              :     inner: Option<ClientInner<C>>,
     637              :     conn_info: ConnInfo,
     638              :     pool: Weak<RwLock<EndpointConnPool<C>>>,
     639              : }
     640              : 
     641              : pub(crate) struct Discard<'a, C: ClientInnerExt> {
     642              :     conn_info: &'a ConnInfo,
     643              :     pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
     644              : }
     645              : 
     646              : impl<C: ClientInnerExt> Client<C> {
     647            7 :     pub(self) fn new(
     648            7 :         inner: ClientInner<C>,
     649            7 :         conn_info: ConnInfo,
     650            7 :         pool: Weak<RwLock<EndpointConnPool<C>>>,
     651            7 :     ) -> Self {
     652            7 :         Self {
     653            7 :             inner: Some(inner),
     654            7 :             span: Span::current(),
     655            7 :             conn_info,
     656            7 :             pool,
     657            7 :         }
     658            7 :     }
     659            1 :     pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
     660            1 :         let Self {
     661            1 :             inner,
     662            1 :             pool,
     663            1 :             conn_info,
     664            1 :             span: _,
     665            1 :         } = self;
     666            1 :         let inner = inner.as_mut().expect("client inner should not be removed");
     667            1 :         (&mut inner.inner, Discard { conn_info, pool })
     668            1 :     }
     669              : }
     670              : 
     671              : impl<C: ClientInnerExt> Discard<'_, C> {
     672            0 :     pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
     673            0 :         let conn_info = &self.conn_info;
     674            0 :         if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
     675            0 :             info!("pool: throwing away connection '{conn_info}' because connection is not idle");
     676            0 :         }
     677            0 :     }
     678            1 :     pub(crate) fn discard(&mut self) {
     679            1 :         let conn_info = &self.conn_info;
     680            1 :         if std::mem::take(self.pool).strong_count() > 0 {
     681            1 :             info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
     682            0 :         }
     683            1 :     }
     684              : }
     685              : 
     686              : impl<C: ClientInnerExt> Deref for Client<C> {
     687              :     type Target = C;
     688              : 
     689            0 :     fn deref(&self) -> &Self::Target {
     690            0 :         &self
     691            0 :             .inner
     692            0 :             .as_ref()
     693            0 :             .expect("client inner should not be removed")
     694            0 :             .inner
     695            0 :     }
     696              : }
     697              : 
     698              : impl<C: ClientInnerExt> Client<C> {
     699            7 :     fn do_drop(&mut self) -> Option<impl FnOnce()> {
     700            7 :         let conn_info = self.conn_info.clone();
     701            7 :         let client = self
     702            7 :             .inner
     703            7 :             .take()
     704            7 :             .expect("client inner should not be removed");
     705            7 :         if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
     706            6 :             let current_span = self.span.clone();
     707            6 :             // return connection to the pool
     708            6 :             return Some(move || {
     709            6 :                 let _span = current_span.enter();
     710            6 :                 EndpointConnPool::put(&conn_pool, &conn_info, client);
     711            6 :             });
     712            1 :         }
     713            1 :         None
     714            7 :     }
     715              : }
     716              : 
     717              : impl<C: ClientInnerExt> Drop for Client<C> {
     718            1 :     fn drop(&mut self) {
     719            1 :         if let Some(drop) = self.do_drop() {
     720            0 :             tokio::task::spawn_blocking(drop);
     721            1 :         }
     722            1 :     }
     723              : }
     724              : 
     725              : #[cfg(test)]
     726              : mod tests {
     727              :     use std::{mem, sync::atomic::AtomicBool};
     728              : 
     729              :     use crate::{
     730              :         proxy::NeonOptions, serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId,
     731              :     };
     732              : 
     733              :     use super::*;
     734              : 
     735              :     struct MockClient(Arc<AtomicBool>);
     736              :     impl MockClient {
     737            6 :         fn new(is_closed: bool) -> Self {
     738            6 :             MockClient(Arc::new(is_closed.into()))
     739            6 :         }
     740              :     }
     741              :     impl ClientInnerExt for MockClient {
     742            8 :         fn is_closed(&self) -> bool {
     743            8 :             self.0.load(atomic::Ordering::Relaxed)
     744            8 :         }
     745            0 :         fn get_process_id(&self) -> i32 {
     746            0 :             0
     747            0 :         }
     748              :     }
     749              : 
     750            5 :     fn create_inner() -> ClientInner<MockClient> {
     751            5 :         create_inner_with(MockClient::new(false))
     752            5 :     }
     753              : 
     754            7 :     fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
     755            7 :         ClientInner {
     756            7 :             inner: client,
     757            7 :             session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
     758            7 :             cancel: CancellationToken::new(),
     759            7 :             aux: MetricsAuxInfo {
     760            7 :                 endpoint_id: (&EndpointId::from("endpoint")).into(),
     761            7 :                 project_id: (&ProjectId::from("project")).into(),
     762            7 :                 branch_id: (&BranchId::from("branch")).into(),
     763            7 :                 cold_start_info: crate::console::messages::ColdStartInfo::Warm,
     764            7 :             },
     765            7 :             conn_id: uuid::Uuid::new_v4(),
     766            7 :         }
     767            7 :     }
     768              : 
     769              :     #[tokio::test]
     770            1 :     async fn test_pool() {
     771            1 :         let _ = env_logger::try_init();
     772            1 :         let config = Box::leak(Box::new(crate::config::HttpConfig {
     773            1 :             accept_websockets: false,
     774            1 :             pool_options: GlobalConnPoolOptions {
     775            1 :                 max_conns_per_endpoint: 2,
     776            1 :                 gc_epoch: Duration::from_secs(1),
     777            1 :                 pool_shards: 2,
     778            1 :                 idle_timeout: Duration::from_secs(1),
     779            1 :                 opt_in: false,
     780            1 :                 max_total_conns: 3,
     781            1 :             },
     782            1 :             cancel_set: CancelSet::new(0),
     783            1 :             client_conn_threshold: u64::MAX,
     784            1 :             max_request_size_bytes: u64::MAX,
     785            1 :             max_response_size_bytes: usize::MAX,
     786            1 :         }));
     787            1 :         let pool = GlobalConnPool::new(config);
     788            1 :         let conn_info = ConnInfo {
     789            1 :             user_info: ComputeUserInfo {
     790            1 :                 user: "user".into(),
     791            1 :                 endpoint: "endpoint".into(),
     792            1 :                 options: NeonOptions::default(),
     793            1 :             },
     794            1 :             dbname: "dbname".into(),
     795            1 :         };
     796            1 :         let ep_pool = Arc::downgrade(
     797            1 :             &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
     798            1 :         );
     799            1 :         {
     800            1 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     801            1 :             assert_eq!(0, pool.get_global_connections_count());
     802            1 :             client.inner().1.discard();
     803            1 :             // Discard should not add the connection from the pool.
     804            1 :             assert_eq!(0, pool.get_global_connections_count());
     805            1 :         }
     806            1 :         {
     807            1 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     808            1 :             client.do_drop().unwrap()();
     809            1 :             mem::forget(client); // drop the client
     810            1 :             assert_eq!(1, pool.get_global_connections_count());
     811            1 :         }
     812            1 :         {
     813            1 :             let mut closed_client = Client::new(
     814            1 :                 create_inner_with(MockClient::new(true)),
     815            1 :                 conn_info.clone(),
     816            1 :                 ep_pool.clone(),
     817            1 :             );
     818            1 :             closed_client.do_drop().unwrap()();
     819            1 :             mem::forget(closed_client); // drop the client
     820            1 :                                         // The closed client shouldn't be added to the pool.
     821            1 :             assert_eq!(1, pool.get_global_connections_count());
     822            1 :         }
     823            1 :         let is_closed: Arc<AtomicBool> = Arc::new(false.into());
     824            1 :         {
     825            1 :             let mut client = Client::new(
     826            1 :                 create_inner_with(MockClient(is_closed.clone())),
     827            1 :                 conn_info.clone(),
     828            1 :                 ep_pool.clone(),
     829            1 :             );
     830            1 :             client.do_drop().unwrap()();
     831            1 :             mem::forget(client); // drop the client
     832            1 : 
     833            1 :             // The client should be added to the pool.
     834            1 :             assert_eq!(2, pool.get_global_connections_count());
     835            1 :         }
     836            1 :         {
     837            1 :             let mut client = Client::new(create_inner(), conn_info, ep_pool);
     838            1 :             client.do_drop().unwrap()();
     839            1 :             mem::forget(client); // drop the client
     840            1 : 
     841            1 :             // The client shouldn't be added to the pool. Because the ep-pool is full.
     842            1 :             assert_eq!(2, pool.get_global_connections_count());
     843            1 :         }
     844            1 : 
     845            1 :         let conn_info = ConnInfo {
     846            1 :             user_info: ComputeUserInfo {
     847            1 :                 user: "user".into(),
     848            1 :                 endpoint: "endpoint-2".into(),
     849            1 :                 options: NeonOptions::default(),
     850            1 :             },
     851            1 :             dbname: "dbname".into(),
     852            1 :         };
     853            1 :         let ep_pool = Arc::downgrade(
     854            1 :             &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
     855            1 :         );
     856            1 :         {
     857            1 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     858            1 :             client.do_drop().unwrap()();
     859            1 :             mem::forget(client); // drop the client
     860            1 :             assert_eq!(3, pool.get_global_connections_count());
     861            1 :         }
     862            1 :         {
     863            1 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     864            1 :             client.do_drop().unwrap()();
     865            1 :             mem::forget(client); // drop the client
     866            1 : 
     867            1 :             // The client shouldn't be added to the pool. Because the global pool is full.
     868            1 :             assert_eq!(3, pool.get_global_connections_count());
     869            1 :         }
     870            1 : 
     871            1 :         is_closed.store(true, atomic::Ordering::Relaxed);
     872            1 :         // Do gc for all shards.
     873            1 :         pool.gc(0);
     874            1 :         pool.gc(1);
     875            1 :         // Closed client should be removed from the pool.
     876            1 :         assert_eq!(2, pool.get_global_connections_count());
     877            1 :     }
     878              : }
        

Generated by: LCOV version 2.1-beta