LCOV - code coverage report
Current view: top level - proxy/src/http - websocket.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 50.0 % 240 120
Test Date: 2023-09-06 10:18:01 Functions: 51.4 % 37 19

            Line data    Source code
       1              : use crate::{
       2              :     cancellation::CancelMap,
       3              :     config::ProxyConfig,
       4              :     error::io_error,
       5              :     protocol2::{ProxyProtocolAccept, WithClientIp},
       6              :     proxy::{handle_client, ClientMode},
       7              : };
       8              : use bytes::{Buf, Bytes};
       9              : use futures::{Sink, Stream, StreamExt};
      10              : use hashbrown::HashMap;
      11              : use hyper::{
      12              :     server::{
      13              :         accept,
      14              :         conn::{AddrIncoming, AddrStream},
      15              :     },
      16              :     upgrade::Upgraded,
      17              :     Body, Method, Request, Response, StatusCode,
      18              : };
      19              : use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
      20              : use pin_project_lite::pin_project;
      21              : use serde_json::{json, Value};
      22              : 
      23              : use std::{
      24              :     convert::Infallible,
      25              :     future::ready,
      26              :     pin::Pin,
      27              :     sync::Arc,
      28              :     task::{ready, Context, Poll},
      29              : };
      30              : use tls_listener::TlsListener;
      31              : use tokio::{
      32              :     io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf},
      33              :     net::TcpListener,
      34              : };
      35              : use tokio_util::sync::CancellationToken;
      36              : use tracing::{error, info, info_span, warn, Instrument};
      37              : use utils::http::{error::ApiError, json::json_response};
      38              : 
      39              : // TODO: use `std::sync::Exclusive` once it's stabilized.
      40              : // Tracking issue: https://github.com/rust-lang/rust/issues/98407.
      41              : use sync_wrapper::SyncWrapper;
      42              : 
      43              : use super::{conn_pool::GlobalConnPool, sql_over_http};
      44              : 
      45              : pin_project! {
      46              :     /// This is a wrapper around a [`WebSocketStream`] that
      47              :     /// implements [`AsyncRead`] and [`AsyncWrite`].
      48              :     pub struct WebSocketRw {
      49              :         #[pin]
      50              :         stream: SyncWrapper<WebSocketStream<Upgraded>>,
      51              :         bytes: Bytes,
      52              :     }
      53              : }
      54              : 
      55              : impl WebSocketRw {
      56            0 :     pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
      57            0 :         Self {
      58            0 :             stream: stream.into(),
      59            0 :             bytes: Bytes::new(),
      60            0 :         }
      61            0 :     }
      62              : }
      63              : 
      64              : impl AsyncWrite for WebSocketRw {
      65            0 :     fn poll_write(
      66            0 :         self: Pin<&mut Self>,
      67            0 :         cx: &mut Context<'_>,
      68            0 :         buf: &[u8],
      69            0 :     ) -> Poll<io::Result<usize>> {
      70            0 :         let mut stream = self.project().stream.get_pin_mut();
      71              : 
      72            0 :         ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
      73            0 :         match stream.as_mut().start_send(Message::Binary(buf.into())) {
      74            0 :             Ok(()) => Poll::Ready(Ok(buf.len())),
      75            0 :             Err(e) => Poll::Ready(Err(io_error(e))),
      76              :         }
      77            0 :     }
      78              : 
      79            0 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      80            0 :         let stream = self.project().stream.get_pin_mut();
      81            0 :         stream.poll_flush(cx).map_err(io_error)
      82            0 :     }
      83              : 
      84            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      85            0 :         let stream = self.project().stream.get_pin_mut();
      86            0 :         stream.poll_close(cx).map_err(io_error)
      87            0 :     }
      88              : }
      89              : 
      90              : impl AsyncRead for WebSocketRw {
      91            0 :     fn poll_read(
      92            0 :         mut self: Pin<&mut Self>,
      93            0 :         cx: &mut Context<'_>,
      94            0 :         buf: &mut ReadBuf<'_>,
      95            0 :     ) -> Poll<io::Result<()>> {
      96            0 :         if buf.remaining() > 0 {
      97            0 :             let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
      98            0 :             let len = std::cmp::min(bytes.len(), buf.remaining());
      99            0 :             buf.put_slice(&bytes[..len]);
     100            0 :             self.consume(len);
     101            0 :         }
     102              : 
     103            0 :         Poll::Ready(Ok(()))
     104            0 :     }
     105              : }
     106              : 
     107              : impl AsyncBufRead for WebSocketRw {
     108            0 :     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
     109            0 :         // Please refer to poll_fill_buf's documentation.
     110            0 :         const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
     111            0 : 
     112            0 :         let mut this = self.project();
     113            0 :         loop {
     114            0 :             if !this.bytes.chunk().is_empty() {
     115            0 :                 let chunk = (*this.bytes).chunk();
     116            0 :                 return Poll::Ready(Ok(chunk));
     117            0 :             }
     118              : 
     119            0 :             let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
     120            0 :             match res.transpose().map_err(io_error)? {
     121            0 :                 Some(message) => match message {
     122            0 :                     Message::Ping(_) => {}
     123            0 :                     Message::Pong(_) => {}
     124            0 :                     Message::Text(text) => {
     125            0 :                         // We expect to see only binary messages.
     126            0 :                         let error = "unexpected text message in the websocket";
     127            0 :                         warn!(length = text.len(), error);
     128            0 :                         return Poll::Ready(Err(io_error(error)));
     129              :                     }
     130              :                     Message::Frame(_) => {
     131              :                         // This case is impossible according to Frame's doc.
     132            0 :                         panic!("unexpected raw frame in the websocket");
     133              :                     }
     134            0 :                     Message::Binary(chunk) => {
     135            0 :                         assert!(this.bytes.is_empty());
     136            0 :                         *this.bytes = Bytes::from(chunk);
     137              :                     }
     138            0 :                     Message::Close(_) => return EOF,
     139              :                 },
     140            0 :                 None => return EOF,
     141              :             }
     142              :         }
     143            0 :     }
     144              : 
     145            0 :     fn consume(self: Pin<&mut Self>, amount: usize) {
     146            0 :         self.project().bytes.advance(amount);
     147            0 :     }
     148              : }
     149              : 
     150            0 : async fn serve_websocket(
     151            0 :     websocket: HyperWebsocket,
     152            0 :     config: &'static ProxyConfig,
     153            0 :     cancel_map: &CancelMap,
     154            0 :     session_id: uuid::Uuid,
     155            0 :     hostname: Option<String>,
     156            0 : ) -> anyhow::Result<()> {
     157            0 :     let websocket = websocket.await?;
     158            0 :     handle_client(
     159            0 :         config,
     160            0 :         cancel_map,
     161            0 :         session_id,
     162            0 :         WebSocketRw::new(websocket),
     163            0 :         ClientMode::Websockets { hostname },
     164            0 :     )
     165            0 :     .await?;
     166            0 :     Ok(())
     167            0 : }
     168              : 
     169           22 : async fn ws_handler(
     170           22 :     mut request: Request<Body>,
     171           22 :     config: &'static ProxyConfig,
     172           22 :     conn_pool: Arc<GlobalConnPool>,
     173           22 :     cancel_map: Arc<CancelMap>,
     174           22 :     session_id: uuid::Uuid,
     175           22 :     sni_hostname: Option<String>,
     176           22 : ) -> Result<Response<Body>, ApiError> {
     177           22 :     let host = request
     178           22 :         .headers()
     179           22 :         .get("host")
     180           22 :         .and_then(|h| h.to_str().ok())
     181           22 :         .and_then(|h| h.split(':').next())
     182           22 :         .map(|s| s.to_string());
     183           22 : 
     184           22 :     // Check if the request is a websocket upgrade request.
     185           22 :     if hyper_tungstenite::is_upgrade_request(&request) {
     186            0 :         info!(session_id = ?session_id, "performing websocket upgrade");
     187              : 
     188            0 :         let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
     189            0 :             .map_err(|e| ApiError::BadRequest(e.into()))?;
     190              : 
     191            0 :         tokio::spawn(
     192            0 :             async move {
     193            0 :                 if let Err(e) =
     194            0 :                     serve_websocket(websocket, config, &cancel_map, session_id, host).await
     195              :                 {
     196            0 :                     error!(session_id = ?session_id, "error in websocket connection: {e:#}");
     197            0 :                 }
     198            0 :             }
     199            0 :             .in_current_span(),
     200            0 :         );
     201            0 : 
     202            0 :         // Return the response so the spawned future can continue.
     203            0 :         Ok(response)
     204              :     // TODO: that deserves a refactor as now this function also handles http json client besides websockets.
     205              :     // Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead.
     206           22 :     } else if request.uri().path() == "/sql" && request.method() == Method::POST {
     207           22 :         let result = sql_over_http::handle(request, sni_hostname, conn_pool, session_id)
     208           22 :             .instrument(info_span!("sql-over-http"))
     209          149 :             .await;
     210           22 :         let status_code = match result {
     211           20 :             Ok(_) => StatusCode::OK,
     212            2 :             Err(_) => StatusCode::BAD_REQUEST,
     213              :         };
     214           22 :         let (json, headers) = match result {
     215           20 :             Ok(r) => r,
     216            2 :             Err(e) => {
     217            2 :                 let message = format!("{:?}", e);
     218            2 :                 let code = match e.downcast_ref::<tokio_postgres::Error>() {
     219            2 :                     Some(e) => match e.code() {
     220            2 :                         Some(e) => serde_json::to_value(e.code()).unwrap(),
     221            0 :                         None => Value::Null,
     222              :                     },
     223            0 :                     None => Value::Null,
     224              :                 };
     225            2 :                 error!(
     226            2 :                     ?code,
     227            2 :                     "sql-over-http per-client task finished with an error: {e:#}"
     228            2 :                 );
     229            2 :                 (
     230            2 :                     json!({ "message": message, "code": code }),
     231            2 :                     HashMap::default(),
     232            2 :                 )
     233              :             }
     234              :         };
     235           22 :         json_response(status_code, json).map(|mut r| {
     236           22 :             r.headers_mut().insert(
     237           22 :                 "Access-Control-Allow-Origin",
     238           22 :                 hyper::http::HeaderValue::from_static("*"),
     239           22 :             );
     240           26 :             for (k, v) in headers {
     241            4 :                 r.headers_mut().insert(k, v);
     242            4 :             }
     243           22 :             r
     244           22 :         })
     245            0 :     } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
     246            0 :         Response::builder()
     247            0 :             .header("Allow", "OPTIONS, POST")
     248            0 :             .header("Access-Control-Allow-Origin", "*")
     249            0 :             .header(
     250            0 :                 "Access-Control-Allow-Headers",
     251            0 :                 "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In",
     252            0 :             )
     253            0 :             .header("Access-Control-Max-Age", "86400" /* 24 hours */)
     254            0 :             .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
     255            0 :             .body(Body::empty())
     256            0 :             .map_err(|e| ApiError::BadRequest(e.into()))
     257              :     } else {
     258            0 :         json_response(StatusCode::BAD_REQUEST, "query is not supported")
     259              :     }
     260           22 : }
     261              : 
     262           14 : pub async fn task_main(
     263           14 :     config: &'static ProxyConfig,
     264           14 :     ws_listener: TcpListener,
     265           14 :     cancellation_token: CancellationToken,
     266           14 : ) -> anyhow::Result<()> {
     267           14 :     scopeguard::defer! {
     268           14 :         info!("websocket server has shut down");
     269           14 :     }
     270           14 : 
     271           14 :     let conn_pool: Arc<GlobalConnPool> = GlobalConnPool::new(config);
     272           14 : 
     273           14 :     // shutdown the connection pool
     274           14 :     tokio::spawn({
     275           14 :         let cancellation_token = cancellation_token.clone();
     276           14 :         let conn_pool = conn_pool.clone();
     277           14 :         async move {
     278           14 :             cancellation_token.cancelled().await;
     279           14 :             tokio::task::spawn_blocking(move || conn_pool.shutdown())
     280           14 :                 .await
     281           14 :                 .unwrap();
     282           14 :         }
     283           14 :     });
     284           14 : 
     285           14 :     let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
     286           14 :     let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
     287           14 :         Some(config) => config.into(),
     288              :         None => {
     289            0 :             warn!("TLS config is missing, WebSocket Secure server will not be started");
     290            0 :             return Ok(());
     291              :         }
     292              :     };
     293              : 
     294           14 :     let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
     295           14 :     let _ = addr_incoming.set_nodelay(true);
     296           14 :     let addr_incoming = ProxyProtocolAccept {
     297           14 :         incoming: addr_incoming,
     298           14 :     };
     299           14 : 
     300           14 :     let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
     301           22 :         if let Err(err) = conn {
     302            0 :             error!("failed to accept TLS connection for websockets: {err:?}");
     303            0 :             ready(false)
     304              :         } else {
     305           22 :             ready(true)
     306              :         }
     307           22 :     });
     308           14 : 
     309           14 :     let make_svc = hyper::service::make_service_fn(
     310           22 :         |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
     311           22 :             let (io, tls) = stream.get_ref();
     312           22 :             let peer_addr = io.client_addr().unwrap_or(io.inner.remote_addr());
     313           22 :             let sni_name = tls.server_name().map(|s| s.to_string());
     314           22 :             let conn_pool = conn_pool.clone();
     315              : 
     316           22 :             async move {
     317           22 :                 Ok::<_, Infallible>(hyper::service::service_fn(move |req: Request<Body>| {
     318           22 :                     let sni_name = sni_name.clone();
     319           22 :                     let conn_pool = conn_pool.clone();
     320              : 
     321           22 :                     async move {
     322           22 :                         let cancel_map = Arc::new(CancelMap::default());
     323           22 :                         let session_id = uuid::Uuid::new_v4();
     324           22 : 
     325           22 :                         ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name)
     326           22 :                             .instrument(info_span!(
     327           22 :                                 "ws-client",
     328           22 :                                 session = %session_id,
     329           22 :                                 %peer_addr,
     330           22 :                             ))
     331          149 :                             .await
     332           22 :                     }
     333           22 :                 }))
     334           22 :             }
     335           22 :         },
     336           14 :     );
     337           14 : 
     338           14 :     hyper::Server::builder(accept::from_stream(tls_listener))
     339           14 :         .serve(make_svc)
     340           14 :         .with_graceful_shutdown(cancellation_token.cancelled())
     341          124 :         .await?;
     342              : 
     343           14 :     Ok(())
     344           14 : }
        

Generated by: LCOV version 2.1-beta