TLA Line data Source code
1 : use anyhow::Context;
2 : use async_trait::async_trait;
3 : use dashmap::DashMap;
4 : use futures::future::poll_fn;
5 : use parking_lot::RwLock;
6 : use pbkdf2::{
7 : password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
8 : Params, Pbkdf2,
9 : };
10 : use pq_proto::StartupMessageParams;
11 : use std::{collections::HashMap, sync::Arc};
12 : use std::{
13 : fmt,
14 : task::{ready, Poll},
15 : };
16 : use std::{
17 : ops::Deref,
18 : sync::atomic::{self, AtomicUsize},
19 : };
20 : use tokio::time;
21 : use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
22 :
23 : use crate::{
24 : auth, console,
25 : metrics::{Ids, MetricCounter, USAGE_METRICS},
26 : proxy::{LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, NUM_DB_CONNECTIONS_OPENED_COUNTER},
27 : };
28 : use crate::{compute, config};
29 :
30 : use crate::proxy::ConnectMechanism;
31 :
32 : use tracing::{error, warn, Span};
33 : use tracing::{info, info_span, Instrument};
34 :
35 : pub const APP_NAME: &str = "sql_over_http";
36 : const MAX_CONNS_PER_ENDPOINT: usize = 20;
37 :
38 CBC 12 : #[derive(Debug, Clone)]
39 : pub struct ConnInfo {
40 : pub username: String,
41 : pub dbname: String,
42 : pub hostname: String,
43 : pub password: String,
44 : }
45 :
46 : impl ConnInfo {
47 : // hm, change to hasher to avoid cloning?
48 29 : pub fn db_and_user(&self) -> (String, String) {
49 29 : (self.dbname.clone(), self.username.clone())
50 29 : }
51 : }
52 :
53 : impl fmt::Display for ConnInfo {
54 : // use custom display to avoid logging password
55 130 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 130 : write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname)
57 130 : }
58 : }
59 :
60 : struct ConnPoolEntry {
61 : conn: ClientInner,
62 : _last_access: std::time::Instant,
63 : }
64 :
65 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
66 : // Number of open connections is limited by the `max_conns_per_endpoint`.
67 : pub struct EndpointConnPool {
68 : pools: HashMap<(String, String), DbUserConnPool>,
69 : total_conns: usize,
70 : }
71 :
72 : /// 4096 is the number of rounds that SCRAM-SHA-256 recommends.
73 : /// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway.
74 : ///
75 : /// Still takes 1.4ms to hash on my hardware.
76 : /// We don't want to ruin the latency improvements of using the pool by making password verification take too long
77 : const PARAMS: Params = Params {
78 : rounds: 4096,
79 : output_length: 32,
80 : };
81 :
82 3 : #[derive(Default)]
83 : pub struct DbUserConnPool {
84 : conns: Vec<ConnPoolEntry>,
85 : password_hash: Option<PasswordHashString>,
86 : }
87 :
88 : pub struct GlobalConnPool {
89 : // endpoint -> per-endpoint connection pool
90 : //
91 : // That should be a fairly conteded map, so return reference to the per-endpoint
92 : // pool as early as possible and release the lock.
93 : global_pool: DashMap<String, Arc<RwLock<EndpointConnPool>>>,
94 :
95 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
96 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
97 : /// It's only used for diagnostics.
98 : global_pool_size: AtomicUsize,
99 :
100 : // Maximum number of connections per one endpoint.
101 : // Can mix different (dbname, username) connections.
102 : // When running out of free slots for a particular endpoint,
103 : // falls back to opening a new connection for each request.
104 : max_conns_per_endpoint: usize,
105 :
106 : proxy_config: &'static crate::config::ProxyConfig,
107 :
108 : // Using a lock to remove any race conditions.
109 : // Eg cleaning up connections while a new connection is returned
110 : closed: RwLock<bool>,
111 : }
112 :
113 : impl GlobalConnPool {
114 16 : pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
115 16 : Arc::new(Self {
116 16 : global_pool: DashMap::new(),
117 16 : global_pool_size: AtomicUsize::new(0),
118 16 : max_conns_per_endpoint: MAX_CONNS_PER_ENDPOINT,
119 16 : proxy_config: config,
120 16 : closed: RwLock::new(false),
121 16 : })
122 16 : }
123 :
124 16 : pub fn shutdown(&self) {
125 16 : *self.closed.write() = true;
126 16 :
127 16 : self.global_pool.retain(|_, endpoint_pool| {
128 3 : let mut pool = endpoint_pool.write();
129 3 : // by clearing this hashmap, we remove the slots that a connection can be returned to.
130 3 : // when returning, it drops the connection if the slot doesn't exist
131 3 : pool.pools.clear();
132 3 : pool.total_conns = 0;
133 3 :
134 3 : false
135 16 : });
136 16 : }
137 :
138 30 : pub async fn get(
139 30 : self: &Arc<Self>,
140 30 : conn_info: &ConnInfo,
141 30 : force_new: bool,
142 30 : session_id: uuid::Uuid,
143 30 : ) -> anyhow::Result<Client> {
144 30 : let mut client: Option<ClientInner> = None;
145 30 : let mut latency_timer = LatencyTimer::new("http");
146 :
147 30 : let pool = if force_new {
148 18 : None
149 : } else {
150 12 : Some((conn_info.clone(), self.clone()))
151 : };
152 :
153 30 : let mut hash_valid = false;
154 30 : if !force_new {
155 12 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
156 12 : let mut hash = None;
157 12 :
158 12 : // find a pool entry by (dbname, username) if exists
159 12 : {
160 12 : let pool = pool.read();
161 12 : if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) {
162 9 : if !pool_entries.conns.is_empty() {
163 6 : hash = pool_entries.password_hash.clone();
164 6 : }
165 3 : }
166 : }
167 :
168 : // a connection exists in the pool, verify the password hash
169 12 : if let Some(hash) = hash {
170 6 : let pw = conn_info.password.clone();
171 6 : let validate = tokio::task::spawn_blocking(move || {
172 6 : Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash())
173 6 : })
174 12 : .await?;
175 :
176 : // if the hash is invalid, don't error
177 : // we will continue with the regular connection flow
178 6 : if validate.is_ok() {
179 3 : hash_valid = true;
180 3 : let mut pool = pool.write();
181 3 : if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
182 3 : if let Some(entry) = pool_entries.conns.pop() {
183 3 : client = Some(entry.conn);
184 3 : pool.total_conns -= 1;
185 3 : }
186 UBC 0 : }
187 CBC 3 : }
188 6 : }
189 18 : }
190 :
191 : // ok return cached connection if found and establish a new one otherwise
192 30 : let new_client = if let Some(client) = client {
193 3 : if client.inner.is_closed() {
194 UBC 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
195 0 : connect_to_compute(self.proxy_config, conn_info, session_id, latency_timer).await
196 : } else {
197 CBC 3 : latency_timer.pool_hit();
198 3 : info!("pool: reusing connection '{conn_info}'");
199 3 : client.session.send(session_id)?;
200 3 : return Ok(Client {
201 3 : inner: Some(client),
202 3 : span: Span::current(),
203 3 : pool,
204 3 : });
205 : }
206 : } else {
207 27 : info!("pool: opening a new connection '{conn_info}'");
208 120 : connect_to_compute(self.proxy_config, conn_info, session_id, latency_timer).await
209 : };
210 :
211 25 : match &new_client {
212 2 : // clear the hash. it's no longer valid
213 2 : // TODO: update tokio-postgres fork to allow access to this error kind directly
214 2 : Err(err)
215 2 : if hash_valid && err.to_string().contains("password authentication failed") =>
216 UBC 0 : {
217 0 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
218 0 : let mut pool = pool.write();
219 0 : if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
220 0 : entry.password_hash = None;
221 0 : }
222 : }
223 : // new password is valid and we should insert/update it
224 CBC 25 : Ok(_) if !force_new && !hash_valid => {
225 7 : let pw = conn_info.password.clone();
226 7 : let new_hash = tokio::task::spawn_blocking(move || {
227 7 : let salt = SaltString::generate(rand::rngs::OsRng);
228 7 : Pbkdf2
229 7 : .hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt)
230 7 : .map(|s| s.serialize())
231 7 : })
232 7 : .await??;
233 :
234 7 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
235 7 : let mut pool = pool.write();
236 7 : pool.pools
237 7 : .entry(conn_info.db_and_user())
238 7 : .or_default()
239 7 : .password_hash = Some(new_hash);
240 : }
241 20 : _ => {}
242 : }
243 :
244 27 : new_client.map(|inner| Client {
245 25 : inner: Some(inner),
246 25 : span: Span::current(),
247 25 : pool,
248 27 : })
249 30 : }
250 :
251 7 : fn put(&self, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> {
252 7 : // We want to hold this open while we return. This ensures that the pool can't close
253 7 : // while we are in the middle of returning the connection.
254 7 : let closed = self.closed.read();
255 7 : if *closed {
256 UBC 0 : info!("pool: throwing away connection '{conn_info}' because pool is closed");
257 0 : return Ok(());
258 CBC 7 : }
259 7 :
260 7 : if client.inner.is_closed() {
261 UBC 0 : info!("pool: throwing away connection '{conn_info}' because connection is closed");
262 0 : return Ok(());
263 CBC 7 : }
264 7 :
265 7 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
266 7 :
267 7 : // return connection to the pool
268 7 : let mut returned = false;
269 7 : let mut per_db_size = 0;
270 7 : let total_conns = {
271 7 : let mut pool = pool.write();
272 7 :
273 7 : if pool.total_conns < self.max_conns_per_endpoint {
274 : // we create this db-user entry in get, so it should not be None
275 7 : if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
276 7 : pool_entries.conns.push(ConnPoolEntry {
277 7 : conn: client,
278 7 : _last_access: std::time::Instant::now(),
279 7 : });
280 7 :
281 7 : returned = true;
282 7 : per_db_size = pool_entries.conns.len();
283 7 :
284 7 : pool.total_conns += 1;
285 7 : }
286 UBC 0 : }
287 :
288 CBC 7 : pool.total_conns
289 7 : };
290 7 :
291 7 : // do logging outside of the mutex
292 7 : if returned {
293 7 : info!("pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
294 : } else {
295 UBC 0 : info!("pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
296 : }
297 :
298 CBC 7 : Ok(())
299 7 : }
300 :
301 : fn get_or_create_endpoint_pool(&self, endpoint: &String) -> Arc<RwLock<EndpointConnPool>> {
302 : // fast path
303 26 : if let Some(pool) = self.global_pool.get(endpoint) {
304 23 : return pool.clone();
305 3 : }
306 3 :
307 3 : // slow path
308 3 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
309 3 : pools: HashMap::new(),
310 3 : total_conns: 0,
311 3 : }));
312 3 :
313 3 : // find or create a pool for this endpoint
314 3 : let mut created = false;
315 3 : let pool = self
316 3 : .global_pool
317 3 : .entry(endpoint.clone())
318 3 : .or_insert_with(|| {
319 3 : created = true;
320 3 : new_pool
321 3 : })
322 3 : .clone();
323 3 :
324 3 : // log new global pool size
325 3 : if created {
326 3 : let global_pool_size = self
327 3 : .global_pool_size
328 3 : .fetch_add(1, atomic::Ordering::Relaxed)
329 3 : + 1;
330 3 : info!(
331 3 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
332 3 : );
333 UBC 0 : }
334 :
335 CBC 3 : pool
336 26 : }
337 : }
338 :
339 : struct TokioMechanism<'a> {
340 : conn_info: &'a ConnInfo,
341 : session_id: uuid::Uuid,
342 : }
343 :
344 : #[async_trait]
345 : impl ConnectMechanism for TokioMechanism<'_> {
346 : type Connection = ClientInner;
347 : type ConnectError = tokio_postgres::Error;
348 : type Error = anyhow::Error;
349 :
350 29 : async fn connect_once(
351 29 : &self,
352 29 : node_info: &console::CachedNodeInfo,
353 29 : timeout: time::Duration,
354 29 : ) -> Result<Self::Connection, Self::ConnectError> {
355 120 : connect_to_compute_once(node_info, self.conn_info, timeout, self.session_id).await
356 58 : }
357 :
358 29 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
359 : }
360 :
361 : // Wake up the destination if needed. Code here is a bit involved because
362 : // we reuse the code from the usual proxy and we need to prepare few structures
363 : // that this code expects.
364 108 : #[tracing::instrument(skip_all)]
365 : async fn connect_to_compute(
366 : config: &config::ProxyConfig,
367 : conn_info: &ConnInfo,
368 : session_id: uuid::Uuid,
369 : latency_timer: LatencyTimer,
370 : ) -> anyhow::Result<ClientInner> {
371 : let tls = config.tls_config.as_ref();
372 27 : let common_names = tls.and_then(|tls| tls.common_names.clone());
373 :
374 : let credential_params = StartupMessageParams::new([
375 : ("user", &conn_info.username),
376 : ("database", &conn_info.dbname),
377 : ("application_name", APP_NAME),
378 : ]);
379 :
380 : let creds = config
381 : .auth_backend
382 : .as_ref()
383 27 : .map(|_| {
384 27 : auth::ClientCredentials::parse(
385 27 : &credential_params,
386 27 : Some(&conn_info.hostname),
387 27 : common_names,
388 27 : )
389 27 : })
390 : .transpose()?;
391 : let extra = console::ConsoleReqExtra {
392 : session_id: uuid::Uuid::new_v4(),
393 : application_name: Some(APP_NAME),
394 : };
395 :
396 : let node_info = creds
397 : .wake_compute(&extra)
398 : .await?
399 : .context("missing cache entry from wake_compute")?;
400 :
401 : crate::proxy::connect_to_compute(
402 : &TokioMechanism {
403 : conn_info,
404 : session_id,
405 : },
406 : node_info,
407 : &extra,
408 : &creds,
409 : latency_timer,
410 : )
411 : .await
412 : }
413 :
414 29 : async fn connect_to_compute_once(
415 29 : node_info: &console::CachedNodeInfo,
416 29 : conn_info: &ConnInfo,
417 29 : timeout: time::Duration,
418 29 : mut session: uuid::Uuid,
419 29 : ) -> Result<ClientInner, tokio_postgres::Error> {
420 29 : let mut config = (*node_info.config).clone();
421 :
422 29 : let (client, mut connection) = config
423 29 : .user(&conn_info.username)
424 29 : .password(&conn_info.password)
425 29 : .dbname(&conn_info.dbname)
426 29 : .connect_timeout(timeout)
427 29 : .connect(tokio_postgres::NoTls)
428 120 : .await?;
429 :
430 25 : let (tx, mut rx) = tokio::sync::watch::channel(session);
431 25 :
432 25 : let conn_id = uuid::Uuid::new_v4();
433 25 : let span = info_span!(parent: None, "connection", %conn_id);
434 25 : span.in_scope(|| {
435 25 : info!(%conn_info, %session, "new connection");
436 25 : });
437 25 : let ids = Ids {
438 25 : endpoint_id: node_info.aux.endpoint_id.to_string(),
439 25 : branch_id: node_info.aux.branch_id.to_string(),
440 25 : };
441 25 :
442 25 : tokio::spawn(
443 25 : async move {
444 25 : NUM_DB_CONNECTIONS_OPENED_COUNTER.with_label_values(&["http"]).inc();
445 25 : scopeguard::defer! {
446 25 : NUM_DB_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc();
447 25 : }
448 246 : poll_fn(move |cx| {
449 246 : if matches!(rx.has_changed(), Ok(true)) {
450 3 : session = *rx.borrow_and_update();
451 3 : info!(%session, "changed session");
452 243 : }
453 :
454 : loop {
455 246 : let message = ready!(connection.poll_message(cx));
456 :
457 UBC 0 : match message {
458 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
459 0 : info!(%session, "notice: {}", notice);
460 : }
461 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
462 0 : warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received");
463 : }
464 : Some(Ok(_)) => {
465 0 : warn!(%session, "unknown message");
466 : }
467 0 : Some(Err(e)) => {
468 0 : error!(%session, "connection error: {}", e);
469 0 : return Poll::Ready(())
470 : }
471 : None => {
472 CBC 25 : info!("connection closed");
473 25 : return Poll::Ready(())
474 : }
475 : }
476 : }
477 246 : }).await
478 25 : }
479 25 : .instrument(span)
480 25 : );
481 25 :
482 25 : Ok(ClientInner {
483 25 : inner: client,
484 25 : session: tx,
485 25 : ids,
486 25 : })
487 29 : }
488 :
489 : struct ClientInner {
490 : inner: tokio_postgres::Client,
491 : session: tokio::sync::watch::Sender<uuid::Uuid>,
492 : ids: Ids,
493 : }
494 :
495 : impl Client {
496 27 : pub fn metrics(&self) -> Arc<MetricCounter> {
497 27 : USAGE_METRICS.register(self.inner.as_ref().unwrap().ids.clone())
498 27 : }
499 : }
500 :
501 : pub struct Client {
502 : span: Span,
503 : inner: Option<ClientInner>,
504 : pool: Option<(ConnInfo, Arc<GlobalConnPool>)>,
505 : }
506 :
507 : pub struct Discard<'a> {
508 : pool: &'a mut Option<(ConnInfo, Arc<GlobalConnPool>)>,
509 : }
510 :
511 : impl Client {
512 28 : pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
513 28 : let Self {
514 28 : inner,
515 28 : pool,
516 28 : span: _,
517 28 : } = self;
518 28 : (
519 28 : &mut inner
520 28 : .as_mut()
521 28 : .expect("client inner should not be removed")
522 28 : .inner,
523 28 : Discard { pool },
524 28 : )
525 28 : }
526 :
527 25 : pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
528 25 : self.inner().1.check_idle(status)
529 25 : }
530 1 : pub fn discard(&mut self) {
531 1 : self.inner().1.discard()
532 1 : }
533 : }
534 :
535 : impl Discard<'_> {
536 27 : pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
537 27 : if status != ReadyForQueryStatus::Idle {
538 2 : if let Some((conn_info, _)) = self.pool.take() {
539 2 : info!("pool: throwing away connection '{conn_info}' because connection is not idle")
540 UBC 0 : }
541 CBC 25 : }
542 27 : }
543 : pub fn discard(&mut self) {
544 1 : if let Some((conn_info, _)) = self.pool.take() {
545 1 : info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
546 UBC 0 : }
547 CBC 1 : }
548 : }
549 :
550 : impl Deref for Client {
551 : type Target = tokio_postgres::Client;
552 :
553 26 : fn deref(&self) -> &Self::Target {
554 26 : &self
555 26 : .inner
556 26 : .as_ref()
557 26 : .expect("client inner should not be removed")
558 26 : .inner
559 26 : }
560 : }
561 :
562 : impl Drop for Client {
563 28 : fn drop(&mut self) {
564 28 : let client = self
565 28 : .inner
566 28 : .take()
567 28 : .expect("client inner should not be removed");
568 28 : if let Some((conn_info, conn_pool)) = self.pool.take() {
569 7 : let current_span = self.span.clone();
570 7 : // return connection to the pool
571 7 : tokio::task::spawn_blocking(move || {
572 7 : let _span = current_span.enter();
573 7 : let _ = conn_pool.put(&conn_info, client);
574 7 : });
575 21 : }
576 28 : }
577 : }
|