Line data Source code
1 : use std::collections::HashMap;
2 : use std::marker::PhantomData;
3 : use std::ops::Deref;
4 : use std::sync::atomic::{self, AtomicUsize};
5 : use std::sync::{Arc, Weak};
6 : use std::time::Duration;
7 :
8 : use clashmap::ClashMap;
9 : use parking_lot::RwLock;
10 : use rand::Rng;
11 : use smol_str::ToSmolStr;
12 : use tracing::{Span, debug, info, warn};
13 :
14 : use super::backend::HttpConnError;
15 : use super::conn_pool::ClientDataRemote;
16 : use super::http_conn_pool::ClientDataHttp;
17 : use super::local_conn_pool::ClientDataLocal;
18 : use crate::auth::backend::ComputeUserInfo;
19 : use crate::context::RequestContext;
20 : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
21 : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
22 : use crate::protocol2::ConnectionInfoExtra;
23 : use crate::types::{DbName, EndpointCacheKey, RoleName};
24 : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
25 :
26 : #[derive(Debug, Clone)]
27 : pub(crate) struct ConnInfo {
28 : pub(crate) user_info: ComputeUserInfo,
29 : pub(crate) dbname: DbName,
30 : }
31 :
32 : impl ConnInfo {
33 : // hm, change to hasher to avoid cloning?
34 3 : pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
35 3 : (self.dbname.clone(), self.user_info.user.clone())
36 3 : }
37 :
38 2 : pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
39 : // We don't want to cache http connections for ephemeral endpoints.
40 2 : if self.user_info.options.is_ephemeral() {
41 0 : None
42 : } else {
43 2 : Some(self.user_info.endpoint_cache_key())
44 : }
45 2 : }
46 : }
47 :
48 : #[derive(Clone)]
49 : #[allow(clippy::large_enum_variant, reason = "TODO")]
50 : pub(crate) enum ClientDataEnum {
51 : Remote(ClientDataRemote),
52 : Local(ClientDataLocal),
53 : Http(ClientDataHttp),
54 : }
55 :
56 : #[derive(Clone)]
57 : pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
58 : pub(crate) inner: C,
59 : pub(crate) aux: MetricsAuxInfo,
60 : pub(crate) conn_id: uuid::Uuid,
61 : pub(crate) data: ClientDataEnum, // custom client data like session, key, jti
62 : }
63 :
64 : impl<C: ClientInnerExt> Drop for ClientInnerCommon<C> {
65 7 : fn drop(&mut self) {
66 7 : match &mut self.data {
67 7 : ClientDataEnum::Remote(remote_data) => {
68 7 : remote_data.cancel();
69 7 : }
70 0 : ClientDataEnum::Local(local_data) => {
71 0 : local_data.cancel();
72 0 : }
73 0 : ClientDataEnum::Http(_http_data) => (),
74 : }
75 7 : }
76 : }
77 :
78 : impl<C: ClientInnerExt> ClientInnerCommon<C> {
79 6 : pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
80 6 : self.conn_id
81 6 : }
82 :
83 0 : pub(crate) fn get_data(&mut self) -> &mut ClientDataEnum {
84 0 : &mut self.data
85 0 : }
86 : }
87 :
88 : pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
89 : pub(crate) conn: ClientInnerCommon<C>,
90 : pub(crate) _last_access: std::time::Instant,
91 : }
92 :
93 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
94 : // Number of open connections is limited by the `max_conns_per_endpoint`.
95 : pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
96 : pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
97 : total_conns: usize,
98 : /// max # connections per endpoint
99 : max_conns: usize,
100 : _guard: HttpEndpointPoolsGuard<'static>,
101 : global_connections_count: Arc<AtomicUsize>,
102 : global_pool_size_max_conns: usize,
103 : pool_name: String,
104 : }
105 :
106 : impl<C: ClientInnerExt> EndpointConnPool<C> {
107 0 : pub(crate) fn new(
108 0 : hmap: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
109 0 : tconns: usize,
110 0 : max_conns_per_endpoint: usize,
111 0 : global_connections_count: Arc<AtomicUsize>,
112 0 : max_total_conns: usize,
113 0 : pname: String,
114 0 : ) -> Self {
115 0 : Self {
116 0 : pools: hmap,
117 0 : total_conns: tconns,
118 0 : max_conns: max_conns_per_endpoint,
119 0 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
120 0 : global_connections_count,
121 0 : global_pool_size_max_conns: max_total_conns,
122 0 : pool_name: pname,
123 0 : }
124 0 : }
125 :
126 0 : pub(crate) fn get_conn_entry(
127 0 : &mut self,
128 0 : db_user: (DbName, RoleName),
129 0 : ) -> Option<ConnPoolEntry<C>> {
130 : let Self {
131 0 : pools,
132 0 : total_conns,
133 0 : global_connections_count,
134 : ..
135 0 : } = self;
136 0 : pools.get_mut(&db_user).and_then(|pool_entries| {
137 0 : let (entry, removed) = pool_entries.get_conn_entry(total_conns);
138 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
139 0 : entry
140 0 : })
141 0 : }
142 :
143 0 : pub(crate) fn remove_client(
144 0 : &mut self,
145 0 : db_user: (DbName, RoleName),
146 0 : conn_id: uuid::Uuid,
147 0 : ) -> bool {
148 : let Self {
149 0 : pools,
150 0 : total_conns,
151 0 : global_connections_count,
152 : ..
153 0 : } = self;
154 0 : if let Some(pool) = pools.get_mut(&db_user) {
155 0 : let old_len = pool.get_conns().len();
156 0 : pool.get_conns()
157 0 : .retain(|conn| conn.conn.get_conn_id() != conn_id);
158 0 : let new_len = pool.get_conns().len();
159 0 : let removed = old_len - new_len;
160 0 : if removed > 0 {
161 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
162 0 : Metrics::get()
163 0 : .proxy
164 0 : .http_pool_opened_connections
165 0 : .get_metric()
166 0 : .dec_by(removed as i64);
167 0 : }
168 0 : *total_conns -= removed;
169 0 : removed > 0
170 : } else {
171 0 : false
172 : }
173 0 : }
174 :
175 6 : pub(crate) fn get_name(&self) -> &str {
176 6 : &self.pool_name
177 6 : }
178 :
179 0 : pub(crate) fn get_pool(&self, db_user: (DbName, RoleName)) -> Option<&DbUserConnPool<C>> {
180 0 : self.pools.get(&db_user)
181 0 : }
182 :
183 0 : pub(crate) fn get_pool_mut(
184 0 : &mut self,
185 0 : db_user: (DbName, RoleName),
186 0 : ) -> Option<&mut DbUserConnPool<C>> {
187 0 : self.pools.get_mut(&db_user)
188 0 : }
189 :
190 6 : pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, mut client: ClientInnerCommon<C>) {
191 6 : let conn_id = client.get_conn_id();
192 6 : let (max_conn, conn_count, pool_name) = {
193 6 : let pool = pool.read();
194 6 : (
195 6 : pool.global_pool_size_max_conns,
196 6 : pool.global_connections_count
197 6 : .load(atomic::Ordering::Relaxed),
198 6 : pool.get_name().to_string(),
199 6 : )
200 6 : };
201 :
202 6 : if client.inner.is_closed() {
203 1 : info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection is closed");
204 1 : return;
205 5 : }
206 :
207 5 : if let Err(error) = client.inner.reset() {
208 0 : warn!(?error, %conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection could not be reset");
209 0 : return;
210 5 : }
211 :
212 5 : if conn_count >= max_conn {
213 1 : info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full");
214 1 : return;
215 4 : }
216 :
217 : // return connection to the pool
218 4 : let mut returned = false;
219 4 : let mut per_db_size = 0;
220 4 : let total_conns = {
221 4 : let mut pool = pool.write();
222 :
223 4 : if pool.total_conns < pool.max_conns {
224 3 : let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
225 3 : pool_entries.get_conns().push(ConnPoolEntry {
226 3 : conn: client,
227 3 : _last_access: std::time::Instant::now(),
228 3 : });
229 3 :
230 3 : returned = true;
231 3 : per_db_size = pool_entries.get_conns().len();
232 3 :
233 3 : pool.total_conns += 1;
234 3 : pool.global_connections_count
235 3 : .fetch_add(1, atomic::Ordering::Relaxed);
236 3 : Metrics::get()
237 3 : .proxy
238 3 : .http_pool_opened_connections
239 3 : .get_metric()
240 3 : .inc();
241 3 : }
242 :
243 4 : pool.total_conns
244 : };
245 :
246 : // do logging outside of the mutex
247 4 : if returned {
248 3 : debug!(%conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
249 : } else {
250 1 : info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
251 : }
252 6 : }
253 : }
254 :
255 : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
256 2 : fn drop(&mut self) {
257 2 : if self.total_conns > 0 {
258 2 : self.global_connections_count
259 2 : .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
260 2 : Metrics::get()
261 2 : .proxy
262 2 : .http_pool_opened_connections
263 2 : .get_metric()
264 2 : .dec_by(self.total_conns as i64);
265 2 : }
266 2 : }
267 : }
268 :
269 : pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
270 : pub(crate) conns: Vec<ConnPoolEntry<C>>,
271 : pub(crate) initialized: Option<bool>, // a bit ugly, exists only for local pools
272 : }
273 :
274 : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
275 2 : fn default() -> Self {
276 2 : Self {
277 2 : conns: Vec::new(),
278 2 : initialized: None,
279 2 : }
280 2 : }
281 : }
282 :
283 : pub(crate) trait DbUserConn<C: ClientInnerExt>: Default {
284 : fn set_initialized(&mut self);
285 : fn is_initialized(&self) -> bool;
286 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize;
287 : fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize);
288 : fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>>;
289 : }
290 :
291 : impl<C: ClientInnerExt> DbUserConn<C> for DbUserConnPool<C> {
292 0 : fn set_initialized(&mut self) {
293 0 : self.initialized = Some(true);
294 0 : }
295 :
296 0 : fn is_initialized(&self) -> bool {
297 0 : self.initialized.unwrap_or(false)
298 0 : }
299 :
300 1 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
301 1 : let old_len = self.conns.len();
302 :
303 2 : self.conns.retain(|conn| !conn.conn.inner.is_closed());
304 :
305 1 : let new_len = self.conns.len();
306 1 : let removed = old_len - new_len;
307 1 : *conns -= removed;
308 1 : removed
309 1 : }
310 :
311 0 : fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize) {
312 0 : let mut removed = self.clear_closed_clients(conns);
313 0 : let conn = self.conns.pop();
314 0 : if conn.is_some() {
315 0 : *conns -= 1;
316 0 : removed += 1;
317 0 : }
318 :
319 0 : Metrics::get()
320 0 : .proxy
321 0 : .http_pool_opened_connections
322 0 : .get_metric()
323 0 : .dec_by(removed as i64);
324 :
325 0 : (conn, removed)
326 0 : }
327 :
328 6 : fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>> {
329 6 : &mut self.conns
330 6 : }
331 : }
332 :
333 : pub(crate) trait EndpointConnPoolExt<C: ClientInnerExt> {
334 : fn clear_closed(&mut self) -> usize;
335 : fn total_conns(&self) -> usize;
336 : }
337 :
338 : impl<C: ClientInnerExt> EndpointConnPoolExt<C> for EndpointConnPool<C> {
339 1 : fn clear_closed(&mut self) -> usize {
340 1 : let mut clients_removed: usize = 0;
341 1 : for db_pool in self.pools.values_mut() {
342 1 : clients_removed += db_pool.clear_closed_clients(&mut self.total_conns);
343 1 : }
344 1 : clients_removed
345 1 : }
346 :
347 1 : fn total_conns(&self) -> usize {
348 1 : self.total_conns
349 1 : }
350 : }
351 :
352 : pub(crate) struct GlobalConnPool<C, P>
353 : where
354 : C: ClientInnerExt,
355 : P: EndpointConnPoolExt<C>,
356 : {
357 : // endpoint -> per-endpoint connection pool
358 : //
359 : // That should be a fairly conteded map, so return reference to the per-endpoint
360 : // pool as early as possible and release the lock.
361 : pub(crate) global_pool: ClashMap<EndpointCacheKey, Arc<RwLock<P>>>,
362 :
363 : /// Number of endpoint-connection pools
364 : ///
365 : /// [`ClashMap::len`] iterates over all inner pools and acquires a read lock on each.
366 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
367 : /// It's only used for diagnostics.
368 : pub(crate) global_pool_size: AtomicUsize,
369 :
370 : /// Total number of connections in the pool
371 : pub(crate) global_connections_count: Arc<AtomicUsize>,
372 :
373 : pub(crate) config: &'static crate::config::HttpConfig,
374 :
375 : _marker: PhantomData<C>,
376 : }
377 :
378 : #[derive(Debug, Clone, Copy)]
379 : pub struct GlobalConnPoolOptions {
380 : // Maximum number of connections per one endpoint.
381 : // Can mix different (dbname, username) connections.
382 : // When running out of free slots for a particular endpoint,
383 : // falls back to opening a new connection for each request.
384 : pub max_conns_per_endpoint: usize,
385 :
386 : pub gc_epoch: Duration,
387 :
388 : pub pool_shards: usize,
389 :
390 : pub idle_timeout: Duration,
391 :
392 : pub opt_in: bool,
393 :
394 : // Total number of connections in the pool.
395 : pub max_total_conns: usize,
396 : }
397 :
398 : impl<C, P> GlobalConnPool<C, P>
399 : where
400 : C: ClientInnerExt,
401 : P: EndpointConnPoolExt<C>,
402 : {
403 1 : pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
404 1 : let shards = config.pool_options.pool_shards;
405 1 : Arc::new(Self {
406 1 : global_pool: ClashMap::with_shard_amount(shards),
407 1 : global_pool_size: AtomicUsize::new(0),
408 1 : config,
409 1 : global_connections_count: Arc::new(AtomicUsize::new(0)),
410 1 : _marker: PhantomData,
411 1 : })
412 1 : }
413 :
414 : #[cfg(test)]
415 9 : pub(crate) fn get_global_connections_count(&self) -> usize {
416 9 : self.global_connections_count
417 9 : .load(atomic::Ordering::Relaxed)
418 9 : }
419 :
420 0 : pub(crate) fn get_idle_timeout(&self) -> Duration {
421 0 : self.config.pool_options.idle_timeout
422 0 : }
423 :
424 0 : pub(crate) fn shutdown(&self) {
425 : // drops all strong references to endpoint-pools
426 0 : self.global_pool.clear();
427 0 : }
428 :
429 0 : pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
430 0 : let epoch = self.config.pool_options.gc_epoch;
431 0 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
432 : loop {
433 0 : interval.tick().await;
434 :
435 0 : let shard = rng.random_range(0..self.global_pool.shards().len());
436 0 : self.gc(shard);
437 : }
438 : }
439 :
440 2 : pub(crate) fn gc(&self, shard: usize) {
441 2 : debug!(shard, "pool: performing epoch reclamation");
442 :
443 : // acquire a random shard lock
444 2 : let mut shard = self.global_pool.shards()[shard].write();
445 :
446 2 : let timer = Metrics::get()
447 2 : .proxy
448 2 : .http_pool_reclaimation_lag_seconds
449 2 : .start_timer();
450 2 : let current_len = shard.len();
451 2 : let mut clients_removed = 0;
452 2 : shard.retain(|(endpoint, x)| {
453 : // if the current endpoint pool is unique (no other strong or weak references)
454 : // then it is currently not in use by any connections.
455 2 : if let Some(pool) = Arc::get_mut(x) {
456 1 : let endpoints = pool.get_mut();
457 1 : clients_removed = endpoints.clear_closed();
458 :
459 1 : if endpoints.total_conns() == 0 {
460 0 : info!("pool: discarding pool for endpoint {endpoint}");
461 0 : return false;
462 1 : }
463 1 : }
464 :
465 2 : true
466 2 : });
467 :
468 2 : let new_len = shard.len();
469 2 : drop(shard);
470 2 : timer.observe();
471 :
472 : // Do logging outside of the lock.
473 2 : if clients_removed > 0 {
474 1 : let size = self
475 1 : .global_connections_count
476 1 : .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
477 1 : - clients_removed;
478 1 : Metrics::get()
479 1 : .proxy
480 1 : .http_pool_opened_connections
481 1 : .get_metric()
482 1 : .dec_by(clients_removed as i64);
483 1 : info!(
484 0 : "pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"
485 : );
486 1 : }
487 2 : let removed = current_len - new_len;
488 :
489 2 : if removed > 0 {
490 0 : let global_pool_size = self
491 0 : .global_pool_size
492 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
493 0 : - removed;
494 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
495 2 : }
496 2 : }
497 : }
498 :
499 : impl<C: ClientInnerExt> GlobalConnPool<C, EndpointConnPool<C>> {
500 0 : pub(crate) fn get(
501 0 : self: &Arc<Self>,
502 0 : ctx: &RequestContext,
503 0 : conn_info: &ConnInfo,
504 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
505 0 : let mut client: Option<ClientInnerCommon<C>> = None;
506 0 : let Some(endpoint) = conn_info.endpoint_cache_key() else {
507 0 : return Ok(None);
508 : };
509 :
510 0 : let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
511 0 : if let Some(entry) = endpoint_pool
512 0 : .write()
513 0 : .get_conn_entry(conn_info.db_and_user())
514 0 : {
515 0 : client = Some(entry.conn);
516 0 : }
517 0 : let endpoint_pool = Arc::downgrade(&endpoint_pool);
518 :
519 : // ok return cached connection if found and establish a new one otherwise
520 0 : if let Some(mut client) = client {
521 0 : if client.inner.is_closed() {
522 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
523 0 : return Ok(None);
524 0 : }
525 0 : tracing::Span::current()
526 0 : .record("conn_id", tracing::field::display(client.get_conn_id()));
527 0 : tracing::Span::current().record(
528 0 : "pid",
529 0 : tracing::field::display(client.inner.get_process_id()),
530 : );
531 0 : debug!(
532 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
533 0 : "pool: reusing connection '{conn_info}'"
534 : );
535 :
536 0 : match client.get_data() {
537 0 : ClientDataEnum::Local(data) => {
538 0 : data.session().send(ctx.session_id())?;
539 : }
540 :
541 0 : ClientDataEnum::Remote(data) => {
542 0 : data.session().send(ctx.session_id())?;
543 : }
544 0 : ClientDataEnum::Http(_) => (),
545 : }
546 :
547 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
548 0 : ctx.success();
549 0 : return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
550 0 : }
551 0 : Ok(None)
552 0 : }
553 :
554 2 : pub(crate) fn get_or_create_endpoint_pool(
555 2 : self: &Arc<Self>,
556 2 : endpoint: &EndpointCacheKey,
557 2 : ) -> Arc<RwLock<EndpointConnPool<C>>> {
558 : // fast path
559 2 : if let Some(pool) = self.global_pool.get(endpoint) {
560 0 : return pool.clone();
561 2 : }
562 :
563 : // slow path
564 2 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
565 2 : pools: HashMap::new(),
566 2 : total_conns: 0,
567 2 : max_conns: self.config.pool_options.max_conns_per_endpoint,
568 2 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
569 2 : global_connections_count: self.global_connections_count.clone(),
570 2 : global_pool_size_max_conns: self.config.pool_options.max_total_conns,
571 2 : pool_name: String::from("remote"),
572 2 : }));
573 :
574 : // find or create a pool for this endpoint
575 2 : let mut created = false;
576 2 : let pool = self
577 2 : .global_pool
578 2 : .entry(endpoint.clone())
579 2 : .or_insert_with(|| {
580 2 : created = true;
581 2 : new_pool
582 2 : })
583 2 : .clone();
584 :
585 : // log new global pool size
586 2 : if created {
587 2 : let global_pool_size = self
588 2 : .global_pool_size
589 2 : .fetch_add(1, atomic::Ordering::Relaxed)
590 2 : + 1;
591 2 : info!(
592 0 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
593 : );
594 0 : }
595 :
596 2 : pool
597 2 : }
598 : }
599 : pub(crate) struct Client<C: ClientInnerExt> {
600 : span: Span,
601 : inner: Option<ClientInnerCommon<C>>,
602 : conn_info: ConnInfo,
603 : pool: Weak<RwLock<EndpointConnPool<C>>>,
604 : }
605 :
606 : pub(crate) struct Discard<'a, C: ClientInnerExt> {
607 : conn_info: &'a ConnInfo,
608 : pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
609 : }
610 :
611 : impl<C: ClientInnerExt> Client<C> {
612 7 : pub(crate) fn new(
613 7 : inner: ClientInnerCommon<C>,
614 7 : conn_info: ConnInfo,
615 7 : pool: Weak<RwLock<EndpointConnPool<C>>>,
616 7 : ) -> Self {
617 7 : Self {
618 7 : inner: Some(inner),
619 7 : span: Span::current(),
620 7 : conn_info,
621 7 : pool,
622 7 : }
623 7 : }
624 :
625 0 : pub(crate) fn client_inner(&mut self) -> (&mut ClientInnerCommon<C>, Discard<'_, C>) {
626 : let Self {
627 0 : inner,
628 0 : pool,
629 0 : conn_info,
630 : span: _,
631 0 : } = self;
632 0 : let inner_m = inner.as_mut().expect("client inner should not be removed");
633 0 : (inner_m, Discard { conn_info, pool })
634 0 : }
635 :
636 1 : pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
637 : let Self {
638 1 : inner,
639 1 : pool,
640 1 : conn_info,
641 : span: _,
642 1 : } = self;
643 1 : let inner = inner.as_mut().expect("client inner should not be removed");
644 1 : (&mut inner.inner, Discard { conn_info, pool })
645 1 : }
646 :
647 0 : pub(crate) fn metrics(&self, ctx: &RequestContext) -> Arc<MetricCounter> {
648 0 : let aux = &self
649 0 : .inner
650 0 : .as_ref()
651 0 : .expect("client inner should not be removed")
652 0 : .aux;
653 :
654 0 : let private_link_id = match ctx.extra() {
655 0 : None => None,
656 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
657 0 : Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
658 : };
659 :
660 0 : USAGE_METRICS.register(Ids {
661 0 : endpoint_id: aux.endpoint_id,
662 0 : branch_id: aux.branch_id,
663 0 : private_link_id,
664 0 : })
665 0 : }
666 : }
667 :
668 : impl<C: ClientInnerExt> Drop for Client<C> {
669 7 : fn drop(&mut self) {
670 7 : let conn_info = self.conn_info.clone();
671 7 : let client = self
672 7 : .inner
673 7 : .take()
674 7 : .expect("client inner should not be removed");
675 7 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
676 6 : let _current_span = self.span.enter();
677 6 : // return connection to the pool
678 6 : EndpointConnPool::put(&conn_pool, &conn_info, client);
679 6 : }
680 7 : }
681 : }
682 :
683 : impl<C: ClientInnerExt> Deref for Client<C> {
684 : type Target = C;
685 :
686 0 : fn deref(&self) -> &Self::Target {
687 0 : &self
688 0 : .inner
689 0 : .as_ref()
690 0 : .expect("client inner should not be removed")
691 0 : .inner
692 0 : }
693 : }
694 :
695 : pub(crate) trait ClientInnerExt: Sync + Send + 'static {
696 : fn is_closed(&self) -> bool;
697 : fn get_process_id(&self) -> i32;
698 : fn reset(&mut self) -> Result<(), postgres_client::Error>;
699 : }
700 :
701 : impl ClientInnerExt for postgres_client::Client {
702 0 : fn is_closed(&self) -> bool {
703 0 : self.is_closed()
704 0 : }
705 :
706 0 : fn get_process_id(&self) -> i32 {
707 0 : self.get_process_id()
708 0 : }
709 :
710 0 : fn reset(&mut self) -> Result<(), postgres_client::Error> {
711 0 : self.reset_session_background()
712 0 : }
713 : }
714 :
715 : impl<C: ClientInnerExt> Discard<'_, C> {
716 1 : pub(crate) fn discard(&mut self) {
717 1 : let conn_info = &self.conn_info;
718 1 : if std::mem::take(self.pool).strong_count() > 0 {
719 1 : info!(
720 0 : "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"
721 : );
722 0 : }
723 1 : }
724 : }
|