LCOV - code coverage report
Current view: top level - proxy/src - serverless.rs (source / functions) Coverage Total Hit
Test: 36bb8dd7c7efcb53483d1a7d9f7cb33e8406dcf0.info Lines: 0.0 % 221 0
Test Date: 2024-04-08 10:22:05 Functions: 0.0 % 32 0

            Line data    Source code
       1              : //! Routers for our serverless APIs
       2              : //!
       3              : //! Handles both SQL over HTTP and SQL over Websockets.
       4              : 
       5              : mod backend;
       6              : mod conn_pool;
       7              : mod json;
       8              : mod sql_over_http;
       9              : pub mod tls_listener;
      10              : mod websocket;
      11              : 
      12              : pub use conn_pool::GlobalConnPoolOptions;
      13              : 
      14              : use anyhow::bail;
      15              : use hyper::StatusCode;
      16              : use metrics::IntCounterPairGuard;
      17              : use rand::rngs::StdRng;
      18              : use rand::SeedableRng;
      19              : pub use reqwest_middleware::{ClientWithMiddleware, Error};
      20              : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
      21              : use tokio_util::task::TaskTracker;
      22              : use tracing::instrument::Instrumented;
      23              : 
      24              : use crate::cancellation::CancellationHandlerMain;
      25              : use crate::config::ProxyConfig;
      26              : use crate::context::RequestMonitoring;
      27              : use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
      28              : use crate::rate_limiter::EndpointRateLimiter;
      29              : use crate::serverless::backend::PoolingBackend;
      30              : use hyper::{
      31              :     server::conn::{AddrIncoming, AddrStream},
      32              :     Body, Method, Request, Response,
      33              : };
      34              : 
      35              : use std::net::IpAddr;
      36              : use std::sync::Arc;
      37              : use std::task::Poll;
      38              : use tls_listener::TlsListener;
      39              : use tokio::net::TcpListener;
      40              : use tokio_util::sync::{CancellationToken, DropGuard};
      41              : use tracing::{error, info, warn, Instrument};
      42              : use utils::http::{error::ApiError, json::json_response};
      43              : 
      44              : pub const SERVERLESS_DRIVER_SNI: &str = "api";
      45              : 
      46            0 : pub async fn task_main(
      47            0 :     config: &'static ProxyConfig,
      48            0 :     ws_listener: TcpListener,
      49            0 :     cancellation_token: CancellationToken,
      50            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      51            0 :     cancellation_handler: Arc<CancellationHandlerMain>,
      52            0 : ) -> anyhow::Result<()> {
      53            0 :     scopeguard::defer! {
      54            0 :         info!("websocket server has shut down");
      55              :     }
      56              : 
      57            0 :     let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
      58            0 :     {
      59            0 :         let conn_pool = Arc::clone(&conn_pool);
      60            0 :         tokio::spawn(async move {
      61            0 :             conn_pool.gc_worker(StdRng::from_entropy()).await;
      62            0 :         });
      63            0 :     }
      64            0 : 
      65            0 :     // shutdown the connection pool
      66            0 :     tokio::spawn({
      67            0 :         let cancellation_token = cancellation_token.clone();
      68            0 :         let conn_pool = conn_pool.clone();
      69            0 :         async move {
      70            0 :             cancellation_token.cancelled().await;
      71            0 :             tokio::task::spawn_blocking(move || conn_pool.shutdown())
      72            0 :                 .await
      73            0 :                 .unwrap();
      74            0 :         }
      75            0 :     });
      76            0 : 
      77            0 :     let backend = Arc::new(PoolingBackend {
      78            0 :         pool: Arc::clone(&conn_pool),
      79            0 :         config,
      80            0 :     });
      81              : 
      82            0 :     let tls_config = match config.tls_config.as_ref() {
      83            0 :         Some(config) => config,
      84              :         None => {
      85            0 :             warn!("TLS config is missing, WebSocket Secure server will not be started");
      86            0 :             return Ok(());
      87              :         }
      88              :     };
      89            0 :     let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
      90            0 :     // prefer http2, but support http/1.1
      91            0 :     tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
      92            0 :     let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
      93              : 
      94            0 :     let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
      95            0 :     let _ = addr_incoming.set_nodelay(true);
      96            0 :     let addr_incoming = ProxyProtocolAccept {
      97            0 :         incoming: addr_incoming,
      98            0 :         protocol: "http",
      99            0 :     };
     100            0 : 
     101            0 :     let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
     102            0 :     ws_connections.close(); // allows `ws_connections.wait to complete`
     103            0 : 
     104            0 :     let tls_listener = TlsListener::new(tls_acceptor, addr_incoming, config.handshake_timeout);
     105            0 : 
     106            0 :     let make_svc = hyper::service::make_service_fn(
     107            0 :         |stream: &tokio_rustls::server::TlsStream<
     108              :             WithConnectionGuard<WithClientIp<AddrStream>>,
     109            0 :         >| {
     110            0 :             let (conn, _) = stream.get_ref();
     111            0 : 
     112            0 :             // this is jank. should dissapear with hyper 1.0 migration.
     113            0 :             let gauge = conn
     114            0 :                 .gauge
     115            0 :                 .lock()
     116            0 :                 .expect("lock should not be poisoned")
     117            0 :                 .take()
     118            0 :                 .expect("gauge should be set on connection start");
     119            0 : 
     120            0 :             // Cancel all current inflight HTTP requests if the HTTP connection is closed.
     121            0 :             let http_cancellation_token = CancellationToken::new();
     122            0 :             let cancel_connection = http_cancellation_token.clone().drop_guard();
     123            0 : 
     124            0 :             let span = conn.span.clone();
     125            0 :             let client_addr = conn.inner.client_addr();
     126            0 :             let remote_addr = conn.inner.inner.remote_addr();
     127            0 :             let backend = backend.clone();
     128            0 :             let ws_connections = ws_connections.clone();
     129            0 :             let endpoint_rate_limiter = endpoint_rate_limiter.clone();
     130            0 :             let cancellation_handler = cancellation_handler.clone();
     131            0 :             async move {
     132            0 :                 let peer_addr = match client_addr {
     133            0 :                     Some(addr) => addr,
     134            0 :                     None if config.require_client_ip => bail!("missing required client ip"),
     135            0 :                     None => remote_addr,
     136              :                 };
     137            0 :                 Ok(MetricService::new(
     138            0 :                     hyper::service::service_fn(move |req: Request<Body>| {
     139            0 :                         let backend = backend.clone();
     140            0 :                         let ws_connections2 = ws_connections.clone();
     141            0 :                         let endpoint_rate_limiter = endpoint_rate_limiter.clone();
     142            0 :                         let cancellation_handler = cancellation_handler.clone();
     143            0 :                         let http_cancellation_token = http_cancellation_token.child_token();
     144            0 : 
     145            0 :                         // `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
     146            0 :                         // By spawning the future, we ensure it never gets cancelled until it decides to.
     147            0 :                         ws_connections.spawn(
     148            0 :                             async move {
     149            0 :                                 // Cancel the current inflight HTTP request if the requets stream is closed.
     150            0 :                                 // This is slightly different to `_cancel_connection` in that
     151            0 :                                 // h2 can cancel individual requests with a `RST_STREAM`.
     152            0 :                                 let _cancel_session = http_cancellation_token.clone().drop_guard();
     153              : 
     154            0 :                                 let res = request_handler(
     155            0 :                                     req,
     156            0 :                                     config,
     157            0 :                                     backend,
     158            0 :                                     ws_connections2,
     159            0 :                                     cancellation_handler,
     160            0 :                                     peer_addr.ip(),
     161            0 :                                     endpoint_rate_limiter,
     162            0 :                                     http_cancellation_token,
     163            0 :                                 )
     164            0 :                                 .await
     165            0 :                                 .map_or_else(|e| e.into_response(), |r| r);
     166            0 : 
     167            0 :                                 _cancel_session.disarm();
     168            0 : 
     169            0 :                                 res
     170            0 :                             }
     171            0 :                             .in_current_span(),
     172            0 :                         )
     173            0 :                     }),
     174            0 :                     gauge,
     175            0 :                     cancel_connection,
     176            0 :                     span,
     177            0 :                 ))
     178            0 :             }
     179            0 :         },
     180            0 :     );
     181            0 : 
     182            0 :     hyper::Server::builder(tls_listener)
     183            0 :         .serve(make_svc)
     184            0 :         .with_graceful_shutdown(cancellation_token.cancelled())
     185            0 :         .await?;
     186              : 
     187              :     // await websocket connections
     188            0 :     ws_connections.wait().await;
     189              : 
     190            0 :     Ok(())
     191            0 : }
     192              : 
     193              : struct MetricService<S> {
     194              :     inner: S,
     195              :     _gauge: IntCounterPairGuard,
     196              :     _cancel: DropGuard,
     197              :     span: tracing::Span,
     198              : }
     199              : 
     200              : impl<S> MetricService<S> {
     201            0 :     fn new(
     202            0 :         inner: S,
     203            0 :         _gauge: IntCounterPairGuard,
     204            0 :         _cancel: DropGuard,
     205            0 :         span: tracing::Span,
     206            0 :     ) -> MetricService<S> {
     207            0 :         MetricService {
     208            0 :             inner,
     209            0 :             _gauge,
     210            0 :             _cancel,
     211            0 :             span,
     212            0 :         }
     213            0 :     }
     214              : }
     215              : 
     216              : impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
     217              : where
     218              :     S: hyper::service::Service<Request<ReqBody>>,
     219              : {
     220              :     type Response = S::Response;
     221              :     type Error = S::Error;
     222              :     type Future = Instrumented<S::Future>;
     223              : 
     224            0 :     fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
     225            0 :         self.inner.poll_ready(cx)
     226            0 :     }
     227              : 
     228            0 :     fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
     229            0 :         self.span
     230            0 :             .in_scope(|| self.inner.call(req))
     231            0 :             .instrument(self.span.clone())
     232            0 :     }
     233              : }
     234              : 
     235              : #[allow(clippy::too_many_arguments)]
     236            0 : async fn request_handler(
     237            0 :     mut request: Request<Body>,
     238            0 :     config: &'static ProxyConfig,
     239            0 :     backend: Arc<PoolingBackend>,
     240            0 :     ws_connections: TaskTracker,
     241            0 :     cancellation_handler: Arc<CancellationHandlerMain>,
     242            0 :     peer_addr: IpAddr,
     243            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     244            0 :     // used to cancel in-flight HTTP requests. not used to cancel websockets
     245            0 :     http_cancellation_token: CancellationToken,
     246            0 : ) -> Result<Response<Body>, ApiError> {
     247            0 :     let session_id = uuid::Uuid::new_v4();
     248            0 : 
     249            0 :     let host = request
     250            0 :         .headers()
     251            0 :         .get("host")
     252            0 :         .and_then(|h| h.to_str().ok())
     253            0 :         .and_then(|h| h.split(':').next())
     254            0 :         .map(|s| s.to_string());
     255            0 : 
     256            0 :     // Check if the request is a websocket upgrade request.
     257            0 :     if hyper_tungstenite::is_upgrade_request(&request) {
     258            0 :         let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
     259            0 :         let span = ctx.span.clone();
     260            0 :         info!(parent: &span, "performing websocket upgrade");
     261              : 
     262            0 :         let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
     263            0 :             .map_err(|e| ApiError::BadRequest(e.into()))?;
     264              : 
     265            0 :         ws_connections.spawn(
     266            0 :             async move {
     267            0 :                 if let Err(e) = websocket::serve_websocket(
     268            0 :                     config,
     269            0 :                     ctx,
     270            0 :                     websocket,
     271            0 :                     cancellation_handler,
     272            0 :                     host,
     273            0 :                     endpoint_rate_limiter,
     274            0 :                 )
     275            0 :                 .await
     276              :                 {
     277            0 :                     error!("error in websocket connection: {e:#}");
     278            0 :                 }
     279            0 :             }
     280            0 :             .instrument(span),
     281            0 :         );
     282            0 : 
     283            0 :         // Return the response so the spawned future can continue.
     284            0 :         Ok(response)
     285            0 :     } else if request.uri().path() == "/sql" && request.method() == Method::POST {
     286            0 :         let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
     287            0 :         let span = ctx.span.clone();
     288            0 : 
     289            0 :         sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
     290            0 :             .instrument(span)
     291            0 :             .await
     292            0 :     } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
     293            0 :         Response::builder()
     294            0 :             .header("Allow", "OPTIONS, POST")
     295            0 :             .header("Access-Control-Allow-Origin", "*")
     296            0 :             .header(
     297            0 :                 "Access-Control-Allow-Headers",
     298            0 :                 "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
     299            0 :             )
     300            0 :             .header("Access-Control-Max-Age", "86400" /* 24 hours */)
     301            0 :             .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
     302            0 :             .body(Body::empty())
     303            0 :             .map_err(|e| ApiError::InternalServerError(e.into()))
     304              :     } else {
     305            0 :         json_response(StatusCode::BAD_REQUEST, "query is not supported")
     306              :     }
     307            0 : }
        

Generated by: LCOV version 2.1-beta