TLA Line data Source code
1 : use crate::{
2 : cancellation::CancelMap,
3 : config::ProxyConfig,
4 : error::io_error,
5 : protocol2::{ProxyProtocolAccept, WithClientIp},
6 : proxy::{
7 : handle_client, ClientMode, NUM_CLIENT_CONNECTION_CLOSED_COUNTER,
8 : NUM_CLIENT_CONNECTION_OPENED_COUNTER,
9 : },
10 : };
11 : use anyhow::bail;
12 : use bytes::{Buf, Bytes};
13 : use futures::{Sink, Stream, StreamExt};
14 : use hyper::{
15 : server::{
16 : accept,
17 : conn::{AddrIncoming, AddrStream},
18 : },
19 : upgrade::Upgraded,
20 : Body, Method, Request, Response, StatusCode,
21 : };
22 : use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
23 : use pin_project_lite::pin_project;
24 :
25 : use std::{
26 : future::ready,
27 : pin::Pin,
28 : sync::Arc,
29 : task::{ready, Context, Poll},
30 : };
31 : use tls_listener::TlsListener;
32 : use tokio::{
33 : io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf},
34 : net::TcpListener,
35 : };
36 : use tokio_util::sync::CancellationToken;
37 : use tracing::{error, info, info_span, warn, Instrument};
38 : use utils::http::{error::ApiError, json::json_response};
39 :
40 : // TODO: use `std::sync::Exclusive` once it's stabilized.
41 : // Tracking issue: https://github.com/rust-lang/rust/issues/98407.
42 : use sync_wrapper::SyncWrapper;
43 :
44 : use super::{conn_pool::GlobalConnPool, sql_over_http};
45 :
46 : pin_project! {
47 : /// This is a wrapper around a [`WebSocketStream`] that
48 : /// implements [`AsyncRead`] and [`AsyncWrite`].
49 : pub struct WebSocketRw {
50 : #[pin]
51 : stream: SyncWrapper<WebSocketStream<Upgraded>>,
52 : bytes: Bytes,
53 : }
54 : }
55 :
56 : impl WebSocketRw {
57 UBC 0 : pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
58 0 : Self {
59 0 : stream: stream.into(),
60 0 : bytes: Bytes::new(),
61 0 : }
62 0 : }
63 : }
64 :
65 : impl AsyncWrite for WebSocketRw {
66 0 : fn poll_write(
67 0 : self: Pin<&mut Self>,
68 0 : cx: &mut Context<'_>,
69 0 : buf: &[u8],
70 0 : ) -> Poll<io::Result<usize>> {
71 0 : let mut stream = self.project().stream.get_pin_mut();
72 :
73 0 : ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
74 0 : match stream.as_mut().start_send(Message::Binary(buf.into())) {
75 0 : Ok(()) => Poll::Ready(Ok(buf.len())),
76 0 : Err(e) => Poll::Ready(Err(io_error(e))),
77 : }
78 0 : }
79 :
80 0 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81 0 : let stream = self.project().stream.get_pin_mut();
82 0 : stream.poll_flush(cx).map_err(io_error)
83 0 : }
84 :
85 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
86 0 : let stream = self.project().stream.get_pin_mut();
87 0 : stream.poll_close(cx).map_err(io_error)
88 0 : }
89 : }
90 :
91 : impl AsyncRead for WebSocketRw {
92 0 : fn poll_read(
93 0 : mut self: Pin<&mut Self>,
94 0 : cx: &mut Context<'_>,
95 0 : buf: &mut ReadBuf<'_>,
96 0 : ) -> Poll<io::Result<()>> {
97 0 : if buf.remaining() > 0 {
98 0 : let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
99 0 : let len = std::cmp::min(bytes.len(), buf.remaining());
100 0 : buf.put_slice(&bytes[..len]);
101 0 : self.consume(len);
102 0 : }
103 :
104 0 : Poll::Ready(Ok(()))
105 0 : }
106 : }
107 :
108 : impl AsyncBufRead for WebSocketRw {
109 0 : fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
110 0 : // Please refer to poll_fill_buf's documentation.
111 0 : const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
112 0 :
113 0 : let mut this = self.project();
114 0 : loop {
115 0 : if !this.bytes.chunk().is_empty() {
116 0 : let chunk = (*this.bytes).chunk();
117 0 : return Poll::Ready(Ok(chunk));
118 0 : }
119 :
120 0 : let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
121 0 : match res.transpose().map_err(io_error)? {
122 0 : Some(message) => match message {
123 0 : Message::Ping(_) => {}
124 0 : Message::Pong(_) => {}
125 0 : Message::Text(text) => {
126 0 : // We expect to see only binary messages.
127 0 : let error = "unexpected text message in the websocket";
128 0 : warn!(length = text.len(), error);
129 0 : return Poll::Ready(Err(io_error(error)));
130 : }
131 : Message::Frame(_) => {
132 : // This case is impossible according to Frame's doc.
133 0 : panic!("unexpected raw frame in the websocket");
134 : }
135 0 : Message::Binary(chunk) => {
136 0 : assert!(this.bytes.is_empty());
137 0 : *this.bytes = Bytes::from(chunk);
138 : }
139 0 : Message::Close(_) => return EOF,
140 : },
141 0 : None => return EOF,
142 : }
143 : }
144 0 : }
145 :
146 0 : fn consume(self: Pin<&mut Self>, amount: usize) {
147 0 : self.project().bytes.advance(amount);
148 0 : }
149 : }
150 :
151 0 : async fn serve_websocket(
152 0 : websocket: HyperWebsocket,
153 0 : config: &'static ProxyConfig,
154 0 : cancel_map: &CancelMap,
155 0 : session_id: uuid::Uuid,
156 0 : hostname: Option<String>,
157 0 : ) -> anyhow::Result<()> {
158 0 : let websocket = websocket.await?;
159 0 : handle_client(
160 0 : config,
161 0 : cancel_map,
162 0 : session_id,
163 0 : WebSocketRw::new(websocket),
164 0 : ClientMode::Websockets { hostname },
165 0 : )
166 0 : .await?;
167 0 : Ok(())
168 0 : }
169 :
170 CBC 30 : async fn ws_handler(
171 30 : mut request: Request<Body>,
172 30 : config: &'static ProxyConfig,
173 30 : conn_pool: Arc<GlobalConnPool>,
174 30 : cancel_map: Arc<CancelMap>,
175 30 : session_id: uuid::Uuid,
176 30 : sni_hostname: Option<String>,
177 30 : ) -> Result<Response<Body>, ApiError> {
178 30 : let host = request
179 30 : .headers()
180 30 : .get("host")
181 30 : .and_then(|h| h.to_str().ok())
182 30 : .and_then(|h| h.split(':').next())
183 30 : .map(|s| s.to_string());
184 30 :
185 30 : // Check if the request is a websocket upgrade request.
186 30 : if hyper_tungstenite::is_upgrade_request(&request) {
187 UBC 0 : info!(session_id = ?session_id, "performing websocket upgrade");
188 :
189 0 : let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
190 0 : .map_err(|e| ApiError::BadRequest(e.into()))?;
191 :
192 0 : tokio::spawn(
193 0 : async move {
194 0 : if let Err(e) =
195 0 : serve_websocket(websocket, config, &cancel_map, session_id, host).await
196 : {
197 0 : error!(session_id = ?session_id, "error in websocket connection: {e:#}");
198 0 : }
199 0 : }
200 0 : .in_current_span(),
201 0 : );
202 0 :
203 0 : // Return the response so the spawned future can continue.
204 0 : Ok(response)
205 : // TODO: that deserves a refactor as now this function also handles http json client besides websockets.
206 : // Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead.
207 CBC 30 : } else if request.uri().path() == "/sql" && request.method() == Method::POST {
208 30 : sql_over_http::handle(
209 30 : request,
210 30 : sni_hostname,
211 30 : conn_pool,
212 30 : session_id,
213 30 : &config.http_config,
214 30 : )
215 233 : .await
216 UBC 0 : } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
217 0 : Response::builder()
218 0 : .header("Allow", "OPTIONS, POST")
219 0 : .header("Access-Control-Allow-Origin", "*")
220 0 : .header(
221 0 : "Access-Control-Allow-Headers",
222 0 : "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In",
223 0 : )
224 0 : .header("Access-Control-Max-Age", "86400" /* 24 hours */)
225 0 : .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
226 0 : .body(Body::empty())
227 0 : .map_err(|e| ApiError::InternalServerError(e.into()))
228 : } else {
229 0 : json_response(StatusCode::BAD_REQUEST, "query is not supported")
230 : }
231 CBC 30 : }
232 :
233 16 : pub async fn task_main(
234 16 : config: &'static ProxyConfig,
235 16 : ws_listener: TcpListener,
236 16 : cancellation_token: CancellationToken,
237 16 : ) -> anyhow::Result<()> {
238 16 : scopeguard::defer! {
239 16 : info!("websocket server has shut down");
240 16 : }
241 16 :
242 16 : let conn_pool: Arc<GlobalConnPool> = GlobalConnPool::new(config);
243 16 :
244 16 : // shutdown the connection pool
245 16 : tokio::spawn({
246 16 : let cancellation_token = cancellation_token.clone();
247 16 : let conn_pool = conn_pool.clone();
248 16 : async move {
249 16 : cancellation_token.cancelled().await;
250 16 : tokio::task::spawn_blocking(move || conn_pool.shutdown())
251 16 : .await
252 16 : .unwrap();
253 16 : }
254 16 : });
255 16 :
256 16 : let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
257 16 : let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
258 16 : Some(config) => config.into(),
259 : None => {
260 UBC 0 : warn!("TLS config is missing, WebSocket Secure server will not be started");
261 0 : return Ok(());
262 : }
263 : };
264 :
265 CBC 16 : let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
266 16 : let _ = addr_incoming.set_nodelay(true);
267 16 : let addr_incoming = ProxyProtocolAccept {
268 16 : incoming: addr_incoming,
269 16 : };
270 16 :
271 16 : let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
272 30 : if let Err(err) = conn {
273 UBC 0 : error!("failed to accept TLS connection for websockets: {err:?}");
274 0 : ready(false)
275 : } else {
276 CBC 30 : ready(true)
277 : }
278 30 : });
279 16 :
280 16 : let make_svc = hyper::service::make_service_fn(
281 30 : |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
282 30 : let (io, tls) = stream.get_ref();
283 30 : let client_addr = io.client_addr();
284 30 : let remote_addr = io.inner.remote_addr();
285 30 : let sni_name = tls.server_name().map(|s| s.to_string());
286 30 : let conn_pool = conn_pool.clone();
287 :
288 30 : async move {
289 30 : let peer_addr = match client_addr {
290 UBC 0 : Some(addr) => addr,
291 0 : None if config.require_client_ip => bail!("missing required client ip"),
292 CBC 30 : None => remote_addr,
293 : };
294 30 : Ok(MetricService::new(hyper::service::service_fn(
295 30 : move |req: Request<Body>| {
296 30 : let sni_name = sni_name.clone();
297 30 : let conn_pool = conn_pool.clone();
298 :
299 30 : async move {
300 30 : let cancel_map = Arc::new(CancelMap::default());
301 30 : let session_id = uuid::Uuid::new_v4();
302 30 :
303 30 : ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name)
304 30 : .instrument(info_span!(
305 30 : "ws-client",
306 30 : session = %session_id,
307 30 : %peer_addr,
308 30 : ))
309 233 : .await
310 30 : }
311 30 : },
312 30 : )))
313 30 : }
314 30 : },
315 16 : );
316 16 :
317 16 : hyper::Server::builder(accept::from_stream(tls_listener))
318 16 : .serve(make_svc)
319 16 : .with_graceful_shutdown(cancellation_token.cancelled())
320 166 : .await?;
321 :
322 16 : Ok(())
323 16 : }
324 :
325 : struct MetricService<S> {
326 : inner: S,
327 : }
328 :
329 : impl<S> MetricService<S> {
330 30 : fn new(inner: S) -> MetricService<S> {
331 30 : NUM_CLIENT_CONNECTION_OPENED_COUNTER
332 30 : .with_label_values(&["http"])
333 30 : .inc();
334 30 : MetricService { inner }
335 30 : }
336 : }
337 :
338 : impl<S> Drop for MetricService<S> {
339 30 : fn drop(&mut self) {
340 30 : NUM_CLIENT_CONNECTION_CLOSED_COUNTER
341 30 : .with_label_values(&["http"])
342 30 : .inc();
343 30 : }
344 : }
345 :
346 : impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
347 : where
348 : S: hyper::service::Service<Request<ReqBody>>,
349 : {
350 : type Response = S::Response;
351 : type Error = S::Error;
352 : type Future = S::Future;
353 :
354 90 : fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
355 90 : self.inner.poll_ready(cx)
356 90 : }
357 :
358 30 : fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
359 30 : self.inner.call(req)
360 30 : }
361 : }
|