Line data Source code
1 : use std::sync::Arc;
2 :
3 : use futures::TryFutureExt;
4 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
5 : use tokio_util::sync::CancellationToken;
6 : use tracing::{error, info, Instrument};
7 :
8 : use crate::auth::backend::ConsoleRedirectBackend;
9 : use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal};
10 : use crate::config::{ProxyConfig, ProxyProtocolV2};
11 : use crate::context::RequestMonitoring;
12 : use crate::error::ReportableError;
13 : use crate::metrics::{Metrics, NumClientConnectionsGuard};
14 : use crate::protocol2::read_proxy_protocol;
15 : use crate::proxy::connect_compute::{connect_to_compute, TcpMechanism};
16 : use crate::proxy::handshake::{handshake, HandshakeData};
17 : use crate::proxy::passthrough::ProxyPassthrough;
18 : use crate::proxy::{
19 : prepare_client_connection, run_until_cancelled, ClientRequestError, ErrorSource,
20 : };
21 :
22 0 : pub async fn task_main(
23 0 : config: &'static ProxyConfig,
24 0 : backend: &'static ConsoleRedirectBackend,
25 0 : listener: tokio::net::TcpListener,
26 0 : cancellation_token: CancellationToken,
27 0 : cancellation_handler: Arc<CancellationHandlerMain>,
28 0 : ) -> anyhow::Result<()> {
29 0 : scopeguard::defer! {
30 0 : info!("proxy has shut down");
31 0 : }
32 0 :
33 0 : // When set for the server socket, the keepalive setting
34 0 : // will be inherited by all accepted client sockets.
35 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
36 :
37 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
38 :
39 0 : while let Some(accept_result) =
40 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
41 : {
42 0 : let (socket, peer_addr) = accept_result?;
43 :
44 0 : let conn_gauge = Metrics::get()
45 0 : .proxy
46 0 : .client_connections
47 0 : .guard(crate::metrics::Protocol::Tcp);
48 0 :
49 0 : let session_id = uuid::Uuid::new_v4();
50 0 : let cancellation_handler = Arc::clone(&cancellation_handler);
51 0 :
52 0 : tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
53 :
54 0 : connections.spawn(async move {
55 0 : let (socket, peer_addr) = match read_proxy_protocol(socket).await {
56 0 : Err(e) => {
57 0 : error!("per-client task finished with an error: {e:#}");
58 0 : return;
59 : }
60 0 : Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
61 0 : error!("missing required proxy protocol header");
62 0 : return;
63 : }
64 0 : Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
65 0 : error!("proxy protocol header not supported");
66 0 : return;
67 : }
68 0 : Ok((socket, Some(addr))) => (socket, addr.ip()),
69 0 : Ok((socket, None)) => (socket, peer_addr.ip()),
70 : };
71 :
72 0 : match socket.inner.set_nodelay(true) {
73 0 : Ok(()) => {}
74 0 : Err(e) => {
75 0 : error!("per-client task finished with an error: failed to set socket option: {e:#}");
76 0 : return;
77 : }
78 : };
79 :
80 0 : let ctx = RequestMonitoring::new(
81 0 : session_id,
82 0 : peer_addr,
83 0 : crate::metrics::Protocol::Tcp,
84 0 : &config.region,
85 0 : );
86 0 : let span = ctx.span();
87 0 :
88 0 : let startup = Box::pin(
89 0 : handle_client(
90 0 : config,
91 0 : backend,
92 0 : &ctx,
93 0 : cancellation_handler,
94 0 : socket,
95 0 : conn_gauge,
96 0 : )
97 0 : .instrument(span.clone()),
98 0 : );
99 0 : let res = startup.await;
100 :
101 0 : match res {
102 0 : Err(e) => {
103 0 : // todo: log and push to ctx the error kind
104 0 : ctx.set_error_kind(e.get_error_kind());
105 0 : error!(parent: &span, "per-client task finished with an error: {e:#}");
106 : }
107 0 : Ok(None) => {
108 0 : ctx.set_success();
109 0 : }
110 0 : Ok(Some(p)) => {
111 0 : ctx.set_success();
112 0 : ctx.log_connect();
113 0 : match p.proxy_pass().instrument(span.clone()).await {
114 0 : Ok(()) => {}
115 0 : Err(ErrorSource::Client(e)) => {
116 0 : error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
117 : }
118 0 : Err(ErrorSource::Compute(e)) => {
119 0 : error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
120 : }
121 : }
122 : }
123 : }
124 0 : });
125 : }
126 :
127 0 : connections.close();
128 0 : drop(listener);
129 0 :
130 0 : // Drain connections
131 0 : connections.wait().await;
132 :
133 0 : Ok(())
134 0 : }
135 :
136 0 : pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
137 0 : config: &'static ProxyConfig,
138 0 : backend: &'static ConsoleRedirectBackend,
139 0 : ctx: &RequestMonitoring,
140 0 : cancellation_handler: Arc<CancellationHandlerMain>,
141 0 : stream: S,
142 0 : conn_gauge: NumClientConnectionsGuard<'static>,
143 0 : ) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
144 0 : info!(
145 0 : protocol = %ctx.protocol(),
146 0 : "handling interactive connection from client"
147 : );
148 :
149 0 : let metrics = &Metrics::get().proxy;
150 0 : let proto = ctx.protocol();
151 0 : let request_gauge = metrics.connection_requests.guard(proto);
152 0 :
153 0 : let tls = config.tls_config.as_ref();
154 0 :
155 0 : let record_handshake_error = !ctx.has_private_peer_addr();
156 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
157 0 : let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
158 0 : let (mut stream, params) =
159 0 : match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
160 0 : HandshakeData::Startup(stream, params) => (stream, params),
161 0 : HandshakeData::Cancel(cancel_key_data) => {
162 0 : return Ok(cancellation_handler
163 0 : .cancel_session(cancel_key_data, ctx.session_id())
164 0 : .await
165 0 : .map(|()| None)?)
166 : }
167 : };
168 0 : drop(pause);
169 0 :
170 0 : ctx.set_db_options(params.clone());
171 :
172 0 : let user_info = match backend
173 0 : .authenticate(ctx, &config.authentication_config, &mut stream)
174 0 : .await
175 : {
176 0 : Ok(auth_result) => auth_result,
177 0 : Err(e) => {
178 0 : return stream.throw_error(e).await?;
179 : }
180 : };
181 :
182 0 : let mut node = connect_to_compute(
183 0 : ctx,
184 0 : &TcpMechanism {
185 0 : params: ¶ms,
186 0 : locks: &config.connect_compute_locks,
187 0 : },
188 0 : &user_info,
189 0 : config.allow_self_signed_compute,
190 0 : config.wake_compute_retry_config,
191 0 : config.connect_to_compute_retry_config,
192 0 : )
193 0 : .or_else(|e| stream.throw_error(e))
194 0 : .await?;
195 :
196 0 : let session = cancellation_handler.get_session();
197 0 : prepare_client_connection(&node, &session, &mut stream).await?;
198 :
199 : // Before proxy passing, forward to compute whatever data is left in the
200 : // PqStream input buffer. Normally there is none, but our serverless npm
201 : // driver in pipeline mode sends startup, password and first query
202 : // immediately after opening the connection.
203 0 : let (stream, read_buf) = stream.into_inner();
204 0 : node.stream.write_all(&read_buf).await?;
205 :
206 0 : Ok(Some(ProxyPassthrough {
207 0 : client: stream,
208 0 : aux: node.aux.clone(),
209 0 : compute: node,
210 0 : _req: request_gauge,
211 0 : _conn: conn_gauge,
212 0 : _cancel: session,
213 0 : }))
214 0 : }
|