LCOV - code coverage report
Current view: top level - proxy/src/serverless - mod.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 0.0 % 310 0
Test Date: 2025-07-16 12:29:03 Functions: 0.0 % 30 0

            Line data    Source code
       1              : //! Routers for our serverless APIs
       2              : //!
       3              : //! Handles both SQL over HTTP and SQL over Websockets.
       4              : 
       5              : mod backend;
       6              : pub mod cancel_set;
       7              : mod conn_pool;
       8              : mod conn_pool_lib;
       9              : mod error;
      10              : mod http_conn_pool;
      11              : mod http_util;
      12              : mod json;
      13              : mod local_conn_pool;
      14              : mod sql_over_http;
      15              : mod websocket;
      16              : 
      17              : use std::net::{IpAddr, SocketAddr};
      18              : use std::pin::{Pin, pin};
      19              : use std::sync::Arc;
      20              : 
      21              : use anyhow::Context;
      22              : use arc_swap::ArcSwapOption;
      23              : use async_trait::async_trait;
      24              : use atomic_take::AtomicTake;
      25              : use bytes::Bytes;
      26              : pub use conn_pool_lib::GlobalConnPoolOptions;
      27              : use futures::TryFutureExt;
      28              : use futures::future::{Either, select};
      29              : use http::{Method, Response, StatusCode};
      30              : use http_body_util::combinators::BoxBody;
      31              : use http_body_util::{BodyExt, Empty};
      32              : use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
      33              : use http_utils::error::ApiError;
      34              : use hyper::body::Incoming;
      35              : use hyper_util::rt::TokioExecutor;
      36              : use hyper_util::server::conn::auto::Builder;
      37              : use rand::SeedableRng;
      38              : use rand::rngs::StdRng;
      39              : use tokio::io::{AsyncRead, AsyncWrite};
      40              : use tokio::net::{TcpListener, TcpStream};
      41              : use tokio::time::timeout;
      42              : use tokio_rustls::TlsAcceptor;
      43              : use tokio_util::sync::CancellationToken;
      44              : use tokio_util::task::TaskTracker;
      45              : use tracing::{Instrument, info, warn};
      46              : 
      47              : use crate::cancellation::CancellationHandler;
      48              : use crate::config::{ProxyConfig, ProxyProtocolV2};
      49              : use crate::context::RequestContext;
      50              : use crate::ext::TaskExt;
      51              : use crate::metrics::Metrics;
      52              : use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
      53              : use crate::rate_limiter::EndpointRateLimiter;
      54              : use crate::serverless::backend::PoolingBackend;
      55              : use crate::serverless::http_util::{api_error_into_response, json_response};
      56              : use crate::util::run_until_cancelled;
      57              : 
      58              : pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
      59              : pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";
      60              : 
      61            0 : pub async fn task_main(
      62            0 :     config: &'static ProxyConfig,
      63            0 :     auth_backend: &'static crate::auth::Backend<'static, ()>,
      64            0 :     ws_listener: TcpListener,
      65            0 :     cancellation_token: CancellationToken,
      66            0 :     cancellation_handler: Arc<CancellationHandler>,
      67            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      68            0 : ) -> anyhow::Result<()> {
      69            0 :     scopeguard::defer! {
      70              :         info!("websocket server has shut down");
      71              :     }
      72              : 
      73            0 :     let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config);
      74            0 :     let conn_pool = conn_pool_lib::GlobalConnPool::new(&config.http_config);
      75              :     {
      76            0 :         let conn_pool = Arc::clone(&conn_pool);
      77            0 :         tokio::spawn(async move {
      78            0 :             conn_pool.gc_worker(StdRng::from_entropy()).await;
      79            0 :         });
      80              :     }
      81              : 
      82              :     // shutdown the connection pool
      83            0 :     tokio::spawn({
      84            0 :         let cancellation_token = cancellation_token.clone();
      85            0 :         let conn_pool = conn_pool.clone();
      86            0 :         async move {
      87            0 :             cancellation_token.cancelled().await;
      88            0 :             tokio::task::spawn_blocking(move || conn_pool.shutdown())
      89            0 :                 .await
      90            0 :                 .propagate_task_panic();
      91            0 :         }
      92              :     });
      93              : 
      94            0 :     let http_conn_pool = conn_pool_lib::GlobalConnPool::new(&config.http_config);
      95              :     {
      96            0 :         let http_conn_pool = Arc::clone(&http_conn_pool);
      97            0 :         tokio::spawn(async move {
      98            0 :             http_conn_pool.gc_worker(StdRng::from_entropy()).await;
      99            0 :         });
     100              :     }
     101              : 
     102              :     // shutdown the connection pool
     103            0 :     tokio::spawn({
     104            0 :         let cancellation_token = cancellation_token.clone();
     105            0 :         let http_conn_pool = http_conn_pool.clone();
     106            0 :         async move {
     107            0 :             cancellation_token.cancelled().await;
     108            0 :             tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
     109            0 :                 .await
     110            0 :                 .propagate_task_panic();
     111            0 :         }
     112              :     });
     113              : 
     114            0 :     let backend = Arc::new(PoolingBackend {
     115            0 :         http_conn_pool: Arc::clone(&http_conn_pool),
     116            0 :         local_pool,
     117            0 :         pool: Arc::clone(&conn_pool),
     118            0 :         config,
     119            0 :         auth_backend,
     120            0 :         endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
     121            0 :     });
     122            0 :     let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = Arc::new(&config.tls_config);
     123              : 
     124            0 :     let connections = tokio_util::task::task_tracker::TaskTracker::new();
     125            0 :     connections.close(); // allows `connections.wait to complete`
     126              : 
     127            0 :     let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
     128            0 :     while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
     129            0 :         let (conn, peer_addr) = res.context("could not accept TCP stream")?;
     130            0 :         if let Err(e) = conn.set_nodelay(true) {
     131            0 :             tracing::error!("could not set nodelay: {e}");
     132            0 :             continue;
     133            0 :         }
     134            0 :         let conn_id = uuid::Uuid::new_v4();
     135            0 :         let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
     136              : 
     137            0 :         let n_connections = Metrics::get()
     138            0 :             .proxy
     139            0 :             .client_connections
     140            0 :             .sample(crate::metrics::Protocol::Http);
     141            0 :         tracing::trace!(?n_connections, threshold = ?config.http_config.client_conn_threshold, "check");
     142            0 :         if n_connections > config.http_config.client_conn_threshold {
     143            0 :             tracing::trace!("attempting to cancel a random connection");
     144            0 :             if let Some(token) = config.http_config.cancel_set.take() {
     145            0 :                 tracing::debug!("cancelling a random connection");
     146            0 :                 token.cancel();
     147            0 :             }
     148            0 :         }
     149              : 
     150            0 :         let conn_token = cancellation_token.child_token();
     151            0 :         let tls_acceptor = tls_acceptor.clone();
     152            0 :         let backend = backend.clone();
     153            0 :         let connections2 = connections.clone();
     154            0 :         let cancellation_handler = cancellation_handler.clone();
     155            0 :         let endpoint_rate_limiter = endpoint_rate_limiter.clone();
     156            0 :         let cancellations = cancellations.clone();
     157            0 :         connections.spawn(
     158            0 :             async move {
     159            0 :                 let conn_token2 = conn_token.clone();
     160            0 :                 let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
     161              : 
     162            0 :                 let session_id = uuid::Uuid::new_v4();
     163              : 
     164            0 :                 let _gauge = Metrics::get()
     165            0 :                     .proxy
     166            0 :                     .client_connections
     167            0 :                     .guard(crate::metrics::Protocol::Http);
     168              : 
     169            0 :                 let startup_result = Box::pin(connection_startup(
     170            0 :                     config,
     171            0 :                     tls_acceptor,
     172            0 :                     session_id,
     173            0 :                     conn,
     174            0 :                     peer_addr,
     175            0 :                 ))
     176            0 :                 .await;
     177            0 :                 let Some((conn, conn_info)) = startup_result else {
     178            0 :                     return;
     179              :                 };
     180              : 
     181            0 :                 Box::pin(connection_handler(
     182            0 :                     config,
     183            0 :                     backend,
     184            0 :                     connections2,
     185            0 :                     cancellations,
     186            0 :                     cancellation_handler,
     187            0 :                     endpoint_rate_limiter,
     188            0 :                     conn_token,
     189            0 :                     conn,
     190            0 :                     conn_info,
     191            0 :                     session_id,
     192            0 :                 ))
     193            0 :                 .await;
     194            0 :             }
     195            0 :             .instrument(http_conn_span),
     196              :         );
     197              :     }
     198              : 
     199            0 :     connections.wait().await;
     200              : 
     201            0 :     Ok(())
     202            0 : }
     203              : 
     204              : pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
     205              : impl<T: AsyncRead + AsyncWrite + Send + 'static> AsyncReadWrite for T {}
     206              : pub(crate) type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
     207              : 
     208              : #[async_trait]
     209              : trait MaybeTlsAcceptor: Send + Sync + 'static {
     210              :     async fn accept(&self, conn: TcpStream) -> std::io::Result<AsyncRW>;
     211              : }
     212              : 
     213              : #[async_trait]
     214              : impl MaybeTlsAcceptor for &'static ArcSwapOption<crate::config::TlsConfig> {
     215            0 :     async fn accept(&self, conn: TcpStream) -> std::io::Result<AsyncRW> {
     216            0 :         match &*self.load() {
     217            0 :             Some(config) => Ok(Box::pin(
     218            0 :                 TlsAcceptor::from(config.http_config.clone())
     219            0 :                     .accept(conn)
     220            0 :                     .await?,
     221              :             )),
     222            0 :             None => Ok(Box::pin(conn)),
     223              :         }
     224            0 :     }
     225              : }
     226              : 
     227              : /// Handles the TCP startup lifecycle.
     228              : /// 1. Parses PROXY protocol V2
     229              : /// 2. Handles TLS handshake
     230            0 : async fn connection_startup(
     231            0 :     config: &ProxyConfig,
     232            0 :     tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
     233            0 :     session_id: uuid::Uuid,
     234            0 :     conn: TcpStream,
     235            0 :     peer_addr: SocketAddr,
     236            0 : ) -> Option<(AsyncRW, ConnectionInfo)> {
     237              :     // handle PROXY protocol
     238            0 :     let (conn, conn_info) = match config.proxy_protocol_v2 {
     239              :         ProxyProtocolV2::Required => {
     240            0 :             match read_proxy_protocol(conn).await {
     241            0 :                 Err(e) => {
     242            0 :                     warn!("per-client task finished with an error: {e:#}");
     243            0 :                     return None;
     244              :                 }
     245              :                 // our load balancers will not send any more data. let's just exit immediately
     246            0 :                 Ok((_conn, ConnectHeader::Local)) => {
     247            0 :                     tracing::debug!("healthcheck received");
     248            0 :                     return None;
     249              :                 }
     250            0 :                 Ok((conn, ConnectHeader::Proxy(info))) => (conn, info),
     251              :             }
     252              :         }
     253              :         // ignore the header - it cannot be confused for a postgres or http connection so will
     254              :         // error later.
     255            0 :         ProxyProtocolV2::Rejected => (
     256            0 :             conn,
     257            0 :             ConnectionInfo {
     258            0 :                 addr: peer_addr,
     259            0 :                 extra: None,
     260            0 :             },
     261            0 :         ),
     262              :     };
     263              : 
     264            0 :     let has_private_peer_addr = match conn_info.addr.ip() {
     265            0 :         IpAddr::V4(ip) => ip.is_private(),
     266            0 :         IpAddr::V6(_) => false,
     267              :     };
     268            0 :     info!(?session_id, %conn_info, "accepted new TCP connection");
     269              : 
     270              :     // try upgrade to TLS, but with a timeout.
     271            0 :     let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
     272            0 :         Ok(Ok(conn)) => {
     273            0 :             info!(?session_id, %conn_info, "accepted new TLS connection");
     274            0 :             conn
     275              :         }
     276              :         // The handshake failed
     277            0 :         Ok(Err(e)) => {
     278            0 :             if !has_private_peer_addr {
     279            0 :                 Metrics::get().proxy.tls_handshake_failures.inc();
     280            0 :             }
     281            0 :             warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
     282            0 :             return None;
     283              :         }
     284              :         // The handshake timed out
     285            0 :         Err(e) => {
     286            0 :             if !has_private_peer_addr {
     287            0 :                 Metrics::get().proxy.tls_handshake_failures.inc();
     288            0 :             }
     289            0 :             warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
     290            0 :             return None;
     291              :         }
     292              :     };
     293              : 
     294            0 :     Some((conn, conn_info))
     295            0 : }
     296              : 
     297              : /// Handles HTTP connection
     298              : /// 1. With graceful shutdowns
     299              : /// 2. With graceful request cancellation with connection failure
     300              : /// 3. With websocket upgrade support.
     301              : #[allow(clippy::too_many_arguments)]
     302            0 : async fn connection_handler(
     303            0 :     config: &'static ProxyConfig,
     304            0 :     backend: Arc<PoolingBackend>,
     305            0 :     connections: TaskTracker,
     306            0 :     cancellations: TaskTracker,
     307            0 :     cancellation_handler: Arc<CancellationHandler>,
     308            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     309            0 :     cancellation_token: CancellationToken,
     310            0 :     conn: AsyncRW,
     311            0 :     conn_info: ConnectionInfo,
     312            0 :     session_id: uuid::Uuid,
     313            0 : ) {
     314            0 :     let session_id = AtomicTake::new(session_id);
     315              : 
     316              :     // Cancel all current inflight HTTP requests if the HTTP connection is closed.
     317            0 :     let http_cancellation_token = CancellationToken::new();
     318            0 :     let _cancel_connection = http_cancellation_token.clone().drop_guard();
     319              : 
     320            0 :     let conn_info2 = conn_info.clone();
     321            0 :     let server = Builder::new(TokioExecutor::new());
     322            0 :     let conn = server.serve_connection_with_upgrades(
     323            0 :         hyper_util::rt::TokioIo::new(conn),
     324            0 :         hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
     325              :             // First HTTP request shares the same session ID
     326            0 :             let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
     327              : 
     328            0 :             if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) {
     329              :                 // take session_id from request, if given.
     330            0 :                 if let Some(id) = req
     331            0 :                     .headers()
     332            0 :                     .get(&NEON_REQUEST_ID)
     333            0 :                     .and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
     334            0 :                 {
     335            0 :                     session_id = id;
     336            0 :                 }
     337            0 :             }
     338              : 
     339              :             // Cancel the current inflight HTTP request if the requets stream is closed.
     340              :             // This is slightly different to `_cancel_connection` in that
     341              :             // h2 can cancel individual requests with a `RST_STREAM`.
     342            0 :             let http_request_token = http_cancellation_token.child_token();
     343            0 :             let cancel_request = http_request_token.clone().drop_guard();
     344              : 
     345              :             // `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
     346              :             // By spawning the future, we ensure it never gets cancelled until it decides to.
     347            0 :             let cancellations = cancellations.clone();
     348            0 :             let handler = connections.spawn(
     349            0 :                 request_handler(
     350            0 :                     req,
     351            0 :                     config,
     352            0 :                     backend.clone(),
     353            0 :                     connections.clone(),
     354            0 :                     cancellation_handler.clone(),
     355            0 :                     session_id,
     356            0 :                     conn_info2.clone(),
     357            0 :                     http_request_token,
     358            0 :                     endpoint_rate_limiter.clone(),
     359            0 :                     cancellations,
     360              :                 )
     361            0 :                 .in_current_span()
     362            0 :                 .map_ok_or_else(api_error_into_response, |r| r),
     363              :             );
     364            0 :             async move {
     365            0 :                 let mut res = handler.await;
     366            0 :                 cancel_request.disarm();
     367              : 
     368              :                 // add the session ID to the response
     369            0 :                 if let Ok(resp) = &mut res {
     370            0 :                     resp.headers_mut()
     371            0 :                         .append(&NEON_REQUEST_ID, uuid_to_header_value(session_id));
     372            0 :                 }
     373              : 
     374            0 :                 res
     375            0 :             }
     376            0 :         }),
     377              :     );
     378              : 
     379              :     // On cancellation, trigger the HTTP connection handler to shut down.
     380            0 :     let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
     381            0 :         Either::Left((_cancelled, mut conn)) => {
     382            0 :             tracing::debug!(%conn_info, "cancelling connection");
     383            0 :             conn.as_mut().graceful_shutdown();
     384            0 :             conn.await
     385              :         }
     386            0 :         Either::Right((res, _)) => res,
     387              :     };
     388              : 
     389            0 :     match res {
     390            0 :         Ok(()) => tracing::info!(%conn_info, "HTTP connection closed"),
     391            0 :         Err(e) => tracing::warn!(%conn_info, "HTTP connection error {e}"),
     392              :     }
     393            0 : }
     394              : 
     395              : #[allow(clippy::too_many_arguments)]
     396            0 : async fn request_handler(
     397            0 :     mut request: hyper::Request<Incoming>,
     398            0 :     config: &'static ProxyConfig,
     399            0 :     backend: Arc<PoolingBackend>,
     400            0 :     ws_connections: TaskTracker,
     401            0 :     cancellation_handler: Arc<CancellationHandler>,
     402            0 :     session_id: uuid::Uuid,
     403            0 :     conn_info: ConnectionInfo,
     404            0 :     // used to cancel in-flight HTTP requests. not used to cancel websockets
     405            0 :     http_cancellation_token: CancellationToken,
     406            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     407            0 :     cancellations: TaskTracker,
     408            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
     409            0 :     let host = request
     410            0 :         .headers()
     411            0 :         .get("host")
     412            0 :         .and_then(|h| h.to_str().ok())
     413            0 :         .and_then(|h| h.split(':').next())
     414            0 :         .map(|s| s.to_string());
     415              : 
     416              :     // Check if the request is a websocket upgrade request.
     417            0 :     if config.http_config.accept_websockets
     418            0 :         && framed_websockets::upgrade::is_upgrade_request(&request)
     419              :     {
     420            0 :         let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
     421              : 
     422            0 :         ctx.set_user_agent(
     423            0 :             request
     424            0 :                 .headers()
     425            0 :                 .get(hyper::header::USER_AGENT)
     426            0 :                 .and_then(|h| h.to_str().ok())
     427            0 :                 .map(Into::into),
     428              :         );
     429              : 
     430            0 :         let span = ctx.span();
     431            0 :         info!(parent: &span, "performing websocket upgrade");
     432              : 
     433            0 :         let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
     434            0 :             .map_err(|e| ApiError::BadRequest(e.into()))?;
     435              : 
     436            0 :         let cancellations = cancellations.clone();
     437            0 :         ws_connections.spawn(
     438            0 :             async move {
     439            0 :                 if let Err(e) = websocket::serve_websocket(
     440            0 :                     config,
     441            0 :                     backend.auth_backend,
     442            0 :                     ctx,
     443            0 :                     websocket,
     444            0 :                     cancellation_handler,
     445            0 :                     endpoint_rate_limiter,
     446            0 :                     host,
     447            0 :                     cancellations,
     448              :                 )
     449            0 :                 .await
     450              :                 {
     451            0 :                     warn!("error in websocket connection: {e:#}");
     452            0 :                 }
     453            0 :             }
     454            0 :             .instrument(span),
     455              :         );
     456              : 
     457              :         // Return the response so the spawned future can continue.
     458            0 :         Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
     459            0 :     } else if request.uri().path() == "/sql" && *request.method() == Method::POST {
     460            0 :         let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
     461            0 :         let span = ctx.span();
     462              : 
     463            0 :         let testodrome_id = request
     464            0 :             .headers()
     465            0 :             .get("X-Neon-Query-ID")
     466            0 :             .and_then(|value| value.to_str().ok())
     467            0 :             .map(|s| s.to_string());
     468              : 
     469            0 :         if let Some(query_id) = testodrome_id {
     470            0 :             info!(parent: &ctx.span(), "testodrome query ID: {query_id}");
     471            0 :             ctx.set_testodrome_id(query_id.into());
     472            0 :         }
     473              : 
     474            0 :         sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
     475            0 :             .instrument(span)
     476            0 :             .await
     477            0 :     } else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
     478            0 :         Response::builder()
     479            0 :             .header("Allow", "OPTIONS, POST")
     480            0 :             .header("Access-Control-Allow-Origin", "*")
     481            0 :             .header(
     482              :                 "Access-Control-Allow-Headers",
     483              :                 "Authorization, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
     484              :             )
     485            0 :             .header("Access-Control-Max-Age", "86400" /* 24 hours */)
     486            0 :             .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
     487            0 :             .body(Empty::new().map_err(|x| match x {}).boxed())
     488            0 :             .map_err(|e| ApiError::InternalServerError(e.into()))
     489              :     } else {
     490            0 :         json_response(StatusCode::BAD_REQUEST, "query is not supported")
     491              :     }
     492            0 : }
        

Generated by: LCOV version 2.1-beta