Line data Source code
1 : use dashmap::DashMap;
2 : use futures::{future::poll_fn, Future};
3 : use parking_lot::RwLock;
4 : use rand::Rng;
5 : use smallvec::SmallVec;
6 : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
7 : use std::{
8 : fmt,
9 : task::{ready, Poll},
10 : };
11 : use std::{
12 : ops::Deref,
13 : sync::atomic::{self, AtomicUsize},
14 : };
15 : use tokio::time::Instant;
16 : use tokio_postgres::tls::NoTlsStream;
17 : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
18 : use tokio_util::sync::CancellationToken;
19 :
20 : use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
21 : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
22 : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
23 : use crate::{
24 : auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
25 : };
26 :
27 : use tracing::{debug, error, warn, Span};
28 : use tracing::{info, info_span, Instrument};
29 :
30 : use super::backend::HttpConnError;
31 :
32 : #[derive(Debug, Clone)]
33 : pub struct ConnInfo {
34 : pub user_info: ComputeUserInfo,
35 : pub dbname: DbName,
36 : pub password: SmallVec<[u8; 16]>,
37 : }
38 :
39 : impl ConnInfo {
40 : // hm, change to hasher to avoid cloning?
41 6 : pub fn db_and_user(&self) -> (DbName, RoleName) {
42 6 : (self.dbname.clone(), self.user_info.user.clone())
43 6 : }
44 :
45 4 : pub fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
46 4 : // We don't want to cache http connections for ephemeral endpoints.
47 4 : if self.user_info.options.is_ephemeral() {
48 0 : None
49 : } else {
50 4 : Some(self.user_info.endpoint_cache_key())
51 : }
52 4 : }
53 : }
54 :
55 : impl fmt::Display for ConnInfo {
56 : // use custom display to avoid logging password
57 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 0 : write!(
59 0 : f,
60 0 : "{}@{}/{}?{}",
61 0 : self.user_info.user,
62 0 : self.user_info.endpoint,
63 0 : self.dbname,
64 0 : self.user_info.options.get_cache_key("")
65 0 : )
66 0 : }
67 : }
68 :
69 : struct ConnPoolEntry<C: ClientInnerExt> {
70 : conn: ClientInner<C>,
71 : _last_access: std::time::Instant,
72 : }
73 :
74 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
75 : // Number of open connections is limited by the `max_conns_per_endpoint`.
76 : pub struct EndpointConnPool<C: ClientInnerExt> {
77 : pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
78 : total_conns: usize,
79 : max_conns: usize,
80 : _guard: HttpEndpointPoolsGuard<'static>,
81 : global_connections_count: Arc<AtomicUsize>,
82 : global_pool_size_max_conns: usize,
83 : }
84 :
85 : impl<C: ClientInnerExt> EndpointConnPool<C> {
86 0 : fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
87 0 : let Self {
88 0 : pools,
89 0 : total_conns,
90 0 : global_connections_count,
91 0 : ..
92 0 : } = self;
93 0 : pools.get_mut(&db_user).and_then(|pool_entries| {
94 0 : pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
95 0 : })
96 0 : }
97 :
98 0 : fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
99 0 : let Self {
100 0 : pools,
101 0 : total_conns,
102 0 : global_connections_count,
103 0 : ..
104 0 : } = self;
105 0 : if let Some(pool) = pools.get_mut(&db_user) {
106 0 : let old_len = pool.conns.len();
107 0 : pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
108 0 : let new_len = pool.conns.len();
109 0 : let removed = old_len - new_len;
110 0 : if removed > 0 {
111 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
112 0 : Metrics::get()
113 0 : .proxy
114 0 : .http_pool_opened_connections
115 0 : .get_metric()
116 0 : .dec_by(removed as i64);
117 0 : }
118 0 : *total_conns -= removed;
119 0 : removed > 0
120 : } else {
121 0 : false
122 : }
123 0 : }
124 :
125 12 : fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
126 12 : let conn_id = client.conn_id;
127 12 :
128 12 : if client.is_closed() {
129 2 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
130 2 : return;
131 10 : }
132 10 : let global_max_conn = pool.read().global_pool_size_max_conns;
133 10 : if pool
134 10 : .read()
135 10 : .global_connections_count
136 10 : .load(atomic::Ordering::Relaxed)
137 10 : >= global_max_conn
138 : {
139 2 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
140 2 : return;
141 8 : }
142 8 :
143 8 : // return connection to the pool
144 8 : let mut returned = false;
145 8 : let mut per_db_size = 0;
146 8 : let total_conns = {
147 8 : let mut pool = pool.write();
148 8 :
149 8 : if pool.total_conns < pool.max_conns {
150 6 : let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
151 6 : pool_entries.conns.push(ConnPoolEntry {
152 6 : conn: client,
153 6 : _last_access: std::time::Instant::now(),
154 6 : });
155 6 :
156 6 : returned = true;
157 6 : per_db_size = pool_entries.conns.len();
158 6 :
159 6 : pool.total_conns += 1;
160 6 : pool.global_connections_count
161 6 : .fetch_add(1, atomic::Ordering::Relaxed);
162 6 : Metrics::get()
163 6 : .proxy
164 6 : .http_pool_opened_connections
165 6 : .get_metric()
166 6 : .inc();
167 6 : }
168 :
169 8 : pool.total_conns
170 8 : };
171 8 :
172 8 : // do logging outside of the mutex
173 8 : if returned {
174 6 : info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
175 : } else {
176 2 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
177 : }
178 12 : }
179 : }
180 :
181 : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
182 4 : fn drop(&mut self) {
183 4 : if self.total_conns > 0 {
184 4 : self.global_connections_count
185 4 : .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
186 4 : Metrics::get()
187 4 : .proxy
188 4 : .http_pool_opened_connections
189 4 : .get_metric()
190 4 : .dec_by(self.total_conns as i64);
191 4 : }
192 4 : }
193 : }
194 :
195 : pub struct DbUserConnPool<C: ClientInnerExt> {
196 : conns: Vec<ConnPoolEntry<C>>,
197 : }
198 :
199 : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
200 4 : fn default() -> Self {
201 4 : Self { conns: Vec::new() }
202 4 : }
203 : }
204 :
205 : impl<C: ClientInnerExt> DbUserConnPool<C> {
206 2 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
207 2 : let old_len = self.conns.len();
208 2 :
209 4 : self.conns.retain(|conn| !conn.conn.is_closed());
210 2 :
211 2 : let new_len = self.conns.len();
212 2 : let removed = old_len - new_len;
213 2 : *conns -= removed;
214 2 : removed
215 2 : }
216 :
217 0 : fn get_conn_entry(
218 0 : &mut self,
219 0 : conns: &mut usize,
220 0 : global_connections_count: Arc<AtomicUsize>,
221 0 : ) -> Option<ConnPoolEntry<C>> {
222 0 : let mut removed = self.clear_closed_clients(conns);
223 0 : let conn = self.conns.pop();
224 0 : if conn.is_some() {
225 0 : *conns -= 1;
226 0 : removed += 1;
227 0 : }
228 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
229 0 : Metrics::get()
230 0 : .proxy
231 0 : .http_pool_opened_connections
232 0 : .get_metric()
233 0 : .dec_by(removed as i64);
234 0 : conn
235 0 : }
236 : }
237 :
238 : pub struct GlobalConnPool<C: ClientInnerExt> {
239 : // endpoint -> per-endpoint connection pool
240 : //
241 : // That should be a fairly conteded map, so return reference to the per-endpoint
242 : // pool as early as possible and release the lock.
243 : global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
244 :
245 : /// Number of endpoint-connection pools
246 : ///
247 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
248 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
249 : /// It's only used for diagnostics.
250 : global_pool_size: AtomicUsize,
251 :
252 : /// Total number of connections in the pool
253 : global_connections_count: Arc<AtomicUsize>,
254 :
255 : config: &'static crate::config::HttpConfig,
256 : }
257 :
258 : #[derive(Debug, Clone, Copy)]
259 : pub struct GlobalConnPoolOptions {
260 : // Maximum number of connections per one endpoint.
261 : // Can mix different (dbname, username) connections.
262 : // When running out of free slots for a particular endpoint,
263 : // falls back to opening a new connection for each request.
264 : pub max_conns_per_endpoint: usize,
265 :
266 : pub gc_epoch: Duration,
267 :
268 : pub pool_shards: usize,
269 :
270 : pub idle_timeout: Duration,
271 :
272 : pub opt_in: bool,
273 :
274 : // Total number of connections in the pool.
275 : pub max_total_conns: usize,
276 : }
277 :
278 : impl<C: ClientInnerExt> GlobalConnPool<C> {
279 2 : pub fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
280 2 : let shards = config.pool_options.pool_shards;
281 2 : Arc::new(Self {
282 2 : global_pool: DashMap::with_shard_amount(shards),
283 2 : global_pool_size: AtomicUsize::new(0),
284 2 : config,
285 2 : global_connections_count: Arc::new(AtomicUsize::new(0)),
286 2 : })
287 2 : }
288 :
289 : #[cfg(test)]
290 18 : pub fn get_global_connections_count(&self) -> usize {
291 18 : self.global_connections_count
292 18 : .load(atomic::Ordering::Relaxed)
293 18 : }
294 :
295 0 : pub fn get_idle_timeout(&self) -> Duration {
296 0 : self.config.pool_options.idle_timeout
297 0 : }
298 :
299 0 : pub fn shutdown(&self) {
300 0 : // drops all strong references to endpoint-pools
301 0 : self.global_pool.clear();
302 0 : }
303 :
304 0 : pub async fn gc_worker(&self, mut rng: impl Rng) {
305 0 : let epoch = self.config.pool_options.gc_epoch;
306 0 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
307 : loop {
308 0 : interval.tick().await;
309 :
310 0 : let shard = rng.gen_range(0..self.global_pool.shards().len());
311 0 : self.gc(shard);
312 : }
313 : }
314 :
315 4 : fn gc(&self, shard: usize) {
316 4 : debug!(shard, "pool: performing epoch reclamation");
317 :
318 : // acquire a random shard lock
319 4 : let mut shard = self.global_pool.shards()[shard].write();
320 4 :
321 4 : let timer = Metrics::get()
322 4 : .proxy
323 4 : .http_pool_reclaimation_lag_seconds
324 4 : .start_timer();
325 4 : let current_len = shard.len();
326 4 : let mut clients_removed = 0;
327 4 : shard.retain(|endpoint, x| {
328 : // if the current endpoint pool is unique (no other strong or weak references)
329 : // then it is currently not in use by any connections.
330 4 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
331 : let EndpointConnPool {
332 2 : pools, total_conns, ..
333 2 : } = pool.get_mut();
334 2 :
335 2 : // ensure that closed clients are removed
336 2 : pools.iter_mut().for_each(|(_, db_pool)| {
337 2 : clients_removed += db_pool.clear_closed_clients(total_conns);
338 2 : });
339 2 :
340 2 : // we only remove this pool if it has no active connections
341 2 : if *total_conns == 0 {
342 0 : info!("pool: discarding pool for endpoint {endpoint}");
343 0 : return false;
344 2 : }
345 2 : }
346 :
347 4 : true
348 4 : });
349 4 :
350 4 : let new_len = shard.len();
351 4 : drop(shard);
352 4 : timer.observe();
353 4 :
354 4 : // Do logging outside of the lock.
355 4 : if clients_removed > 0 {
356 2 : let size = self
357 2 : .global_connections_count
358 2 : .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
359 2 : - clients_removed;
360 2 : Metrics::get()
361 2 : .proxy
362 2 : .http_pool_opened_connections
363 2 : .get_metric()
364 2 : .dec_by(clients_removed as i64);
365 2 : info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
366 2 : }
367 4 : let removed = current_len - new_len;
368 4 :
369 4 : if removed > 0 {
370 0 : let global_pool_size = self
371 0 : .global_pool_size
372 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
373 0 : - removed;
374 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
375 4 : }
376 4 : }
377 :
378 0 : pub fn get(
379 0 : self: &Arc<Self>,
380 0 : ctx: &mut RequestMonitoring,
381 0 : conn_info: &ConnInfo,
382 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
383 0 : let mut client: Option<ClientInner<C>> = None;
384 0 : let Some(endpoint) = conn_info.endpoint_cache_key() else {
385 0 : return Ok(None);
386 : };
387 :
388 0 : let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
389 0 : if let Some(entry) = endpoint_pool
390 0 : .write()
391 0 : .get_conn_entry(conn_info.db_and_user())
392 : {
393 0 : client = Some(entry.conn)
394 0 : }
395 0 : let endpoint_pool = Arc::downgrade(&endpoint_pool);
396 :
397 : // ok return cached connection if found and establish a new one otherwise
398 0 : if let Some(client) = client {
399 0 : if client.is_closed() {
400 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
401 0 : return Ok(None);
402 : } else {
403 0 : tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
404 0 : tracing::Span::current().record(
405 0 : "pid",
406 0 : &tracing::field::display(client.inner.get_process_id()),
407 0 : );
408 0 : info!(
409 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
410 0 : "pool: reusing connection '{conn_info}'"
411 : );
412 0 : client.session.send(ctx.session_id)?;
413 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
414 0 : ctx.latency_timer.success();
415 0 : return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
416 : }
417 0 : }
418 0 : Ok(None)
419 0 : }
420 :
421 4 : fn get_or_create_endpoint_pool(
422 4 : self: &Arc<Self>,
423 4 : endpoint: &EndpointCacheKey,
424 4 : ) -> Arc<RwLock<EndpointConnPool<C>>> {
425 : // fast path
426 4 : if let Some(pool) = self.global_pool.get(endpoint) {
427 0 : return pool.clone();
428 4 : }
429 4 :
430 4 : // slow path
431 4 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
432 4 : pools: HashMap::new(),
433 4 : total_conns: 0,
434 4 : max_conns: self.config.pool_options.max_conns_per_endpoint,
435 4 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
436 4 : global_connections_count: self.global_connections_count.clone(),
437 4 : global_pool_size_max_conns: self.config.pool_options.max_total_conns,
438 4 : }));
439 4 :
440 4 : // find or create a pool for this endpoint
441 4 : let mut created = false;
442 4 : let pool = self
443 4 : .global_pool
444 4 : .entry(endpoint.clone())
445 4 : .or_insert_with(|| {
446 4 : created = true;
447 4 : new_pool
448 4 : })
449 4 : .clone();
450 4 :
451 4 : // log new global pool size
452 4 : if created {
453 4 : let global_pool_size = self
454 4 : .global_pool_size
455 4 : .fetch_add(1, atomic::Ordering::Relaxed)
456 4 : + 1;
457 4 : info!(
458 0 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
459 : );
460 0 : }
461 :
462 4 : pool
463 4 : }
464 : }
465 :
466 0 : pub fn poll_client<C: ClientInnerExt>(
467 0 : global_pool: Arc<GlobalConnPool<C>>,
468 0 : ctx: &mut RequestMonitoring,
469 0 : conn_info: ConnInfo,
470 0 : client: C,
471 0 : mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
472 0 : conn_id: uuid::Uuid,
473 0 : aux: MetricsAuxInfo,
474 0 : ) -> Client<C> {
475 0 : let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol);
476 0 : let mut session_id = ctx.session_id;
477 0 : let (tx, mut rx) = tokio::sync::watch::channel(session_id);
478 :
479 0 : let span = info_span!(parent: None, "connection", %conn_id);
480 0 : let cold_start_info = ctx.cold_start_info;
481 0 : span.in_scope(|| {
482 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
483 0 : });
484 0 : let pool = match conn_info.endpoint_cache_key() {
485 0 : Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
486 0 : None => Weak::new(),
487 : };
488 0 : let pool_clone = pool.clone();
489 0 :
490 0 : let db_user = conn_info.db_and_user();
491 0 : let idle = global_pool.get_idle_timeout();
492 0 : let cancel = CancellationToken::new();
493 0 : let cancelled = cancel.clone().cancelled_owned();
494 0 :
495 0 : tokio::spawn(
496 0 : async move {
497 0 : let _conn_gauge = conn_gauge;
498 0 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
499 0 : let mut cancelled = pin!(cancelled);
500 0 :
501 0 : poll_fn(move |cx| {
502 0 : if cancelled.as_mut().poll(cx).is_ready() {
503 0 : info!("connection dropped");
504 0 : return Poll::Ready(())
505 0 : }
506 0 :
507 0 : match rx.has_changed() {
508 : Ok(true) => {
509 0 : session_id = *rx.borrow_and_update();
510 0 : info!(%session_id, "changed session");
511 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
512 : }
513 : Err(_) => {
514 0 : info!("connection dropped");
515 0 : return Poll::Ready(())
516 : }
517 0 : _ => {}
518 : }
519 :
520 : // 5 minute idle connection timeout
521 0 : if idle_timeout.as_mut().poll(cx).is_ready() {
522 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
523 0 : info!("connection idle");
524 0 : if let Some(pool) = pool.clone().upgrade() {
525 : // remove client from pool - should close the connection if it's idle.
526 : // does nothing if the client is currently checked-out and in-use
527 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
528 0 : info!("idle connection removed");
529 0 : }
530 0 : }
531 0 : }
532 :
533 : loop {
534 0 : let message = ready!(connection.poll_message(cx));
535 :
536 0 : match message {
537 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
538 0 : info!(%session_id, "notice: {}", notice);
539 : }
540 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
541 0 : warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
542 : }
543 : Some(Ok(_)) => {
544 0 : warn!(%session_id, "unknown message");
545 : }
546 0 : Some(Err(e)) => {
547 0 : error!(%session_id, "connection error: {}", e);
548 0 : break
549 : }
550 : None => {
551 0 : info!("connection closed");
552 0 : break
553 : }
554 : }
555 : }
556 :
557 : // remove from connection pool
558 0 : if let Some(pool) = pool.clone().upgrade() {
559 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
560 0 : info!("closed connection removed");
561 0 : }
562 0 : }
563 :
564 0 : Poll::Ready(())
565 0 : }).await;
566 :
567 0 : }
568 0 : .instrument(span));
569 0 : let inner = ClientInner {
570 0 : inner: client,
571 0 : session: tx,
572 0 : cancel,
573 0 : aux,
574 0 : conn_id,
575 0 : };
576 0 : Client::new(inner, conn_info, pool_clone)
577 0 : }
578 :
579 : struct ClientInner<C: ClientInnerExt> {
580 : inner: C,
581 : session: tokio::sync::watch::Sender<uuid::Uuid>,
582 : cancel: CancellationToken,
583 : aux: MetricsAuxInfo,
584 : conn_id: uuid::Uuid,
585 : }
586 :
587 : impl<C: ClientInnerExt> Drop for ClientInner<C> {
588 14 : fn drop(&mut self) {
589 14 : // on client drop, tell the conn to shut down
590 14 : self.cancel.cancel();
591 14 : }
592 : }
593 :
594 : pub trait ClientInnerExt: Sync + Send + 'static {
595 : fn is_closed(&self) -> bool;
596 : fn get_process_id(&self) -> i32;
597 : }
598 :
599 : impl ClientInnerExt for tokio_postgres::Client {
600 0 : fn is_closed(&self) -> bool {
601 0 : self.is_closed()
602 0 : }
603 0 : fn get_process_id(&self) -> i32 {
604 0 : self.get_process_id()
605 0 : }
606 : }
607 :
608 : impl<C: ClientInnerExt> ClientInner<C> {
609 16 : pub fn is_closed(&self) -> bool {
610 16 : self.inner.is_closed()
611 16 : }
612 : }
613 :
614 : impl<C: ClientInnerExt> Client<C> {
615 0 : pub fn metrics(&self) -> Arc<MetricCounter> {
616 0 : let aux = &self.inner.as_ref().unwrap().aux;
617 0 : USAGE_METRICS.register(Ids {
618 0 : endpoint_id: aux.endpoint_id,
619 0 : branch_id: aux.branch_id,
620 0 : })
621 0 : }
622 : }
623 :
624 : pub struct Client<C: ClientInnerExt> {
625 : span: Span,
626 : inner: Option<ClientInner<C>>,
627 : conn_info: ConnInfo,
628 : pool: Weak<RwLock<EndpointConnPool<C>>>,
629 : }
630 :
631 : pub struct Discard<'a, C: ClientInnerExt> {
632 : conn_info: &'a ConnInfo,
633 : pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
634 : }
635 :
636 : impl<C: ClientInnerExt> Client<C> {
637 14 : pub(self) fn new(
638 14 : inner: ClientInner<C>,
639 14 : conn_info: ConnInfo,
640 14 : pool: Weak<RwLock<EndpointConnPool<C>>>,
641 14 : ) -> Self {
642 14 : Self {
643 14 : inner: Some(inner),
644 14 : span: Span::current(),
645 14 : conn_info,
646 14 : pool,
647 14 : }
648 14 : }
649 2 : pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
650 2 : let Self {
651 2 : inner,
652 2 : pool,
653 2 : conn_info,
654 2 : span: _,
655 2 : } = self;
656 2 : let inner = inner.as_mut().expect("client inner should not be removed");
657 2 : (&mut inner.inner, Discard { pool, conn_info })
658 2 : }
659 : }
660 :
661 : impl<C: ClientInnerExt> Discard<'_, C> {
662 0 : pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
663 0 : let conn_info = &self.conn_info;
664 0 : if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
665 0 : info!("pool: throwing away connection '{conn_info}' because connection is not idle")
666 0 : }
667 0 : }
668 2 : pub fn discard(&mut self) {
669 2 : let conn_info = &self.conn_info;
670 2 : if std::mem::take(self.pool).strong_count() > 0 {
671 2 : info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
672 0 : }
673 2 : }
674 : }
675 :
676 : impl<C: ClientInnerExt> Deref for Client<C> {
677 : type Target = C;
678 :
679 0 : fn deref(&self) -> &Self::Target {
680 0 : &self
681 0 : .inner
682 0 : .as_ref()
683 0 : .expect("client inner should not be removed")
684 0 : .inner
685 0 : }
686 : }
687 :
688 : impl<C: ClientInnerExt> Client<C> {
689 14 : fn do_drop(&mut self) -> Option<impl FnOnce()> {
690 14 : let conn_info = self.conn_info.clone();
691 14 : let client = self
692 14 : .inner
693 14 : .take()
694 14 : .expect("client inner should not be removed");
695 14 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
696 12 : let current_span = self.span.clone();
697 12 : // return connection to the pool
698 12 : return Some(move || {
699 12 : let _span = current_span.enter();
700 12 : EndpointConnPool::put(&conn_pool, &conn_info, client);
701 12 : });
702 2 : }
703 2 : None
704 14 : }
705 : }
706 :
707 : impl<C: ClientInnerExt> Drop for Client<C> {
708 2 : fn drop(&mut self) {
709 2 : if let Some(drop) = self.do_drop() {
710 0 : tokio::task::spawn_blocking(drop);
711 2 : }
712 2 : }
713 : }
714 :
715 : #[cfg(test)]
716 : mod tests {
717 : use std::{mem, sync::atomic::AtomicBool};
718 :
719 : use crate::{serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId};
720 :
721 : use super::*;
722 :
723 : struct MockClient(Arc<AtomicBool>);
724 : impl MockClient {
725 12 : fn new(is_closed: bool) -> Self {
726 12 : MockClient(Arc::new(is_closed.into()))
727 12 : }
728 : }
729 : impl ClientInnerExt for MockClient {
730 16 : fn is_closed(&self) -> bool {
731 16 : self.0.load(atomic::Ordering::Relaxed)
732 16 : }
733 0 : fn get_process_id(&self) -> i32 {
734 0 : 0
735 0 : }
736 : }
737 :
738 10 : fn create_inner() -> ClientInner<MockClient> {
739 10 : create_inner_with(MockClient::new(false))
740 10 : }
741 :
742 14 : fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
743 14 : ClientInner {
744 14 : inner: client,
745 14 : session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
746 14 : cancel: CancellationToken::new(),
747 14 : aux: MetricsAuxInfo {
748 14 : endpoint_id: (&EndpointId::from("endpoint")).into(),
749 14 : project_id: (&ProjectId::from("project")).into(),
750 14 : branch_id: (&BranchId::from("branch")).into(),
751 14 : cold_start_info: crate::console::messages::ColdStartInfo::Warm,
752 14 : },
753 14 : conn_id: uuid::Uuid::new_v4(),
754 14 : }
755 14 : }
756 :
757 : #[tokio::test]
758 2 : async fn test_pool() {
759 2 : let _ = env_logger::try_init();
760 2 : let config = Box::leak(Box::new(crate::config::HttpConfig {
761 2 : pool_options: GlobalConnPoolOptions {
762 2 : max_conns_per_endpoint: 2,
763 2 : gc_epoch: Duration::from_secs(1),
764 2 : pool_shards: 2,
765 2 : idle_timeout: Duration::from_secs(1),
766 2 : opt_in: false,
767 2 : max_total_conns: 3,
768 2 : },
769 2 : request_timeout: Duration::from_secs(1),
770 2 : cancel_set: CancelSet::new(0),
771 2 : client_conn_threshold: u64::MAX,
772 2 : }));
773 2 : let pool = GlobalConnPool::new(config);
774 2 : let conn_info = ConnInfo {
775 2 : user_info: ComputeUserInfo {
776 2 : user: "user".into(),
777 2 : endpoint: "endpoint".into(),
778 2 : options: Default::default(),
779 2 : },
780 2 : dbname: "dbname".into(),
781 2 : password: "password".as_bytes().into(),
782 2 : };
783 2 : let ep_pool = Arc::downgrade(
784 2 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
785 2 : );
786 2 : {
787 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
788 2 : assert_eq!(0, pool.get_global_connections_count());
789 2 : client.inner().1.discard();
790 2 : // Discard should not add the connection from the pool.
791 2 : assert_eq!(0, pool.get_global_connections_count());
792 2 : }
793 2 : {
794 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
795 2 : client.do_drop().unwrap()();
796 2 : mem::forget(client); // drop the client
797 2 : assert_eq!(1, pool.get_global_connections_count());
798 2 : }
799 2 : {
800 2 : let mut closed_client = Client::new(
801 2 : create_inner_with(MockClient::new(true)),
802 2 : conn_info.clone(),
803 2 : ep_pool.clone(),
804 2 : );
805 2 : closed_client.do_drop().unwrap()();
806 2 : mem::forget(closed_client); // drop the client
807 2 : // The closed client shouldn't be added to the pool.
808 2 : assert_eq!(1, pool.get_global_connections_count());
809 2 : }
810 2 : let is_closed: Arc<AtomicBool> = Arc::new(false.into());
811 2 : {
812 2 : let mut client = Client::new(
813 2 : create_inner_with(MockClient(is_closed.clone())),
814 2 : conn_info.clone(),
815 2 : ep_pool.clone(),
816 2 : );
817 2 : client.do_drop().unwrap()();
818 2 : mem::forget(client); // drop the client
819 2 :
820 2 : // The client should be added to the pool.
821 2 : assert_eq!(2, pool.get_global_connections_count());
822 2 : }
823 2 : {
824 2 : let mut client = Client::new(create_inner(), conn_info, ep_pool);
825 2 : client.do_drop().unwrap()();
826 2 : mem::forget(client); // drop the client
827 2 :
828 2 : // The client shouldn't be added to the pool. Because the ep-pool is full.
829 2 : assert_eq!(2, pool.get_global_connections_count());
830 2 : }
831 2 :
832 2 : let conn_info = ConnInfo {
833 2 : user_info: ComputeUserInfo {
834 2 : user: "user".into(),
835 2 : endpoint: "endpoint-2".into(),
836 2 : options: Default::default(),
837 2 : },
838 2 : dbname: "dbname".into(),
839 2 : password: "password".as_bytes().into(),
840 2 : };
841 2 : let ep_pool = Arc::downgrade(
842 2 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
843 2 : );
844 2 : {
845 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
846 2 : client.do_drop().unwrap()();
847 2 : mem::forget(client); // drop the client
848 2 : assert_eq!(3, pool.get_global_connections_count());
849 2 : }
850 2 : {
851 2 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
852 2 : client.do_drop().unwrap()();
853 2 : mem::forget(client); // drop the client
854 2 :
855 2 : // The client shouldn't be added to the pool. Because the global pool is full.
856 2 : assert_eq!(3, pool.get_global_connections_count());
857 2 : }
858 2 :
859 2 : is_closed.store(true, atomic::Ordering::Relaxed);
860 2 : // Do gc for all shards.
861 2 : pool.gc(0);
862 2 : pool.gc(1);
863 2 : // Closed client should be removed from the pool.
864 2 : assert_eq!(2, pool.get_global_connections_count());
865 2 : }
866 : }
|