LCOV - differential code coverage report
Current view: top level - proxy/src - serverless.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 77.5 % 187 145 42 145
Current Date: 2024-01-09 02:06:09 Functions: 68.8 % 32 22 10 22
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Routers for our serverless APIs
       2                 : //!
       3                 : //! Handles both SQL over HTTP and SQL over Websockets.
       4                 : 
       5                 : mod conn_pool;
       6                 : mod sql_over_http;
       7                 : mod websocket;
       8                 : 
       9                 : pub use conn_pool::GlobalConnPoolOptions;
      10                 : 
      11                 : use anyhow::bail;
      12                 : use hyper::StatusCode;
      13                 : use metrics::IntCounterPairGuard;
      14                 : use rand::rngs::StdRng;
      15                 : use rand::SeedableRng;
      16                 : pub use reqwest_middleware::{ClientWithMiddleware, Error};
      17                 : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
      18                 : use tokio_util::task::TaskTracker;
      19                 : 
      20                 : use crate::context::RequestMonitoring;
      21                 : use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
      22                 : use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
      23                 : use crate::rate_limiter::EndpointRateLimiter;
      24                 : use crate::{cancellation::CancelMap, config::ProxyConfig};
      25                 : use futures::StreamExt;
      26                 : use hyper::{
      27                 :     server::{
      28                 :         accept,
      29                 :         conn::{AddrIncoming, AddrStream},
      30                 :     },
      31                 :     Body, Method, Request, Response,
      32                 : };
      33                 : 
      34                 : use std::net::IpAddr;
      35                 : use std::task::Poll;
      36                 : use std::{future::ready, sync::Arc};
      37                 : use tls_listener::TlsListener;
      38                 : use tokio::net::TcpListener;
      39                 : use tokio_util::sync::CancellationToken;
      40                 : use tracing::{error, info, info_span, warn, Instrument};
      41                 : use utils::http::{error::ApiError, json::json_response};
      42                 : 
      43 CBC          22 : pub async fn task_main(
      44              22 :     config: &'static ProxyConfig,
      45              22 :     ws_listener: TcpListener,
      46              22 :     cancellation_token: CancellationToken,
      47              22 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      48              22 : ) -> anyhow::Result<()> {
      49              22 :     scopeguard::defer! {
      50              22 :         info!("websocket server has shut down");
      51              22 :     }
      52              22 : 
      53              22 :     let conn_pool = conn_pool::GlobalConnPool::new(config);
      54              22 : 
      55              22 :     let conn_pool2 = Arc::clone(&conn_pool);
      56              22 :     tokio::spawn(async move {
      57              30 :         conn_pool2.gc_worker(StdRng::from_entropy()).await;
      58              22 :     });
      59              22 : 
      60              22 :     // shutdown the connection pool
      61              22 :     tokio::spawn({
      62              22 :         let cancellation_token = cancellation_token.clone();
      63              22 :         let conn_pool = conn_pool.clone();
      64              22 :         async move {
      65              22 :             cancellation_token.cancelled().await;
      66              22 :             tokio::task::spawn_blocking(move || conn_pool.shutdown())
      67              20 :                 .await
      68              21 :                 .unwrap();
      69              22 :         }
      70              22 :     });
      71              22 : 
      72              22 :     let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
      73              22 :     let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
      74              22 :         Some(config) => config.into(),
      75                 :         None => {
      76 UBC           0 :             warn!("TLS config is missing, WebSocket Secure server will not be started");
      77               0 :             return Ok(());
      78                 :         }
      79                 :     };
      80                 : 
      81 CBC          22 :     let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
      82              22 :     let _ = addr_incoming.set_nodelay(true);
      83              22 :     let addr_incoming = ProxyProtocolAccept {
      84              22 :         incoming: addr_incoming,
      85              22 :     };
      86              22 : 
      87              22 :     let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
      88              22 :     ws_connections.close(); // allows `ws_connections.wait to complete`
      89              22 : 
      90              45 :     let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
      91              45 :         if let Err(err) = conn {
      92 UBC           0 :             error!("failed to accept TLS connection for websockets: {err:?}");
      93               0 :             ready(false)
      94                 :         } else {
      95 CBC          45 :             ready(true)
      96                 :         }
      97              45 :     });
      98              22 : 
      99              22 :     let make_svc = hyper::service::make_service_fn(
     100              45 :         |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
     101              45 :             let (io, tls) = stream.get_ref();
     102              45 :             let client_addr = io.client_addr();
     103              45 :             let remote_addr = io.inner.remote_addr();
     104              45 :             let sni_name = tls.server_name().map(|s| s.to_string());
     105              45 :             let conn_pool = conn_pool.clone();
     106              45 :             let ws_connections = ws_connections.clone();
     107              45 :             let endpoint_rate_limiter = endpoint_rate_limiter.clone();
     108                 : 
     109              45 :             async move {
     110              45 :                 let peer_addr = match client_addr {
     111 UBC           0 :                     Some(addr) => addr,
     112               0 :                     None if config.require_client_ip => bail!("missing required client ip"),
     113 CBC          45 :                     None => remote_addr,
     114                 :                 };
     115              45 :                 Ok(MetricService::new(hyper::service::service_fn(
     116              45 :                     move |req: Request<Body>| {
     117              45 :                         let sni_name = sni_name.clone();
     118              45 :                         let conn_pool = conn_pool.clone();
     119              45 :                         let ws_connections = ws_connections.clone();
     120              45 :                         let endpoint_rate_limiter = endpoint_rate_limiter.clone();
     121                 : 
     122              45 :                         async move {
     123              45 :                             let cancel_map = Arc::new(CancelMap::default());
     124              45 :                             let session_id = uuid::Uuid::new_v4();
     125              45 : 
     126              45 :                             request_handler(
     127              45 :                                 req,
     128              45 :                                 config,
     129              45 :                                 conn_pool,
     130              45 :                                 ws_connections,
     131              45 :                                 cancel_map,
     132              45 :                                 session_id,
     133              45 :                                 sni_name,
     134              45 :                                 peer_addr.ip(),
     135              45 :                                 endpoint_rate_limiter,
     136              45 :                             )
     137              45 :                             .instrument(info_span!(
     138              45 :                                 "serverless",
     139              45 :                                 session = %session_id,
     140              45 :                                 %peer_addr,
     141              45 :                             ))
     142             727 :                             .await
     143              45 :                         }
     144              45 :                     },
     145              45 :                 )))
     146              45 :             }
     147              45 :         },
     148              22 :     );
     149              22 : 
     150              22 :     hyper::Server::builder(accept::from_stream(tls_listener))
     151              22 :         .serve(make_svc)
     152              22 :         .with_graceful_shutdown(cancellation_token.cancelled())
     153             207 :         .await?;
     154                 : 
     155                 :     // await websocket connections
     156              22 :     ws_connections.wait().await;
     157                 : 
     158              22 :     Ok(())
     159              22 : }
     160                 : 
     161                 : struct MetricService<S> {
     162                 :     inner: S,
     163                 :     _gauge: IntCounterPairGuard,
     164                 : }
     165                 : 
     166                 : impl<S> MetricService<S> {
     167              45 :     fn new(inner: S) -> MetricService<S> {
     168              45 :         MetricService {
     169              45 :             inner,
     170              45 :             _gauge: NUM_CLIENT_CONNECTION_GAUGE
     171              45 :                 .with_label_values(&["http"])
     172              45 :                 .guard(),
     173              45 :         }
     174              45 :     }
     175                 : }
     176                 : 
     177                 : impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
     178                 : where
     179                 :     S: hyper::service::Service<Request<ReqBody>>,
     180                 : {
     181                 :     type Response = S::Response;
     182                 :     type Error = S::Error;
     183                 :     type Future = S::Future;
     184                 : 
     185             124 :     fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
     186             124 :         self.inner.poll_ready(cx)
     187             124 :     }
     188                 : 
     189              45 :     fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
     190              45 :         self.inner.call(req)
     191              45 :     }
     192                 : }
     193                 : 
     194                 : #[allow(clippy::too_many_arguments)]
     195              45 : async fn request_handler(
     196              45 :     mut request: Request<Body>,
     197              45 :     config: &'static ProxyConfig,
     198              45 :     conn_pool: Arc<conn_pool::GlobalConnPool>,
     199              45 :     ws_connections: TaskTracker,
     200              45 :     cancel_map: Arc<CancelMap>,
     201              45 :     session_id: uuid::Uuid,
     202              45 :     sni_hostname: Option<String>,
     203              45 :     peer_addr: IpAddr,
     204              45 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     205              45 : ) -> Result<Response<Body>, ApiError> {
     206              45 :     let host = request
     207              45 :         .headers()
     208              45 :         .get("host")
     209              45 :         .and_then(|h| h.to_str().ok())
     210              45 :         .and_then(|h| h.split(':').next())
     211              45 :         .map(|s| s.to_string());
     212              45 : 
     213              45 :     // Check if the request is a websocket upgrade request.
     214              45 :     if hyper_tungstenite::is_upgrade_request(&request) {
     215 UBC           0 :         info!(session_id = ?session_id, "performing websocket upgrade");
     216                 : 
     217               0 :         let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
     218               0 :             .map_err(|e| ApiError::BadRequest(e.into()))?;
     219                 : 
     220               0 :         ws_connections.spawn(
     221               0 :             async move {
     222               0 :                 let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
     223                 : 
     224               0 :                 if let Err(e) = websocket::serve_websocket(
     225               0 :                     config,
     226               0 :                     &mut ctx,
     227               0 :                     websocket,
     228               0 :                     &cancel_map,
     229               0 :                     host,
     230               0 :                     endpoint_rate_limiter,
     231               0 :                 )
     232               0 :                 .await
     233                 :                 {
     234               0 :                     error!(session_id = ?session_id, "error in websocket connection: {e:#}");
     235               0 :                 }
     236               0 :             }
     237               0 :             .in_current_span(),
     238               0 :         );
     239               0 : 
     240               0 :         // Return the response so the spawned future can continue.
     241               0 :         Ok(response)
     242 CBC          45 :     } else if request.uri().path() == "/sql" && request.method() == Method::POST {
     243              45 :         let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
     244              45 : 
     245              45 :         sql_over_http::handle(
     246              45 :             &config.http_config,
     247              45 :             &mut ctx,
     248              45 :             request,
     249              45 :             sni_hostname,
     250              45 :             conn_pool,
     251              45 :         )
     252             727 :         .await
     253 UBC           0 :     } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
     254               0 :         Response::builder()
     255               0 :             .header("Allow", "OPTIONS, POST")
     256               0 :             .header("Access-Control-Allow-Origin", "*")
     257               0 :             .header(
     258               0 :                 "Access-Control-Allow-Headers",
     259               0 :                 "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
     260               0 :             )
     261               0 :             .header("Access-Control-Max-Age", "86400" /* 24 hours */)
     262               0 :             .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
     263               0 :             .body(Body::empty())
     264               0 :             .map_err(|e| ApiError::InternalServerError(e.into()))
     265                 :     } else {
     266               0 :         json_response(StatusCode::BAD_REQUEST, "query is not supported")
     267                 :     }
     268 CBC          45 : }
        

Generated by: LCOV version 2.1-beta