Line data Source code
1 : #[cfg(test)]
2 : mod tests;
3 :
4 : pub(crate) mod connect_compute;
5 : mod copy_bidirectional;
6 : pub(crate) mod handshake;
7 : pub(crate) mod passthrough;
8 : pub(crate) mod retry;
9 : pub(crate) mod wake_compute;
10 : use std::sync::Arc;
11 :
12 : pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
13 : use futures::{FutureExt, TryFutureExt};
14 : use itertools::Itertools;
15 : use once_cell::sync::OnceCell;
16 : use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
17 : use regex::Regex;
18 : use serde::{Deserialize, Serialize};
19 : use smol_str::{SmolStr, ToSmolStr, format_smolstr};
20 : use thiserror::Error;
21 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
22 : use tokio_util::sync::CancellationToken;
23 : use tracing::{Instrument, debug, error, info, warn};
24 :
25 : use self::connect_compute::{TcpMechanism, connect_to_compute};
26 : use self::passthrough::ProxyPassthrough;
27 : use crate::cancellation::{self, CancellationHandler};
28 : use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
29 : use crate::context::RequestContext;
30 : use crate::error::ReportableError;
31 : use crate::metrics::{Metrics, NumClientConnectionsGuard};
32 : use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
33 : use crate::proxy::handshake::{HandshakeData, handshake};
34 : use crate::rate_limiter::EndpointRateLimiter;
35 : use crate::stream::{PqStream, Stream};
36 : use crate::types::EndpointCacheKey;
37 : use crate::{auth, compute};
38 :
39 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
40 :
41 0 : pub async fn run_until_cancelled<F: std::future::Future>(
42 0 : f: F,
43 0 : cancellation_token: &CancellationToken,
44 0 : ) -> Option<F::Output> {
45 0 : match futures::future::select(
46 0 : std::pin::pin!(f),
47 0 : std::pin::pin!(cancellation_token.cancelled()),
48 0 : )
49 0 : .await
50 : {
51 0 : futures::future::Either::Left((f, _)) => Some(f),
52 0 : futures::future::Either::Right(((), _)) => None,
53 : }
54 0 : }
55 :
56 0 : pub async fn task_main(
57 0 : config: &'static ProxyConfig,
58 0 : auth_backend: &'static auth::Backend<'static, ()>,
59 0 : listener: tokio::net::TcpListener,
60 0 : cancellation_token: CancellationToken,
61 0 : cancellation_handler: Arc<CancellationHandler>,
62 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
63 0 : ) -> anyhow::Result<()> {
64 0 : scopeguard::defer! {
65 0 : info!("proxy has shut down");
66 0 : }
67 0 :
68 0 : // When set for the server socket, the keepalive setting
69 0 : // will be inherited by all accepted client sockets.
70 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
71 :
72 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
73 0 : let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
74 :
75 0 : while let Some(accept_result) =
76 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
77 : {
78 0 : let (socket, peer_addr) = accept_result?;
79 :
80 0 : let conn_gauge = Metrics::get()
81 0 : .proxy
82 0 : .client_connections
83 0 : .guard(crate::metrics::Protocol::Tcp);
84 0 :
85 0 : let session_id = uuid::Uuid::new_v4();
86 0 : let cancellation_handler = Arc::clone(&cancellation_handler);
87 0 : let cancellations = cancellations.clone();
88 0 :
89 0 : debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
90 0 : let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
91 0 :
92 0 : connections.spawn(async move {
93 0 : let (socket, conn_info) = match read_proxy_protocol(socket).await {
94 0 : Err(e) => {
95 0 : warn!("per-client task finished with an error: {e:#}");
96 0 : return;
97 : }
98 : // our load balancers will not send any more data. let's just exit immediately
99 0 : Ok((_socket, ConnectHeader::Local)) => {
100 0 : debug!("healthcheck received");
101 0 : return;
102 : }
103 0 : Ok((_socket, ConnectHeader::Missing))
104 0 : if config.proxy_protocol_v2 == ProxyProtocolV2::Required =>
105 0 : {
106 0 : warn!("missing required proxy protocol header");
107 0 : return;
108 : }
109 0 : Ok((_socket, ConnectHeader::Proxy(_)))
110 0 : if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected =>
111 0 : {
112 0 : warn!("proxy protocol header not supported");
113 0 : return;
114 : }
115 0 : Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
116 0 : Ok((socket, ConnectHeader::Missing)) => (
117 0 : socket,
118 0 : ConnectionInfo {
119 0 : addr: peer_addr,
120 0 : extra: None,
121 0 : },
122 0 : ),
123 : };
124 :
125 0 : match socket.inner.set_nodelay(true) {
126 0 : Ok(()) => {}
127 0 : Err(e) => {
128 0 : error!(
129 0 : "per-client task finished with an error: failed to set socket option: {e:#}"
130 : );
131 0 : return;
132 : }
133 : }
134 :
135 0 : let ctx = RequestContext::new(
136 0 : session_id,
137 0 : conn_info,
138 0 : crate::metrics::Protocol::Tcp,
139 0 : &config.region,
140 0 : );
141 :
142 0 : let res = handle_client(
143 0 : config,
144 0 : auth_backend,
145 0 : &ctx,
146 0 : cancellation_handler,
147 0 : socket,
148 0 : ClientMode::Tcp,
149 0 : endpoint_rate_limiter2,
150 0 : conn_gauge,
151 0 : cancellations,
152 0 : )
153 0 : .instrument(ctx.span())
154 0 : .boxed()
155 0 : .await;
156 :
157 0 : match res {
158 0 : Err(e) => {
159 0 : ctx.set_error_kind(e.get_error_kind());
160 0 : warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
161 : }
162 0 : Ok(None) => {
163 0 : ctx.set_success();
164 0 : }
165 0 : Ok(Some(p)) => {
166 0 : ctx.set_success();
167 0 : let _disconnect = ctx.log_connect();
168 0 : match p.proxy_pass(&config.connect_to_compute).await {
169 0 : Ok(()) => {}
170 0 : Err(ErrorSource::Client(e)) => {
171 0 : warn!(
172 : ?session_id,
173 0 : "per-client task finished with an IO error from the client: {e:#}"
174 : );
175 : }
176 0 : Err(ErrorSource::Compute(e)) => {
177 0 : error!(
178 : ?session_id,
179 0 : "per-client task finished with an IO error from the compute: {e:#}"
180 : );
181 : }
182 : }
183 : }
184 : }
185 0 : });
186 : }
187 :
188 0 : connections.close();
189 0 : cancellations.close();
190 0 : drop(listener);
191 0 :
192 0 : // Drain connections
193 0 : connections.wait().await;
194 0 : cancellations.wait().await;
195 :
196 0 : Ok(())
197 0 : }
198 :
199 : pub(crate) enum ClientMode {
200 : Tcp,
201 : Websockets { hostname: Option<String> },
202 : }
203 :
204 : /// Abstracts the logic of handling TCP vs WS clients
205 : impl ClientMode {
206 0 : pub(crate) fn allow_cleartext(&self) -> bool {
207 0 : match self {
208 0 : ClientMode::Tcp => false,
209 0 : ClientMode::Websockets { .. } => true,
210 : }
211 0 : }
212 :
213 0 : fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
214 0 : match self {
215 0 : ClientMode::Tcp => s.sni_hostname(),
216 0 : ClientMode::Websockets { hostname } => hostname.as_deref(),
217 : }
218 0 : }
219 :
220 0 : fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
221 0 : match self {
222 0 : ClientMode::Tcp => tls,
223 : // TLS is None here if using websockets, because the connection is already encrypted.
224 0 : ClientMode::Websockets { .. } => None,
225 : }
226 0 : }
227 : }
228 :
229 : #[derive(Debug, Error)]
230 : // almost all errors should be reported to the user, but there's a few cases where we cannot
231 : // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
232 : // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
233 : // we cannot be sure the client even understands our error message
234 : // 3. PrepareClient: The client disconnected, so we can't tell them anyway...
235 : pub(crate) enum ClientRequestError {
236 : #[error("{0}")]
237 : Cancellation(#[from] cancellation::CancelError),
238 : #[error("{0}")]
239 : Handshake(#[from] handshake::HandshakeError),
240 : #[error("{0}")]
241 : HandshakeTimeout(#[from] tokio::time::error::Elapsed),
242 : #[error("{0}")]
243 : PrepareClient(#[from] std::io::Error),
244 : #[error("{0}")]
245 : ReportedError(#[from] crate::stream::ReportedError),
246 : }
247 :
248 : impl ReportableError for ClientRequestError {
249 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
250 0 : match self {
251 0 : ClientRequestError::Cancellation(e) => e.get_error_kind(),
252 0 : ClientRequestError::Handshake(e) => e.get_error_kind(),
253 0 : ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
254 0 : ClientRequestError::ReportedError(e) => e.get_error_kind(),
255 0 : ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
256 : }
257 0 : }
258 : }
259 :
260 : #[allow(clippy::too_many_arguments)]
261 0 : pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
262 0 : config: &'static ProxyConfig,
263 0 : auth_backend: &'static auth::Backend<'static, ()>,
264 0 : ctx: &RequestContext,
265 0 : cancellation_handler: Arc<CancellationHandler>,
266 0 : stream: S,
267 0 : mode: ClientMode,
268 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
269 0 : conn_gauge: NumClientConnectionsGuard<'static>,
270 0 : cancellations: tokio_util::task::task_tracker::TaskTracker,
271 0 : ) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
272 0 : debug!(
273 0 : protocol = %ctx.protocol(),
274 0 : "handling interactive connection from client"
275 : );
276 :
277 0 : let metrics = &Metrics::get().proxy;
278 0 : let proto = ctx.protocol();
279 0 : let request_gauge = metrics.connection_requests.guard(proto);
280 0 :
281 0 : let tls = config.tls_config.load();
282 0 : let tls = tls.as_deref();
283 0 :
284 0 : let record_handshake_error = !ctx.has_private_peer_addr();
285 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
286 0 : let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
287 :
288 0 : let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
289 0 : .await??
290 : {
291 0 : HandshakeData::Startup(stream, params) => (stream, params),
292 0 : HandshakeData::Cancel(cancel_key_data) => {
293 0 : // spawn a task to cancel the session, but don't wait for it
294 0 : cancellations.spawn({
295 0 : let cancellation_handler_clone = Arc::clone(&cancellation_handler);
296 0 : let ctx = ctx.clone();
297 0 : let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
298 0 : cancel_span.follows_from(tracing::Span::current());
299 0 : async move {
300 0 : cancellation_handler_clone
301 0 : .cancel_session(
302 0 : cancel_key_data,
303 0 : ctx,
304 0 : config.authentication_config.ip_allowlist_check_enabled,
305 0 : config.authentication_config.is_vpc_acccess_proxy,
306 0 : auth_backend.get_api(),
307 0 : )
308 0 : .await
309 0 : .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
310 0 : }.instrument(cancel_span)
311 0 : });
312 0 :
313 0 : return Ok(None);
314 : }
315 : };
316 0 : drop(pause);
317 0 :
318 0 : ctx.set_db_options(params.clone());
319 0 :
320 0 : let hostname = mode.hostname(stream.get_ref());
321 0 :
322 0 : let common_names = tls.map(|tls| &tls.common_names);
323 0 :
324 0 : // Extract credentials which we're going to use for auth.
325 0 : let result = auth_backend
326 0 : .as_ref()
327 0 : .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
328 0 : .transpose();
329 :
330 0 : let user_info = match result {
331 0 : Ok(user_info) => user_info,
332 0 : Err(e) => stream.throw_error(e).await?,
333 : };
334 :
335 0 : let user = user_info.get_user().to_owned();
336 0 : let (user_info, _ip_allowlist) = match user_info
337 0 : .authenticate(
338 0 : ctx,
339 0 : &mut stream,
340 0 : mode.allow_cleartext(),
341 0 : &config.authentication_config,
342 0 : endpoint_rate_limiter,
343 0 : )
344 0 : .await
345 : {
346 0 : Ok(auth_result) => auth_result,
347 0 : Err(e) => {
348 0 : let db = params.get("database");
349 0 : let app = params.get("application_name");
350 0 : let params_span = tracing::info_span!("", ?user, ?db, ?app);
351 :
352 0 : return stream.throw_error(e).instrument(params_span).await?;
353 : }
354 : };
355 :
356 0 : let compute_user_info = match &user_info {
357 0 : auth::Backend::ControlPlane(_, info) => &info.info,
358 0 : auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
359 : };
360 0 : let params_compat = compute_user_info
361 0 : .options
362 0 : .get(NeonOptions::PARAMS_COMPAT)
363 0 : .is_some();
364 :
365 0 : let mut node = connect_to_compute(
366 0 : ctx,
367 0 : &TcpMechanism {
368 0 : user_info: compute_user_info.clone(),
369 0 : params_compat,
370 0 : params: ¶ms,
371 0 : locks: &config.connect_compute_locks,
372 0 : },
373 0 : &user_info,
374 0 : config.wake_compute_retry_config,
375 0 : &config.connect_to_compute,
376 0 : )
377 0 : .or_else(|e| stream.throw_error(e))
378 0 : .await?;
379 :
380 0 : let cancellation_handler_clone = Arc::clone(&cancellation_handler);
381 0 : let session = cancellation_handler_clone.get_key();
382 0 :
383 0 : session
384 0 : .write_cancel_key(node.cancel_closure.clone())
385 0 : .await?;
386 :
387 0 : prepare_client_connection(&node, *session.key(), &mut stream).await?;
388 :
389 : // Before proxy passing, forward to compute whatever data is left in the
390 : // PqStream input buffer. Normally there is none, but our serverless npm
391 : // driver in pipeline mode sends startup, password and first query
392 : // immediately after opening the connection.
393 0 : let (stream, read_buf) = stream.into_inner();
394 0 : node.stream.write_all(&read_buf).await?;
395 :
396 0 : let private_link_id = match ctx.extra() {
397 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
398 0 : Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
399 0 : None => None,
400 : };
401 :
402 0 : Ok(Some(ProxyPassthrough {
403 0 : client: stream,
404 0 : aux: node.aux.clone(),
405 0 : private_link_id,
406 0 : compute: node,
407 0 : session_id: ctx.session_id(),
408 0 : cancel: session,
409 0 : _req: request_gauge,
410 0 : _conn: conn_gauge,
411 0 : }))
412 0 : }
413 :
414 : /// Finish client connection initialization: confirm auth success, send params, etc.
415 : #[tracing::instrument(skip_all)]
416 : pub(crate) async fn prepare_client_connection(
417 : node: &compute::PostgresConnection,
418 : cancel_key_data: CancelKeyData,
419 : stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
420 : ) -> Result<(), std::io::Error> {
421 : // Forward all deferred notices to the client.
422 : for notice in &node.delayed_notice {
423 : stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?;
424 : }
425 :
426 : // Forward all postgres connection params to the client.
427 : for (name, value) in &node.params {
428 : stream.write_message_noflush(&Be::ParameterStatus {
429 : name: name.as_bytes(),
430 : value: value.as_bytes(),
431 : })?;
432 : }
433 :
434 : stream
435 : .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
436 : .write_message(&Be::ReadyForQuery)
437 : .await?;
438 :
439 : Ok(())
440 : }
441 :
442 0 : #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
443 : pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
444 :
445 : impl NeonOptions {
446 : // proxy options:
447 :
448 : /// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
449 : const PARAMS_COMPAT: &str = "proxy_params_compat";
450 :
451 : // cplane options:
452 :
453 : /// `LSN` allows provisioning an ephemeral compute with time-travel to the provided LSN.
454 : const LSN: &str = "lsn";
455 :
456 : /// `ENDPOINT_TYPE` allows configuring an ephemeral compute to be read_only or read_write.
457 : const ENDPOINT_TYPE: &str = "endpoint_type";
458 :
459 11 : pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
460 11 : params
461 11 : .options_raw()
462 11 : .map(Self::parse_from_iter)
463 11 : .unwrap_or_default()
464 11 : }
465 :
466 7 : pub(crate) fn parse_options_raw(options: &str) -> Self {
467 7 : Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
468 7 : }
469 :
470 0 : pub(crate) fn get(&self, key: &str) -> Option<SmolStr> {
471 0 : self.0
472 0 : .iter()
473 0 : .find_map(|(k, v)| (k == key).then_some(v))
474 0 : .cloned()
475 0 : }
476 :
477 2 : pub(crate) fn is_ephemeral(&self) -> bool {
478 2 : self.0.iter().any(|(k, _)| match &**k {
479 0 : // This is not a cplane option, we know it does not create ephemeral computes.
480 0 : Self::PARAMS_COMPAT => false,
481 0 : Self::LSN => true,
482 0 : Self::ENDPOINT_TYPE => true,
483 : // err on the side of caution. any cplane options we don't know about
484 : // might lead to ephemeral computes.
485 0 : _ => true,
486 2 : })
487 2 : }
488 :
489 13 : fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
490 13 : let mut options = options
491 13 : .filter_map(neon_option)
492 13 : .map(|(k, v)| (k.into(), v.into()))
493 13 : .collect_vec();
494 13 : options.sort();
495 13 : Self(options)
496 13 : }
497 :
498 4 : pub(crate) fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
499 4 : // prefix + format!(" {k}:{v}")
500 4 : // kinda jank because SmolStr is immutable
501 4 : std::iter::once(prefix)
502 4 : .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
503 4 : .collect::<SmolStr>()
504 4 : .into()
505 4 : }
506 :
507 : /// <https://swagger.io/docs/specification/serialization/> DeepObject format
508 : /// `paramName[prop1]=value1¶mName[prop2]=value2&...`
509 0 : pub(crate) fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
510 0 : self.0
511 0 : .iter()
512 0 : .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
513 0 : .collect()
514 0 : }
515 : }
516 :
517 33 : pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
518 : static RE: OnceCell<Regex> = OnceCell::new();
519 33 : let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").expect("regex should be correct"));
520 :
521 33 : let cap = re.captures(bytes)?;
522 5 : let (_, [k, v]) = cap.extract();
523 5 : Some((k, v))
524 33 : }
|