Line data Source code
1 : use std::sync::Arc;
2 :
3 : use futures::{FutureExt, TryFutureExt};
4 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
5 : use tokio_util::sync::CancellationToken;
6 : use tracing::{Instrument, debug, error, info};
7 :
8 : use crate::auth::backend::ConsoleRedirectBackend;
9 : use crate::cancellation::CancellationHandler;
10 : use crate::config::{ProxyConfig, ProxyProtocolV2};
11 : use crate::context::RequestContext;
12 : use crate::error::ReportableError;
13 : use crate::metrics::{Metrics, NumClientConnectionsGuard};
14 : use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
15 : use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
16 : use crate::proxy::handshake::{HandshakeData, handshake};
17 : use crate::proxy::passthrough::ProxyPassthrough;
18 : use crate::proxy::{
19 : ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
20 : };
21 :
22 0 : pub async fn task_main(
23 0 : config: &'static ProxyConfig,
24 0 : backend: &'static ConsoleRedirectBackend,
25 0 : listener: tokio::net::TcpListener,
26 0 : cancellation_token: CancellationToken,
27 0 : cancellation_handler: Arc<CancellationHandler>,
28 0 : ) -> anyhow::Result<()> {
29 0 : scopeguard::defer! {
30 0 : info!("proxy has shut down");
31 0 : }
32 0 :
33 0 : // When set for the server socket, the keepalive setting
34 0 : // will be inherited by all accepted client sockets.
35 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
36 :
37 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
38 0 : let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
39 :
40 0 : while let Some(accept_result) =
41 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
42 : {
43 0 : let (socket, peer_addr) = accept_result?;
44 :
45 0 : let conn_gauge = Metrics::get()
46 0 : .proxy
47 0 : .client_connections
48 0 : .guard(crate::metrics::Protocol::Tcp);
49 0 :
50 0 : let session_id = uuid::Uuid::new_v4();
51 0 : let cancellation_handler = Arc::clone(&cancellation_handler);
52 0 : let cancellations = cancellations.clone();
53 0 :
54 0 : debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
55 :
56 0 : connections.spawn(async move {
57 0 : let (socket, peer_addr) = match read_proxy_protocol(socket).await {
58 0 : Err(e) => {
59 0 : error!("per-client task finished with an error: {e:#}");
60 0 : return;
61 : }
62 : // our load balancers will not send any more data. let's just exit immediately
63 0 : Ok((_socket, ConnectHeader::Local)) => {
64 0 : debug!("healthcheck received");
65 0 : return;
66 : }
67 0 : Ok((_socket, ConnectHeader::Missing))
68 0 : if config.proxy_protocol_v2 == ProxyProtocolV2::Required =>
69 0 : {
70 0 : error!("missing required proxy protocol header");
71 0 : return;
72 : }
73 0 : Ok((_socket, ConnectHeader::Proxy(_)))
74 0 : if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected =>
75 0 : {
76 0 : error!("proxy protocol header not supported");
77 0 : return;
78 : }
79 0 : Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
80 0 : Ok((socket, ConnectHeader::Missing)) => (
81 0 : socket,
82 0 : ConnectionInfo {
83 0 : addr: peer_addr,
84 0 : extra: None,
85 0 : },
86 0 : ),
87 : };
88 :
89 0 : match socket.inner.set_nodelay(true) {
90 0 : Ok(()) => {}
91 0 : Err(e) => {
92 0 : error!(
93 0 : "per-client task finished with an error: failed to set socket option: {e:#}"
94 : );
95 0 : return;
96 : }
97 : }
98 :
99 0 : let ctx = RequestContext::new(
100 0 : session_id,
101 0 : peer_addr,
102 0 : crate::metrics::Protocol::Tcp,
103 0 : &config.region,
104 0 : );
105 :
106 0 : let res = handle_client(
107 0 : config,
108 0 : backend,
109 0 : &ctx,
110 0 : cancellation_handler,
111 0 : socket,
112 0 : conn_gauge,
113 0 : cancellations,
114 0 : )
115 0 : .instrument(ctx.span())
116 0 : .boxed()
117 0 : .await;
118 :
119 0 : match res {
120 0 : Err(e) => {
121 0 : ctx.set_error_kind(e.get_error_kind());
122 0 : error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
123 : }
124 0 : Ok(None) => {
125 0 : ctx.set_success();
126 0 : }
127 0 : Ok(Some(p)) => {
128 0 : ctx.set_success();
129 0 : let _disconnect = ctx.log_connect();
130 0 : match p.proxy_pass(&config.connect_to_compute).await {
131 0 : Ok(()) => {}
132 0 : Err(ErrorSource::Client(e)) => {
133 0 : error!(
134 : ?session_id,
135 0 : "per-client task finished with an IO error from the client: {e:#}"
136 : );
137 : }
138 0 : Err(ErrorSource::Compute(e)) => {
139 0 : error!(
140 : ?session_id,
141 0 : "per-client task finished with an IO error from the compute: {e:#}"
142 : );
143 : }
144 : }
145 : }
146 : }
147 0 : });
148 : }
149 :
150 0 : connections.close();
151 0 : cancellations.close();
152 0 : drop(listener);
153 0 :
154 0 : // Drain connections
155 0 : connections.wait().await;
156 0 : cancellations.wait().await;
157 :
158 0 : Ok(())
159 0 : }
160 :
161 : #[allow(clippy::too_many_arguments)]
162 0 : pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
163 0 : config: &'static ProxyConfig,
164 0 : backend: &'static ConsoleRedirectBackend,
165 0 : ctx: &RequestContext,
166 0 : cancellation_handler: Arc<CancellationHandler>,
167 0 : stream: S,
168 0 : conn_gauge: NumClientConnectionsGuard<'static>,
169 0 : cancellations: tokio_util::task::task_tracker::TaskTracker,
170 0 : ) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
171 0 : debug!(
172 0 : protocol = %ctx.protocol(),
173 0 : "handling interactive connection from client"
174 : );
175 :
176 0 : let metrics = &Metrics::get().proxy;
177 0 : let proto = ctx.protocol();
178 0 : let request_gauge = metrics.connection_requests.guard(proto);
179 0 :
180 0 : let tls = config.tls_config.as_ref();
181 0 :
182 0 : let record_handshake_error = !ctx.has_private_peer_addr();
183 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
184 0 : let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
185 :
186 0 : let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
187 0 : .await??
188 : {
189 0 : HandshakeData::Startup(stream, params) => (stream, params),
190 0 : HandshakeData::Cancel(cancel_key_data) => {
191 0 : // spawn a task to cancel the session, but don't wait for it
192 0 : cancellations.spawn({
193 0 : let cancellation_handler_clone = Arc::clone(&cancellation_handler);
194 0 : let ctx = ctx.clone();
195 0 : let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
196 0 : cancel_span.follows_from(tracing::Span::current());
197 0 : async move {
198 0 : cancellation_handler_clone
199 0 : .cancel_session(
200 0 : cancel_key_data,
201 0 : ctx,
202 0 : config.authentication_config.ip_allowlist_check_enabled,
203 0 : config.authentication_config.is_vpc_acccess_proxy,
204 0 : backend.get_api(),
205 0 : )
206 0 : .await
207 0 : .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
208 0 : }.instrument(cancel_span)
209 0 : });
210 0 :
211 0 : return Ok(None);
212 : }
213 : };
214 0 : drop(pause);
215 0 :
216 0 : ctx.set_db_options(params.clone());
217 :
218 0 : let (node_info, user_info, _ip_allowlist) = match backend
219 0 : .authenticate(ctx, &config.authentication_config, &mut stream)
220 0 : .await
221 : {
222 0 : Ok(auth_result) => auth_result,
223 0 : Err(e) => {
224 0 : return stream.throw_error(e).await?;
225 : }
226 : };
227 :
228 0 : let mut node = connect_to_compute(
229 0 : ctx,
230 0 : &TcpMechanism {
231 0 : user_info,
232 0 : params_compat: true,
233 0 : params: ¶ms,
234 0 : locks: &config.connect_compute_locks,
235 0 : },
236 0 : &node_info,
237 0 : config.wake_compute_retry_config,
238 0 : &config.connect_to_compute,
239 0 : )
240 0 : .or_else(|e| stream.throw_error(e))
241 0 : .await?;
242 :
243 0 : let cancellation_handler_clone = Arc::clone(&cancellation_handler);
244 0 : let session = cancellation_handler_clone.get_key();
245 0 :
246 0 : session
247 0 : .write_cancel_key(node.cancel_closure.clone())
248 0 : .await?;
249 :
250 0 : prepare_client_connection(&node, *session.key(), &mut stream).await?;
251 :
252 : // Before proxy passing, forward to compute whatever data is left in the
253 : // PqStream input buffer. Normally there is none, but our serverless npm
254 : // driver in pipeline mode sends startup, password and first query
255 : // immediately after opening the connection.
256 0 : let (stream, read_buf) = stream.into_inner();
257 0 : node.stream.write_all(&read_buf).await?;
258 :
259 0 : Ok(Some(ProxyPassthrough {
260 0 : client: stream,
261 0 : aux: node.aux.clone(),
262 0 : private_link_id: None,
263 0 : compute: node,
264 0 : session_id: ctx.session_id(),
265 0 : cancel: session,
266 0 : _req: request_gauge,
267 0 : _conn: conn_gauge,
268 0 : }))
269 0 : }
|