Line data Source code
1 : use anyhow::Context;
2 : use async_trait::async_trait;
3 : use dashmap::DashMap;
4 : use futures::future::poll_fn;
5 : use parking_lot::RwLock;
6 : use pbkdf2::{
7 : password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
8 : Params, Pbkdf2,
9 : };
10 : use pq_proto::StartupMessageParams;
11 : use std::sync::atomic::{self, AtomicUsize};
12 : use std::{collections::HashMap, sync::Arc};
13 : use std::{
14 : fmt,
15 : task::{ready, Poll},
16 : };
17 : use tokio::time;
18 : use tokio_postgres::AsyncMessage;
19 :
20 : use crate::{auth, console};
21 : use crate::{compute, config};
22 :
23 : use super::sql_over_http::MAX_RESPONSE_SIZE;
24 :
25 : use crate::proxy::ConnectMechanism;
26 :
27 : use tracing::{error, warn};
28 : use tracing::{info, info_span, Instrument};
29 :
30 : pub const APP_NAME: &str = "sql_over_http";
31 : const MAX_CONNS_PER_ENDPOINT: usize = 20;
32 :
33 0 : #[derive(Debug)]
34 : pub struct ConnInfo {
35 : pub username: String,
36 : pub dbname: String,
37 : pub hostname: String,
38 : pub password: String,
39 : }
40 :
41 : impl ConnInfo {
42 : // hm, change to hasher to avoid cloning?
43 14 : pub fn db_and_user(&self) -> (String, String) {
44 14 : (self.dbname.clone(), self.username.clone())
45 14 : }
46 : }
47 :
48 : impl fmt::Display for ConnInfo {
49 : // use custom display to avoid logging password
50 88 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 88 : write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname)
52 88 : }
53 : }
54 :
55 : struct ConnPoolEntry {
56 : conn: Client,
57 : _last_access: std::time::Instant,
58 : }
59 :
60 : // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
61 : // Number of open connections is limited by the `max_conns_per_endpoint`.
62 : pub struct EndpointConnPool {
63 : pools: HashMap<(String, String), DbUserConnPool>,
64 : total_conns: usize,
65 : }
66 :
67 : /// 4096 is the number of rounds that SCRAM-SHA-256 recommends.
68 : /// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway.
69 : ///
70 : /// Still takes 1.4ms to hash on my hardware.
71 : /// We don't want to ruin the latency improvements of using the pool by making password verification take too long
72 : const PARAMS: Params = Params {
73 : rounds: 4096,
74 : output_length: 32,
75 : };
76 :
77 1 : #[derive(Default)]
78 : pub struct DbUserConnPool {
79 : conns: Vec<ConnPoolEntry>,
80 : password_hash: Option<PasswordHashString>,
81 : }
82 :
83 : pub struct GlobalConnPool {
84 : // endpoint -> per-endpoint connection pool
85 : //
86 : // That should be a fairly conteded map, so return reference to the per-endpoint
87 : // pool as early as possible and release the lock.
88 : global_pool: DashMap<String, Arc<RwLock<EndpointConnPool>>>,
89 :
90 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
91 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
92 : /// It's only used for diagnostics.
93 : global_pool_size: AtomicUsize,
94 :
95 : // Maximum number of connections per one endpoint.
96 : // Can mix different (dbname, username) connections.
97 : // When running out of free slots for a particular endpoint,
98 : // falls back to opening a new connection for each request.
99 : max_conns_per_endpoint: usize,
100 :
101 : proxy_config: &'static crate::config::ProxyConfig,
102 :
103 : // Using a lock to remove any race conditions.
104 : // Eg cleaning up connections while a new connection is returned
105 : closed: RwLock<bool>,
106 : }
107 :
108 : impl GlobalConnPool {
109 14 : pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
110 14 : Arc::new(Self {
111 14 : global_pool: DashMap::new(),
112 14 : global_pool_size: AtomicUsize::new(0),
113 14 : max_conns_per_endpoint: MAX_CONNS_PER_ENDPOINT,
114 14 : proxy_config: config,
115 14 : closed: RwLock::new(false),
116 14 : })
117 14 : }
118 :
119 14 : pub fn shutdown(&self) {
120 14 : *self.closed.write() = true;
121 14 :
122 14 : self.global_pool.retain(|_, endpoint_pool| {
123 1 : let mut pool = endpoint_pool.write();
124 1 : // by clearing this hashmap, we remove the slots that a connection can be returned to.
125 1 : // when returning, it drops the connection if the slot doesn't exist
126 1 : pool.pools.clear();
127 1 : pool.total_conns = 0;
128 1 :
129 1 : false
130 14 : });
131 14 : }
132 :
133 22 : pub async fn get(
134 22 : &self,
135 22 : conn_info: &ConnInfo,
136 22 : force_new: bool,
137 22 : session_id: uuid::Uuid,
138 22 : ) -> anyhow::Result<Client> {
139 22 : let mut client: Option<Client> = None;
140 22 :
141 22 : let mut hash_valid = false;
142 22 : if !force_new {
143 6 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
144 6 : let mut hash = None;
145 6 :
146 6 : // find a pool entry by (dbname, username) if exists
147 6 : {
148 6 : let pool = pool.read();
149 6 : if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) {
150 5 : if !pool_entries.conns.is_empty() {
151 5 : hash = pool_entries.password_hash.clone();
152 5 : }
153 1 : }
154 : }
155 :
156 : // a connection exists in the pool, verify the password hash
157 6 : if let Some(hash) = hash {
158 5 : let pw = conn_info.password.clone();
159 5 : let validate = tokio::task::spawn_blocking(move || {
160 5 : Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash())
161 5 : })
162 10 : .await?;
163 :
164 : // if the hash is invalid, don't error
165 : // we will continue with the regular connection flow
166 5 : if validate.is_ok() {
167 2 : hash_valid = true;
168 2 : let mut pool = pool.write();
169 2 : if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
170 2 : if let Some(entry) = pool_entries.conns.pop() {
171 2 : client = Some(entry.conn);
172 2 : pool.total_conns -= 1;
173 2 : }
174 0 : }
175 3 : }
176 1 : }
177 16 : }
178 :
179 : // ok return cached connection if found and establish a new one otherwise
180 22 : let new_client = if let Some(client) = client {
181 2 : if client.inner.is_closed() {
182 0 : info!("pool: cached connection '{conn_info}' is closed, opening a new one");
183 0 : connect_to_compute(self.proxy_config, conn_info, session_id).await
184 : } else {
185 2 : info!("pool: reusing connection '{conn_info}'");
186 2 : client.session.send(session_id)?;
187 2 : return Ok(client);
188 : }
189 : } else {
190 20 : info!("pool: opening a new connection '{conn_info}'");
191 88 : connect_to_compute(self.proxy_config, conn_info, session_id).await
192 : };
193 :
194 18 : match &new_client {
195 2 : // clear the hash. it's no longer valid
196 2 : // TODO: update tokio-postgres fork to allow access to this error kind directly
197 2 : Err(err)
198 2 : if hash_valid && err.to_string().contains("password authentication failed") =>
199 0 : {
200 0 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
201 0 : let mut pool = pool.write();
202 0 : if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
203 0 : entry.password_hash = None;
204 0 : }
205 : }
206 : // new password is valid and we should insert/update it
207 18 : Ok(_) if !force_new && !hash_valid => {
208 2 : let pw = conn_info.password.clone();
209 2 : let new_hash = tokio::task::spawn_blocking(move || {
210 2 : let salt = SaltString::generate(rand::rngs::OsRng);
211 2 : Pbkdf2
212 2 : .hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt)
213 2 : .map(|s| s.serialize())
214 2 : })
215 2 : .await??;
216 :
217 2 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
218 2 : let mut pool = pool.write();
219 2 : pool.pools
220 2 : .entry(conn_info.db_and_user())
221 2 : .or_default()
222 2 : .password_hash = Some(new_hash);
223 : }
224 18 : _ => {}
225 : }
226 :
227 20 : new_client
228 22 : }
229 :
230 4 : pub fn put(&self, conn_info: &ConnInfo, client: Client) -> anyhow::Result<()> {
231 4 : // We want to hold this open while we return. This ensures that the pool can't close
232 4 : // while we are in the middle of returning the connection.
233 4 : let closed = self.closed.read();
234 4 : if *closed {
235 0 : info!("pool: throwing away connection '{conn_info}' because pool is closed");
236 0 : return Ok(());
237 4 : }
238 4 :
239 4 : if client.inner.is_closed() {
240 0 : info!("pool: throwing away connection '{conn_info}' because connection is closed");
241 0 : return Ok(());
242 4 : }
243 4 :
244 4 : let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
245 4 :
246 4 : // return connection to the pool
247 4 : let mut returned = false;
248 4 : let mut per_db_size = 0;
249 4 : let total_conns = {
250 4 : let mut pool = pool.write();
251 4 :
252 4 : if pool.total_conns < self.max_conns_per_endpoint {
253 : // we create this db-user entry in get, so it should not be None
254 4 : if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
255 4 : pool_entries.conns.push(ConnPoolEntry {
256 4 : conn: client,
257 4 : _last_access: std::time::Instant::now(),
258 4 : });
259 4 :
260 4 : returned = true;
261 4 : per_db_size = pool_entries.conns.len();
262 4 :
263 4 : pool.total_conns += 1;
264 4 : }
265 0 : }
266 :
267 4 : pool.total_conns
268 4 : };
269 4 :
270 4 : // do logging outside of the mutex
271 4 : if returned {
272 4 : info!("pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
273 : } else {
274 0 : info!("pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
275 : }
276 :
277 4 : Ok(())
278 4 : }
279 :
280 : fn get_or_create_endpoint_pool(&self, endpoint: &String) -> Arc<RwLock<EndpointConnPool>> {
281 : // fast path
282 12 : if let Some(pool) = self.global_pool.get(endpoint) {
283 11 : return pool.clone();
284 1 : }
285 1 :
286 1 : // slow path
287 1 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
288 1 : pools: HashMap::new(),
289 1 : total_conns: 0,
290 1 : }));
291 1 :
292 1 : // find or create a pool for this endpoint
293 1 : let mut created = false;
294 1 : let pool = self
295 1 : .global_pool
296 1 : .entry(endpoint.clone())
297 1 : .or_insert_with(|| {
298 1 : created = true;
299 1 : new_pool
300 1 : })
301 1 : .clone();
302 1 :
303 1 : // log new global pool size
304 1 : if created {
305 1 : let global_pool_size = self
306 1 : .global_pool_size
307 1 : .fetch_add(1, atomic::Ordering::Relaxed)
308 1 : + 1;
309 1 : info!(
310 1 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
311 1 : );
312 0 : }
313 :
314 1 : pool
315 12 : }
316 : }
317 :
318 : struct TokioMechanism<'a> {
319 : conn_info: &'a ConnInfo,
320 : session_id: uuid::Uuid,
321 : }
322 :
323 : #[async_trait]
324 : impl ConnectMechanism for TokioMechanism<'_> {
325 : type Connection = Client;
326 : type ConnectError = tokio_postgres::Error;
327 : type Error = anyhow::Error;
328 :
329 22 : async fn connect_once(
330 22 : &self,
331 22 : node_info: &console::CachedNodeInfo,
332 22 : timeout: time::Duration,
333 22 : ) -> Result<Self::Connection, Self::ConnectError> {
334 88 : connect_to_compute_once(node_info, self.conn_info, timeout, self.session_id).await
335 44 : }
336 :
337 22 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
338 : }
339 :
340 : // Wake up the destination if needed. Code here is a bit involved because
341 : // we reuse the code from the usual proxy and we need to prepare few structures
342 : // that this code expects.
343 80 : #[tracing::instrument(skip_all)]
344 : async fn connect_to_compute(
345 : config: &config::ProxyConfig,
346 : conn_info: &ConnInfo,
347 : session_id: uuid::Uuid,
348 : ) -> anyhow::Result<Client> {
349 : let tls = config.tls_config.as_ref();
350 20 : let common_names = tls.and_then(|tls| tls.common_names.clone());
351 :
352 : let credential_params = StartupMessageParams::new([
353 : ("user", &conn_info.username),
354 : ("database", &conn_info.dbname),
355 : ("application_name", APP_NAME),
356 : ]);
357 :
358 : let creds = config
359 : .auth_backend
360 : .as_ref()
361 20 : .map(|_| {
362 20 : auth::ClientCredentials::parse(
363 20 : &credential_params,
364 20 : Some(&conn_info.hostname),
365 20 : common_names,
366 20 : )
367 20 : })
368 : .transpose()?;
369 : let extra = console::ConsoleReqExtra {
370 : session_id: uuid::Uuid::new_v4(),
371 : application_name: Some(APP_NAME),
372 : };
373 :
374 : let node_info = creds
375 : .wake_compute(&extra)
376 : .await?
377 : .context("missing cache entry from wake_compute")?;
378 :
379 : crate::proxy::connect_to_compute(
380 : &TokioMechanism {
381 : conn_info,
382 : session_id,
383 : },
384 : node_info,
385 : &extra,
386 : &creds,
387 : )
388 : .await
389 : }
390 :
391 22 : async fn connect_to_compute_once(
392 22 : node_info: &console::CachedNodeInfo,
393 22 : conn_info: &ConnInfo,
394 22 : timeout: time::Duration,
395 22 : mut session: uuid::Uuid,
396 22 : ) -> Result<Client, tokio_postgres::Error> {
397 22 : let mut config = (*node_info.config).clone();
398 :
399 22 : let (client, mut connection) = config
400 22 : .user(&conn_info.username)
401 22 : .password(&conn_info.password)
402 22 : .dbname(&conn_info.dbname)
403 22 : .max_backend_message_size(MAX_RESPONSE_SIZE)
404 22 : .connect_timeout(timeout)
405 22 : .connect(tokio_postgres::NoTls)
406 88 : .await?;
407 :
408 18 : let (tx, mut rx) = tokio::sync::watch::channel(session);
409 18 :
410 18 : let conn_id = uuid::Uuid::new_v4();
411 18 : let span = info_span!(parent: None, "connection", %conn_id);
412 18 : span.in_scope(|| {
413 18 : info!(%conn_info, %session, "new connection");
414 18 : });
415 18 :
416 18 : tokio::spawn(
417 131 : poll_fn(move |cx| {
418 131 : if matches!(rx.has_changed(), Ok(true)) {
419 2 : session = *rx.borrow_and_update();
420 2 : info!(%session, "changed session");
421 129 : }
422 :
423 : loop {
424 131 : let message = ready!(connection.poll_message(cx));
425 :
426 0 : match message {
427 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
428 0 : info!(%session, "notice: {}", notice);
429 : }
430 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
431 0 : warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received");
432 : }
433 : Some(Ok(_)) => {
434 0 : warn!(%session, "unknown message");
435 : }
436 0 : Some(Err(e)) => {
437 0 : error!(%session, "connection error: {}", e);
438 0 : return Poll::Ready(())
439 : }
440 : None => {
441 18 : info!("connection closed");
442 18 : return Poll::Ready(())
443 : }
444 : }
445 : }
446 131 : })
447 18 : .instrument(span)
448 18 : );
449 18 :
450 18 : Ok(Client {
451 18 : inner: client,
452 18 : session: tx,
453 18 : })
454 22 : }
455 :
456 : pub struct Client {
457 : pub inner: tokio_postgres::Client,
458 : session: tokio::sync::watch::Sender<uuid::Uuid>,
459 : }
|