Line data Source code
1 : use std::collections::VecDeque;
2 : use std::sync::atomic::{self, AtomicUsize};
3 : use std::sync::{Arc, Weak};
4 :
5 : use hyper::client::conn::http2;
6 : use hyper_util::rt::{TokioExecutor, TokioIo};
7 : use parking_lot::RwLock;
8 : use smol_str::ToSmolStr;
9 : use tokio::net::TcpStream;
10 : use tracing::{Instrument, debug, error, info, info_span};
11 :
12 : use super::backend::HttpConnError;
13 : use super::conn_pool_lib::{
14 : ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry,
15 : EndpointConnPoolExt, GlobalConnPool,
16 : };
17 : use crate::context::RequestContext;
18 : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
19 : use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
20 : use crate::protocol2::ConnectionInfoExtra;
21 : use crate::types::EndpointCacheKey;
22 : use crate::usage_metrics::{Ids, MetricCounter, TrafficDirection, USAGE_METRICS};
23 :
24 : pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
25 : pub(crate) type Connect =
26 : http2::Connection<TokioIo<TcpStream>, hyper::body::Incoming, TokioExecutor>;
27 :
28 : #[derive(Clone)]
29 : pub(crate) struct ClientDataHttp();
30 :
31 : // Per-endpoint connection pool
32 : // Number of open connections is limited by the `max_conns_per_endpoint`.
33 : pub(crate) struct HttpConnPool<C: ClientInnerExt + Clone> {
34 : // TODO(conrad):
35 : // either we should open more connections depending on stream count
36 : // (not exposed by hyper, need our own counter)
37 : // or we can change this to an Option rather than a VecDeque.
38 : //
39 : // Opening more connections to the same db because we run out of streams
40 : // seems somewhat redundant though.
41 : //
42 : // Probably we should run a semaphore and just the single conn. TBD.
43 : conns: VecDeque<ConnPoolEntry<C>>,
44 : _guard: HttpEndpointPoolsGuard<'static>,
45 : global_connections_count: Arc<AtomicUsize>,
46 : }
47 :
48 : impl<C: ClientInnerExt + Clone> HttpConnPool<C> {
49 0 : fn get_conn_entry(&mut self) -> Option<ConnPoolEntry<C>> {
50 0 : let Self { conns, .. } = self;
51 :
52 : loop {
53 0 : let conn = conns.pop_front()?;
54 0 : if !conn.conn.inner.is_closed() {
55 0 : let new_conn = ConnPoolEntry {
56 0 : conn: conn.conn.clone(),
57 0 : _last_access: std::time::Instant::now(),
58 0 : };
59 0 :
60 0 : conns.push_back(new_conn);
61 0 : return Some(conn);
62 0 : }
63 : }
64 0 : }
65 :
66 0 : fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
67 0 : let Self {
68 0 : conns,
69 0 : global_connections_count,
70 0 : ..
71 0 : } = self;
72 0 :
73 0 : let old_len = conns.len();
74 0 : conns.retain(|entry| entry.conn.conn_id != conn_id);
75 0 : let new_len = conns.len();
76 0 : let removed = old_len - new_len;
77 0 : if removed > 0 {
78 0 : global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
79 0 : Metrics::get()
80 0 : .proxy
81 0 : .http_pool_opened_connections
82 0 : .get_metric()
83 0 : .dec_by(removed as i64);
84 0 : }
85 0 : removed > 0
86 0 : }
87 : }
88 :
89 : impl<C: ClientInnerExt + Clone> EndpointConnPoolExt<C> for HttpConnPool<C> {
90 0 : fn clear_closed(&mut self) -> usize {
91 0 : let Self { conns, .. } = self;
92 0 : let old_len = conns.len();
93 0 : conns.retain(|entry| !entry.conn.inner.is_closed());
94 0 :
95 0 : let new_len = conns.len();
96 0 : old_len - new_len
97 0 : }
98 :
99 0 : fn total_conns(&self) -> usize {
100 0 : self.conns.len()
101 0 : }
102 : }
103 :
104 : impl<C: ClientInnerExt + Clone> Drop for HttpConnPool<C> {
105 0 : fn drop(&mut self) {
106 0 : if !self.conns.is_empty() {
107 0 : self.global_connections_count
108 0 : .fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
109 0 : Metrics::get()
110 0 : .proxy
111 0 : .http_pool_opened_connections
112 0 : .get_metric()
113 0 : .dec_by(self.conns.len() as i64);
114 0 : }
115 0 : }
116 : }
117 :
118 : impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
119 : #[expect(unused_results)]
120 0 : pub(crate) fn get(
121 0 : self: &Arc<Self>,
122 0 : ctx: &RequestContext,
123 0 : conn_info: &ConnInfo,
124 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
125 : let result: Result<Option<Client<C>>, HttpConnError>;
126 0 : let Some(endpoint) = conn_info.endpoint_cache_key() else {
127 0 : result = Ok(None);
128 0 : return result;
129 : };
130 0 : let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
131 0 : let Some(client) = endpoint_pool.write().get_conn_entry() else {
132 0 : result = Ok(None);
133 0 : return result;
134 : };
135 :
136 0 : tracing::Span::current().record("conn_id", tracing::field::display(client.conn.conn_id));
137 0 : debug!(
138 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
139 0 : "pool: reusing connection '{conn_info}'"
140 : );
141 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
142 0 : ctx.success();
143 0 :
144 0 : Ok(Some(Client::new(client.conn.clone())))
145 0 : }
146 :
147 0 : fn get_or_create_endpoint_pool(
148 0 : self: &Arc<Self>,
149 0 : endpoint: &EndpointCacheKey,
150 0 : ) -> Arc<RwLock<HttpConnPool<C>>> {
151 : // fast path
152 0 : if let Some(pool) = self.global_pool.get(endpoint) {
153 0 : return pool.clone();
154 0 : }
155 0 :
156 0 : // slow path
157 0 : let new_pool = Arc::new(RwLock::new(HttpConnPool {
158 0 : conns: VecDeque::new(),
159 0 : _guard: Metrics::get().proxy.http_endpoint_pools.guard(),
160 0 : global_connections_count: self.global_connections_count.clone(),
161 0 : }));
162 0 :
163 0 : // find or create a pool for this endpoint
164 0 : let mut created = false;
165 0 : let pool = self
166 0 : .global_pool
167 0 : .entry(endpoint.clone())
168 0 : .or_insert_with(|| {
169 0 : created = true;
170 0 : new_pool
171 0 : })
172 0 : .clone();
173 0 :
174 0 : // log new global pool size
175 0 : if created {
176 0 : let global_pool_size = self
177 0 : .global_pool_size
178 0 : .fetch_add(1, atomic::Ordering::Relaxed)
179 0 : + 1;
180 0 : info!(
181 0 : "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
182 : );
183 0 : }
184 :
185 0 : pool
186 0 : }
187 : }
188 :
189 0 : pub(crate) fn poll_http2_client(
190 0 : global_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
191 0 : ctx: &RequestContext,
192 0 : conn_info: &ConnInfo,
193 0 : client: Send,
194 0 : connection: Connect,
195 0 : conn_id: uuid::Uuid,
196 0 : aux: MetricsAuxInfo,
197 0 : ) -> Client<Send> {
198 0 : let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
199 0 : let session_id = ctx.session_id();
200 :
201 0 : let span = info_span!(parent: None, "connection", %conn_id);
202 0 : let cold_start_info = ctx.cold_start_info();
203 0 : span.in_scope(|| {
204 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
205 0 : });
206 :
207 0 : let pool = match conn_info.endpoint_cache_key() {
208 0 : Some(endpoint) => {
209 0 : let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
210 0 : let client = ClientInnerCommon {
211 0 : inner: client.clone(),
212 0 : aux: aux.clone(),
213 0 : conn_id,
214 0 : data: ClientDataEnum::Http(ClientDataHttp()),
215 0 : };
216 0 : pool.write().conns.push_back(ConnPoolEntry {
217 0 : conn: client,
218 0 : _last_access: std::time::Instant::now(),
219 0 : });
220 0 : Metrics::get()
221 0 : .proxy
222 0 : .http_pool_opened_connections
223 0 : .get_metric()
224 0 : .inc();
225 0 :
226 0 : Arc::downgrade(&pool)
227 : }
228 0 : None => Weak::new(),
229 : };
230 :
231 0 : tokio::spawn(
232 0 : async move {
233 0 : let _conn_gauge = conn_gauge;
234 0 : let res = connection.await;
235 0 : match res {
236 0 : Ok(()) => info!("connection closed"),
237 0 : Err(e) => error!(%session_id, "connection error: {e:?}"),
238 : }
239 :
240 : // remove from connection pool
241 0 : if let Some(pool) = pool.clone().upgrade() {
242 0 : if pool.write().remove_conn(conn_id) {
243 0 : info!("closed connection removed");
244 0 : }
245 0 : }
246 0 : }
247 0 : .instrument(span),
248 0 : );
249 0 :
250 0 : let client = ClientInnerCommon {
251 0 : inner: client,
252 0 : aux,
253 0 : conn_id,
254 0 : data: ClientDataEnum::Http(ClientDataHttp()),
255 0 : };
256 0 :
257 0 : Client::new(client)
258 0 : }
259 :
260 : pub(crate) struct Client<C: ClientInnerExt + Clone> {
261 : pub(crate) inner: ClientInnerCommon<C>,
262 : }
263 :
264 : impl<C: ClientInnerExt + Clone> Client<C> {
265 0 : pub(self) fn new(inner: ClientInnerCommon<C>) -> Self {
266 0 : Self { inner }
267 0 : }
268 :
269 0 : pub(crate) fn metrics(
270 0 : &self,
271 0 : direction: TrafficDirection,
272 0 : ctx: &RequestContext,
273 0 : ) -> Arc<MetricCounter> {
274 0 : let aux = &self.inner.aux;
275 :
276 0 : let private_link_id = match ctx.extra() {
277 0 : None => None,
278 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
279 0 : Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
280 : };
281 :
282 0 : USAGE_METRICS.register(Ids {
283 0 : endpoint_id: aux.endpoint_id,
284 0 : branch_id: aux.branch_id,
285 0 : direction,
286 0 : private_link_id,
287 0 : })
288 0 : }
289 : }
290 :
291 : impl ClientInnerExt for Send {
292 0 : fn is_closed(&self) -> bool {
293 0 : self.is_closed()
294 0 : }
295 :
296 0 : fn get_process_id(&self) -> i32 {
297 0 : // ideally throw something meaningful
298 0 : -1
299 0 : }
300 : }
|