LCOV - code coverage report
Current view: top level - proxy/src/http - conn_pool.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 89.6 % 251 225
Test Date: 2023-09-06 10:18:01 Functions: 78.6 % 42 33

            Line data    Source code
       1              : use anyhow::Context;
       2              : use async_trait::async_trait;
       3              : use dashmap::DashMap;
       4              : use futures::future::poll_fn;
       5              : use parking_lot::RwLock;
       6              : use pbkdf2::{
       7              :     password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
       8              :     Params, Pbkdf2,
       9              : };
      10              : use pq_proto::StartupMessageParams;
      11              : use std::sync::atomic::{self, AtomicUsize};
      12              : use std::{collections::HashMap, sync::Arc};
      13              : use std::{
      14              :     fmt,
      15              :     task::{ready, Poll},
      16              : };
      17              : use tokio::time;
      18              : use tokio_postgres::AsyncMessage;
      19              : 
      20              : use crate::{auth, console};
      21              : use crate::{compute, config};
      22              : 
      23              : use super::sql_over_http::MAX_RESPONSE_SIZE;
      24              : 
      25              : use crate::proxy::ConnectMechanism;
      26              : 
      27              : use tracing::{error, warn};
      28              : use tracing::{info, info_span, Instrument};
      29              : 
      30              : pub const APP_NAME: &str = "sql_over_http";
      31              : const MAX_CONNS_PER_ENDPOINT: usize = 20;
      32              : 
      33            0 : #[derive(Debug)]
      34              : pub struct ConnInfo {
      35              :     pub username: String,
      36              :     pub dbname: String,
      37              :     pub hostname: String,
      38              :     pub password: String,
      39              : }
      40              : 
      41              : impl ConnInfo {
      42              :     // hm, change to hasher to avoid cloning?
      43           14 :     pub fn db_and_user(&self) -> (String, String) {
      44           14 :         (self.dbname.clone(), self.username.clone())
      45           14 :     }
      46              : }
      47              : 
      48              : impl fmt::Display for ConnInfo {
      49              :     // use custom display to avoid logging password
      50           88 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      51           88 :         write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname)
      52           88 :     }
      53              : }
      54              : 
      55              : struct ConnPoolEntry {
      56              :     conn: Client,
      57              :     _last_access: std::time::Instant,
      58              : }
      59              : 
      60              : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
      61              : // Number of open connections is limited by the `max_conns_per_endpoint`.
      62              : pub struct EndpointConnPool {
      63              :     pools: HashMap<(String, String), DbUserConnPool>,
      64              :     total_conns: usize,
      65              : }
      66              : 
      67              : /// 4096 is the number of rounds that SCRAM-SHA-256 recommends.
      68              : /// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway.
      69              : ///
      70              : /// Still takes 1.4ms to hash on my hardware.
      71              : /// We don't want to ruin the latency improvements of using the pool by making password verification take too long
      72              : const PARAMS: Params = Params {
      73              :     rounds: 4096,
      74              :     output_length: 32,
      75              : };
      76              : 
      77            1 : #[derive(Default)]
      78              : pub struct DbUserConnPool {
      79              :     conns: Vec<ConnPoolEntry>,
      80              :     password_hash: Option<PasswordHashString>,
      81              : }
      82              : 
      83              : pub struct GlobalConnPool {
      84              :     // endpoint -> per-endpoint connection pool
      85              :     //
      86              :     // That should be a fairly conteded map, so return reference to the per-endpoint
      87              :     // pool as early as possible and release the lock.
      88              :     global_pool: DashMap<String, Arc<RwLock<EndpointConnPool>>>,
      89              : 
      90              :     /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
      91              :     /// That seems like far too much effort, so we're using a relaxed increment counter instead.
      92              :     /// It's only used for diagnostics.
      93              :     global_pool_size: AtomicUsize,
      94              : 
      95              :     // Maximum number of connections per one endpoint.
      96              :     // Can mix different (dbname, username) connections.
      97              :     // When running out of free slots for a particular endpoint,
      98              :     // falls back to opening a new connection for each request.
      99              :     max_conns_per_endpoint: usize,
     100              : 
     101              :     proxy_config: &'static crate::config::ProxyConfig,
     102              : 
     103              :     // Using a lock to remove any race conditions.
     104              :     // Eg cleaning up connections while a new connection is returned
     105              :     closed: RwLock<bool>,
     106              : }
     107              : 
     108              : impl GlobalConnPool {
     109           14 :     pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
     110           14 :         Arc::new(Self {
     111           14 :             global_pool: DashMap::new(),
     112           14 :             global_pool_size: AtomicUsize::new(0),
     113           14 :             max_conns_per_endpoint: MAX_CONNS_PER_ENDPOINT,
     114           14 :             proxy_config: config,
     115           14 :             closed: RwLock::new(false),
     116           14 :         })
     117           14 :     }
     118              : 
     119           14 :     pub fn shutdown(&self) {
     120           14 :         *self.closed.write() = true;
     121           14 : 
     122           14 :         self.global_pool.retain(|_, endpoint_pool| {
     123            1 :             let mut pool = endpoint_pool.write();
     124            1 :             // by clearing this hashmap, we remove the slots that a connection can be returned to.
     125            1 :             // when returning, it drops the connection if the slot doesn't exist
     126            1 :             pool.pools.clear();
     127            1 :             pool.total_conns = 0;
     128            1 : 
     129            1 :             false
     130           14 :         });
     131           14 :     }
     132              : 
     133           22 :     pub async fn get(
     134           22 :         &self,
     135           22 :         conn_info: &ConnInfo,
     136           22 :         force_new: bool,
     137           22 :         session_id: uuid::Uuid,
     138           22 :     ) -> anyhow::Result<Client> {
     139           22 :         let mut client: Option<Client> = None;
     140           22 : 
     141           22 :         let mut hash_valid = false;
     142           22 :         if !force_new {
     143            6 :             let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
     144            6 :             let mut hash = None;
     145            6 : 
     146            6 :             // find a pool entry by (dbname, username) if exists
     147            6 :             {
     148            6 :                 let pool = pool.read();
     149            6 :                 if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) {
     150            5 :                     if !pool_entries.conns.is_empty() {
     151            5 :                         hash = pool_entries.password_hash.clone();
     152            5 :                     }
     153            1 :                 }
     154              :             }
     155              : 
     156              :             // a connection exists in the pool, verify the password hash
     157            6 :             if let Some(hash) = hash {
     158            5 :                 let pw = conn_info.password.clone();
     159            5 :                 let validate = tokio::task::spawn_blocking(move || {
     160            5 :                     Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash())
     161            5 :                 })
     162           10 :                 .await?;
     163              : 
     164              :                 // if the hash is invalid, don't error
     165              :                 // we will continue with the regular connection flow
     166            5 :                 if validate.is_ok() {
     167            2 :                     hash_valid = true;
     168            2 :                     let mut pool = pool.write();
     169            2 :                     if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
     170            2 :                         if let Some(entry) = pool_entries.conns.pop() {
     171            2 :                             client = Some(entry.conn);
     172            2 :                             pool.total_conns -= 1;
     173            2 :                         }
     174            0 :                     }
     175            3 :                 }
     176            1 :             }
     177           16 :         }
     178              : 
     179              :         // ok return cached connection if found and establish a new one otherwise
     180           22 :         let new_client = if let Some(client) = client {
     181            2 :             if client.inner.is_closed() {
     182            0 :                 info!("pool: cached connection '{conn_info}' is closed, opening a new one");
     183            0 :                 connect_to_compute(self.proxy_config, conn_info, session_id).await
     184              :             } else {
     185            2 :                 info!("pool: reusing connection '{conn_info}'");
     186            2 :                 client.session.send(session_id)?;
     187            2 :                 return Ok(client);
     188              :             }
     189              :         } else {
     190           20 :             info!("pool: opening a new connection '{conn_info}'");
     191           88 :             connect_to_compute(self.proxy_config, conn_info, session_id).await
     192              :         };
     193              : 
     194           18 :         match &new_client {
     195            2 :             // clear the hash. it's no longer valid
     196            2 :             // TODO: update tokio-postgres fork to allow access to this error kind directly
     197            2 :             Err(err)
     198            2 :                 if hash_valid && err.to_string().contains("password authentication failed") =>
     199            0 :             {
     200            0 :                 let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
     201            0 :                 let mut pool = pool.write();
     202            0 :                 if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
     203            0 :                     entry.password_hash = None;
     204            0 :                 }
     205              :             }
     206              :             // new password is valid and we should insert/update it
     207           18 :             Ok(_) if !force_new && !hash_valid => {
     208            2 :                 let pw = conn_info.password.clone();
     209            2 :                 let new_hash = tokio::task::spawn_blocking(move || {
     210            2 :                     let salt = SaltString::generate(rand::rngs::OsRng);
     211            2 :                     Pbkdf2
     212            2 :                         .hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt)
     213            2 :                         .map(|s| s.serialize())
     214            2 :                 })
     215            2 :                 .await??;
     216              : 
     217            2 :                 let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
     218            2 :                 let mut pool = pool.write();
     219            2 :                 pool.pools
     220            2 :                     .entry(conn_info.db_and_user())
     221            2 :                     .or_default()
     222            2 :                     .password_hash = Some(new_hash);
     223              :             }
     224           18 :             _ => {}
     225              :         }
     226              : 
     227           20 :         new_client
     228           22 :     }
     229              : 
     230            4 :     pub fn put(&self, conn_info: &ConnInfo, client: Client) -> anyhow::Result<()> {
     231            4 :         // We want to hold this open while we return. This ensures that the pool can't close
     232            4 :         // while we are in the middle of returning the connection.
     233            4 :         let closed = self.closed.read();
     234            4 :         if *closed {
     235            0 :             info!("pool: throwing away connection '{conn_info}' because pool is closed");
     236            0 :             return Ok(());
     237            4 :         }
     238            4 : 
     239            4 :         if client.inner.is_closed() {
     240            0 :             info!("pool: throwing away connection '{conn_info}' because connection is closed");
     241            0 :             return Ok(());
     242            4 :         }
     243            4 : 
     244            4 :         let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
     245            4 : 
     246            4 :         // return connection to the pool
     247            4 :         let mut returned = false;
     248            4 :         let mut per_db_size = 0;
     249            4 :         let total_conns = {
     250            4 :             let mut pool = pool.write();
     251            4 : 
     252            4 :             if pool.total_conns < self.max_conns_per_endpoint {
     253              :                 // we create this db-user entry in get, so it should not be None
     254            4 :                 if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
     255            4 :                     pool_entries.conns.push(ConnPoolEntry {
     256            4 :                         conn: client,
     257            4 :                         _last_access: std::time::Instant::now(),
     258            4 :                     });
     259            4 : 
     260            4 :                     returned = true;
     261            4 :                     per_db_size = pool_entries.conns.len();
     262            4 : 
     263            4 :                     pool.total_conns += 1;
     264            4 :                 }
     265            0 :             }
     266              : 
     267            4 :             pool.total_conns
     268            4 :         };
     269            4 : 
     270            4 :         // do logging outside of the mutex
     271            4 :         if returned {
     272            4 :             info!("pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
     273              :         } else {
     274            0 :             info!("pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
     275              :         }
     276              : 
     277            4 :         Ok(())
     278            4 :     }
     279              : 
     280              :     fn get_or_create_endpoint_pool(&self, endpoint: &String) -> Arc<RwLock<EndpointConnPool>> {
     281              :         // fast path
     282           12 :         if let Some(pool) = self.global_pool.get(endpoint) {
     283           11 :             return pool.clone();
     284            1 :         }
     285            1 : 
     286            1 :         // slow path
     287            1 :         let new_pool = Arc::new(RwLock::new(EndpointConnPool {
     288            1 :             pools: HashMap::new(),
     289            1 :             total_conns: 0,
     290            1 :         }));
     291            1 : 
     292            1 :         // find or create a pool for this endpoint
     293            1 :         let mut created = false;
     294            1 :         let pool = self
     295            1 :             .global_pool
     296            1 :             .entry(endpoint.clone())
     297            1 :             .or_insert_with(|| {
     298            1 :                 created = true;
     299            1 :                 new_pool
     300            1 :             })
     301            1 :             .clone();
     302            1 : 
     303            1 :         // log new global pool size
     304            1 :         if created {
     305            1 :             let global_pool_size = self
     306            1 :                 .global_pool_size
     307            1 :                 .fetch_add(1, atomic::Ordering::Relaxed)
     308            1 :                 + 1;
     309            1 :             info!(
     310            1 :                 "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
     311            1 :             );
     312            0 :         }
     313              : 
     314            1 :         pool
     315           12 :     }
     316              : }
     317              : 
     318              : struct TokioMechanism<'a> {
     319              :     conn_info: &'a ConnInfo,
     320              :     session_id: uuid::Uuid,
     321              : }
     322              : 
     323              : #[async_trait]
     324              : impl ConnectMechanism for TokioMechanism<'_> {
     325              :     type Connection = Client;
     326              :     type ConnectError = tokio_postgres::Error;
     327              :     type Error = anyhow::Error;
     328              : 
     329           22 :     async fn connect_once(
     330           22 :         &self,
     331           22 :         node_info: &console::CachedNodeInfo,
     332           22 :         timeout: time::Duration,
     333           22 :     ) -> Result<Self::Connection, Self::ConnectError> {
     334           88 :         connect_to_compute_once(node_info, self.conn_info, timeout, self.session_id).await
     335           44 :     }
     336              : 
     337           22 :     fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
     338              : }
     339              : 
     340              : // Wake up the destination if needed. Code here is a bit involved because
     341              : // we reuse the code from the usual proxy and we need to prepare few structures
     342              : // that this code expects.
     343           80 : #[tracing::instrument(skip_all)]
     344              : async fn connect_to_compute(
     345              :     config: &config::ProxyConfig,
     346              :     conn_info: &ConnInfo,
     347              :     session_id: uuid::Uuid,
     348              : ) -> anyhow::Result<Client> {
     349              :     let tls = config.tls_config.as_ref();
     350           20 :     let common_names = tls.and_then(|tls| tls.common_names.clone());
     351              : 
     352              :     let credential_params = StartupMessageParams::new([
     353              :         ("user", &conn_info.username),
     354              :         ("database", &conn_info.dbname),
     355              :         ("application_name", APP_NAME),
     356              :     ]);
     357              : 
     358              :     let creds = config
     359              :         .auth_backend
     360              :         .as_ref()
     361           20 :         .map(|_| {
     362           20 :             auth::ClientCredentials::parse(
     363           20 :                 &credential_params,
     364           20 :                 Some(&conn_info.hostname),
     365           20 :                 common_names,
     366           20 :             )
     367           20 :         })
     368              :         .transpose()?;
     369              :     let extra = console::ConsoleReqExtra {
     370              :         session_id: uuid::Uuid::new_v4(),
     371              :         application_name: Some(APP_NAME),
     372              :     };
     373              : 
     374              :     let node_info = creds
     375              :         .wake_compute(&extra)
     376              :         .await?
     377              :         .context("missing cache entry from wake_compute")?;
     378              : 
     379              :     crate::proxy::connect_to_compute(
     380              :         &TokioMechanism {
     381              :             conn_info,
     382              :             session_id,
     383              :         },
     384              :         node_info,
     385              :         &extra,
     386              :         &creds,
     387              :     )
     388              :     .await
     389              : }
     390              : 
     391           22 : async fn connect_to_compute_once(
     392           22 :     node_info: &console::CachedNodeInfo,
     393           22 :     conn_info: &ConnInfo,
     394           22 :     timeout: time::Duration,
     395           22 :     mut session: uuid::Uuid,
     396           22 : ) -> Result<Client, tokio_postgres::Error> {
     397           22 :     let mut config = (*node_info.config).clone();
     398              : 
     399           22 :     let (client, mut connection) = config
     400           22 :         .user(&conn_info.username)
     401           22 :         .password(&conn_info.password)
     402           22 :         .dbname(&conn_info.dbname)
     403           22 :         .max_backend_message_size(MAX_RESPONSE_SIZE)
     404           22 :         .connect_timeout(timeout)
     405           22 :         .connect(tokio_postgres::NoTls)
     406           88 :         .await?;
     407              : 
     408           18 :     let (tx, mut rx) = tokio::sync::watch::channel(session);
     409           18 : 
     410           18 :     let conn_id = uuid::Uuid::new_v4();
     411           18 :     let span = info_span!(parent: None, "connection", %conn_id);
     412           18 :     span.in_scope(|| {
     413           18 :         info!(%conn_info, %session, "new connection");
     414           18 :     });
     415           18 : 
     416           18 :     tokio::spawn(
     417          131 :         poll_fn(move |cx| {
     418          131 :             if matches!(rx.has_changed(), Ok(true)) {
     419            2 :                 session = *rx.borrow_and_update();
     420            2 :                 info!(%session, "changed session");
     421          129 :             }
     422              : 
     423              :             loop {
     424          131 :                 let message = ready!(connection.poll_message(cx));
     425              : 
     426            0 :                 match message {
     427            0 :                     Some(Ok(AsyncMessage::Notice(notice))) => {
     428            0 :                         info!(%session, "notice: {}", notice);
     429              :                     }
     430            0 :                     Some(Ok(AsyncMessage::Notification(notif))) => {
     431            0 :                         warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received");
     432              :                     }
     433              :                     Some(Ok(_)) => {
     434            0 :                         warn!(%session, "unknown message");
     435              :                     }
     436            0 :                     Some(Err(e)) => {
     437            0 :                         error!(%session, "connection error: {}", e);
     438            0 :                         return Poll::Ready(())
     439              :                     }
     440              :                     None => {
     441           18 :                         info!("connection closed");
     442           18 :                         return Poll::Ready(())
     443              :                     }
     444              :                 }
     445              :             }
     446          131 :         })
     447           18 :         .instrument(span)
     448           18 :     );
     449           18 : 
     450           18 :     Ok(Client {
     451           18 :         inner: client,
     452           18 :         session: tx,
     453           18 :     })
     454           22 : }
     455              : 
     456              : pub struct Client {
     457              :     pub inner: tokio_postgres::Client,
     458              :     session: tokio::sync::watch::Sender<uuid::Uuid>,
     459              : }
        

Generated by: LCOV version 2.1-beta