Line data Source code
1 : use std::collections::HashMap;
2 : use std::ops::Deref;
3 : use std::sync::atomic::{self, AtomicUsize};
4 : use std::sync::{Arc, Weak};
5 : use std::time::Duration;
6 :
7 : use dashmap::DashMap;
8 : use parking_lot::RwLock;
9 : use rand::Rng;
10 : use tokio_postgres::ReadyForQueryStatus;
11 : use tracing::{debug, info, Span};
12 :
13 : use super::backend::HttpConnError;
14 : use super::conn_pool::ClientDataRemote;
15 : use super::http_conn_pool::ClientDataHttp;
16 : use super::local_conn_pool::ClientDataLocal;
17 : use crate::auth::backend::ComputeUserInfo;
18 : use crate::context::RequestContext;
19 : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
20 : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
21 : use crate::types::{DbName, EndpointCacheKey, RoleName};
22 : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
23 :
24 : #[derive(Debug, Clone)]
25 : pub(crate) struct ConnInfo {
26 : pub(crate) user_info: ComputeUserInfo,
27 : pub(crate) dbname: DbName,
28 : }
29 :
30 : impl ConnInfo {
31 : // hm, change to hasher to avoid cloning?
32 3 : pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
33 3 : (self.dbname.clone(), self.user_info.user.clone())
34 3 : }
35 :
36 2 : pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
37 2 : // We don't want to cache http connections for ephemeral endpoints.
38 2 : if self.user_info.options.is_ephemeral() {
39 0 : None
40 : } else {
41 2 : Some(self.user_info.endpoint_cache_key())
42 : }
43 2 : }
44 : }
45 :
46 : pub(crate) enum ClientDataEnum {
47 : Remote(ClientDataRemote),
48 : Local(ClientDataLocal),
49 : #[allow(dead_code)]
50 : Http(ClientDataHttp),
51 : }
52 :
53 : pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
54 : pub(crate) inner: C,
55 : pub(crate) aux: MetricsAuxInfo,
56 : pub(crate) conn_id: uuid::Uuid,
57 : pub(crate) data: ClientDataEnum, // custom client data like session, key, jti
58 : }
59 :
60 : impl<C: ClientInnerExt> Drop for ClientInnerCommon<C> {
61 7 : fn drop(&mut self) {
62 7 : match &mut self.data {
63 7 : ClientDataEnum::Remote(remote_data) => {
64 7 : remote_data.cancel();
65 7 : }
66 0 : ClientDataEnum::Local(local_data) => {
67 0 : local_data.cancel();
68 0 : }
69 0 : ClientDataEnum::Http(_http_data) => (),
70 : }
71 7 : }
72 : }
73 :
74 : impl<C: ClientInnerExt> ClientInnerCommon<C> {
75 6 : pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
76 6 : self.conn_id
77 6 : }
78 :
79 0 : pub(crate) fn get_data(&mut self) -> &mut ClientDataEnum {
80 0 : &mut self.data
81 0 : }
82 : }
83 :
84 : pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
85 : pub(crate) conn: ClientInnerCommon<C>,
86 : pub(crate) _last_access: std::time::Instant,
87 : }
88 :
89 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
90 : // Number of open connections is limited by the `max_conns_per_endpoint`.
91 : pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
92 : pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
93 : total_conns: usize,
94 : max_conns: usize,
95 : _guard: HttpEndpointPoolsGuard<'static>,
96 : global_connections_count: Arc<AtomicUsize>,
97 : global_pool_size_max_conns: usize,
98 : pool_name: String,
99 : }
100 :
101 : impl<C: ClientInnerExt> EndpointConnPool<C> {
102 0 : pub(crate) fn new(
103 0 : hmap: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
104 0 : tconns: usize,
105 0 : max_conns_per_endpoint: usize,
106 0 : global_connections_count: Arc<AtomicUsize>,
107 0 : max_total_conns: usize,
108 0 : pname: String,
109 0 : ) -> Self {
110 0 : Self {
111 0 : pools: hmap,
112 0 : total_conns: tconns,
113 0 : max_conns: max_conns_per_endpoint,
114 0 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
115 0 : global_connections_count,
116 0 : global_pool_size_max_conns: max_total_conns,
117 0 : pool_name: pname,
118 0 : }
119 0 : }
120 :
121 0 : pub(crate) fn get_conn_entry(
122 0 : &mut self,
123 0 : db_user: (DbName, RoleName),
124 0 : ) -> Option<ConnPoolEntry<C>> {
125 0 : let Self {
126 0 : pools,
127 0 : total_conns,
128 0 : global_connections_count,
129 0 : ..
130 0 : } = self;
131 0 : pools.get_mut(&db_user).and_then(|pool_entries| {
132 0 : let (entry, removed) = pool_entries.get_conn_entry(total_conns);
133 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
134 0 : entry
135 0 : })
136 0 : }
137 :
138 0 : pub(crate) fn remove_client(
139 0 : &mut self,
140 0 : db_user: (DbName, RoleName),
141 0 : conn_id: uuid::Uuid,
142 0 : ) -> bool {
143 0 : let Self {
144 0 : pools,
145 0 : total_conns,
146 0 : global_connections_count,
147 0 : ..
148 0 : } = self;
149 0 : if let Some(pool) = pools.get_mut(&db_user) {
150 0 : let old_len = pool.get_conns().len();
151 0 : pool.get_conns()
152 0 : .retain(|conn| conn.conn.get_conn_id() != conn_id);
153 0 : let new_len = pool.get_conns().len();
154 0 : let removed = old_len - new_len;
155 0 : if removed > 0 {
156 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
157 0 : Metrics::get()
158 0 : .proxy
159 0 : .http_pool_opened_connections
160 0 : .get_metric()
161 0 : .dec_by(removed as i64);
162 0 : }
163 0 : *total_conns -= removed;
164 0 : removed > 0
165 : } else {
166 0 : false
167 : }
168 0 : }
169 :
170 6 : pub(crate) fn get_name(&self) -> &str {
171 6 : &self.pool_name
172 6 : }
173 :
174 0 : pub(crate) fn get_pool(&self, db_user: (DbName, RoleName)) -> Option<&DbUserConnPool<C>> {
175 0 : self.pools.get(&db_user)
176 0 : }
177 :
178 0 : pub(crate) fn get_pool_mut(
179 0 : &mut self,
180 0 : db_user: (DbName, RoleName),
181 0 : ) -> Option<&mut DbUserConnPool<C>> {
182 0 : self.pools.get_mut(&db_user)
183 0 : }
184 :
185 6 : pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
186 6 : let conn_id = client.get_conn_id();
187 6 : let pool_name = pool.read().get_name().to_string();
188 6 : if client.inner.is_closed() {
189 1 : info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
190 1 : return;
191 5 : }
192 5 :
193 5 : let global_max_conn = pool.read().global_pool_size_max_conns;
194 5 : if pool
195 5 : .read()
196 5 : .global_connections_count
197 5 : .load(atomic::Ordering::Relaxed)
198 5 : >= global_max_conn
199 : {
200 1 : info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
201 1 : return;
202 4 : }
203 4 :
204 4 : // return connection to the pool
205 4 : let mut returned = false;
206 4 : let mut per_db_size = 0;
207 4 : let total_conns = {
208 4 : let mut pool = pool.write();
209 4 :
210 4 : if pool.total_conns < pool.max_conns {
211 3 : let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
212 3 : pool_entries.get_conns().push(ConnPoolEntry {
213 3 : conn: client,
214 3 : _last_access: std::time::Instant::now(),
215 3 : });
216 3 :
217 3 : returned = true;
218 3 : per_db_size = pool_entries.get_conns().len();
219 3 :
220 3 : pool.total_conns += 1;
221 3 : pool.global_connections_count
222 3 : .fetch_add(1, atomic::Ordering::Relaxed);
223 3 : Metrics::get()
224 3 : .proxy
225 3 : .http_pool_opened_connections
226 3 : .get_metric()
227 3 : .inc();
228 3 : }
229 :
230 4 : pool.total_conns
231 4 : };
232 4 :
233 4 : // do logging outside of the mutex
234 4 : if returned {
235 3 : debug!(%conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
236 : } else {
237 1 : info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
238 : }
239 6 : }
240 : }
241 :
242 : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
243 2 : fn drop(&mut self) {
244 2 : if self.total_conns > 0 {
245 2 : self.global_connections_count
246 2 : .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
247 2 : Metrics::get()
248 2 : .proxy
249 2 : .http_pool_opened_connections
250 2 : .get_metric()
251 2 : .dec_by(self.total_conns as i64);
252 2 : }
253 2 : }
254 : }
255 :
256 : pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
257 : pub(crate) conns: Vec<ConnPoolEntry<C>>,
258 : pub(crate) initialized: Option<bool>, // a bit ugly, exists only for local pools
259 : }
260 :
261 : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
262 2 : fn default() -> Self {
263 2 : Self {
264 2 : conns: Vec::new(),
265 2 : initialized: None,
266 2 : }
267 2 : }
268 : }
269 :
270 : pub(crate) trait DbUserConn<C: ClientInnerExt>: Default {
271 : fn set_initialized(&mut self);
272 : fn is_initialized(&self) -> bool;
273 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize;
274 : fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize);
275 : fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>>;
276 : }
277 :
278 : impl<C: ClientInnerExt> DbUserConn<C> for DbUserConnPool<C> {
279 0 : fn set_initialized(&mut self) {
280 0 : self.initialized = Some(true);
281 0 : }
282 :
283 0 : fn is_initialized(&self) -> bool {
284 0 : self.initialized.unwrap_or(false)
285 0 : }
286 :
287 1 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
288 1 : let old_len = self.conns.len();
289 1 :
290 2 : self.conns.retain(|conn| !conn.conn.inner.is_closed());
291 1 :
292 1 : let new_len = self.conns.len();
293 1 : let removed = old_len - new_len;
294 1 : *conns -= removed;
295 1 : removed
296 1 : }
297 :
298 0 : fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize) {
299 0 : let mut removed = self.clear_closed_clients(conns);
300 0 : let conn = self.conns.pop();
301 0 : if conn.is_some() {
302 0 : *conns -= 1;
303 0 : removed += 1;
304 0 : }
305 :
306 0 : Metrics::get()
307 0 : .proxy
308 0 : .http_pool_opened_connections
309 0 : .get_metric()
310 0 : .dec_by(removed as i64);
311 0 :
312 0 : (conn, removed)
313 0 : }
314 :
315 6 : fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>> {
316 6 : &mut self.conns
317 6 : }
318 : }
319 :
320 : pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
321 : // endpoint -> per-endpoint connection pool
322 : //
323 : // That should be a fairly conteded map, so return reference to the per-endpoint
324 : // pool as early as possible and release the lock.
325 : global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
326 :
327 : /// Number of endpoint-connection pools
328 : ///
329 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
330 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
331 : /// It's only used for diagnostics.
332 : global_pool_size: AtomicUsize,
333 :
334 : /// Total number of connections in the pool
335 : global_connections_count: Arc<AtomicUsize>,
336 :
337 : config: &'static crate::config::HttpConfig,
338 : }
339 :
340 : #[derive(Debug, Clone, Copy)]
341 : pub struct GlobalConnPoolOptions {
342 : // Maximum number of connections per one endpoint.
343 : // Can mix different (dbname, username) connections.
344 : // When running out of free slots for a particular endpoint,
345 : // falls back to opening a new connection for each request.
346 : pub max_conns_per_endpoint: usize,
347 :
348 : pub gc_epoch: Duration,
349 :
350 : pub pool_shards: usize,
351 :
352 : pub idle_timeout: Duration,
353 :
354 : pub opt_in: bool,
355 :
356 : // Total number of connections in the pool.
357 : pub max_total_conns: usize,
358 : }
359 :
360 : impl<C: ClientInnerExt> GlobalConnPool<C> {
361 1 : pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
362 1 : let shards = config.pool_options.pool_shards;
363 1 : Arc::new(Self {
364 1 : global_pool: DashMap::with_shard_amount(shards),
365 1 : global_pool_size: AtomicUsize::new(0),
366 1 : config,
367 1 : global_connections_count: Arc::new(AtomicUsize::new(0)),
368 1 : })
369 1 : }
370 :
371 : #[cfg(test)]
372 9 : pub(crate) fn get_global_connections_count(&self) -> usize {
373 9 : self.global_connections_count
374 9 : .load(atomic::Ordering::Relaxed)
375 9 : }
376 :
377 0 : pub(crate) fn get_idle_timeout(&self) -> Duration {
378 0 : self.config.pool_options.idle_timeout
379 0 : }
380 :
381 0 : pub(crate) fn get(
382 0 : self: &Arc<Self>,
383 0 : ctx: &RequestContext,
384 0 : conn_info: &ConnInfo,
385 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
386 0 : let mut client: Option<ClientInnerCommon<C>> = None;
387 0 : let Some(endpoint) = conn_info.endpoint_cache_key() else {
388 0 : return Ok(None);
389 : };
390 :
391 0 : let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
392 0 : if let Some(entry) = endpoint_pool
393 0 : .write()
394 0 : .get_conn_entry(conn_info.db_and_user())
395 0 : {
396 0 : client = Some(entry.conn);
397 0 : }
398 0 : let endpoint_pool = Arc::downgrade(&endpoint_pool);
399 :
400 : // ok return cached connection if found and establish a new one otherwise
401 0 : if let Some(mut client) = client {
402 0 : if client.inner.is_closed() {
403 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
404 0 : return Ok(None);
405 0 : }
406 0 : tracing::Span::current()
407 0 : .record("conn_id", tracing::field::display(client.get_conn_id()));
408 0 : tracing::Span::current().record(
409 0 : "pid",
410 0 : tracing::field::display(client.inner.get_process_id()),
411 0 : );
412 0 : debug!(
413 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
414 0 : "pool: reusing connection '{conn_info}'"
415 : );
416 :
417 0 : match client.get_data() {
418 0 : ClientDataEnum::Local(data) => {
419 0 : data.session().send(ctx.session_id())?;
420 : }
421 :
422 0 : ClientDataEnum::Remote(data) => {
423 0 : data.session().send(ctx.session_id())?;
424 : }
425 0 : ClientDataEnum::Http(_) => (),
426 : }
427 :
428 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
429 0 : ctx.success();
430 0 : return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
431 0 : }
432 0 : Ok(None)
433 0 : }
434 :
435 0 : pub(crate) fn shutdown(&self) {
436 0 : // drops all strong references to endpoint-pools
437 0 : self.global_pool.clear();
438 0 : }
439 :
440 0 : pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
441 0 : let epoch = self.config.pool_options.gc_epoch;
442 0 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
443 : loop {
444 0 : interval.tick().await;
445 :
446 0 : let shard = rng.gen_range(0..self.global_pool.shards().len());
447 0 : self.gc(shard);
448 : }
449 : }
450 :
451 2 : pub(crate) fn gc(&self, shard: usize) {
452 2 : debug!(shard, "pool: performing epoch reclamation");
453 :
454 : // acquire a random shard lock
455 2 : let mut shard = self.global_pool.shards()[shard].write();
456 2 :
457 2 : let timer = Metrics::get()
458 2 : .proxy
459 2 : .http_pool_reclaimation_lag_seconds
460 2 : .start_timer();
461 2 : let current_len = shard.len();
462 2 : let mut clients_removed = 0;
463 2 : shard.retain(|endpoint, x| {
464 : // if the current endpoint pool is unique (no other strong or weak references)
465 : // then it is currently not in use by any connections.
466 2 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
467 : let EndpointConnPool {
468 1 : pools, total_conns, ..
469 1 : } = pool.get_mut();
470 :
471 : // ensure that closed clients are removed
472 1 : for db_pool in pools.values_mut() {
473 1 : clients_removed += db_pool.clear_closed_clients(total_conns);
474 1 : }
475 :
476 : // we only remove this pool if it has no active connections
477 1 : if *total_conns == 0 {
478 0 : info!("pool: discarding pool for endpoint {endpoint}");
479 0 : return false;
480 1 : }
481 1 : }
482 :
483 2 : true
484 2 : });
485 2 :
486 2 : let new_len = shard.len();
487 2 : drop(shard);
488 2 : timer.observe();
489 2 :
490 2 : // Do logging outside of the lock.
491 2 : if clients_removed > 0 {
492 1 : let size = self
493 1 : .global_connections_count
494 1 : .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
495 1 : - clients_removed;
496 1 : Metrics::get()
497 1 : .proxy
498 1 : .http_pool_opened_connections
499 1 : .get_metric()
500 1 : .dec_by(clients_removed as i64);
501 1 : info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
502 1 : }
503 2 : let removed = current_len - new_len;
504 2 :
505 2 : if removed > 0 {
506 0 : let global_pool_size = self
507 0 : .global_pool_size
508 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
509 0 : - removed;
510 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
511 2 : }
512 2 : }
513 :
514 2 : pub(crate) fn get_or_create_endpoint_pool(
515 2 : self: &Arc<Self>,
516 2 : endpoint: &EndpointCacheKey,
517 2 : ) -> Arc<RwLock<EndpointConnPool<C>>> {
518 : // fast path
519 2 : if let Some(pool) = self.global_pool.get(endpoint) {
520 0 : return pool.clone();
521 2 : }
522 2 :
523 2 : // slow path
524 2 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
525 2 : pools: HashMap::new(),
526 2 : total_conns: 0,
527 2 : max_conns: self.config.pool_options.max_conns_per_endpoint,
528 2 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
529 2 : global_connections_count: self.global_connections_count.clone(),
530 2 : global_pool_size_max_conns: self.config.pool_options.max_total_conns,
531 2 : pool_name: String::from("remote"),
532 2 : }));
533 2 :
534 2 : // find or create a pool for this endpoint
535 2 : let mut created = false;
536 2 : let pool = self
537 2 : .global_pool
538 2 : .entry(endpoint.clone())
539 2 : .or_insert_with(|| {
540 2 : created = true;
541 2 : new_pool
542 2 : })
543 2 : .clone();
544 2 :
545 2 : // log new global pool size
546 2 : if created {
547 2 : let global_pool_size = self
548 2 : .global_pool_size
549 2 : .fetch_add(1, atomic::Ordering::Relaxed)
550 2 : + 1;
551 2 : info!(
552 0 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
553 : );
554 0 : }
555 :
556 2 : pool
557 2 : }
558 : }
559 :
560 : pub(crate) struct Client<C: ClientInnerExt> {
561 : span: Span,
562 : inner: Option<ClientInnerCommon<C>>,
563 : conn_info: ConnInfo,
564 : pool: Weak<RwLock<EndpointConnPool<C>>>,
565 : }
566 :
567 : pub(crate) struct Discard<'a, C: ClientInnerExt> {
568 : conn_info: &'a ConnInfo,
569 : pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
570 : }
571 :
572 : impl<C: ClientInnerExt> Client<C> {
573 7 : pub(crate) fn new(
574 7 : inner: ClientInnerCommon<C>,
575 7 : conn_info: ConnInfo,
576 7 : pool: Weak<RwLock<EndpointConnPool<C>>>,
577 7 : ) -> Self {
578 7 : Self {
579 7 : inner: Some(inner),
580 7 : span: Span::current(),
581 7 : conn_info,
582 7 : pool,
583 7 : }
584 7 : }
585 :
586 0 : pub(crate) fn client_inner(&mut self) -> (&mut ClientInnerCommon<C>, Discard<'_, C>) {
587 0 : let Self {
588 0 : inner,
589 0 : pool,
590 0 : conn_info,
591 0 : span: _,
592 0 : } = self;
593 0 : let inner_m = inner.as_mut().expect("client inner should not be removed");
594 0 : (inner_m, Discard { conn_info, pool })
595 0 : }
596 :
597 1 : pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
598 1 : let Self {
599 1 : inner,
600 1 : pool,
601 1 : conn_info,
602 1 : span: _,
603 1 : } = self;
604 1 : let inner = inner.as_mut().expect("client inner should not be removed");
605 1 : (&mut inner.inner, Discard { conn_info, pool })
606 1 : }
607 :
608 0 : pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
609 0 : let aux = &self.inner.as_ref().unwrap().aux;
610 0 : USAGE_METRICS.register(Ids {
611 0 : endpoint_id: aux.endpoint_id,
612 0 : branch_id: aux.branch_id,
613 0 : })
614 0 : }
615 :
616 7 : pub(crate) fn do_drop(&mut self) -> Option<impl FnOnce() + use<C>> {
617 7 : let conn_info = self.conn_info.clone();
618 7 : let client = self
619 7 : .inner
620 7 : .take()
621 7 : .expect("client inner should not be removed");
622 7 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
623 6 : let current_span = self.span.clone();
624 6 : // return connection to the pool
625 6 : return Some(move || {
626 6 : let _span = current_span.enter();
627 6 : EndpointConnPool::put(&conn_pool, &conn_info, client);
628 6 : });
629 1 : }
630 1 : None
631 7 : }
632 : }
633 :
634 : impl<C: ClientInnerExt> Drop for Client<C> {
635 1 : fn drop(&mut self) {
636 1 : if let Some(drop) = self.do_drop() {
637 0 : tokio::task::spawn_blocking(drop);
638 1 : }
639 1 : }
640 : }
641 :
642 : impl<C: ClientInnerExt> Deref for Client<C> {
643 : type Target = C;
644 :
645 0 : fn deref(&self) -> &Self::Target {
646 0 : &self
647 0 : .inner
648 0 : .as_ref()
649 0 : .expect("client inner should not be removed")
650 0 : .inner
651 0 : }
652 : }
653 :
654 : pub(crate) trait ClientInnerExt: Sync + Send + 'static {
655 : fn is_closed(&self) -> bool;
656 : fn get_process_id(&self) -> i32;
657 : }
658 :
659 : impl ClientInnerExt for tokio_postgres::Client {
660 0 : fn is_closed(&self) -> bool {
661 0 : self.is_closed()
662 0 : }
663 :
664 0 : fn get_process_id(&self) -> i32 {
665 0 : self.get_process_id()
666 0 : }
667 : }
668 :
669 : impl<C: ClientInnerExt> Discard<'_, C> {
670 0 : pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
671 0 : let conn_info = &self.conn_info;
672 0 : if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
673 0 : info!("pool: throwing away connection '{conn_info}' because connection is not idle");
674 0 : }
675 0 : }
676 1 : pub(crate) fn discard(&mut self) {
677 1 : let conn_info = &self.conn_info;
678 1 : if std::mem::take(self.pool).strong_count() > 0 {
679 1 : info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
680 0 : }
681 1 : }
682 : }
|