TLA Line data Source code
1 : use crate::{
2 : cancellation::CancelMap,
3 : config::ProxyConfig,
4 : context::RequestMonitoring,
5 : error::io_error,
6 : proxy::{handle_client, ClientMode},
7 : rate_limiter::EndpointRateLimiter,
8 : };
9 : use bytes::{Buf, Bytes};
10 : use futures::{Sink, Stream};
11 : use hyper::upgrade::Upgraded;
12 : use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
13 : use pin_project_lite::pin_project;
14 :
15 : use std::{
16 : pin::Pin,
17 : sync::Arc,
18 : task::{ready, Context, Poll},
19 : };
20 : use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
21 : use tracing::warn;
22 :
23 : // TODO: use `std::sync::Exclusive` once it's stabilized.
24 : // Tracking issue: https://github.com/rust-lang/rust/issues/98407.
25 : use sync_wrapper::SyncWrapper;
26 :
27 : pin_project! {
28 : /// This is a wrapper around a [`WebSocketStream`] that
29 : /// implements [`AsyncRead`] and [`AsyncWrite`].
30 : pub struct WebSocketRw<S = Upgraded> {
31 : #[pin]
32 : stream: SyncWrapper<WebSocketStream<S>>,
33 : bytes: Bytes,
34 : }
35 : }
36 :
37 : impl<S> WebSocketRw<S> {
38 CBC 1 : pub fn new(stream: WebSocketStream<S>) -> Self {
39 1 : Self {
40 1 : stream: stream.into(),
41 1 : bytes: Bytes::new(),
42 1 : }
43 1 : }
44 : }
45 :
46 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
47 1 : fn poll_write(
48 1 : self: Pin<&mut Self>,
49 1 : cx: &mut Context<'_>,
50 1 : buf: &[u8],
51 1 : ) -> Poll<io::Result<usize>> {
52 1 : let mut stream = self.project().stream.get_pin_mut();
53 :
54 1 : ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
55 1 : match stream.as_mut().start_send(Message::Binary(buf.into())) {
56 1 : Ok(()) => Poll::Ready(Ok(buf.len())),
57 UBC 0 : Err(e) => Poll::Ready(Err(io_error(e))),
58 : }
59 CBC 1 : }
60 :
61 1 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
62 1 : let stream = self.project().stream.get_pin_mut();
63 1 : stream.poll_flush(cx).map_err(io_error)
64 1 : }
65 :
66 UBC 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67 0 : let stream = self.project().stream.get_pin_mut();
68 0 : stream.poll_close(cx).map_err(io_error)
69 0 : }
70 : }
71 :
72 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
73 CBC 3 : fn poll_read(
74 3 : mut self: Pin<&mut Self>,
75 3 : cx: &mut Context<'_>,
76 3 : buf: &mut ReadBuf<'_>,
77 3 : ) -> Poll<io::Result<()>> {
78 3 : if buf.remaining() > 0 {
79 3 : let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
80 2 : let len = std::cmp::min(bytes.len(), buf.remaining());
81 2 : buf.put_slice(&bytes[..len]);
82 2 : self.consume(len);
83 UBC 0 : }
84 :
85 CBC 2 : Poll::Ready(Ok(()))
86 3 : }
87 : }
88 :
89 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
90 3 : fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
91 3 : // Please refer to poll_fill_buf's documentation.
92 3 : const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
93 3 :
94 3 : let mut this = self.project();
95 4 : loop {
96 4 : if !this.bytes.chunk().is_empty() {
97 1 : let chunk = (*this.bytes).chunk();
98 1 : return Poll::Ready(Ok(chunk));
99 3 : }
100 :
101 3 : let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
102 2 : match res.transpose().map_err(io_error)? {
103 2 : Some(message) => match message {
104 UBC 0 : Message::Ping(_) => {}
105 0 : Message::Pong(_) => {}
106 0 : Message::Text(text) => {
107 0 : // We expect to see only binary messages.
108 0 : let error = "unexpected text message in the websocket";
109 0 : warn!(length = text.len(), error);
110 0 : return Poll::Ready(Err(io_error(error)));
111 : }
112 : Message::Frame(_) => {
113 : // This case is impossible according to Frame's doc.
114 0 : panic!("unexpected raw frame in the websocket");
115 : }
116 CBC 1 : Message::Binary(chunk) => {
117 1 : assert!(this.bytes.is_empty());
118 1 : *this.bytes = Bytes::from(chunk);
119 : }
120 1 : Message::Close(_) => return EOF,
121 : },
122 UBC 0 : None => return EOF,
123 : }
124 : }
125 CBC 3 : }
126 :
127 2 : fn consume(self: Pin<&mut Self>, amount: usize) {
128 2 : self.project().bytes.advance(amount);
129 2 : }
130 : }
131 :
132 UBC 0 : pub async fn serve_websocket(
133 0 : config: &'static ProxyConfig,
134 0 : ctx: &mut RequestMonitoring,
135 0 : websocket: HyperWebsocket,
136 0 : cancel_map: &CancelMap,
137 0 : hostname: Option<String>,
138 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
139 0 : ) -> anyhow::Result<()> {
140 0 : let websocket = websocket.await?;
141 0 : handle_client(
142 0 : config,
143 0 : ctx,
144 0 : cancel_map,
145 0 : WebSocketRw::new(websocket),
146 0 : ClientMode::Websockets { hostname },
147 0 : endpoint_rate_limiter,
148 0 : )
149 0 : .await?;
150 0 : Ok(())
151 0 : }
152 :
153 : #[cfg(test)]
154 : mod tests {
155 : use std::pin::pin;
156 :
157 : use futures::{SinkExt, StreamExt};
158 : use hyper_tungstenite::{
159 : tungstenite::{protocol::Role, Message},
160 : WebSocketStream,
161 : };
162 : use tokio::{
163 : io::{duplex, AsyncReadExt, AsyncWriteExt},
164 : task::JoinSet,
165 : };
166 :
167 : use super::WebSocketRw;
168 :
169 CBC 1 : #[tokio::test]
170 1 : async fn websocket_stream_wrapper_happy_path() {
171 1 : let (stream1, stream2) = duplex(1024);
172 1 :
173 1 : let mut js = JoinSet::new();
174 1 :
175 1 : js.spawn(async move {
176 1 : let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
177 :
178 1 : client
179 1 : .send(Message::Binary(b"hello world".to_vec()))
180 UBC 0 : .await
181 CBC 1 : .unwrap();
182 :
183 1 : let message = client.next().await.unwrap().unwrap();
184 1 : assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
185 :
186 1 : client.close(None).await.unwrap();
187 1 : });
188 1 :
189 1 : js.spawn(async move {
190 1 : let mut rw = pin!(WebSocketRw::new(
191 1 : WebSocketStream::from_raw_socket(stream2, Role::Server, None).await
192 : ));
193 :
194 1 : let mut buf = vec![0; 1024];
195 1 : let n = rw.read(&mut buf).await.unwrap();
196 1 : assert_eq!(&buf[..n], b"hello world");
197 :
198 1 : rw.write_all(b"websockets are cool").await.unwrap();
199 1 : rw.flush().await.unwrap();
200 :
201 1 : let n = rw.read_to_end(&mut buf).await.unwrap();
202 1 : assert_eq!(n, 0);
203 1 : });
204 1 :
205 1 : js.join_next().await.unwrap().unwrap();
206 1 : js.join_next().await.unwrap().unwrap();
207 : }
208 : }
|