Line data Source code
1 : use crate::{
2 : cancellation::CancellationHandlerMain,
3 : config::ProxyConfig,
4 : context::RequestMonitoring,
5 : error::{io_error, ReportableError},
6 : metrics::Metrics,
7 : proxy::{handle_client, ClientMode},
8 : rate_limiter::EndpointRateLimiter,
9 : };
10 : use bytes::{Buf, Bytes};
11 : use futures::{Sink, Stream};
12 : use hyper::upgrade::Upgraded;
13 : use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
14 : use pin_project_lite::pin_project;
15 :
16 : use std::{
17 : pin::Pin,
18 : sync::Arc,
19 : task::{ready, Context, Poll},
20 : };
21 : use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
22 : use tracing::warn;
23 :
24 : // TODO: use `std::sync::Exclusive` once it's stabilized.
25 : // Tracking issue: https://github.com/rust-lang/rust/issues/98407.
26 : use sync_wrapper::SyncWrapper;
27 :
28 : pin_project! {
29 : /// This is a wrapper around a [`WebSocketStream`] that
30 : /// implements [`AsyncRead`] and [`AsyncWrite`].
31 : pub struct WebSocketRw<S = Upgraded> {
32 : #[pin]
33 : stream: SyncWrapper<WebSocketStream<S>>,
34 : bytes: Bytes,
35 : }
36 : }
37 :
38 : impl<S> WebSocketRw<S> {
39 2 : pub fn new(stream: WebSocketStream<S>) -> Self {
40 2 : Self {
41 2 : stream: stream.into(),
42 2 : bytes: Bytes::new(),
43 2 : }
44 2 : }
45 : }
46 :
47 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
48 2 : fn poll_write(
49 2 : self: Pin<&mut Self>,
50 2 : cx: &mut Context<'_>,
51 2 : buf: &[u8],
52 2 : ) -> Poll<io::Result<usize>> {
53 2 : let mut stream = self.project().stream.get_pin_mut();
54 :
55 2 : ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
56 2 : match stream.as_mut().start_send(Message::Binary(buf.into())) {
57 2 : Ok(()) => Poll::Ready(Ok(buf.len())),
58 0 : Err(e) => Poll::Ready(Err(io_error(e))),
59 : }
60 2 : }
61 :
62 2 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
63 2 : let stream = self.project().stream.get_pin_mut();
64 2 : stream.poll_flush(cx).map_err(io_error)
65 2 : }
66 :
67 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68 0 : let stream = self.project().stream.get_pin_mut();
69 0 : stream.poll_close(cx).map_err(io_error)
70 0 : }
71 : }
72 :
73 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
74 6 : fn poll_read(
75 6 : mut self: Pin<&mut Self>,
76 6 : cx: &mut Context<'_>,
77 6 : buf: &mut ReadBuf<'_>,
78 6 : ) -> Poll<io::Result<()>> {
79 6 : if buf.remaining() > 0 {
80 6 : let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
81 4 : let len = std::cmp::min(bytes.len(), buf.remaining());
82 4 : buf.put_slice(&bytes[..len]);
83 4 : self.consume(len);
84 0 : }
85 :
86 4 : Poll::Ready(Ok(()))
87 6 : }
88 : }
89 :
90 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
91 6 : fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
92 6 : // Please refer to poll_fill_buf's documentation.
93 6 : const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
94 6 :
95 6 : let mut this = self.project();
96 8 : loop {
97 8 : if !this.bytes.chunk().is_empty() {
98 2 : let chunk = (*this.bytes).chunk();
99 2 : return Poll::Ready(Ok(chunk));
100 6 : }
101 :
102 6 : let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
103 4 : match res.transpose().map_err(io_error)? {
104 4 : Some(message) => match message {
105 0 : Message::Ping(_) => {}
106 0 : Message::Pong(_) => {}
107 0 : Message::Text(text) => {
108 0 : // We expect to see only binary messages.
109 0 : let error = "unexpected text message in the websocket";
110 0 : warn!(length = text.len(), error);
111 0 : return Poll::Ready(Err(io_error(error)));
112 : }
113 : Message::Frame(_) => {
114 : // This case is impossible according to Frame's doc.
115 0 : panic!("unexpected raw frame in the websocket");
116 : }
117 2 : Message::Binary(chunk) => {
118 2 : assert!(this.bytes.is_empty());
119 2 : *this.bytes = Bytes::from(chunk);
120 : }
121 2 : Message::Close(_) => return EOF,
122 : },
123 0 : None => return EOF,
124 : }
125 : }
126 6 : }
127 :
128 4 : fn consume(self: Pin<&mut Self>, amount: usize) {
129 4 : self.project().bytes.advance(amount);
130 4 : }
131 : }
132 :
133 0 : pub async fn serve_websocket(
134 0 : config: &'static ProxyConfig,
135 0 : mut ctx: RequestMonitoring,
136 0 : websocket: HyperWebsocket,
137 0 : cancellation_handler: Arc<CancellationHandlerMain>,
138 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
139 0 : hostname: Option<String>,
140 0 : ) -> anyhow::Result<()> {
141 0 : let websocket = websocket.await?;
142 0 : let conn_gauge = Metrics::get()
143 0 : .proxy
144 0 : .client_connections
145 0 : .guard(crate::metrics::Protocol::Ws);
146 :
147 0 : let res = handle_client(
148 0 : config,
149 0 : &mut ctx,
150 0 : cancellation_handler,
151 0 : WebSocketRw::new(websocket),
152 0 : ClientMode::Websockets { hostname },
153 0 : endpoint_rate_limiter,
154 0 : conn_gauge,
155 0 : )
156 0 : .await;
157 :
158 0 : match res {
159 0 : Err(e) => {
160 0 : // todo: log and push to ctx the error kind
161 0 : ctx.set_error_kind(e.get_error_kind());
162 0 : Err(e.into())
163 : }
164 : Ok(None) => {
165 0 : ctx.set_success();
166 0 : Ok(())
167 : }
168 0 : Ok(Some(p)) => {
169 0 : ctx.set_success();
170 0 : ctx.log_connect();
171 0 : p.proxy_pass().await
172 : }
173 : }
174 0 : }
175 :
176 : #[cfg(test)]
177 : mod tests {
178 : use std::pin::pin;
179 :
180 : use futures::{SinkExt, StreamExt};
181 : use hyper_tungstenite::{
182 : tungstenite::{protocol::Role, Message},
183 : WebSocketStream,
184 : };
185 : use tokio::{
186 : io::{duplex, AsyncReadExt, AsyncWriteExt},
187 : task::JoinSet,
188 : };
189 :
190 : use super::WebSocketRw;
191 :
192 : #[tokio::test]
193 2 : async fn websocket_stream_wrapper_happy_path() {
194 2 : let (stream1, stream2) = duplex(1024);
195 2 :
196 2 : let mut js = JoinSet::new();
197 2 :
198 2 : js.spawn(async move {
199 2 : let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
200 2 :
201 2 : client
202 2 : .send(Message::Binary(b"hello world".to_vec()))
203 2 : .await
204 2 : .unwrap();
205 2 :
206 2 : let message = client.next().await.unwrap().unwrap();
207 2 : assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
208 2 :
209 2 : client.close(None).await.unwrap();
210 2 : });
211 2 :
212 2 : js.spawn(async move {
213 2 : let mut rw = pin!(WebSocketRw::new(
214 2 : WebSocketStream::from_raw_socket(stream2, Role::Server, None).await
215 2 : ));
216 2 :
217 2 : let mut buf = vec![0; 1024];
218 2 : let n = rw.read(&mut buf).await.unwrap();
219 2 : assert_eq!(&buf[..n], b"hello world");
220 2 :
221 2 : rw.write_all(b"websockets are cool").await.unwrap();
222 2 : rw.flush().await.unwrap();
223 2 :
224 2 : let n = rw.read_to_end(&mut buf).await.unwrap();
225 2 : assert_eq!(n, 0);
226 2 : });
227 2 :
228 2 : js.join_next().await.unwrap().unwrap();
229 2 : js.join_next().await.unwrap().unwrap();
230 2 : }
231 : }
|