Line data Source code
1 : use std::collections::HashMap;
2 : use std::ops::Deref;
3 : use std::sync::atomic::{self, AtomicUsize};
4 : use std::sync::{Arc, Weak};
5 : use std::time::Duration;
6 :
7 : use dashmap::DashMap;
8 : use parking_lot::RwLock;
9 : use rand::Rng;
10 : use tokio_postgres::ReadyForQueryStatus;
11 : use tracing::{debug, info, Span};
12 :
13 : use super::backend::HttpConnError;
14 : use super::conn_pool::ClientInnerRemote;
15 : use crate::auth::backend::ComputeUserInfo;
16 : use crate::context::RequestMonitoring;
17 : use crate::control_plane::messages::ColdStartInfo;
18 : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
19 : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
20 : use crate::{DbName, EndpointCacheKey, RoleName};
21 :
22 : #[derive(Debug, Clone)]
23 : pub(crate) struct ConnInfo {
24 : pub(crate) user_info: ComputeUserInfo,
25 : pub(crate) dbname: DbName,
26 : }
27 :
28 : impl ConnInfo {
29 : // hm, change to hasher to avoid cloning?
30 3 : pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
31 3 : (self.dbname.clone(), self.user_info.user.clone())
32 3 : }
33 :
34 2 : pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
35 2 : // We don't want to cache http connections for ephemeral endpoints.
36 2 : if self.user_info.options.is_ephemeral() {
37 0 : None
38 : } else {
39 2 : Some(self.user_info.endpoint_cache_key())
40 : }
41 2 : }
42 : }
43 :
44 : pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
45 : pub(crate) conn: ClientInnerRemote<C>,
46 : pub(crate) _last_access: std::time::Instant,
47 : }
48 :
49 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
50 : // Number of open connections is limited by the `max_conns_per_endpoint`.
51 : pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
52 : pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
53 : total_conns: usize,
54 : max_conns: usize,
55 : _guard: HttpEndpointPoolsGuard<'static>,
56 : global_connections_count: Arc<AtomicUsize>,
57 : global_pool_size_max_conns: usize,
58 : }
59 :
60 : impl<C: ClientInnerExt> EndpointConnPool<C> {
61 0 : fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
62 0 : let Self {
63 0 : pools,
64 0 : total_conns,
65 0 : global_connections_count,
66 0 : ..
67 0 : } = self;
68 0 : pools.get_mut(&db_user).and_then(|pool_entries| {
69 0 : let (entry, removed) = pool_entries.get_conn_entry(total_conns);
70 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
71 0 : entry
72 0 : })
73 0 : }
74 :
75 0 : pub(crate) fn remove_client(
76 0 : &mut self,
77 0 : db_user: (DbName, RoleName),
78 0 : conn_id: uuid::Uuid,
79 0 : ) -> bool {
80 0 : let Self {
81 0 : pools,
82 0 : total_conns,
83 0 : global_connections_count,
84 0 : ..
85 0 : } = self;
86 0 : if let Some(pool) = pools.get_mut(&db_user) {
87 0 : let old_len = pool.conns.len();
88 0 : pool.conns.retain(|conn| conn.conn.get_conn_id() != conn_id);
89 0 : let new_len = pool.conns.len();
90 0 : let removed = old_len - new_len;
91 0 : if removed > 0 {
92 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
93 0 : Metrics::get()
94 0 : .proxy
95 0 : .http_pool_opened_connections
96 0 : .get_metric()
97 0 : .dec_by(removed as i64);
98 0 : }
99 0 : *total_conns -= removed;
100 0 : removed > 0
101 : } else {
102 0 : false
103 : }
104 0 : }
105 :
106 6 : pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerRemote<C>) {
107 6 : let conn_id = client.get_conn_id();
108 6 :
109 6 : if client.is_closed() {
110 1 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
111 1 : return;
112 5 : }
113 5 :
114 5 : let global_max_conn = pool.read().global_pool_size_max_conns;
115 5 : if pool
116 5 : .read()
117 5 : .global_connections_count
118 5 : .load(atomic::Ordering::Relaxed)
119 5 : >= global_max_conn
120 : {
121 1 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
122 1 : return;
123 4 : }
124 4 :
125 4 : // return connection to the pool
126 4 : let mut returned = false;
127 4 : let mut per_db_size = 0;
128 4 : let total_conns = {
129 4 : let mut pool = pool.write();
130 4 :
131 4 : if pool.total_conns < pool.max_conns {
132 3 : let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
133 3 : pool_entries.conns.push(ConnPoolEntry {
134 3 : conn: client,
135 3 : _last_access: std::time::Instant::now(),
136 3 : });
137 3 :
138 3 : returned = true;
139 3 : per_db_size = pool_entries.conns.len();
140 3 :
141 3 : pool.total_conns += 1;
142 3 : pool.global_connections_count
143 3 : .fetch_add(1, atomic::Ordering::Relaxed);
144 3 : Metrics::get()
145 3 : .proxy
146 3 : .http_pool_opened_connections
147 3 : .get_metric()
148 3 : .inc();
149 3 : }
150 :
151 4 : pool.total_conns
152 4 : };
153 4 :
154 4 : // do logging outside of the mutex
155 4 : if returned {
156 3 : info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
157 : } else {
158 1 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
159 : }
160 6 : }
161 : }
162 :
163 : impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
164 2 : fn drop(&mut self) {
165 2 : if self.total_conns > 0 {
166 2 : self.global_connections_count
167 2 : .fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
168 2 : Metrics::get()
169 2 : .proxy
170 2 : .http_pool_opened_connections
171 2 : .get_metric()
172 2 : .dec_by(self.total_conns as i64);
173 2 : }
174 2 : }
175 : }
176 :
177 : pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
178 : pub(crate) conns: Vec<ConnPoolEntry<C>>,
179 : }
180 :
181 : impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
182 2 : fn default() -> Self {
183 2 : Self { conns: Vec::new() }
184 2 : }
185 : }
186 :
187 : impl<C: ClientInnerExt> DbUserConnPool<C> {
188 1 : fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
189 1 : let old_len = self.conns.len();
190 1 :
191 2 : self.conns.retain(|conn| !conn.conn.is_closed());
192 1 :
193 1 : let new_len = self.conns.len();
194 1 : let removed = old_len - new_len;
195 1 : *conns -= removed;
196 1 : removed
197 1 : }
198 :
199 0 : pub(crate) fn get_conn_entry(
200 0 : &mut self,
201 0 : conns: &mut usize,
202 0 : ) -> (Option<ConnPoolEntry<C>>, usize) {
203 0 : let mut removed = self.clear_closed_clients(conns);
204 0 : let conn = self.conns.pop();
205 0 : if conn.is_some() {
206 0 : *conns -= 1;
207 0 : removed += 1;
208 0 : }
209 :
210 0 : Metrics::get()
211 0 : .proxy
212 0 : .http_pool_opened_connections
213 0 : .get_metric()
214 0 : .dec_by(removed as i64);
215 0 :
216 0 : (conn, removed)
217 0 : }
218 : }
219 :
220 : pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
221 : // endpoint -> per-endpoint connection pool
222 : //
223 : // That should be a fairly conteded map, so return reference to the per-endpoint
224 : // pool as early as possible and release the lock.
225 : global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
226 :
227 : /// Number of endpoint-connection pools
228 : ///
229 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
230 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
231 : /// It's only used for diagnostics.
232 : global_pool_size: AtomicUsize,
233 :
234 : /// Total number of connections in the pool
235 : global_connections_count: Arc<AtomicUsize>,
236 :
237 : config: &'static crate::config::HttpConfig,
238 : }
239 :
240 : #[derive(Debug, Clone, Copy)]
241 : pub struct GlobalConnPoolOptions {
242 : // Maximum number of connections per one endpoint.
243 : // Can mix different (dbname, username) connections.
244 : // When running out of free slots for a particular endpoint,
245 : // falls back to opening a new connection for each request.
246 : pub max_conns_per_endpoint: usize,
247 :
248 : pub gc_epoch: Duration,
249 :
250 : pub pool_shards: usize,
251 :
252 : pub idle_timeout: Duration,
253 :
254 : pub opt_in: bool,
255 :
256 : // Total number of connections in the pool.
257 : pub max_total_conns: usize,
258 : }
259 :
260 : impl<C: ClientInnerExt> GlobalConnPool<C> {
261 1 : pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
262 1 : let shards = config.pool_options.pool_shards;
263 1 : Arc::new(Self {
264 1 : global_pool: DashMap::with_shard_amount(shards),
265 1 : global_pool_size: AtomicUsize::new(0),
266 1 : config,
267 1 : global_connections_count: Arc::new(AtomicUsize::new(0)),
268 1 : })
269 1 : }
270 :
271 : #[cfg(test)]
272 9 : pub(crate) fn get_global_connections_count(&self) -> usize {
273 9 : self.global_connections_count
274 9 : .load(atomic::Ordering::Relaxed)
275 9 : }
276 :
277 0 : pub(crate) fn get_idle_timeout(&self) -> Duration {
278 0 : self.config.pool_options.idle_timeout
279 0 : }
280 :
281 0 : pub(crate) fn shutdown(&self) {
282 0 : // drops all strong references to endpoint-pools
283 0 : self.global_pool.clear();
284 0 : }
285 :
286 0 : pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
287 0 : let epoch = self.config.pool_options.gc_epoch;
288 0 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
289 : loop {
290 0 : interval.tick().await;
291 :
292 0 : let shard = rng.gen_range(0..self.global_pool.shards().len());
293 0 : self.gc(shard);
294 : }
295 : }
296 :
297 2 : pub(crate) fn gc(&self, shard: usize) {
298 2 : debug!(shard, "pool: performing epoch reclamation");
299 :
300 : // acquire a random shard lock
301 2 : let mut shard = self.global_pool.shards()[shard].write();
302 2 :
303 2 : let timer = Metrics::get()
304 2 : .proxy
305 2 : .http_pool_reclaimation_lag_seconds
306 2 : .start_timer();
307 2 : let current_len = shard.len();
308 2 : let mut clients_removed = 0;
309 2 : shard.retain(|endpoint, x| {
310 : // if the current endpoint pool is unique (no other strong or weak references)
311 : // then it is currently not in use by any connections.
312 2 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
313 : let EndpointConnPool {
314 1 : pools, total_conns, ..
315 1 : } = pool.get_mut();
316 :
317 : // ensure that closed clients are removed
318 1 : for db_pool in pools.values_mut() {
319 1 : clients_removed += db_pool.clear_closed_clients(total_conns);
320 1 : }
321 :
322 : // we only remove this pool if it has no active connections
323 1 : if *total_conns == 0 {
324 0 : info!("pool: discarding pool for endpoint {endpoint}");
325 0 : return false;
326 1 : }
327 1 : }
328 :
329 2 : true
330 2 : });
331 2 :
332 2 : let new_len = shard.len();
333 2 : drop(shard);
334 2 : timer.observe();
335 2 :
336 2 : // Do logging outside of the lock.
337 2 : if clients_removed > 0 {
338 1 : let size = self
339 1 : .global_connections_count
340 1 : .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
341 1 : - clients_removed;
342 1 : Metrics::get()
343 1 : .proxy
344 1 : .http_pool_opened_connections
345 1 : .get_metric()
346 1 : .dec_by(clients_removed as i64);
347 1 : info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
348 1 : }
349 2 : let removed = current_len - new_len;
350 2 :
351 2 : if removed > 0 {
352 0 : let global_pool_size = self
353 0 : .global_pool_size
354 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
355 0 : - removed;
356 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
357 2 : }
358 2 : }
359 :
360 2 : pub(crate) fn get_or_create_endpoint_pool(
361 2 : self: &Arc<Self>,
362 2 : endpoint: &EndpointCacheKey,
363 2 : ) -> Arc<RwLock<EndpointConnPool<C>>> {
364 : // fast path
365 2 : if let Some(pool) = self.global_pool.get(endpoint) {
366 0 : return pool.clone();
367 2 : }
368 2 :
369 2 : // slow path
370 2 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
371 2 : pools: HashMap::new(),
372 2 : total_conns: 0,
373 2 : max_conns: self.config.pool_options.max_conns_per_endpoint,
374 2 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
375 2 : global_connections_count: self.global_connections_count.clone(),
376 2 : global_pool_size_max_conns: self.config.pool_options.max_total_conns,
377 2 : }));
378 2 :
379 2 : // find or create a pool for this endpoint
380 2 : let mut created = false;
381 2 : let pool = self
382 2 : .global_pool
383 2 : .entry(endpoint.clone())
384 2 : .or_insert_with(|| {
385 2 : created = true;
386 2 : new_pool
387 2 : })
388 2 : .clone();
389 2 :
390 2 : // log new global pool size
391 2 : if created {
392 2 : let global_pool_size = self
393 2 : .global_pool_size
394 2 : .fetch_add(1, atomic::Ordering::Relaxed)
395 2 : + 1;
396 2 : info!(
397 0 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
398 : );
399 0 : }
400 :
401 2 : pool
402 2 : }
403 :
404 0 : pub(crate) fn get(
405 0 : self: &Arc<Self>,
406 0 : ctx: &RequestMonitoring,
407 0 : conn_info: &ConnInfo,
408 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
409 0 : let mut client: Option<ClientInnerRemote<C>> = None;
410 0 : let Some(endpoint) = conn_info.endpoint_cache_key() else {
411 0 : return Ok(None);
412 : };
413 :
414 0 : let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
415 0 : if let Some(entry) = endpoint_pool
416 0 : .write()
417 0 : .get_conn_entry(conn_info.db_and_user())
418 0 : {
419 0 : client = Some(entry.conn);
420 0 : }
421 0 : let endpoint_pool = Arc::downgrade(&endpoint_pool);
422 :
423 : // ok return cached connection if found and establish a new one otherwise
424 0 : if let Some(mut client) = client {
425 0 : if client.is_closed() {
426 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
427 0 : return Ok(None);
428 0 : }
429 0 : tracing::Span::current()
430 0 : .record("conn_id", tracing::field::display(client.get_conn_id()));
431 0 : tracing::Span::current().record(
432 0 : "pid",
433 0 : tracing::field::display(client.inner().get_process_id()),
434 0 : );
435 0 : info!(
436 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
437 0 : "pool: reusing connection '{conn_info}'"
438 : );
439 :
440 0 : client.session().send(ctx.session_id())?;
441 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
442 0 : ctx.success();
443 0 : return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
444 0 : }
445 0 : Ok(None)
446 0 : }
447 : }
448 :
449 : impl<C: ClientInnerExt> Client<C> {
450 7 : pub(crate) fn new(
451 7 : inner: ClientInnerRemote<C>,
452 7 : conn_info: ConnInfo,
453 7 : pool: Weak<RwLock<EndpointConnPool<C>>>,
454 7 : ) -> Self {
455 7 : Self {
456 7 : inner: Some(inner),
457 7 : span: Span::current(),
458 7 : conn_info,
459 7 : pool,
460 7 : }
461 7 : }
462 :
463 1 : pub(crate) fn inner_mut(&mut self) -> (&mut C, Discard<'_, C>) {
464 1 : let Self {
465 1 : inner,
466 1 : pool,
467 1 : conn_info,
468 1 : span: _,
469 1 : } = self;
470 1 : let inner = inner.as_mut().expect("client inner should not be removed");
471 1 : let inner_ref = inner.inner_mut();
472 1 : (inner_ref, Discard { conn_info, pool })
473 1 : }
474 :
475 0 : pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
476 0 : let aux = &self.inner.as_ref().unwrap().aux();
477 0 : USAGE_METRICS.register(Ids {
478 0 : endpoint_id: aux.endpoint_id,
479 0 : branch_id: aux.branch_id,
480 0 : })
481 0 : }
482 :
483 7 : pub(crate) fn do_drop(&mut self) -> Option<impl FnOnce() + use<C>> {
484 7 : let conn_info = self.conn_info.clone();
485 7 : let client = self
486 7 : .inner
487 7 : .take()
488 7 : .expect("client inner should not be removed");
489 7 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
490 6 : let current_span = self.span.clone();
491 6 : // return connection to the pool
492 6 : return Some(move || {
493 6 : let _span = current_span.enter();
494 6 : EndpointConnPool::put(&conn_pool, &conn_info, client);
495 6 : });
496 1 : }
497 1 : None
498 7 : }
499 : }
500 :
501 : pub(crate) struct Client<C: ClientInnerExt> {
502 : span: Span,
503 : inner: Option<ClientInnerRemote<C>>,
504 : conn_info: ConnInfo,
505 : pool: Weak<RwLock<EndpointConnPool<C>>>,
506 : }
507 :
508 : impl<C: ClientInnerExt> Drop for Client<C> {
509 1 : fn drop(&mut self) {
510 1 : if let Some(drop) = self.do_drop() {
511 0 : tokio::task::spawn_blocking(drop);
512 1 : }
513 1 : }
514 : }
515 :
516 : impl<C: ClientInnerExt> Deref for Client<C> {
517 : type Target = C;
518 :
519 0 : fn deref(&self) -> &Self::Target {
520 0 : self.inner
521 0 : .as_ref()
522 0 : .expect("client inner should not be removed")
523 0 : .inner()
524 0 : }
525 : }
526 :
527 : pub(crate) trait ClientInnerExt: Sync + Send + 'static {
528 : fn is_closed(&self) -> bool;
529 : fn get_process_id(&self) -> i32;
530 : }
531 :
532 : impl ClientInnerExt for tokio_postgres::Client {
533 0 : fn is_closed(&self) -> bool {
534 0 : self.is_closed()
535 0 : }
536 :
537 0 : fn get_process_id(&self) -> i32 {
538 0 : self.get_process_id()
539 0 : }
540 : }
541 :
542 : pub(crate) struct Discard<'a, C: ClientInnerExt> {
543 : conn_info: &'a ConnInfo,
544 : pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
545 : }
546 :
547 : impl<C: ClientInnerExt> Discard<'_, C> {
548 0 : pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
549 0 : let conn_info = &self.conn_info;
550 0 : if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
551 0 : info!("pool: throwing away connection '{conn_info}' because connection is not idle");
552 0 : }
553 0 : }
554 1 : pub(crate) fn discard(&mut self) {
555 1 : let conn_info = &self.conn_info;
556 1 : if std::mem::take(self.pool).strong_count() > 0 {
557 1 : info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
558 0 : }
559 1 : }
560 : }
|