Line data Source code
1 : use std::fmt;
2 : use std::pin::pin;
3 : use std::sync::{Arc, Weak};
4 : use std::task::{Poll, ready};
5 :
6 : use futures::Future;
7 : use futures::future::poll_fn;
8 : use postgres_client::AsyncMessage;
9 : use postgres_client::tls::NoTlsStream;
10 : use smallvec::SmallVec;
11 : use tokio::net::TcpStream;
12 : use tokio::time::Instant;
13 : use tokio_util::sync::CancellationToken;
14 : use tracing::{Instrument, error, info, info_span, warn};
15 : #[cfg(test)]
16 : use {
17 : super::conn_pool_lib::GlobalConnPoolOptions,
18 : crate::auth::backend::ComputeUserInfo,
19 : std::{sync::atomic, time::Duration},
20 : };
21 :
22 : use super::conn_pool_lib::{
23 : Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
24 : GlobalConnPool,
25 : };
26 : use crate::context::RequestContext;
27 : use crate::control_plane::messages::MetricsAuxInfo;
28 : use crate::metrics::Metrics;
29 :
30 : #[derive(Debug, Clone)]
31 : pub(crate) struct ConnInfoWithAuth {
32 : pub(crate) conn_info: ConnInfo,
33 : pub(crate) auth: AuthData,
34 : }
35 :
36 : #[derive(Debug, Clone)]
37 : pub(crate) enum AuthData {
38 : Password(SmallVec<[u8; 16]>),
39 : Jwt(String),
40 : }
41 :
42 : impl fmt::Display for ConnInfo {
43 : // use custom display to avoid logging password
44 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 0 : write!(
46 0 : f,
47 0 : "{}@{}/{}?{}",
48 0 : self.user_info.user,
49 0 : self.user_info.endpoint,
50 0 : self.dbname,
51 0 : self.user_info.options.get_cache_key("")
52 0 : )
53 0 : }
54 : }
55 :
56 0 : pub(crate) fn poll_client<C: ClientInnerExt>(
57 0 : global_pool: Arc<GlobalConnPool<C, EndpointConnPool<C>>>,
58 0 : ctx: &RequestContext,
59 0 : conn_info: ConnInfo,
60 0 : client: C,
61 0 : mut connection: postgres_client::Connection<TcpStream, NoTlsStream>,
62 0 : conn_id: uuid::Uuid,
63 0 : aux: MetricsAuxInfo,
64 0 : ) -> Client<C> {
65 0 : let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
66 0 : let mut session_id = ctx.session_id();
67 0 : let (tx, mut rx) = tokio::sync::watch::channel(session_id);
68 :
69 0 : let span = info_span!(parent: None, "connection", %conn_id);
70 0 : let cold_start_info = ctx.cold_start_info();
71 0 : span.in_scope(|| {
72 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
73 0 : });
74 0 : let pool = match conn_info.endpoint_cache_key() {
75 0 : Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
76 0 : None => Weak::new(),
77 : };
78 0 : let pool_clone = pool.clone();
79 0 :
80 0 : let db_user = conn_info.db_and_user();
81 0 : let idle = global_pool.get_idle_timeout();
82 0 : let cancel = CancellationToken::new();
83 0 : let cancelled = cancel.clone().cancelled_owned();
84 0 :
85 0 : tokio::spawn(
86 0 : async move {
87 0 : let _conn_gauge = conn_gauge;
88 0 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
89 0 : let mut cancelled = pin!(cancelled);
90 0 :
91 0 : poll_fn(move |cx| {
92 0 : if cancelled.as_mut().poll(cx).is_ready() {
93 0 : info!("connection dropped");
94 0 : return Poll::Ready(())
95 0 : }
96 0 :
97 0 : match rx.has_changed() {
98 : Ok(true) => {
99 0 : session_id = *rx.borrow_and_update();
100 0 : info!(%session_id, "changed session");
101 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
102 : }
103 : Err(_) => {
104 0 : info!("connection dropped");
105 0 : return Poll::Ready(())
106 : }
107 0 : _ => {}
108 : }
109 :
110 : // 5 minute idle connection timeout
111 0 : if idle_timeout.as_mut().poll(cx).is_ready() {
112 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
113 0 : info!("connection idle");
114 0 : if let Some(pool) = pool.clone().upgrade() {
115 : // remove client from pool - should close the connection if it's idle.
116 : // does nothing if the client is currently checked-out and in-use
117 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
118 0 : info!("idle connection removed");
119 0 : }
120 0 : }
121 0 : }
122 :
123 : loop {
124 0 : let message = ready!(connection.poll_message(cx));
125 :
126 0 : match message {
127 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
128 0 : info!(%session_id, "notice: {}", notice);
129 : }
130 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
131 0 : warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
132 : }
133 : Some(Ok(_)) => {
134 0 : warn!(%session_id, "unknown message");
135 : }
136 0 : Some(Err(e)) => {
137 0 : error!(%session_id, "connection error: {}", e);
138 0 : break
139 : }
140 : None => {
141 0 : info!("connection closed");
142 0 : break
143 : }
144 : }
145 : }
146 :
147 : // remove from connection pool
148 0 : if let Some(pool) = pool.clone().upgrade() {
149 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
150 0 : info!("closed connection removed");
151 0 : }
152 0 : }
153 :
154 0 : Poll::Ready(())
155 0 : }).await;
156 :
157 0 : }
158 0 : .instrument(span));
159 0 : let inner = ClientInnerCommon {
160 0 : inner: client,
161 0 : aux,
162 0 : conn_id,
163 0 : data: ClientDataEnum::Remote(ClientDataRemote {
164 0 : session: tx,
165 0 : cancel,
166 0 : }),
167 0 : };
168 0 :
169 0 : Client::new(inner, conn_info, pool_clone)
170 0 : }
171 :
172 : #[derive(Clone)]
173 : pub(crate) struct ClientDataRemote {
174 : session: tokio::sync::watch::Sender<uuid::Uuid>,
175 : cancel: CancellationToken,
176 : }
177 :
178 : impl ClientDataRemote {
179 0 : pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
180 0 : &mut self.session
181 0 : }
182 :
183 7 : pub fn cancel(&mut self) {
184 7 : self.cancel.cancel();
185 7 : }
186 : }
187 :
188 : #[cfg(test)]
189 : #[expect(clippy::unwrap_used)]
190 : mod tests {
191 : use std::sync::atomic::AtomicBool;
192 :
193 : use super::*;
194 : use crate::proxy::NeonOptions;
195 : use crate::serverless::cancel_set::CancelSet;
196 : use crate::types::{BranchId, EndpointId, ProjectId};
197 :
198 : struct MockClient(Arc<AtomicBool>);
199 : impl MockClient {
200 6 : fn new(is_closed: bool) -> Self {
201 6 : MockClient(Arc::new(is_closed.into()))
202 6 : }
203 : }
204 : impl ClientInnerExt for MockClient {
205 8 : fn is_closed(&self) -> bool {
206 8 : self.0.load(atomic::Ordering::Relaxed)
207 8 : }
208 0 : fn get_process_id(&self) -> i32 {
209 0 : 0
210 0 : }
211 : }
212 :
213 5 : fn create_inner() -> ClientInnerCommon<MockClient> {
214 5 : create_inner_with(MockClient::new(false))
215 5 : }
216 :
217 7 : fn create_inner_with(client: MockClient) -> ClientInnerCommon<MockClient> {
218 7 : ClientInnerCommon {
219 7 : inner: client,
220 7 : aux: MetricsAuxInfo {
221 7 : endpoint_id: (&EndpointId::from("endpoint")).into(),
222 7 : project_id: (&ProjectId::from("project")).into(),
223 7 : branch_id: (&BranchId::from("branch")).into(),
224 7 : compute_id: "compute".into(),
225 7 : cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
226 7 : },
227 7 : conn_id: uuid::Uuid::new_v4(),
228 7 : data: ClientDataEnum::Remote(ClientDataRemote {
229 7 : session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
230 7 : cancel: CancellationToken::new(),
231 7 : }),
232 7 : }
233 7 : }
234 :
235 : #[tokio::test]
236 1 : async fn test_pool() {
237 1 : let _ = env_logger::try_init();
238 1 : let config = Box::leak(Box::new(crate::config::HttpConfig {
239 1 : accept_websockets: false,
240 1 : pool_options: GlobalConnPoolOptions {
241 1 : max_conns_per_endpoint: 2,
242 1 : gc_epoch: Duration::from_secs(1),
243 1 : pool_shards: 2,
244 1 : idle_timeout: Duration::from_secs(1),
245 1 : opt_in: false,
246 1 : max_total_conns: 3,
247 1 : },
248 1 : cancel_set: CancelSet::new(0),
249 1 : client_conn_threshold: u64::MAX,
250 1 : max_request_size_bytes: usize::MAX,
251 1 : max_response_size_bytes: usize::MAX,
252 1 : }));
253 1 : let pool = GlobalConnPool::new(config);
254 1 : let conn_info = ConnInfo {
255 1 : user_info: ComputeUserInfo {
256 1 : user: "user".into(),
257 1 : endpoint: "endpoint".into(),
258 1 : options: NeonOptions::default(),
259 1 : },
260 1 : dbname: "dbname".into(),
261 1 : };
262 1 : let ep_pool = Arc::downgrade(
263 1 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
264 1 : );
265 1 : {
266 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
267 1 : assert_eq!(0, pool.get_global_connections_count());
268 1 : client.inner().1.discard();
269 1 : // Discard should not add the connection from the pool.
270 1 : assert_eq!(0, pool.get_global_connections_count());
271 1 : }
272 1 : {
273 1 : let client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
274 1 : drop(client);
275 1 : assert_eq!(1, pool.get_global_connections_count());
276 1 : }
277 1 : {
278 1 : let closed_client = Client::new(
279 1 : create_inner_with(MockClient::new(true)),
280 1 : conn_info.clone(),
281 1 : ep_pool.clone(),
282 1 : );
283 1 : drop(closed_client);
284 1 : assert_eq!(1, pool.get_global_connections_count());
285 1 : }
286 1 : let is_closed: Arc<AtomicBool> = Arc::new(false.into());
287 1 : {
288 1 : let client = Client::new(
289 1 : create_inner_with(MockClient(is_closed.clone())),
290 1 : conn_info.clone(),
291 1 : ep_pool.clone(),
292 1 : );
293 1 : drop(client);
294 1 : // The client should be added to the pool.
295 1 : assert_eq!(2, pool.get_global_connections_count());
296 1 : }
297 1 : {
298 1 : let client = Client::new(create_inner(), conn_info, ep_pool);
299 1 : drop(client);
300 1 :
301 1 : // The client shouldn't be added to the pool. Because the ep-pool is full.
302 1 : assert_eq!(2, pool.get_global_connections_count());
303 1 : }
304 1 :
305 1 : let conn_info = ConnInfo {
306 1 : user_info: ComputeUserInfo {
307 1 : user: "user".into(),
308 1 : endpoint: "endpoint-2".into(),
309 1 : options: NeonOptions::default(),
310 1 : },
311 1 : dbname: "dbname".into(),
312 1 : };
313 1 : let ep_pool = Arc::downgrade(
314 1 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
315 1 : );
316 1 : {
317 1 : let client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
318 1 : drop(client);
319 1 : assert_eq!(3, pool.get_global_connections_count());
320 1 : }
321 1 : {
322 1 : let client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
323 1 : drop(client);
324 1 :
325 1 : // The client shouldn't be added to the pool. Because the global pool is full.
326 1 : assert_eq!(3, pool.get_global_connections_count());
327 1 : }
328 1 :
329 1 : is_closed.store(true, atomic::Ordering::Relaxed);
330 1 : // Do gc for all shards.
331 1 : pool.gc(0);
332 1 : pool.gc(1);
333 1 : // Closed client should be removed from the pool.
334 1 : assert_eq!(2, pool.get_global_connections_count());
335 1 : }
336 : }
|