LCOV - differential code coverage report
Current view: top level - proxy/src/serverless - conn_pool.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 86.7 % 483 419 64 419
Current Date: 2024-01-09 02:06:09 Functions: 70.7 % 82 58 24 58
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use anyhow::{anyhow, Context};
       2                 : use async_trait::async_trait;
       3                 : use dashmap::DashMap;
       4                 : use futures::{future::poll_fn, Future};
       5                 : use metrics::{register_int_counter_pair, IntCounterPair, IntCounterPairGuard};
       6                 : use once_cell::sync::Lazy;
       7                 : use parking_lot::RwLock;
       8                 : use pbkdf2::{
       9                 :     password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
      10                 :     Params, Pbkdf2,
      11                 : };
      12                 : use pq_proto::StartupMessageParams;
      13                 : use prometheus::{exponential_buckets, register_histogram, Histogram};
      14                 : use rand::Rng;
      15                 : use smol_str::SmolStr;
      16                 : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
      17                 : use std::{
      18                 :     fmt,
      19                 :     task::{ready, Poll},
      20                 : };
      21                 : use std::{
      22                 :     ops::Deref,
      23                 :     sync::atomic::{self, AtomicUsize},
      24                 : };
      25                 : use tokio::time::{self, Instant};
      26                 : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
      27                 : 
      28                 : use crate::{
      29                 :     auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list},
      30                 :     console,
      31                 :     context::RequestMonitoring,
      32                 :     metrics::NUM_DB_CONNECTIONS_GAUGE,
      33                 :     proxy::{connect_compute::ConnectMechanism, neon_options},
      34                 :     usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
      35                 : };
      36                 : use crate::{compute, config};
      37                 : 
      38                 : use tracing::{debug, error, warn, Span};
      39                 : use tracing::{info, info_span, Instrument};
      40                 : 
      41                 : pub const APP_NAME: &str = "/sql_over_http";
      42                 : 
      43 CBC          42 : #[derive(Debug, Clone)]
      44                 : pub struct ConnInfo {
      45                 :     pub username: SmolStr,
      46                 :     pub dbname: SmolStr,
      47                 :     pub hostname: SmolStr,
      48                 :     pub password: SmolStr,
      49                 :     pub options: Option<SmolStr>,
      50                 : }
      51                 : 
      52                 : impl ConnInfo {
      53                 :     // hm, change to hasher to avoid cloning?
      54             109 :     pub fn db_and_user(&self) -> (SmolStr, SmolStr) {
      55             109 :         (self.dbname.clone(), self.username.clone())
      56             109 :     }
      57                 : }
      58                 : 
      59                 : impl fmt::Display for ConnInfo {
      60                 :     // use custom display to avoid logging password
      61             214 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      62             214 :         write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname)
      63             214 :     }
      64                 : }
      65                 : 
      66                 : struct ConnPoolEntry {
      67                 :     conn: ClientInner,
      68                 :     _last_access: std::time::Instant,
      69                 : }
      70                 : 
      71                 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
      72                 : // Number of open connections is limited by the `max_conns_per_endpoint`.
      73                 : pub struct EndpointConnPool {
      74                 :     pools: HashMap<(SmolStr, SmolStr), DbUserConnPool>,
      75                 :     total_conns: usize,
      76                 :     max_conns: usize,
      77                 :     _guard: IntCounterPairGuard,
      78                 : }
      79                 : 
      80                 : impl EndpointConnPool {
      81              13 :     fn get_conn_entry(&mut self, db_user: (SmolStr, SmolStr)) -> Option<ConnPoolEntry> {
      82              13 :         let Self {
      83              13 :             pools, total_conns, ..
      84              13 :         } = self;
      85              13 :         pools
      86              13 :             .get_mut(&db_user)
      87              13 :             .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
      88              13 :     }
      89                 : 
      90               3 :     fn remove_client(&mut self, db_user: (SmolStr, SmolStr), conn_id: uuid::Uuid) -> bool {
      91               3 :         let Self {
      92               3 :             pools, total_conns, ..
      93               3 :         } = self;
      94               3 :         if let Some(pool) = pools.get_mut(&db_user) {
      95               3 :             let old_len = pool.conns.len();
      96               3 :             pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
      97               3 :             let new_len = pool.conns.len();
      98               3 :             let removed = old_len - new_len;
      99               3 :             *total_conns -= removed;
     100               3 :             removed > 0
     101                 :         } else {
     102 UBC           0 :             false
     103                 :         }
     104 CBC           3 :     }
     105                 : 
     106              20 :     fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> {
     107              20 :         let conn_id = client.conn_id;
     108              20 : 
     109              20 :         if client.inner.is_closed() {
     110 UBC           0 :             info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
     111               0 :             return Ok(());
     112 CBC          20 :         }
     113              20 : 
     114              20 :         // return connection to the pool
     115              20 :         let mut returned = false;
     116              20 :         let mut per_db_size = 0;
     117              20 :         let total_conns = {
     118              20 :             let mut pool = pool.write();
     119              20 : 
     120              20 :             if pool.total_conns < pool.max_conns {
     121                 :                 // we create this db-user entry in get, so it should not be None
     122              20 :                 if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
     123              20 :                     pool_entries.conns.push(ConnPoolEntry {
     124              20 :                         conn: client,
     125              20 :                         _last_access: std::time::Instant::now(),
     126              20 :                     });
     127              20 : 
     128              20 :                     returned = true;
     129              20 :                     per_db_size = pool_entries.conns.len();
     130              20 : 
     131              20 :                     pool.total_conns += 1;
     132              20 :                 }
     133 UBC           0 :             }
     134                 : 
     135 CBC          20 :             pool.total_conns
     136              20 :         };
     137              20 : 
     138              20 :         // do logging outside of the mutex
     139              20 :         if returned {
     140              20 :             info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
     141                 :         } else {
     142 UBC           0 :             info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
     143                 :         }
     144                 : 
     145 CBC          20 :         Ok(())
     146              20 :     }
     147                 : }
     148                 : 
     149                 : /// 4096 is the number of rounds that SCRAM-SHA-256 recommends.
     150                 : /// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway.
     151                 : ///
     152                 : /// Still takes 1.4ms to hash on my hardware.
     153                 : /// We don't want to ruin the latency improvements of using the pool by making password verification take too long
     154                 : const PARAMS: Params = Params {
     155                 :     rounds: 4096,
     156                 :     output_length: 32,
     157                 : };
     158                 : 
     159               7 : #[derive(Default)]
     160                 : pub struct DbUserConnPool {
     161                 :     conns: Vec<ConnPoolEntry>,
     162                 :     password_hash: Option<PasswordHashString>,
     163                 : }
     164                 : 
     165                 : impl DbUserConnPool {
     166              13 :     fn clear_closed_clients(&mut self, conns: &mut usize) {
     167              13 :         let old_len = self.conns.len();
     168              13 : 
     169              13 :         self.conns.retain(|conn| !conn.conn.inner.is_closed());
     170              13 : 
     171              13 :         let new_len = self.conns.len();
     172              13 :         let removed = old_len - new_len;
     173              13 :         *conns -= removed;
     174              13 :     }
     175                 : 
     176              13 :     fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry> {
     177              13 :         self.clear_closed_clients(conns);
     178              13 :         let conn = self.conns.pop();
     179              13 :         if conn.is_some() {
     180               4 :             *conns -= 1;
     181               9 :         }
     182              13 :         conn
     183              13 :     }
     184                 : }
     185                 : 
     186                 : pub struct GlobalConnPool {
     187                 :     // endpoint -> per-endpoint connection pool
     188                 :     //
     189                 :     // That should be a fairly conteded map, so return reference to the per-endpoint
     190                 :     // pool as early as possible and release the lock.
     191                 :     global_pool: DashMap<SmolStr, Arc<RwLock<EndpointConnPool>>>,
     192                 : 
     193                 :     /// Number of endpoint-connection pools
     194                 :     ///
     195                 :     /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
     196                 :     /// That seems like far too much effort, so we're using a relaxed increment counter instead.
     197                 :     /// It's only used for diagnostics.
     198                 :     global_pool_size: AtomicUsize,
     199                 : 
     200                 :     proxy_config: &'static crate::config::ProxyConfig,
     201                 : }
     202                 : 
     203 UBC           0 : #[derive(Debug, Clone, Copy)]
     204                 : pub struct GlobalConnPoolOptions {
     205                 :     // Maximum number of connections per one endpoint.
     206                 :     // Can mix different (dbname, username) connections.
     207                 :     // When running out of free slots for a particular endpoint,
     208                 :     // falls back to opening a new connection for each request.
     209                 :     pub max_conns_per_endpoint: usize,
     210                 : 
     211                 :     pub gc_epoch: Duration,
     212                 : 
     213                 :     pub pool_shards: usize,
     214                 : 
     215                 :     pub idle_timeout: Duration,
     216                 : 
     217                 :     pub opt_in: bool,
     218                 : }
     219                 : 
     220 CBC          22 : pub static GC_LATENCY: Lazy<Histogram> = Lazy::new(|| {
     221              22 :     register_histogram!(
     222              22 :         "proxy_http_pool_reclaimation_lag_seconds",
     223              22 :         "Time it takes to reclaim unused connection pools",
     224              22 :         // 1us -> 65ms
     225              22 :         exponential_buckets(1e-6, 2.0, 16).unwrap(),
     226              22 :     )
     227              22 :     .unwrap()
     228              22 : });
     229                 : 
     230               7 : pub static ENDPOINT_POOLS: Lazy<IntCounterPair> = Lazy::new(|| {
     231              14 :     register_int_counter_pair!(
     232              14 :         "proxy_http_pool_endpoints_registered_total",
     233              14 :         "Number of endpoints we have registered pools for",
     234              14 :         "proxy_http_pool_endpoints_unregistered_total",
     235              14 :         "Number of endpoints we have unregistered pools for",
     236              14 :     )
     237               7 :     .unwrap()
     238               7 : });
     239                 : 
     240                 : impl GlobalConnPool {
     241              22 :     pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
     242              22 :         let shards = config.http_config.pool_options.pool_shards;
     243              22 :         Arc::new(Self {
     244              22 :             global_pool: DashMap::with_shard_amount(shards),
     245              22 :             global_pool_size: AtomicUsize::new(0),
     246              22 :             proxy_config: config,
     247              22 :         })
     248              22 :     }
     249                 : 
     250              22 :     pub fn shutdown(&self) {
     251              22 :         // drops all strong references to endpoint-pools
     252              22 :         self.global_pool.clear();
     253              22 :     }
     254                 : 
     255              22 :     pub async fn gc_worker(&self, mut rng: impl Rng) {
     256              22 :         let epoch = self.proxy_config.http_config.pool_options.gc_epoch;
     257              22 :         let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
     258                 :         loop {
     259              52 :             interval.tick().await;
     260                 : 
     261              30 :             let shard = rng.gen_range(0..self.global_pool.shards().len());
     262              30 :             self.gc(shard);
     263                 :         }
     264                 :     }
     265                 : 
     266              30 :     fn gc(&self, shard: usize) {
     267              30 :         debug!(shard, "pool: performing epoch reclamation");
     268                 : 
     269                 :         // acquire a random shard lock
     270              30 :         let mut shard = self.global_pool.shards()[shard].write();
     271              30 : 
     272              30 :         let timer = GC_LATENCY.start_timer();
     273              30 :         let current_len = shard.len();
     274              30 :         shard.retain(|endpoint, x| {
     275                 :             // if the current endpoint pool is unique (no other strong or weak references)
     276                 :             // then it is currently not in use by any connections.
     277 UBC           0 :             if let Some(pool) = Arc::get_mut(x.get_mut()) {
     278                 :                 let EndpointConnPool {
     279               0 :                     pools, total_conns, ..
     280               0 :                 } = pool.get_mut();
     281               0 : 
     282               0 :                 // ensure that closed clients are removed
     283               0 :                 pools
     284               0 :                     .iter_mut()
     285               0 :                     .for_each(|(_, db_pool)| db_pool.clear_closed_clients(total_conns));
     286               0 : 
     287               0 :                 // we only remove this pool if it has no active connections
     288               0 :                 if *total_conns == 0 {
     289               0 :                     info!("pool: discarding pool for endpoint {endpoint}");
     290               0 :                     return false;
     291               0 :                 }
     292               0 :             }
     293                 : 
     294               0 :             true
     295 CBC          30 :         });
     296              30 :         let new_len = shard.len();
     297              30 :         drop(shard);
     298              30 :         timer.observe_duration();
     299              30 : 
     300              30 :         let removed = current_len - new_len;
     301              30 : 
     302              30 :         if removed > 0 {
     303 UBC           0 :             let global_pool_size = self
     304               0 :                 .global_pool_size
     305               0 :                 .fetch_sub(removed, atomic::Ordering::Relaxed)
     306               0 :                 - removed;
     307               0 :             info!("pool: performed global pool gc. size now {global_pool_size}");
     308 CBC          30 :         }
     309              30 :     }
     310                 : 
     311              45 :     pub async fn get(
     312              45 :         self: &Arc<Self>,
     313              45 :         ctx: &mut RequestMonitoring,
     314              45 :         conn_info: ConnInfo,
     315              45 :         force_new: bool,
     316              45 :     ) -> anyhow::Result<Client> {
     317              45 :         let mut client: Option<ClientInner> = None;
     318              45 : 
     319              45 :         let mut hash_valid = false;
     320              45 :         let mut endpoint_pool = Weak::new();
     321              45 :         if !force_new {
     322              27 :             let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
     323              27 :             endpoint_pool = Arc::downgrade(&pool);
     324              27 :             let mut hash = None;
     325              27 : 
     326              27 :             // find a pool entry by (dbname, username) if exists
     327              27 :             {
     328              27 :                 let pool = pool.read();
     329              27 :                 if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) {
     330              19 :                     if !pool_entries.conns.is_empty() {
     331              16 :                         hash = pool_entries.password_hash.clone();
     332              16 :                     }
     333               8 :                 }
     334                 :             }
     335                 : 
     336                 :             // a connection exists in the pool, verify the password hash
     337              27 :             if let Some(hash) = hash {
     338              16 :                 let pw = conn_info.password.clone();
     339              16 :                 let validate = tokio::task::spawn_blocking(move || {
     340              16 :                     Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash())
     341              16 :                 })
     342              32 :                 .await?;
     343                 : 
     344                 :                 // if the hash is invalid, don't error
     345                 :                 // we will continue with the regular connection flow
     346              16 :                 if validate.is_ok() {
     347              13 :                     hash_valid = true;
     348              13 :                     if let Some(entry) = pool.write().get_conn_entry(conn_info.db_and_user()) {
     349               4 :                         client = Some(entry.conn)
     350               9 :                     }
     351               3 :                 }
     352              11 :             }
     353              18 :         }
     354                 : 
     355                 :         // ok return cached connection if found and establish a new one otherwise
     356              45 :         let new_client = if let Some(client) = client {
     357               4 :             if client.inner.is_closed() {
     358 UBC           0 :                 let conn_id = uuid::Uuid::new_v4();
     359               0 :                 info!(%conn_id, "pool: cached connection '{conn_info}' is closed, opening a new one");
     360               0 :                 connect_to_compute(
     361               0 :                     self.proxy_config,
     362               0 :                     ctx,
     363               0 :                     &conn_info,
     364               0 :                     conn_id,
     365               0 :                     endpoint_pool.clone(),
     366               0 :                 )
     367               0 :                 .await
     368                 :             } else {
     369 CBC           4 :                 info!("pool: reusing connection '{conn_info}'");
     370               4 :                 client.session.send(ctx.session_id)?;
     371               4 :                 tracing::Span::current().record(
     372               4 :                     "pid",
     373               4 :                     &tracing::field::display(client.inner.get_process_id()),
     374               4 :                 );
     375               4 :                 ctx.latency_timer.pool_hit();
     376               4 :                 ctx.latency_timer.success();
     377               4 :                 return Ok(Client::new(client, conn_info, endpoint_pool).await);
     378                 :             }
     379                 :         } else {
     380              41 :             let conn_id = uuid::Uuid::new_v4();
     381              41 :             info!(%conn_id, "pool: opening a new connection '{conn_info}'");
     382              41 :             connect_to_compute(
     383              41 :                 self.proxy_config,
     384              41 :                 ctx,
     385              41 :                 &conn_info,
     386              41 :                 conn_id,
     387              41 :                 endpoint_pool.clone(),
     388              41 :             )
     389             459 :             .await
     390                 :         };
     391              41 :         if let Ok(client) = &new_client {
     392              38 :             tracing::Span::current().record(
     393              38 :                 "pid",
     394              38 :                 &tracing::field::display(client.inner.get_process_id()),
     395              38 :             );
     396              38 :         }
     397                 : 
     398              38 :         match &new_client {
     399               3 :             // clear the hash. it's no longer valid
     400               3 :             // TODO: update tokio-postgres fork to allow access to this error kind directly
     401               3 :             Err(err)
     402               3 :                 if hash_valid && err.to_string().contains("password authentication failed") =>
     403 UBC           0 :             {
     404               0 :                 let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
     405               0 :                 let mut pool = pool.write();
     406               0 :                 if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
     407               0 :                     entry.password_hash = None;
     408               0 :                 }
     409                 :             }
     410                 :             // new password is valid and we should insert/update it
     411 CBC          38 :             Ok(_) if !force_new && !hash_valid => {
     412              11 :                 let pw = conn_info.password.clone();
     413              11 :                 let new_hash = tokio::task::spawn_blocking(move || {
     414              11 :                     let salt = SaltString::generate(rand::rngs::OsRng);
     415              11 :                     Pbkdf2
     416              11 :                         .hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt)
     417              11 :                         .map(|s| s.serialize())
     418              11 :                 })
     419              11 :                 .await??;
     420                 : 
     421              11 :                 let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
     422              11 :                 let mut pool = pool.write();
     423              11 :                 pool.pools
     424              11 :                     .entry(conn_info.db_and_user())
     425              11 :                     .or_default()
     426              11 :                     .password_hash = Some(new_hash);
     427                 :             }
     428              30 :             _ => {}
     429                 :         }
     430              41 :         let new_client = new_client?;
     431              38 :         Ok(Client::new(new_client, conn_info, endpoint_pool).await)
     432              45 :     }
     433                 : 
     434              38 :     fn get_or_create_endpoint_pool(&self, endpoint: &SmolStr) -> Arc<RwLock<EndpointConnPool>> {
     435                 :         // fast path
     436              38 :         if let Some(pool) = self.global_pool.get(endpoint) {
     437              31 :             return pool.clone();
     438               7 :         }
     439               7 : 
     440               7 :         // slow path
     441               7 :         let new_pool = Arc::new(RwLock::new(EndpointConnPool {
     442               7 :             pools: HashMap::new(),
     443               7 :             total_conns: 0,
     444               7 :             max_conns: self
     445               7 :                 .proxy_config
     446               7 :                 .http_config
     447               7 :                 .pool_options
     448               7 :                 .max_conns_per_endpoint,
     449               7 :             _guard: ENDPOINT_POOLS.guard(),
     450               7 :         }));
     451               7 : 
     452               7 :         // find or create a pool for this endpoint
     453               7 :         let mut created = false;
     454               7 :         let pool = self
     455               7 :             .global_pool
     456               7 :             .entry(endpoint.clone())
     457               7 :             .or_insert_with(|| {
     458               7 :                 created = true;
     459               7 :                 new_pool
     460               7 :             })
     461               7 :             .clone();
     462               7 : 
     463               7 :         // log new global pool size
     464               7 :         if created {
     465               7 :             let global_pool_size = self
     466               7 :                 .global_pool_size
     467               7 :                 .fetch_add(1, atomic::Ordering::Relaxed)
     468               7 :                 + 1;
     469               7 :             info!(
     470               7 :                 "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
     471               7 :             );
     472 UBC           0 :         }
     473                 : 
     474 CBC           7 :         pool
     475              38 :     }
     476                 : }
     477                 : 
     478                 : struct TokioMechanism<'a> {
     479                 :     pool: Weak<RwLock<EndpointConnPool>>,
     480                 :     conn_info: &'a ConnInfo,
     481                 :     conn_id: uuid::Uuid,
     482                 :     idle: Duration,
     483                 : }
     484                 : 
     485                 : #[async_trait]
     486                 : impl ConnectMechanism for TokioMechanism<'_> {
     487                 :     type Connection = ClientInner;
     488                 :     type ConnectError = tokio_postgres::Error;
     489                 :     type Error = anyhow::Error;
     490                 : 
     491              42 :     async fn connect_once(
     492              42 :         &self,
     493              42 :         ctx: &mut RequestMonitoring,
     494              42 :         node_info: &console::CachedNodeInfo,
     495              42 :         timeout: time::Duration,
     496              42 :     ) -> Result<Self::Connection, Self::ConnectError> {
     497              42 :         connect_to_compute_once(
     498              42 :             ctx,
     499              42 :             node_info,
     500              42 :             self.conn_info,
     501              42 :             timeout,
     502              42 :             self.conn_id,
     503              42 :             self.pool.clone(),
     504              42 :             self.idle,
     505              42 :         )
     506             146 :         .await
     507              84 :     }
     508                 : 
     509              42 :     fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
     510                 : }
     511                 : 
     512                 : // Wake up the destination if needed. Code here is a bit involved because
     513                 : // we reuse the code from the usual proxy and we need to prepare few structures
     514                 : // that this code expects.
     515 UBC           0 : #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
     516                 : async fn connect_to_compute(
     517                 :     config: &config::ProxyConfig,
     518                 :     ctx: &mut RequestMonitoring,
     519                 :     conn_info: &ConnInfo,
     520                 :     conn_id: uuid::Uuid,
     521                 :     pool: Weak<RwLock<EndpointConnPool>>,
     522                 : ) -> anyhow::Result<ClientInner> {
     523                 :     let tls = config.tls_config.as_ref();
     524 CBC          41 :     let common_names = tls.and_then(|tls| tls.common_names.clone());
     525                 : 
     526                 :     let params = StartupMessageParams::new([
     527                 :         ("user", &conn_info.username),
     528                 :         ("database", &conn_info.dbname),
     529                 :         ("application_name", APP_NAME),
     530                 :         ("options", conn_info.options.as_deref().unwrap_or("")),
     531                 :     ]);
     532                 :     let creds =
     533                 :         auth::ClientCredentials::parse(ctx, &params, Some(&conn_info.hostname), common_names)?;
     534                 : 
     535                 :     let creds =
     536 UBC           0 :         ComputeUserInfo::try_from(creds).map_err(|_| anyhow!("missing endpoint identifier"))?;
     537 CBC          41 :     let backend = config.auth_backend.as_ref().map(|_| creds);
     538                 : 
     539                 :     let console_options = neon_options(&params);
     540                 : 
     541                 :     if !config.disable_ip_check_for_http {
     542                 :         let allowed_ips = backend.get_allowed_ips(ctx).await?;
     543                 :         if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
     544                 :             return Err(auth::AuthError::ip_address_not_allowed().into());
     545                 :         }
     546                 :     }
     547                 :     let extra = console::ConsoleReqExtra {
     548                 :         options: console_options,
     549                 :     };
     550                 :     let node_info = backend
     551                 :         .wake_compute(ctx, &extra)
     552                 :         .await?
     553                 :         .context("missing cache entry from wake_compute")?;
     554                 : 
     555                 :     ctx.set_project(node_info.aux.clone());
     556                 : 
     557                 :     crate::proxy::connect_compute::connect_to_compute(
     558                 :         ctx,
     559                 :         &TokioMechanism {
     560                 :             conn_id,
     561                 :             conn_info,
     562                 :             pool,
     563                 :             idle: config.http_config.pool_options.idle_timeout,
     564                 :         },
     565                 :         node_info,
     566                 :         &extra,
     567                 :         &backend,
     568                 :     )
     569                 :     .await
     570                 : }
     571                 : 
     572              42 : async fn connect_to_compute_once(
     573              42 :     ctx: &mut RequestMonitoring,
     574              42 :     node_info: &console::CachedNodeInfo,
     575              42 :     conn_info: &ConnInfo,
     576              42 :     timeout: time::Duration,
     577              42 :     conn_id: uuid::Uuid,
     578              42 :     pool: Weak<RwLock<EndpointConnPool>>,
     579              42 :     idle: Duration,
     580              42 : ) -> Result<ClientInner, tokio_postgres::Error> {
     581              42 :     let mut config = (*node_info.config).clone();
     582              42 :     let mut session = ctx.session_id;
     583                 : 
     584              42 :     let (client, mut connection) = config
     585              42 :         .user(&conn_info.username)
     586              42 :         .password(&*conn_info.password)
     587              42 :         .dbname(&conn_info.dbname)
     588              42 :         .connect_timeout(timeout)
     589              42 :         .connect(tokio_postgres::NoTls)
     590             146 :         .await?;
     591                 : 
     592              38 :     let conn_gauge = NUM_DB_CONNECTIONS_GAUGE
     593              38 :         .with_label_values(&[ctx.protocol])
     594              38 :         .guard();
     595              38 : 
     596              38 :     tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     597              38 : 
     598              38 :     let (tx, mut rx) = tokio::sync::watch::channel(session);
     599                 : 
     600              38 :     let span = info_span!(parent: None, "connection", %conn_id);
     601              38 :     span.in_scope(|| {
     602              38 :         info!(%conn_info, %session, "new connection");
     603              38 :     });
     604              38 :     let ids = Ids {
     605              38 :         endpoint_id: node_info.aux.endpoint_id.clone(),
     606              38 :         branch_id: node_info.aux.branch_id.clone(),
     607              38 :     };
     608              38 : 
     609              38 :     let db_user = conn_info.db_and_user();
     610              38 :     tokio::spawn(
     611              38 :         async move {
     612              38 :             let _conn_gauge = conn_gauge;
     613              38 :             let mut idle_timeout = pin!(tokio::time::sleep(idle));
     614            5846 :             poll_fn(move |cx| {
     615            5846 :                 if matches!(rx.has_changed(), Ok(true)) {
     616               4 :                     session = *rx.borrow_and_update();
     617               4 :                     info!(%session, "changed session");
     618               4 :                     idle_timeout.as_mut().reset(Instant::now() + idle);
     619            5842 :                 }
     620                 : 
     621                 :                 // 5 minute idle connection timeout
     622            5846 :                 if idle_timeout.as_mut().poll(cx).is_ready() {
     623 UBC           0 :                     idle_timeout.as_mut().reset(Instant::now() + idle);
     624               0 :                     info!("connection idle");
     625               0 :                     if let Some(pool) = pool.clone().upgrade() {
     626                 :                         // remove client from pool - should close the connection if it's idle.
     627                 :                         // does nothing if the client is currently checked-out and in-use
     628               0 :                         if pool.write().remove_client(db_user.clone(), conn_id) {
     629               0 :                             info!("idle connection removed");
     630               0 :                         }
     631               0 :                     }
     632 CBC        5846 :                 }
     633                 : 
     634                 :                 loop {
     635            5846 :                     let message = ready!(connection.poll_message(cx));
     636                 : 
     637 UBC           0 :                     match message {
     638               0 :                         Some(Ok(AsyncMessage::Notice(notice))) => {
     639               0 :                             info!(%session, "notice: {}", notice);
     640                 :                         }
     641               0 :                         Some(Ok(AsyncMessage::Notification(notif))) => {
     642               0 :                             warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received");
     643                 :                         }
     644                 :                         Some(Ok(_)) => {
     645               0 :                             warn!(%session, "unknown message");
     646                 :                         }
     647               0 :                         Some(Err(e)) => {
     648               0 :                             error!(%session, "connection error: {}", e);
     649               0 :                             break
     650                 :                         }
     651                 :                         None => {
     652 CBC          37 :                             info!("connection closed");
     653              37 :                             break
     654                 :                         }
     655                 :                     }
     656                 :                 }
     657                 : 
     658                 :                 // remove from connection pool
     659              37 :                 if let Some(pool) = pool.clone().upgrade() {
     660               3 :                     if pool.write().remove_client(db_user.clone(), conn_id) {
     661 UBC           0 :                         info!("closed connection removed");
     662 CBC           3 :                     }
     663              34 :                 }
     664                 : 
     665              37 :                 Poll::Ready(())
     666            5846 :             }).await;
     667                 : 
     668              38 :         }
     669              38 :         .instrument(span)
     670              38 :     );
     671              38 : 
     672              38 :     Ok(ClientInner {
     673              38 :         inner: client,
     674              38 :         session: tx,
     675              38 :         ids,
     676              38 :         conn_id,
     677              38 :     })
     678              42 : }
     679                 : 
     680                 : struct ClientInner {
     681                 :     inner: tokio_postgres::Client,
     682                 :     session: tokio::sync::watch::Sender<uuid::Uuid>,
     683                 :     ids: Ids,
     684                 :     conn_id: uuid::Uuid,
     685                 : }
     686                 : 
     687                 : impl Client {
     688              40 :     pub fn metrics(&self) -> Arc<MetricCounter> {
     689              40 :         USAGE_METRICS.register(self.inner.as_ref().unwrap().ids.clone())
     690              40 :     }
     691                 : }
     692                 : 
     693                 : pub struct Client {
     694                 :     conn_id: uuid::Uuid,
     695                 :     span: Span,
     696                 :     inner: Option<ClientInner>,
     697                 :     conn_info: ConnInfo,
     698                 :     pool: Weak<RwLock<EndpointConnPool>>,
     699                 : }
     700                 : 
     701                 : pub struct Discard<'a> {
     702                 :     conn_id: uuid::Uuid,
     703                 :     conn_info: &'a ConnInfo,
     704                 :     pool: &'a mut Weak<RwLock<EndpointConnPool>>,
     705                 : }
     706                 : 
     707                 : impl Client {
     708              42 :     pub(self) async fn new(
     709              42 :         inner: ClientInner,
     710              42 :         conn_info: ConnInfo,
     711              42 :         pool: Weak<RwLock<EndpointConnPool>>,
     712              42 :     ) -> Self {
     713              42 :         Self {
     714              42 :             conn_id: inner.conn_id,
     715              42 :             inner: Some(inner),
     716              42 :             span: Span::current(),
     717              42 :             conn_info,
     718              42 :             pool,
     719              42 :         }
     720              42 :     }
     721              42 :     pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
     722              42 :         let Self {
     723              42 :             inner,
     724              42 :             pool,
     725              42 :             conn_id,
     726              42 :             conn_info,
     727              42 :             span: _,
     728              42 :         } = self;
     729              42 :         (
     730              42 :             &mut inner
     731              42 :                 .as_mut()
     732              42 :                 .expect("client inner should not be removed")
     733              42 :                 .inner,
     734              42 :             Discard {
     735              42 :                 pool,
     736              42 :                 conn_info,
     737              42 :                 conn_id: *conn_id,
     738              42 :             },
     739              42 :         )
     740              42 :     }
     741                 : 
     742              38 :     pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
     743              38 :         self.inner().1.check_idle(status)
     744              38 :     }
     745               2 :     pub fn discard(&mut self) {
     746               2 :         self.inner().1.discard()
     747               2 :     }
     748                 : }
     749                 : 
     750                 : impl Discard<'_> {
     751              40 :     pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
     752              40 :         let conn_info = &self.conn_info;
     753              40 :         if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
     754               2 :             info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle")
     755              38 :         }
     756              40 :     }
     757               2 :     pub fn discard(&mut self) {
     758               2 :         let conn_info = &self.conn_info;
     759               2 :         if std::mem::take(self.pool).strong_count() > 0 {
     760               2 :             info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
     761 UBC           0 :         }
     762 CBC           2 :     }
     763                 : }
     764                 : 
     765                 : impl Deref for Client {
     766                 :     type Target = tokio_postgres::Client;
     767                 : 
     768              40 :     fn deref(&self) -> &Self::Target {
     769              40 :         &self
     770              40 :             .inner
     771              40 :             .as_ref()
     772              40 :             .expect("client inner should not be removed")
     773              40 :             .inner
     774              40 :     }
     775                 : }
     776                 : 
     777                 : impl Drop for Client {
     778              42 :     fn drop(&mut self) {
     779              42 :         let conn_info = self.conn_info.clone();
     780              42 :         let client = self
     781              42 :             .inner
     782              42 :             .take()
     783              42 :             .expect("client inner should not be removed");
     784              42 :         if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
     785              20 :             let current_span = self.span.clone();
     786              20 :             // return connection to the pool
     787              20 :             tokio::task::spawn_blocking(move || {
     788              20 :                 let _span = current_span.enter();
     789              20 :                 let _ = EndpointConnPool::put(&conn_pool, &conn_info, client);
     790              20 :             });
     791              22 :         }
     792              42 :     }
     793                 : }
        

Generated by: LCOV version 2.1-beta