Line data Source code
1 : use anyhow::Context;
2 : use async_trait::async_trait;
3 : use dashmap::DashMap;
4 : use futures::{future::poll_fn, Future};
5 : use metrics::{register_int_counter_pair, IntCounterPair, IntCounterPairGuard};
6 : use once_cell::sync::Lazy;
7 : use parking_lot::RwLock;
8 : use pbkdf2::{
9 : password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
10 : Params, Pbkdf2,
11 : };
12 : use prometheus::{exponential_buckets, register_histogram, Histogram};
13 : use rand::Rng;
14 : use smol_str::SmolStr;
15 : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
16 : use std::{
17 : fmt,
18 : task::{ready, Poll},
19 : };
20 : use std::{
21 : ops::Deref,
22 : sync::atomic::{self, AtomicUsize},
23 : };
24 : use tokio::time::{self, Instant};
25 : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
26 :
27 : use crate::{
28 : auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list},
29 : console::{self, messages::MetricsAuxInfo},
30 : context::RequestMonitoring,
31 : metrics::NUM_DB_CONNECTIONS_GAUGE,
32 : proxy::connect_compute::ConnectMechanism,
33 : usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
34 : DbName, EndpointCacheKey, RoleName,
35 : };
36 : use crate::{compute, config};
37 :
38 : use tracing::{debug, error, warn, Span};
39 : use tracing::{info, info_span, Instrument};
40 :
41 : pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
42 :
43 43 : #[derive(Debug, Clone)]
44 : pub struct ConnInfo {
45 : pub user_info: ComputeUserInfo,
46 : pub dbname: DbName,
47 : pub password: SmolStr,
48 : }
49 :
50 : impl ConnInfo {
51 : // hm, change to hasher to avoid cloning?
52 110 : pub fn db_and_user(&self) -> (DbName, RoleName) {
53 110 : (self.dbname.clone(), self.user_info.user.clone())
54 110 : }
55 :
56 38 : pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
57 38 : self.user_info.endpoint_cache_key()
58 38 : }
59 : }
60 :
61 : impl fmt::Display for ConnInfo {
62 : // use custom display to avoid logging password
63 218 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 218 : write!(
65 218 : f,
66 218 : "{}@{}/{}?{}",
67 218 : self.user_info.user,
68 218 : self.user_info.endpoint,
69 218 : self.dbname,
70 218 : self.user_info.options.get_cache_key("")
71 218 : )
72 218 : }
73 : }
74 :
75 : struct ConnPoolEntry {
76 : conn: ClientInner,
77 : _last_access: std::time::Instant,
78 : }
79 :
80 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
81 : // Number of open connections is limited by the `max_conns_per_endpoint`.
82 : pub struct EndpointConnPool {
83 : pools: HashMap<(DbName, RoleName), DbUserConnPool>,
84 : total_conns: usize,
85 : max_conns: usize,
86 : _guard: IntCounterPairGuard,
87 : }
88 :
89 : impl EndpointConnPool {
90 13 : fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry> {
91 13 : let Self {
92 13 : pools, total_conns, ..
93 13 : } = self;
94 13 : pools
95 13 : .get_mut(&db_user)
96 13 : .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
97 13 : }
98 :
99 3 : fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
100 3 : let Self {
101 3 : pools, total_conns, ..
102 3 : } = self;
103 3 : if let Some(pool) = pools.get_mut(&db_user) {
104 3 : let old_len = pool.conns.len();
105 3 : pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
106 3 : let new_len = pool.conns.len();
107 3 : let removed = old_len - new_len;
108 3 : *total_conns -= removed;
109 3 : removed > 0
110 : } else {
111 0 : false
112 : }
113 3 : }
114 :
115 20 : fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> {
116 20 : let conn_id = client.conn_id;
117 20 :
118 20 : if client.inner.is_closed() {
119 0 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
120 0 : return Ok(());
121 20 : }
122 20 :
123 20 : // return connection to the pool
124 20 : let mut returned = false;
125 20 : let mut per_db_size = 0;
126 20 : let total_conns = {
127 20 : let mut pool = pool.write();
128 20 :
129 20 : if pool.total_conns < pool.max_conns {
130 : // we create this db-user entry in get, so it should not be None
131 20 : if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
132 20 : pool_entries.conns.push(ConnPoolEntry {
133 20 : conn: client,
134 20 : _last_access: std::time::Instant::now(),
135 20 : });
136 20 :
137 20 : returned = true;
138 20 : per_db_size = pool_entries.conns.len();
139 20 :
140 20 : pool.total_conns += 1;
141 20 : }
142 0 : }
143 :
144 20 : pool.total_conns
145 20 : };
146 20 :
147 20 : // do logging outside of the mutex
148 20 : if returned {
149 20 : info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
150 : } else {
151 0 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
152 : }
153 :
154 20 : Ok(())
155 20 : }
156 : }
157 :
158 : /// 4096 is the number of rounds that SCRAM-SHA-256 recommends.
159 : /// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway.
160 : ///
161 : /// Still takes 1.4ms to hash on my hardware.
162 : /// We don't want to ruin the latency improvements of using the pool by making password verification take too long
163 : const PARAMS: Params = Params {
164 : rounds: 4096,
165 : output_length: 32,
166 : };
167 :
168 7 : #[derive(Default)]
169 : pub struct DbUserConnPool {
170 : conns: Vec<ConnPoolEntry>,
171 : password_hash: Option<PasswordHashString>,
172 : }
173 :
174 : impl DbUserConnPool {
175 13 : fn clear_closed_clients(&mut self, conns: &mut usize) {
176 13 : let old_len = self.conns.len();
177 13 :
178 13 : self.conns.retain(|conn| !conn.conn.inner.is_closed());
179 13 :
180 13 : let new_len = self.conns.len();
181 13 : let removed = old_len - new_len;
182 13 : *conns -= removed;
183 13 : }
184 :
185 13 : fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry> {
186 13 : self.clear_closed_clients(conns);
187 13 : let conn = self.conns.pop();
188 13 : if conn.is_some() {
189 4 : *conns -= 1;
190 9 : }
191 13 : conn
192 13 : }
193 : }
194 :
195 : pub struct GlobalConnPool {
196 : // endpoint -> per-endpoint connection pool
197 : //
198 : // That should be a fairly conteded map, so return reference to the per-endpoint
199 : // pool as early as possible and release the lock.
200 : global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
201 :
202 : /// Number of endpoint-connection pools
203 : ///
204 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
205 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
206 : /// It's only used for diagnostics.
207 : global_pool_size: AtomicUsize,
208 :
209 : proxy_config: &'static crate::config::ProxyConfig,
210 : }
211 :
212 0 : #[derive(Debug, Clone, Copy)]
213 : pub struct GlobalConnPoolOptions {
214 : // Maximum number of connections per one endpoint.
215 : // Can mix different (dbname, username) connections.
216 : // When running out of free slots for a particular endpoint,
217 : // falls back to opening a new connection for each request.
218 : pub max_conns_per_endpoint: usize,
219 :
220 : pub gc_epoch: Duration,
221 :
222 : pub pool_shards: usize,
223 :
224 : pub idle_timeout: Duration,
225 :
226 : pub opt_in: bool,
227 : }
228 :
229 23 : pub static GC_LATENCY: Lazy<Histogram> = Lazy::new(|| {
230 23 : register_histogram!(
231 23 : "proxy_http_pool_reclaimation_lag_seconds",
232 23 : "Time it takes to reclaim unused connection pools",
233 23 : // 1us -> 65ms
234 23 : exponential_buckets(1e-6, 2.0, 16).unwrap(),
235 23 : )
236 23 : .unwrap()
237 23 : });
238 :
239 7 : pub static ENDPOINT_POOLS: Lazy<IntCounterPair> = Lazy::new(|| {
240 14 : register_int_counter_pair!(
241 14 : "proxy_http_pool_endpoints_registered_total",
242 14 : "Number of endpoints we have registered pools for",
243 14 : "proxy_http_pool_endpoints_unregistered_total",
244 14 : "Number of endpoints we have unregistered pools for",
245 14 : )
246 7 : .unwrap()
247 7 : });
248 :
249 : impl GlobalConnPool {
250 23 : pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
251 23 : let shards = config.http_config.pool_options.pool_shards;
252 23 : Arc::new(Self {
253 23 : global_pool: DashMap::with_shard_amount(shards),
254 23 : global_pool_size: AtomicUsize::new(0),
255 23 : proxy_config: config,
256 23 : })
257 23 : }
258 :
259 23 : pub fn shutdown(&self) {
260 23 : // drops all strong references to endpoint-pools
261 23 : self.global_pool.clear();
262 23 : }
263 :
264 23 : pub async fn gc_worker(&self, mut rng: impl Rng) {
265 23 : let epoch = self.proxy_config.http_config.pool_options.gc_epoch;
266 23 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
267 : loop {
268 55 : interval.tick().await;
269 :
270 32 : let shard = rng.gen_range(0..self.global_pool.shards().len());
271 32 : self.gc(shard);
272 : }
273 : }
274 :
275 32 : fn gc(&self, shard: usize) {
276 32 : debug!(shard, "pool: performing epoch reclamation");
277 :
278 : // acquire a random shard lock
279 32 : let mut shard = self.global_pool.shards()[shard].write();
280 32 :
281 32 : let timer = GC_LATENCY.start_timer();
282 32 : let current_len = shard.len();
283 32 : shard.retain(|endpoint, x| {
284 : // if the current endpoint pool is unique (no other strong or weak references)
285 : // then it is currently not in use by any connections.
286 0 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
287 : let EndpointConnPool {
288 0 : pools, total_conns, ..
289 0 : } = pool.get_mut();
290 0 :
291 0 : // ensure that closed clients are removed
292 0 : pools
293 0 : .iter_mut()
294 0 : .for_each(|(_, db_pool)| db_pool.clear_closed_clients(total_conns));
295 0 :
296 0 : // we only remove this pool if it has no active connections
297 0 : if *total_conns == 0 {
298 0 : info!("pool: discarding pool for endpoint {endpoint}");
299 0 : return false;
300 0 : }
301 0 : }
302 :
303 0 : true
304 32 : });
305 32 : let new_len = shard.len();
306 32 : drop(shard);
307 32 : timer.observe_duration();
308 32 :
309 32 : let removed = current_len - new_len;
310 32 :
311 32 : if removed > 0 {
312 0 : let global_pool_size = self
313 0 : .global_pool_size
314 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
315 0 : - removed;
316 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
317 32 : }
318 32 : }
319 :
320 46 : pub async fn get(
321 46 : self: &Arc<Self>,
322 46 : ctx: &mut RequestMonitoring,
323 46 : conn_info: ConnInfo,
324 46 : force_new: bool,
325 46 : ) -> anyhow::Result<Client> {
326 46 : let mut client: Option<ClientInner> = None;
327 46 :
328 46 : let mut hash_valid = false;
329 46 : let mut endpoint_pool = Weak::new();
330 46 : if !force_new {
331 27 : let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
332 27 : endpoint_pool = Arc::downgrade(&pool);
333 27 : let mut hash = None;
334 27 :
335 27 : // find a pool entry by (dbname, username) if exists
336 27 : {
337 27 : let pool = pool.read();
338 27 : if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) {
339 19 : if !pool_entries.conns.is_empty() {
340 16 : hash = pool_entries.password_hash.clone();
341 16 : }
342 8 : }
343 : }
344 :
345 : // a connection exists in the pool, verify the password hash
346 27 : if let Some(hash) = hash {
347 16 : let pw = conn_info.password.clone();
348 16 : let validate = tokio::task::spawn_blocking(move || {
349 16 : Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash())
350 16 : })
351 32 : .await?;
352 :
353 : // if the hash is invalid, don't error
354 : // we will continue with the regular connection flow
355 16 : if validate.is_ok() {
356 13 : hash_valid = true;
357 13 : if let Some(entry) = pool.write().get_conn_entry(conn_info.db_and_user()) {
358 4 : client = Some(entry.conn)
359 9 : }
360 3 : }
361 11 : }
362 19 : }
363 :
364 : // ok return cached connection if found and establish a new one otherwise
365 46 : let new_client = if let Some(client) = client {
366 4 : ctx.set_project(client.aux.clone());
367 4 : if client.inner.is_closed() {
368 0 : let conn_id = uuid::Uuid::new_v4();
369 0 : info!(%conn_id, "pool: cached connection '{conn_info}' is closed, opening a new one");
370 0 : connect_to_compute(
371 0 : self.proxy_config,
372 0 : ctx,
373 0 : &conn_info,
374 0 : conn_id,
375 0 : endpoint_pool.clone(),
376 0 : )
377 0 : .await
378 : } else {
379 4 : info!("pool: reusing connection '{conn_info}'");
380 4 : client.session.send(ctx.session_id)?;
381 4 : tracing::Span::current().record(
382 4 : "pid",
383 4 : &tracing::field::display(client.inner.get_process_id()),
384 4 : );
385 4 : ctx.latency_timer.pool_hit();
386 4 : ctx.latency_timer.success();
387 4 : return Ok(Client::new(client, conn_info, endpoint_pool).await);
388 : }
389 : } else {
390 42 : let conn_id = uuid::Uuid::new_v4();
391 42 : info!(%conn_id, "pool: opening a new connection '{conn_info}'");
392 42 : connect_to_compute(
393 42 : self.proxy_config,
394 42 : ctx,
395 42 : &conn_info,
396 42 : conn_id,
397 42 : endpoint_pool.clone(),
398 42 : )
399 469 : .await
400 : };
401 42 : if let Ok(client) = &new_client {
402 39 : tracing::Span::current().record(
403 39 : "pid",
404 39 : &tracing::field::display(client.inner.get_process_id()),
405 39 : );
406 39 : }
407 :
408 39 : match &new_client {
409 3 : // clear the hash. it's no longer valid
410 3 : // TODO: update tokio-postgres fork to allow access to this error kind directly
411 3 : Err(err)
412 3 : if hash_valid && err.to_string().contains("password authentication failed") =>
413 0 : {
414 0 : let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
415 0 : let mut pool = pool.write();
416 0 : if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
417 0 : entry.password_hash = None;
418 0 : }
419 : }
420 : // new password is valid and we should insert/update it
421 39 : Ok(_) if !force_new && !hash_valid => {
422 11 : let pw = conn_info.password.clone();
423 11 : let new_hash = tokio::task::spawn_blocking(move || {
424 11 : let salt = SaltString::generate(rand::rngs::OsRng);
425 11 : Pbkdf2
426 11 : .hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt)
427 11 : .map(|s| s.serialize())
428 11 : })
429 11 : .await??;
430 :
431 11 : let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
432 11 : let mut pool = pool.write();
433 11 : pool.pools
434 11 : .entry(conn_info.db_and_user())
435 11 : .or_default()
436 11 : .password_hash = Some(new_hash);
437 : }
438 31 : _ => {}
439 : }
440 42 : let new_client = new_client?;
441 39 : Ok(Client::new(new_client, conn_info, endpoint_pool).await)
442 46 : }
443 :
444 38 : fn get_or_create_endpoint_pool(
445 38 : &self,
446 38 : endpoint: &EndpointCacheKey,
447 38 : ) -> Arc<RwLock<EndpointConnPool>> {
448 : // fast path
449 38 : if let Some(pool) = self.global_pool.get(endpoint) {
450 31 : return pool.clone();
451 7 : }
452 7 :
453 7 : // slow path
454 7 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
455 7 : pools: HashMap::new(),
456 7 : total_conns: 0,
457 7 : max_conns: self
458 7 : .proxy_config
459 7 : .http_config
460 7 : .pool_options
461 7 : .max_conns_per_endpoint,
462 7 : _guard: ENDPOINT_POOLS.guard(),
463 7 : }));
464 7 :
465 7 : // find or create a pool for this endpoint
466 7 : let mut created = false;
467 7 : let pool = self
468 7 : .global_pool
469 7 : .entry(endpoint.clone())
470 7 : .or_insert_with(|| {
471 7 : created = true;
472 7 : new_pool
473 7 : })
474 7 : .clone();
475 7 :
476 7 : // log new global pool size
477 7 : if created {
478 7 : let global_pool_size = self
479 7 : .global_pool_size
480 7 : .fetch_add(1, atomic::Ordering::Relaxed)
481 7 : + 1;
482 7 : info!(
483 7 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
484 7 : );
485 0 : }
486 :
487 7 : pool
488 38 : }
489 : }
490 :
491 : struct TokioMechanism<'a> {
492 : pool: Weak<RwLock<EndpointConnPool>>,
493 : conn_info: &'a ConnInfo,
494 : conn_id: uuid::Uuid,
495 : idle: Duration,
496 : }
497 :
498 : #[async_trait]
499 : impl ConnectMechanism for TokioMechanism<'_> {
500 : type Connection = ClientInner;
501 : type ConnectError = tokio_postgres::Error;
502 : type Error = anyhow::Error;
503 :
504 43 : async fn connect_once(
505 43 : &self,
506 43 : ctx: &mut RequestMonitoring,
507 43 : node_info: &console::CachedNodeInfo,
508 43 : timeout: time::Duration,
509 43 : ) -> Result<Self::Connection, Self::ConnectError> {
510 43 : connect_to_compute_once(
511 43 : ctx,
512 43 : node_info,
513 43 : self.conn_info,
514 43 : timeout,
515 43 : self.conn_id,
516 43 : self.pool.clone(),
517 43 : self.idle,
518 43 : )
519 149 : .await
520 86 : }
521 :
522 43 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
523 : }
524 :
525 : // Wake up the destination if needed. Code here is a bit involved because
526 : // we reuse the code from the usual proxy and we need to prepare few structures
527 : // that this code expects.
528 0 : #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
529 : async fn connect_to_compute(
530 : config: &config::ProxyConfig,
531 : ctx: &mut RequestMonitoring,
532 : conn_info: &ConnInfo,
533 : conn_id: uuid::Uuid,
534 : pool: Weak<RwLock<EndpointConnPool>>,
535 : ) -> anyhow::Result<ClientInner> {
536 : ctx.set_application(Some(APP_NAME));
537 : let backend = config
538 : .auth_backend
539 : .as_ref()
540 42 : .map(|_| conn_info.user_info.clone());
541 :
542 : if !config.disable_ip_check_for_http {
543 : let (allowed_ips, _) = backend.get_allowed_ips_and_secret(ctx).await?;
544 : if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
545 : return Err(auth::AuthError::ip_address_not_allowed().into());
546 : }
547 : }
548 : let node_info = backend
549 : .wake_compute(ctx)
550 : .await?
551 : .context("missing cache entry from wake_compute")?;
552 :
553 : ctx.set_project(node_info.aux.clone());
554 :
555 : crate::proxy::connect_compute::connect_to_compute(
556 : ctx,
557 : &TokioMechanism {
558 : conn_id,
559 : conn_info,
560 : pool,
561 : idle: config.http_config.pool_options.idle_timeout,
562 : },
563 : node_info,
564 : &backend,
565 : )
566 : .await
567 : }
568 :
569 43 : async fn connect_to_compute_once(
570 43 : ctx: &mut RequestMonitoring,
571 43 : node_info: &console::CachedNodeInfo,
572 43 : conn_info: &ConnInfo,
573 43 : timeout: time::Duration,
574 43 : conn_id: uuid::Uuid,
575 43 : pool: Weak<RwLock<EndpointConnPool>>,
576 43 : idle: Duration,
577 43 : ) -> Result<ClientInner, tokio_postgres::Error> {
578 43 : let mut config = (*node_info.config).clone();
579 43 : let mut session = ctx.session_id;
580 :
581 43 : let (client, mut connection) = config
582 43 : .user(&conn_info.user_info.user)
583 43 : .password(&*conn_info.password)
584 43 : .dbname(&conn_info.dbname)
585 43 : .connect_timeout(timeout)
586 43 : .connect(tokio_postgres::NoTls)
587 149 : .await?;
588 :
589 39 : let conn_gauge = NUM_DB_CONNECTIONS_GAUGE
590 39 : .with_label_values(&[ctx.protocol])
591 39 : .guard();
592 39 :
593 39 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
594 39 :
595 39 : let (tx, mut rx) = tokio::sync::watch::channel(session);
596 :
597 39 : let span = info_span!(parent: None, "connection", %conn_id);
598 39 : span.in_scope(|| {
599 39 : info!(%conn_info, %session, "new connection");
600 39 : });
601 39 :
602 39 : let db_user = conn_info.db_and_user();
603 39 : tokio::spawn(
604 39 : async move {
605 39 : let _conn_gauge = conn_gauge;
606 39 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
607 7010 : poll_fn(move |cx| {
608 7010 : if matches!(rx.has_changed(), Ok(true)) {
609 4 : session = *rx.borrow_and_update();
610 4 : info!(%session, "changed session");
611 4 : idle_timeout.as_mut().reset(Instant::now() + idle);
612 7006 : }
613 :
614 : // 5 minute idle connection timeout
615 7010 : if idle_timeout.as_mut().poll(cx).is_ready() {
616 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
617 0 : info!("connection idle");
618 0 : if let Some(pool) = pool.clone().upgrade() {
619 : // remove client from pool - should close the connection if it's idle.
620 : // does nothing if the client is currently checked-out and in-use
621 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
622 0 : info!("idle connection removed");
623 0 : }
624 0 : }
625 7010 : }
626 :
627 : loop {
628 7010 : let message = ready!(connection.poll_message(cx));
629 :
630 0 : match message {
631 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
632 0 : info!(%session, "notice: {}", notice);
633 : }
634 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
635 0 : warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received");
636 : }
637 : Some(Ok(_)) => {
638 0 : warn!(%session, "unknown message");
639 : }
640 0 : Some(Err(e)) => {
641 0 : error!(%session, "connection error: {}", e);
642 0 : break
643 : }
644 : None => {
645 38 : info!("connection closed");
646 38 : break
647 : }
648 : }
649 : }
650 :
651 : // remove from connection pool
652 38 : if let Some(pool) = pool.clone().upgrade() {
653 3 : if pool.write().remove_client(db_user.clone(), conn_id) {
654 0 : info!("closed connection removed");
655 3 : }
656 35 : }
657 :
658 38 : Poll::Ready(())
659 7010 : }).await;
660 :
661 39 : }
662 39 : .instrument(span)
663 39 : );
664 39 :
665 39 : Ok(ClientInner {
666 39 : inner: client,
667 39 : session: tx,
668 39 : aux: node_info.aux.clone(),
669 39 : conn_id,
670 39 : })
671 43 : }
672 :
673 : struct ClientInner {
674 : inner: tokio_postgres::Client,
675 : session: tokio::sync::watch::Sender<uuid::Uuid>,
676 : aux: MetricsAuxInfo,
677 : conn_id: uuid::Uuid,
678 : }
679 :
680 : impl Client {
681 41 : pub fn metrics(&self) -> Arc<MetricCounter> {
682 41 : let aux = &self.inner.as_ref().unwrap().aux;
683 41 : USAGE_METRICS.register(Ids {
684 41 : endpoint_id: aux.endpoint_id.clone(),
685 41 : branch_id: aux.branch_id.clone(),
686 41 : })
687 41 : }
688 : }
689 :
690 : pub struct Client {
691 : conn_id: uuid::Uuid,
692 : span: Span,
693 : inner: Option<ClientInner>,
694 : conn_info: ConnInfo,
695 : pool: Weak<RwLock<EndpointConnPool>>,
696 : }
697 :
698 : pub struct Discard<'a> {
699 : conn_id: uuid::Uuid,
700 : conn_info: &'a ConnInfo,
701 : pool: &'a mut Weak<RwLock<EndpointConnPool>>,
702 : }
703 :
704 : impl Client {
705 43 : pub(self) async fn new(
706 43 : inner: ClientInner,
707 43 : conn_info: ConnInfo,
708 43 : pool: Weak<RwLock<EndpointConnPool>>,
709 43 : ) -> Self {
710 43 : Self {
711 43 : conn_id: inner.conn_id,
712 43 : inner: Some(inner),
713 43 : span: Span::current(),
714 43 : conn_info,
715 43 : pool,
716 43 : }
717 43 : }
718 43 : pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
719 43 : let Self {
720 43 : inner,
721 43 : pool,
722 43 : conn_id,
723 43 : conn_info,
724 43 : span: _,
725 43 : } = self;
726 43 : (
727 43 : &mut inner
728 43 : .as_mut()
729 43 : .expect("client inner should not be removed")
730 43 : .inner,
731 43 : Discard {
732 43 : pool,
733 43 : conn_info,
734 43 : conn_id: *conn_id,
735 43 : },
736 43 : )
737 43 : }
738 :
739 39 : pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
740 39 : self.inner().1.check_idle(status)
741 39 : }
742 2 : pub fn discard(&mut self) {
743 2 : self.inner().1.discard()
744 2 : }
745 : }
746 :
747 : impl Discard<'_> {
748 41 : pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
749 41 : let conn_info = &self.conn_info;
750 41 : if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
751 2 : info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle")
752 39 : }
753 41 : }
754 2 : pub fn discard(&mut self) {
755 2 : let conn_info = &self.conn_info;
756 2 : if std::mem::take(self.pool).strong_count() > 0 {
757 2 : info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
758 0 : }
759 2 : }
760 : }
761 :
762 : impl Deref for Client {
763 : type Target = tokio_postgres::Client;
764 :
765 41 : fn deref(&self) -> &Self::Target {
766 41 : &self
767 41 : .inner
768 41 : .as_ref()
769 41 : .expect("client inner should not be removed")
770 41 : .inner
771 41 : }
772 : }
773 :
774 : impl Drop for Client {
775 43 : fn drop(&mut self) {
776 43 : let conn_info = self.conn_info.clone();
777 43 : let client = self
778 43 : .inner
779 43 : .take()
780 43 : .expect("client inner should not be removed");
781 43 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
782 20 : let current_span = self.span.clone();
783 20 : // return connection to the pool
784 20 : tokio::task::spawn_blocking(move || {
785 20 : let _span = current_span.enter();
786 20 : let _ = EndpointConnPool::put(&conn_pool, &conn_info, client);
787 20 : });
788 23 : }
789 43 : }
790 : }
|