LCOV - code coverage report
Current view: top level - proxy/src/serverless - conn_pool.rs (source / functions) Coverage Total Hit
Test: b4ae4c4857f9ef3e144e982a35ee23bc84c71983.info Lines: 57.4 % 251 144
Test Date: 2024-10-22 22:13:45 Functions: 40.0 % 25 10

            Line data    Source code
       1              : use std::fmt;
       2              : use std::pin::pin;
       3              : use std::sync::{Arc, Weak};
       4              : use std::task::{ready, Poll};
       5              : 
       6              : use futures::future::poll_fn;
       7              : use futures::Future;
       8              : use smallvec::SmallVec;
       9              : use tokio::time::Instant;
      10              : use tokio_postgres::tls::NoTlsStream;
      11              : use tokio_postgres::{AsyncMessage, Socket};
      12              : use tokio_util::sync::CancellationToken;
      13              : use tracing::{error, info, info_span, warn, Instrument};
      14              : #[cfg(test)]
      15              : use {
      16              :     super::conn_pool_lib::GlobalConnPoolOptions,
      17              :     crate::auth::backend::ComputeUserInfo,
      18              :     std::{sync::atomic, time::Duration},
      19              : };
      20              : 
      21              : use super::conn_pool_lib::{Client, ClientInnerExt, ConnInfo, GlobalConnPool};
      22              : use crate::context::RequestMonitoring;
      23              : use crate::control_plane::messages::MetricsAuxInfo;
      24              : use crate::metrics::Metrics;
      25              : 
      26              : #[derive(Debug, Clone)]
      27              : pub(crate) struct ConnInfoWithAuth {
      28              :     pub(crate) conn_info: ConnInfo,
      29              :     pub(crate) auth: AuthData,
      30              : }
      31              : 
      32              : #[derive(Debug, Clone)]
      33              : pub(crate) enum AuthData {
      34              :     Password(SmallVec<[u8; 16]>),
      35              :     Jwt(String),
      36              : }
      37              : 
      38              : impl fmt::Display for ConnInfo {
      39              :     // use custom display to avoid logging password
      40            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      41            0 :         write!(
      42            0 :             f,
      43            0 :             "{}@{}/{}?{}",
      44            0 :             self.user_info.user,
      45            0 :             self.user_info.endpoint,
      46            0 :             self.dbname,
      47            0 :             self.user_info.options.get_cache_key("")
      48            0 :         )
      49            0 :     }
      50              : }
      51              : 
      52            0 : pub(crate) fn poll_client<C: ClientInnerExt>(
      53            0 :     global_pool: Arc<GlobalConnPool<C>>,
      54            0 :     ctx: &RequestMonitoring,
      55            0 :     conn_info: ConnInfo,
      56            0 :     client: C,
      57            0 :     mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
      58            0 :     conn_id: uuid::Uuid,
      59            0 :     aux: MetricsAuxInfo,
      60            0 : ) -> Client<C> {
      61            0 :     let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
      62            0 :     let mut session_id = ctx.session_id();
      63            0 :     let (tx, mut rx) = tokio::sync::watch::channel(session_id);
      64              : 
      65            0 :     let span = info_span!(parent: None, "connection", %conn_id);
      66            0 :     let cold_start_info = ctx.cold_start_info();
      67            0 :     span.in_scope(|| {
      68            0 :         info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
      69            0 :     });
      70            0 :     let pool = match conn_info.endpoint_cache_key() {
      71            0 :         Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
      72            0 :         None => Weak::new(),
      73              :     };
      74            0 :     let pool_clone = pool.clone();
      75            0 : 
      76            0 :     let db_user = conn_info.db_and_user();
      77            0 :     let idle = global_pool.get_idle_timeout();
      78            0 :     let cancel = CancellationToken::new();
      79            0 :     let cancelled = cancel.clone().cancelled_owned();
      80            0 : 
      81            0 :     tokio::spawn(
      82            0 :     async move {
      83            0 :         let _conn_gauge = conn_gauge;
      84            0 :         let mut idle_timeout = pin!(tokio::time::sleep(idle));
      85            0 :         let mut cancelled = pin!(cancelled);
      86            0 : 
      87            0 :         poll_fn(move |cx| {
      88            0 :             if cancelled.as_mut().poll(cx).is_ready() {
      89            0 :                 info!("connection dropped");
      90            0 :                 return Poll::Ready(())
      91            0 :             }
      92            0 : 
      93            0 :             match rx.has_changed() {
      94              :                 Ok(true) => {
      95            0 :                     session_id = *rx.borrow_and_update();
      96            0 :                     info!(%session_id, "changed session");
      97            0 :                     idle_timeout.as_mut().reset(Instant::now() + idle);
      98              :                 }
      99              :                 Err(_) => {
     100            0 :                     info!("connection dropped");
     101            0 :                     return Poll::Ready(())
     102              :                 }
     103            0 :                 _ => {}
     104              :             }
     105              : 
     106              :             // 5 minute idle connection timeout
     107            0 :             if idle_timeout.as_mut().poll(cx).is_ready() {
     108            0 :                 idle_timeout.as_mut().reset(Instant::now() + idle);
     109            0 :                 info!("connection idle");
     110            0 :                 if let Some(pool) = pool.clone().upgrade() {
     111              :                     // remove client from pool - should close the connection if it's idle.
     112              :                     // does nothing if the client is currently checked-out and in-use
     113            0 :                     if pool.write().remove_client(db_user.clone(), conn_id) {
     114            0 :                         info!("idle connection removed");
     115            0 :                     }
     116            0 :                 }
     117            0 :             }
     118              : 
     119              :             loop {
     120            0 :                 let message = ready!(connection.poll_message(cx));
     121              : 
     122            0 :                 match message {
     123            0 :                     Some(Ok(AsyncMessage::Notice(notice))) => {
     124            0 :                         info!(%session_id, "notice: {}", notice);
     125              :                     }
     126            0 :                     Some(Ok(AsyncMessage::Notification(notif))) => {
     127            0 :                         warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
     128              :                     }
     129              :                     Some(Ok(_)) => {
     130            0 :                         warn!(%session_id, "unknown message");
     131              :                     }
     132            0 :                     Some(Err(e)) => {
     133            0 :                         error!(%session_id, "connection error: {}", e);
     134            0 :                         break
     135              :                     }
     136              :                     None => {
     137            0 :                         info!("connection closed");
     138            0 :                         break
     139              :                     }
     140              :                 }
     141              :             }
     142              : 
     143              :             // remove from connection pool
     144            0 :             if let Some(pool) = pool.clone().upgrade() {
     145            0 :                 if pool.write().remove_client(db_user.clone(), conn_id) {
     146            0 :                     info!("closed connection removed");
     147            0 :                 }
     148            0 :             }
     149              : 
     150            0 :             Poll::Ready(())
     151            0 :         }).await;
     152              : 
     153            0 :     }
     154            0 :     .instrument(span));
     155            0 :     let inner = ClientInnerRemote {
     156            0 :         inner: client,
     157            0 :         session: tx,
     158            0 :         cancel,
     159            0 :         aux,
     160            0 :         conn_id,
     161            0 :     };
     162            0 :     Client::new(inner, conn_info, pool_clone)
     163            0 : }
     164              : 
     165              : pub(crate) struct ClientInnerRemote<C: ClientInnerExt> {
     166              :     inner: C,
     167              :     session: tokio::sync::watch::Sender<uuid::Uuid>,
     168              :     cancel: CancellationToken,
     169              :     aux: MetricsAuxInfo,
     170              :     conn_id: uuid::Uuid,
     171              : }
     172              : 
     173              : impl<C: ClientInnerExt> ClientInnerRemote<C> {
     174            1 :     pub(crate) fn inner_mut(&mut self) -> &mut C {
     175            1 :         &mut self.inner
     176            1 :     }
     177              : 
     178            0 :     pub(crate) fn inner(&self) -> &C {
     179            0 :         &self.inner
     180            0 :     }
     181              : 
     182            0 :     pub(crate) fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
     183            0 :         &mut self.session
     184            0 :     }
     185              : 
     186            0 :     pub(crate) fn aux(&self) -> &MetricsAuxInfo {
     187            0 :         &self.aux
     188            0 :     }
     189              : 
     190            6 :     pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
     191            6 :         self.conn_id
     192            6 :     }
     193              : 
     194            8 :     pub(crate) fn is_closed(&self) -> bool {
     195            8 :         self.inner.is_closed()
     196            8 :     }
     197              : }
     198              : 
     199              : impl<C: ClientInnerExt> Drop for ClientInnerRemote<C> {
     200            7 :     fn drop(&mut self) {
     201            7 :         // on client drop, tell the conn to shut down
     202            7 :         self.cancel.cancel();
     203            7 :     }
     204              : }
     205              : 
     206              : #[cfg(test)]
     207              : mod tests {
     208              :     use std::mem;
     209              :     use std::sync::atomic::AtomicBool;
     210              : 
     211              :     use super::*;
     212              :     use crate::proxy::NeonOptions;
     213              :     use crate::serverless::cancel_set::CancelSet;
     214              :     use crate::{BranchId, EndpointId, ProjectId};
     215              : 
     216              :     struct MockClient(Arc<AtomicBool>);
     217              :     impl MockClient {
     218            6 :         fn new(is_closed: bool) -> Self {
     219            6 :             MockClient(Arc::new(is_closed.into()))
     220            6 :         }
     221              :     }
     222              :     impl ClientInnerExt for MockClient {
     223            8 :         fn is_closed(&self) -> bool {
     224            8 :             self.0.load(atomic::Ordering::Relaxed)
     225            8 :         }
     226            0 :         fn get_process_id(&self) -> i32 {
     227            0 :             0
     228            0 :         }
     229              :     }
     230              : 
     231            5 :     fn create_inner() -> ClientInnerRemote<MockClient> {
     232            5 :         create_inner_with(MockClient::new(false))
     233            5 :     }
     234              : 
     235            7 :     fn create_inner_with(client: MockClient) -> ClientInnerRemote<MockClient> {
     236            7 :         ClientInnerRemote {
     237            7 :             inner: client,
     238            7 :             session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
     239            7 :             cancel: CancellationToken::new(),
     240            7 :             aux: MetricsAuxInfo {
     241            7 :                 endpoint_id: (&EndpointId::from("endpoint")).into(),
     242            7 :                 project_id: (&ProjectId::from("project")).into(),
     243            7 :                 branch_id: (&BranchId::from("branch")).into(),
     244            7 :                 cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
     245            7 :             },
     246            7 :             conn_id: uuid::Uuid::new_v4(),
     247            7 :         }
     248            7 :     }
     249              : 
     250              :     #[tokio::test]
     251            1 :     async fn test_pool() {
     252            1 :         let _ = env_logger::try_init();
     253            1 :         let config = Box::leak(Box::new(crate::config::HttpConfig {
     254            1 :             accept_websockets: false,
     255            1 :             pool_options: GlobalConnPoolOptions {
     256            1 :                 max_conns_per_endpoint: 2,
     257            1 :                 gc_epoch: Duration::from_secs(1),
     258            1 :                 pool_shards: 2,
     259            1 :                 idle_timeout: Duration::from_secs(1),
     260            1 :                 opt_in: false,
     261            1 :                 max_total_conns: 3,
     262            1 :             },
     263            1 :             cancel_set: CancelSet::new(0),
     264            1 :             client_conn_threshold: u64::MAX,
     265            1 :             max_request_size_bytes: u64::MAX,
     266            1 :             max_response_size_bytes: usize::MAX,
     267            1 :         }));
     268            1 :         let pool = GlobalConnPool::new(config);
     269            1 :         let conn_info = ConnInfo {
     270            1 :             user_info: ComputeUserInfo {
     271            1 :                 user: "user".into(),
     272            1 :                 endpoint: "endpoint".into(),
     273            1 :                 options: NeonOptions::default(),
     274            1 :             },
     275            1 :             dbname: "dbname".into(),
     276            1 :         };
     277            1 :         let ep_pool = Arc::downgrade(
     278            1 :             &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
     279            1 :         );
     280            1 :         {
     281            1 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     282            1 :             assert_eq!(0, pool.get_global_connections_count());
     283            1 :             client.inner_mut().1.discard();
     284            1 :             // Discard should not add the connection from the pool.
     285            1 :             assert_eq!(0, pool.get_global_connections_count());
     286            1 :         }
     287            1 :         {
     288            1 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     289            1 :             client.do_drop().unwrap()();
     290            1 :             mem::forget(client); // drop the client
     291            1 :             assert_eq!(1, pool.get_global_connections_count());
     292            1 :         }
     293            1 :         {
     294            1 :             let mut closed_client = Client::new(
     295            1 :                 create_inner_with(MockClient::new(true)),
     296            1 :                 conn_info.clone(),
     297            1 :                 ep_pool.clone(),
     298            1 :             );
     299            1 :             closed_client.do_drop().unwrap()();
     300            1 :             mem::forget(closed_client); // drop the client
     301            1 :                                         // The closed client shouldn't be added to the pool.
     302            1 :             assert_eq!(1, pool.get_global_connections_count());
     303            1 :         }
     304            1 :         let is_closed: Arc<AtomicBool> = Arc::new(false.into());
     305            1 :         {
     306            1 :             let mut client = Client::new(
     307            1 :                 create_inner_with(MockClient(is_closed.clone())),
     308            1 :                 conn_info.clone(),
     309            1 :                 ep_pool.clone(),
     310            1 :             );
     311            1 :             client.do_drop().unwrap()();
     312            1 :             mem::forget(client); // drop the client
     313            1 : 
     314            1 :             // The client should be added to the pool.
     315            1 :             assert_eq!(2, pool.get_global_connections_count());
     316            1 :         }
     317            1 :         {
     318            1 :             let mut client = Client::new(create_inner(), conn_info, ep_pool);
     319            1 :             client.do_drop().unwrap()();
     320            1 :             mem::forget(client); // drop the client
     321            1 : 
     322            1 :             // The client shouldn't be added to the pool. Because the ep-pool is full.
     323            1 :             assert_eq!(2, pool.get_global_connections_count());
     324            1 :         }
     325            1 : 
     326            1 :         let conn_info = ConnInfo {
     327            1 :             user_info: ComputeUserInfo {
     328            1 :                 user: "user".into(),
     329            1 :                 endpoint: "endpoint-2".into(),
     330            1 :                 options: NeonOptions::default(),
     331            1 :             },
     332            1 :             dbname: "dbname".into(),
     333            1 :         };
     334            1 :         let ep_pool = Arc::downgrade(
     335            1 :             &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
     336            1 :         );
     337            1 :         {
     338            1 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     339            1 :             client.do_drop().unwrap()();
     340            1 :             mem::forget(client); // drop the client
     341            1 :             assert_eq!(3, pool.get_global_connections_count());
     342            1 :         }
     343            1 :         {
     344            1 :             let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
     345            1 :             client.do_drop().unwrap()();
     346            1 :             mem::forget(client); // drop the client
     347            1 : 
     348            1 :             // The client shouldn't be added to the pool. Because the global pool is full.
     349            1 :             assert_eq!(3, pool.get_global_connections_count());
     350            1 :         }
     351            1 : 
     352            1 :         is_closed.store(true, atomic::Ordering::Relaxed);
     353            1 :         // Do gc for all shards.
     354            1 :         pool.gc(0);
     355            1 :         pool.gc(1);
     356            1 :         // Closed client should be removed from the pool.
     357            1 :         assert_eq!(2, pool.get_global_connections_count());
     358            1 :     }
     359              : }
        

Generated by: LCOV version 2.1-beta