Line data Source code
1 : use dashmap::DashMap;
2 : use futures::{future::poll_fn, Future};
3 : use parking_lot::RwLock;
4 : use rand::Rng;
5 : use smallvec::SmallVec;
6 : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
7 : use std::{
8 : fmt,
9 : task::{ready, Poll},
10 : };
11 : use std::{
12 : ops::Deref,
13 : sync::atomic::{self, AtomicUsize},
14 : };
15 : use tokio::time::Instant;
16 : use tokio_postgres::tls::NoTlsStream;
17 : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
18 : use tokio_util::sync::CancellationToken;
19 :
20 : use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
21 : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
22 : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
23 : use crate::{
24 : auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
25 : };
26 :
27 : use tracing::{debug, error, warn, Span};
28 : use tracing::{info, info_span, Instrument};
29 :
30 : use super::backend::HttpConnError;
31 :
32 : #[derive(Debug, Clone)]
33 : pub(crate) struct ConnInfoWithAuth {
34 : pub(crate) conn_info: ConnInfo,
35 : pub(crate) auth: AuthData,
36 : }
37 :
38 : #[derive(Debug, Clone)]
39 : pub(crate) struct ConnInfo {
40 : pub(crate) user_info: ComputeUserInfo,
41 : pub(crate) dbname: DbName,
42 : }
43 :
44 : #[derive(Debug, Clone)]
45 : pub(crate) enum AuthData {
46 : Password(SmallVec<[u8; 16]>),
47 : Jwt(String),
48 : }
49 :
50 : impl ConnInfo {
51 : // hm, change to hasher to avoid cloning?
52 3 : pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
53 3 : (self.dbname.clone(), self.user_info.user.clone())
54 3 : }
55 :
56 2 : pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
57 2 : // We don't want to cache http connections for ephemeral endpoints.
58 2 : if self.user_info.options.is_ephemeral() {
59 0 : None
60 : } else {
61 2 : Some(self.user_info.endpoint_cache_key())
62 : }
63 2 : }
64 : }
65 :
66 : impl fmt::Display for ConnInfo {
67 : // use custom display to avoid logging password
68 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 0 : write!(
70 0 : f,
71 0 : "{}@{}/{}?{}",
72 0 : self.user_info.user,
73 0 : self.user_info.endpoint,
74 0 : self.dbname,
75 0 : self.user_info.options.get_cache_key("")
76 0 : )
77 0 : }
78 : }
79 :
80 : struct ConnPoolEntry<C: ClientInnerExt> {
81 : conn: ClientInner<C>,
82 : _last_access: std::time::Instant,
83 : }
84 :
85 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
86 : // Number of open connections is limited by the `max_conns_per_endpoint`.
87 : pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
88 : pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
89 : total_conns: usize,
90 : max_conns: usize,
91 : _guard: HttpEndpointPoolsGuard<'static>,
92 : global_connections_count: Arc<AtomicUsize>,
93 : global_pool_size_max_conns: usize,
94 : }
95 :
96 : impl<C: ClientInnerExt> EndpointConnPool<C> {
97 0 : fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
98 0 : let Self {
99 0 : pools,
100 0 : total_conns,
101 0 : global_connections_count,
102 0 : ..
103 0 : } = self;
104 0 : pools.get_mut(&db_user).and_then(|pool_entries| {
105 0 : pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
106 0 : })
107 0 : }
108 :
109 0 : fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
110 0 : let Self {
111 0 : pools,
112 0 : total_conns,
113 0 : global_connections_count,
114 0 : ..
115 0 : } = self;
116 0 : if let Some(pool) = pools.get_mut(&db_user) {
117 0 : let old_len = pool.conns.len();
118 0 : pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
119 0 : let new_len = pool.conns.len();
120 0 : let removed = old_len - new_len;
121 0 : if removed > 0 {
122 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
123 0 : Metrics::get()
124 0 : .proxy
125 0 : .http_pool_opened_connections
126 0 : .get_metric()
127 0 : .dec_by(removed as i64);
128 0 : }
129 0 : *total_conns -= removed;
130 0 : removed > 0
131 : } else {
132 0 : false
133 : }
134 0 : }
135 :
136 6 : fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
137 6 : let conn_id = client.conn_id;
138 6 :
139 6 : if client.is_closed() {
140 1 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
141 1 : return;
142 5 : }
143 5 : let global_max_conn = pool.read().global_pool_size_max_conns;
144 5 : if pool
145 5 : .read()
146 5 : .global_connections_count
147 5 : .load(atomic::Ordering::Relaxed)
148 5 : >= global_max_conn
149 : {
150 1 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
151 1 : return;
152 4 : }
153 4 :
154 4 : // return connection to the pool
155 4 : let mut returned = false;
156 4 : let mut per_db_size = 0;
157 4 : let total_conns = {
158 4 : let mut pool = pool.write();
159 4 :
160 4 : if pool.total_conns < pool.max_conns {
161 3 : let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
162 3 : pool_entries.conns.push(ConnPoolEntry {
163 3 : conn: client,
164 3 : _last_access: std::time::Instant::now(),
165 3 : });
166 3 :
167 3 : returned = true;
168 3 : per_db_size = pool_entries.conns.len();
169 3 :
170 3 : pool.total_conns += 1;
171 3 : pool.global_connections_count
172 3 : .fetch_add(1, atomic::Ordering::Relaxed);
173 3 : Metrics::get()
174 3 : .proxy
175 3 : .http_pool_opened_connections
176 3 : .get_metric()
177 3 : .inc();
178 3 : }
179 :
180 4 : pool.total_conns
181 4 : };
182 4 :
183 4 : // do logging outside of the mutex
184 4 : if returned {
185 3 : info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
186 : } else {
187 1 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
188 : }
189 6 : }
190 : }
191 :
192 : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
193 2 : fn drop(&mut self) {
194 2 : if self.total_conns > 0 {
195 2 : self.global_connections_count
196 2 : .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
197 2 : Metrics::get()
198 2 : .proxy
199 2 : .http_pool_opened_connections
200 2 : .get_metric()
201 2 : .dec_by(self.total_conns as i64);
202 2 : }
203 2 : }
204 : }
205 :
206 : pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
207 : conns: Vec<ConnPoolEntry<C>>,
208 : }
209 :
210 : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
211 2 : fn default() -> Self {
212 2 : Self { conns: Vec::new() }
213 2 : }
214 : }
215 :
216 : impl<C: ClientInnerExt> DbUserConnPool<C> {
217 1 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
218 1 : let old_len = self.conns.len();
219 1 :
220 2 : self.conns.retain(|conn| !conn.conn.is_closed());
221 1 :
222 1 : let new_len = self.conns.len();
223 1 : let removed = old_len - new_len;
224 1 : *conns -= removed;
225 1 : removed
226 1 : }
227 :
228 0 : fn get_conn_entry(
229 0 : &mut self,
230 0 : conns: &mut usize,
231 0 : global_connections_count: Arc<AtomicUsize>,
232 0 : ) -> Option<ConnPoolEntry<C>> {
233 0 : let mut removed = self.clear_closed_clients(conns);
234 0 : let conn = self.conns.pop();
235 0 : if conn.is_some() {
236 0 : *conns -= 1;
237 0 : removed += 1;
238 0 : }
239 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
240 0 : Metrics::get()
241 0 : .proxy
242 0 : .http_pool_opened_connections
243 0 : .get_metric()
244 0 : .dec_by(removed as i64);
245 0 : conn
246 0 : }
247 : }
248 :
249 : pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
250 : // endpoint -> per-endpoint connection pool
251 : //
252 : // That should be a fairly conteded map, so return reference to the per-endpoint
253 : // pool as early as possible and release the lock.
254 : global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
255 :
256 : /// Number of endpoint-connection pools
257 : ///
258 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
259 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
260 : /// It's only used for diagnostics.
261 : global_pool_size: AtomicUsize,
262 :
263 : /// Total number of connections in the pool
264 : global_connections_count: Arc<AtomicUsize>,
265 :
266 : config: &'static crate::config::HttpConfig,
267 : }
268 :
269 : #[derive(Debug, Clone, Copy)]
270 : pub struct GlobalConnPoolOptions {
271 : // Maximum number of connections per one endpoint.
272 : // Can mix different (dbname, username) connections.
273 : // When running out of free slots for a particular endpoint,
274 : // falls back to opening a new connection for each request.
275 : pub max_conns_per_endpoint: usize,
276 :
277 : pub gc_epoch: Duration,
278 :
279 : pub pool_shards: usize,
280 :
281 : pub idle_timeout: Duration,
282 :
283 : pub opt_in: bool,
284 :
285 : // Total number of connections in the pool.
286 : pub max_total_conns: usize,
287 : }
288 :
289 : impl<C: ClientInnerExt> GlobalConnPool<C> {
290 1 : pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
291 1 : let shards = config.pool_options.pool_shards;
292 1 : Arc::new(Self {
293 1 : global_pool: DashMap::with_shard_amount(shards),
294 1 : global_pool_size: AtomicUsize::new(0),
295 1 : config,
296 1 : global_connections_count: Arc::new(AtomicUsize::new(0)),
297 1 : })
298 1 : }
299 :
300 : #[cfg(test)]
301 9 : pub(crate) fn get_global_connections_count(&self) -> usize {
302 9 : self.global_connections_count
303 9 : .load(atomic::Ordering::Relaxed)
304 9 : }
305 :
306 0 : pub(crate) fn get_idle_timeout(&self) -> Duration {
307 0 : self.config.pool_options.idle_timeout
308 0 : }
309 :
310 0 : pub(crate) fn shutdown(&self) {
311 0 : // drops all strong references to endpoint-pools
312 0 : self.global_pool.clear();
313 0 : }
314 :
315 0 : pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
316 0 : let epoch = self.config.pool_options.gc_epoch;
317 0 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
318 : loop {
319 0 : interval.tick().await;
320 :
321 0 : let shard = rng.gen_range(0..self.global_pool.shards().len());
322 0 : self.gc(shard);
323 : }
324 : }
325 :
326 2 : fn gc(&self, shard: usize) {
327 2 : debug!(shard, "pool: performing epoch reclamation");
328 :
329 : // acquire a random shard lock
330 2 : let mut shard = self.global_pool.shards()[shard].write();
331 2 :
332 2 : let timer = Metrics::get()
333 2 : .proxy
334 2 : .http_pool_reclaimation_lag_seconds
335 2 : .start_timer();
336 2 : let current_len = shard.len();
337 2 : let mut clients_removed = 0;
338 2 : shard.retain(|endpoint, x| {
339 : // if the current endpoint pool is unique (no other strong or weak references)
340 : // then it is currently not in use by any connections.
341 2 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
342 : let EndpointConnPool {
343 1 : pools, total_conns, ..
344 1 : } = pool.get_mut();
345 :
346 : // ensure that closed clients are removed
347 1 : for db_pool in pools.values_mut() {
348 1 : clients_removed += db_pool.clear_closed_clients(total_conns);
349 1 : }
350 :
351 : // we only remove this pool if it has no active connections
352 1 : if *total_conns == 0 {
353 0 : info!("pool: discarding pool for endpoint {endpoint}");
354 0 : return false;
355 1 : }
356 1 : }
357 :
358 2 : true
359 2 : });
360 2 :
361 2 : let new_len = shard.len();
362 2 : drop(shard);
363 2 : timer.observe();
364 2 :
365 2 : // Do logging outside of the lock.
366 2 : if clients_removed > 0 {
367 1 : let size = self
368 1 : .global_connections_count
369 1 : .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
370 1 : - clients_removed;
371 1 : Metrics::get()
372 1 : .proxy
373 1 : .http_pool_opened_connections
374 1 : .get_metric()
375 1 : .dec_by(clients_removed as i64);
376 1 : info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
377 1 : }
378 2 : let removed = current_len - new_len;
379 2 :
380 2 : if removed > 0 {
381 0 : let global_pool_size = self
382 0 : .global_pool_size
383 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
384 0 : - removed;
385 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
386 2 : }
387 2 : }
388 :
389 0 : pub(crate) fn get(
390 0 : self: &Arc<Self>,
391 0 : ctx: &RequestMonitoring,
392 0 : conn_info: &ConnInfo,
393 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
394 0 : let mut client: Option<ClientInner<C>> = None;
395 0 : let Some(endpoint) = conn_info.endpoint_cache_key() else {
396 0 : return Ok(None);
397 : };
398 :
399 0 : let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
400 0 : if let Some(entry) = endpoint_pool
401 0 : .write()
402 0 : .get_conn_entry(conn_info.db_and_user())
403 0 : {
404 0 : client = Some(entry.conn);
405 0 : }
406 0 : let endpoint_pool = Arc::downgrade(&endpoint_pool);
407 :
408 : // ok return cached connection if found and establish a new one otherwise
409 0 : if let Some(client) = client {
410 0 : if client.is_closed() {
411 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
412 0 : return Ok(None);
413 0 : }
414 0 : tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
415 0 : tracing::Span::current().record(
416 0 : "pid",
417 0 : tracing::field::display(client.inner.get_process_id()),
418 0 : );
419 0 : info!(
420 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
421 0 : "pool: reusing connection '{conn_info}'"
422 : );
423 0 : client.session.send(ctx.session_id())?;
424 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
425 0 : ctx.success();
426 0 : return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
427 0 : }
428 0 : Ok(None)
429 0 : }
430 :
431 2 : fn get_or_create_endpoint_pool(
432 2 : self: &Arc<Self>,
433 2 : endpoint: &EndpointCacheKey,
434 2 : ) -> Arc<RwLock<EndpointConnPool<C>>> {
435 : // fast path
436 2 : if let Some(pool) = self.global_pool.get(endpoint) {
437 0 : return pool.clone();
438 2 : }
439 2 :
440 2 : // slow path
441 2 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
442 2 : pools: HashMap::new(),
443 2 : total_conns: 0,
444 2 : max_conns: self.config.pool_options.max_conns_per_endpoint,
445 2 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
446 2 : global_connections_count: self.global_connections_count.clone(),
447 2 : global_pool_size_max_conns: self.config.pool_options.max_total_conns,
448 2 : }));
449 2 :
450 2 : // find or create a pool for this endpoint
451 2 : let mut created = false;
452 2 : let pool = self
453 2 : .global_pool
454 2 : .entry(endpoint.clone())
455 2 : .or_insert_with(|| {
456 2 : created = true;
457 2 : new_pool
458 2 : })
459 2 : .clone();
460 2 :
461 2 : // log new global pool size
462 2 : if created {
463 2 : let global_pool_size = self
464 2 : .global_pool_size
465 2 : .fetch_add(1, atomic::Ordering::Relaxed)
466 2 : + 1;
467 2 : info!(
468 0 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
469 : );
470 0 : }
471 :
472 2 : pool
473 2 : }
474 : }
475 :
476 0 : pub(crate) fn poll_client<C: ClientInnerExt>(
477 0 : global_pool: Arc<GlobalConnPool<C>>,
478 0 : ctx: &RequestMonitoring,
479 0 : conn_info: ConnInfo,
480 0 : client: C,
481 0 : mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
482 0 : conn_id: uuid::Uuid,
483 0 : aux: MetricsAuxInfo,
484 0 : ) -> Client<C> {
485 0 : let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
486 0 : let mut session_id = ctx.session_id();
487 0 : let (tx, mut rx) = tokio::sync::watch::channel(session_id);
488 :
489 0 : let span = info_span!(parent: None, "connection", %conn_id);
490 0 : let cold_start_info = ctx.cold_start_info();
491 0 : span.in_scope(|| {
492 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
493 0 : });
494 0 : let pool = match conn_info.endpoint_cache_key() {
495 0 : Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
496 0 : None => Weak::new(),
497 : };
498 0 : let pool_clone = pool.clone();
499 0 :
500 0 : let db_user = conn_info.db_and_user();
501 0 : let idle = global_pool.get_idle_timeout();
502 0 : let cancel = CancellationToken::new();
503 0 : let cancelled = cancel.clone().cancelled_owned();
504 0 :
505 0 : tokio::spawn(
506 0 : async move {
507 0 : let _conn_gauge = conn_gauge;
508 0 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
509 0 : let mut cancelled = pin!(cancelled);
510 0 :
511 0 : poll_fn(move |cx| {
512 0 : if cancelled.as_mut().poll(cx).is_ready() {
513 0 : info!("connection dropped");
514 0 : return Poll::Ready(())
515 0 : }
516 0 :
517 0 : match rx.has_changed() {
518 : Ok(true) => {
519 0 : session_id = *rx.borrow_and_update();
520 0 : info!(%session_id, "changed session");
521 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
522 : }
523 : Err(_) => {
524 0 : info!("connection dropped");
525 0 : return Poll::Ready(())
526 : }
527 0 : _ => {}
528 : }
529 :
530 : // 5 minute idle connection timeout
531 0 : if idle_timeout.as_mut().poll(cx).is_ready() {
532 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
533 0 : info!("connection idle");
534 0 : if let Some(pool) = pool.clone().upgrade() {
535 : // remove client from pool - should close the connection if it's idle.
536 : // does nothing if the client is currently checked-out and in-use
537 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
538 0 : info!("idle connection removed");
539 0 : }
540 0 : }
541 0 : }
542 :
543 : loop {
544 0 : let message = ready!(connection.poll_message(cx));
545 :
546 0 : match message {
547 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
548 0 : info!(%session_id, "notice: {}", notice);
549 : }
550 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
551 0 : warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
552 : }
553 : Some(Ok(_)) => {
554 0 : warn!(%session_id, "unknown message");
555 : }
556 0 : Some(Err(e)) => {
557 0 : error!(%session_id, "connection error: {}", e);
558 0 : break
559 : }
560 : None => {
561 0 : info!("connection closed");
562 0 : break
563 : }
564 : }
565 : }
566 :
567 : // remove from connection pool
568 0 : if let Some(pool) = pool.clone().upgrade() {
569 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
570 0 : info!("closed connection removed");
571 0 : }
572 0 : }
573 :
574 0 : Poll::Ready(())
575 0 : }).await;
576 :
577 0 : }
578 0 : .instrument(span));
579 0 : let inner = ClientInner {
580 0 : inner: client,
581 0 : session: tx,
582 0 : cancel,
583 0 : aux,
584 0 : conn_id,
585 0 : };
586 0 : Client::new(inner, conn_info, pool_clone)
587 0 : }
588 :
589 : struct ClientInner<C: ClientInnerExt> {
590 : inner: C,
591 : session: tokio::sync::watch::Sender<uuid::Uuid>,
592 : cancel: CancellationToken,
593 : aux: MetricsAuxInfo,
594 : conn_id: uuid::Uuid,
595 : }
596 :
597 : impl<C: ClientInnerExt> Drop for ClientInner<C> {
598 7 : fn drop(&mut self) {
599 7 : // on client drop, tell the conn to shut down
600 7 : self.cancel.cancel();
601 7 : }
602 : }
603 :
604 : pub(crate) trait ClientInnerExt: Sync + Send + 'static {
605 : fn is_closed(&self) -> bool;
606 : fn get_process_id(&self) -> i32;
607 : }
608 :
609 : impl ClientInnerExt for tokio_postgres::Client {
610 0 : fn is_closed(&self) -> bool {
611 0 : self.is_closed()
612 0 : }
613 0 : fn get_process_id(&self) -> i32 {
614 0 : self.get_process_id()
615 0 : }
616 : }
617 :
618 : impl<C: ClientInnerExt> ClientInner<C> {
619 8 : pub(crate) fn is_closed(&self) -> bool {
620 8 : self.inner.is_closed()
621 8 : }
622 : }
623 :
624 : impl<C: ClientInnerExt> Client<C> {
625 0 : pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
626 0 : let aux = &self.inner.as_ref().unwrap().aux;
627 0 : USAGE_METRICS.register(Ids {
628 0 : endpoint_id: aux.endpoint_id,
629 0 : branch_id: aux.branch_id,
630 0 : })
631 0 : }
632 : }
633 :
634 : pub(crate) struct Client<C: ClientInnerExt> {
635 : span: Span,
636 : inner: Option<ClientInner<C>>,
637 : conn_info: ConnInfo,
638 : pool: Weak<RwLock<EndpointConnPool<C>>>,
639 : }
640 :
641 : pub(crate) struct Discard<'a, C: ClientInnerExt> {
642 : conn_info: &'a ConnInfo,
643 : pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
644 : }
645 :
646 : impl<C: ClientInnerExt> Client<C> {
647 7 : pub(self) fn new(
648 7 : inner: ClientInner<C>,
649 7 : conn_info: ConnInfo,
650 7 : pool: Weak<RwLock<EndpointConnPool<C>>>,
651 7 : ) -> Self {
652 7 : Self {
653 7 : inner: Some(inner),
654 7 : span: Span::current(),
655 7 : conn_info,
656 7 : pool,
657 7 : }
658 7 : }
659 1 : pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
660 1 : let Self {
661 1 : inner,
662 1 : pool,
663 1 : conn_info,
664 1 : span: _,
665 1 : } = self;
666 1 : let inner = inner.as_mut().expect("client inner should not be removed");
667 1 : (&mut inner.inner, Discard { conn_info, pool })
668 1 : }
669 : }
670 :
671 : impl<C: ClientInnerExt> Discard<'_, C> {
672 0 : pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
673 0 : let conn_info = &self.conn_info;
674 0 : if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
675 0 : info!("pool: throwing away connection '{conn_info}' because connection is not idle");
676 0 : }
677 0 : }
678 1 : pub(crate) fn discard(&mut self) {
679 1 : let conn_info = &self.conn_info;
680 1 : if std::mem::take(self.pool).strong_count() > 0 {
681 1 : info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
682 0 : }
683 1 : }
684 : }
685 :
686 : impl<C: ClientInnerExt> Deref for Client<C> {
687 : type Target = C;
688 :
689 0 : fn deref(&self) -> &Self::Target {
690 0 : &self
691 0 : .inner
692 0 : .as_ref()
693 0 : .expect("client inner should not be removed")
694 0 : .inner
695 0 : }
696 : }
697 :
698 : impl<C: ClientInnerExt> Client<C> {
699 7 : fn do_drop(&mut self) -> Option<impl FnOnce()> {
700 7 : let conn_info = self.conn_info.clone();
701 7 : let client = self
702 7 : .inner
703 7 : .take()
704 7 : .expect("client inner should not be removed");
705 7 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
706 6 : let current_span = self.span.clone();
707 6 : // return connection to the pool
708 6 : return Some(move || {
709 6 : let _span = current_span.enter();
710 6 : EndpointConnPool::put(&conn_pool, &conn_info, client);
711 6 : });
712 1 : }
713 1 : None
714 7 : }
715 : }
716 :
717 : impl<C: ClientInnerExt> Drop for Client<C> {
718 1 : fn drop(&mut self) {
719 1 : if let Some(drop) = self.do_drop() {
720 0 : tokio::task::spawn_blocking(drop);
721 1 : }
722 1 : }
723 : }
724 :
725 : #[cfg(test)]
726 : mod tests {
727 : use std::{mem, sync::atomic::AtomicBool};
728 :
729 : use crate::{
730 : proxy::NeonOptions, serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId,
731 : };
732 :
733 : use super::*;
734 :
735 : struct MockClient(Arc<AtomicBool>);
736 : impl MockClient {
737 6 : fn new(is_closed: bool) -> Self {
738 6 : MockClient(Arc::new(is_closed.into()))
739 6 : }
740 : }
741 : impl ClientInnerExt for MockClient {
742 8 : fn is_closed(&self) -> bool {
743 8 : self.0.load(atomic::Ordering::Relaxed)
744 8 : }
745 0 : fn get_process_id(&self) -> i32 {
746 0 : 0
747 0 : }
748 : }
749 :
750 5 : fn create_inner() -> ClientInner<MockClient> {
751 5 : create_inner_with(MockClient::new(false))
752 5 : }
753 :
754 7 : fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
755 7 : ClientInner {
756 7 : inner: client,
757 7 : session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
758 7 : cancel: CancellationToken::new(),
759 7 : aux: MetricsAuxInfo {
760 7 : endpoint_id: (&EndpointId::from("endpoint")).into(),
761 7 : project_id: (&ProjectId::from("project")).into(),
762 7 : branch_id: (&BranchId::from("branch")).into(),
763 7 : cold_start_info: crate::console::messages::ColdStartInfo::Warm,
764 7 : },
765 7 : conn_id: uuid::Uuid::new_v4(),
766 7 : }
767 7 : }
768 :
769 : #[tokio::test]
770 1 : async fn test_pool() {
771 1 : let _ = env_logger::try_init();
772 1 : let config = Box::leak(Box::new(crate::config::HttpConfig {
773 1 : accept_websockets: false,
774 1 : pool_options: GlobalConnPoolOptions {
775 1 : max_conns_per_endpoint: 2,
776 1 : gc_epoch: Duration::from_secs(1),
777 1 : pool_shards: 2,
778 1 : idle_timeout: Duration::from_secs(1),
779 1 : opt_in: false,
780 1 : max_total_conns: 3,
781 1 : },
782 1 : cancel_set: CancelSet::new(0),
783 1 : client_conn_threshold: u64::MAX,
784 1 : max_request_size_bytes: u64::MAX,
785 1 : max_response_size_bytes: usize::MAX,
786 1 : }));
787 1 : let pool = GlobalConnPool::new(config);
788 1 : let conn_info = ConnInfo {
789 1 : user_info: ComputeUserInfo {
790 1 : user: "user".into(),
791 1 : endpoint: "endpoint".into(),
792 1 : options: NeonOptions::default(),
793 1 : },
794 1 : dbname: "dbname".into(),
795 1 : };
796 1 : let ep_pool = Arc::downgrade(
797 1 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
798 1 : );
799 1 : {
800 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
801 1 : assert_eq!(0, pool.get_global_connections_count());
802 1 : client.inner().1.discard();
803 1 : // Discard should not add the connection from the pool.
804 1 : assert_eq!(0, pool.get_global_connections_count());
805 1 : }
806 1 : {
807 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
808 1 : client.do_drop().unwrap()();
809 1 : mem::forget(client); // drop the client
810 1 : assert_eq!(1, pool.get_global_connections_count());
811 1 : }
812 1 : {
813 1 : let mut closed_client = Client::new(
814 1 : create_inner_with(MockClient::new(true)),
815 1 : conn_info.clone(),
816 1 : ep_pool.clone(),
817 1 : );
818 1 : closed_client.do_drop().unwrap()();
819 1 : mem::forget(closed_client); // drop the client
820 1 : // The closed client shouldn't be added to the pool.
821 1 : assert_eq!(1, pool.get_global_connections_count());
822 1 : }
823 1 : let is_closed: Arc<AtomicBool> = Arc::new(false.into());
824 1 : {
825 1 : let mut client = Client::new(
826 1 : create_inner_with(MockClient(is_closed.clone())),
827 1 : conn_info.clone(),
828 1 : ep_pool.clone(),
829 1 : );
830 1 : client.do_drop().unwrap()();
831 1 : mem::forget(client); // drop the client
832 1 :
833 1 : // The client should be added to the pool.
834 1 : assert_eq!(2, pool.get_global_connections_count());
835 1 : }
836 1 : {
837 1 : let mut client = Client::new(create_inner(), conn_info, ep_pool);
838 1 : client.do_drop().unwrap()();
839 1 : mem::forget(client); // drop the client
840 1 :
841 1 : // The client shouldn't be added to the pool. Because the ep-pool is full.
842 1 : assert_eq!(2, pool.get_global_connections_count());
843 1 : }
844 1 :
845 1 : let conn_info = ConnInfo {
846 1 : user_info: ComputeUserInfo {
847 1 : user: "user".into(),
848 1 : endpoint: "endpoint-2".into(),
849 1 : options: NeonOptions::default(),
850 1 : },
851 1 : dbname: "dbname".into(),
852 1 : };
853 1 : let ep_pool = Arc::downgrade(
854 1 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
855 1 : );
856 1 : {
857 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
858 1 : client.do_drop().unwrap()();
859 1 : mem::forget(client); // drop the client
860 1 : assert_eq!(3, pool.get_global_connections_count());
861 1 : }
862 1 : {
863 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
864 1 : client.do_drop().unwrap()();
865 1 : mem::forget(client); // drop the client
866 1 :
867 1 : // The client shouldn't be added to the pool. Because the global pool is full.
868 1 : assert_eq!(3, pool.get_global_connections_count());
869 1 : }
870 1 :
871 1 : is_closed.store(true, atomic::Ordering::Relaxed);
872 1 : // Do gc for all shards.
873 1 : pool.gc(0);
874 1 : pool.gc(1);
875 1 : // Closed client should be removed from the pool.
876 1 : assert_eq!(2, pool.get_global_connections_count());
877 1 : }
878 : }
|