LCOV - code coverage report
Current view: top level - proxy/src/pglb - mod.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 1.6 % 192 3
Test Date: 2025-07-16 12:29:03 Functions: 5.9 % 17 1

            Line data    Source code
       1              : pub mod copy_bidirectional;
       2              : pub mod handshake;
       3              : pub mod inprocess;
       4              : pub mod passthrough;
       5              : 
       6              : use std::sync::Arc;
       7              : 
       8              : use futures::FutureExt;
       9              : use smol_str::ToSmolStr;
      10              : use thiserror::Error;
      11              : use tokio::io::{AsyncRead, AsyncWrite};
      12              : use tokio_util::sync::CancellationToken;
      13              : use tracing::{Instrument, debug, error, info, warn};
      14              : 
      15              : use crate::auth;
      16              : use crate::cancellation::{self, CancellationHandler};
      17              : use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
      18              : use crate::context::RequestContext;
      19              : use crate::error::{ReportableError, UserFacingError};
      20              : use crate::metrics::{Metrics, NumClientConnectionsGuard};
      21              : pub use crate::pglb::copy_bidirectional::ErrorSource;
      22              : use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
      23              : use crate::pglb::passthrough::ProxyPassthrough;
      24              : use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
      25              : use crate::proxy::handle_client;
      26              : use crate::rate_limiter::EndpointRateLimiter;
      27              : use crate::stream::Stream;
      28              : use crate::util::run_until_cancelled;
      29              : 
      30              : pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
      31              : 
      32              : #[derive(Error, Debug)]
      33              : #[error("{ERR_INSECURE_CONNECTION}")]
      34              : pub struct TlsRequired;
      35              : 
      36              : impl ReportableError for TlsRequired {
      37            2 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      38            2 :         crate::error::ErrorKind::User
      39            2 :     }
      40              : }
      41              : 
      42              : impl UserFacingError for TlsRequired {}
      43              : 
      44            0 : pub async fn task_main(
      45            0 :     config: &'static ProxyConfig,
      46            0 :     auth_backend: &'static auth::Backend<'static, ()>,
      47            0 :     listener: tokio::net::TcpListener,
      48            0 :     cancellation_token: CancellationToken,
      49            0 :     cancellation_handler: Arc<CancellationHandler>,
      50            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      51            0 : ) -> anyhow::Result<()> {
      52            0 :     scopeguard::defer! {
      53              :         info!("proxy has shut down");
      54              :     }
      55              : 
      56              :     // When set for the server socket, the keepalive setting
      57              :     // will be inherited by all accepted client sockets.
      58            0 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
      59              : 
      60            0 :     let connections = tokio_util::task::task_tracker::TaskTracker::new();
      61            0 :     let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
      62              : 
      63            0 :     while let Some(accept_result) =
      64            0 :         run_until_cancelled(listener.accept(), &cancellation_token).await
      65              :     {
      66            0 :         let (socket, peer_addr) = accept_result?;
      67              : 
      68            0 :         let conn_gauge = Metrics::get()
      69            0 :             .proxy
      70            0 :             .client_connections
      71            0 :             .guard(crate::metrics::Protocol::Tcp);
      72              : 
      73            0 :         let session_id = uuid::Uuid::new_v4();
      74            0 :         let cancellation_handler = Arc::clone(&cancellation_handler);
      75            0 :         let cancellations = cancellations.clone();
      76              : 
      77            0 :         debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
      78            0 :         let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
      79              : 
      80            0 :         connections.spawn(async move {
      81            0 :             let (socket, conn_info) = match config.proxy_protocol_v2 {
      82              :                 ProxyProtocolV2::Required => {
      83            0 :                     match read_proxy_protocol(socket).await {
      84            0 :                         Err(e) => {
      85            0 :                             warn!("per-client task finished with an error: {e:#}");
      86            0 :                             return;
      87              :                         }
      88              :                         // our load balancers will not send any more data. let's just exit immediately
      89            0 :                         Ok((_socket, ConnectHeader::Local)) => {
      90            0 :                             debug!("healthcheck received");
      91            0 :                             return;
      92              :                         }
      93            0 :                         Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
      94              :                     }
      95              :                 }
      96              :                 // ignore the header - it cannot be confused for a postgres or http connection so will
      97              :                 // error later.
      98            0 :                 ProxyProtocolV2::Rejected => (
      99            0 :                     socket,
     100            0 :                     ConnectionInfo {
     101            0 :                         addr: peer_addr,
     102            0 :                         extra: None,
     103            0 :                     },
     104            0 :                 ),
     105              :             };
     106              : 
     107            0 :             match socket.set_nodelay(true) {
     108            0 :                 Ok(()) => {}
     109            0 :                 Err(e) => {
     110            0 :                     error!(
     111            0 :                         "per-client task finished with an error: failed to set socket option: {e:#}"
     112              :                     );
     113            0 :                     return;
     114              :                 }
     115              :             }
     116              : 
     117            0 :             let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
     118              : 
     119            0 :             let res = handle_connection(
     120            0 :                 config,
     121            0 :                 auth_backend,
     122            0 :                 &ctx,
     123            0 :                 cancellation_handler,
     124            0 :                 socket,
     125            0 :                 ClientMode::Tcp,
     126            0 :                 endpoint_rate_limiter2,
     127            0 :                 conn_gauge,
     128            0 :                 cancellations,
     129            0 :             )
     130            0 :             .instrument(ctx.span())
     131            0 :             .boxed()
     132            0 :             .await;
     133              : 
     134            0 :             match res {
     135            0 :                 Err(e) => {
     136            0 :                     ctx.set_error_kind(e.get_error_kind());
     137            0 :                     warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
     138              :                 }
     139            0 :                 Ok(None) => {
     140            0 :                     ctx.set_success();
     141            0 :                 }
     142            0 :                 Ok(Some(p)) => {
     143            0 :                     ctx.set_success();
     144            0 :                     let _disconnect = ctx.log_connect();
     145            0 :                     match p.proxy_pass().await {
     146            0 :                         Ok(()) => {}
     147            0 :                         Err(ErrorSource::Client(e)) => {
     148            0 :                             warn!(
     149              :                                 ?session_id,
     150            0 :                                 "per-client task finished with an IO error from the client: {e:#}"
     151              :                             );
     152              :                         }
     153            0 :                         Err(ErrorSource::Compute(e)) => {
     154            0 :                             error!(
     155              :                                 ?session_id,
     156            0 :                                 "per-client task finished with an IO error from the compute: {e:#}"
     157              :                             );
     158              :                         }
     159              :                     }
     160              :                 }
     161              :             }
     162            0 :         });
     163              :     }
     164              : 
     165            0 :     connections.close();
     166            0 :     cancellations.close();
     167            0 :     drop(listener);
     168              : 
     169              :     // Drain connections
     170            0 :     connections.wait().await;
     171            0 :     cancellations.wait().await;
     172              : 
     173            0 :     Ok(())
     174            0 : }
     175              : 
     176              : pub(crate) enum ClientMode {
     177              :     Tcp,
     178              :     Websockets { hostname: Option<String> },
     179              : }
     180              : 
     181              : /// Abstracts the logic of handling TCP vs WS clients
     182              : impl ClientMode {
     183            0 :     pub fn allow_cleartext(&self) -> bool {
     184            0 :         match self {
     185            0 :             ClientMode::Tcp => false,
     186            0 :             ClientMode::Websockets { .. } => true,
     187              :         }
     188            0 :     }
     189              : 
     190            0 :     pub fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
     191            0 :         match self {
     192            0 :             ClientMode::Tcp => s.sni_hostname(),
     193            0 :             ClientMode::Websockets { hostname } => hostname.as_deref(),
     194              :         }
     195            0 :     }
     196              : 
     197            0 :     pub fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
     198            0 :         match self {
     199            0 :             ClientMode::Tcp => tls,
     200              :             // TLS is None here if using websockets, because the connection is already encrypted.
     201            0 :             ClientMode::Websockets { .. } => None,
     202              :         }
     203            0 :     }
     204              : }
     205              : 
     206              : #[derive(Debug, Error)]
     207              : // almost all errors should be reported to the user, but there's a few cases where we cannot
     208              : // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
     209              : // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
     210              : //    we cannot be sure the client even understands our error message
     211              : // 3. PrepareClient: The client disconnected, so we can't tell them anyway...
     212              : pub(crate) enum ClientRequestError {
     213              :     #[error("{0}")]
     214              :     Cancellation(#[from] cancellation::CancelError),
     215              :     #[error("{0}")]
     216              :     Handshake(#[from] HandshakeError),
     217              :     #[error("{0}")]
     218              :     HandshakeTimeout(#[from] tokio::time::error::Elapsed),
     219              :     #[error("{0}")]
     220              :     PrepareClient(#[from] std::io::Error),
     221              :     #[error("{0}")]
     222              :     ReportedError(#[from] crate::stream::ReportedError),
     223              : }
     224              : 
     225              : impl ReportableError for ClientRequestError {
     226            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     227            0 :         match self {
     228            0 :             ClientRequestError::Cancellation(e) => e.get_error_kind(),
     229            0 :             ClientRequestError::Handshake(e) => e.get_error_kind(),
     230            0 :             ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
     231            0 :             ClientRequestError::ReportedError(e) => e.get_error_kind(),
     232            0 :             ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
     233              :         }
     234            0 :     }
     235              : }
     236              : 
     237              : #[allow(clippy::too_many_arguments)]
     238            0 : pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
     239            0 :     config: &'static ProxyConfig,
     240            0 :     auth_backend: &'static auth::Backend<'static, ()>,
     241            0 :     ctx: &RequestContext,
     242            0 :     cancellation_handler: Arc<CancellationHandler>,
     243            0 :     client: S,
     244            0 :     mode: ClientMode,
     245            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     246            0 :     conn_gauge: NumClientConnectionsGuard<'static>,
     247            0 :     cancellations: tokio_util::task::task_tracker::TaskTracker,
     248            0 : ) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
     249            0 :     debug!(
     250            0 :         protocol = %ctx.protocol(),
     251            0 :         "handling interactive connection from client"
     252              :     );
     253              : 
     254            0 :     let metrics = &Metrics::get().proxy;
     255            0 :     let proto = ctx.protocol();
     256            0 :     let request_gauge = metrics.connection_requests.guard(proto);
     257              : 
     258            0 :     let tls = config.tls_config.load();
     259            0 :     let tls = tls.as_deref();
     260              : 
     261            0 :     let record_handshake_error = !ctx.has_private_peer_addr();
     262            0 :     let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
     263            0 :     let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
     264              : 
     265            0 :     let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
     266            0 :         .await??
     267              :     {
     268            0 :         HandshakeData::Startup(client, params) => (client, params),
     269            0 :         HandshakeData::Cancel(cancel_key_data) => {
     270              :             // spawn a task to cancel the session, but don't wait for it
     271            0 :             cancellations.spawn({
     272            0 :                 let cancellation_handler_clone = Arc::clone(&cancellation_handler);
     273            0 :                 let ctx = ctx.clone();
     274            0 :                 let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
     275            0 :                 cancel_span.follows_from(tracing::Span::current());
     276            0 :                 async move {
     277            0 :                     cancellation_handler_clone
     278            0 :                         .cancel_session(
     279            0 :                             cancel_key_data,
     280            0 :                             ctx,
     281            0 :                             config.authentication_config.ip_allowlist_check_enabled,
     282            0 :                             config.authentication_config.is_vpc_acccess_proxy,
     283            0 :                             auth_backend.get_api(),
     284            0 :                         )
     285            0 :                         .await
     286            0 :                         .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
     287            0 :                 }.instrument(cancel_span)
     288              :             });
     289              : 
     290            0 :             return Ok(None);
     291              :         }
     292              :     };
     293            0 :     drop(pause);
     294              : 
     295            0 :     ctx.set_db_options(params.clone());
     296              : 
     297            0 :     let common_names = tls.map(|tls| &tls.common_names);
     298              : 
     299            0 :     let (node, cancel_on_shutdown) = handle_client(
     300            0 :         config,
     301            0 :         auth_backend,
     302            0 :         ctx,
     303            0 :         cancellation_handler,
     304            0 :         &mut client,
     305            0 :         &mode,
     306            0 :         endpoint_rate_limiter,
     307            0 :         common_names,
     308            0 :         &params,
     309            0 :     )
     310            0 :     .await?;
     311              : 
     312            0 :     let client = client.flush_and_into_inner().await?;
     313              : 
     314            0 :     let private_link_id = match ctx.extra() {
     315            0 :         Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
     316            0 :         Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
     317            0 :         None => None,
     318              :     };
     319              : 
     320            0 :     Ok(Some(ProxyPassthrough {
     321            0 :         client,
     322            0 :         compute: node.stream,
     323            0 : 
     324            0 :         aux: node.aux,
     325            0 :         private_link_id,
     326            0 : 
     327            0 :         _cancel_on_shutdown: cancel_on_shutdown,
     328            0 : 
     329            0 :         _req: request_gauge,
     330            0 :         _conn: conn_gauge,
     331            0 :         _db_conn: node.guage,
     332            0 :     }))
     333            0 : }
        

Generated by: LCOV version 2.1-beta