Line data Source code
1 : use std::fmt;
2 : use std::pin::pin;
3 : use std::sync::{Arc, Weak};
4 : use std::task::{ready, Poll};
5 :
6 : use futures::future::poll_fn;
7 : use futures::Future;
8 : use smallvec::SmallVec;
9 : use tokio::time::Instant;
10 : use tokio_postgres::tls::NoTlsStream;
11 : use tokio_postgres::{AsyncMessage, Socket};
12 : use tokio_util::sync::CancellationToken;
13 : use tracing::{error, info, info_span, warn, Instrument};
14 : #[cfg(test)]
15 : use {
16 : super::conn_pool_lib::GlobalConnPoolOptions,
17 : crate::auth::backend::ComputeUserInfo,
18 : std::{sync::atomic, time::Duration},
19 : };
20 :
21 : use super::conn_pool_lib::{
22 : Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
23 : GlobalConnPool,
24 : };
25 : use crate::context::RequestContext;
26 : use crate::control_plane::messages::MetricsAuxInfo;
27 : use crate::metrics::Metrics;
28 :
29 : #[derive(Debug, Clone)]
30 : pub(crate) struct ConnInfoWithAuth {
31 : pub(crate) conn_info: ConnInfo,
32 : pub(crate) auth: AuthData,
33 : }
34 :
35 : #[derive(Debug, Clone)]
36 : pub(crate) enum AuthData {
37 : Password(SmallVec<[u8; 16]>),
38 : Jwt(String),
39 : }
40 :
41 : impl fmt::Display for ConnInfo {
42 : // use custom display to avoid logging password
43 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 0 : write!(
45 0 : f,
46 0 : "{}@{}/{}?{}",
47 0 : self.user_info.user,
48 0 : self.user_info.endpoint,
49 0 : self.dbname,
50 0 : self.user_info.options.get_cache_key("")
51 0 : )
52 0 : }
53 : }
54 :
55 0 : pub(crate) fn poll_client<C: ClientInnerExt>(
56 0 : global_pool: Arc<GlobalConnPool<C, EndpointConnPool<C>>>,
57 0 : ctx: &RequestContext,
58 0 : conn_info: ConnInfo,
59 0 : client: C,
60 0 : mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
61 0 : conn_id: uuid::Uuid,
62 0 : aux: MetricsAuxInfo,
63 0 : ) -> Client<C> {
64 0 : let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
65 0 : let mut session_id = ctx.session_id();
66 0 : let (tx, mut rx) = tokio::sync::watch::channel(session_id);
67 :
68 0 : let span = info_span!(parent: None, "connection", %conn_id);
69 0 : let cold_start_info = ctx.cold_start_info();
70 0 : span.in_scope(|| {
71 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
72 0 : });
73 0 : let pool = match conn_info.endpoint_cache_key() {
74 0 : Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
75 0 : None => Weak::new(),
76 : };
77 0 : let pool_clone = pool.clone();
78 0 :
79 0 : let db_user = conn_info.db_and_user();
80 0 : let idle = global_pool.get_idle_timeout();
81 0 : let cancel = CancellationToken::new();
82 0 : let cancelled = cancel.clone().cancelled_owned();
83 0 :
84 0 : tokio::spawn(
85 0 : async move {
86 0 : let _conn_gauge = conn_gauge;
87 0 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
88 0 : let mut cancelled = pin!(cancelled);
89 0 :
90 0 : poll_fn(move |cx| {
91 0 : if cancelled.as_mut().poll(cx).is_ready() {
92 0 : info!("connection dropped");
93 0 : return Poll::Ready(())
94 0 : }
95 0 :
96 0 : match rx.has_changed() {
97 : Ok(true) => {
98 0 : session_id = *rx.borrow_and_update();
99 0 : info!(%session_id, "changed session");
100 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
101 : }
102 : Err(_) => {
103 0 : info!("connection dropped");
104 0 : return Poll::Ready(())
105 : }
106 0 : _ => {}
107 : }
108 :
109 : // 5 minute idle connection timeout
110 0 : if idle_timeout.as_mut().poll(cx).is_ready() {
111 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
112 0 : info!("connection idle");
113 0 : if let Some(pool) = pool.clone().upgrade() {
114 : // remove client from pool - should close the connection if it's idle.
115 : // does nothing if the client is currently checked-out and in-use
116 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
117 0 : info!("idle connection removed");
118 0 : }
119 0 : }
120 0 : }
121 :
122 : loop {
123 0 : let message = ready!(connection.poll_message(cx));
124 :
125 0 : match message {
126 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
127 0 : info!(%session_id, "notice: {}", notice);
128 : }
129 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
130 0 : warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
131 : }
132 : Some(Ok(_)) => {
133 0 : warn!(%session_id, "unknown message");
134 : }
135 0 : Some(Err(e)) => {
136 0 : error!(%session_id, "connection error: {}", e);
137 0 : break
138 : }
139 : None => {
140 0 : info!("connection closed");
141 0 : break
142 : }
143 : }
144 : }
145 :
146 : // remove from connection pool
147 0 : if let Some(pool) = pool.clone().upgrade() {
148 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
149 0 : info!("closed connection removed");
150 0 : }
151 0 : }
152 :
153 0 : Poll::Ready(())
154 0 : }).await;
155 :
156 0 : }
157 0 : .instrument(span));
158 0 : let inner = ClientInnerCommon {
159 0 : inner: client,
160 0 : aux,
161 0 : conn_id,
162 0 : data: ClientDataEnum::Remote(ClientDataRemote {
163 0 : session: tx,
164 0 : cancel,
165 0 : }),
166 0 : };
167 0 :
168 0 : Client::new(inner, conn_info, pool_clone)
169 0 : }
170 :
171 : #[derive(Clone)]
172 : pub(crate) struct ClientDataRemote {
173 : session: tokio::sync::watch::Sender<uuid::Uuid>,
174 : cancel: CancellationToken,
175 : }
176 :
177 : impl ClientDataRemote {
178 0 : pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
179 0 : &mut self.session
180 0 : }
181 :
182 7 : pub fn cancel(&mut self) {
183 7 : self.cancel.cancel();
184 7 : }
185 : }
186 :
187 : #[cfg(test)]
188 : mod tests {
189 : use std::mem;
190 : use std::sync::atomic::AtomicBool;
191 :
192 : use super::*;
193 : use crate::proxy::NeonOptions;
194 : use crate::serverless::cancel_set::CancelSet;
195 : use crate::types::{BranchId, EndpointId, ProjectId};
196 :
197 : struct MockClient(Arc<AtomicBool>);
198 : impl MockClient {
199 6 : fn new(is_closed: bool) -> Self {
200 6 : MockClient(Arc::new(is_closed.into()))
201 6 : }
202 : }
203 : impl ClientInnerExt for MockClient {
204 8 : fn is_closed(&self) -> bool {
205 8 : self.0.load(atomic::Ordering::Relaxed)
206 8 : }
207 0 : fn get_process_id(&self) -> i32 {
208 0 : 0
209 0 : }
210 : }
211 :
212 5 : fn create_inner() -> ClientInnerCommon<MockClient> {
213 5 : create_inner_with(MockClient::new(false))
214 5 : }
215 :
216 7 : fn create_inner_with(client: MockClient) -> ClientInnerCommon<MockClient> {
217 7 : ClientInnerCommon {
218 7 : inner: client,
219 7 : aux: MetricsAuxInfo {
220 7 : endpoint_id: (&EndpointId::from("endpoint")).into(),
221 7 : project_id: (&ProjectId::from("project")).into(),
222 7 : branch_id: (&BranchId::from("branch")).into(),
223 7 : cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
224 7 : },
225 7 : conn_id: uuid::Uuid::new_v4(),
226 7 : data: ClientDataEnum::Remote(ClientDataRemote {
227 7 : session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
228 7 : cancel: CancellationToken::new(),
229 7 : }),
230 7 : }
231 7 : }
232 :
233 : #[tokio::test]
234 1 : async fn test_pool() {
235 1 : let _ = env_logger::try_init();
236 1 : let config = Box::leak(Box::new(crate::config::HttpConfig {
237 1 : accept_websockets: false,
238 1 : pool_options: GlobalConnPoolOptions {
239 1 : max_conns_per_endpoint: 2,
240 1 : gc_epoch: Duration::from_secs(1),
241 1 : pool_shards: 2,
242 1 : idle_timeout: Duration::from_secs(1),
243 1 : opt_in: false,
244 1 : max_total_conns: 3,
245 1 : },
246 1 : cancel_set: CancelSet::new(0),
247 1 : client_conn_threshold: u64::MAX,
248 1 : max_request_size_bytes: usize::MAX,
249 1 : max_response_size_bytes: usize::MAX,
250 1 : }));
251 1 : let pool = GlobalConnPool::new(config);
252 1 : let conn_info = ConnInfo {
253 1 : user_info: ComputeUserInfo {
254 1 : user: "user".into(),
255 1 : endpoint: "endpoint".into(),
256 1 : options: NeonOptions::default(),
257 1 : },
258 1 : dbname: "dbname".into(),
259 1 : };
260 1 : let ep_pool = Arc::downgrade(
261 1 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
262 1 : );
263 1 : {
264 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
265 1 : assert_eq!(0, pool.get_global_connections_count());
266 1 : client.inner().1.discard();
267 1 : // Discard should not add the connection from the pool.
268 1 : assert_eq!(0, pool.get_global_connections_count());
269 1 : }
270 1 : {
271 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
272 1 : client.do_drop().unwrap()();
273 1 : mem::forget(client); // drop the client
274 1 : assert_eq!(1, pool.get_global_connections_count());
275 1 : }
276 1 : {
277 1 : let mut closed_client = Client::new(
278 1 : create_inner_with(MockClient::new(true)),
279 1 : conn_info.clone(),
280 1 : ep_pool.clone(),
281 1 : );
282 1 : closed_client.do_drop().unwrap()();
283 1 : mem::forget(closed_client); // drop the client
284 1 : // The closed client shouldn't be added to the pool.
285 1 : assert_eq!(1, pool.get_global_connections_count());
286 1 : }
287 1 : let is_closed: Arc<AtomicBool> = Arc::new(false.into());
288 1 : {
289 1 : let mut client = Client::new(
290 1 : create_inner_with(MockClient(is_closed.clone())),
291 1 : conn_info.clone(),
292 1 : ep_pool.clone(),
293 1 : );
294 1 : client.do_drop().unwrap()();
295 1 : mem::forget(client); // drop the client
296 1 :
297 1 : // The client should be added to the pool.
298 1 : assert_eq!(2, pool.get_global_connections_count());
299 1 : }
300 1 : {
301 1 : let mut client = Client::new(create_inner(), conn_info, ep_pool);
302 1 : client.do_drop().unwrap()();
303 1 : mem::forget(client); // drop the client
304 1 :
305 1 : // The client shouldn't be added to the pool. Because the ep-pool is full.
306 1 : assert_eq!(2, pool.get_global_connections_count());
307 1 : }
308 1 :
309 1 : let conn_info = ConnInfo {
310 1 : user_info: ComputeUserInfo {
311 1 : user: "user".into(),
312 1 : endpoint: "endpoint-2".into(),
313 1 : options: NeonOptions::default(),
314 1 : },
315 1 : dbname: "dbname".into(),
316 1 : };
317 1 : let ep_pool = Arc::downgrade(
318 1 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
319 1 : );
320 1 : {
321 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
322 1 : client.do_drop().unwrap()();
323 1 : mem::forget(client); // drop the client
324 1 : assert_eq!(3, pool.get_global_connections_count());
325 1 : }
326 1 : {
327 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
328 1 : client.do_drop().unwrap()();
329 1 : mem::forget(client); // drop the client
330 1 :
331 1 : // The client shouldn't be added to the pool. Because the global pool is full.
332 1 : assert_eq!(3, pool.get_global_connections_count());
333 1 : }
334 1 :
335 1 : is_closed.store(true, atomic::Ordering::Relaxed);
336 1 : // Do gc for all shards.
337 1 : pool.gc(0);
338 1 : pool.gc(1);
339 1 : // Closed client should be removed from the pool.
340 1 : assert_eq!(2, pool.get_global_connections_count());
341 1 : }
342 : }
|