       1              : use anyhow::Context;
       2              : use async_trait::async_trait;
       3              : use dashmap::DashMap;
       4              : use futures::{future::poll_fn, Future};
       5              : use metrics::{register_int_counter_pair, IntCounterPair, IntCounterPairGuard};
       6              : use once_cell::sync::Lazy;
       7              : use parking_lot::RwLock;
       8              : use pbkdf2::{
       9              :     password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
      10              :     Params, Pbkdf2,
      11              : };
      12              : use prometheus::{exponential_buckets, register_histogram, Histogram};
      13              : use rand::Rng;
      14              : use smol_str::SmolStr;
      15              : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
      16              : use std::{
      17              :     fmt,
      18              :     task::{ready, Poll},
      19              : };
      20              : use std::{
      21              :     ops::Deref,
      22              :     sync::atomic::{self, AtomicUsize},
      23              : };
      24              : use tokio::time::{self, Instant};
      25              : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
      26              : 
      27              : use crate::{
      28              :     auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list},
      29              :     console::{self, messages::MetricsAuxInfo},
      30              :     context::RequestMonitoring,
      31              :     metrics::NUM_DB_CONNECTIONS_GAUGE,
      32              :     proxy::connect_compute::ConnectMechanism,
      33              :     usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
      34              :     DbName, EndpointCacheKey, RoleName,
      35              : };
      36              : use crate::{compute, config};
      37              : 
      38              : use tracing::{debug, error, warn, Span};
      39              : use tracing::{info, info_span, Instrument};
      40              : 
      41              : pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
      42              : 
      43           43 : #[derive(Debug, Clone)]
      44              : pub struct ConnInfo {
      45              :     pub user_info: ComputeUserInfo,
      46              :     pub dbname: DbName,
      47              :     pub password: SmolStr,
      48              : }
      49              : 
      50              : impl ConnInfo {
      51              :     // hm, change to hasher to avoid cloning?
      52          110 :     pub fn db_and_user(&self) -> (DbName, RoleName) {
      53          110 :         (self.dbname.clone(), self.user_info.user.clone())
      54          110 :     }
      55              : 
      56           38 :     pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
      57           38 :         self.user_info.endpoint_cache_key()
      58           38 :     }
      59              : }
      60              : 
      61              : impl fmt::Display for ConnInfo {
      62              :     // use custom display to avoid logging password
      63          218 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      64          218 :         write!(
      65          218 :             f,
      66          218 :             "{}@{}/{}?{}",
      67          218 :             self.user_info.user,
      68          218 :             self.user_info.endpoint,
      69          218 :             self.dbname,
      70          218 :             self.user_info.options.get_cache_key("")
      71          218 :         )
      72          218 :     }
      73              : }
      74              : 
      75              : struct ConnPoolEntry {
      76              :     conn: ClientInner,
      77              :     _last_access: std::time::Instant,
      78              : }
      79              : 
      80              : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
      81              : // Number of open connections is limited by the `max_conns_per_endpoint`.
      82              : pub struct EndpointConnPool {
      83              :     pools: HashMap<(DbName, RoleName), DbUserConnPool>,
      84              :     total_conns: usize,
      85              :     max_conns: usize,
      86              :     _guard: IntCounterPairGuard,
      87              : }
      88              : 
      89              : impl EndpointConnPool {
      90           13 :     fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry> {
      91           13 :         let Self {
      92           13 :             pools, total_conns, ..
      93           13 :         } = self;
      94           13 :         pools
      95           13 :             .get_mut(&db_user)
      96           13 :             .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
      97           13 :     }
      98              : 
      99            3 :     fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
     100            3 :         let Self {
     101            3 :             pools, total_conns, ..
     102            3 :         } = self;
     103            3 :         if let Some(pool) = pools.get_mut(&db_user) {
     104            3 :             let old_len = pool.conns.len();
     105            3 :             pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
     106            3 :             let new_len = pool.conns.len();
     107            3 :             let removed = old_len - new_len;
     108            3 :             *total_conns -= removed;
     109            3 :             removed > 0
     110              :         } else {
     111            0 :             false
     112              :         }
     113            3 :     }
     114              : 
     115           20 :     fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> {
     116           20 :         let conn_id = client.conn_id;
     117           20 : 
     118           20 :         if client.inner.is_closed() {
     119            0 :             info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
     120            0 :             return Ok(());
     121           20 :         }
     122           20 : 
     123           20 :         // return connection to the pool
     124           20 :         let mut returned = false;
     125           20 :         let mut per_db_size = 0;
     126           20 :         let total_conns = {
     127           20 :             let mut pool = pool.write();
     128           20 : 
     129           20 :             if pool.total_conns < pool.max_conns {
     130              :                 // we create this db-user entry in get, so it should not be None
     131           20 :                 if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
     132           20 :                     pool_entries.conns.push(ConnPoolEntry {
     133           20 :                         conn: client,
     134           20 :                         _last_access: std::time::Instant::now(),
     135           20 :                     });
     136           20 : 
     137           20 :                     returned = true;
     138           20 :                     per_db_size = pool_entries.conns.len();
     139           20 : 
     140           20 :                     pool.total_conns += 1;
     141           20 :                 }
     142            0 :             }
     143              : 
     144           20 :             pool.total_conns
     145           20 :         };
     146           20 : 
     147           20 :         // do logging outside of the mutex
     148           20 :         if returned {
     149           20 :             info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
     150              :         } else {
     151            0 :             info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
     152              :         }
     153              : 
     154           20 :         Ok(())
     155           20 :     }
     156              : }
     157              : 
     158              : /// 4096 is the number of rounds that SCRAM-SHA-256 recommends.
     159              : /// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway.
     160              : ///
     161              : /// Still takes 1.4ms to hash on my hardware.
     162              : /// We don't want to ruin the latency improvements of using the pool by making password verification take too long
     163              : const PARAMS: Params = Params {
     164              :     rounds: 4096,
     165              :     output_length: 32,
     166              : };
     167              : 
     168            7 : #[derive(Default)]
     169              : pub struct DbUserConnPool {
     170              :     conns: Vec<ConnPoolEntry>,
     171              :     password_hash: Option<PasswordHashString>,
     172              : }
     173              : 
     174              : impl DbUserConnPool {
     175           13 :     fn clear_closed_clients(&mut self, conns: &mut usize) {
     176           13 :         let old_len = self.conns.len();
     177           13 : 
     178           13 :         self.conns.retain(|conn| !conn.conn.inner.is_closed());
     179           13 : 
     180           13 :         let new_len = self.conns.len();
     181           13 :         let removed = old_len - new_len;
     182           13 :         *conns -= removed;
     183           13 :     }
     184              : 
     185           13 :     fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry> {
     186           13 :         self.clear_closed_clients(conns);
     187           13 :         let conn = self.conns.pop();
     188           13 :         if conn.is_some() {
     189            4 :             *conns -= 1;
     190            9 :         }
     191           13 :         conn
     192           13 :     }
     193              : }
     194              : 
     195              : pub struct GlobalConnPool {
     196              :     // endpoint -> per-endpoint connection pool
     197              :     //
     198              :     // That should be a fairly conteded map, so return reference to the per-endpoint
     199              :     // pool as early as possible and release the lock.
     200              :     global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
     201              : 
     202              :     /// Number of endpoint-connection pools
     203              :     ///
     204              :     /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
     205              :     /// That seems like far too much effort, so we're using a relaxed increment counter instead.
     206              :     /// It's only used for diagnostics.
     207              :     global_pool_size: AtomicUsize,
     208              : 
     209              :     proxy_config: &'static crate::config::ProxyConfig,
     210              : }
     211              : 
     212            0 : #[derive(Debug, Clone, Copy)]
     213              : pub struct GlobalConnPoolOptions {
     214              :     // Maximum number of connections per one endpoint.
     215              :     // Can mix different (dbname, username) connections.
     216              :     // When running out of free slots for a particular endpoint,
     217              :     // falls back to opening a new connection for each request.
     218              :     pub max_conns_per_endpoint: usize,
     219              : 
     220              :     pub gc_epoch: Duration,
     221              : 
     222              :     pub pool_shards: usize,
     223              : 
     224              :     pub idle_timeout: Duration,
     225              : 
     226              :     pub opt_in: bool,
     227              : }
     228              : 
     229           23 : pub static GC_LATENCY: Lazy<Histogram> = Lazy::new(|| {
     230           23 :     register_histogram!(
     231           23 :         "proxy_http_pool_reclaimation_lag_seconds",
     232           23 :         "Time it takes to reclaim unused connection pools",
     233           23 :         // 1us -> 65ms
     234           23 :         exponential_buckets(1e-6, 2.0, 16).unwrap(),
     235           23 :     )
     236           23 :     .unwrap()
     237           23 : });
     238              : 
     239            7 : pub static ENDPOINT_POOLS: Lazy<IntCounterPair> = Lazy::new(|| {
     240           14 :     register_int_counter_pair!(
     241           14 :         "proxy_http_pool_endpoints_registered_total",
     242           14 :         "Number of endpoints we have registered pools for",
     243           14 :         "proxy_http_pool_endpoints_unregistered_total",
     244           14 :         "Number of endpoints we have unregistered pools for",
     245           14 :     )
     246            7 :     .unwrap()
     247            7 : });
     248              : 
     249              : impl GlobalConnPool {
     250           23 :     pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
     251           23 :         let shards = config.http_config.pool_options.pool_shards;
     252           23 :         Arc::new(Self {
     253           23 :             global_pool: DashMap::with_shard_amount(shards),
     254           23 :             global_pool_size: AtomicUsize::new(0),
     255           23 :             proxy_config: config,
     256           23 :         })
     257           23 :     }
     258              : 
     259           23 :     pub fn shutdown(&self) {
     260           23 :         // drops all strong references to endpoint-pools
     261           23 :         self.global_pool.clear();
     262           23 :     }
     263              : 
     264           23 :     pub async fn gc_worker(&self, mut rng: impl Rng) {
     265           23 :         let epoch = self.proxy_config.http_config.pool_options.gc_epoch;
     266           23 :         let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
     267              :         loop {
     268           55 :             interval.tick().await;
     269              : 
     270           32 :             let shard = rng.gen_range(0..self.global_pool.shards().len());
     271           32 :             self.gc(shard);
     272              :         }
     273              :     }
     274              : 
     275           32 :     fn gc(&self, shard: usize) {
     276           32 :         debug!(shard, "pool: performing epoch reclamation");
     277              : 
     278              :         // acquire a random shard lock
     279           32 :         let mut shard = self.global_pool.shards()[shard].write();
     280           32 : 
     281           32 :         let timer = GC_LATENCY.start_timer();
     282           32 :         let current_len = shard.len();
     283           32 :         shard.retain(|endpoint, x| {
     284              :             // if the current endpoint pool is unique (no other strong or weak references)
     285              :             // then it is currently not in use by any connections.
     286            0 :             if let Some(pool) = Arc::get_mut(x.get_mut()) {
     287              :                 let EndpointConnPool {
     288            0 :                     pools, total_conns, ..
     289            0 :                 } = pool.get_mut();
     290            0 : 
     291            0 :                 // ensure that closed clients are removed
     292            0 :                 pools
     293            0 :                     .iter_mut()
     294            0 :                     .for_each(|(_, db_pool)| db_pool.clear_closed_clients(total_conns));
     295            0 : 
     296            0 :                 // we only remove this pool if it has no active connections
     297            0 :                 if *total_conns == 0 {
     298            0 :                     info!("pool: discarding pool for endpoint {endpoint}");
     299            0 :                     return false;
     300            0 :                 }
     301            0 :             }
     302              : 
     303            0 :             true
     304           32 :         });
     305           32 :         let new_len = shard.len();
     306           32 :         drop(shard);
     307           32 :         timer.observe_duration();
     308           32 : 
     309           32 :         let removed = current_len - new_len;
     310           32 : 
     311           32 :         if removed > 0 {
     312            0 :             let global_pool_size = self
     313            0 :                 .global_pool_size
     314            0 :                 .fetch_sub(removed, atomic::Ordering::Relaxed)
     315            0 :                 - removed;
     316            0 :             info!("pool: performed global pool gc. size now {global_pool_size}");
     317           32 :         }
     318           32 :     }
     319              : 
     320           46 :     pub async fn get(
     321           46 :         self: &Arc<Self>,
     322           46 :         ctx: &mut RequestMonitoring,
     323           46 :         conn_info: ConnInfo,
     324           46 :         force_new: bool,
     325           46 :     ) -> anyhow::Result<Client> {
     326           46 :         let mut client: Option<ClientInner> = None;
     327           46 : 
     328           46 :         let mut hash_valid = false;
     329           46 :         let mut endpoint_pool = Weak::new();
     330           46 :         if !force_new {
     331           27 :             let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
     332           27 :             endpoint_pool = Arc::downgrade(&pool);
     333           27 :             let mut hash = None;
     334           27 : 
     335           27 :             // find a pool entry by (dbname, username) if exists
     336           27 :             {
     337           27 :                 let pool =;
     338           27 :                 if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) {
     339           19 :                     if !pool_entries.conns.is_empty() {
     340           16 :                         hash = pool_entries.password_hash.clone();
     341           16 :                     }
     342            8 :                 }
     343              :             }
     344              : 
     345              :             // a connection exists in the pool, verify the password hash
     346           27 :             if let Some(hash) = hash {
     347           16 :                 let pw = conn_info.password.clone();
     348           16 :                 let validate = tokio::task::spawn_blocking(move || {
     349           16 :                     Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash())
     350           16 :                 })
     351           32 :                 .await?;
     352              : 
     353              :                 // if the hash is invalid, don't error
     354              :                 // we will continue with the regular connection flow
     355           16 :                 if validate.is_ok() {
     356           13 :                     hash_valid = true;
     357           13 :                     if let Some(entry) = pool.write().get_conn_entry(conn_info.db_and_user()) {
     358            4 :                         client = Some(entry.conn)
     359            9 :                     }
     360            3 :                 }
     361           11 :             }
     362           19 :         }
     363              : 
     364              :         // ok return cached connection if found and establish a new one otherwise
     365           46 :         let new_client = if let Some(client) = client {
     366            4 :             ctx.set_project(client.aux.clone());
     367            4 :             if client.inner.is_closed() {
     368            0 :                 let conn_id = uuid::Uuid::new_v4();
     369            0 :                 info!(%conn_id, "pool: cached connection '{conn_info}' is closed, opening a new one");
     370            0 :                 connect_to_compute(
     371            0 :                     self.proxy_config,
     372            0 :                     ctx,
     373            0 :                     &conn_info,
     374            0 :                     conn_id,
     375            0 :                     endpoint_pool.clone(),
     376            0 :                 )
     377            0 :                 .await
     378              :             } else {
     379            4 :                 info!("pool: reusing connection '{conn_info}'");
     380            4 :                 client.session.send(ctx.session_id)?;
     381            4 :                 tracing::Span::current().record(
     382            4 :                     "pid",
     383            4 :                     &tracing::field::display(client.inner.get_process_id()),
     384            4 :                 );
     385            4 :                 ctx.latency_timer.pool_hit();
     386            4 :                 ctx.latency_timer.success();
     387            4 :                 return Ok(Client::new(client, conn_info, endpoint_pool).await);
     388              :             }
     389              :         } else {
     390           42 :             let conn_id = uuid::Uuid::new_v4();
     391           42 :             info!(%conn_id, "pool: opening a new connection '{conn_info}'");
     392           42 :             connect_to_compute(
     393           42 :                 self.proxy_config,
     394           42 :                 ctx,
     395           42 :                 &conn_info,
     396           42 :                 conn_id,
     397           42 :                 endpoint_pool.clone(),
     398           42 :             )
     399          469 :             .await
     400              :         };
     401           42 :         if let Ok(client) = &new_client {
     402           39 :             tracing::Span::current().record(
     403           39 :                 "pid",
     404           39 :                 &tracing::field::display(client.inner.get_process_id()),
     405           39 :             );
     406           39 :         }
     407              : 
     408           39 :         match &new_client {
     409            3 :             // clear the hash. it's no longer valid
     410            3 :             // TODO: update tokio-postgres fork to allow access to this error kind directly
     411            3 :             Err(err)
     412            3 :                 if hash_valid && err.to_string().contains("password authentication failed") =>
     413            0 :             {
     414            0 :                 let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
     415            0 :                 let mut pool = pool.write();
     416            0 :                 if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
     417            0 :                     entry.password_hash = None;
     418            0 :                 }
     419              :             }
     420              :             // new password is valid and we should insert/update it
     421           39 :             Ok(_) if !force_new && !hash_valid => {
     422           11 :                 let pw = conn_info.password.clone();
     423           11 :                 let new_hash = tokio::task::spawn_blocking(move || {
     424           11 :                     let salt = SaltString::generate(rand::rngs::OsRng);
     425           11 :                     Pbkdf2
     426           11 :                         .hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt)
     427           11 :                         .map(|s| s.serialize())
     428           11 :                 })
     429           11 :                 .await??;
     430              : 
     431           11 :                 let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
     432           11 :                 let mut pool = pool.write();
     433           11 :                 pool.pools
     434           11 :                     .entry(conn_info.db_and_user())
     435           11 :                     .or_default()
     436           11 :                     .password_hash = Some(new_hash);
     437              :             }
     438           31 :             _ => {}
     439              :         }
     440           42 :         let new_client = new_client?;
     441           39 :         Ok(Client::new(new_client, conn_info, endpoint_pool).await)
     442           46 :     }
     443              : 
     444           38 :     fn get_or_create_endpoint_pool(
     445           38 :         &self,
     446           38 :         endpoint: &EndpointCacheKey,
     447           38 :     ) -> Arc<RwLock<EndpointConnPool>> {
     448              :         // fast path
     449           38 :         if let Some(pool) = self.global_pool.get(endpoint) {
     450           31 :             return pool.clone();
     451            7 :         }
     452            7 : 
     453            7 :         // slow path
     454            7 :         let new_pool = Arc::new(RwLock::new(EndpointConnPool {
     455            7 :             pools: HashMap::new(),
     456            7 :             total_conns: 0,
     457            7 :             max_conns: self
     458            7 :                 .proxy_config
     459            7 :                 .http_config
     460            7 :                 .pool_options
     461            7 :                 .max_conns_per_endpoint,
     462            7 :             _guard: ENDPOINT_POOLS.guard(),
     463            7 :         }));
     464            7 : 
     465            7 :         // find or create a pool for this endpoint
     466            7 :         let mut created = false;
     467            7 :         let pool = self
     468            7 :             .global_pool
     469            7 :             .entry(endpoint.clone())
     470            7 :             .or_insert_with(|| {
     471            7 :                 created = true;
     472            7 :                 new_pool
     473            7 :             })
     474            7 :             .clone();
     475            7 : 
     476            7 :         // log new global pool size
     477            7 :         if created {
     478            7 :             let global_pool_size = self
     479            7 :                 .global_pool_size
     480            7 :                 .fetch_add(1, atomic::Ordering::Relaxed)
     481            7 :                 + 1;
     482            7 :             info!(
     483            7 :                 "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
     484            7 :             );
     485            0 :         }
     486              : 
     487            7 :         pool
     488           38 :     }
     489              : }
     490              : 
     491              : struct TokioMechanism<'a> {
     492              :     pool: Weak<RwLock<EndpointConnPool>>,
     493              :     conn_info: &'a ConnInfo,
     494              :     conn_id: uuid::Uuid,
     495              :     idle: Duration,
     496              : }
     497              : 
     498              : #[async_trait]
     499              : impl ConnectMechanism for TokioMechanism<'_> {
     500              :     type Connection = ClientInner;
     501              :     type ConnectError = tokio_postgres::Error;
     502              :     type Error = anyhow::Error;
     503              : 
     504           43 :     async fn connect_once(
     505           43 :         &self,
     506           43 :         ctx: &mut RequestMonitoring,
     507           43 :         node_info: &console::CachedNodeInfo,
     508           43 :         timeout: time::Duration,
     509           43 :     ) -> Result<Self::Connection, Self::ConnectError> {
     510           43 :         connect_to_compute_once(
     511           43 :             ctx,
     512           43 :             node_info,
     513           43 :             self.conn_info,
     514           43 :             timeout,
     515           43 :             self.conn_id,
     516           43 :             self.pool.clone(),
     517           43 :             self.idle,
     518           43 :         )
     519          149 :         .await
     520           86 :     }
     521              : 
     522           43 :     fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
     523              : }
     524              : 
     525              : // Wake up the destination if needed. Code here is a bit involved because
     526              : // we reuse the code from the usual proxy and we need to prepare few structures
     527              : // that this code expects.
     528            0 : #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
     529              : async fn connect_to_compute(
     530              :     config: &config::ProxyConfig,
     531              :     ctx: &mut RequestMonitoring,
     532              :     conn_info: &ConnInfo,
     533              :     conn_id: uuid::Uuid,
     534              :     pool: Weak<RwLock<EndpointConnPool>>,
     535              : ) -> anyhow::Result<ClientInner> {
     536              :     ctx.set_application(Some(APP_NAME));
     537              :     let backend = config
     538              :         .auth_backend
     539              :         .as_ref()
     540           42 :         .map(|_| conn_info.user_info.clone());
     541              : 
     542              :     if !config.disable_ip_check_for_http {
     543              :         let (allowed_ips, _) = backend.get_allowed_ips_and_secret(ctx).await?;
     544              :         if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
     545              :             return Err(auth::AuthError::ip_address_not_allowed().into());
     546              :         }
     547              :     }
     548              :     let node_info = backend
     549              :         .wake_compute(ctx)
     550              :         .await?
     551              :         .context("missing cache entry from wake_compute")?;
     552              : 
     553              :     ctx.set_project(node_info.aux.clone());
     554              : 
     555              :     crate::proxy::connect_compute::connect_to_compute(
     556              :         ctx,
     557              :         &TokioMechanism {
     558              :             conn_id,
     559              :             conn_info,
     560              :             pool,
     561              :             idle: config.http_config.pool_options.idle_timeout,
     562              :         },
     563              :         node_info,
     564              :         &backend,
     565              :     )
     566              :     .await
     567              : }
     568              : 
     569           43 : async fn connect_to_compute_once(
     570           43 :     ctx: &mut RequestMonitoring,
     571           43 :     node_info: &console::CachedNodeInfo,
     572           43 :     conn_info: &ConnInfo,
     573           43 :     timeout: time::Duration,
     574           43 :     conn_id: uuid::Uuid,
     575           43 :     pool: Weak<RwLock<EndpointConnPool>>,
     576           43 :     idle: Duration,
     577           43 : ) -> Result<ClientInner, tokio_postgres::Error> {
     578           43 :     let mut config = (*node_info.config).clone();
     579           43 :     let mut session = ctx.session_id;
     580              : 
     581           43 :     let (client, mut connection) = config
     582           43 :         .user(&conn_info.user_info.user)
     583           43 :         .password(&*conn_info.password)
     584           43 :         .dbname(&conn_info.dbname)
     585           43 :         .connect_timeout(timeout)
     586           43 :         .connect(tokio_postgres::NoTls)
     587          149 :         .await?;
     588              : 
     589           39 :     let conn_gauge = NUM_DB_CONNECTIONS_GAUGE
     590           39 :         .with_label_values(&[ctx.protocol])
     591           39 :         .guard();
     592           39 : 
     593           39 :     tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     594           39 : 
     595           39 :     let (tx, mut rx) = tokio::sync::watch::channel(session);
     596              : 
     597           39 :     let span = info_span!(parent: None, "connection", %conn_id);
     598           39 :     span.in_scope(|| {
     599           39 :         info!(%conn_info, %session, "new connection");
     600           39 :     });
     601           39 : 
     602           39 :     let db_user = conn_info.db_and_user();
     603           39 :     tokio::spawn(
     604           39 :         async move {
     605           39 :             let _conn_gauge = conn_gauge;
     606           39 :             let mut idle_timeout = pin!(tokio::time::sleep(idle));
     607         7010 :             poll_fn(move |cx| {
     608         7010 :                 if matches!(rx.has_changed(), Ok(true)) {
     609            4 :                     session = *rx.borrow_and_update();
     610            4 :                     info!(%session, "changed session");
     611            4 :                     idle_timeout.as_mut().reset(Instant::now() + idle);
     612         7006 :                 }
     613              : 
     614              :                 // 5 minute idle connection timeout
     615         7010 :                 if idle_timeout.as_mut().poll(cx).is_ready() {
     616            0 :                     idle_timeout.as_mut().reset(Instant::now() + idle);
     617            0 :                     info!("connection idle");
     618            0 :                     if let Some(pool) = pool.clone().upgrade() {
     619              :                         // remove client from pool - should close the connection if it's idle.
     620              :                         // does nothing if the client is currently checked-out and in-use
     621            0 :                         if pool.write().remove_client(db_user.clone(), conn_id) {
     622            0 :                             info!("idle connection removed");
     623            0 :                         }
     624            0 :                     }
     625         7010 :                 }
     626              : 
     627              :                 loop {
     628         7010 :                     let message = ready!(connection.poll_message(cx));
     629              : 
     630            0 :                     match message {
     631            0 :                         Some(Ok(AsyncMessage::Notice(notice))) => {
     632            0 :                             info!(%session, "notice: {}", notice);
     633              :                         }
     634            0 :                         Some(Ok(AsyncMessage::Notification(notif))) => {
     635            0 :                             warn!(%session, pid = notif.process_id(), channel =, "notification received");
     636              :                         }
     637              :                         Some(Ok(_)) => {
     638            0 :                             warn!(%session, "unknown message");
     639              :                         }
     640            0 :                         Some(Err(e)) => {
     641            0 :                             error!(%session, "connection error: {}", e);
     642            0 :                             break
     643              :                         }
     644              :                         None => {
     645           38 :                             info!("connection closed");
     646           38 :                             break
     647              :                         }
     648              :                     }
     649              :                 }
     650              : 
     651              :                 // remove from connection pool
     652           38 :                 if let Some(pool) = pool.clone().upgrade() {
     653            3 :                     if pool.write().remove_client(db_user.clone(), conn_id) {
     654            0 :                         info!("closed connection removed");
     655            3 :                     }
     656           35 :                 }
     657              : 
     658           38 :                 Poll::Ready(())
     659         7010 :             }).await;
     660              : 
     661           39 :         }
     662           39 :         .instrument(span)
     663           39 :     );
     664           39 : 
     665           39 :     Ok(ClientInner {
     666           39 :         inner: client,
     667           39 :         session: tx,
     668           39 :         aux: node_info.aux.clone(),
     669           39 :         conn_id,
     670           39 :     })
     671           43 : }
     672              : 
     673              : struct ClientInner {
     674              :     inner: tokio_postgres::Client,
     675              :     session: tokio::sync::watch::Sender<uuid::Uuid>,
     676              :     aux: MetricsAuxInfo,
     677              :     conn_id: uuid::Uuid,
     678              : }
     679              : 
     680              : impl Client {
     681           41 :     pub fn metrics(&self) -> Arc<MetricCounter> {
     682           41 :         let aux = &self.inner.as_ref().unwrap().aux;
     683           41 :         USAGE_METRICS.register(Ids {
     684           41 :             endpoint_id: aux.endpoint_id.clone(),
     685           41 :             branch_id: aux.branch_id.clone(),
     686           41 :         })
     687           41 :     }
     688              : }
     689              : 
     690              : pub struct Client {
     691              :     conn_id: uuid::Uuid,
     692              :     span: Span,
     693              :     inner: Option<ClientInner>,
     694              :     conn_info: ConnInfo,
     695              :     pool: Weak<RwLock<EndpointConnPool>>,
     696              : }
     697              : 
     698              : pub struct Discard<'a> {
     699              :     conn_id: uuid::Uuid,
     700              :     conn_info: &'a ConnInfo,
     701              :     pool: &'a mut Weak<RwLock<EndpointConnPool>>,
     702              : }
     703              : 
     704              : impl Client {
     705           43 :     pub(self) async fn new(
     706           43 :         inner: ClientInner,
     707           43 :         conn_info: ConnInfo,
     708           43 :         pool: Weak<RwLock<EndpointConnPool>>,
     709           43 :     ) -> Self {
     710           43 :         Self {
     711           43 :             conn_id: inner.conn_id,
     712           43 :             inner: Some(inner),
     713           43 :             span: Span::current(),
     714           43 :             conn_info,
     715           43 :             pool,
     716           43 :         }
     717           43 :     }
     718           43 :     pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
     719           43 :         let Self {
     720           43 :             inner,
     721           43 :             pool,
     722           43 :             conn_id,
     723           43 :             conn_info,
     724           43 :             span: _,
     725           43 :         } = self;
     726           43 :         (
     727           43 :             &mut inner
     728           43 :                 .as_mut()
     729           43 :                 .expect("client inner should not be removed")
     730           43 :                 .inner,
     731           43 :             Discard {
     732           43 :                 pool,
     733           43 :                 conn_info,
     734           43 :                 conn_id: *conn_id,
     735           43 :             },
     736           43 :         )
     737           43 :     }
     738              : 
     739           39 :     pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
     740           39 :         self.inner().1.check_idle(status)
     741           39 :     }
     742            2 :     pub fn discard(&mut self) {
     743            2 :         self.inner().1.discard()
     744            2 :     }
     745              : }
     746              : 
     747              : impl Discard<'_> {
     748           41 :     pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
     749           41 :         let conn_info = &self.conn_info;
     750           41 :         if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
     751            2 :             info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle")
     752           39 :         }
     753           41 :     }
     754            2 :     pub fn discard(&mut self) {
     755            2 :         let conn_info = &self.conn_info;
     756            2 :         if std::mem::take(self.pool).strong_count() > 0 {
     757            2 :             info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
     758            0 :         }
     759            2 :     }
     760              : }
     761              : 
     762              : impl Deref for Client {
     763              :     type Target = tokio_postgres::Client;
     764              : 
     765           41 :     fn deref(&self) -> &Self::Target {
     766           41 :         &self
     767           41 :             .inner
     768           41 :             .as_ref()
     769           41 :             .expect("client inner should not be removed")
     770           41 :             .inner
     771           41 :     }
     772              : }
     773              : 
     774              : impl Drop for Client {
     775           43 :     fn drop(&mut self) {
     776           43 :         let conn_info = self.conn_info.clone();
     777           43 :         let client = self
     778           43 :             .inner
     779           43 :             .take()
     780           43 :             .expect("client inner should not be removed");
     781           43 :         if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
     782           20 :             let current_span = self.span.clone();
     783           20 :             // return connection to the pool
     784           20 :             tokio::task::spawn_blocking(move || {
     785           20 :                 let _span = current_span.enter();
     786           20 :                 let _ = EndpointConnPool::put(&conn_pool, &conn_info, client);
     787           20 :             });
     788           23 :         }
     789           43 :     }
     790              : }

