LCOV - code coverage report
Current view: top level - proxy/src/serverless - websocket.rs (source / functions) Coverage Total Hit
Test: f081ec316c96fa98335efd15ef501745aa4f015d.info Lines: 65.2 % 138 90
Test Date: 2024-06-25 15:11:17 Functions: 38.5 % 26 10

            Line data    Source code
       1              : use crate::{
       2              :     cancellation::CancellationHandlerMain,
       3              :     config::ProxyConfig,
       4              :     context::RequestMonitoring,
       5              :     error::{io_error, ReportableError},
       6              :     metrics::Metrics,
       7              :     proxy::{handle_client, ClientMode},
       8              :     rate_limiter::EndpointRateLimiter,
       9              : };
      10              : use bytes::{Buf, BufMut, Bytes, BytesMut};
      11              : use framed_websockets::{Frame, OpCode, WebSocketServer};
      12              : use futures::{Sink, Stream};
      13              : use hyper1::upgrade::OnUpgrade;
      14              : use hyper_util::rt::TokioIo;
      15              : use pin_project_lite::pin_project;
      16              : 
      17              : use std::{
      18              :     pin::Pin,
      19              :     sync::Arc,
      20              :     task::{ready, Context, Poll},
      21              : };
      22              : use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
      23              : use tracing::warn;
      24              : 
      25              : pin_project! {
      26              :     /// This is a wrapper around a [`WebSocketStream`] that
      27              :     /// implements [`AsyncRead`] and [`AsyncWrite`].
      28              :     pub struct WebSocketRw<S> {
      29              :         #[pin]
      30              :         stream: WebSocketServer<S>,
      31              :         recv: Bytes,
      32              :         send: BytesMut,
      33              :     }
      34              : }
      35              : 
      36              : impl<S> WebSocketRw<S> {
      37            2 :     pub fn new(stream: WebSocketServer<S>) -> Self {
      38            2 :         Self {
      39            2 :             stream,
      40            2 :             recv: Bytes::new(),
      41            2 :             send: BytesMut::new(),
      42            2 :         }
      43            2 :     }
      44              : }
      45              : 
      46              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
      47            2 :     fn poll_write(
      48            2 :         self: Pin<&mut Self>,
      49            2 :         cx: &mut Context<'_>,
      50            2 :         buf: &[u8],
      51            2 :     ) -> Poll<io::Result<usize>> {
      52            2 :         let this = self.project();
      53            2 :         let mut stream = this.stream;
      54              : 
      55            2 :         ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
      56              : 
      57            2 :         this.send.put(buf);
      58            2 :         match stream.as_mut().start_send(Frame::binary(this.send.split())) {
      59            2 :             Ok(()) => Poll::Ready(Ok(buf.len())),
      60            0 :             Err(e) => Poll::Ready(Err(io_error(e))),
      61              :         }
      62            2 :     }
      63              : 
      64            2 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      65            2 :         let stream = self.project().stream;
      66            2 :         stream.poll_flush(cx).map_err(io_error)
      67            2 :     }
      68              : 
      69            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      70            0 :         let stream = self.project().stream;
      71            0 :         stream.poll_close(cx).map_err(io_error)
      72            0 :     }
      73              : }
      74              : 
      75              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
      76            6 :     fn poll_read(
      77            6 :         mut self: Pin<&mut Self>,
      78            6 :         cx: &mut Context<'_>,
      79            6 :         buf: &mut ReadBuf<'_>,
      80            6 :     ) -> Poll<io::Result<()>> {
      81            6 :         let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
      82            4 :         let len = std::cmp::min(bytes.len(), buf.remaining());
      83            4 :         buf.put_slice(&bytes[..len]);
      84            4 :         self.consume(len);
      85            4 :         Poll::Ready(Ok(()))
      86            6 :     }
      87              : }
      88              : 
      89              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
      90            6 :     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
      91            6 :         // Please refer to poll_fill_buf's documentation.
      92            6 :         const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
      93            6 : 
      94            6 :         let mut this = self.project();
      95            8 :         loop {
      96            8 :             if !this.recv.chunk().is_empty() {
      97            2 :                 let chunk = (*this.recv).chunk();
      98            2 :                 return Poll::Ready(Ok(chunk));
      99            6 :             }
     100              : 
     101            6 :             let res = ready!(this.stream.as_mut().poll_next(cx));
     102            4 :             match res.transpose().map_err(io_error)? {
     103            4 :                 Some(message) => match message.opcode {
     104            0 :                     OpCode::Ping => {}
     105            0 :                     OpCode::Pong => {}
     106              :                     OpCode::Text => {
     107              :                         // We expect to see only binary messages.
     108            0 :                         let error = "unexpected text message in the websocket";
     109            0 :                         warn!(length = message.payload.len(), error);
     110            0 :                         return Poll::Ready(Err(io_error(error)));
     111              :                     }
     112              :                     OpCode::Binary | OpCode::Continuation => {
     113            2 :                         debug_assert!(this.recv.is_empty());
     114            2 :                         *this.recv = message.payload.freeze();
     115              :                     }
     116            2 :                     OpCode::Close => return EOF,
     117              :                 },
     118            0 :                 None => return EOF,
     119              :             }
     120              :         }
     121            6 :     }
     122              : 
     123            4 :     fn consume(self: Pin<&mut Self>, amount: usize) {
     124            4 :         self.project().recv.advance(amount);
     125            4 :     }
     126              : }
     127              : 
     128            0 : pub async fn serve_websocket(
     129            0 :     config: &'static ProxyConfig,
     130            0 :     mut ctx: RequestMonitoring,
     131            0 :     websocket: OnUpgrade,
     132            0 :     cancellation_handler: Arc<CancellationHandlerMain>,
     133            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     134            0 :     hostname: Option<String>,
     135            0 : ) -> anyhow::Result<()> {
     136            0 :     let websocket = websocket.await?;
     137            0 :     let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
     138            0 : 
     139            0 :     let conn_gauge = Metrics::get()
     140            0 :         .proxy
     141            0 :         .client_connections
     142            0 :         .guard(crate::metrics::Protocol::Ws);
     143              : 
     144            0 :     let res = Box::pin(handle_client(
     145            0 :         config,
     146            0 :         &mut ctx,
     147            0 :         cancellation_handler,
     148            0 :         WebSocketRw::new(websocket),
     149            0 :         ClientMode::Websockets { hostname },
     150            0 :         endpoint_rate_limiter,
     151            0 :         conn_gauge,
     152            0 :     ))
     153            0 :     .await;
     154              : 
     155            0 :     match res {
     156            0 :         Err(e) => {
     157            0 :             // todo: log and push to ctx the error kind
     158            0 :             ctx.set_error_kind(e.get_error_kind());
     159            0 :             Err(e.into())
     160              :         }
     161              :         Ok(None) => {
     162            0 :             ctx.set_success();
     163            0 :             Ok(())
     164              :         }
     165            0 :         Ok(Some(p)) => {
     166            0 :             ctx.set_success();
     167            0 :             ctx.log_connect();
     168            0 :             p.proxy_pass().await
     169              :         }
     170              :     }
     171            0 : }
     172              : 
     173              : #[cfg(test)]
     174              : mod tests {
     175              :     use std::pin::pin;
     176              : 
     177              :     use framed_websockets::WebSocketServer;
     178              :     use futures::{SinkExt, StreamExt};
     179              :     use tokio::{
     180              :         io::{duplex, AsyncReadExt, AsyncWriteExt},
     181              :         task::JoinSet,
     182              :     };
     183              :     use tokio_tungstenite::{
     184              :         tungstenite::{protocol::Role, Message},
     185              :         WebSocketStream,
     186              :     };
     187              : 
     188              :     use super::WebSocketRw;
     189              : 
     190              :     #[tokio::test]
     191            2 :     async fn websocket_stream_wrapper_happy_path() {
     192            2 :         let (stream1, stream2) = duplex(1024);
     193            2 : 
     194            2 :         let mut js = JoinSet::new();
     195            2 : 
     196            2 :         js.spawn(async move {
     197            2 :             let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
     198            2 : 
     199            2 :             client
     200            2 :                 .send(Message::Binary(b"hello world".to_vec()))
     201            2 :                 .await
     202            2 :                 .unwrap();
     203            2 : 
     204            2 :             let message = client.next().await.unwrap().unwrap();
     205            2 :             assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
     206            2 : 
     207            2 :             client.close(None).await.unwrap();
     208            2 :         });
     209            2 : 
     210            2 :         js.spawn(async move {
     211            2 :             let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
     212            2 : 
     213            2 :             let mut buf = vec![0; 1024];
     214            2 :             let n = rw.read(&mut buf).await.unwrap();
     215            2 :             assert_eq!(&buf[..n], b"hello world");
     216            2 : 
     217            2 :             rw.write_all(b"websockets are cool").await.unwrap();
     218            2 :             rw.flush().await.unwrap();
     219            2 : 
     220            2 :             let n = rw.read_to_end(&mut buf).await.unwrap();
     221            2 :             assert_eq!(n, 0);
     222            2 :         });
     223            2 : 
     224            2 :         js.join_next().await.unwrap().unwrap();
     225            2 :         js.join_next().await.unwrap().unwrap();
     226            2 :     }
     227              : }
        

Generated by: LCOV version 2.1-beta