TLA Line data Source code
1 : use anyhow::{anyhow, Context};
2 : use async_trait::async_trait;
3 : use dashmap::DashMap;
4 : use futures::{future::poll_fn, Future};
5 : use metrics::{register_int_counter_pair, IntCounterPair, IntCounterPairGuard};
6 : use once_cell::sync::Lazy;
7 : use parking_lot::RwLock;
8 : use pbkdf2::{
9 : password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
10 : Params, Pbkdf2,
11 : };
12 : use pq_proto::StartupMessageParams;
13 : use prometheus::{exponential_buckets, register_histogram, Histogram};
14 : use rand::Rng;
15 : use smol_str::SmolStr;
16 : use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
17 : use std::{
18 : fmt,
19 : task::{ready, Poll},
20 : };
21 : use std::{
22 : ops::Deref,
23 : sync::atomic::{self, AtomicUsize},
24 : };
25 : use tokio::time::{self, Instant};
26 : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
27 :
28 : use crate::{
29 : auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list},
30 : console,
31 : context::RequestMonitoring,
32 : metrics::NUM_DB_CONNECTIONS_GAUGE,
33 : proxy::{connect_compute::ConnectMechanism, neon_options},
34 : usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
35 : };
36 : use crate::{compute, config};
37 :
38 : use tracing::{debug, error, warn, Span};
39 : use tracing::{info, info_span, Instrument};
40 :
41 : pub const APP_NAME: &str = "/sql_over_http";
42 :
43 CBC 42 : #[derive(Debug, Clone)]
44 : pub struct ConnInfo {
45 : pub username: SmolStr,
46 : pub dbname: SmolStr,
47 : pub hostname: SmolStr,
48 : pub password: SmolStr,
49 : pub options: Option<SmolStr>,
50 : }
51 :
52 : impl ConnInfo {
53 : // hm, change to hasher to avoid cloning?
54 109 : pub fn db_and_user(&self) -> (SmolStr, SmolStr) {
55 109 : (self.dbname.clone(), self.username.clone())
56 109 : }
57 : }
58 :
59 : impl fmt::Display for ConnInfo {
60 : // use custom display to avoid logging password
61 214 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 214 : write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname)
63 214 : }
64 : }
65 :
66 : struct ConnPoolEntry {
67 : conn: ClientInner,
68 : _last_access: std::time::Instant,
69 : }
70 :
71 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
72 : // Number of open connections is limited by the `max_conns_per_endpoint`.
73 : pub struct EndpointConnPool {
74 : pools: HashMap<(SmolStr, SmolStr), DbUserConnPool>,
75 : total_conns: usize,
76 : max_conns: usize,
77 : _guard: IntCounterPairGuard,
78 : }
79 :
80 : impl EndpointConnPool {
81 13 : fn get_conn_entry(&mut self, db_user: (SmolStr, SmolStr)) -> Option<ConnPoolEntry> {
82 13 : let Self {
83 13 : pools, total_conns, ..
84 13 : } = self;
85 13 : pools
86 13 : .get_mut(&db_user)
87 13 : .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
88 13 : }
89 :
90 3 : fn remove_client(&mut self, db_user: (SmolStr, SmolStr), conn_id: uuid::Uuid) -> bool {
91 3 : let Self {
92 3 : pools, total_conns, ..
93 3 : } = self;
94 3 : if let Some(pool) = pools.get_mut(&db_user) {
95 3 : let old_len = pool.conns.len();
96 3 : pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
97 3 : let new_len = pool.conns.len();
98 3 : let removed = old_len - new_len;
99 3 : *total_conns -= removed;
100 3 : removed > 0
101 : } else {
102 UBC 0 : false
103 : }
104 CBC 3 : }
105 :
106 20 : fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> {
107 20 : let conn_id = client.conn_id;
108 20 :
109 20 : if client.inner.is_closed() {
110 UBC 0 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
111 0 : return Ok(());
112 CBC 20 : }
113 20 :
114 20 : // return connection to the pool
115 20 : let mut returned = false;
116 20 : let mut per_db_size = 0;
117 20 : let total_conns = {
118 20 : let mut pool = pool.write();
119 20 :
120 20 : if pool.total_conns < pool.max_conns {
121 : // we create this db-user entry in get, so it should not be None
122 20 : if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
123 20 : pool_entries.conns.push(ConnPoolEntry {
124 20 : conn: client,
125 20 : _last_access: std::time::Instant::now(),
126 20 : });
127 20 :
128 20 : returned = true;
129 20 : per_db_size = pool_entries.conns.len();
130 20 :
131 20 : pool.total_conns += 1;
132 20 : }
133 UBC 0 : }
134 :
135 CBC 20 : pool.total_conns
136 20 : };
137 20 :
138 20 : // do logging outside of the mutex
139 20 : if returned {
140 20 : info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
141 : } else {
142 UBC 0 : info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
143 : }
144 :
145 CBC 20 : Ok(())
146 20 : }
147 : }
148 :
149 : /// 4096 is the number of rounds that SCRAM-SHA-256 recommends.
150 : /// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway.
151 : ///
152 : /// Still takes 1.4ms to hash on my hardware.
153 : /// We don't want to ruin the latency improvements of using the pool by making password verification take too long
154 : const PARAMS: Params = Params {
155 : rounds: 4096,
156 : output_length: 32,
157 : };
158 :
159 7 : #[derive(Default)]
160 : pub struct DbUserConnPool {
161 : conns: Vec<ConnPoolEntry>,
162 : password_hash: Option<PasswordHashString>,
163 : }
164 :
165 : impl DbUserConnPool {
166 13 : fn clear_closed_clients(&mut self, conns: &mut usize) {
167 13 : let old_len = self.conns.len();
168 13 :
169 13 : self.conns.retain(|conn| !conn.conn.inner.is_closed());
170 13 :
171 13 : let new_len = self.conns.len();
172 13 : let removed = old_len - new_len;
173 13 : *conns -= removed;
174 13 : }
175 :
176 13 : fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry> {
177 13 : self.clear_closed_clients(conns);
178 13 : let conn = self.conns.pop();
179 13 : if conn.is_some() {
180 4 : *conns -= 1;
181 9 : }
182 13 : conn
183 13 : }
184 : }
185 :
186 : pub struct GlobalConnPool {
187 : // endpoint -> per-endpoint connection pool
188 : //
189 : // That should be a fairly conteded map, so return reference to the per-endpoint
190 : // pool as early as possible and release the lock.
191 : global_pool: DashMap<SmolStr, Arc<RwLock<EndpointConnPool>>>,
192 :
193 : /// Number of endpoint-connection pools
194 : ///
195 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
196 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
197 : /// It's only used for diagnostics.
198 : global_pool_size: AtomicUsize,
199 :
200 : proxy_config: &'static crate::config::ProxyConfig,
201 : }
202 :
203 UBC 0 : #[derive(Debug, Clone, Copy)]
204 : pub struct GlobalConnPoolOptions {
205 : // Maximum number of connections per one endpoint.
206 : // Can mix different (dbname, username) connections.
207 : // When running out of free slots for a particular endpoint,
208 : // falls back to opening a new connection for each request.
209 : pub max_conns_per_endpoint: usize,
210 :
211 : pub gc_epoch: Duration,
212 :
213 : pub pool_shards: usize,
214 :
215 : pub idle_timeout: Duration,
216 :
217 : pub opt_in: bool,
218 : }
219 :
220 CBC 22 : pub static GC_LATENCY: Lazy<Histogram> = Lazy::new(|| {
221 22 : register_histogram!(
222 22 : "proxy_http_pool_reclaimation_lag_seconds",
223 22 : "Time it takes to reclaim unused connection pools",
224 22 : // 1us -> 65ms
225 22 : exponential_buckets(1e-6, 2.0, 16).unwrap(),
226 22 : )
227 22 : .unwrap()
228 22 : });
229 :
230 7 : pub static ENDPOINT_POOLS: Lazy<IntCounterPair> = Lazy::new(|| {
231 14 : register_int_counter_pair!(
232 14 : "proxy_http_pool_endpoints_registered_total",
233 14 : "Number of endpoints we have registered pools for",
234 14 : "proxy_http_pool_endpoints_unregistered_total",
235 14 : "Number of endpoints we have unregistered pools for",
236 14 : )
237 7 : .unwrap()
238 7 : });
239 :
240 : impl GlobalConnPool {
241 22 : pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
242 22 : let shards = config.http_config.pool_options.pool_shards;
243 22 : Arc::new(Self {
244 22 : global_pool: DashMap::with_shard_amount(shards),
245 22 : global_pool_size: AtomicUsize::new(0),
246 22 : proxy_config: config,
247 22 : })
248 22 : }
249 :
250 22 : pub fn shutdown(&self) {
251 22 : // drops all strong references to endpoint-pools
252 22 : self.global_pool.clear();
253 22 : }
254 :
255 22 : pub async fn gc_worker(&self, mut rng: impl Rng) {
256 22 : let epoch = self.proxy_config.http_config.pool_options.gc_epoch;
257 22 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
258 : loop {
259 52 : interval.tick().await;
260 :
261 30 : let shard = rng.gen_range(0..self.global_pool.shards().len());
262 30 : self.gc(shard);
263 : }
264 : }
265 :
266 30 : fn gc(&self, shard: usize) {
267 30 : debug!(shard, "pool: performing epoch reclamation");
268 :
269 : // acquire a random shard lock
270 30 : let mut shard = self.global_pool.shards()[shard].write();
271 30 :
272 30 : let timer = GC_LATENCY.start_timer();
273 30 : let current_len = shard.len();
274 30 : shard.retain(|endpoint, x| {
275 : // if the current endpoint pool is unique (no other strong or weak references)
276 : // then it is currently not in use by any connections.
277 UBC 0 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
278 : let EndpointConnPool {
279 0 : pools, total_conns, ..
280 0 : } = pool.get_mut();
281 0 :
282 0 : // ensure that closed clients are removed
283 0 : pools
284 0 : .iter_mut()
285 0 : .for_each(|(_, db_pool)| db_pool.clear_closed_clients(total_conns));
286 0 :
287 0 : // we only remove this pool if it has no active connections
288 0 : if *total_conns == 0 {
289 0 : info!("pool: discarding pool for endpoint {endpoint}");
290 0 : return false;
291 0 : }
292 0 : }
293 :
294 0 : true
295 CBC 30 : });
296 30 : let new_len = shard.len();
297 30 : drop(shard);
298 30 : timer.observe_duration();
299 30 :
300 30 : let removed = current_len - new_len;
301 30 :
302 30 : if removed > 0 {
303 UBC 0 : let global_pool_size = self
304 0 : .global_pool_size
305 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
306 0 : - removed;
307 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
308 CBC 30 : }
309 30 : }
310 :
311 45 : pub async fn get(
312 45 : self: &Arc<Self>,
313 45 : ctx: &mut RequestMonitoring,
314 45 : conn_info: ConnInfo,
315 45 : force_new: bool,
316 45 : ) -> anyhow::Result<Client> {
317 45 : let mut client: Option<ClientInner> = None;
318 45 :
319 45 : let mut hash_valid = false;
320 45 : let mut endpoint_pool = Weak::new();
321 45 : if !force_new {
322 27 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
323 27 : endpoint_pool = Arc::downgrade(&pool);
324 27 : let mut hash = None;
325 27 :
326 27 : // find a pool entry by (dbname, username) if exists
327 27 : {
328 27 : let pool = pool.read();
329 27 : if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) {
330 19 : if !pool_entries.conns.is_empty() {
331 16 : hash = pool_entries.password_hash.clone();
332 16 : }
333 8 : }
334 : }
335 :
336 : // a connection exists in the pool, verify the password hash
337 27 : if let Some(hash) = hash {
338 16 : let pw = conn_info.password.clone();
339 16 : let validate = tokio::task::spawn_blocking(move || {
340 16 : Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash())
341 16 : })
342 32 : .await?;
343 :
344 : // if the hash is invalid, don't error
345 : // we will continue with the regular connection flow
346 16 : if validate.is_ok() {
347 13 : hash_valid = true;
348 13 : if let Some(entry) = pool.write().get_conn_entry(conn_info.db_and_user()) {
349 4 : client = Some(entry.conn)
350 9 : }
351 3 : }
352 11 : }
353 18 : }
354 :
355 : // ok return cached connection if found and establish a new one otherwise
356 45 : let new_client = if let Some(client) = client {
357 4 : if client.inner.is_closed() {
358 UBC 0 : let conn_id = uuid::Uuid::new_v4();
359 0 : info!(%conn_id, "pool: cached connection '{conn_info}' is closed, opening a new one");
360 0 : connect_to_compute(
361 0 : self.proxy_config,
362 0 : ctx,
363 0 : &conn_info,
364 0 : conn_id,
365 0 : endpoint_pool.clone(),
366 0 : )
367 0 : .await
368 : } else {
369 CBC 4 : info!("pool: reusing connection '{conn_info}'");
370 4 : client.session.send(ctx.session_id)?;
371 4 : tracing::Span::current().record(
372 4 : "pid",
373 4 : &tracing::field::display(client.inner.get_process_id()),
374 4 : );
375 4 : ctx.latency_timer.pool_hit();
376 4 : ctx.latency_timer.success();
377 4 : return Ok(Client::new(client, conn_info, endpoint_pool).await);
378 : }
379 : } else {
380 41 : let conn_id = uuid::Uuid::new_v4();
381 41 : info!(%conn_id, "pool: opening a new connection '{conn_info}'");
382 41 : connect_to_compute(
383 41 : self.proxy_config,
384 41 : ctx,
385 41 : &conn_info,
386 41 : conn_id,
387 41 : endpoint_pool.clone(),
388 41 : )
389 459 : .await
390 : };
391 41 : if let Ok(client) = &new_client {
392 38 : tracing::Span::current().record(
393 38 : "pid",
394 38 : &tracing::field::display(client.inner.get_process_id()),
395 38 : );
396 38 : }
397 :
398 38 : match &new_client {
399 3 : // clear the hash. it's no longer valid
400 3 : // TODO: update tokio-postgres fork to allow access to this error kind directly
401 3 : Err(err)
402 3 : if hash_valid && err.to_string().contains("password authentication failed") =>
403 UBC 0 : {
404 0 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
405 0 : let mut pool = pool.write();
406 0 : if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
407 0 : entry.password_hash = None;
408 0 : }
409 : }
410 : // new password is valid and we should insert/update it
411 CBC 38 : Ok(_) if !force_new && !hash_valid => {
412 11 : let pw = conn_info.password.clone();
413 11 : let new_hash = tokio::task::spawn_blocking(move || {
414 11 : let salt = SaltString::generate(rand::rngs::OsRng);
415 11 : Pbkdf2
416 11 : .hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt)
417 11 : .map(|s| s.serialize())
418 11 : })
419 11 : .await??;
420 :
421 11 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
422 11 : let mut pool = pool.write();
423 11 : pool.pools
424 11 : .entry(conn_info.db_and_user())
425 11 : .or_default()
426 11 : .password_hash = Some(new_hash);
427 : }
428 30 : _ => {}
429 : }
430 41 : let new_client = new_client?;
431 38 : Ok(Client::new(new_client, conn_info, endpoint_pool).await)
432 45 : }
433 :
434 38 : fn get_or_create_endpoint_pool(&self, endpoint: &SmolStr) -> Arc<RwLock<EndpointConnPool>> {
435 : // fast path
436 38 : if let Some(pool) = self.global_pool.get(endpoint) {
437 31 : return pool.clone();
438 7 : }
439 7 :
440 7 : // slow path
441 7 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
442 7 : pools: HashMap::new(),
443 7 : total_conns: 0,
444 7 : max_conns: self
445 7 : .proxy_config
446 7 : .http_config
447 7 : .pool_options
448 7 : .max_conns_per_endpoint,
449 7 : _guard: ENDPOINT_POOLS.guard(),
450 7 : }));
451 7 :
452 7 : // find or create a pool for this endpoint
453 7 : let mut created = false;
454 7 : let pool = self
455 7 : .global_pool
456 7 : .entry(endpoint.clone())
457 7 : .or_insert_with(|| {
458 7 : created = true;
459 7 : new_pool
460 7 : })
461 7 : .clone();
462 7 :
463 7 : // log new global pool size
464 7 : if created {
465 7 : let global_pool_size = self
466 7 : .global_pool_size
467 7 : .fetch_add(1, atomic::Ordering::Relaxed)
468 7 : + 1;
469 7 : info!(
470 7 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
471 7 : );
472 UBC 0 : }
473 :
474 CBC 7 : pool
475 38 : }
476 : }
477 :
478 : struct TokioMechanism<'a> {
479 : pool: Weak<RwLock<EndpointConnPool>>,
480 : conn_info: &'a ConnInfo,
481 : conn_id: uuid::Uuid,
482 : idle: Duration,
483 : }
484 :
485 : #[async_trait]
486 : impl ConnectMechanism for TokioMechanism<'_> {
487 : type Connection = ClientInner;
488 : type ConnectError = tokio_postgres::Error;
489 : type Error = anyhow::Error;
490 :
491 42 : async fn connect_once(
492 42 : &self,
493 42 : ctx: &mut RequestMonitoring,
494 42 : node_info: &console::CachedNodeInfo,
495 42 : timeout: time::Duration,
496 42 : ) -> Result<Self::Connection, Self::ConnectError> {
497 42 : connect_to_compute_once(
498 42 : ctx,
499 42 : node_info,
500 42 : self.conn_info,
501 42 : timeout,
502 42 : self.conn_id,
503 42 : self.pool.clone(),
504 42 : self.idle,
505 42 : )
506 146 : .await
507 84 : }
508 :
509 42 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
510 : }
511 :
512 : // Wake up the destination if needed. Code here is a bit involved because
513 : // we reuse the code from the usual proxy and we need to prepare few structures
514 : // that this code expects.
515 UBC 0 : #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
516 : async fn connect_to_compute(
517 : config: &config::ProxyConfig,
518 : ctx: &mut RequestMonitoring,
519 : conn_info: &ConnInfo,
520 : conn_id: uuid::Uuid,
521 : pool: Weak<RwLock<EndpointConnPool>>,
522 : ) -> anyhow::Result<ClientInner> {
523 : let tls = config.tls_config.as_ref();
524 CBC 41 : let common_names = tls.and_then(|tls| tls.common_names.clone());
525 :
526 : let params = StartupMessageParams::new([
527 : ("user", &conn_info.username),
528 : ("database", &conn_info.dbname),
529 : ("application_name", APP_NAME),
530 : ("options", conn_info.options.as_deref().unwrap_or("")),
531 : ]);
532 : let creds =
533 : auth::ClientCredentials::parse(ctx, ¶ms, Some(&conn_info.hostname), common_names)?;
534 :
535 : let creds =
536 UBC 0 : ComputeUserInfo::try_from(creds).map_err(|_| anyhow!("missing endpoint identifier"))?;
537 CBC 41 : let backend = config.auth_backend.as_ref().map(|_| creds);
538 :
539 : let console_options = neon_options(¶ms);
540 :
541 : if !config.disable_ip_check_for_http {
542 : let allowed_ips = backend.get_allowed_ips(ctx).await?;
543 : if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
544 : return Err(auth::AuthError::ip_address_not_allowed().into());
545 : }
546 : }
547 : let extra = console::ConsoleReqExtra {
548 : options: console_options,
549 : };
550 : let node_info = backend
551 : .wake_compute(ctx, &extra)
552 : .await?
553 : .context("missing cache entry from wake_compute")?;
554 :
555 : ctx.set_project(node_info.aux.clone());
556 :
557 : crate::proxy::connect_compute::connect_to_compute(
558 : ctx,
559 : &TokioMechanism {
560 : conn_id,
561 : conn_info,
562 : pool,
563 : idle: config.http_config.pool_options.idle_timeout,
564 : },
565 : node_info,
566 : &extra,
567 : &backend,
568 : )
569 : .await
570 : }
571 :
572 42 : async fn connect_to_compute_once(
573 42 : ctx: &mut RequestMonitoring,
574 42 : node_info: &console::CachedNodeInfo,
575 42 : conn_info: &ConnInfo,
576 42 : timeout: time::Duration,
577 42 : conn_id: uuid::Uuid,
578 42 : pool: Weak<RwLock<EndpointConnPool>>,
579 42 : idle: Duration,
580 42 : ) -> Result<ClientInner, tokio_postgres::Error> {
581 42 : let mut config = (*node_info.config).clone();
582 42 : let mut session = ctx.session_id;
583 :
584 42 : let (client, mut connection) = config
585 42 : .user(&conn_info.username)
586 42 : .password(&*conn_info.password)
587 42 : .dbname(&conn_info.dbname)
588 42 : .connect_timeout(timeout)
589 42 : .connect(tokio_postgres::NoTls)
590 146 : .await?;
591 :
592 38 : let conn_gauge = NUM_DB_CONNECTIONS_GAUGE
593 38 : .with_label_values(&[ctx.protocol])
594 38 : .guard();
595 38 :
596 38 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
597 38 :
598 38 : let (tx, mut rx) = tokio::sync::watch::channel(session);
599 :
600 38 : let span = info_span!(parent: None, "connection", %conn_id);
601 38 : span.in_scope(|| {
602 38 : info!(%conn_info, %session, "new connection");
603 38 : });
604 38 : let ids = Ids {
605 38 : endpoint_id: node_info.aux.endpoint_id.clone(),
606 38 : branch_id: node_info.aux.branch_id.clone(),
607 38 : };
608 38 :
609 38 : let db_user = conn_info.db_and_user();
610 38 : tokio::spawn(
611 38 : async move {
612 38 : let _conn_gauge = conn_gauge;
613 38 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
614 5846 : poll_fn(move |cx| {
615 5846 : if matches!(rx.has_changed(), Ok(true)) {
616 4 : session = *rx.borrow_and_update();
617 4 : info!(%session, "changed session");
618 4 : idle_timeout.as_mut().reset(Instant::now() + idle);
619 5842 : }
620 :
621 : // 5 minute idle connection timeout
622 5846 : if idle_timeout.as_mut().poll(cx).is_ready() {
623 UBC 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
624 0 : info!("connection idle");
625 0 : if let Some(pool) = pool.clone().upgrade() {
626 : // remove client from pool - should close the connection if it's idle.
627 : // does nothing if the client is currently checked-out and in-use
628 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
629 0 : info!("idle connection removed");
630 0 : }
631 0 : }
632 CBC 5846 : }
633 :
634 : loop {
635 5846 : let message = ready!(connection.poll_message(cx));
636 :
637 UBC 0 : match message {
638 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
639 0 : info!(%session, "notice: {}", notice);
640 : }
641 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
642 0 : warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received");
643 : }
644 : Some(Ok(_)) => {
645 0 : warn!(%session, "unknown message");
646 : }
647 0 : Some(Err(e)) => {
648 0 : error!(%session, "connection error: {}", e);
649 0 : break
650 : }
651 : None => {
652 CBC 37 : info!("connection closed");
653 37 : break
654 : }
655 : }
656 : }
657 :
658 : // remove from connection pool
659 37 : if let Some(pool) = pool.clone().upgrade() {
660 3 : if pool.write().remove_client(db_user.clone(), conn_id) {
661 UBC 0 : info!("closed connection removed");
662 CBC 3 : }
663 34 : }
664 :
665 37 : Poll::Ready(())
666 5846 : }).await;
667 :
668 38 : }
669 38 : .instrument(span)
670 38 : );
671 38 :
672 38 : Ok(ClientInner {
673 38 : inner: client,
674 38 : session: tx,
675 38 : ids,
676 38 : conn_id,
677 38 : })
678 42 : }
679 :
680 : struct ClientInner {
681 : inner: tokio_postgres::Client,
682 : session: tokio::sync::watch::Sender<uuid::Uuid>,
683 : ids: Ids,
684 : conn_id: uuid::Uuid,
685 : }
686 :
687 : impl Client {
688 40 : pub fn metrics(&self) -> Arc<MetricCounter> {
689 40 : USAGE_METRICS.register(self.inner.as_ref().unwrap().ids.clone())
690 40 : }
691 : }
692 :
693 : pub struct Client {
694 : conn_id: uuid::Uuid,
695 : span: Span,
696 : inner: Option<ClientInner>,
697 : conn_info: ConnInfo,
698 : pool: Weak<RwLock<EndpointConnPool>>,
699 : }
700 :
701 : pub struct Discard<'a> {
702 : conn_id: uuid::Uuid,
703 : conn_info: &'a ConnInfo,
704 : pool: &'a mut Weak<RwLock<EndpointConnPool>>,
705 : }
706 :
707 : impl Client {
708 42 : pub(self) async fn new(
709 42 : inner: ClientInner,
710 42 : conn_info: ConnInfo,
711 42 : pool: Weak<RwLock<EndpointConnPool>>,
712 42 : ) -> Self {
713 42 : Self {
714 42 : conn_id: inner.conn_id,
715 42 : inner: Some(inner),
716 42 : span: Span::current(),
717 42 : conn_info,
718 42 : pool,
719 42 : }
720 42 : }
721 42 : pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
722 42 : let Self {
723 42 : inner,
724 42 : pool,
725 42 : conn_id,
726 42 : conn_info,
727 42 : span: _,
728 42 : } = self;
729 42 : (
730 42 : &mut inner
731 42 : .as_mut()
732 42 : .expect("client inner should not be removed")
733 42 : .inner,
734 42 : Discard {
735 42 : pool,
736 42 : conn_info,
737 42 : conn_id: *conn_id,
738 42 : },
739 42 : )
740 42 : }
741 :
742 38 : pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
743 38 : self.inner().1.check_idle(status)
744 38 : }
745 2 : pub fn discard(&mut self) {
746 2 : self.inner().1.discard()
747 2 : }
748 : }
749 :
750 : impl Discard<'_> {
751 40 : pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
752 40 : let conn_info = &self.conn_info;
753 40 : if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
754 2 : info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle")
755 38 : }
756 40 : }
757 2 : pub fn discard(&mut self) {
758 2 : let conn_info = &self.conn_info;
759 2 : if std::mem::take(self.pool).strong_count() > 0 {
760 2 : info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
761 UBC 0 : }
762 CBC 2 : }
763 : }
764 :
765 : impl Deref for Client {
766 : type Target = tokio_postgres::Client;
767 :
768 40 : fn deref(&self) -> &Self::Target {
769 40 : &self
770 40 : .inner
771 40 : .as_ref()
772 40 : .expect("client inner should not be removed")
773 40 : .inner
774 40 : }
775 : }
776 :
777 : impl Drop for Client {
778 42 : fn drop(&mut self) {
779 42 : let conn_info = self.conn_info.clone();
780 42 : let client = self
781 42 : .inner
782 42 : .take()
783 42 : .expect("client inner should not be removed");
784 42 : if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
785 20 : let current_span = self.span.clone();
786 20 : // return connection to the pool
787 20 : tokio::task::spawn_blocking(move || {
788 20 : let _span = current_span.enter();
789 20 : let _ = EndpointConnPool::put(&conn_pool, &conn_info, client);
790 20 : });
791 22 : }
792 42 : }
793 : }
|