LCOV - differential code coverage report
Current view: top level - proxy/src/http - websocket.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 49.8 % 239 119 120 119
Current Date: 2023-10-19 02:04:12 Functions: 48.8 % 43 21 22 21
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : use crate::{
       2                 :     cancellation::CancelMap,
       3                 :     config::ProxyConfig,
       4                 :     error::io_error,
       5                 :     protocol2::{ProxyProtocolAccept, WithClientIp},
       6                 :     proxy::{
       7                 :         handle_client, ClientMode, NUM_CLIENT_CONNECTION_CLOSED_COUNTER,
       8                 :         NUM_CLIENT_CONNECTION_OPENED_COUNTER,
       9                 :     },
      10                 : };
      11                 : use anyhow::bail;
      12                 : use bytes::{Buf, Bytes};
      13                 : use futures::{Sink, Stream, StreamExt};
      14                 : use hyper::{
      15                 :     server::{
      16                 :         accept,
      17                 :         conn::{AddrIncoming, AddrStream},
      18                 :     },
      19                 :     upgrade::Upgraded,
      20                 :     Body, Method, Request, Response, StatusCode,
      21                 : };
      22                 : use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
      23                 : use pin_project_lite::pin_project;
      24                 : 
      25                 : use std::{
      26                 :     future::ready,
      27                 :     pin::Pin,
      28                 :     sync::Arc,
      29                 :     task::{ready, Context, Poll},
      30                 : };
      31                 : use tls_listener::TlsListener;
      32                 : use tokio::{
      33                 :     io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf},
      34                 :     net::TcpListener,
      35                 : };
      36                 : use tokio_util::sync::CancellationToken;
      37                 : use tracing::{error, info, info_span, warn, Instrument};
      38                 : use utils::http::{error::ApiError, json::json_response};
      39                 : 
      40                 : // TODO: use `std::sync::Exclusive` once it's stabilized.
      41                 : // Tracking issue: https://github.com/rust-lang/rust/issues/98407.
      42                 : use sync_wrapper::SyncWrapper;
      43                 : 
      44                 : use super::{conn_pool::GlobalConnPool, sql_over_http};
      45                 : 
      46                 : pin_project! {
      47                 :     /// This is a wrapper around a [`WebSocketStream`] that
      48                 :     /// implements [`AsyncRead`] and [`AsyncWrite`].
      49                 :     pub struct WebSocketRw {
      50                 :         #[pin]
      51                 :         stream: SyncWrapper<WebSocketStream<Upgraded>>,
      52                 :         bytes: Bytes,
      53                 :     }
      54                 : }
      55                 : 
      56                 : impl WebSocketRw {
      57 UBC           0 :     pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
      58               0 :         Self {
      59               0 :             stream: stream.into(),
      60               0 :             bytes: Bytes::new(),
      61               0 :         }
      62               0 :     }
      63                 : }
      64                 : 
      65                 : impl AsyncWrite for WebSocketRw {
      66               0 :     fn poll_write(
      67               0 :         self: Pin<&mut Self>,
      68               0 :         cx: &mut Context<'_>,
      69               0 :         buf: &[u8],
      70               0 :     ) -> Poll<io::Result<usize>> {
      71               0 :         let mut stream = self.project().stream.get_pin_mut();
      72                 : 
      73               0 :         ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
      74               0 :         match stream.as_mut().start_send(Message::Binary(buf.into())) {
      75               0 :             Ok(()) => Poll::Ready(Ok(buf.len())),
      76               0 :             Err(e) => Poll::Ready(Err(io_error(e))),
      77                 :         }
      78               0 :     }
      79                 : 
      80               0 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      81               0 :         let stream = self.project().stream.get_pin_mut();
      82               0 :         stream.poll_flush(cx).map_err(io_error)
      83               0 :     }
      84                 : 
      85               0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      86               0 :         let stream = self.project().stream.get_pin_mut();
      87               0 :         stream.poll_close(cx).map_err(io_error)
      88               0 :     }
      89                 : }
      90                 : 
      91                 : impl AsyncRead for WebSocketRw {
      92               0 :     fn poll_read(
      93               0 :         mut self: Pin<&mut Self>,
      94               0 :         cx: &mut Context<'_>,
      95               0 :         buf: &mut ReadBuf<'_>,
      96               0 :     ) -> Poll<io::Result<()>> {
      97               0 :         if buf.remaining() > 0 {
      98               0 :             let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
      99               0 :             let len = std::cmp::min(bytes.len(), buf.remaining());
     100               0 :             buf.put_slice(&bytes[..len]);
     101               0 :             self.consume(len);
     102               0 :         }
     103                 : 
     104               0 :         Poll::Ready(Ok(()))
     105               0 :     }
     106                 : }
     107                 : 
     108                 : impl AsyncBufRead for WebSocketRw {
     109               0 :     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
     110               0 :         // Please refer to poll_fill_buf's documentation.
     111               0 :         const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
     112               0 : 
     113               0 :         let mut this = self.project();
     114               0 :         loop {
     115               0 :             if !this.bytes.chunk().is_empty() {
     116               0 :                 let chunk = (*this.bytes).chunk();
     117               0 :                 return Poll::Ready(Ok(chunk));
     118               0 :             }
     119                 : 
     120               0 :             let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
     121               0 :             match res.transpose().map_err(io_error)? {
     122               0 :                 Some(message) => match message {
     123               0 :                     Message::Ping(_) => {}
     124               0 :                     Message::Pong(_) => {}
     125               0 :                     Message::Text(text) => {
     126               0 :                         // We expect to see only binary messages.
     127               0 :                         let error = "unexpected text message in the websocket";
     128               0 :                         warn!(length = text.len(), error);
     129               0 :                         return Poll::Ready(Err(io_error(error)));
     130                 :                     }
     131                 :                     Message::Frame(_) => {
     132                 :                         // This case is impossible according to Frame's doc.
     133               0 :                         panic!("unexpected raw frame in the websocket");
     134                 :                     }
     135               0 :                     Message::Binary(chunk) => {
     136               0 :                         assert!(this.bytes.is_empty());
     137               0 :                         *this.bytes = Bytes::from(chunk);
     138                 :                     }
     139               0 :                     Message::Close(_) => return EOF,
     140                 :                 },
     141               0 :                 None => return EOF,
     142                 :             }
     143                 :         }
     144               0 :     }
     145                 : 
     146               0 :     fn consume(self: Pin<&mut Self>, amount: usize) {
     147               0 :         self.project().bytes.advance(amount);
     148               0 :     }
     149                 : }
     150                 : 
     151               0 : async fn serve_websocket(
     152               0 :     websocket: HyperWebsocket,
     153               0 :     config: &'static ProxyConfig,
     154               0 :     cancel_map: &CancelMap,
     155               0 :     session_id: uuid::Uuid,
     156               0 :     hostname: Option<String>,
     157               0 : ) -> anyhow::Result<()> {
     158               0 :     let websocket = websocket.await?;
     159               0 :     handle_client(
     160               0 :         config,
     161               0 :         cancel_map,
     162               0 :         session_id,
     163               0 :         WebSocketRw::new(websocket),
     164               0 :         ClientMode::Websockets { hostname },
     165               0 :     )
     166               0 :     .await?;
     167               0 :     Ok(())
     168               0 : }
     169                 : 
     170 CBC          30 : async fn ws_handler(
     171              30 :     mut request: Request<Body>,
     172              30 :     config: &'static ProxyConfig,
     173              30 :     conn_pool: Arc<GlobalConnPool>,
     174              30 :     cancel_map: Arc<CancelMap>,
     175              30 :     session_id: uuid::Uuid,
     176              30 :     sni_hostname: Option<String>,
     177              30 : ) -> Result<Response<Body>, ApiError> {
     178              30 :     let host = request
     179              30 :         .headers()
     180              30 :         .get("host")
     181              30 :         .and_then(|h| h.to_str().ok())
     182              30 :         .and_then(|h| h.split(':').next())
     183              30 :         .map(|s| s.to_string());
     184              30 : 
     185              30 :     // Check if the request is a websocket upgrade request.
     186              30 :     if hyper_tungstenite::is_upgrade_request(&request) {
     187 UBC           0 :         info!(session_id = ?session_id, "performing websocket upgrade");
     188                 : 
     189               0 :         let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
     190               0 :             .map_err(|e| ApiError::BadRequest(e.into()))?;
     191                 : 
     192               0 :         tokio::spawn(
     193               0 :             async move {
     194               0 :                 if let Err(e) =
     195               0 :                     serve_websocket(websocket, config, &cancel_map, session_id, host).await
     196                 :                 {
     197               0 :                     error!(session_id = ?session_id, "error in websocket connection: {e:#}");
     198               0 :                 }
     199               0 :             }
     200               0 :             .in_current_span(),
     201               0 :         );
     202               0 : 
     203               0 :         // Return the response so the spawned future can continue.
     204               0 :         Ok(response)
     205                 :     // TODO: that deserves a refactor as now this function also handles http json client besides websockets.
     206                 :     // Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead.
     207 CBC          30 :     } else if request.uri().path() == "/sql" && request.method() == Method::POST {
     208              30 :         sql_over_http::handle(
     209              30 :             request,
     210              30 :             sni_hostname,
     211              30 :             conn_pool,
     212              30 :             session_id,
     213              30 :             &config.http_config,
     214              30 :         )
     215             233 :         .await
     216 UBC           0 :     } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
     217               0 :         Response::builder()
     218               0 :             .header("Allow", "OPTIONS, POST")
     219               0 :             .header("Access-Control-Allow-Origin", "*")
     220               0 :             .header(
     221               0 :                 "Access-Control-Allow-Headers",
     222               0 :                 "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In",
     223               0 :             )
     224               0 :             .header("Access-Control-Max-Age", "86400" /* 24 hours */)
     225               0 :             .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
     226               0 :             .body(Body::empty())
     227               0 :             .map_err(|e| ApiError::InternalServerError(e.into()))
     228                 :     } else {
     229               0 :         json_response(StatusCode::BAD_REQUEST, "query is not supported")
     230                 :     }
     231 CBC          30 : }
     232                 : 
     233              16 : pub async fn task_main(
     234              16 :     config: &'static ProxyConfig,
     235              16 :     ws_listener: TcpListener,
     236              16 :     cancellation_token: CancellationToken,
     237              16 : ) -> anyhow::Result<()> {
     238              16 :     scopeguard::defer! {
     239              16 :         info!("websocket server has shut down");
     240              16 :     }
     241              16 : 
     242              16 :     let conn_pool: Arc<GlobalConnPool> = GlobalConnPool::new(config);
     243              16 : 
     244              16 :     // shutdown the connection pool
     245              16 :     tokio::spawn({
     246              16 :         let cancellation_token = cancellation_token.clone();
     247              16 :         let conn_pool = conn_pool.clone();
     248              16 :         async move {
     249              16 :             cancellation_token.cancelled().await;
     250              16 :             tokio::task::spawn_blocking(move || conn_pool.shutdown())
     251              16 :                 .await
     252              16 :                 .unwrap();
     253              16 :         }
     254              16 :     });
     255              16 : 
     256              16 :     let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
     257              16 :     let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
     258              16 :         Some(config) => config.into(),
     259                 :         None => {
     260 UBC           0 :             warn!("TLS config is missing, WebSocket Secure server will not be started");
     261               0 :             return Ok(());
     262                 :         }
     263                 :     };
     264                 : 
     265 CBC          16 :     let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
     266              16 :     let _ = addr_incoming.set_nodelay(true);
     267              16 :     let addr_incoming = ProxyProtocolAccept {
     268              16 :         incoming: addr_incoming,
     269              16 :     };
     270              16 : 
     271              16 :     let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
     272              30 :         if let Err(err) = conn {
     273 UBC           0 :             error!("failed to accept TLS connection for websockets: {err:?}");
     274               0 :             ready(false)
     275                 :         } else {
     276 CBC          30 :             ready(true)
     277                 :         }
     278              30 :     });
     279              16 : 
     280              16 :     let make_svc = hyper::service::make_service_fn(
     281              30 :         |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
     282              30 :             let (io, tls) = stream.get_ref();
     283              30 :             let client_addr = io.client_addr();
     284              30 :             let remote_addr = io.inner.remote_addr();
     285              30 :             let sni_name = tls.server_name().map(|s| s.to_string());
     286              30 :             let conn_pool = conn_pool.clone();
     287                 : 
     288              30 :             async move {
     289              30 :                 let peer_addr = match client_addr {
     290 UBC           0 :                     Some(addr) => addr,
     291               0 :                     None if config.require_client_ip => bail!("missing required client ip"),
     292 CBC          30 :                     None => remote_addr,
     293                 :                 };
     294              30 :                 Ok(MetricService::new(hyper::service::service_fn(
     295              30 :                     move |req: Request<Body>| {
     296              30 :                         let sni_name = sni_name.clone();
     297              30 :                         let conn_pool = conn_pool.clone();
     298                 : 
     299              30 :                         async move {
     300              30 :                             let cancel_map = Arc::new(CancelMap::default());
     301              30 :                             let session_id = uuid::Uuid::new_v4();
     302              30 : 
     303              30 :                             ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name)
     304              30 :                                 .instrument(info_span!(
     305              30 :                                     "ws-client",
     306              30 :                                     session = %session_id,
     307              30 :                                     %peer_addr,
     308              30 :                                 ))
     309             233 :                                 .await
     310              30 :                         }
     311              30 :                     },
     312              30 :                 )))
     313              30 :             }
     314              30 :         },
     315              16 :     );
     316              16 : 
     317              16 :     hyper::Server::builder(accept::from_stream(tls_listener))
     318              16 :         .serve(make_svc)
     319              16 :         .with_graceful_shutdown(cancellation_token.cancelled())
     320             166 :         .await?;
     321                 : 
     322              16 :     Ok(())
     323              16 : }
     324                 : 
     325                 : struct MetricService<S> {
     326                 :     inner: S,
     327                 : }
     328                 : 
     329                 : impl<S> MetricService<S> {
     330              30 :     fn new(inner: S) -> MetricService<S> {
     331              30 :         NUM_CLIENT_CONNECTION_OPENED_COUNTER
     332              30 :             .with_label_values(&["http"])
     333              30 :             .inc();
     334              30 :         MetricService { inner }
     335              30 :     }
     336                 : }
     337                 : 
     338                 : impl<S> Drop for MetricService<S> {
     339              30 :     fn drop(&mut self) {
     340              30 :         NUM_CLIENT_CONNECTION_CLOSED_COUNTER
     341              30 :             .with_label_values(&["http"])
     342              30 :             .inc();
     343              30 :     }
     344                 : }
     345                 : 
     346                 : impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
     347                 : where
     348                 :     S: hyper::service::Service<Request<ReqBody>>,
     349                 : {
     350                 :     type Response = S::Response;
     351                 :     type Error = S::Error;
     352                 :     type Future = S::Future;
     353                 : 
     354              90 :     fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
     355              90 :         self.inner.poll_ready(cx)
     356              90 :     }
     357                 : 
     358              30 :     fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
     359              30 :         self.inner.call(req)
     360              30 :     }
     361                 : }
        

Generated by: LCOV version 2.1-beta