LCOV - code coverage report
Current view: top level - proxy/src/serverless - websocket.rs (source / functions) Coverage Total Hit
Test: 691a4c28fe7169edd60b367c52d448a0a6605f1f.info Lines: 64.5 % 141 91
Test Date: 2024-05-10 13:18:37 Functions: 38.5 % 26 10

            Line data    Source code
       1              : use crate::{
       2              :     cancellation::CancellationHandlerMain,
       3              :     config::ProxyConfig,
       4              :     context::RequestMonitoring,
       5              :     error::{io_error, ReportableError},
       6              :     metrics::Metrics,
       7              :     proxy::{handle_client, ClientMode},
       8              :     rate_limiter::EndpointRateLimiter,
       9              : };
      10              : use bytes::{Buf, Bytes};
      11              : use futures::{Sink, Stream};
      12              : use hyper::upgrade::Upgraded;
      13              : use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
      14              : use pin_project_lite::pin_project;
      15              : 
      16              : use std::{
      17              :     pin::Pin,
      18              :     sync::Arc,
      19              :     task::{ready, Context, Poll},
      20              : };
      21              : use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
      22              : use tracing::warn;
      23              : 
      24              : // TODO: use `std::sync::Exclusive` once it's stabilized.
      25              : // Tracking issue: https://github.com/rust-lang/rust/issues/98407.
      26              : use sync_wrapper::SyncWrapper;
      27              : 
      28              : pin_project! {
      29              :     /// This is a wrapper around a [`WebSocketStream`] that
      30              :     /// implements [`AsyncRead`] and [`AsyncWrite`].
      31              :     pub struct WebSocketRw<S = Upgraded> {
      32              :         #[pin]
      33              :         stream: SyncWrapper<WebSocketStream<S>>,
      34              :         bytes: Bytes,
      35              :     }
      36              : }
      37              : 
      38              : impl<S> WebSocketRw<S> {
      39            2 :     pub fn new(stream: WebSocketStream<S>) -> Self {
      40            2 :         Self {
      41            2 :             stream: stream.into(),
      42            2 :             bytes: Bytes::new(),
      43            2 :         }
      44            2 :     }
      45              : }
      46              : 
      47              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
      48            2 :     fn poll_write(
      49            2 :         self: Pin<&mut Self>,
      50            2 :         cx: &mut Context<'_>,
      51            2 :         buf: &[u8],
      52            2 :     ) -> Poll<io::Result<usize>> {
      53            2 :         let mut stream = self.project().stream.get_pin_mut();
      54              : 
      55            2 :         ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
      56            2 :         match stream.as_mut().start_send(Message::Binary(buf.into())) {
      57            2 :             Ok(()) => Poll::Ready(Ok(buf.len())),
      58            0 :             Err(e) => Poll::Ready(Err(io_error(e))),
      59              :         }
      60            2 :     }
      61              : 
      62            2 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      63            2 :         let stream = self.project().stream.get_pin_mut();
      64            2 :         stream.poll_flush(cx).map_err(io_error)
      65            2 :     }
      66              : 
      67            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      68            0 :         let stream = self.project().stream.get_pin_mut();
      69            0 :         stream.poll_close(cx).map_err(io_error)
      70            0 :     }
      71              : }
      72              : 
      73              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
      74            6 :     fn poll_read(
      75            6 :         mut self: Pin<&mut Self>,
      76            6 :         cx: &mut Context<'_>,
      77            6 :         buf: &mut ReadBuf<'_>,
      78            6 :     ) -> Poll<io::Result<()>> {
      79            6 :         if buf.remaining() > 0 {
      80            6 :             let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
      81            4 :             let len = std::cmp::min(bytes.len(), buf.remaining());
      82            4 :             buf.put_slice(&bytes[..len]);
      83            4 :             self.consume(len);
      84            0 :         }
      85              : 
      86            4 :         Poll::Ready(Ok(()))
      87            6 :     }
      88              : }
      89              : 
      90              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
      91            6 :     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
      92            6 :         // Please refer to poll_fill_buf's documentation.
      93            6 :         const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
      94            6 : 
      95            6 :         let mut this = self.project();
      96            8 :         loop {
      97            8 :             if !this.bytes.chunk().is_empty() {
      98            2 :                 let chunk = (*this.bytes).chunk();
      99            2 :                 return Poll::Ready(Ok(chunk));
     100            6 :             }
     101              : 
     102            6 :             let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
     103            4 :             match res.transpose().map_err(io_error)? {
     104            4 :                 Some(message) => match message {
     105            0 :                     Message::Ping(_) => {}
     106            0 :                     Message::Pong(_) => {}
     107            0 :                     Message::Text(text) => {
     108            0 :                         // We expect to see only binary messages.
     109            0 :                         let error = "unexpected text message in the websocket";
     110            0 :                         warn!(length = text.len(), error);
     111            0 :                         return Poll::Ready(Err(io_error(error)));
     112              :                     }
     113              :                     Message::Frame(_) => {
     114              :                         // This case is impossible according to Frame's doc.
     115            0 :                         panic!("unexpected raw frame in the websocket");
     116              :                     }
     117            2 :                     Message::Binary(chunk) => {
     118            2 :                         assert!(this.bytes.is_empty());
     119            2 :                         *this.bytes = Bytes::from(chunk);
     120              :                     }
     121            2 :                     Message::Close(_) => return EOF,
     122              :                 },
     123            0 :                 None => return EOF,
     124              :             }
     125              :         }
     126            6 :     }
     127              : 
     128            4 :     fn consume(self: Pin<&mut Self>, amount: usize) {
     129            4 :         self.project().bytes.advance(amount);
     130            4 :     }
     131              : }
     132              : 
     133            0 : pub async fn serve_websocket(
     134            0 :     config: &'static ProxyConfig,
     135            0 :     mut ctx: RequestMonitoring,
     136            0 :     websocket: HyperWebsocket,
     137            0 :     cancellation_handler: Arc<CancellationHandlerMain>,
     138            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     139            0 :     hostname: Option<String>,
     140            0 : ) -> anyhow::Result<()> {
     141            0 :     let websocket = websocket.await?;
     142            0 :     let conn_gauge = Metrics::get()
     143            0 :         .proxy
     144            0 :         .client_connections
     145            0 :         .guard(crate::metrics::Protocol::Ws);
     146              : 
     147            0 :     let res = handle_client(
     148            0 :         config,
     149            0 :         &mut ctx,
     150            0 :         cancellation_handler,
     151            0 :         WebSocketRw::new(websocket),
     152            0 :         ClientMode::Websockets { hostname },
     153            0 :         endpoint_rate_limiter,
     154            0 :         conn_gauge,
     155            0 :     )
     156            0 :     .await;
     157              : 
     158            0 :     match res {
     159            0 :         Err(e) => {
     160            0 :             // todo: log and push to ctx the error kind
     161            0 :             ctx.set_error_kind(e.get_error_kind());
     162            0 :             Err(e.into())
     163              :         }
     164              :         Ok(None) => {
     165            0 :             ctx.set_success();
     166            0 :             Ok(())
     167              :         }
     168            0 :         Ok(Some(p)) => {
     169            0 :             ctx.set_success();
     170            0 :             ctx.log_connect();
     171            0 :             p.proxy_pass().await
     172              :         }
     173              :     }
     174            0 : }
     175              : 
     176              : #[cfg(test)]
     177              : mod tests {
     178              :     use std::pin::pin;
     179              : 
     180              :     use futures::{SinkExt, StreamExt};
     181              :     use hyper_tungstenite::{
     182              :         tungstenite::{protocol::Role, Message},
     183              :         WebSocketStream,
     184              :     };
     185              :     use tokio::{
     186              :         io::{duplex, AsyncReadExt, AsyncWriteExt},
     187              :         task::JoinSet,
     188              :     };
     189              : 
     190              :     use super::WebSocketRw;
     191              : 
     192              :     #[tokio::test]
     193            2 :     async fn websocket_stream_wrapper_happy_path() {
     194            2 :         let (stream1, stream2) = duplex(1024);
     195            2 : 
     196            2 :         let mut js = JoinSet::new();
     197            2 : 
     198            2 :         js.spawn(async move {
     199            2 :             let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
     200            2 : 
     201            2 :             client
     202            2 :                 .send(Message::Binary(b"hello world".to_vec()))
     203            2 :                 .await
     204            2 :                 .unwrap();
     205            2 : 
     206            2 :             let message = client.next().await.unwrap().unwrap();
     207            2 :             assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
     208            2 : 
     209            2 :             client.close(None).await.unwrap();
     210            2 :         });
     211            2 : 
     212            2 :         js.spawn(async move {
     213            2 :             let mut rw = pin!(WebSocketRw::new(
     214            2 :                 WebSocketStream::from_raw_socket(stream2, Role::Server, None).await
     215            2 :             ));
     216            2 : 
     217            2 :             let mut buf = vec![0; 1024];
     218            2 :             let n = rw.read(&mut buf).await.unwrap();
     219            2 :             assert_eq!(&buf[..n], b"hello world");
     220            2 : 
     221            2 :             rw.write_all(b"websockets are cool").await.unwrap();
     222            2 :             rw.flush().await.unwrap();
     223            2 : 
     224            2 :             let n = rw.read_to_end(&mut buf).await.unwrap();
     225            2 :             assert_eq!(n, 0);
     226            2 :         });
     227            2 : 
     228            2 :         js.join_next().await.unwrap().unwrap();
     229            2 :         js.join_next().await.unwrap().unwrap();
     230            2 :     }
     231              : }
        

Generated by: LCOV version 2.1-beta