LCOV - code coverage report
Current view: top level - proxy/src - serverless.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 77.8 % 189 147
Test Date: 2024-02-07 07:37:29 Functions: 67.7 % 31 21

            Line data    Source code
       1              : //! Routers for our serverless APIs
       2              : //!
       3              : //! Handles both SQL over HTTP and SQL over Websockets.
       4              : 
       5              : mod conn_pool;
       6              : mod json;
       7              : mod sql_over_http;
       8              : mod websocket;
       9              : 
      10              : pub use conn_pool::GlobalConnPoolOptions;
      11              : 
      12              : use anyhow::bail;
      13              : use hyper::StatusCode;
      14              : use metrics::IntCounterPairGuard;
      15              : use rand::rngs::StdRng;
      16              : use rand::SeedableRng;
      17              : pub use reqwest_middleware::{ClientWithMiddleware, Error};
      18              : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
      19              : use tokio_util::task::TaskTracker;
      20              : 
      21              : use crate::config::TlsConfig;
      22              : use crate::context::RequestMonitoring;
      23              : use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
      24              : use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
      25              : use crate::rate_limiter::EndpointRateLimiter;
      26              : use crate::{cancellation::CancelMap, config::ProxyConfig};
      27              : use futures::StreamExt;
      28              : use hyper::{
      29              :     server::{
      30              :         accept,
      31              :         conn::{AddrIncoming, AddrStream},
      32              :     },
      33              :     Body, Method, Request, Response,
      34              : };
      35              : 
      36              : use std::net::IpAddr;
      37              : use std::task::Poll;
      38              : use std::{future::ready, sync::Arc};
      39              : use tls_listener::TlsListener;
      40              : use tokio::net::TcpListener;
      41              : use tokio_util::sync::CancellationToken;
      42              : use tracing::{error, info, info_span, warn, Instrument};
      43              : use utils::http::{error::ApiError, json::json_response};
      44              : 
      45              : pub const SERVERLESS_DRIVER_SNI: &str = "api";
      46              : 
      47           23 : pub async fn task_main(
      48           23 :     config: &'static ProxyConfig,
      49           23 :     ws_listener: TcpListener,
      50           23 :     cancellation_token: CancellationToken,
      51           23 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      52           23 : ) -> anyhow::Result<()> {
      53           23 :     scopeguard::defer! {
      54           23 :         info!("websocket server has shut down");
      55           23 :     }
      56           23 : 
      57           23 :     let conn_pool = conn_pool::GlobalConnPool::new(config);
      58           23 : 
      59           23 :     let conn_pool2 = Arc::clone(&conn_pool);
      60           23 :     tokio::spawn(async move {
      61           27 :         conn_pool2.gc_worker(StdRng::from_entropy()).await;
      62           23 :     });
      63           23 : 
      64           23 :     // shutdown the connection pool
      65           23 :     tokio::spawn({
      66           23 :         let cancellation_token = cancellation_token.clone();
      67           23 :         let conn_pool = conn_pool.clone();
      68           23 :         async move {
      69           23 :             cancellation_token.cancelled().await;
      70           23 :             tokio::task::spawn_blocking(move || conn_pool.shutdown())
      71           22 :                 .await
      72           23 :                 .unwrap();
      73           23 :         }
      74           23 :     });
      75              : 
      76           23 :     let tls_config = match config.tls_config.as_ref() {
      77           23 :         Some(config) => config,
      78              :         None => {
      79            0 :             warn!("TLS config is missing, WebSocket Secure server will not be started");
      80            0 :             return Ok(());
      81              :         }
      82              :     };
      83           23 :     let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into();
      84              : 
      85           23 :     let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
      86           23 :     let _ = addr_incoming.set_nodelay(true);
      87           23 :     let addr_incoming = ProxyProtocolAccept {
      88           23 :         incoming: addr_incoming,
      89           23 :     };
      90           23 : 
      91           23 :     let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
      92           23 :     ws_connections.close(); // allows `ws_connections.wait to complete`
      93           23 : 
      94           46 :     let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
      95           46 :         if let Err(err) = conn {
      96            0 :             error!("failed to accept TLS connection for websockets: {err:?}");
      97            0 :             ready(false)
      98              :         } else {
      99           46 :             ready(true)
     100              :         }
     101           46 :     });
     102           23 : 
     103           23 :     let make_svc = hyper::service::make_service_fn(
     104           46 :         |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
     105           46 :             let (io, tls) = stream.get_ref();
     106           46 :             let client_addr = io.client_addr();
     107           46 :             let remote_addr = io.inner.remote_addr();
     108           46 :             let sni_name = tls.server_name().map(|s| s.to_string());
     109           46 :             let conn_pool = conn_pool.clone();
     110           46 :             let ws_connections = ws_connections.clone();
     111           46 :             let endpoint_rate_limiter = endpoint_rate_limiter.clone();
     112              : 
     113           46 :             async move {
     114           46 :                 let peer_addr = match client_addr {
     115            0 :                     Some(addr) => addr,
     116            0 :                     None if config.require_client_ip => bail!("missing required client ip"),
     117           46 :                     None => remote_addr,
     118              :                 };
     119           46 :                 Ok(MetricService::new(hyper::service::service_fn(
     120           46 :                     move |req: Request<Body>| {
     121           46 :                         let sni_name = sni_name.clone();
     122           46 :                         let conn_pool = conn_pool.clone();
     123           46 :                         let ws_connections = ws_connections.clone();
     124           46 :                         let endpoint_rate_limiter = endpoint_rate_limiter.clone();
     125              : 
     126           46 :                         async move {
     127           46 :                             let cancel_map = Arc::new(CancelMap::default());
     128           46 :                             let session_id = uuid::Uuid::new_v4();
     129           46 : 
     130           46 :                             request_handler(
     131           46 :                                 req,
     132           46 :                                 config,
     133           46 :                                 tls_config,
     134           46 :                                 conn_pool,
     135           46 :                                 ws_connections,
     136           46 :                                 cancel_map,
     137           46 :                                 session_id,
     138           46 :                                 sni_name,
     139           46 :                                 peer_addr.ip(),
     140           46 :                                 endpoint_rate_limiter,
     141           46 :                             )
     142           46 :                             .instrument(info_span!(
     143           46 :                                 "serverless",
     144           46 :                                 session = %session_id,
     145           46 :                                 %peer_addr,
     146           46 :                             ))
     147          832 :                             .await
     148           46 :                         }
     149           46 :                     },
     150           46 :                 )))
     151           46 :             }
     152           46 :         },
     153           23 :     );
     154           23 : 
     155           23 :     hyper::Server::builder(accept::from_stream(tls_listener))
     156           23 :         .serve(make_svc)
     157           23 :         .with_graceful_shutdown(cancellation_token.cancelled())
     158          214 :         .await?;
     159              : 
     160              :     // await websocket connections
     161           23 :     ws_connections.wait().await;
     162              : 
     163           23 :     Ok(())
     164           23 : }
     165              : 
     166              : struct MetricService<S> {
     167              :     inner: S,
     168              :     _gauge: IntCounterPairGuard,
     169              : }
     170              : 
     171              : impl<S> MetricService<S> {
     172           46 :     fn new(inner: S) -> MetricService<S> {
     173           46 :         MetricService {
     174           46 :             inner,
     175           46 :             _gauge: NUM_CLIENT_CONNECTION_GAUGE
     176           46 :                 .with_label_values(&["http"])
     177           46 :                 .guard(),
     178           46 :         }
     179           46 :     }
     180              : }
     181              : 
     182              : impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
     183              : where
     184              :     S: hyper::service::Service<Request<ReqBody>>,
     185              : {
     186              :     type Response = S::Response;
     187              :     type Error = S::Error;
     188              :     type Future = S::Future;
     189              : 
     190          125 :     fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
     191          125 :         self.inner.poll_ready(cx)
     192          125 :     }
     193              : 
     194           46 :     fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
     195           46 :         self.inner.call(req)
     196           46 :     }
     197              : }
     198              : 
     199              : #[allow(clippy::too_many_arguments)]
     200           46 : async fn request_handler(
     201           46 :     mut request: Request<Body>,
     202           46 :     config: &'static ProxyConfig,
     203           46 :     tls: &'static TlsConfig,
     204           46 :     conn_pool: Arc<conn_pool::GlobalConnPool>,
     205           46 :     ws_connections: TaskTracker,
     206           46 :     cancel_map: Arc<CancelMap>,
     207           46 :     session_id: uuid::Uuid,
     208           46 :     sni_hostname: Option<String>,
     209           46 :     peer_addr: IpAddr,
     210           46 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     211           46 : ) -> Result<Response<Body>, ApiError> {
     212           46 :     let host = request
     213           46 :         .headers()
     214           46 :         .get("host")
     215           46 :         .and_then(|h| h.to_str().ok())
     216           46 :         .and_then(|h| h.split(':').next())
     217           46 :         .map(|s| s.to_string());
     218           46 : 
     219           46 :     // Check if the request is a websocket upgrade request.
     220           46 :     if hyper_tungstenite::is_upgrade_request(&request) {
     221            0 :         info!(session_id = ?session_id, "performing websocket upgrade");
     222              : 
     223            0 :         let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
     224            0 :             .map_err(|e| ApiError::BadRequest(e.into()))?;
     225              : 
     226            0 :         ws_connections.spawn(
     227            0 :             async move {
     228            0 :                 let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
     229              : 
     230            0 :                 if let Err(e) = websocket::serve_websocket(
     231            0 :                     config,
     232            0 :                     &mut ctx,
     233            0 :                     websocket,
     234            0 :                     cancel_map,
     235            0 :                     host,
     236            0 :                     endpoint_rate_limiter,
     237            0 :                 )
     238            0 :                 .await
     239              :                 {
     240            0 :                     error!(session_id = ?session_id, "error in websocket connection: {e:#}");
     241            0 :                 }
     242            0 :             }
     243            0 :             .in_current_span(),
     244            0 :         );
     245            0 : 
     246            0 :         // Return the response so the spawned future can continue.
     247            0 :         Ok(response)
     248           46 :     } else if request.uri().path() == "/sql" && request.method() == Method::POST {
     249           46 :         let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
     250           46 : 
     251           46 :         sql_over_http::handle(
     252           46 :             tls,
     253           46 :             &config.http_config,
     254           46 :             &mut ctx,
     255           46 :             request,
     256           46 :             sni_hostname,
     257           46 :             conn_pool,
     258           46 :         )
     259          832 :         .await
     260            0 :     } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
     261            0 :         Response::builder()
     262            0 :             .header("Allow", "OPTIONS, POST")
     263            0 :             .header("Access-Control-Allow-Origin", "*")
     264            0 :             .header(
     265            0 :                 "Access-Control-Allow-Headers",
     266            0 :                 "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
     267            0 :             )
     268            0 :             .header("Access-Control-Max-Age", "86400" /* 24 hours */)
     269            0 :             .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
     270            0 :             .body(Body::empty())
     271            0 :             .map_err(|e| ApiError::InternalServerError(e.into()))
     272              :     } else {
     273            0 :         json_response(StatusCode::BAD_REQUEST, "query is not supported")
     274              :     }
     275           46 : }
        

Generated by: LCOV version 2.1-beta