Line data Source code
1 : use dashmap::DashMap;
2 : use futures::{future::poll_fn, Future};
3 : use metrics::IntCounterPairGuard;
4 : use parking_lot::RwLock;
5 : use rand::Rng;
6 : use smallvec::SmallVec;
7 : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
8 : use std::{
9 : fmt,
10 : task::{ready, Poll},
11 : };
12 : use std::{
13 : ops::Deref,
14 : sync::atomic::{self, AtomicUsize},
15 : };
16 : use tokio::time::Instant;
17 : use tokio_postgres::tls::NoTlsStream;
18 : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
19 :
20 : use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
21 : use crate::metrics::{ENDPOINT_POOLS, GC_LATENCY, NUM_OPEN_CLIENTS_IN_HTTP_POOL};
22 : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
23 : use crate::{
24 : auth::backend::ComputeUserInfo, context::RequestMonitoring, metrics::NUM_DB_CONNECTIONS_GAUGE,
25 : DbName, EndpointCacheKey, RoleName,
26 : };
27 :
28 : use tracing::{debug, error, warn, Span};
29 : use tracing::{info, info_span, Instrument};
30 :
31 : use super::backend::HttpConnError;
32 :
33 : #[derive(Debug, Clone)]
34 : pub struct ConnInfo {
35 : pub user_info: ComputeUserInfo,
36 : pub dbname: DbName,
37 : pub password: SmallVec<[u8; 16]>,
38 : }
39 :
40 : impl ConnInfo {
41 : // hm, change to hasher to avoid cloning?
42 6 : pub fn db_and_user(&self) -> (DbName, RoleName) {
43 6 : (self.dbname.clone(), self.user_info.user.clone())
44 6 : }
45 :
46 4 : pub fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
47 4 : // We don't want to cache http connections for ephemeral endpoints.
48 4 : if self.user_info.options.is_ephemeral() {
49 0 : None
50 : } else {
51 4 : Some(self.user_info.endpoint_cache_key())
52 : }
53 4 : }
54 : }
55 :
56 : impl fmt::Display for ConnInfo {
57 : // use custom display to avoid logging password
58 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 0 : write!(
60 0 : f,
61 0 : "{}@{}/{}?{}",
62 0 : self.user_info.user,
63 0 : self.user_info.endpoint,
64 0 : self.dbname,
65 0 : self.user_info.options.get_cache_key("")
66 0 : )
67 0 : }
68 : }
69 :
70 : struct ConnPoolEntry<C: ClientInnerExt> {
71 : conn: ClientInner<C>,
72 : _last_access: std::time::Instant,
73 : }
74 :
75 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
76 : // Number of open connections is limited by the `max_conns_per_endpoint`.
77 : pub struct EndpointConnPool<C: ClientInnerExt> {
78 : pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
79 : total_conns: usize,
80 : max_conns: usize,
81 : _guard: IntCounterPairGuard,
82 : global_connections_count: Arc<AtomicUsize>,
83 : global_pool_size_max_conns: usize,
84 : }
85 :
86 : impl<C: ClientInnerExt> EndpointConnPool<C> {
87 0 : fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
88 0 : let Self {
89 0 : pools,
90 0 : total_conns,
91 0 : global_connections_count,
92 0 : ..
93 0 : } = self;
94 0 : pools.get_mut(&db_user).and_then(|pool_entries| {
95 0 : pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
96 0 : })
97 0 : }
98 :
99 0 : fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
100 0 : let Self {
101 0 : pools,
102 0 : total_conns,
103 0 : global_connections_count,
104 0 : ..
105 0 : } = self;
106 0 : if let Some(pool) = pools.get_mut(&db_user) {
107 0 : let old_len = pool.conns.len();
108 0 : pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
109 0 : let new_len = pool.conns.len();
110 0 : let removed = old_len - new_len;
111 0 : if removed > 0 {
112 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
113 0 : NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64);
114 0 : }
115 0 : *total_conns -= removed;
116 0 : removed > 0
117 : } else {
118 0 : false
119 : }
120 0 : }
121 :
122 12 : fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
123 12 : let conn_id = client.conn_id;
124 12 :
125 12 : if client.is_closed() {
126 2 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
127 2 : return;
128 10 : }
129 10 : let global_max_conn = pool.read().global_pool_size_max_conns;
130 10 : if pool
131 10 : .read()
132 10 : .global_connections_count
133 10 : .load(atomic::Ordering::Relaxed)
134 10 : >= global_max_conn
135 : {
136 2 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
137 2 : return;
138 8 : }
139 8 :
140 8 : // return connection to the pool
141 8 : let mut returned = false;
142 8 : let mut per_db_size = 0;
143 8 : let total_conns = {
144 8 : let mut pool = pool.write();
145 8 :
146 8 : if pool.total_conns < pool.max_conns {
147 6 : let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
148 6 : pool_entries.conns.push(ConnPoolEntry {
149 6 : conn: client,
150 6 : _last_access: std::time::Instant::now(),
151 6 : });
152 6 :
153 6 : returned = true;
154 6 : per_db_size = pool_entries.conns.len();
155 6 :
156 6 : pool.total_conns += 1;
157 6 : pool.global_connections_count
158 6 : .fetch_add(1, atomic::Ordering::Relaxed);
159 6 : NUM_OPEN_CLIENTS_IN_HTTP_POOL.inc();
160 6 : }
161 :
162 8 : pool.total_conns
163 8 : };
164 8 :
165 8 : // do logging outside of the mutex
166 8 : if returned {
167 6 : info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
168 : } else {
169 2 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
170 : }
171 12 : }
172 : }
173 :
174 : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
175 4 : fn drop(&mut self) {
176 4 : if self.total_conns > 0 {
177 4 : self.global_connections_count
178 4 : .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
179 4 : NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(self.total_conns as i64);
180 4 : }
181 4 : }
182 : }
183 :
184 : pub struct DbUserConnPool<C: ClientInnerExt> {
185 : conns: Vec<ConnPoolEntry<C>>,
186 : }
187 :
188 : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
189 4 : fn default() -> Self {
190 4 : Self { conns: Vec::new() }
191 4 : }
192 : }
193 :
194 : impl<C: ClientInnerExt> DbUserConnPool<C> {
195 2 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
196 2 : let old_len = self.conns.len();
197 2 :
198 4 : self.conns.retain(|conn| !conn.conn.is_closed());
199 2 :
200 2 : let new_len = self.conns.len();
201 2 : let removed = old_len - new_len;
202 2 : *conns -= removed;
203 2 : removed
204 2 : }
205 :
206 0 : fn get_conn_entry(
207 0 : &mut self,
208 0 : conns: &mut usize,
209 0 : global_connections_count: Arc<AtomicUsize>,
210 0 : ) -> Option<ConnPoolEntry<C>> {
211 0 : let mut removed = self.clear_closed_clients(conns);
212 0 : let conn = self.conns.pop();
213 0 : if conn.is_some() {
214 0 : *conns -= 1;
215 0 : removed += 1;
216 0 : }
217 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
218 0 : NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64);
219 0 : conn
220 0 : }
221 : }
222 :
223 : pub struct GlobalConnPool<C: ClientInnerExt> {
224 : // endpoint -> per-endpoint connection pool
225 : //
226 : // That should be a fairly conteded map, so return reference to the per-endpoint
227 : // pool as early as possible and release the lock.
228 : global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
229 :
230 : /// Number of endpoint-connection pools
231 : ///
232 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
233 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
234 : /// It's only used for diagnostics.
235 : global_pool_size: AtomicUsize,
236 :
237 : /// Total number of connections in the pool
238 : global_connections_count: Arc<AtomicUsize>,
239 :
240 : config: &'static crate::config::HttpConfig,
241 : }
242 :
243 : #[derive(Debug, Clone, Copy)]
244 : pub struct GlobalConnPoolOptions {
245 : // Maximum number of connections per one endpoint.
246 : // Can mix different (dbname, username) connections.
247 : // When running out of free slots for a particular endpoint,
248 : // falls back to opening a new connection for each request.
249 : pub max_conns_per_endpoint: usize,
250 :
251 : pub gc_epoch: Duration,
252 :
253 : pub pool_shards: usize,
254 :
255 : pub idle_timeout: Duration,
256 :
257 : pub opt_in: bool,
258 :
259 : // Total number of connections in the pool.
260 : pub max_total_conns: usize,
261 : }
262 :
263 : impl<C: ClientInnerExt> GlobalConnPool<C> {
264 2 : pub fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
265 2 : let shards = config.pool_options.pool_shards;
266 2 : Arc::new(Self {
267 2 : global_pool: DashMap::with_shard_amount(shards),
268 2 : global_pool_size: AtomicUsize::new(0),
269 2 : config,
270 2 : global_connections_count: Arc::new(AtomicUsize::new(0)),
271 2 : })
272 2 : }
273 :
274 : #[cfg(test)]
275 18 : pub fn get_global_connections_count(&self) -> usize {
276 18 : self.global_connections_count
277 18 : .load(atomic::Ordering::Relaxed)
278 18 : }
279 :
280 0 : pub fn get_idle_timeout(&self) -> Duration {
281 0 : self.config.pool_options.idle_timeout
282 0 : }
283 :
284 0 : pub fn shutdown(&self) {
285 0 : // drops all strong references to endpoint-pools
286 0 : self.global_pool.clear();
287 0 : }
288 :
289 0 : pub async fn gc_worker(&self, mut rng: impl Rng) {
290 0 : let epoch = self.config.pool_options.gc_epoch;
291 0 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
292 : loop {
293 0 : interval.tick().await;
294 :
295 0 : let shard = rng.gen_range(0..self.global_pool.shards().len());
296 0 : self.gc(shard);
297 : }
298 : }
299 :
300 4 : fn gc(&self, shard: usize) {
301 4 : debug!(shard, "pool: performing epoch reclamation");
302 :
303 : // acquire a random shard lock
304 4 : let mut shard = self.global_pool.shards()[shard].write();
305 4 :
306 4 : let timer = GC_LATENCY.start_timer();
307 4 : let current_len = shard.len();
308 4 : let mut clients_removed = 0;
309 4 : shard.retain(|endpoint, x| {
310 : // if the current endpoint pool is unique (no other strong or weak references)
311 : // then it is currently not in use by any connections.
312 4 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
313 : let EndpointConnPool {
314 2 : pools, total_conns, ..
315 2 : } = pool.get_mut();
316 2 :
317 2 : // ensure that closed clients are removed
318 2 : pools.iter_mut().for_each(|(_, db_pool)| {
319 2 : clients_removed += db_pool.clear_closed_clients(total_conns);
320 2 : });
321 2 :
322 2 : // we only remove this pool if it has no active connections
323 2 : if *total_conns == 0 {
324 0 : info!("pool: discarding pool for endpoint {endpoint}");
325 0 : return false;
326 2 : }
327 2 : }
328 :
329 4 : true
330 4 : });
331 4 :
332 4 : let new_len = shard.len();
333 4 : drop(shard);
334 4 : timer.observe_duration();
335 4 :
336 4 : // Do logging outside of the lock.
337 4 : if clients_removed > 0 {
338 2 : let size = self
339 2 : .global_connections_count
340 2 : .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
341 2 : - clients_removed;
342 2 : NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(clients_removed as i64);
343 2 : info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
344 2 : }
345 4 : let removed = current_len - new_len;
346 4 :
347 4 : if removed > 0 {
348 0 : let global_pool_size = self
349 0 : .global_pool_size
350 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
351 0 : - removed;
352 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
353 4 : }
354 4 : }
355 :
356 0 : pub async fn get(
357 0 : self: &Arc<Self>,
358 0 : ctx: &mut RequestMonitoring,
359 0 : conn_info: &ConnInfo,
360 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
361 0 : let mut client: Option<ClientInner<C>> = None;
362 0 : let Some(endpoint) = conn_info.endpoint_cache_key() else {
363 0 : return Ok(None);
364 : };
365 :
366 0 : let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
367 0 : if let Some(entry) = endpoint_pool
368 0 : .write()
369 0 : .get_conn_entry(conn_info.db_and_user())
370 : {
371 0 : client = Some(entry.conn)
372 0 : }
373 0 : let endpoint_pool = Arc::downgrade(&endpoint_pool);
374 :
375 : // ok return cached connection if found and establish a new one otherwise
376 0 : if let Some(client) = client {
377 0 : if client.is_closed() {
378 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
379 0 : return Ok(None);
380 : } else {
381 0 : tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
382 0 : tracing::Span::current().record(
383 0 : "pid",
384 0 : &tracing::field::display(client.inner.get_process_id()),
385 0 : );
386 0 : info!(
387 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
388 0 : "pool: reusing connection '{conn_info}'"
389 0 : );
390 0 : client.session.send(ctx.session_id)?;
391 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
392 0 : ctx.latency_timer.success();
393 0 : return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
394 : }
395 0 : }
396 0 : Ok(None)
397 0 : }
398 :
399 4 : fn get_or_create_endpoint_pool(
400 4 : self: &Arc<Self>,
401 4 : endpoint: &EndpointCacheKey,
402 4 : ) -> Arc<RwLock<EndpointConnPool<C>>> {
403 : // fast path
404 4 : if let Some(pool) = self.global_pool.get(endpoint) {
405 0 : return pool.clone();
406 4 : }
407 4 :
408 4 : // slow path
409 4 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
410 4 : pools: HashMap::new(),
411 4 : total_conns: 0,
412 4 : max_conns: self.config.pool_options.max_conns_per_endpoint,
413 4 : _guard: ENDPOINT_POOLS.guard(),
414 4 : global_connections_count: self.global_connections_count.clone(),
415 4 : global_pool_size_max_conns: self.config.pool_options.max_total_conns,
416 4 : }));
417 4 :
418 4 : // find or create a pool for this endpoint
419 4 : let mut created = false;
420 4 : let pool = self
421 4 : .global_pool
422 4 : .entry(endpoint.clone())
423 4 : .or_insert_with(|| {
424 4 : created = true;
425 4 : new_pool
426 4 : })
427 4 : .clone();
428 4 :
429 4 : // log new global pool size
430 4 : if created {
431 4 : let global_pool_size = self
432 4 : .global_pool_size
433 4 : .fetch_add(1, atomic::Ordering::Relaxed)
434 4 : + 1;
435 4 : info!(
436 0 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
437 0 : );
438 0 : }
439 :
440 4 : pool
441 4 : }
442 : }
443 :
444 0 : pub fn poll_client<C: ClientInnerExt>(
445 0 : global_pool: Arc<GlobalConnPool<C>>,
446 0 : ctx: &mut RequestMonitoring,
447 0 : conn_info: ConnInfo,
448 0 : client: C,
449 0 : mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
450 0 : conn_id: uuid::Uuid,
451 0 : aux: MetricsAuxInfo,
452 0 : ) -> Client<C> {
453 0 : let conn_gauge = NUM_DB_CONNECTIONS_GAUGE
454 0 : .with_label_values(&[ctx.protocol])
455 0 : .guard();
456 0 : let mut session_id = ctx.session_id;
457 0 : let (tx, mut rx) = tokio::sync::watch::channel(session_id);
458 :
459 0 : let span = info_span!(parent: None, "connection", %conn_id);
460 0 : let cold_start_info = ctx.cold_start_info;
461 0 : span.in_scope(|| {
462 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
463 0 : });
464 0 : let pool = match conn_info.endpoint_cache_key() {
465 0 : Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
466 0 : None => Weak::new(),
467 : };
468 0 : let pool_clone = pool.clone();
469 0 :
470 0 : let db_user = conn_info.db_and_user();
471 0 : let idle = global_pool.get_idle_timeout();
472 0 : tokio::spawn(
473 0 : async move {
474 0 : let _conn_gauge = conn_gauge;
475 0 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
476 0 : poll_fn(move |cx| {
477 0 : if matches!(rx.has_changed(), Ok(true)) {
478 0 : session_id = *rx.borrow_and_update();
479 0 : info!(%session_id, "changed session");
480 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
481 0 : }
482 :
483 : // 5 minute idle connection timeout
484 0 : if idle_timeout.as_mut().poll(cx).is_ready() {
485 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
486 0 : info!("connection idle");
487 0 : if let Some(pool) = pool.clone().upgrade() {
488 : // remove client from pool - should close the connection if it's idle.
489 : // does nothing if the client is currently checked-out and in-use
490 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
491 0 : info!("idle connection removed");
492 0 : }
493 0 : }
494 0 : }
495 :
496 : loop {
497 0 : let message = ready!(connection.poll_message(cx));
498 :
499 0 : match message {
500 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
501 0 : info!(%session_id, "notice: {}", notice);
502 : }
503 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
504 0 : warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
505 : }
506 : Some(Ok(_)) => {
507 0 : warn!(%session_id, "unknown message");
508 : }
509 0 : Some(Err(e)) => {
510 0 : error!(%session_id, "connection error: {}", e);
511 0 : break
512 : }
513 : None => {
514 0 : info!("connection closed");
515 0 : break
516 : }
517 : }
518 : }
519 :
520 : // remove from connection pool
521 0 : if let Some(pool) = pool.clone().upgrade() {
522 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
523 0 : info!("closed connection removed");
524 0 : }
525 0 : }
526 :
527 0 : Poll::Ready(())
528 0 : }).await;
529 :
530 0 : }
531 0 : .instrument(span));
532 0 : let inner = ClientInner {
533 0 : inner: client,
534 0 : session: tx,
535 0 : aux,
536 0 : conn_id,
537 0 : };
538 0 : Client::new(inner, conn_info, pool_clone)
539 0 : }
540 :
541 : struct ClientInner<C: ClientInnerExt> {
542 : inner: C,
543 : session: tokio::sync::watch::Sender<uuid::Uuid>,
544 : aux: MetricsAuxInfo,
545 : conn_id: uuid::Uuid,
546 : }
547 :
548 : pub trait ClientInnerExt: Sync + Send + 'static {
549 : fn is_closed(&self) -> bool;
550 : fn get_process_id(&self) -> i32;
551 : }
552 :
553 : impl ClientInnerExt for tokio_postgres::Client {
554 0 : fn is_closed(&self) -> bool {
555 0 : self.is_closed()
556 0 : }
557 0 : fn get_process_id(&self) -> i32 {
558 0 : self.get_process_id()
559 0 : }
560 : }
561 :
562 : impl<C: ClientInnerExt> ClientInner<C> {
563 16 : pub fn is_closed(&self) -> bool {
564 16 : self.inner.is_closed()
565 16 : }
566 : }
567 :
568 : impl<C: ClientInnerExt> Client<C> {
569 0 : pub fn metrics(&self) -> Arc<MetricCounter> {
570 0 : let aux = &self.inner.as_ref().unwrap().aux;
571 0 : USAGE_METRICS.register(Ids {
572 0 : endpoint_id: aux.endpoint_id,
573 0 : branch_id: aux.branch_id,
574 0 : })
575 0 : }
576 : }
577 :
578 : pub struct Client<C: ClientInnerExt> {
579 : span: Span,
580 : inner: Option<ClientInner<C>>,
581 : conn_info: ConnInfo,
582 : pool: Weak<RwLock<EndpointConnPool<C>>>,
583 : }
584 :
585 : pub struct Discard<'a, C: ClientInnerExt> {
586 : conn_info: &'a ConnInfo,
587 : pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
588 : }
589 :
590 : impl<C: ClientInnerExt> Client<C> {
591 14 : pub(self) fn new(
592 14 : inner: ClientInner<C>,
593 14 : conn_info: ConnInfo,
594 14 : pool: Weak<RwLock<EndpointConnPool<C>>>,
595 14 : ) -> Self {
596 14 : Self {
597 14 : inner: Some(inner),
598 14 : span: Span::current(),
599 14 : conn_info,
600 14 : pool,
601 14 : }
602 14 : }
603 2 : pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
604 2 : let Self {
605 2 : inner,
606 2 : pool,
607 2 : conn_info,
608 2 : span: _,
609 2 : } = self;
610 2 : let inner = inner.as_mut().expect("client inner should not be removed");
611 2 : (&mut inner.inner, Discard { pool, conn_info })
612 2 : }
613 : }
614 :
615 : impl<C: ClientInnerExt> Discard<'_, C> {
616 0 : pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
617 0 : let conn_info = &self.conn_info;
618 0 : if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
619 0 : info!("pool: throwing away connection '{conn_info}' because connection is not idle")
620 0 : }
621 0 : }
622 2 : pub fn discard(&mut self) {
623 2 : let conn_info = &self.conn_info;
624 2 : if std::mem::take(self.pool).strong_count() > 0 {
625 2 : info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
626 0 : }
627 2 : }
628 : }
629 :
630 : impl<C: ClientInnerExt> Deref for Client<C> {
631 : type Target = C;
632 :
633 0 : fn deref(&self) -> &Self::Target {
634 0 : &self
635 0 : .inner
636 0 : .as_ref()
637 0 : .expect("client inner should not be removed")
638 0 : .inner
639 0 : }
640 : }
641 :
642 : impl<C: ClientInnerExt> Client<C> {
643 14 : fn do_drop(&mut self) -> Option<impl FnOnce()> {
644 14 : let conn_info = self.conn_info.clone();
645 14 : let client = self
646 14 : .inner
647 14 : .take()
648 14 : .expect("client inner should not be removed");
649 14 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
650 12 : let current_span = self.span.clone();
651 12 : // return connection to the pool
652 12 : return Some(move || {
653 12 : let _span = current_span.enter();
654 12 : EndpointConnPool::put(&conn_pool, &conn_info, client);
655 12 : });
656 2 : }
657 2 : None
658 14 : }
659 : }
660 :
661 : impl<C: ClientInnerExt> Drop for Client<C> {
662 2 : fn drop(&mut self) {
663 2 : if let Some(drop) = self.do_drop() {
664 0 : tokio::task::spawn_blocking(drop);
665 2 : }
666 2 : }
667 : }
668 :
669 : #[cfg(test)]
670 : mod tests {
671 : use std::{mem, sync::atomic::AtomicBool};
672 :
673 : use crate::{BranchId, EndpointId, ProjectId};
674 :
675 : use super::*;
676 :
677 : struct MockClient(Arc<AtomicBool>);
678 : impl MockClient {
679 12 : fn new(is_closed: bool) -> Self {
680 12 : MockClient(Arc::new(is_closed.into()))
681 12 : }
682 : }
683 : impl ClientInnerExt for MockClient {
684 16 : fn is_closed(&self) -> bool {
685 16 : self.0.load(atomic::Ordering::Relaxed)
686 16 : }
687 0 : fn get_process_id(&self) -> i32 {
688 0 : 0
689 0 : }
690 : }
691 :
692 10 : fn create_inner() -> ClientInner<MockClient> {
693 10 : create_inner_with(MockClient::new(false))
694 10 : }
695 :
696 14 : fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
697 14 : ClientInner {
698 14 : inner: client,
699 14 : session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
700 14 : aux: MetricsAuxInfo {
701 14 : endpoint_id: (&EndpointId::from("endpoint")).into(),
702 14 : project_id: (&ProjectId::from("project")).into(),
703 14 : branch_id: (&BranchId::from("branch")).into(),
704 14 : cold_start_info: crate::console::messages::ColdStartInfo::Warm,
705 14 : },
706 14 : conn_id: uuid::Uuid::new_v4(),
707 14 : }
708 14 : }
709 :
710 : #[tokio::test]
711 2 : async fn test_pool() {
712 2 : let _ = env_logger::try_init();
713 2 : let config = Box::leak(Box::new(crate::config::HttpConfig {
714 2 : pool_options: GlobalConnPoolOptions {
715 2 : max_conns_per_endpoint: 2,
716 2 : gc_epoch: Duration::from_secs(1),
717 2 : pool_shards: 2,
718 2 : idle_timeout: Duration::from_secs(1),
719 2 : opt_in: false,
720 2 : max_total_conns: 3,
721 2 : },
722 2 : request_timeout: Duration::from_secs(1),
723 2 : }));
724 2 : let pool = GlobalConnPool::new(config);
725 2 : let conn_info = ConnInfo {
726 2 : user_info: ComputeUserInfo {
727 2 : user: "user".into(),
728 2 : endpoint: "endpoint".into(),
729 2 : options: Default::default(),
730 2 : },
731 2 : dbname: "dbname".into(),
732 2 : password: "password".as_bytes().into(),
733 2 : };
734 2 : let ep_pool = Arc::downgrade(
735 2 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
736 2 : );
737 2 : {
738 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
739 2 : assert_eq!(0, pool.get_global_connections_count());
740 2 : client.inner().1.discard();
741 2 : // Discard should not add the connection from the pool.
742 2 : assert_eq!(0, pool.get_global_connections_count());
743 2 : }
744 2 : {
745 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
746 2 : client.do_drop().unwrap()();
747 2 : mem::forget(client); // drop the client
748 2 : assert_eq!(1, pool.get_global_connections_count());
749 2 : }
750 2 : {
751 2 : let mut closed_client = Client::new(
752 2 : create_inner_with(MockClient::new(true)),
753 2 : conn_info.clone(),
754 2 : ep_pool.clone(),
755 2 : );
756 2 : closed_client.do_drop().unwrap()();
757 2 : mem::forget(closed_client); // drop the client
758 2 : // The closed client shouldn't be added to the pool.
759 2 : assert_eq!(1, pool.get_global_connections_count());
760 2 : }
761 2 : let is_closed: Arc<AtomicBool> = Arc::new(false.into());
762 2 : {
763 2 : let mut client = Client::new(
764 2 : create_inner_with(MockClient(is_closed.clone())),
765 2 : conn_info.clone(),
766 2 : ep_pool.clone(),
767 2 : );
768 2 : client.do_drop().unwrap()();
769 2 : mem::forget(client); // drop the client
770 2 :
771 2 : // The client should be added to the pool.
772 2 : assert_eq!(2, pool.get_global_connections_count());
773 2 : }
774 2 : {
775 2 : let mut client = Client::new(create_inner(), conn_info, ep_pool);
776 2 : client.do_drop().unwrap()();
777 2 : mem::forget(client); // drop the client
778 2 :
779 2 : // The client shouldn't be added to the pool. Because the ep-pool is full.
780 2 : assert_eq!(2, pool.get_global_connections_count());
781 2 : }
782 2 :
783 2 : let conn_info = ConnInfo {
784 2 : user_info: ComputeUserInfo {
785 2 : user: "user".into(),
786 2 : endpoint: "endpoint-2".into(),
787 2 : options: Default::default(),
788 2 : },
789 2 : dbname: "dbname".into(),
790 2 : password: "password".as_bytes().into(),
791 2 : };
792 2 : let ep_pool = Arc::downgrade(
793 2 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
794 2 : );
795 2 : {
796 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
797 2 : client.do_drop().unwrap()();
798 2 : mem::forget(client); // drop the client
799 2 : assert_eq!(3, pool.get_global_connections_count());
800 2 : }
801 2 : {
802 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
803 2 : client.do_drop().unwrap()();
804 2 : mem::forget(client); // drop the client
805 2 :
806 2 : // The client shouldn't be added to the pool. Because the global pool is full.
807 2 : assert_eq!(3, pool.get_global_connections_count());
808 2 : }
809 2 :
810 2 : is_closed.store(true, atomic::Ordering::Relaxed);
811 2 : // Do gc for all shards.
812 2 : pool.gc(0);
813 2 : pool.gc(1);
814 2 : // Closed client should be removed from the pool.
815 2 : assert_eq!(2, pool.get_global_connections_count());
816 2 : }
817 : }
|