Line data Source code
1 : pub mod copy_bidirectional;
2 : pub mod handshake;
3 : pub mod inprocess;
4 : pub mod passthrough;
5 :
6 : use std::sync::Arc;
7 :
8 : use futures::FutureExt;
9 : use smol_str::ToSmolStr;
10 : use thiserror::Error;
11 : use tokio::io::{AsyncRead, AsyncWrite};
12 : use tokio_util::sync::CancellationToken;
13 : use tracing::{Instrument, debug, error, info, warn};
14 :
15 : use crate::auth;
16 : use crate::cancellation::{self, CancellationHandler};
17 : use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
18 : use crate::context::RequestContext;
19 : use crate::error::{ReportableError, UserFacingError};
20 : use crate::metrics::{Metrics, NumClientConnectionsGuard};
21 : pub use crate::pglb::copy_bidirectional::ErrorSource;
22 : use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
23 : use crate::pglb::passthrough::ProxyPassthrough;
24 : use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
25 : use crate::proxy::handle_client;
26 : use crate::rate_limiter::EndpointRateLimiter;
27 : use crate::stream::Stream;
28 : use crate::util::run_until_cancelled;
29 :
30 : pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
31 :
32 : #[derive(Error, Debug)]
33 : #[error("{ERR_INSECURE_CONNECTION}")]
34 : pub struct TlsRequired;
35 :
36 : impl ReportableError for TlsRequired {
37 2 : fn get_error_kind(&self) -> crate::error::ErrorKind {
38 2 : crate::error::ErrorKind::User
39 2 : }
40 : }
41 :
42 : impl UserFacingError for TlsRequired {}
43 :
44 0 : pub async fn task_main(
45 0 : config: &'static ProxyConfig,
46 0 : auth_backend: &'static auth::Backend<'static, ()>,
47 0 : listener: tokio::net::TcpListener,
48 0 : cancellation_token: CancellationToken,
49 0 : cancellation_handler: Arc<CancellationHandler>,
50 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
51 0 : ) -> anyhow::Result<()> {
52 0 : scopeguard::defer! {
53 : info!("proxy has shut down");
54 : }
55 :
56 : // When set for the server socket, the keepalive setting
57 : // will be inherited by all accepted client sockets.
58 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
59 :
60 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
61 0 : let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
62 :
63 0 : while let Some(accept_result) =
64 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
65 : {
66 0 : let (socket, peer_addr) = accept_result?;
67 :
68 0 : let conn_gauge = Metrics::get()
69 0 : .proxy
70 0 : .client_connections
71 0 : .guard(crate::metrics::Protocol::Tcp);
72 :
73 0 : let session_id = uuid::Uuid::new_v4();
74 0 : let cancellation_handler = Arc::clone(&cancellation_handler);
75 0 : let cancellations = cancellations.clone();
76 :
77 0 : debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
78 0 : let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
79 :
80 0 : connections.spawn(async move {
81 0 : let (socket, conn_info) = match config.proxy_protocol_v2 {
82 : ProxyProtocolV2::Required => {
83 0 : match read_proxy_protocol(socket).await {
84 0 : Err(e) => {
85 0 : warn!("per-client task finished with an error: {e:#}");
86 0 : return;
87 : }
88 : // our load balancers will not send any more data. let's just exit immediately
89 0 : Ok((_socket, ConnectHeader::Local)) => {
90 0 : debug!("healthcheck received");
91 0 : return;
92 : }
93 0 : Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
94 : }
95 : }
96 : // ignore the header - it cannot be confused for a postgres or http connection so will
97 : // error later.
98 0 : ProxyProtocolV2::Rejected => (
99 0 : socket,
100 0 : ConnectionInfo {
101 0 : addr: peer_addr,
102 0 : extra: None,
103 0 : },
104 0 : ),
105 : };
106 :
107 0 : match socket.set_nodelay(true) {
108 0 : Ok(()) => {}
109 0 : Err(e) => {
110 0 : error!(
111 0 : "per-client task finished with an error: failed to set socket option: {e:#}"
112 : );
113 0 : return;
114 : }
115 : }
116 :
117 0 : let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
118 :
119 0 : let res = handle_connection(
120 0 : config,
121 0 : auth_backend,
122 0 : &ctx,
123 0 : cancellation_handler,
124 0 : socket,
125 0 : ClientMode::Tcp,
126 0 : endpoint_rate_limiter2,
127 0 : conn_gauge,
128 0 : cancellations,
129 0 : )
130 0 : .instrument(ctx.span())
131 0 : .boxed()
132 0 : .await;
133 :
134 0 : match res {
135 0 : Err(e) => {
136 0 : ctx.set_error_kind(e.get_error_kind());
137 0 : warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
138 : }
139 0 : Ok(None) => {
140 0 : ctx.set_success();
141 0 : }
142 0 : Ok(Some(p)) => {
143 0 : ctx.set_success();
144 0 : let _disconnect = ctx.log_connect();
145 0 : match p.proxy_pass().await {
146 0 : Ok(()) => {}
147 0 : Err(ErrorSource::Client(e)) => {
148 0 : warn!(
149 : ?session_id,
150 0 : "per-client task finished with an IO error from the client: {e:#}"
151 : );
152 : }
153 0 : Err(ErrorSource::Compute(e)) => {
154 0 : error!(
155 : ?session_id,
156 0 : "per-client task finished with an IO error from the compute: {e:#}"
157 : );
158 : }
159 : }
160 : }
161 : }
162 0 : });
163 : }
164 :
165 0 : connections.close();
166 0 : cancellations.close();
167 0 : drop(listener);
168 :
169 : // Drain connections
170 0 : connections.wait().await;
171 0 : cancellations.wait().await;
172 :
173 0 : Ok(())
174 0 : }
175 :
176 : pub(crate) enum ClientMode {
177 : Tcp,
178 : Websockets { hostname: Option<String> },
179 : }
180 :
181 : /// Abstracts the logic of handling TCP vs WS clients
182 : impl ClientMode {
183 0 : pub fn allow_cleartext(&self) -> bool {
184 0 : match self {
185 0 : ClientMode::Tcp => false,
186 0 : ClientMode::Websockets { .. } => true,
187 : }
188 0 : }
189 :
190 0 : pub fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
191 0 : match self {
192 0 : ClientMode::Tcp => s.sni_hostname(),
193 0 : ClientMode::Websockets { hostname } => hostname.as_deref(),
194 : }
195 0 : }
196 :
197 0 : pub fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
198 0 : match self {
199 0 : ClientMode::Tcp => tls,
200 : // TLS is None here if using websockets, because the connection is already encrypted.
201 0 : ClientMode::Websockets { .. } => None,
202 : }
203 0 : }
204 : }
205 :
206 : #[derive(Debug, Error)]
207 : // almost all errors should be reported to the user, but there's a few cases where we cannot
208 : // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
209 : // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
210 : // we cannot be sure the client even understands our error message
211 : // 3. PrepareClient: The client disconnected, so we can't tell them anyway...
212 : pub(crate) enum ClientRequestError {
213 : #[error("{0}")]
214 : Cancellation(#[from] cancellation::CancelError),
215 : #[error("{0}")]
216 : Handshake(#[from] HandshakeError),
217 : #[error("{0}")]
218 : HandshakeTimeout(#[from] tokio::time::error::Elapsed),
219 : #[error("{0}")]
220 : PrepareClient(#[from] std::io::Error),
221 : #[error("{0}")]
222 : ReportedError(#[from] crate::stream::ReportedError),
223 : }
224 :
225 : impl ReportableError for ClientRequestError {
226 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
227 0 : match self {
228 0 : ClientRequestError::Cancellation(e) => e.get_error_kind(),
229 0 : ClientRequestError::Handshake(e) => e.get_error_kind(),
230 0 : ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
231 0 : ClientRequestError::ReportedError(e) => e.get_error_kind(),
232 0 : ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
233 : }
234 0 : }
235 : }
236 :
237 : #[allow(clippy::too_many_arguments)]
238 0 : pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
239 0 : config: &'static ProxyConfig,
240 0 : auth_backend: &'static auth::Backend<'static, ()>,
241 0 : ctx: &RequestContext,
242 0 : cancellation_handler: Arc<CancellationHandler>,
243 0 : client: S,
244 0 : mode: ClientMode,
245 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
246 0 : conn_gauge: NumClientConnectionsGuard<'static>,
247 0 : cancellations: tokio_util::task::task_tracker::TaskTracker,
248 0 : ) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
249 0 : debug!(
250 0 : protocol = %ctx.protocol(),
251 0 : "handling interactive connection from client"
252 : );
253 :
254 0 : let metrics = &Metrics::get().proxy;
255 0 : let proto = ctx.protocol();
256 0 : let request_gauge = metrics.connection_requests.guard(proto);
257 :
258 0 : let tls = config.tls_config.load();
259 0 : let tls = tls.as_deref();
260 :
261 0 : let record_handshake_error = !ctx.has_private_peer_addr();
262 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
263 0 : let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
264 :
265 0 : let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
266 0 : .await??
267 : {
268 0 : HandshakeData::Startup(client, params) => (client, params),
269 0 : HandshakeData::Cancel(cancel_key_data) => {
270 : // spawn a task to cancel the session, but don't wait for it
271 0 : cancellations.spawn({
272 0 : let cancellation_handler_clone = Arc::clone(&cancellation_handler);
273 0 : let ctx = ctx.clone();
274 0 : let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
275 0 : cancel_span.follows_from(tracing::Span::current());
276 0 : async move {
277 0 : cancellation_handler_clone
278 0 : .cancel_session(
279 0 : cancel_key_data,
280 0 : ctx,
281 0 : config.authentication_config.ip_allowlist_check_enabled,
282 0 : config.authentication_config.is_vpc_acccess_proxy,
283 0 : auth_backend.get_api(),
284 0 : )
285 0 : .await
286 0 : .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
287 0 : }.instrument(cancel_span)
288 : });
289 :
290 0 : return Ok(None);
291 : }
292 : };
293 0 : drop(pause);
294 :
295 0 : ctx.set_db_options(params.clone());
296 :
297 0 : let common_names = tls.map(|tls| &tls.common_names);
298 :
299 0 : let (node, cancel_on_shutdown) = handle_client(
300 0 : config,
301 0 : auth_backend,
302 0 : ctx,
303 0 : cancellation_handler,
304 0 : &mut client,
305 0 : &mode,
306 0 : endpoint_rate_limiter,
307 0 : common_names,
308 0 : ¶ms,
309 0 : )
310 0 : .await?;
311 :
312 0 : let client = client.flush_and_into_inner().await?;
313 :
314 0 : let private_link_id = match ctx.extra() {
315 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
316 0 : Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
317 0 : None => None,
318 : };
319 :
320 0 : Ok(Some(ProxyPassthrough {
321 0 : client,
322 0 : compute: node.stream,
323 0 :
324 0 : aux: node.aux,
325 0 : private_link_id,
326 0 :
327 0 : _cancel_on_shutdown: cancel_on_shutdown,
328 0 :
329 0 : _req: request_gauge,
330 0 : _conn: conn_gauge,
331 0 : _db_conn: node.guage,
332 0 : }))
333 0 : }
|