Line data Source code
1 : use dashmap::DashMap;
2 : use futures::{future::poll_fn, Future};
3 : use metrics::IntCounterPairGuard;
4 : use parking_lot::RwLock;
5 : use rand::Rng;
6 : use smallvec::SmallVec;
7 : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
8 : use std::{
9 : fmt,
10 : task::{ready, Poll},
11 : };
12 : use std::{
13 : ops::Deref,
14 : sync::atomic::{self, AtomicUsize},
15 : };
16 : use tokio::time::Instant;
17 : use tokio_postgres::tls::NoTlsStream;
18 : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
19 :
20 : use crate::console::messages::MetricsAuxInfo;
21 : use crate::metrics::{ENDPOINT_POOLS, GC_LATENCY, NUM_OPEN_CLIENTS_IN_HTTP_POOL};
22 : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
23 : use crate::{
24 : auth::backend::ComputeUserInfo, context::RequestMonitoring, metrics::NUM_DB_CONNECTIONS_GAUGE,
25 : DbName, EndpointCacheKey, RoleName,
26 : };
27 :
28 : use tracing::{debug, error, warn, Span};
29 : use tracing::{info, info_span, Instrument};
30 :
31 : use super::backend::HttpConnError;
32 :
33 114 : #[derive(Debug, Clone)]
34 : pub struct ConnInfo {
35 : pub user_info: ComputeUserInfo,
36 : pub dbname: DbName,
37 : pub password: SmallVec<[u8; 16]>,
38 : }
39 :
40 : impl ConnInfo {
41 : // hm, change to hasher to avoid cloning?
42 110 : pub fn db_and_user(&self) -> (DbName, RoleName) {
43 110 : (self.dbname.clone(), self.user_info.user.clone())
44 110 : }
45 :
46 68 : pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
47 68 : self.user_info.endpoint_cache_key()
48 68 : }
49 : }
50 :
51 : impl fmt::Display for ConnInfo {
52 : // use custom display to avoid logging password
53 256 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 256 : write!(
55 256 : f,
56 256 : "{}@{}/{}?{}",
57 256 : self.user_info.user,
58 256 : self.user_info.endpoint,
59 256 : self.dbname,
60 256 : self.user_info.options.get_cache_key("")
61 256 : )
62 256 : }
63 : }
64 :
65 : struct ConnPoolEntry<C: ClientInnerExt> {
66 : conn: ClientInner<C>,
67 : _last_access: std::time::Instant,
68 : }
69 :
70 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
71 : // Number of open connections is limited by the `max_conns_per_endpoint`.
72 : pub struct EndpointConnPool<C: ClientInnerExt> {
73 : pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
74 : total_conns: usize,
75 : max_conns: usize,
76 : _guard: IntCounterPairGuard,
77 : global_connections_count: Arc<AtomicUsize>,
78 : global_pool_size_max_conns: usize,
79 : }
80 :
81 : impl<C: ClientInnerExt> EndpointConnPool<C> {
82 24 : fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
83 24 : let Self {
84 24 : pools,
85 24 : total_conns,
86 24 : global_connections_count,
87 24 : ..
88 24 : } = self;
89 24 : pools.get_mut(&db_user).and_then(|pool_entries| {
90 14 : pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
91 24 : })
92 24 : }
93 :
94 3 : fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
95 3 : let Self {
96 3 : pools,
97 3 : total_conns,
98 3 : global_connections_count,
99 3 : ..
100 3 : } = self;
101 3 : if let Some(pool) = pools.get_mut(&db_user) {
102 1 : let old_len = pool.conns.len();
103 1 : pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
104 1 : let new_len = pool.conns.len();
105 1 : let removed = old_len - new_len;
106 1 : if removed > 0 {
107 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
108 0 : NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64);
109 1 : }
110 1 : *total_conns -= removed;
111 1 : removed > 0
112 : } else {
113 2 : false
114 : }
115 3 : }
116 :
117 52 : fn put(
118 52 : pool: &RwLock<Self>,
119 52 : conn_info: &ConnInfo,
120 52 : client: ClientInner<C>,
121 52 : ) -> anyhow::Result<()> {
122 52 : let conn_id = client.conn_id;
123 52 :
124 52 : if client.is_closed() {
125 2 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
126 2 : return Ok(());
127 50 : }
128 50 : let global_max_conn = pool.read().global_pool_size_max_conns;
129 50 : if pool
130 50 : .read()
131 50 : .global_connections_count
132 50 : .load(atomic::Ordering::Relaxed)
133 50 : >= global_max_conn
134 : {
135 2 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
136 2 : return Ok(());
137 48 : }
138 48 :
139 48 : // return connection to the pool
140 48 : let mut returned = false;
141 48 : let mut per_db_size = 0;
142 48 : let total_conns = {
143 48 : let mut pool = pool.write();
144 48 :
145 48 : if pool.total_conns < pool.max_conns {
146 46 : let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
147 46 : pool_entries.conns.push(ConnPoolEntry {
148 46 : conn: client,
149 46 : _last_access: std::time::Instant::now(),
150 46 : });
151 46 :
152 46 : returned = true;
153 46 : per_db_size = pool_entries.conns.len();
154 46 :
155 46 : pool.total_conns += 1;
156 46 : pool.global_connections_count
157 46 : .fetch_add(1, atomic::Ordering::Relaxed);
158 46 : NUM_OPEN_CLIENTS_IN_HTTP_POOL.inc();
159 46 : }
160 :
161 48 : pool.total_conns
162 48 : };
163 48 :
164 48 : // do logging outside of the mutex
165 48 : if returned {
166 46 : info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
167 : } else {
168 2 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
169 : }
170 :
171 48 : Ok(())
172 52 : }
173 : }
174 :
175 : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
176 17 : fn drop(&mut self) {
177 17 : if self.total_conns > 0 {
178 16 : self.global_connections_count
179 16 : .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
180 16 : NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(self.total_conns as i64);
181 16 : }
182 17 : }
183 : }
184 :
185 : pub struct DbUserConnPool<C: ClientInnerExt> {
186 : conns: Vec<ConnPoolEntry<C>>,
187 : }
188 :
189 : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
190 16 : fn default() -> Self {
191 16 : Self { conns: Vec::new() }
192 16 : }
193 : }
194 :
195 : impl<C: ClientInnerExt> DbUserConnPool<C> {
196 16 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
197 16 : let old_len = self.conns.len();
198 16 :
199 16 : self.conns.retain(|conn| !conn.conn.is_closed());
200 16 :
201 16 : let new_len = self.conns.len();
202 16 : let removed = old_len - new_len;
203 16 : *conns -= removed;
204 16 : removed
205 16 : }
206 :
207 14 : fn get_conn_entry(
208 14 : &mut self,
209 14 : conns: &mut usize,
210 14 : global_connections_count: Arc<AtomicUsize>,
211 14 : ) -> Option<ConnPoolEntry<C>> {
212 14 : let mut removed = self.clear_closed_clients(conns);
213 14 : let conn = self.conns.pop();
214 14 : if conn.is_some() {
215 4 : *conns -= 1;
216 4 : removed += 1;
217 10 : }
218 14 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
219 14 : NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64);
220 14 : conn
221 14 : }
222 : }
223 :
224 : pub struct GlobalConnPool<C: ClientInnerExt> {
225 : // endpoint -> per-endpoint connection pool
226 : //
227 : // That should be a fairly conteded map, so return reference to the per-endpoint
228 : // pool as early as possible and release the lock.
229 : global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
230 :
231 : /// Number of endpoint-connection pools
232 : ///
233 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
234 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
235 : /// It's only used for diagnostics.
236 : global_pool_size: AtomicUsize,
237 :
238 : /// Total number of connections in the pool
239 : global_connections_count: Arc<AtomicUsize>,
240 :
241 : config: &'static crate::config::HttpConfig,
242 : }
243 :
244 0 : #[derive(Debug, Clone, Copy)]
245 : pub struct GlobalConnPoolOptions {
246 : // Maximum number of connections per one endpoint.
247 : // Can mix different (dbname, username) connections.
248 : // When running out of free slots for a particular endpoint,
249 : // falls back to opening a new connection for each request.
250 : pub max_conns_per_endpoint: usize,
251 :
252 : pub gc_epoch: Duration,
253 :
254 : pub pool_shards: usize,
255 :
256 : pub idle_timeout: Duration,
257 :
258 : pub opt_in: bool,
259 :
260 : // Total number of connections in the pool.
261 : pub max_total_conns: usize,
262 : }
263 :
264 : impl<C: ClientInnerExt> GlobalConnPool<C> {
265 27 : pub fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
266 27 : let shards = config.pool_options.pool_shards;
267 27 : Arc::new(Self {
268 27 : global_pool: DashMap::with_shard_amount(shards),
269 27 : global_pool_size: AtomicUsize::new(0),
270 27 : config,
271 27 : global_connections_count: Arc::new(AtomicUsize::new(0)),
272 27 : })
273 27 : }
274 :
275 : #[cfg(test)]
276 18 : pub fn get_global_connections_count(&self) -> usize {
277 18 : self.global_connections_count
278 18 : .load(atomic::Ordering::Relaxed)
279 18 : }
280 :
281 40 : pub fn get_idle_timeout(&self) -> Duration {
282 40 : self.config.pool_options.idle_timeout
283 40 : }
284 :
285 25 : pub fn shutdown(&self) {
286 25 : // drops all strong references to endpoint-pools
287 25 : self.global_pool.clear();
288 25 : }
289 :
290 25 : pub async fn gc_worker(&self, mut rng: impl Rng) {
291 25 : let epoch = self.config.pool_options.gc_epoch;
292 25 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
293 : loop {
294 58 : interval.tick().await;
295 :
296 33 : let shard = rng.gen_range(0..self.global_pool.shards().len());
297 33 : self.gc(shard);
298 : }
299 : }
300 :
301 37 : fn gc(&self, shard: usize) {
302 37 : debug!(shard, "pool: performing epoch reclamation");
303 :
304 : // acquire a random shard lock
305 37 : let mut shard = self.global_pool.shards()[shard].write();
306 37 :
307 37 : let timer = GC_LATENCY.start_timer();
308 37 : let current_len = shard.len();
309 37 : let mut clients_removed = 0;
310 37 : shard.retain(|endpoint, x| {
311 : // if the current endpoint pool is unique (no other strong or weak references)
312 : // then it is currently not in use by any connections.
313 4 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
314 : let EndpointConnPool {
315 2 : pools, total_conns, ..
316 2 : } = pool.get_mut();
317 2 :
318 2 : // ensure that closed clients are removed
319 2 : pools.iter_mut().for_each(|(_, db_pool)| {
320 2 : clients_removed += db_pool.clear_closed_clients(total_conns);
321 2 : });
322 2 :
323 2 : // we only remove this pool if it has no active connections
324 2 : if *total_conns == 0 {
325 0 : info!("pool: discarding pool for endpoint {endpoint}");
326 0 : return false;
327 2 : }
328 2 : }
329 :
330 4 : true
331 37 : });
332 37 :
333 37 : let new_len = shard.len();
334 37 : drop(shard);
335 37 : timer.observe_duration();
336 37 :
337 37 : // Do logging outside of the lock.
338 37 : if clients_removed > 0 {
339 2 : let size = self
340 2 : .global_connections_count
341 2 : .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
342 2 : - clients_removed;
343 2 : NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(clients_removed as i64);
344 2 : info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
345 35 : }
346 37 : let removed = current_len - new_len;
347 37 :
348 37 : if removed > 0 {
349 0 : let global_pool_size = self
350 0 : .global_pool_size
351 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
352 0 : - removed;
353 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
354 37 : }
355 37 : }
356 :
357 24 : pub async fn get(
358 24 : self: &Arc<Self>,
359 24 : ctx: &mut RequestMonitoring,
360 24 : conn_info: &ConnInfo,
361 24 : ) -> Result<Option<Client<C>>, HttpConnError> {
362 24 : let mut client: Option<ClientInner<C>> = None;
363 24 :
364 24 : let endpoint_pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
365 24 : if let Some(entry) = endpoint_pool
366 24 : .write()
367 24 : .get_conn_entry(conn_info.db_and_user())
368 : {
369 4 : client = Some(entry.conn)
370 20 : }
371 24 : let endpoint_pool = Arc::downgrade(&endpoint_pool);
372 :
373 : // ok return cached connection if found and establish a new one otherwise
374 24 : if let Some(client) = client {
375 4 : if client.is_closed() {
376 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
377 0 : return Ok(None);
378 : } else {
379 4 : tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
380 4 : tracing::Span::current().record(
381 4 : "pid",
382 4 : &tracing::field::display(client.inner.get_process_id()),
383 4 : );
384 4 : info!("pool: reusing connection '{conn_info}'");
385 4 : client.session.send(ctx.session_id)?;
386 4 : ctx.latency_timer.pool_hit();
387 4 : ctx.latency_timer.success();
388 4 : return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
389 : }
390 20 : }
391 20 : Ok(None)
392 24 : }
393 :
394 68 : fn get_or_create_endpoint_pool(
395 68 : self: &Arc<Self>,
396 68 : endpoint: &EndpointCacheKey,
397 68 : ) -> Arc<RwLock<EndpointConnPool<C>>> {
398 : // fast path
399 68 : if let Some(pool) = self.global_pool.get(endpoint) {
400 51 : return pool.clone();
401 17 : }
402 17 :
403 17 : // slow path
404 17 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
405 17 : pools: HashMap::new(),
406 17 : total_conns: 0,
407 17 : max_conns: self.config.pool_options.max_conns_per_endpoint,
408 17 : _guard: ENDPOINT_POOLS.guard(),
409 17 : global_connections_count: self.global_connections_count.clone(),
410 17 : global_pool_size_max_conns: self.config.pool_options.max_total_conns,
411 17 : }));
412 17 :
413 17 : // find or create a pool for this endpoint
414 17 : let mut created = false;
415 17 : let pool = self
416 17 : .global_pool
417 17 : .entry(endpoint.clone())
418 17 : .or_insert_with(|| {
419 17 : created = true;
420 17 : new_pool
421 17 : })
422 17 : .clone();
423 17 :
424 17 : // log new global pool size
425 17 : if created {
426 17 : let global_pool_size = self
427 17 : .global_pool_size
428 17 : .fetch_add(1, atomic::Ordering::Relaxed)
429 17 : + 1;
430 17 : info!(
431 13 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
432 13 : );
433 0 : }
434 :
435 17 : pool
436 68 : }
437 : }
438 :
439 40 : pub fn poll_client<C: ClientInnerExt>(
440 40 : global_pool: Arc<GlobalConnPool<C>>,
441 40 : ctx: &mut RequestMonitoring,
442 40 : conn_info: ConnInfo,
443 40 : client: C,
444 40 : mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
445 40 : conn_id: uuid::Uuid,
446 40 : aux: MetricsAuxInfo,
447 40 : ) -> Client<C> {
448 40 : let conn_gauge = NUM_DB_CONNECTIONS_GAUGE
449 40 : .with_label_values(&[ctx.protocol])
450 40 : .guard();
451 40 : let mut session_id = ctx.session_id;
452 40 : let (tx, mut rx) = tokio::sync::watch::channel(session_id);
453 :
454 40 : let span = info_span!(parent: None, "connection", %conn_id);
455 40 : span.in_scope(|| {
456 40 : info!(%conn_info, %session_id, "new connection");
457 40 : });
458 40 : let pool =
459 40 : Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()));
460 40 : let pool_clone = pool.clone();
461 40 :
462 40 : let db_user = conn_info.db_and_user();
463 40 : let idle = global_pool.get_idle_timeout();
464 40 : tokio::spawn(
465 40 : async move {
466 40 : let _conn_gauge = conn_gauge;
467 40 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
468 7044 : poll_fn(move |cx| {
469 7044 : if matches!(rx.has_changed(), Ok(true)) {
470 4 : session_id = *rx.borrow_and_update();
471 4 : info!(%session_id, "changed session");
472 4 : idle_timeout.as_mut().reset(Instant::now() + idle);
473 7040 : }
474 :
475 : // 5 minute idle connection timeout
476 7044 : if idle_timeout.as_mut().poll(cx).is_ready() {
477 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
478 0 : info!("connection idle");
479 0 : if let Some(pool) = pool.clone().upgrade() {
480 : // remove client from pool - should close the connection if it's idle.
481 : // does nothing if the client is currently checked-out and in-use
482 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
483 0 : info!("idle connection removed");
484 0 : }
485 0 : }
486 7044 : }
487 :
488 : loop {
489 7044 : let message = ready!(connection.poll_message(cx));
490 :
491 0 : match message {
492 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
493 0 : info!(%session_id, "notice: {}", notice);
494 : }
495 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
496 0 : warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
497 : }
498 : Some(Ok(_)) => {
499 0 : warn!(%session_id, "unknown message");
500 : }
501 0 : Some(Err(e)) => {
502 0 : error!(%session_id, "connection error: {}", e);
503 0 : break
504 : }
505 : None => {
506 39 : info!("connection closed");
507 39 : break
508 : }
509 : }
510 : }
511 :
512 : // remove from connection pool
513 39 : if let Some(pool) = pool.clone().upgrade() {
514 3 : if pool.write().remove_client(db_user.clone(), conn_id) {
515 0 : info!("closed connection removed");
516 3 : }
517 36 : }
518 :
519 39 : Poll::Ready(())
520 7044 : }).await;
521 :
522 40 : }
523 40 : .instrument(span));
524 40 : let inner = ClientInner {
525 40 : inner: client,
526 40 : session: tx,
527 40 : aux,
528 40 : conn_id,
529 40 : };
530 40 : Client::new(inner, conn_info, pool_clone)
531 40 : }
532 :
533 : struct ClientInner<C: ClientInnerExt> {
534 : inner: C,
535 : session: tokio::sync::watch::Sender<uuid::Uuid>,
536 : aux: MetricsAuxInfo,
537 : conn_id: uuid::Uuid,
538 : }
539 :
540 : pub trait ClientInnerExt: Sync + Send + 'static {
541 : fn is_closed(&self) -> bool;
542 : fn get_process_id(&self) -> i32;
543 : }
544 :
545 : impl ClientInnerExt for tokio_postgres::Client {
546 48 : fn is_closed(&self) -> bool {
547 48 : self.is_closed()
548 48 : }
549 4 : fn get_process_id(&self) -> i32 {
550 4 : self.get_process_id()
551 4 : }
552 : }
553 :
554 : impl<C: ClientInnerExt> ClientInner<C> {
555 64 : pub fn is_closed(&self) -> bool {
556 64 : self.inner.is_closed()
557 64 : }
558 : }
559 :
560 : impl<C: ClientInnerExt> Client<C> {
561 42 : pub fn metrics(&self) -> Arc<MetricCounter> {
562 42 : let aux = &self.inner.as_ref().unwrap().aux;
563 42 : USAGE_METRICS.register(Ids {
564 42 : endpoint_id: aux.endpoint_id.clone(),
565 42 : branch_id: aux.branch_id.clone(),
566 42 : })
567 42 : }
568 : }
569 :
570 : pub struct Client<C: ClientInnerExt> {
571 : span: Span,
572 : inner: Option<ClientInner<C>>,
573 : conn_info: ConnInfo,
574 : pool: Weak<RwLock<EndpointConnPool<C>>>,
575 : }
576 :
577 : pub struct Discard<'a, C: ClientInnerExt> {
578 : conn_info: &'a ConnInfo,
579 : pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
580 : }
581 :
582 : impl<C: ClientInnerExt> Client<C> {
583 58 : pub(self) fn new(
584 58 : inner: ClientInner<C>,
585 58 : conn_info: ConnInfo,
586 58 : pool: Weak<RwLock<EndpointConnPool<C>>>,
587 58 : ) -> Self {
588 58 : Self {
589 58 : inner: Some(inner),
590 58 : span: Span::current(),
591 58 : conn_info,
592 58 : pool,
593 58 : }
594 58 : }
595 46 : pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
596 46 : let Self {
597 46 : inner,
598 46 : pool,
599 46 : conn_info,
600 46 : span: _,
601 46 : } = self;
602 46 : let inner = inner.as_mut().expect("client inner should not be removed");
603 46 : (&mut inner.inner, Discard { pool, conn_info })
604 46 : }
605 :
606 39 : pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
607 39 : self.inner().1.check_idle(status)
608 39 : }
609 4 : pub fn discard(&mut self) {
610 4 : self.inner().1.discard()
611 4 : }
612 : }
613 :
614 : impl<C: ClientInnerExt> Discard<'_, C> {
615 42 : pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
616 42 : let conn_info = &self.conn_info;
617 42 : if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
618 2 : info!("pool: throwing away connection '{conn_info}' because connection is not idle")
619 40 : }
620 42 : }
621 4 : pub fn discard(&mut self) {
622 4 : let conn_info = &self.conn_info;
623 4 : if std::mem::take(self.pool).strong_count() > 0 {
624 4 : info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
625 0 : }
626 4 : }
627 : }
628 :
629 : impl<C: ClientInnerExt> Deref for Client<C> {
630 : type Target = C;
631 :
632 41 : fn deref(&self) -> &Self::Target {
633 41 : &self
634 41 : .inner
635 41 : .as_ref()
636 41 : .expect("client inner should not be removed")
637 41 : .inner
638 41 : }
639 : }
640 :
641 : impl<C: ClientInnerExt> Client<C> {
642 58 : fn do_drop(&mut self) -> Option<impl FnOnce()> {
643 58 : let conn_info = self.conn_info.clone();
644 58 : let client = self
645 58 : .inner
646 58 : .take()
647 58 : .expect("client inner should not be removed");
648 58 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
649 52 : let current_span = self.span.clone();
650 52 : // return connection to the pool
651 52 : return Some(move || {
652 52 : let _span = current_span.enter();
653 52 : let _ = EndpointConnPool::put(&conn_pool, &conn_info, client);
654 52 : });
655 6 : }
656 6 : None
657 58 : }
658 : }
659 :
660 : impl<C: ClientInnerExt> Drop for Client<C> {
661 46 : fn drop(&mut self) {
662 46 : if let Some(drop) = self.do_drop() {
663 40 : tokio::task::spawn_blocking(drop);
664 40 : }
665 46 : }
666 : }
667 :
668 : #[cfg(test)]
669 : mod tests {
670 : use env_logger;
671 : use std::{mem, sync::atomic::AtomicBool};
672 :
673 : use super::*;
674 :
675 : struct MockClient(Arc<AtomicBool>);
676 : impl MockClient {
677 12 : fn new(is_closed: bool) -> Self {
678 12 : MockClient(Arc::new(is_closed.into()))
679 12 : }
680 : }
681 : impl ClientInnerExt for MockClient {
682 16 : fn is_closed(&self) -> bool {
683 16 : self.0.load(atomic::Ordering::Relaxed)
684 16 : }
685 0 : fn get_process_id(&self) -> i32 {
686 0 : 0
687 0 : }
688 : }
689 :
690 10 : fn create_inner() -> ClientInner<MockClient> {
691 10 : create_inner_with(MockClient::new(false))
692 10 : }
693 :
694 14 : fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
695 14 : ClientInner {
696 14 : inner: client,
697 14 : session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
698 14 : aux: Default::default(),
699 14 : conn_id: uuid::Uuid::new_v4(),
700 14 : }
701 14 : }
702 :
703 2 : #[tokio::test]
704 2 : async fn test_pool() {
705 2 : let _ = env_logger::try_init();
706 2 : let config = Box::leak(Box::new(crate::config::HttpConfig {
707 2 : pool_options: GlobalConnPoolOptions {
708 2 : max_conns_per_endpoint: 2,
709 2 : gc_epoch: Duration::from_secs(1),
710 2 : pool_shards: 2,
711 2 : idle_timeout: Duration::from_secs(1),
712 2 : opt_in: false,
713 2 : max_total_conns: 3,
714 2 : },
715 2 : request_timeout: Duration::from_secs(1),
716 2 : }));
717 2 : let pool = GlobalConnPool::new(config);
718 2 : let conn_info = ConnInfo {
719 2 : user_info: ComputeUserInfo {
720 2 : user: "user".into(),
721 2 : endpoint: "endpoint".into(),
722 2 : options: Default::default(),
723 2 : },
724 2 : dbname: "dbname".into(),
725 2 : password: "password".as_bytes().into(),
726 2 : };
727 2 : let ep_pool =
728 2 : Arc::downgrade(&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()));
729 2 : {
730 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
731 2 : assert_eq!(0, pool.get_global_connections_count());
732 2 : client.discard();
733 2 : // Discard should not add the connection from the pool.
734 2 : assert_eq!(0, pool.get_global_connections_count());
735 2 : }
736 2 : {
737 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
738 2 : client.do_drop().unwrap()();
739 2 : mem::forget(client); // drop the client
740 2 : assert_eq!(1, pool.get_global_connections_count());
741 2 : }
742 2 : {
743 2 : let mut closed_client = Client::new(
744 2 : create_inner_with(MockClient::new(true)),
745 2 : conn_info.clone(),
746 2 : ep_pool.clone(),
747 2 : );
748 2 : closed_client.do_drop().unwrap()();
749 2 : mem::forget(closed_client); // drop the client
750 2 : // The closed client shouldn't be added to the pool.
751 2 : assert_eq!(1, pool.get_global_connections_count());
752 2 : }
753 2 : let is_closed: Arc<AtomicBool> = Arc::new(false.into());
754 2 : {
755 2 : let mut client = Client::new(
756 2 : create_inner_with(MockClient(is_closed.clone())),
757 2 : conn_info.clone(),
758 2 : ep_pool.clone(),
759 2 : );
760 2 : client.do_drop().unwrap()();
761 2 : mem::forget(client); // drop the client
762 2 :
763 2 : // The client should be added to the pool.
764 2 : assert_eq!(2, pool.get_global_connections_count());
765 2 : }
766 2 : {
767 2 : let mut client = Client::new(create_inner(), conn_info, ep_pool);
768 2 : client.do_drop().unwrap()();
769 2 : mem::forget(client); // drop the client
770 2 :
771 2 : // The client shouldn't be added to the pool. Because the ep-pool is full.
772 2 : assert_eq!(2, pool.get_global_connections_count());
773 2 : }
774 2 :
775 2 : let conn_info = ConnInfo {
776 2 : user_info: ComputeUserInfo {
777 2 : user: "user".into(),
778 2 : endpoint: "endpoint-2".into(),
779 2 : options: Default::default(),
780 2 : },
781 2 : dbname: "dbname".into(),
782 2 : password: "password".as_bytes().into(),
783 2 : };
784 2 : let ep_pool =
785 2 : Arc::downgrade(&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()));
786 2 : {
787 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
788 2 : client.do_drop().unwrap()();
789 2 : mem::forget(client); // drop the client
790 2 : assert_eq!(3, pool.get_global_connections_count());
791 2 : }
792 2 : {
793 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
794 2 : client.do_drop().unwrap()();
795 2 : mem::forget(client); // drop the client
796 2 :
797 2 : // The client shouldn't be added to the pool. Because the global pool is full.
798 2 : assert_eq!(3, pool.get_global_connections_count());
799 2 : }
800 2 :
801 2 : is_closed.store(true, atomic::Ordering::Relaxed);
802 2 : // Do gc for all shards.
803 2 : pool.gc(0);
804 2 : pool.gc(1);
805 2 : // Closed client should be removed from the pool.
806 2 : assert_eq!(2, pool.get_global_connections_count());
807 2 : }
808 : }
|