Line data Source code
1 : use std::collections::VecDeque;
2 : use std::sync::atomic::{self, AtomicUsize};
3 : use std::sync::{Arc, Weak};
4 :
5 : use dashmap::DashMap;
6 : use hyper::client::conn::http2;
7 : use hyper_util::rt::{TokioExecutor, TokioIo};
8 : use parking_lot::RwLock;
9 : use rand::Rng;
10 : use tokio::net::TcpStream;
11 : use tracing::{debug, error, info, info_span, Instrument};
12 :
13 : use super::backend::HttpConnError;
14 : use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
15 : use crate::context::RequestMonitoring;
16 : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
17 : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
18 : use crate::types::EndpointCacheKey;
19 : use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
20 :
21 : pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
22 : pub(crate) type Connect =
23 : http2::Connection<TokioIo<TcpStream>, hyper::body::Incoming, TokioExecutor>;
24 :
25 : #[derive(Clone)]
26 : pub(crate) struct ConnPoolEntry<C: ClientInnerExt + Clone> {
27 : conn: C,
28 : conn_id: uuid::Uuid,
29 : aux: MetricsAuxInfo,
30 : }
31 :
32 : pub(crate) struct ClientDataHttp();
33 :
34 : // Per-endpoint connection pool
35 : // Number of open connections is limited by the `max_conns_per_endpoint`.
36 : pub(crate) struct EndpointConnPool<C: ClientInnerExt + Clone> {
37 : // TODO(conrad):
38 : // either we should open more connections depending on stream count
39 : // (not exposed by hyper, need our own counter)
40 : // or we can change this to an Option rather than a VecDeque.
41 : //
42 : // Opening more connections to the same db because we run out of streams
43 : // seems somewhat redundant though.
44 : //
45 : // Probably we should run a semaphore and just the single conn. TBD.
46 : conns: VecDeque<ConnPoolEntry<C>>,
47 : _guard: HttpEndpointPoolsGuard<'static>,
48 : global_connections_count: Arc<AtomicUsize>,
49 : }
50 :
51 : impl<C: ClientInnerExt + Clone> EndpointConnPool<C> {
52 0 : fn get_conn_entry(&mut self) -> Option<ConnPoolEntry<C>> {
53 0 : let Self { conns, .. } = self;
54 :
55 : loop {
56 0 : let conn = conns.pop_front()?;
57 0 : if !conn.conn.is_closed() {
58 0 : conns.push_back(conn.clone());
59 0 : return Some(conn);
60 0 : }
61 : }
62 0 : }
63 :
64 0 : fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
65 0 : let Self {
66 0 : conns,
67 0 : global_connections_count,
68 0 : ..
69 0 : } = self;
70 0 :
71 0 : let old_len = conns.len();
72 0 : conns.retain(|conn| conn.conn_id != conn_id);
73 0 : let new_len = conns.len();
74 0 : let removed = old_len - new_len;
75 0 : if removed > 0 {
76 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
77 0 : Metrics::get()
78 0 : .proxy
79 0 : .http_pool_opened_connections
80 0 : .get_metric()
81 0 : .dec_by(removed as i64);
82 0 : }
83 0 : removed > 0
84 0 : }
85 : }
86 :
87 : impl<C: ClientInnerExt + Clone> Drop for EndpointConnPool<C> {
88 0 : fn drop(&mut self) {
89 0 : if !self.conns.is_empty() {
90 0 : self.global_connections_count
91 0 : .fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
92 0 : Metrics::get()
93 0 : .proxy
94 0 : .http_pool_opened_connections
95 0 : .get_metric()
96 0 : .dec_by(self.conns.len() as i64);
97 0 : }
98 0 : }
99 : }
100 :
101 : pub(crate) struct GlobalConnPool<C: ClientInnerExt + Clone> {
102 : // endpoint -> per-endpoint connection pool
103 : //
104 : // That should be a fairly conteded map, so return reference to the per-endpoint
105 : // pool as early as possible and release the lock.
106 : global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
107 :
108 : /// Number of endpoint-connection pools
109 : ///
110 : /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
111 : /// That seems like far too much effort, so we're using a relaxed increment counter instead.
112 : /// It's only used for diagnostics.
113 : global_pool_size: AtomicUsize,
114 :
115 : /// Total number of connections in the pool
116 : global_connections_count: Arc<AtomicUsize>,
117 :
118 : config: &'static crate::config::HttpConfig,
119 : }
120 :
121 : impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
122 0 : pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
123 0 : let shards = config.pool_options.pool_shards;
124 0 : Arc::new(Self {
125 0 : global_pool: DashMap::with_shard_amount(shards),
126 0 : global_pool_size: AtomicUsize::new(0),
127 0 : config,
128 0 : global_connections_count: Arc::new(AtomicUsize::new(0)),
129 0 : })
130 0 : }
131 :
132 0 : pub(crate) fn shutdown(&self) {
133 0 : // drops all strong references to endpoint-pools
134 0 : self.global_pool.clear();
135 0 : }
136 :
137 0 : pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
138 0 : let epoch = self.config.pool_options.gc_epoch;
139 0 : let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
140 : loop {
141 0 : interval.tick().await;
142 :
143 0 : let shard = rng.gen_range(0..self.global_pool.shards().len());
144 0 : self.gc(shard);
145 : }
146 : }
147 :
148 0 : fn gc(&self, shard: usize) {
149 0 : debug!(shard, "pool: performing epoch reclamation");
150 :
151 : // acquire a random shard lock
152 0 : let mut shard = self.global_pool.shards()[shard].write();
153 0 :
154 0 : let timer = Metrics::get()
155 0 : .proxy
156 0 : .http_pool_reclaimation_lag_seconds
157 0 : .start_timer();
158 0 : let current_len = shard.len();
159 0 : let mut clients_removed = 0;
160 0 : shard.retain(|endpoint, x| {
161 : // if the current endpoint pool is unique (no other strong or weak references)
162 : // then it is currently not in use by any connections.
163 0 : if let Some(pool) = Arc::get_mut(x.get_mut()) {
164 0 : let EndpointConnPool { conns, .. } = pool.get_mut();
165 0 :
166 0 : let old_len = conns.len();
167 0 :
168 0 : conns.retain(|conn| !conn.conn.is_closed());
169 0 :
170 0 : let new_len = conns.len();
171 0 : let removed = old_len - new_len;
172 0 : clients_removed += removed;
173 0 :
174 0 : // we only remove this pool if it has no active connections
175 0 : if conns.is_empty() {
176 0 : info!("pool: discarding pool for endpoint {endpoint}");
177 0 : return false;
178 0 : }
179 0 : }
180 :
181 0 : true
182 0 : });
183 0 :
184 0 : let new_len = shard.len();
185 0 : drop(shard);
186 0 : timer.observe();
187 0 :
188 0 : // Do logging outside of the lock.
189 0 : if clients_removed > 0 {
190 0 : let size = self
191 0 : .global_connections_count
192 0 : .fetch_sub(clients_removed, atomic::Ordering::Relaxed)
193 0 : - clients_removed;
194 0 : Metrics::get()
195 0 : .proxy
196 0 : .http_pool_opened_connections
197 0 : .get_metric()
198 0 : .dec_by(clients_removed as i64);
199 0 : info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
200 0 : }
201 0 : let removed = current_len - new_len;
202 0 :
203 0 : if removed > 0 {
204 0 : let global_pool_size = self
205 0 : .global_pool_size
206 0 : .fetch_sub(removed, atomic::Ordering::Relaxed)
207 0 : - removed;
208 0 : info!("pool: performed global pool gc. size now {global_pool_size}");
209 0 : }
210 0 : }
211 :
212 : #[expect(unused_results)]
213 0 : pub(crate) fn get(
214 0 : self: &Arc<Self>,
215 0 : ctx: &RequestMonitoring,
216 0 : conn_info: &ConnInfo,
217 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
218 : let result: Result<Option<Client<C>>, HttpConnError>;
219 0 : let Some(endpoint) = conn_info.endpoint_cache_key() else {
220 0 : result = Ok(None);
221 0 : return result;
222 : };
223 0 : let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
224 0 : let Some(client) = endpoint_pool.write().get_conn_entry() else {
225 0 : result = Ok(None);
226 0 : return result;
227 : };
228 :
229 0 : tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
230 0 : info!(
231 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
232 0 : "pool: reusing connection '{conn_info}'"
233 : );
234 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
235 0 : ctx.success();
236 0 : Ok(Some(Client::new(client.conn, client.aux)))
237 0 : }
238 :
239 0 : fn get_or_create_endpoint_pool(
240 0 : self: &Arc<Self>,
241 0 : endpoint: &EndpointCacheKey,
242 0 : ) -> Arc<RwLock<EndpointConnPool<C>>> {
243 : // fast path
244 0 : if let Some(pool) = self.global_pool.get(endpoint) {
245 0 : return pool.clone();
246 0 : }
247 0 :
248 0 : // slow path
249 0 : let new_pool = Arc::new(RwLock::new(EndpointConnPool {
250 0 : conns: VecDeque::new(),
251 0 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
252 0 : global_connections_count: self.global_connections_count.clone(),
253 0 : }));
254 0 :
255 0 : // find or create a pool for this endpoint
256 0 : let mut created = false;
257 0 : let pool = self
258 0 : .global_pool
259 0 : .entry(endpoint.clone())
260 0 : .or_insert_with(|| {
261 0 : created = true;
262 0 : new_pool
263 0 : })
264 0 : .clone();
265 0 :
266 0 : // log new global pool size
267 0 : if created {
268 0 : let global_pool_size = self
269 0 : .global_pool_size
270 0 : .fetch_add(1, atomic::Ordering::Relaxed)
271 0 : + 1;
272 0 : info!(
273 0 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
274 : );
275 0 : }
276 :
277 0 : pool
278 0 : }
279 : }
280 :
281 0 : pub(crate) fn poll_http2_client(
282 0 : global_pool: Arc<GlobalConnPool<Send>>,
283 0 : ctx: &RequestMonitoring,
284 0 : conn_info: &ConnInfo,
285 0 : client: Send,
286 0 : connection: Connect,
287 0 : conn_id: uuid::Uuid,
288 0 : aux: MetricsAuxInfo,
289 0 : ) -> Client<Send> {
290 0 : let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
291 0 : let session_id = ctx.session_id();
292 :
293 0 : let span = info_span!(parent: None, "connection", %conn_id);
294 0 : let cold_start_info = ctx.cold_start_info();
295 0 : span.in_scope(|| {
296 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
297 0 : });
298 :
299 0 : let pool = match conn_info.endpoint_cache_key() {
300 0 : Some(endpoint) => {
301 0 : let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
302 0 :
303 0 : pool.write().conns.push_back(ConnPoolEntry {
304 0 : conn: client.clone(),
305 0 : conn_id,
306 0 : aux: aux.clone(),
307 0 : });
308 0 : Metrics::get()
309 0 : .proxy
310 0 : .http_pool_opened_connections
311 0 : .get_metric()
312 0 : .inc();
313 0 :
314 0 : Arc::downgrade(&pool)
315 : }
316 0 : None => Weak::new(),
317 : };
318 :
319 0 : tokio::spawn(
320 0 : async move {
321 0 : let _conn_gauge = conn_gauge;
322 0 : let res = connection.await;
323 0 : match res {
324 0 : Ok(()) => info!("connection closed"),
325 0 : Err(e) => error!(%session_id, "connection error: {e:?}"),
326 : }
327 :
328 : // remove from connection pool
329 0 : if let Some(pool) = pool.clone().upgrade() {
330 0 : if pool.write().remove_conn(conn_id) {
331 0 : info!("closed connection removed");
332 0 : }
333 0 : }
334 0 : }
335 0 : .instrument(span),
336 0 : );
337 0 :
338 0 : Client::new(client, aux)
339 0 : }
340 :
341 : pub(crate) struct Client<C: ClientInnerExt + Clone> {
342 : pub(crate) inner: C,
343 : aux: MetricsAuxInfo,
344 : }
345 :
346 : impl<C: ClientInnerExt + Clone> Client<C> {
347 0 : pub(self) fn new(inner: C, aux: MetricsAuxInfo) -> Self {
348 0 : Self { inner, aux }
349 0 : }
350 :
351 0 : pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
352 0 : USAGE_METRICS.register(Ids {
353 0 : endpoint_id: self.aux.endpoint_id,
354 0 : branch_id: self.aux.branch_id,
355 0 : })
356 0 : }
357 : }
358 :
359 : impl ClientInnerExt for Send {
360 0 : fn is_closed(&self) -> bool {
361 0 : self.is_closed()
362 0 : }
363 :
364 0 : fn get_process_id(&self) -> i32 {
365 0 : // ideally throw something meaningful
366 0 : -1
367 0 : }
368 : }
|