Line data Source code
1 : use std::fmt;
2 : use std::pin::pin;
3 : use std::sync::{Arc, Weak};
4 : use std::task::{ready, Poll};
5 :
6 : use futures::future::poll_fn;
7 : use futures::Future;
8 : use smallvec::SmallVec;
9 : use tokio::time::Instant;
10 : use tokio_postgres::tls::NoTlsStream;
11 : use tokio_postgres::{AsyncMessage, Socket};
12 : use tokio_util::sync::CancellationToken;
13 : use tracing::{error, info, info_span, warn, Instrument};
14 : #[cfg(test)]
15 : use {
16 : super::conn_pool_lib::GlobalConnPoolOptions,
17 : crate::auth::backend::ComputeUserInfo,
18 : std::{sync::atomic, time::Duration},
19 : };
20 :
21 : use super::conn_pool_lib::{Client, ClientInnerExt, ConnInfo, GlobalConnPool};
22 : use crate::context::RequestMonitoring;
23 : use crate::control_plane::messages::MetricsAuxInfo;
24 : use crate::metrics::Metrics;
25 :
26 : #[derive(Debug, Clone)]
27 : pub(crate) struct ConnInfoWithAuth {
28 : pub(crate) conn_info: ConnInfo,
29 : pub(crate) auth: AuthData,
30 : }
31 :
32 : #[derive(Debug, Clone)]
33 : pub(crate) enum AuthData {
34 : Password(SmallVec<[u8; 16]>),
35 : Jwt(String),
36 : }
37 :
38 : impl fmt::Display for ConnInfo {
39 : // use custom display to avoid logging password
40 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 0 : write!(
42 0 : f,
43 0 : "{}@{}/{}?{}",
44 0 : self.user_info.user,
45 0 : self.user_info.endpoint,
46 0 : self.dbname,
47 0 : self.user_info.options.get_cache_key("")
48 0 : )
49 0 : }
50 : }
51 :
52 0 : pub(crate) fn poll_client<C: ClientInnerExt>(
53 0 : global_pool: Arc<GlobalConnPool<C>>,
54 0 : ctx: &RequestMonitoring,
55 0 : conn_info: ConnInfo,
56 0 : client: C,
57 0 : mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
58 0 : conn_id: uuid::Uuid,
59 0 : aux: MetricsAuxInfo,
60 0 : ) -> Client<C> {
61 0 : let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
62 0 : let mut session_id = ctx.session_id();
63 0 : let (tx, mut rx) = tokio::sync::watch::channel(session_id);
64 :
65 0 : let span = info_span!(parent: None, "connection", %conn_id);
66 0 : let cold_start_info = ctx.cold_start_info();
67 0 : span.in_scope(|| {
68 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
69 0 : });
70 0 : let pool = match conn_info.endpoint_cache_key() {
71 0 : Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
72 0 : None => Weak::new(),
73 : };
74 0 : let pool_clone = pool.clone();
75 0 :
76 0 : let db_user = conn_info.db_and_user();
77 0 : let idle = global_pool.get_idle_timeout();
78 0 : let cancel = CancellationToken::new();
79 0 : let cancelled = cancel.clone().cancelled_owned();
80 0 :
81 0 : tokio::spawn(
82 0 : async move {
83 0 : let _conn_gauge = conn_gauge;
84 0 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
85 0 : let mut cancelled = pin!(cancelled);
86 0 :
87 0 : poll_fn(move |cx| {
88 0 : if cancelled.as_mut().poll(cx).is_ready() {
89 0 : info!("connection dropped");
90 0 : return Poll::Ready(())
91 0 : }
92 0 :
93 0 : match rx.has_changed() {
94 : Ok(true) => {
95 0 : session_id = *rx.borrow_and_update();
96 0 : info!(%session_id, "changed session");
97 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
98 : }
99 : Err(_) => {
100 0 : info!("connection dropped");
101 0 : return Poll::Ready(())
102 : }
103 0 : _ => {}
104 : }
105 :
106 : // 5 minute idle connection timeout
107 0 : if idle_timeout.as_mut().poll(cx).is_ready() {
108 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
109 0 : info!("connection idle");
110 0 : if let Some(pool) = pool.clone().upgrade() {
111 : // remove client from pool - should close the connection if it's idle.
112 : // does nothing if the client is currently checked-out and in-use
113 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
114 0 : info!("idle connection removed");
115 0 : }
116 0 : }
117 0 : }
118 :
119 : loop {
120 0 : let message = ready!(connection.poll_message(cx));
121 :
122 0 : match message {
123 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
124 0 : info!(%session_id, "notice: {}", notice);
125 : }
126 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
127 0 : warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
128 : }
129 : Some(Ok(_)) => {
130 0 : warn!(%session_id, "unknown message");
131 : }
132 0 : Some(Err(e)) => {
133 0 : error!(%session_id, "connection error: {}", e);
134 0 : break
135 : }
136 : None => {
137 0 : info!("connection closed");
138 0 : break
139 : }
140 : }
141 : }
142 :
143 : // remove from connection pool
144 0 : if let Some(pool) = pool.clone().upgrade() {
145 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
146 0 : info!("closed connection removed");
147 0 : }
148 0 : }
149 :
150 0 : Poll::Ready(())
151 0 : }).await;
152 :
153 0 : }
154 0 : .instrument(span));
155 0 : let inner = ClientInnerRemote {
156 0 : inner: client,
157 0 : session: tx,
158 0 : cancel,
159 0 : aux,
160 0 : conn_id,
161 0 : };
162 0 : Client::new(inner, conn_info, pool_clone)
163 0 : }
164 :
165 : pub(crate) struct ClientInnerRemote<C: ClientInnerExt> {
166 : inner: C,
167 : session: tokio::sync::watch::Sender<uuid::Uuid>,
168 : cancel: CancellationToken,
169 : aux: MetricsAuxInfo,
170 : conn_id: uuid::Uuid,
171 : }
172 :
173 : impl<C: ClientInnerExt> ClientInnerRemote<C> {
174 1 : pub(crate) fn inner_mut(&mut self) -> &mut C {
175 1 : &mut self.inner
176 1 : }
177 :
178 0 : pub(crate) fn inner(&self) -> &C {
179 0 : &self.inner
180 0 : }
181 :
182 0 : pub(crate) fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
183 0 : &mut self.session
184 0 : }
185 :
186 0 : pub(crate) fn aux(&self) -> &MetricsAuxInfo {
187 0 : &self.aux
188 0 : }
189 :
190 6 : pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
191 6 : self.conn_id
192 6 : }
193 :
194 8 : pub(crate) fn is_closed(&self) -> bool {
195 8 : self.inner.is_closed()
196 8 : }
197 : }
198 :
199 : impl<C: ClientInnerExt> Drop for ClientInnerRemote<C> {
200 7 : fn drop(&mut self) {
201 7 : // on client drop, tell the conn to shut down
202 7 : self.cancel.cancel();
203 7 : }
204 : }
205 :
206 : #[cfg(test)]
207 : mod tests {
208 : use std::mem;
209 : use std::sync::atomic::AtomicBool;
210 :
211 : use super::*;
212 : use crate::proxy::NeonOptions;
213 : use crate::serverless::cancel_set::CancelSet;
214 : use crate::{BranchId, EndpointId, ProjectId};
215 :
216 : struct MockClient(Arc<AtomicBool>);
217 : impl MockClient {
218 6 : fn new(is_closed: bool) -> Self {
219 6 : MockClient(Arc::new(is_closed.into()))
220 6 : }
221 : }
222 : impl ClientInnerExt for MockClient {
223 8 : fn is_closed(&self) -> bool {
224 8 : self.0.load(atomic::Ordering::Relaxed)
225 8 : }
226 0 : fn get_process_id(&self) -> i32 {
227 0 : 0
228 0 : }
229 : }
230 :
231 5 : fn create_inner() -> ClientInnerRemote<MockClient> {
232 5 : create_inner_with(MockClient::new(false))
233 5 : }
234 :
235 7 : fn create_inner_with(client: MockClient) -> ClientInnerRemote<MockClient> {
236 7 : ClientInnerRemote {
237 7 : inner: client,
238 7 : session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
239 7 : cancel: CancellationToken::new(),
240 7 : aux: MetricsAuxInfo {
241 7 : endpoint_id: (&EndpointId::from("endpoint")).into(),
242 7 : project_id: (&ProjectId::from("project")).into(),
243 7 : branch_id: (&BranchId::from("branch")).into(),
244 7 : cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
245 7 : },
246 7 : conn_id: uuid::Uuid::new_v4(),
247 7 : }
248 7 : }
249 :
250 : #[tokio::test]
251 1 : async fn test_pool() {
252 1 : let _ = env_logger::try_init();
253 1 : let config = Box::leak(Box::new(crate::config::HttpConfig {
254 1 : accept_websockets: false,
255 1 : pool_options: GlobalConnPoolOptions {
256 1 : max_conns_per_endpoint: 2,
257 1 : gc_epoch: Duration::from_secs(1),
258 1 : pool_shards: 2,
259 1 : idle_timeout: Duration::from_secs(1),
260 1 : opt_in: false,
261 1 : max_total_conns: 3,
262 1 : },
263 1 : cancel_set: CancelSet::new(0),
264 1 : client_conn_threshold: u64::MAX,
265 1 : max_request_size_bytes: u64::MAX,
266 1 : max_response_size_bytes: usize::MAX,
267 1 : }));
268 1 : let pool = GlobalConnPool::new(config);
269 1 : let conn_info = ConnInfo {
270 1 : user_info: ComputeUserInfo {
271 1 : user: "user".into(),
272 1 : endpoint: "endpoint".into(),
273 1 : options: NeonOptions::default(),
274 1 : },
275 1 : dbname: "dbname".into(),
276 1 : };
277 1 : let ep_pool = Arc::downgrade(
278 1 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
279 1 : );
280 1 : {
281 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
282 1 : assert_eq!(0, pool.get_global_connections_count());
283 1 : client.inner_mut().1.discard();
284 1 : // Discard should not add the connection from the pool.
285 1 : assert_eq!(0, pool.get_global_connections_count());
286 1 : }
287 1 : {
288 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
289 1 : client.do_drop().unwrap()();
290 1 : mem::forget(client); // drop the client
291 1 : assert_eq!(1, pool.get_global_connections_count());
292 1 : }
293 1 : {
294 1 : let mut closed_client = Client::new(
295 1 : create_inner_with(MockClient::new(true)),
296 1 : conn_info.clone(),
297 1 : ep_pool.clone(),
298 1 : );
299 1 : closed_client.do_drop().unwrap()();
300 1 : mem::forget(closed_client); // drop the client
301 1 : // The closed client shouldn't be added to the pool.
302 1 : assert_eq!(1, pool.get_global_connections_count());
303 1 : }
304 1 : let is_closed: Arc<AtomicBool> = Arc::new(false.into());
305 1 : {
306 1 : let mut client = Client::new(
307 1 : create_inner_with(MockClient(is_closed.clone())),
308 1 : conn_info.clone(),
309 1 : ep_pool.clone(),
310 1 : );
311 1 : client.do_drop().unwrap()();
312 1 : mem::forget(client); // drop the client
313 1 :
314 1 : // The client should be added to the pool.
315 1 : assert_eq!(2, pool.get_global_connections_count());
316 1 : }
317 1 : {
318 1 : let mut client = Client::new(create_inner(), conn_info, ep_pool);
319 1 : client.do_drop().unwrap()();
320 1 : mem::forget(client); // drop the client
321 1 :
322 1 : // The client shouldn't be added to the pool. Because the ep-pool is full.
323 1 : assert_eq!(2, pool.get_global_connections_count());
324 1 : }
325 1 :
326 1 : let conn_info = ConnInfo {
327 1 : user_info: ComputeUserInfo {
328 1 : user: "user".into(),
329 1 : endpoint: "endpoint-2".into(),
330 1 : options: NeonOptions::default(),
331 1 : },
332 1 : dbname: "dbname".into(),
333 1 : };
334 1 : let ep_pool = Arc::downgrade(
335 1 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
336 1 : );
337 1 : {
338 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
339 1 : client.do_drop().unwrap()();
340 1 : mem::forget(client); // drop the client
341 1 : assert_eq!(3, pool.get_global_connections_count());
342 1 : }
343 1 : {
344 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
345 1 : client.do_drop().unwrap()();
346 1 : mem::forget(client); // drop the client
347 1 :
348 1 : // The client shouldn't be added to the pool. Because the global pool is full.
349 1 : assert_eq!(3, pool.get_global_connections_count());
350 1 : }
351 1 :
352 1 : is_closed.store(true, atomic::Ordering::Relaxed);
353 1 : // Do gc for all shards.
354 1 : pool.gc(0);
355 1 : pool.gc(1);
356 1 : // Closed client should be removed from the pool.
357 1 : assert_eq!(2, pool.get_global_connections_count());
358 1 : }
359 : }
|