LCOV - code coverage report
Current view: top level - proxy/src - serverless.rs (source / functions) Coverage Total Hit
Test: b837401fb09d2d9818b70e630fdb67e9799b7b0d.info Lines: 0.0 % 236 0
Test Date: 2024-04-18 15:32:49 Functions: 0.0 % 31 0

            Line data    Source code
       1              : //! Routers for our serverless APIs
       2              : //!
       3              : //! Handles both SQL over HTTP and SQL over Websockets.
       4              : 
       5              : mod backend;
       6              : mod conn_pool;
       7              : mod http_util;
       8              : mod json;
       9              : mod sql_over_http;
      10              : mod websocket;
      11              : 
      12              : use atomic_take::AtomicTake;
      13              : use bytes::Bytes;
      14              : pub use conn_pool::GlobalConnPoolOptions;
      15              : 
      16              : use anyhow::Context;
      17              : use futures::future::{select, Either};
      18              : use futures::TryFutureExt;
      19              : use http::{Method, Response, StatusCode};
      20              : use http_body_util::Full;
      21              : use hyper1::body::Incoming;
      22              : use hyper_util::rt::TokioExecutor;
      23              : use hyper_util::server::conn::auto::Builder;
      24              : use rand::rngs::StdRng;
      25              : use rand::SeedableRng;
      26              : pub use reqwest_middleware::{ClientWithMiddleware, Error};
      27              : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
      28              : use tokio::time::timeout;
      29              : use tokio_rustls::TlsAcceptor;
      30              : use tokio_util::task::TaskTracker;
      31              : 
      32              : use crate::cancellation::CancellationHandlerMain;
      33              : use crate::config::ProxyConfig;
      34              : use crate::context::RequestMonitoring;
      35              : use crate::metrics::Metrics;
      36              : use crate::protocol2::WithClientIp;
      37              : use crate::proxy::run_until_cancelled;
      38              : use crate::serverless::backend::PoolingBackend;
      39              : use crate::serverless::http_util::{api_error_into_response, json_response};
      40              : 
      41              : use std::net::{IpAddr, SocketAddr};
      42              : use std::pin::pin;
      43              : use std::sync::Arc;
      44              : use tokio::net::{TcpListener, TcpStream};
      45              : use tokio_util::sync::CancellationToken;
      46              : use tracing::{error, info, warn, Instrument};
      47              : use utils::http::error::ApiError;
      48              : 
      49              : pub const SERVERLESS_DRIVER_SNI: &str = "api";
      50              : 
      51            0 : pub async fn task_main(
      52            0 :     config: &'static ProxyConfig,
      53            0 :     ws_listener: TcpListener,
      54            0 :     cancellation_token: CancellationToken,
      55            0 :     cancellation_handler: Arc<CancellationHandlerMain>,
      56            0 : ) -> anyhow::Result<()> {
      57            0 :     scopeguard::defer! {
      58            0 :         info!("websocket server has shut down");
      59              :     }
      60              : 
      61            0 :     let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
      62            0 :     {
      63            0 :         let conn_pool = Arc::clone(&conn_pool);
      64            0 :         tokio::spawn(async move {
      65            0 :             conn_pool.gc_worker(StdRng::from_entropy()).await;
      66            0 :         });
      67            0 :     }
      68            0 : 
      69            0 :     // shutdown the connection pool
      70            0 :     tokio::spawn({
      71            0 :         let cancellation_token = cancellation_token.clone();
      72            0 :         let conn_pool = conn_pool.clone();
      73            0 :         async move {
      74            0 :             cancellation_token.cancelled().await;
      75            0 :             tokio::task::spawn_blocking(move || conn_pool.shutdown())
      76            0 :                 .await
      77            0 :                 .unwrap();
      78            0 :         }
      79            0 :     });
      80            0 : 
      81            0 :     let backend = Arc::new(PoolingBackend {
      82            0 :         pool: Arc::clone(&conn_pool),
      83            0 :         config,
      84            0 :     });
      85              : 
      86            0 :     let tls_config = match config.tls_config.as_ref() {
      87            0 :         Some(config) => config,
      88              :         None => {
      89            0 :             warn!("TLS config is missing, WebSocket Secure server will not be started");
      90            0 :             return Ok(());
      91              :         }
      92              :     };
      93            0 :     let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
      94            0 :     // prefer http2, but support http/1.1
      95            0 :     tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
      96            0 :     let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
      97            0 : 
      98            0 :     let connections = tokio_util::task::task_tracker::TaskTracker::new();
      99            0 :     connections.close(); // allows `connections.wait to complete`
     100            0 : 
     101            0 :     let server = Builder::new(hyper_util::rt::TokioExecutor::new());
     102              : 
     103            0 :     while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
     104            0 :         let (conn, peer_addr) = res.context("could not accept TCP stream")?;
     105            0 :         if let Err(e) = conn.set_nodelay(true) {
     106            0 :             tracing::error!("could not set nodelay: {e}");
     107            0 :             continue;
     108            0 :         }
     109            0 :         let conn_id = uuid::Uuid::new_v4();
     110            0 :         let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
     111              : 
     112            0 :         connections.spawn(
     113            0 :             connection_handler(
     114            0 :                 config,
     115            0 :                 backend.clone(),
     116            0 :                 connections.clone(),
     117            0 :                 cancellation_handler.clone(),
     118            0 :                 cancellation_token.clone(),
     119            0 :                 server.clone(),
     120            0 :                 tls_acceptor.clone(),
     121            0 :                 conn,
     122            0 :                 peer_addr,
     123            0 :             )
     124            0 :             .instrument(http_conn_span),
     125            0 :         );
     126              :     }
     127              : 
     128            0 :     connections.wait().await;
     129              : 
     130            0 :     Ok(())
     131            0 : }
     132              : 
     133              : /// Handles the TCP lifecycle.
     134              : ///
     135              : /// 1. Parses PROXY protocol V2
     136              : /// 2. Handles TLS handshake
     137              : /// 3. Handles HTTP connection
     138              : ///     1. With graceful shutdowns
     139              : ///     2. With graceful request cancellation with connection failure
     140              : ///     3. With websocket upgrade support.
     141              : #[allow(clippy::too_many_arguments)]
     142            0 : async fn connection_handler(
     143            0 :     config: &'static ProxyConfig,
     144            0 :     backend: Arc<PoolingBackend>,
     145            0 :     connections: TaskTracker,
     146            0 :     cancellation_handler: Arc<CancellationHandlerMain>,
     147            0 :     cancellation_token: CancellationToken,
     148            0 :     server: Builder<TokioExecutor>,
     149            0 :     tls_acceptor: TlsAcceptor,
     150            0 :     conn: TcpStream,
     151            0 :     peer_addr: SocketAddr,
     152            0 : ) {
     153            0 :     let session_id = uuid::Uuid::new_v4();
     154            0 : 
     155            0 :     let _gauge = Metrics::get()
     156            0 :         .proxy
     157            0 :         .client_connections
     158            0 :         .guard(crate::metrics::Protocol::Http);
     159            0 : 
     160            0 :     // handle PROXY protocol
     161            0 :     let mut conn = WithClientIp::new(conn);
     162            0 :     let peer = match conn.wait_for_addr().await {
     163            0 :         Ok(peer) => peer,
     164            0 :         Err(e) => {
     165            0 :             tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
     166            0 :             return;
     167              :         }
     168              :     };
     169              : 
     170            0 :     let peer_addr = peer.unwrap_or(peer_addr).ip();
     171            0 :     let has_private_peer_addr = match peer_addr {
     172            0 :         IpAddr::V4(ip) => ip.is_private(),
     173            0 :         _ => false,
     174              :     };
     175            0 :     info!(?session_id, %peer_addr, "accepted new TCP connection");
     176              : 
     177              :     // try upgrade to TLS, but with a timeout.
     178            0 :     let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
     179            0 :         Ok(Ok(conn)) => {
     180            0 :             info!(?session_id, %peer_addr, "accepted new TLS connection");
     181            0 :             conn
     182              :         }
     183              :         // The handshake failed
     184            0 :         Ok(Err(e)) => {
     185            0 :             if !has_private_peer_addr {
     186            0 :                 Metrics::get().proxy.tls_handshake_failures.inc();
     187            0 :             }
     188            0 :             warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
     189            0 :             return;
     190              :         }
     191              :         // The handshake timed out
     192            0 :         Err(e) => {
     193            0 :             if !has_private_peer_addr {
     194            0 :                 Metrics::get().proxy.tls_handshake_failures.inc();
     195            0 :             }
     196            0 :             warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
     197            0 :             return;
     198              :         }
     199              :     };
     200              : 
     201            0 :     let session_id = AtomicTake::new(session_id);
     202            0 : 
     203            0 :     // Cancel all current inflight HTTP requests if the HTTP connection is closed.
     204            0 :     let http_cancellation_token = CancellationToken::new();
     205            0 :     let _cancel_connection = http_cancellation_token.clone().drop_guard();
     206            0 : 
     207            0 :     let conn = server.serve_connection_with_upgrades(
     208            0 :         hyper_util::rt::TokioIo::new(conn),
     209            0 :         hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
     210            0 :             // First HTTP request shares the same session ID
     211            0 :             let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
     212            0 : 
     213            0 :             // Cancel the current inflight HTTP request if the requets stream is closed.
     214            0 :             // This is slightly different to `_cancel_connection` in that
     215            0 :             // h2 can cancel individual requests with a `RST_STREAM`.
     216            0 :             let http_request_token = http_cancellation_token.child_token();
     217            0 :             let cancel_request = http_request_token.clone().drop_guard();
     218            0 : 
     219            0 :             // `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
     220            0 :             // By spawning the future, we ensure it never gets cancelled until it decides to.
     221            0 :             let handler = connections.spawn(
     222            0 :                 request_handler(
     223            0 :                     req,
     224            0 :                     config,
     225            0 :                     backend.clone(),
     226            0 :                     connections.clone(),
     227            0 :                     cancellation_handler.clone(),
     228            0 :                     session_id,
     229            0 :                     peer_addr,
     230            0 :                     http_request_token,
     231            0 :                 )
     232            0 :                 .in_current_span()
     233            0 :                 .map_ok_or_else(api_error_into_response, |r| r),
     234            0 :             );
     235              : 
     236            0 :             async move {
     237            0 :                 let res = handler.await;
     238            0 :                 cancel_request.disarm();
     239            0 :                 res
     240            0 :             }
     241            0 :         }),
     242            0 :     );
     243              : 
     244              :     // On cancellation, trigger the HTTP connection handler to shut down.
     245            0 :     let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
     246            0 :         Either::Left((_cancelled, mut conn)) => {
     247            0 :             conn.as_mut().graceful_shutdown();
     248            0 :             conn.await
     249              :         }
     250            0 :         Either::Right((res, _)) => res,
     251              :     };
     252              : 
     253            0 :     match res {
     254            0 :         Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
     255            0 :         Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
     256              :     }
     257            0 : }
     258              : 
     259              : #[allow(clippy::too_many_arguments)]
     260            0 : async fn request_handler(
     261            0 :     mut request: hyper1::Request<Incoming>,
     262            0 :     config: &'static ProxyConfig,
     263            0 :     backend: Arc<PoolingBackend>,
     264            0 :     ws_connections: TaskTracker,
     265            0 :     cancellation_handler: Arc<CancellationHandlerMain>,
     266            0 :     session_id: uuid::Uuid,
     267            0 :     peer_addr: IpAddr,
     268            0 :     // used to cancel in-flight HTTP requests. not used to cancel websockets
     269            0 :     http_cancellation_token: CancellationToken,
     270            0 : ) -> Result<Response<Full<Bytes>>, ApiError> {
     271            0 :     let host = request
     272            0 :         .headers()
     273            0 :         .get("host")
     274            0 :         .and_then(|h| h.to_str().ok())
     275            0 :         .and_then(|h| h.split(':').next())
     276            0 :         .map(|s| s.to_string());
     277            0 : 
     278            0 :     // Check if the request is a websocket upgrade request.
     279            0 :     if hyper_tungstenite::is_upgrade_request(&request) {
     280            0 :         let ctx = RequestMonitoring::new(
     281            0 :             session_id,
     282            0 :             peer_addr,
     283            0 :             crate::metrics::Protocol::Ws,
     284            0 :             &config.region,
     285            0 :         );
     286            0 : 
     287            0 :         let span = ctx.span.clone();
     288            0 :         info!(parent: &span, "performing websocket upgrade");
     289              : 
     290            0 :         let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
     291            0 :             .map_err(|e| ApiError::BadRequest(e.into()))?;
     292              : 
     293            0 :         ws_connections.spawn(
     294            0 :             async move {
     295            0 :                 if let Err(e) =
     296            0 :                     websocket::serve_websocket(config, ctx, websocket, cancellation_handler, host)
     297            0 :                         .await
     298              :                 {
     299            0 :                     error!("error in websocket connection: {e:#}");
     300            0 :                 }
     301            0 :             }
     302            0 :             .instrument(span),
     303            0 :         );
     304            0 : 
     305            0 :         // Return the response so the spawned future can continue.
     306            0 :         Ok(response)
     307            0 :     } else if request.uri().path() == "/sql" && *request.method() == Method::POST {
     308            0 :         let ctx = RequestMonitoring::new(
     309            0 :             session_id,
     310            0 :             peer_addr,
     311            0 :             crate::metrics::Protocol::Http,
     312            0 :             &config.region,
     313            0 :         );
     314            0 :         let span = ctx.span.clone();
     315            0 : 
     316            0 :         sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
     317            0 :             .instrument(span)
     318            0 :             .await
     319            0 :     } else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
     320            0 :         Response::builder()
     321            0 :             .header("Allow", "OPTIONS, POST")
     322            0 :             .header("Access-Control-Allow-Origin", "*")
     323            0 :             .header(
     324            0 :                 "Access-Control-Allow-Headers",
     325            0 :                 "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
     326            0 :             )
     327            0 :             .header("Access-Control-Max-Age", "86400" /* 24 hours */)
     328            0 :             .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
     329            0 :             .body(Full::new(Bytes::new()))
     330            0 :             .map_err(|e| ApiError::InternalServerError(e.into()))
     331              :     } else {
     332            0 :         json_response(StatusCode::BAD_REQUEST, "query is not supported")
     333              :     }
     334            0 : }
        

Generated by: LCOV version 2.1-beta