Line data Source code
1 : use crate::{
2 : cancellation::CancelMap,
3 : config::ProxyConfig,
4 : error::io_error,
5 : protocol2::{ProxyProtocolAccept, WithClientIp},
6 : proxy::{handle_client, ClientMode},
7 : };
8 : use bytes::{Buf, Bytes};
9 : use futures::{Sink, Stream, StreamExt};
10 : use hashbrown::HashMap;
11 : use hyper::{
12 : server::{
13 : accept,
14 : conn::{AddrIncoming, AddrStream},
15 : },
16 : upgrade::Upgraded,
17 : Body, Method, Request, Response, StatusCode,
18 : };
19 : use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
20 : use pin_project_lite::pin_project;
21 : use serde_json::{json, Value};
22 :
23 : use std::{
24 : convert::Infallible,
25 : future::ready,
26 : pin::Pin,
27 : sync::Arc,
28 : task::{ready, Context, Poll},
29 : };
30 : use tls_listener::TlsListener;
31 : use tokio::{
32 : io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf},
33 : net::TcpListener,
34 : };
35 : use tokio_util::sync::CancellationToken;
36 : use tracing::{error, info, info_span, warn, Instrument};
37 : use utils::http::{error::ApiError, json::json_response};
38 :
39 : // TODO: use `std::sync::Exclusive` once it's stabilized.
40 : // Tracking issue: https://github.com/rust-lang/rust/issues/98407.
41 : use sync_wrapper::SyncWrapper;
42 :
43 : use super::{conn_pool::GlobalConnPool, sql_over_http};
44 :
45 : pin_project! {
46 : /// This is a wrapper around a [`WebSocketStream`] that
47 : /// implements [`AsyncRead`] and [`AsyncWrite`].
48 : pub struct WebSocketRw {
49 : #[pin]
50 : stream: SyncWrapper<WebSocketStream<Upgraded>>,
51 : bytes: Bytes,
52 : }
53 : }
54 :
55 : impl WebSocketRw {
56 0 : pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
57 0 : Self {
58 0 : stream: stream.into(),
59 0 : bytes: Bytes::new(),
60 0 : }
61 0 : }
62 : }
63 :
64 : impl AsyncWrite for WebSocketRw {
65 0 : fn poll_write(
66 0 : self: Pin<&mut Self>,
67 0 : cx: &mut Context<'_>,
68 0 : buf: &[u8],
69 0 : ) -> Poll<io::Result<usize>> {
70 0 : let mut stream = self.project().stream.get_pin_mut();
71 :
72 0 : ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
73 0 : match stream.as_mut().start_send(Message::Binary(buf.into())) {
74 0 : Ok(()) => Poll::Ready(Ok(buf.len())),
75 0 : Err(e) => Poll::Ready(Err(io_error(e))),
76 : }
77 0 : }
78 :
79 0 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
80 0 : let stream = self.project().stream.get_pin_mut();
81 0 : stream.poll_flush(cx).map_err(io_error)
82 0 : }
83 :
84 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
85 0 : let stream = self.project().stream.get_pin_mut();
86 0 : stream.poll_close(cx).map_err(io_error)
87 0 : }
88 : }
89 :
90 : impl AsyncRead for WebSocketRw {
91 0 : fn poll_read(
92 0 : mut self: Pin<&mut Self>,
93 0 : cx: &mut Context<'_>,
94 0 : buf: &mut ReadBuf<'_>,
95 0 : ) -> Poll<io::Result<()>> {
96 0 : if buf.remaining() > 0 {
97 0 : let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
98 0 : let len = std::cmp::min(bytes.len(), buf.remaining());
99 0 : buf.put_slice(&bytes[..len]);
100 0 : self.consume(len);
101 0 : }
102 :
103 0 : Poll::Ready(Ok(()))
104 0 : }
105 : }
106 :
107 : impl AsyncBufRead for WebSocketRw {
108 0 : fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
109 0 : // Please refer to poll_fill_buf's documentation.
110 0 : const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
111 0 :
112 0 : let mut this = self.project();
113 0 : loop {
114 0 : if !this.bytes.chunk().is_empty() {
115 0 : let chunk = (*this.bytes).chunk();
116 0 : return Poll::Ready(Ok(chunk));
117 0 : }
118 :
119 0 : let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
120 0 : match res.transpose().map_err(io_error)? {
121 0 : Some(message) => match message {
122 0 : Message::Ping(_) => {}
123 0 : Message::Pong(_) => {}
124 0 : Message::Text(text) => {
125 0 : // We expect to see only binary messages.
126 0 : let error = "unexpected text message in the websocket";
127 0 : warn!(length = text.len(), error);
128 0 : return Poll::Ready(Err(io_error(error)));
129 : }
130 : Message::Frame(_) => {
131 : // This case is impossible according to Frame's doc.
132 0 : panic!("unexpected raw frame in the websocket");
133 : }
134 0 : Message::Binary(chunk) => {
135 0 : assert!(this.bytes.is_empty());
136 0 : *this.bytes = Bytes::from(chunk);
137 : }
138 0 : Message::Close(_) => return EOF,
139 : },
140 0 : None => return EOF,
141 : }
142 : }
143 0 : }
144 :
145 0 : fn consume(self: Pin<&mut Self>, amount: usize) {
146 0 : self.project().bytes.advance(amount);
147 0 : }
148 : }
149 :
150 0 : async fn serve_websocket(
151 0 : websocket: HyperWebsocket,
152 0 : config: &'static ProxyConfig,
153 0 : cancel_map: &CancelMap,
154 0 : session_id: uuid::Uuid,
155 0 : hostname: Option<String>,
156 0 : ) -> anyhow::Result<()> {
157 0 : let websocket = websocket.await?;
158 0 : handle_client(
159 0 : config,
160 0 : cancel_map,
161 0 : session_id,
162 0 : WebSocketRw::new(websocket),
163 0 : ClientMode::Websockets { hostname },
164 0 : )
165 0 : .await?;
166 0 : Ok(())
167 0 : }
168 :
169 22 : async fn ws_handler(
170 22 : mut request: Request<Body>,
171 22 : config: &'static ProxyConfig,
172 22 : conn_pool: Arc<GlobalConnPool>,
173 22 : cancel_map: Arc<CancelMap>,
174 22 : session_id: uuid::Uuid,
175 22 : sni_hostname: Option<String>,
176 22 : ) -> Result<Response<Body>, ApiError> {
177 22 : let host = request
178 22 : .headers()
179 22 : .get("host")
180 22 : .and_then(|h| h.to_str().ok())
181 22 : .and_then(|h| h.split(':').next())
182 22 : .map(|s| s.to_string());
183 22 :
184 22 : // Check if the request is a websocket upgrade request.
185 22 : if hyper_tungstenite::is_upgrade_request(&request) {
186 0 : info!(session_id = ?session_id, "performing websocket upgrade");
187 :
188 0 : let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
189 0 : .map_err(|e| ApiError::BadRequest(e.into()))?;
190 :
191 0 : tokio::spawn(
192 0 : async move {
193 0 : if let Err(e) =
194 0 : serve_websocket(websocket, config, &cancel_map, session_id, host).await
195 : {
196 0 : error!(session_id = ?session_id, "error in websocket connection: {e:#}");
197 0 : }
198 0 : }
199 0 : .in_current_span(),
200 0 : );
201 0 :
202 0 : // Return the response so the spawned future can continue.
203 0 : Ok(response)
204 : // TODO: that deserves a refactor as now this function also handles http json client besides websockets.
205 : // Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead.
206 22 : } else if request.uri().path() == "/sql" && request.method() == Method::POST {
207 22 : let result = sql_over_http::handle(request, sni_hostname, conn_pool, session_id)
208 22 : .instrument(info_span!("sql-over-http"))
209 149 : .await;
210 22 : let status_code = match result {
211 20 : Ok(_) => StatusCode::OK,
212 2 : Err(_) => StatusCode::BAD_REQUEST,
213 : };
214 22 : let (json, headers) = match result {
215 20 : Ok(r) => r,
216 2 : Err(e) => {
217 2 : let message = format!("{:?}", e);
218 2 : let code = match e.downcast_ref::<tokio_postgres::Error>() {
219 2 : Some(e) => match e.code() {
220 2 : Some(e) => serde_json::to_value(e.code()).unwrap(),
221 0 : None => Value::Null,
222 : },
223 0 : None => Value::Null,
224 : };
225 2 : error!(
226 2 : ?code,
227 2 : "sql-over-http per-client task finished with an error: {e:#}"
228 2 : );
229 2 : (
230 2 : json!({ "message": message, "code": code }),
231 2 : HashMap::default(),
232 2 : )
233 : }
234 : };
235 22 : json_response(status_code, json).map(|mut r| {
236 22 : r.headers_mut().insert(
237 22 : "Access-Control-Allow-Origin",
238 22 : hyper::http::HeaderValue::from_static("*"),
239 22 : );
240 26 : for (k, v) in headers {
241 4 : r.headers_mut().insert(k, v);
242 4 : }
243 22 : r
244 22 : })
245 0 : } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
246 0 : Response::builder()
247 0 : .header("Allow", "OPTIONS, POST")
248 0 : .header("Access-Control-Allow-Origin", "*")
249 0 : .header(
250 0 : "Access-Control-Allow-Headers",
251 0 : "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In",
252 0 : )
253 0 : .header("Access-Control-Max-Age", "86400" /* 24 hours */)
254 0 : .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
255 0 : .body(Body::empty())
256 0 : .map_err(|e| ApiError::BadRequest(e.into()))
257 : } else {
258 0 : json_response(StatusCode::BAD_REQUEST, "query is not supported")
259 : }
260 22 : }
261 :
262 14 : pub async fn task_main(
263 14 : config: &'static ProxyConfig,
264 14 : ws_listener: TcpListener,
265 14 : cancellation_token: CancellationToken,
266 14 : ) -> anyhow::Result<()> {
267 14 : scopeguard::defer! {
268 14 : info!("websocket server has shut down");
269 14 : }
270 14 :
271 14 : let conn_pool: Arc<GlobalConnPool> = GlobalConnPool::new(config);
272 14 :
273 14 : // shutdown the connection pool
274 14 : tokio::spawn({
275 14 : let cancellation_token = cancellation_token.clone();
276 14 : let conn_pool = conn_pool.clone();
277 14 : async move {
278 14 : cancellation_token.cancelled().await;
279 14 : tokio::task::spawn_blocking(move || conn_pool.shutdown())
280 14 : .await
281 14 : .unwrap();
282 14 : }
283 14 : });
284 14 :
285 14 : let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
286 14 : let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
287 14 : Some(config) => config.into(),
288 : None => {
289 0 : warn!("TLS config is missing, WebSocket Secure server will not be started");
290 0 : return Ok(());
291 : }
292 : };
293 :
294 14 : let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
295 14 : let _ = addr_incoming.set_nodelay(true);
296 14 : let addr_incoming = ProxyProtocolAccept {
297 14 : incoming: addr_incoming,
298 14 : };
299 14 :
300 14 : let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
301 22 : if let Err(err) = conn {
302 0 : error!("failed to accept TLS connection for websockets: {err:?}");
303 0 : ready(false)
304 : } else {
305 22 : ready(true)
306 : }
307 22 : });
308 14 :
309 14 : let make_svc = hyper::service::make_service_fn(
310 22 : |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
311 22 : let (io, tls) = stream.get_ref();
312 22 : let peer_addr = io.client_addr().unwrap_or(io.inner.remote_addr());
313 22 : let sni_name = tls.server_name().map(|s| s.to_string());
314 22 : let conn_pool = conn_pool.clone();
315 :
316 22 : async move {
317 22 : Ok::<_, Infallible>(hyper::service::service_fn(move |req: Request<Body>| {
318 22 : let sni_name = sni_name.clone();
319 22 : let conn_pool = conn_pool.clone();
320 :
321 22 : async move {
322 22 : let cancel_map = Arc::new(CancelMap::default());
323 22 : let session_id = uuid::Uuid::new_v4();
324 22 :
325 22 : ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name)
326 22 : .instrument(info_span!(
327 22 : "ws-client",
328 22 : session = %session_id,
329 22 : %peer_addr,
330 22 : ))
331 149 : .await
332 22 : }
333 22 : }))
334 22 : }
335 22 : },
336 14 : );
337 14 :
338 14 : hyper::Server::builder(accept::from_stream(tls_listener))
339 14 : .serve(make_svc)
340 14 : .with_graceful_shutdown(cancellation_token.cancelled())
341 124 : .await?;
342 :
343 14 : Ok(())
344 14 : }
|