Line data Source code
1 : use dashmap::DashMap;
2 : use futures::{future::poll_fn, Future};
3 : use parking_lot::RwLock;
4 : use rand::Rng;
5 : use smallvec::SmallVec;
6 : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
7 : use std::{
8 : fmt,
9 : task::{ready, Poll},
10 : };
11 : use std::{
12 : ops::Deref,
13 : sync::atomic::{self, AtomicUsize},
14 : };
15 : use tokio::time::Instant;
16 : use tokio_postgres::tls::NoTlsStream;
17 : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
18 : use tokio_util::sync::CancellationToken;
19 :
20 : use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
21 : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
22 : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
23 : use crate::{
24 : auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
25 : };
26 :
27 : use tracing::{debug, error, warn, Span};
28 : use tracing::{info, info_span, Instrument};
29 :
30 : use super::backend::HttpConnError;
31 :
32 : #[derive(Debug, Clone)]
33 : pub(crate) struct ConnInfo {
34 : pub(crate) user_info: ComputeUserInfo,
35 : pub(crate) dbname: DbName,
36 : pub(crate) auth: AuthData,
37 : }
38 :
39 : #[derive(Debug, Clone)]
40 : pub(crate) enum AuthData {
41 : Password(SmallVec<[u8; 16]>),
42 : Jwt(String),
43 : }
44 :
45 : impl ConnInfo {
46 : // hm, change to hasher to avoid cloning?
47 18 : pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
48 18 : (self.dbname.clone(), self.user_info.user.clone())
49 18 : }
50 :
51 12 : pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
52 12 : // We don't want to cache http connections for ephemeral endpoints.
53 12 : if self.user_info.options.is_ephemeral() {
54 0 : None
55 : } else {
56 12 : Some(self.user_info.endpoint_cache_key())
57 : }
58 12 : }
59 : }
60 :
61 : impl fmt::Display for ConnInfo {
62 : // use custom display to avoid logging password
63 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 0 : write!(
65 0 : f,
66 0 : "{}@{}/{}?{}",
67 0 : self.user_info.user,
68 0 : self.user_info.endpoint,
69 0 : self.dbname,
70 0 : self.user_info.options.get_cache_key("")
71 0 : )
72 0 : }
73 : }
74 :
75 : struct ConnPoolEntry<C: ClientInnerExt> {
76 : conn: ClientInner<C>,
77 : _last_access: std::time::Instant,
78 : }
79 :
80 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
81 : // Number of open connections is limited by the `max_conns_per_endpoint`.
82 : pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
83 : pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
84 : total_conns: usize,
85 : max_conns: usize,
86 : _guard: HttpEndpointPoolsGuard<'static>,
87 : global_connections_count: Arc<AtomicUsize>,
88 : global_pool_size_max_conns: usize,
89 : }
90 :
91 : impl<C: ClientInnerExt> EndpointConnPool<C> {
92 0 : fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
93 0 : let Self {
94 0 : pools,
95 0 : total_conns,
96 0 : global_connections_count,
97 0 : ..
98 0 : } = self;
99 0 : pools.get_mut(&db_user).and_then(|pool_entries| {
100 0 : pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
101 0 : })
102 0 : }
103 :
104 0 : fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
105 0 : let Self {
106 0 : pools,
107 0 : total_conns,
108 0 : global_connections_count,
109 0 : ..
110 0 : } = self;
111 0 : if let Some(pool) = pools.get_mut(&db_user) {
112 0 : let old_len = pool.conns.len();
113 0 : pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
114 0 : let new_len = pool.conns.len();
115 0 : let removed = old_len - new_len;
116 0 : if removed > 0 {
117 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
118 0 : Metrics::get()
119 0 : .proxy
120 0 : .http_pool_opened_connections
121 0 : .get_metric()
122 0 : .dec_by(removed as i64);
123 0 : }
124 0 : *total_conns -= removed;
125 0 : removed > 0
126 : } else {
127 0 : false
128 : }
129 0 : }
130 :
131 36 : fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
132 36 : let conn_id = client.conn_id;
133 36 :
134 36 : if client.is_closed() {
135 6 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
136 6 : return;
137 30 : }
138 30 : let global_max_conn = pool.read().global_pool_size_max_conns;
139 30 : if pool
140 30 : .read()
141 30 : .global_connections_count
142 30 : .load(atomic::Ordering::Relaxed)
143 30 : >= global_max_conn
144 : {
145 6 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
146 6 : return;
147 24 : }
148 24 :
149 24 : // return connection to the pool
150 24 : let mut returned = false;
151 24 : let mut per_db_size = 0;
152 24 : let total_conns = {
153 24 : let mut pool = pool.write();
154 24 :
155 24 : if pool.total_conns < pool.max_conns {
156 18 : let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
157 18 : pool_entries.conns.push(ConnPoolEntry {
158 18 : conn: client,
159 18 : _last_access: std::time::Instant::now(),
160 18 : });
161 18 :
162 18 : returned = true;
163 18 : per_db_size = pool_entries.conns.len();
164 18 :
165 18 : pool.total_conns += 1;
166 18 : pool.global_connections_count
167 18 : .fetch_add(1, atomic::Ordering::Relaxed);
168 18 : Metrics::get()
169 18 : .proxy
170 18 : .http_pool_opened_connections
171 18 : .get_metric()
172 18 : .inc();
173 18 : }
174 :
175 24 : pool.total_conns
176 24 : };
177 24 :
178 24 : // do logging outside of the mutex
179 24 : if returned {
180 18 : info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
181 : } else {
182 6 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
183 : }
184 36 : }
185 : }
186 :
187 : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
188 12 : fn drop(&mut self) {
189 12 : if self.total_conns > 0 {
190 12 : self.global_connections_count
191 12 : .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
192 12 : Metrics::get()
193 12 : .proxy
194 12 : .http_pool_opened_connections
195 12 : .get_metric()
196 12 : .dec_by(self.total_conns as i64);
197 12 : }
198 12 : }
199 : }
200 :
201 : pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
202 : conns: Vec<ConnPoolEntry<C>>,
203 : }
204 :
205 : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
206 12 : fn default() -> Self {
207 12 : Self { conns: Vec::new() }
208 12 : }
209 : }
210 :
211 : impl<C: ClientInnerExt> DbUserConnPool<C> {
212 6 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
213 6 : let old_len = self.conns.len();
214 6 :
215 12 : self.conns.retain(|conn| !conn.conn.is_closed());
216 6 :
217 6 : let new_len = self.conns.len();
218 6 : let removed = old_len - new_len;
219 6 : *conns -= removed;
220 6 : removed
221 6 : }
222 :
223 0 : fn get_conn_entry(
224 0 : &mut self,
225 0 : conns: &mut usize,
226 0 : global_connections_count: Arc<AtomicUsize>,
227 0 : ) -> Option<ConnPoolEntry<C>> {
228 0 : let mut removed = self.clear_closed_clients(conns);
229 0 : let conn = self.conns.pop();
230 0 : if conn.is_some() {
231 0 : *conns -= 1;
232 0 : removed += 1;
233 0 : }
234 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
235 0 : Metrics::get()
236 0 : .proxy
237 0 : .http_pool_opened_connections
238 0 : .get_metric()
239 0 : .dec_by(removed as i64);
240 0 : conn
241 0 : }
242 : }
243 :
244 : pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
245 : // endpoint -> per-endpoint connection pool
246 : //
247 : // That should be a fairly conteded map, so return reference to the per-endpoint
248 : // pool as early as possible and release the lock.
249 : global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
250 :
251 : /// Number of endpoint-connection pools
252 : ///
253 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
254 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
255 : /// It's only used for diagnostics.
256 : global_pool_size: AtomicUsize,
257 :
258 : /// Total number of connections in the pool
259 : global_connections_count: Arc<AtomicUsize>,
260 :
261 : config: &'static crate::config::HttpConfig,
262 : }
263 :
264 : #[derive(Debug, Clone, Copy)]
265 : pub struct GlobalConnPoolOptions {
266 : // Maximum number of connections per one endpoint.
267 : // Can mix different (dbname, username) connections.
268 : // When running out of free slots for a particular endpoint,
269 : // falls back to opening a new connection for each request.
270 : pub max_conns_per_endpoint: usize,
271 :
272 : pub gc_epoch: Duration,
273 :
274 : pub pool_shards: usize,
275 :
276 : pub idle_timeout: Duration,
277 :
278 : pub opt_in: bool,
279 :
280 : // Total number of connections in the pool.
281 : pub max_total_conns: usize,
282 : }
283 :
284 : impl<C: ClientInnerExt> GlobalConnPool<C> {
285 6 : pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
286 6 : let shards = config.pool_options.pool_shards;
287 6 : Arc::new(Self {
288 6 : global_pool: DashMap::with_shard_amount(shards),
289 6 : global_pool_size: AtomicUsize::new(0),
290 6 : config,
291 6 : global_connections_count: Arc::new(AtomicUsize::new(0)),
292 6 : })
293 6 : }
294 :
295 : #[cfg(test)]
296 54 : pub(crate) fn get_global_connections_count(&self) -> usize {
297 54 : self.global_connections_count
298 54 : .load(atomic::Ordering::Relaxed)
299 54 : }
300 :
301 0 : pub(crate) fn get_idle_timeout(&self) -> Duration {
302 0 : self.config.pool_options.idle_timeout
303 0 : }
304 :
305 0 : pub(crate) fn shutdown(&self) {
306 0 : // drops all strong references to endpoint-pools
307 0 : self.global_pool.clear();
308 0 : }
309 :
310 0 : pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
311 0 : let epoch = self.config.pool_options.gc_epoch;
312 0 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
313 : loop {
314 0 : interval.tick().await;
315 :
316 0 : let shard = rng.gen_range(0..self.global_pool.shards().len());
317 0 : self.gc(shard);
318 : }
319 : }
320 :
321 12 : fn gc(&self, shard: usize) {
322 12 : debug!(shard, "pool: performing epoch reclamation");
323 :
324 : // acquire a random shard lock
325 12 : let mut shard = self.global_pool.shards()[shard].write();
326 12 :
327 12 : let timer = Metrics::get()
328 12 : .proxy
329 12 : .http_pool_reclaimation_lag_seconds
330 12 : .start_timer();
331 12 : let current_len = shard.len();
332 12 : let mut clients_removed = 0;
333 12 : shard.retain(|endpoint, x| {
334 : // if the current endpoint pool is unique (no other strong or weak references)
335 : // then it is currently not in use by any connections.
336 12 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
337 : let EndpointConnPool {
338 6 : pools, total_conns, ..
339 6 : } = pool.get_mut();
340 :
341 : // ensure that closed clients are removed
342 6 : for db_pool in pools.values_mut() {
343 6 : clients_removed += db_pool.clear_closed_clients(total_conns);
344 6 : }
345 :
346 : // we only remove this pool if it has no active connections
347 6 : if *total_conns == 0 {
348 0 : info!("pool: discarding pool for endpoint {endpoint}");
349 0 : return false;
350 6 : }
351 6 : }
352 :
353 12 : true
354 12 : });
355 12 :
356 12 : let new_len = shard.len();
357 12 : drop(shard);
358 12 : timer.observe();
359 12 :
360 12 : // Do logging outside of the lock.
361 12 : if clients_removed > 0 {
362 6 : let size = self
363 6 : .global_connections_count
364 6 : .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
365 6 : - clients_removed;
366 6 : Metrics::get()
367 6 : .proxy
368 6 : .http_pool_opened_connections
369 6 : .get_metric()
370 6 : .dec_by(clients_removed as i64);
371 6 : info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
372 6 : }
373 12 : let removed = current_len - new_len;
374 12 :
375 12 : if removed > 0 {
376 0 : let global_pool_size = self
377 0 : .global_pool_size
378 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
379 0 : - removed;
380 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
381 12 : }
382 12 : }
383 :
384 0 : pub(crate) fn get(
385 0 : self: &Arc<Self>,
386 0 : ctx: &RequestMonitoring,
387 0 : conn_info: &ConnInfo,
388 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
389 0 : let mut client: Option<ClientInner<C>> = None;
390 0 : let Some(endpoint) = conn_info.endpoint_cache_key() else {
391 0 : return Ok(None);
392 : };
393 :
394 0 : let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
395 0 : if let Some(entry) = endpoint_pool
396 0 : .write()
397 0 : .get_conn_entry(conn_info.db_and_user())
398 0 : {
399 0 : client = Some(entry.conn);
400 0 : }
401 0 : let endpoint_pool = Arc::downgrade(&endpoint_pool);
402 :
403 : // ok return cached connection if found and establish a new one otherwise
404 0 : if let Some(client) = client {
405 0 : if client.is_closed() {
406 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
407 0 : return Ok(None);
408 0 : }
409 0 : tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
410 0 : tracing::Span::current().record(
411 0 : "pid",
412 0 : tracing::field::display(client.inner.get_process_id()),
413 0 : );
414 0 : info!(
415 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
416 0 : "pool: reusing connection '{conn_info}'"
417 : );
418 0 : client.session.send(ctx.session_id())?;
419 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
420 0 : ctx.success();
421 0 : return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
422 0 : }
423 0 : Ok(None)
424 0 : }
425 :
426 12 : fn get_or_create_endpoint_pool(
427 12 : self: &Arc<Self>,
428 12 : endpoint: &EndpointCacheKey,
429 12 : ) -> Arc<RwLock<EndpointConnPool<C>>> {
430 : // fast path
431 12 : if let Some(pool) = self.global_pool.get(endpoint) {
432 0 : return pool.clone();
433 12 : }
434 12 :
435 12 : // slow path
436 12 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
437 12 : pools: HashMap::new(),
438 12 : total_conns: 0,
439 12 : max_conns: self.config.pool_options.max_conns_per_endpoint,
440 12 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
441 12 : global_connections_count: self.global_connections_count.clone(),
442 12 : global_pool_size_max_conns: self.config.pool_options.max_total_conns,
443 12 : }));
444 12 :
445 12 : // find or create a pool for this endpoint
446 12 : let mut created = false;
447 12 : let pool = self
448 12 : .global_pool
449 12 : .entry(endpoint.clone())
450 12 : .or_insert_with(|| {
451 12 : created = true;
452 12 : new_pool
453 12 : })
454 12 : .clone();
455 12 :
456 12 : // log new global pool size
457 12 : if created {
458 12 : let global_pool_size = self
459 12 : .global_pool_size
460 12 : .fetch_add(1, atomic::Ordering::Relaxed)
461 12 : + 1;
462 12 : info!(
463 0 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
464 : );
465 0 : }
466 :
467 12 : pool
468 12 : }
469 : }
470 :
471 0 : pub(crate) fn poll_client<C: ClientInnerExt>(
472 0 : global_pool: Arc<GlobalConnPool<C>>,
473 0 : ctx: &RequestMonitoring,
474 0 : conn_info: ConnInfo,
475 0 : client: C,
476 0 : mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
477 0 : conn_id: uuid::Uuid,
478 0 : aux: MetricsAuxInfo,
479 0 : ) -> Client<C> {
480 0 : let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
481 0 : let mut session_id = ctx.session_id();
482 0 : let (tx, mut rx) = tokio::sync::watch::channel(session_id);
483 :
484 0 : let span = info_span!(parent: None, "connection", %conn_id);
485 0 : let cold_start_info = ctx.cold_start_info();
486 0 : span.in_scope(|| {
487 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
488 0 : });
489 0 : let pool = match conn_info.endpoint_cache_key() {
490 0 : Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
491 0 : None => Weak::new(),
492 : };
493 0 : let pool_clone = pool.clone();
494 0 :
495 0 : let db_user = conn_info.db_and_user();
496 0 : let idle = global_pool.get_idle_timeout();
497 0 : let cancel = CancellationToken::new();
498 0 : let cancelled = cancel.clone().cancelled_owned();
499 0 :
500 0 : tokio::spawn(
501 0 : async move {
502 0 : let _conn_gauge = conn_gauge;
503 0 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
504 0 : let mut cancelled = pin!(cancelled);
505 0 :
506 0 : poll_fn(move |cx| {
507 0 : if cancelled.as_mut().poll(cx).is_ready() {
508 0 : info!("connection dropped");
509 0 : return Poll::Ready(())
510 0 : }
511 0 :
512 0 : match rx.has_changed() {
513 : Ok(true) => {
514 0 : session_id = *rx.borrow_and_update();
515 0 : info!(%session_id, "changed session");
516 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
517 : }
518 : Err(_) => {
519 0 : info!("connection dropped");
520 0 : return Poll::Ready(())
521 : }
522 0 : _ => {}
523 : }
524 :
525 : // 5 minute idle connection timeout
526 0 : if idle_timeout.as_mut().poll(cx).is_ready() {
527 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
528 0 : info!("connection idle");
529 0 : if let Some(pool) = pool.clone().upgrade() {
530 : // remove client from pool - should close the connection if it's idle.
531 : // does nothing if the client is currently checked-out and in-use
532 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
533 0 : info!("idle connection removed");
534 0 : }
535 0 : }
536 0 : }
537 :
538 : loop {
539 0 : let message = ready!(connection.poll_message(cx));
540 :
541 0 : match message {
542 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
543 0 : info!(%session_id, "notice: {}", notice);
544 : }
545 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
546 0 : warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
547 : }
548 : Some(Ok(_)) => {
549 0 : warn!(%session_id, "unknown message");
550 : }
551 0 : Some(Err(e)) => {
552 0 : error!(%session_id, "connection error: {}", e);
553 0 : break
554 : }
555 : None => {
556 0 : info!("connection closed");
557 0 : break
558 : }
559 : }
560 : }
561 :
562 : // remove from connection pool
563 0 : if let Some(pool) = pool.clone().upgrade() {
564 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
565 0 : info!("closed connection removed");
566 0 : }
567 0 : }
568 :
569 0 : Poll::Ready(())
570 0 : }).await;
571 :
572 0 : }
573 0 : .instrument(span));
574 0 : let inner = ClientInner {
575 0 : inner: client,
576 0 : session: tx,
577 0 : cancel,
578 0 : aux,
579 0 : conn_id,
580 0 : };
581 0 : Client::new(inner, conn_info, pool_clone)
582 0 : }
583 :
584 : struct ClientInner<C: ClientInnerExt> {
585 : inner: C,
586 : session: tokio::sync::watch::Sender<uuid::Uuid>,
587 : cancel: CancellationToken,
588 : aux: MetricsAuxInfo,
589 : conn_id: uuid::Uuid,
590 : }
591 :
592 : impl<C: ClientInnerExt> Drop for ClientInner<C> {
593 42 : fn drop(&mut self) {
594 42 : // on client drop, tell the conn to shut down
595 42 : self.cancel.cancel();
596 42 : }
597 : }
598 :
599 : pub(crate) trait ClientInnerExt: Sync + Send + 'static {
600 : fn is_closed(&self) -> bool;
601 : fn get_process_id(&self) -> i32;
602 : }
603 :
604 : impl ClientInnerExt for tokio_postgres::Client {
605 0 : fn is_closed(&self) -> bool {
606 0 : self.is_closed()
607 0 : }
608 0 : fn get_process_id(&self) -> i32 {
609 0 : self.get_process_id()
610 0 : }
611 : }
612 :
613 : impl<C: ClientInnerExt> ClientInner<C> {
614 48 : pub(crate) fn is_closed(&self) -> bool {
615 48 : self.inner.is_closed()
616 48 : }
617 : }
618 :
619 : impl<C: ClientInnerExt> Client<C> {
620 0 : pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
621 0 : let aux = &self.inner.as_ref().unwrap().aux;
622 0 : USAGE_METRICS.register(Ids {
623 0 : endpoint_id: aux.endpoint_id,
624 0 : branch_id: aux.branch_id,
625 0 : })
626 0 : }
627 : }
628 :
629 : pub(crate) struct Client<C: ClientInnerExt> {
630 : span: Span,
631 : inner: Option<ClientInner<C>>,
632 : conn_info: ConnInfo,
633 : pool: Weak<RwLock<EndpointConnPool<C>>>,
634 : }
635 :
636 : pub(crate) struct Discard<'a, C: ClientInnerExt> {
637 : conn_info: &'a ConnInfo,
638 : pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
639 : }
640 :
641 : impl<C: ClientInnerExt> Client<C> {
642 42 : pub(self) fn new(
643 42 : inner: ClientInner<C>,
644 42 : conn_info: ConnInfo,
645 42 : pool: Weak<RwLock<EndpointConnPool<C>>>,
646 42 : ) -> Self {
647 42 : Self {
648 42 : inner: Some(inner),
649 42 : span: Span::current(),
650 42 : conn_info,
651 42 : pool,
652 42 : }
653 42 : }
654 6 : pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
655 6 : let Self {
656 6 : inner,
657 6 : pool,
658 6 : conn_info,
659 6 : span: _,
660 6 : } = self;
661 6 : let inner = inner.as_mut().expect("client inner should not be removed");
662 6 : (&mut inner.inner, Discard { conn_info, pool })
663 6 : }
664 : }
665 :
666 : impl<C: ClientInnerExt> Discard<'_, C> {
667 0 : pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
668 0 : let conn_info = &self.conn_info;
669 0 : if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
670 0 : info!("pool: throwing away connection '{conn_info}' because connection is not idle");
671 0 : }
672 0 : }
673 6 : pub(crate) fn discard(&mut self) {
674 6 : let conn_info = &self.conn_info;
675 6 : if std::mem::take(self.pool).strong_count() > 0 {
676 6 : info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
677 0 : }
678 6 : }
679 : }
680 :
681 : impl<C: ClientInnerExt> Deref for Client<C> {
682 : type Target = C;
683 :
684 0 : fn deref(&self) -> &Self::Target {
685 0 : &self
686 0 : .inner
687 0 : .as_ref()
688 0 : .expect("client inner should not be removed")
689 0 : .inner
690 0 : }
691 : }
692 :
693 : impl<C: ClientInnerExt> Client<C> {
694 42 : fn do_drop(&mut self) -> Option<impl FnOnce()> {
695 42 : let conn_info = self.conn_info.clone();
696 42 : let client = self
697 42 : .inner
698 42 : .take()
699 42 : .expect("client inner should not be removed");
700 42 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
701 36 : let current_span = self.span.clone();
702 36 : // return connection to the pool
703 36 : return Some(move || {
704 36 : let _span = current_span.enter();
705 36 : EndpointConnPool::put(&conn_pool, &conn_info, client);
706 36 : });
707 6 : }
708 6 : None
709 42 : }
710 : }
711 :
712 : impl<C: ClientInnerExt> Drop for Client<C> {
713 6 : fn drop(&mut self) {
714 6 : if let Some(drop) = self.do_drop() {
715 0 : tokio::task::spawn_blocking(drop);
716 6 : }
717 6 : }
718 : }
719 :
720 : #[cfg(test)]
721 : mod tests {
722 : use std::{mem, sync::atomic::AtomicBool};
723 :
724 : use crate::{
725 : proxy::NeonOptions, serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId,
726 : };
727 :
728 : use super::*;
729 :
730 : struct MockClient(Arc<AtomicBool>);
731 : impl MockClient {
732 36 : fn new(is_closed: bool) -> Self {
733 36 : MockClient(Arc::new(is_closed.into()))
734 36 : }
735 : }
736 : impl ClientInnerExt for MockClient {
737 48 : fn is_closed(&self) -> bool {
738 48 : self.0.load(atomic::Ordering::Relaxed)
739 48 : }
740 0 : fn get_process_id(&self) -> i32 {
741 0 : 0
742 0 : }
743 : }
744 :
745 30 : fn create_inner() -> ClientInner<MockClient> {
746 30 : create_inner_with(MockClient::new(false))
747 30 : }
748 :
749 42 : fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
750 42 : ClientInner {
751 42 : inner: client,
752 42 : session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
753 42 : cancel: CancellationToken::new(),
754 42 : aux: MetricsAuxInfo {
755 42 : endpoint_id: (&EndpointId::from("endpoint")).into(),
756 42 : project_id: (&ProjectId::from("project")).into(),
757 42 : branch_id: (&BranchId::from("branch")).into(),
758 42 : cold_start_info: crate::console::messages::ColdStartInfo::Warm,
759 42 : },
760 42 : conn_id: uuid::Uuid::new_v4(),
761 42 : }
762 42 : }
763 :
764 : #[tokio::test]
765 6 : async fn test_pool() {
766 6 : let _ = env_logger::try_init();
767 6 : let config = Box::leak(Box::new(crate::config::HttpConfig {
768 6 : accept_websockets: false,
769 6 : pool_options: GlobalConnPoolOptions {
770 6 : max_conns_per_endpoint: 2,
771 6 : gc_epoch: Duration::from_secs(1),
772 6 : pool_shards: 2,
773 6 : idle_timeout: Duration::from_secs(1),
774 6 : opt_in: false,
775 6 : max_total_conns: 3,
776 6 : },
777 6 : cancel_set: CancelSet::new(0),
778 6 : client_conn_threshold: u64::MAX,
779 6 : }));
780 6 : let pool = GlobalConnPool::new(config);
781 6 : let conn_info = ConnInfo {
782 6 : user_info: ComputeUserInfo {
783 6 : user: "user".into(),
784 6 : endpoint: "endpoint".into(),
785 6 : options: NeonOptions::default(),
786 6 : },
787 6 : dbname: "dbname".into(),
788 6 : auth: AuthData::Password("password".as_bytes().into()),
789 6 : };
790 6 : let ep_pool = Arc::downgrade(
791 6 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
792 6 : );
793 6 : {
794 6 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
795 6 : assert_eq!(0, pool.get_global_connections_count());
796 6 : client.inner().1.discard();
797 6 : // Discard should not add the connection from the pool.
798 6 : assert_eq!(0, pool.get_global_connections_count());
799 6 : }
800 6 : {
801 6 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
802 6 : client.do_drop().unwrap()();
803 6 : mem::forget(client); // drop the client
804 6 : assert_eq!(1, pool.get_global_connections_count());
805 6 : }
806 6 : {
807 6 : let mut closed_client = Client::new(
808 6 : create_inner_with(MockClient::new(true)),
809 6 : conn_info.clone(),
810 6 : ep_pool.clone(),
811 6 : );
812 6 : closed_client.do_drop().unwrap()();
813 6 : mem::forget(closed_client); // drop the client
814 6 : // The closed client shouldn't be added to the pool.
815 6 : assert_eq!(1, pool.get_global_connections_count());
816 6 : }
817 6 : let is_closed: Arc<AtomicBool> = Arc::new(false.into());
818 6 : {
819 6 : let mut client = Client::new(
820 6 : create_inner_with(MockClient(is_closed.clone())),
821 6 : conn_info.clone(),
822 6 : ep_pool.clone(),
823 6 : );
824 6 : client.do_drop().unwrap()();
825 6 : mem::forget(client); // drop the client
826 6 :
827 6 : // The client should be added to the pool.
828 6 : assert_eq!(2, pool.get_global_connections_count());
829 6 : }
830 6 : {
831 6 : let mut client = Client::new(create_inner(), conn_info, ep_pool);
832 6 : client.do_drop().unwrap()();
833 6 : mem::forget(client); // drop the client
834 6 :
835 6 : // The client shouldn't be added to the pool. Because the ep-pool is full.
836 6 : assert_eq!(2, pool.get_global_connections_count());
837 6 : }
838 6 :
839 6 : let conn_info = ConnInfo {
840 6 : user_info: ComputeUserInfo {
841 6 : user: "user".into(),
842 6 : endpoint: "endpoint-2".into(),
843 6 : options: NeonOptions::default(),
844 6 : },
845 6 : dbname: "dbname".into(),
846 6 : auth: AuthData::Password("password".as_bytes().into()),
847 6 : };
848 6 : let ep_pool = Arc::downgrade(
849 6 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
850 6 : );
851 6 : {
852 6 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
853 6 : client.do_drop().unwrap()();
854 6 : mem::forget(client); // drop the client
855 6 : assert_eq!(3, pool.get_global_connections_count());
856 6 : }
857 6 : {
858 6 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
859 6 : client.do_drop().unwrap()();
860 6 : mem::forget(client); // drop the client
861 6 :
862 6 : // The client shouldn't be added to the pool. Because the global pool is full.
863 6 : assert_eq!(3, pool.get_global_connections_count());
864 6 : }
865 6 :
866 6 : is_closed.store(true, atomic::Ordering::Relaxed);
867 6 : // Do gc for all shards.
868 6 : pool.gc(0);
869 6 : pool.gc(1);
870 6 : // Closed client should be removed from the pool.
871 6 : assert_eq!(2, pool.get_global_connections_count());
872 6 : }
873 : }
|