Line data Source code
1 : use std::fmt;
2 : use std::pin::pin;
3 : use std::sync::{Arc, Weak};
4 : use std::task::{ready, Poll};
5 :
6 : use futures::future::poll_fn;
7 : use futures::Future;
8 : use smallvec::SmallVec;
9 : use tokio::time::Instant;
10 : use tokio_postgres::tls::NoTlsStream;
11 : use tokio_postgres::{AsyncMessage, Socket};
12 : use tokio_util::sync::CancellationToken;
13 : use tracing::{error, info, info_span, warn, Instrument};
14 : #[cfg(test)]
15 : use {
16 : super::conn_pool_lib::GlobalConnPoolOptions,
17 : crate::auth::backend::ComputeUserInfo,
18 : std::{sync::atomic, time::Duration},
19 : };
20 :
21 : use super::conn_pool_lib::{
22 : Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, GlobalConnPool,
23 : };
24 : use crate::context::RequestMonitoring;
25 : use crate::control_plane::messages::MetricsAuxInfo;
26 : use crate::metrics::Metrics;
27 :
28 : #[derive(Debug, Clone)]
29 : pub(crate) struct ConnInfoWithAuth {
30 : pub(crate) conn_info: ConnInfo,
31 : pub(crate) auth: AuthData,
32 : }
33 :
34 : #[derive(Debug, Clone)]
35 : pub(crate) enum AuthData {
36 : Password(SmallVec<[u8; 16]>),
37 : Jwt(String),
38 : }
39 :
40 : impl fmt::Display for ConnInfo {
41 : // use custom display to avoid logging password
42 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 0 : write!(
44 0 : f,
45 0 : "{}@{}/{}?{}",
46 0 : self.user_info.user,
47 0 : self.user_info.endpoint,
48 0 : self.dbname,
49 0 : self.user_info.options.get_cache_key("")
50 0 : )
51 0 : }
52 : }
53 :
54 0 : pub(crate) fn poll_client<C: ClientInnerExt>(
55 0 : global_pool: Arc<GlobalConnPool<C>>,
56 0 : ctx: &RequestMonitoring,
57 0 : conn_info: ConnInfo,
58 0 : client: C,
59 0 : mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
60 0 : conn_id: uuid::Uuid,
61 0 : aux: MetricsAuxInfo,
62 0 : ) -> Client<C> {
63 0 : let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
64 0 : let mut session_id = ctx.session_id();
65 0 : let (tx, mut rx) = tokio::sync::watch::channel(session_id);
66 :
67 0 : let span = info_span!(parent: None, "connection", %conn_id);
68 0 : let cold_start_info = ctx.cold_start_info();
69 0 : span.in_scope(|| {
70 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
71 0 : });
72 0 : let pool = match conn_info.endpoint_cache_key() {
73 0 : Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
74 0 : None => Weak::new(),
75 : };
76 0 : let pool_clone = pool.clone();
77 0 :
78 0 : let db_user = conn_info.db_and_user();
79 0 : let idle = global_pool.get_idle_timeout();
80 0 : let cancel = CancellationToken::new();
81 0 : let cancelled = cancel.clone().cancelled_owned();
82 0 :
83 0 : tokio::spawn(
84 0 : async move {
85 0 : let _conn_gauge = conn_gauge;
86 0 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
87 0 : let mut cancelled = pin!(cancelled);
88 0 :
89 0 : poll_fn(move |cx| {
90 0 : if cancelled.as_mut().poll(cx).is_ready() {
91 0 : info!("connection dropped");
92 0 : return Poll::Ready(())
93 0 : }
94 0 :
95 0 : match rx.has_changed() {
96 : Ok(true) => {
97 0 : session_id = *rx.borrow_and_update();
98 0 : info!(%session_id, "changed session");
99 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
100 : }
101 : Err(_) => {
102 0 : info!("connection dropped");
103 0 : return Poll::Ready(())
104 : }
105 0 : _ => {}
106 : }
107 :
108 : // 5 minute idle connection timeout
109 0 : if idle_timeout.as_mut().poll(cx).is_ready() {
110 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
111 0 : info!("connection idle");
112 0 : if let Some(pool) = pool.clone().upgrade() {
113 : // remove client from pool - should close the connection if it's idle.
114 : // does nothing if the client is currently checked-out and in-use
115 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
116 0 : info!("idle connection removed");
117 0 : }
118 0 : }
119 0 : }
120 :
121 : loop {
122 0 : let message = ready!(connection.poll_message(cx));
123 :
124 0 : match message {
125 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
126 0 : info!(%session_id, "notice: {}", notice);
127 : }
128 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
129 0 : warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
130 : }
131 : Some(Ok(_)) => {
132 0 : warn!(%session_id, "unknown message");
133 : }
134 0 : Some(Err(e)) => {
135 0 : error!(%session_id, "connection error: {}", e);
136 0 : break
137 : }
138 : None => {
139 0 : info!("connection closed");
140 0 : break
141 : }
142 : }
143 : }
144 :
145 : // remove from connection pool
146 0 : if let Some(pool) = pool.clone().upgrade() {
147 0 : if pool.write().remove_client(db_user.clone(), conn_id) {
148 0 : info!("closed connection removed");
149 0 : }
150 0 : }
151 :
152 0 : Poll::Ready(())
153 0 : }).await;
154 :
155 0 : }
156 0 : .instrument(span));
157 0 : let inner = ClientInnerCommon {
158 0 : inner: client,
159 0 : aux,
160 0 : conn_id,
161 0 : data: ClientDataEnum::Remote(ClientDataRemote {
162 0 : session: tx,
163 0 : cancel,
164 0 : }),
165 0 : };
166 0 :
167 0 : Client::new(inner, conn_info, pool_clone)
168 0 : }
169 :
170 : pub(crate) struct ClientDataRemote {
171 : session: tokio::sync::watch::Sender<uuid::Uuid>,
172 : cancel: CancellationToken,
173 : }
174 :
175 : impl ClientDataRemote {
176 0 : pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
177 0 : &mut self.session
178 0 : }
179 :
180 7 : pub fn cancel(&mut self) {
181 7 : self.cancel.cancel();
182 7 : }
183 : }
184 :
185 : #[cfg(test)]
186 : mod tests {
187 : use std::mem;
188 : use std::sync::atomic::AtomicBool;
189 :
190 : use super::*;
191 : use crate::proxy::NeonOptions;
192 : use crate::serverless::cancel_set::CancelSet;
193 : use crate::types::{BranchId, EndpointId, ProjectId};
194 :
195 : struct MockClient(Arc<AtomicBool>);
196 : impl MockClient {
197 6 : fn new(is_closed: bool) -> Self {
198 6 : MockClient(Arc::new(is_closed.into()))
199 6 : }
200 : }
201 : impl ClientInnerExt for MockClient {
202 8 : fn is_closed(&self) -> bool {
203 8 : self.0.load(atomic::Ordering::Relaxed)
204 8 : }
205 0 : fn get_process_id(&self) -> i32 {
206 0 : 0
207 0 : }
208 : }
209 :
210 5 : fn create_inner() -> ClientInnerCommon<MockClient> {
211 5 : create_inner_with(MockClient::new(false))
212 5 : }
213 :
214 7 : fn create_inner_with(client: MockClient) -> ClientInnerCommon<MockClient> {
215 7 : ClientInnerCommon {
216 7 : inner: client,
217 7 : aux: MetricsAuxInfo {
218 7 : endpoint_id: (&EndpointId::from("endpoint")).into(),
219 7 : project_id: (&ProjectId::from("project")).into(),
220 7 : branch_id: (&BranchId::from("branch")).into(),
221 7 : cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
222 7 : },
223 7 : conn_id: uuid::Uuid::new_v4(),
224 7 : data: ClientDataEnum::Remote(ClientDataRemote {
225 7 : session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
226 7 : cancel: CancellationToken::new(),
227 7 : }),
228 7 : }
229 7 : }
230 :
231 : #[tokio::test]
232 1 : async fn test_pool() {
233 1 : let _ = env_logger::try_init();
234 1 : let config = Box::leak(Box::new(crate::config::HttpConfig {
235 1 : accept_websockets: false,
236 1 : pool_options: GlobalConnPoolOptions {
237 1 : max_conns_per_endpoint: 2,
238 1 : gc_epoch: Duration::from_secs(1),
239 1 : pool_shards: 2,
240 1 : idle_timeout: Duration::from_secs(1),
241 1 : opt_in: false,
242 1 : max_total_conns: 3,
243 1 : },
244 1 : cancel_set: CancelSet::new(0),
245 1 : client_conn_threshold: u64::MAX,
246 1 : max_request_size_bytes: u64::MAX,
247 1 : max_response_size_bytes: usize::MAX,
248 1 : }));
249 1 : let pool = GlobalConnPool::new(config);
250 1 : let conn_info = ConnInfo {
251 1 : user_info: ComputeUserInfo {
252 1 : user: "user".into(),
253 1 : endpoint: "endpoint".into(),
254 1 : options: NeonOptions::default(),
255 1 : },
256 1 : dbname: "dbname".into(),
257 1 : };
258 1 : let ep_pool = Arc::downgrade(
259 1 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
260 1 : );
261 1 : {
262 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
263 1 : assert_eq!(0, pool.get_global_connections_count());
264 1 : client.inner().1.discard();
265 1 : // Discard should not add the connection from the pool.
266 1 : assert_eq!(0, pool.get_global_connections_count());
267 1 : }
268 1 : {
269 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
270 1 : client.do_drop().unwrap()();
271 1 : mem::forget(client); // drop the client
272 1 : assert_eq!(1, pool.get_global_connections_count());
273 1 : }
274 1 : {
275 1 : let mut closed_client = Client::new(
276 1 : create_inner_with(MockClient::new(true)),
277 1 : conn_info.clone(),
278 1 : ep_pool.clone(),
279 1 : );
280 1 : closed_client.do_drop().unwrap()();
281 1 : mem::forget(closed_client); // drop the client
282 1 : // The closed client shouldn't be added to the pool.
283 1 : assert_eq!(1, pool.get_global_connections_count());
284 1 : }
285 1 : let is_closed: Arc<AtomicBool> = Arc::new(false.into());
286 1 : {
287 1 : let mut client = Client::new(
288 1 : create_inner_with(MockClient(is_closed.clone())),
289 1 : conn_info.clone(),
290 1 : ep_pool.clone(),
291 1 : );
292 1 : client.do_drop().unwrap()();
293 1 : mem::forget(client); // drop the client
294 1 :
295 1 : // The client should be added to the pool.
296 1 : assert_eq!(2, pool.get_global_connections_count());
297 1 : }
298 1 : {
299 1 : let mut client = Client::new(create_inner(), conn_info, ep_pool);
300 1 : client.do_drop().unwrap()();
301 1 : mem::forget(client); // drop the client
302 1 :
303 1 : // The client shouldn't be added to the pool. Because the ep-pool is full.
304 1 : assert_eq!(2, pool.get_global_connections_count());
305 1 : }
306 1 :
307 1 : let conn_info = ConnInfo {
308 1 : user_info: ComputeUserInfo {
309 1 : user: "user".into(),
310 1 : endpoint: "endpoint-2".into(),
311 1 : options: NeonOptions::default(),
312 1 : },
313 1 : dbname: "dbname".into(),
314 1 : };
315 1 : let ep_pool = Arc::downgrade(
316 1 : &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
317 1 : );
318 1 : {
319 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
320 1 : client.do_drop().unwrap()();
321 1 : mem::forget(client); // drop the client
322 1 : assert_eq!(3, pool.get_global_connections_count());
323 1 : }
324 1 : {
325 1 : let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
326 1 : client.do_drop().unwrap()();
327 1 : mem::forget(client); // drop the client
328 1 :
329 1 : // The client shouldn't be added to the pool. Because the global pool is full.
330 1 : assert_eq!(3, pool.get_global_connections_count());
331 1 : }
332 1 :
333 1 : is_closed.store(true, atomic::Ordering::Relaxed);
334 1 : // Do gc for all shards.
335 1 : pool.gc(0);
336 1 : pool.gc(1);
337 1 : // Closed client should be removed from the pool.
338 1 : assert_eq!(2, pool.get_global_connections_count());
339 1 : }
340 : }
|