Line data Source code
1 : use std::collections::HashMap;
2 : use std::marker::PhantomData;
3 : use std::ops::Deref;
4 : use std::sync::atomic::{self, AtomicUsize};
5 : use std::sync::{Arc, Weak};
6 : use std::time::Duration;
7 :
8 : use dashmap::DashMap;
9 : use parking_lot::RwLock;
10 : use rand::Rng;
11 : use tokio_postgres::ReadyForQueryStatus;
12 : use tracing::{debug, info, Span};
13 :
14 : use super::backend::HttpConnError;
15 : use super::conn_pool::ClientDataRemote;
16 : use super::http_conn_pool::ClientDataHttp;
17 : use super::local_conn_pool::ClientDataLocal;
18 : use crate::auth::backend::ComputeUserInfo;
19 : use crate::context::RequestContext;
20 : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
21 : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
22 : use crate::types::{DbName, EndpointCacheKey, RoleName};
23 : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
24 :
25 : #[derive(Debug, Clone)]
26 : pub(crate) struct ConnInfo {
27 : pub(crate) user_info: ComputeUserInfo,
28 : pub(crate) dbname: DbName,
29 : }
30 :
31 : impl ConnInfo {
32 : // hm, change to hasher to avoid cloning?
33 3 : pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
34 3 : (self.dbname.clone(), self.user_info.user.clone())
35 3 : }
36 :
37 2 : pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
38 2 : // We don't want to cache http connections for ephemeral endpoints.
39 2 : if self.user_info.options.is_ephemeral() {
40 0 : None
41 : } else {
42 2 : Some(self.user_info.endpoint_cache_key())
43 : }
44 2 : }
45 : }
46 :
47 : #[derive(Clone)]
48 : pub(crate) enum ClientDataEnum {
49 : Remote(ClientDataRemote),
50 : Local(ClientDataLocal),
51 : Http(ClientDataHttp),
52 : }
53 :
54 : #[derive(Clone)]
55 : pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
56 : pub(crate) inner: C,
57 : pub(crate) aux: MetricsAuxInfo,
58 : pub(crate) conn_id: uuid::Uuid,
59 : pub(crate) data: ClientDataEnum, // custom client data like session, key, jti
60 : }
61 :
62 : impl<C: ClientInnerExt> Drop for ClientInnerCommon<C> {
63 7 : fn drop(&mut self) {
64 7 : match &mut self.data {
65 7 : ClientDataEnum::Remote(remote_data) => {
66 7 : remote_data.cancel();
67 7 : }
68 0 : ClientDataEnum::Local(local_data) => {
69 0 : local_data.cancel();
70 0 : }
71 0 : ClientDataEnum::Http(_http_data) => (),
72 : }
73 7 : }
74 : }
75 :
76 : impl<C: ClientInnerExt> ClientInnerCommon<C> {
77 6 : pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
78 6 : self.conn_id
79 6 : }
80 :
81 0 : pub(crate) fn get_data(&mut self) -> &mut ClientDataEnum {
82 0 : &mut self.data
83 0 : }
84 : }
85 :
86 : pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
87 : pub(crate) conn: ClientInnerCommon<C>,
88 : pub(crate) _last_access: std::time::Instant,
89 : }
90 :
91 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
92 : // Number of open connections is limited by the `max_conns_per_endpoint`.
93 : pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
94 : pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
95 : total_conns: usize,
96 : /// max # connections per endpoint
97 : max_conns: usize,
98 : _guard: HttpEndpointPoolsGuard<'static>,
99 : global_connections_count: Arc<AtomicUsize>,
100 : global_pool_size_max_conns: usize,
101 : pool_name: String,
102 : }
103 :
104 : impl<C: ClientInnerExt> EndpointConnPool<C> {
105 0 : pub(crate) fn new(
106 0 : hmap: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
107 0 : tconns: usize,
108 0 : max_conns_per_endpoint: usize,
109 0 : global_connections_count: Arc<AtomicUsize>,
110 0 : max_total_conns: usize,
111 0 : pname: String,
112 0 : ) -> Self {
113 0 : Self {
114 0 : pools: hmap,
115 0 : total_conns: tconns,
116 0 : max_conns: max_conns_per_endpoint,
117 0 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
118 0 : global_connections_count,
119 0 : global_pool_size_max_conns: max_total_conns,
120 0 : pool_name: pname,
121 0 : }
122 0 : }
123 :
124 0 : pub(crate) fn get_conn_entry(
125 0 : &mut self,
126 0 : db_user: (DbName, RoleName),
127 0 : ) -> Option<ConnPoolEntry<C>> {
128 0 : let Self {
129 0 : pools,
130 0 : total_conns,
131 0 : global_connections_count,
132 0 : ..
133 0 : } = self;
134 0 : pools.get_mut(&db_user).and_then(|pool_entries| {
135 0 : let (entry, removed) = pool_entries.get_conn_entry(total_conns);
136 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
137 0 : entry
138 0 : })
139 0 : }
140 :
141 0 : pub(crate) fn remove_client(
142 0 : &mut self,
143 0 : db_user: (DbName, RoleName),
144 0 : conn_id: uuid::Uuid,
145 0 : ) -> bool {
146 0 : let Self {
147 0 : pools,
148 0 : total_conns,
149 0 : global_connections_count,
150 0 : ..
151 0 : } = self;
152 0 : if let Some(pool) = pools.get_mut(&db_user) {
153 0 : let old_len = pool.get_conns().len();
154 0 : pool.get_conns()
155 0 : .retain(|conn| conn.conn.get_conn_id() != conn_id);
156 0 : let new_len = pool.get_conns().len();
157 0 : let removed = old_len - new_len;
158 0 : if removed > 0 {
159 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
160 0 : Metrics::get()
161 0 : .proxy
162 0 : .http_pool_opened_connections
163 0 : .get_metric()
164 0 : .dec_by(removed as i64);
165 0 : }
166 0 : *total_conns -= removed;
167 0 : removed > 0
168 : } else {
169 0 : false
170 : }
171 0 : }
172 :
173 6 : pub(crate) fn get_name(&self) -> &str {
174 6 : &self.pool_name
175 6 : }
176 :
177 0 : pub(crate) fn get_pool(&self, db_user: (DbName, RoleName)) -> Option<&DbUserConnPool<C>> {
178 0 : self.pools.get(&db_user)
179 0 : }
180 :
181 0 : pub(crate) fn get_pool_mut(
182 0 : &mut self,
183 0 : db_user: (DbName, RoleName),
184 0 : ) -> Option<&mut DbUserConnPool<C>> {
185 0 : self.pools.get_mut(&db_user)
186 0 : }
187 :
188 6 : pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
189 6 : let conn_id = client.get_conn_id();
190 6 : let pool_name = pool.read().get_name().to_string();
191 6 : if client.inner.is_closed() {
192 1 : info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
193 1 : return;
194 5 : }
195 5 :
196 5 : let global_max_conn = pool.read().global_pool_size_max_conns;
197 5 : if pool
198 5 : .read()
199 5 : .global_connections_count
200 5 : .load(atomic::Ordering::Relaxed)
201 5 : >= global_max_conn
202 : {
203 1 : info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
204 1 : return;
205 4 : }
206 4 :
207 4 : // return connection to the pool
208 4 : let mut returned = false;
209 4 : let mut per_db_size = 0;
210 4 : let total_conns = {
211 4 : let mut pool = pool.write();
212 4 :
213 4 : if pool.total_conns < pool.max_conns {
214 3 : let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
215 3 : pool_entries.get_conns().push(ConnPoolEntry {
216 3 : conn: client,
217 3 : _last_access: std::time::Instant::now(),
218 3 : });
219 3 :
220 3 : returned = true;
221 3 : per_db_size = pool_entries.get_conns().len();
222 3 :
223 3 : pool.total_conns += 1;
224 3 : pool.global_connections_count
225 3 : .fetch_add(1, atomic::Ordering::Relaxed);
226 3 : Metrics::get()
227 3 : .proxy
228 3 : .http_pool_opened_connections
229 3 : .get_metric()
230 3 : .inc();
231 3 : }
232 :
233 4 : pool.total_conns
234 4 : };
235 4 :
236 4 : // do logging outside of the mutex
237 4 : if returned {
238 3 : debug!(%conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
239 : } else {
240 1 : info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
241 : }
242 6 : }
243 : }
244 :
245 : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
246 2 : fn drop(&mut self) {
247 2 : if self.total_conns > 0 {
248 2 : self.global_connections_count
249 2 : .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
250 2 : Metrics::get()
251 2 : .proxy
252 2 : .http_pool_opened_connections
253 2 : .get_metric()
254 2 : .dec_by(self.total_conns as i64);
255 2 : }
256 2 : }
257 : }
258 :
259 : pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
260 : pub(crate) conns: Vec<ConnPoolEntry<C>>,
261 : pub(crate) initialized: Option<bool>, // a bit ugly, exists only for local pools
262 : }
263 :
264 : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
265 2 : fn default() -> Self {
266 2 : Self {
267 2 : conns: Vec::new(),
268 2 : initialized: None,
269 2 : }
270 2 : }
271 : }
272 :
273 : pub(crate) trait DbUserConn<C: ClientInnerExt>: Default {
274 : fn set_initialized(&mut self);
275 : fn is_initialized(&self) -> bool;
276 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize;
277 : fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize);
278 : fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>>;
279 : }
280 :
281 : impl<C: ClientInnerExt> DbUserConn<C> for DbUserConnPool<C> {
282 0 : fn set_initialized(&mut self) {
283 0 : self.initialized = Some(true);
284 0 : }
285 :
286 0 : fn is_initialized(&self) -> bool {
287 0 : self.initialized.unwrap_or(false)
288 0 : }
289 :
290 1 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
291 1 : let old_len = self.conns.len();
292 1 :
293 2 : self.conns.retain(|conn| !conn.conn.inner.is_closed());
294 1 :
295 1 : let new_len = self.conns.len();
296 1 : let removed = old_len - new_len;
297 1 : *conns -= removed;
298 1 : removed
299 1 : }
300 :
301 0 : fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize) {
302 0 : let mut removed = self.clear_closed_clients(conns);
303 0 : let conn = self.conns.pop();
304 0 : if conn.is_some() {
305 0 : *conns -= 1;
306 0 : removed += 1;
307 0 : }
308 :
309 0 : Metrics::get()
310 0 : .proxy
311 0 : .http_pool_opened_connections
312 0 : .get_metric()
313 0 : .dec_by(removed as i64);
314 0 :
315 0 : (conn, removed)
316 0 : }
317 :
318 6 : fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>> {
319 6 : &mut self.conns
320 6 : }
321 : }
322 :
323 : pub(crate) trait EndpointConnPoolExt<C: ClientInnerExt> {
324 : fn clear_closed(&mut self) -> usize;
325 : fn total_conns(&self) -> usize;
326 : }
327 :
328 : impl<C: ClientInnerExt> EndpointConnPoolExt<C> for EndpointConnPool<C> {
329 1 : fn clear_closed(&mut self) -> usize {
330 1 : let mut clients_removed: usize = 0;
331 1 : for db_pool in self.pools.values_mut() {
332 1 : clients_removed += db_pool.clear_closed_clients(&mut self.total_conns);
333 1 : }
334 1 : clients_removed
335 1 : }
336 :
337 1 : fn total_conns(&self) -> usize {
338 1 : self.total_conns
339 1 : }
340 : }
341 :
342 : pub(crate) struct GlobalConnPool<C, P>
343 : where
344 : C: ClientInnerExt,
345 : P: EndpointConnPoolExt<C>,
346 : {
347 : // endpoint -> per-endpoint connection pool
348 : //
349 : // That should be a fairly conteded map, so return reference to the per-endpoint
350 : // pool as early as possible and release the lock.
351 : pub(crate) global_pool: DashMap<EndpointCacheKey, Arc<RwLock<P>>>,
352 :
353 : /// Number of endpoint-connection pools
354 : ///
355 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
356 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
357 : /// It's only used for diagnostics.
358 : pub(crate) global_pool_size: AtomicUsize,
359 :
360 : /// Total number of connections in the pool
361 : pub(crate) global_connections_count: Arc<AtomicUsize>,
362 :
363 : pub(crate) config: &'static crate::config::HttpConfig,
364 :
365 : _marker: PhantomData<C>,
366 : }
367 :
368 : #[derive(Debug, Clone, Copy)]
369 : pub struct GlobalConnPoolOptions {
370 : // Maximum number of connections per one endpoint.
371 : // Can mix different (dbname, username) connections.
372 : // When running out of free slots for a particular endpoint,
373 : // falls back to opening a new connection for each request.
374 : pub max_conns_per_endpoint: usize,
375 :
376 : pub gc_epoch: Duration,
377 :
378 : pub pool_shards: usize,
379 :
380 : pub idle_timeout: Duration,
381 :
382 : pub opt_in: bool,
383 :
384 : // Total number of connections in the pool.
385 : pub max_total_conns: usize,
386 : }
387 :
388 : impl<C, P> GlobalConnPool<C, P>
389 : where
390 : C: ClientInnerExt,
391 : P: EndpointConnPoolExt<C>,
392 : {
393 1 : pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
394 1 : let shards = config.pool_options.pool_shards;
395 1 : Arc::new(Self {
396 1 : global_pool: DashMap::with_shard_amount(shards),
397 1 : global_pool_size: AtomicUsize::new(0),
398 1 : config,
399 1 : global_connections_count: Arc::new(AtomicUsize::new(0)),
400 1 : _marker: PhantomData,
401 1 : })
402 1 : }
403 :
404 : #[cfg(test)]
405 9 : pub(crate) fn get_global_connections_count(&self) -> usize {
406 9 : self.global_connections_count
407 9 : .load(atomic::Ordering::Relaxed)
408 9 : }
409 :
410 0 : pub(crate) fn get_idle_timeout(&self) -> Duration {
411 0 : self.config.pool_options.idle_timeout
412 0 : }
413 :
414 0 : pub(crate) fn shutdown(&self) {
415 0 : // drops all strong references to endpoint-pools
416 0 : self.global_pool.clear();
417 0 : }
418 :
419 0 : pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
420 0 : let epoch = self.config.pool_options.gc_epoch;
421 0 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
422 : loop {
423 0 : interval.tick().await;
424 :
425 0 : let shard = rng.gen_range(0..self.global_pool.shards().len());
426 0 : self.gc(shard);
427 : }
428 : }
429 :
430 2 : pub(crate) fn gc(&self, shard: usize) {
431 2 : debug!(shard, "pool: performing epoch reclamation");
432 :
433 : // acquire a random shard lock
434 2 : let mut shard = self.global_pool.shards()[shard].write();
435 2 :
436 2 : let timer = Metrics::get()
437 2 : .proxy
438 2 : .http_pool_reclaimation_lag_seconds
439 2 : .start_timer();
440 2 : let current_len = shard.len();
441 2 : let mut clients_removed = 0;
442 2 : shard.retain(|endpoint, x| {
443 : // if the current endpoint pool is unique (no other strong or weak references)
444 : // then it is currently not in use by any connections.
445 2 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
446 1 : let endpoints = pool.get_mut();
447 1 : clients_removed = endpoints.clear_closed();
448 1 :
449 1 : if endpoints.total_conns() == 0 {
450 0 : info!("pool: discarding pool for endpoint {endpoint}");
451 0 : return false;
452 1 : }
453 1 : }
454 :
455 2 : true
456 2 : });
457 2 :
458 2 : let new_len = shard.len();
459 2 : drop(shard);
460 2 : timer.observe();
461 2 :
462 2 : // Do logging outside of the lock.
463 2 : if clients_removed > 0 {
464 1 : let size = self
465 1 : .global_connections_count
466 1 : .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
467 1 : - clients_removed;
468 1 : Metrics::get()
469 1 : .proxy
470 1 : .http_pool_opened_connections
471 1 : .get_metric()
472 1 : .dec_by(clients_removed as i64);
473 1 : info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
474 1 : }
475 2 : let removed = current_len - new_len;
476 2 :
477 2 : if removed > 0 {
478 0 : let global_pool_size = self
479 0 : .global_pool_size
480 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
481 0 : - removed;
482 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
483 2 : }
484 2 : }
485 : }
486 :
487 : impl<C: ClientInnerExt> GlobalConnPool<C, EndpointConnPool<C>> {
488 0 : pub(crate) fn get(
489 0 : self: &Arc<Self>,
490 0 : ctx: &RequestContext,
491 0 : conn_info: &ConnInfo,
492 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
493 0 : let mut client: Option<ClientInnerCommon<C>> = None;
494 0 : let Some(endpoint) = conn_info.endpoint_cache_key() else {
495 0 : return Ok(None);
496 : };
497 :
498 0 : let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
499 0 : if let Some(entry) = endpoint_pool
500 0 : .write()
501 0 : .get_conn_entry(conn_info.db_and_user())
502 0 : {
503 0 : client = Some(entry.conn);
504 0 : }
505 0 : let endpoint_pool = Arc::downgrade(&endpoint_pool);
506 :
507 : // ok return cached connection if found and establish a new one otherwise
508 0 : if let Some(mut client) = client {
509 0 : if client.inner.is_closed() {
510 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
511 0 : return Ok(None);
512 0 : }
513 0 : tracing::Span::current()
514 0 : .record("conn_id", tracing::field::display(client.get_conn_id()));
515 0 : tracing::Span::current().record(
516 0 : "pid",
517 0 : tracing::field::display(client.inner.get_process_id()),
518 0 : );
519 0 : debug!(
520 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
521 0 : "pool: reusing connection '{conn_info}'"
522 : );
523 :
524 0 : match client.get_data() {
525 0 : ClientDataEnum::Local(data) => {
526 0 : data.session().send(ctx.session_id())?;
527 : }
528 :
529 0 : ClientDataEnum::Remote(data) => {
530 0 : data.session().send(ctx.session_id())?;
531 : }
532 0 : ClientDataEnum::Http(_) => (),
533 : }
534 :
535 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
536 0 : ctx.success();
537 0 : return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
538 0 : }
539 0 : Ok(None)
540 0 : }
541 :
542 2 : pub(crate) fn get_or_create_endpoint_pool(
543 2 : self: &Arc<Self>,
544 2 : endpoint: &EndpointCacheKey,
545 2 : ) -> Arc<RwLock<EndpointConnPool<C>>> {
546 : // fast path
547 2 : if let Some(pool) = self.global_pool.get(endpoint) {
548 0 : return pool.clone();
549 2 : }
550 2 :
551 2 : // slow path
552 2 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
553 2 : pools: HashMap::new(),
554 2 : total_conns: 0,
555 2 : max_conns: self.config.pool_options.max_conns_per_endpoint,
556 2 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
557 2 : global_connections_count: self.global_connections_count.clone(),
558 2 : global_pool_size_max_conns: self.config.pool_options.max_total_conns,
559 2 : pool_name: String::from("remote"),
560 2 : }));
561 2 :
562 2 : // find or create a pool for this endpoint
563 2 : let mut created = false;
564 2 : let pool = self
565 2 : .global_pool
566 2 : .entry(endpoint.clone())
567 2 : .or_insert_with(|| {
568 2 : created = true;
569 2 : new_pool
570 2 : })
571 2 : .clone();
572 2 :
573 2 : // log new global pool size
574 2 : if created {
575 2 : let global_pool_size = self
576 2 : .global_pool_size
577 2 : .fetch_add(1, atomic::Ordering::Relaxed)
578 2 : + 1;
579 2 : info!(
580 0 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
581 : );
582 0 : }
583 :
584 2 : pool
585 2 : }
586 : }
587 : pub(crate) struct Client<C: ClientInnerExt> {
588 : span: Span,
589 : inner: Option<ClientInnerCommon<C>>,
590 : conn_info: ConnInfo,
591 : pool: Weak<RwLock<EndpointConnPool<C>>>,
592 : }
593 :
594 : pub(crate) struct Discard<'a, C: ClientInnerExt> {
595 : conn_info: &'a ConnInfo,
596 : pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
597 : }
598 :
599 : impl<C: ClientInnerExt> Client<C> {
600 7 : pub(crate) fn new(
601 7 : inner: ClientInnerCommon<C>,
602 7 : conn_info: ConnInfo,
603 7 : pool: Weak<RwLock<EndpointConnPool<C>>>,
604 7 : ) -> Self {
605 7 : Self {
606 7 : inner: Some(inner),
607 7 : span: Span::current(),
608 7 : conn_info,
609 7 : pool,
610 7 : }
611 7 : }
612 :
613 0 : pub(crate) fn client_inner(&mut self) -> (&mut ClientInnerCommon<C>, Discard<'_, C>) {
614 0 : let Self {
615 0 : inner,
616 0 : pool,
617 0 : conn_info,
618 0 : span: _,
619 0 : } = self;
620 0 : let inner_m = inner.as_mut().expect("client inner should not be removed");
621 0 : (inner_m, Discard { conn_info, pool })
622 0 : }
623 :
624 1 : pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
625 1 : let Self {
626 1 : inner,
627 1 : pool,
628 1 : conn_info,
629 1 : span: _,
630 1 : } = self;
631 1 : let inner = inner.as_mut().expect("client inner should not be removed");
632 1 : (&mut inner.inner, Discard { conn_info, pool })
633 1 : }
634 :
635 0 : pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
636 0 : let aux = &self.inner.as_ref().unwrap().aux;
637 0 : USAGE_METRICS.register(Ids {
638 0 : endpoint_id: aux.endpoint_id,
639 0 : branch_id: aux.branch_id,
640 0 : })
641 0 : }
642 :
643 7 : pub(crate) fn do_drop(&mut self) -> Option<impl FnOnce() + use<C>> {
644 7 : let conn_info = self.conn_info.clone();
645 7 : let client = self
646 7 : .inner
647 7 : .take()
648 7 : .expect("client inner should not be removed");
649 7 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
650 6 : let current_span = self.span.clone();
651 6 : // return connection to the pool
652 6 : return Some(move || {
653 6 : let _span = current_span.enter();
654 6 : EndpointConnPool::put(&conn_pool, &conn_info, client);
655 6 : });
656 1 : }
657 1 : None
658 7 : }
659 : }
660 :
661 : impl<C: ClientInnerExt> Drop for Client<C> {
662 1 : fn drop(&mut self) {
663 1 : if let Some(drop) = self.do_drop() {
664 0 : tokio::task::spawn_blocking(drop);
665 1 : }
666 1 : }
667 : }
668 :
669 : impl<C: ClientInnerExt> Deref for Client<C> {
670 : type Target = C;
671 :
672 0 : fn deref(&self) -> &Self::Target {
673 0 : &self
674 0 : .inner
675 0 : .as_ref()
676 0 : .expect("client inner should not be removed")
677 0 : .inner
678 0 : }
679 : }
680 :
681 : pub(crate) trait ClientInnerExt: Sync + Send + 'static {
682 : fn is_closed(&self) -> bool;
683 : fn get_process_id(&self) -> i32;
684 : }
685 :
686 : impl ClientInnerExt for tokio_postgres::Client {
687 0 : fn is_closed(&self) -> bool {
688 0 : self.is_closed()
689 0 : }
690 :
691 0 : fn get_process_id(&self) -> i32 {
692 0 : self.get_process_id()
693 0 : }
694 : }
695 :
696 : impl<C: ClientInnerExt> Discard<'_, C> {
697 0 : pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
698 0 : let conn_info = &self.conn_info;
699 0 : if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
700 0 : info!("pool: throwing away connection '{conn_info}' because connection is not idle");
701 0 : }
702 0 : }
703 1 : pub(crate) fn discard(&mut self) {
704 1 : let conn_info = &self.conn_info;
705 1 : if std::mem::take(self.pool).strong_count() > 0 {
706 1 : info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
707 0 : }
708 1 : }
709 : }
|