Line data Source code
1 : use std::sync::Arc;
2 :
3 : use futures::{FutureExt, TryFutureExt};
4 : use postgres_client::RawCancelToken;
5 : use tokio::io::{AsyncRead, AsyncWrite};
6 : use tokio_util::sync::CancellationToken;
7 : use tracing::{Instrument, debug, error, info};
8 :
9 : use crate::auth::backend::ConsoleRedirectBackend;
10 : use crate::cancellation::{CancelClosure, CancellationHandler};
11 : use crate::config::{ProxyConfig, ProxyProtocolV2};
12 : use crate::context::RequestContext;
13 : use crate::error::ReportableError;
14 : use crate::metrics::{Metrics, NumClientConnectionsGuard};
15 : use crate::pglb::ClientRequestError;
16 : use crate::pglb::handshake::{HandshakeData, handshake};
17 : use crate::pglb::passthrough::ProxyPassthrough;
18 : use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
19 : use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
20 : use crate::proxy::{ErrorSource, forward_compute_params_to_client, send_client_greeting};
21 : use crate::util::run_until_cancelled;
22 :
23 0 : pub async fn task_main(
24 0 : config: &'static ProxyConfig,
25 0 : backend: &'static ConsoleRedirectBackend,
26 0 : listener: tokio::net::TcpListener,
27 0 : cancellation_token: CancellationToken,
28 0 : cancellation_handler: Arc<CancellationHandler>,
29 0 : ) -> anyhow::Result<()> {
30 0 : scopeguard::defer! {
31 : info!("proxy has shut down");
32 : }
33 :
34 : // When set for the server socket, the keepalive setting
35 : // will be inherited by all accepted client sockets.
36 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
37 :
38 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
39 0 : let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
40 :
41 0 : while let Some(accept_result) =
42 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
43 : {
44 0 : let (socket, peer_addr) = accept_result?;
45 :
46 0 : let conn_gauge = Metrics::get()
47 0 : .proxy
48 0 : .client_connections
49 0 : .guard(crate::metrics::Protocol::Tcp);
50 :
51 0 : let session_id = uuid::Uuid::new_v4();
52 0 : let cancellation_handler = Arc::clone(&cancellation_handler);
53 0 : let cancellations = cancellations.clone();
54 :
55 0 : debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
56 :
57 0 : connections.spawn(async move {
58 0 : let (socket, conn_info) = match config.proxy_protocol_v2 {
59 : ProxyProtocolV2::Required => {
60 0 : match read_proxy_protocol(socket).await {
61 0 : Err(e) => {
62 0 : error!("per-client task finished with an error: {e:#}");
63 0 : return;
64 : }
65 : // our load balancers will not send any more data. let's just exit immediately
66 0 : Ok((_socket, ConnectHeader::Local)) => {
67 0 : debug!("healthcheck received");
68 0 : return;
69 : }
70 0 : Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
71 : }
72 : }
73 : // ignore the header - it cannot be confused for a postgres or http connection so will
74 : // error later.
75 0 : ProxyProtocolV2::Rejected => (
76 0 : socket,
77 0 : ConnectionInfo {
78 0 : addr: peer_addr,
79 0 : extra: None,
80 0 : },
81 0 : ),
82 : };
83 :
84 0 : match socket.set_nodelay(true) {
85 0 : Ok(()) => {}
86 0 : Err(e) => {
87 0 : error!(
88 0 : "per-client task finished with an error: failed to set socket option: {e:#}"
89 : );
90 0 : return;
91 : }
92 : }
93 :
94 0 : let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
95 :
96 0 : let res = handle_client(
97 0 : config,
98 0 : backend,
99 0 : &ctx,
100 0 : cancellation_handler,
101 0 : socket,
102 0 : conn_gauge,
103 0 : cancellations,
104 0 : )
105 0 : .instrument(ctx.span())
106 0 : .boxed()
107 0 : .await;
108 :
109 0 : match res {
110 0 : Err(e) => {
111 0 : ctx.set_error_kind(e.get_error_kind());
112 0 : error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
113 : }
114 0 : Ok(None) => {
115 0 : ctx.set_success();
116 0 : }
117 0 : Ok(Some(p)) => {
118 0 : ctx.set_success();
119 0 : let _disconnect = ctx.log_connect();
120 0 : match p.proxy_pass().await {
121 0 : Ok(()) => {}
122 0 : Err(ErrorSource::Client(e)) => {
123 0 : error!(
124 : ?session_id,
125 0 : "per-client task finished with an IO error from the client: {e:#}"
126 : );
127 : }
128 0 : Err(ErrorSource::Compute(e)) => {
129 0 : error!(
130 : ?session_id,
131 0 : "per-client task finished with an IO error from the compute: {e:#}"
132 : );
133 : }
134 : }
135 : }
136 : }
137 0 : });
138 : }
139 :
140 0 : connections.close();
141 0 : cancellations.close();
142 0 : drop(listener);
143 :
144 : // Drain connections
145 0 : connections.wait().await;
146 0 : cancellations.wait().await;
147 :
148 0 : Ok(())
149 0 : }
150 :
151 : #[allow(clippy::too_many_arguments)]
152 0 : pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
153 0 : config: &'static ProxyConfig,
154 0 : backend: &'static ConsoleRedirectBackend,
155 0 : ctx: &RequestContext,
156 0 : cancellation_handler: Arc<CancellationHandler>,
157 0 : stream: S,
158 0 : conn_gauge: NumClientConnectionsGuard<'static>,
159 0 : cancellations: tokio_util::task::task_tracker::TaskTracker,
160 0 : ) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
161 0 : debug!(
162 0 : protocol = %ctx.protocol(),
163 0 : "handling interactive connection from client"
164 : );
165 :
166 0 : let metrics = &Metrics::get().proxy;
167 0 : let proto = ctx.protocol();
168 0 : let request_gauge = metrics.connection_requests.guard(proto);
169 :
170 0 : let tls = config.tls_config.load();
171 0 : let tls = tls.as_deref();
172 :
173 0 : let record_handshake_error = !ctx.has_private_peer_addr();
174 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
175 0 : let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
176 :
177 0 : let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
178 0 : .await??
179 : {
180 0 : HandshakeData::Startup(stream, params) => (stream, params),
181 0 : HandshakeData::Cancel(cancel_key_data) => {
182 : // spawn a task to cancel the session, but don't wait for it
183 0 : cancellations.spawn({
184 0 : let cancellation_handler_clone = Arc::clone(&cancellation_handler);
185 0 : let ctx = ctx.clone();
186 0 : let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
187 0 : cancel_span.follows_from(tracing::Span::current());
188 0 : async move {
189 0 : cancellation_handler_clone
190 0 : .cancel_session(
191 0 : cancel_key_data,
192 0 : ctx,
193 0 : config.authentication_config.ip_allowlist_check_enabled,
194 0 : config.authentication_config.is_vpc_acccess_proxy,
195 0 : backend.get_api(),
196 0 : )
197 0 : .await
198 0 : .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
199 0 : }.instrument(cancel_span)
200 : });
201 :
202 0 : return Ok(None);
203 : }
204 : };
205 0 : drop(pause);
206 :
207 0 : ctx.set_db_options(params.clone());
208 :
209 0 : let (node_info, mut auth_info, user_info) = match backend
210 0 : .authenticate(ctx, &config.authentication_config, &mut stream)
211 0 : .await
212 : {
213 0 : Ok(auth_result) => auth_result,
214 0 : Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
215 : };
216 0 : auth_info.set_startup_params(¶ms, true);
217 :
218 0 : let mut node = connect_to_compute(
219 0 : ctx,
220 0 : &TcpMechanism {
221 0 : locks: &config.connect_compute_locks,
222 0 : },
223 0 : &node_info,
224 0 : config.wake_compute_retry_config,
225 0 : &config.connect_to_compute,
226 : )
227 0 : .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
228 0 : .await?;
229 :
230 0 : auth_info
231 0 : .authenticate(ctx, &mut node)
232 0 : .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
233 0 : .await?;
234 0 : send_client_greeting(ctx, &config.greetings, &mut stream);
235 :
236 0 : let session = cancellation_handler.get_key();
237 :
238 0 : let (process_id, secret_key) =
239 0 : forward_compute_params_to_client(ctx, *session.key(), &mut stream, &mut node.stream)
240 0 : .await?;
241 0 : let stream = stream.flush_and_into_inner().await?;
242 0 : let hostname = node.hostname.to_string();
243 :
244 0 : let session_id = ctx.session_id();
245 0 : let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
246 0 : tokio::spawn(async move {
247 0 : session
248 0 : .maintain_cancel_key(
249 0 : session_id,
250 0 : cancel,
251 0 : &CancelClosure {
252 0 : socket_addr: node.socket_addr,
253 0 : cancel_token: RawCancelToken {
254 0 : ssl_mode: node.ssl_mode,
255 0 : process_id,
256 0 : secret_key,
257 0 : },
258 0 : hostname,
259 0 : user_info,
260 0 : },
261 0 : &config.connect_to_compute,
262 0 : )
263 0 : .await;
264 0 : });
265 :
266 0 : Ok(Some(ProxyPassthrough {
267 0 : client: stream,
268 0 : compute: node.stream.into_framed().into_inner(),
269 0 :
270 0 : aux: node.aux,
271 0 : private_link_id: None,
272 0 :
273 0 : _cancel_on_shutdown: cancel_on_shutdown,
274 0 :
275 0 : _req: request_gauge,
276 0 : _conn: conn_gauge,
277 0 : _db_conn: node.guage,
278 0 : }))
279 0 : }
|