Line data Source code
1 : //! Routers for our serverless APIs
2 : //!
3 : //! Handles both SQL over HTTP and SQL over Websockets.
4 :
5 : mod backend;
6 : mod conn_pool;
7 : mod http_util;
8 : mod json;
9 : mod sql_over_http;
10 : mod websocket;
11 :
12 : use atomic_take::AtomicTake;
13 : use bytes::Bytes;
14 : pub use conn_pool::GlobalConnPoolOptions;
15 :
16 : use anyhow::Context;
17 : use futures::future::{select, Either};
18 : use futures::TryFutureExt;
19 : use http::{Method, Response, StatusCode};
20 : use http_body_util::Full;
21 : use hyper1::body::Incoming;
22 : use hyper_util::rt::TokioExecutor;
23 : use hyper_util::server::conn::auto::Builder;
24 : use rand::rngs::StdRng;
25 : use rand::SeedableRng;
26 : pub use reqwest_middleware::{ClientWithMiddleware, Error};
27 : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
28 : use tokio::time::timeout;
29 : use tokio_rustls::TlsAcceptor;
30 : use tokio_util::task::TaskTracker;
31 :
32 : use crate::cancellation::CancellationHandlerMain;
33 : use crate::config::ProxyConfig;
34 : use crate::context::RequestMonitoring;
35 : use crate::metrics::Metrics;
36 : use crate::protocol2::WithClientIp;
37 : use crate::proxy::run_until_cancelled;
38 : use crate::serverless::backend::PoolingBackend;
39 : use crate::serverless::http_util::{api_error_into_response, json_response};
40 :
41 : use std::net::{IpAddr, SocketAddr};
42 : use std::pin::pin;
43 : use std::sync::Arc;
44 : use tokio::net::{TcpListener, TcpStream};
45 : use tokio_util::sync::CancellationToken;
46 : use tracing::{error, info, warn, Instrument};
47 : use utils::http::error::ApiError;
48 :
49 : pub const SERVERLESS_DRIVER_SNI: &str = "api";
50 :
51 0 : pub async fn task_main(
52 0 : config: &'static ProxyConfig,
53 0 : ws_listener: TcpListener,
54 0 : cancellation_token: CancellationToken,
55 0 : cancellation_handler: Arc<CancellationHandlerMain>,
56 0 : ) -> anyhow::Result<()> {
57 0 : scopeguard::defer! {
58 0 : info!("websocket server has shut down");
59 : }
60 :
61 0 : let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
62 0 : {
63 0 : let conn_pool = Arc::clone(&conn_pool);
64 0 : tokio::spawn(async move {
65 0 : conn_pool.gc_worker(StdRng::from_entropy()).await;
66 0 : });
67 0 : }
68 0 :
69 0 : // shutdown the connection pool
70 0 : tokio::spawn({
71 0 : let cancellation_token = cancellation_token.clone();
72 0 : let conn_pool = conn_pool.clone();
73 0 : async move {
74 0 : cancellation_token.cancelled().await;
75 0 : tokio::task::spawn_blocking(move || conn_pool.shutdown())
76 0 : .await
77 0 : .unwrap();
78 0 : }
79 0 : });
80 0 :
81 0 : let backend = Arc::new(PoolingBackend {
82 0 : pool: Arc::clone(&conn_pool),
83 0 : config,
84 0 : });
85 :
86 0 : let tls_config = match config.tls_config.as_ref() {
87 0 : Some(config) => config,
88 : None => {
89 0 : warn!("TLS config is missing, WebSocket Secure server will not be started");
90 0 : return Ok(());
91 : }
92 : };
93 0 : let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
94 0 : // prefer http2, but support http/1.1
95 0 : tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
96 0 : let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
97 0 :
98 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
99 0 : connections.close(); // allows `connections.wait to complete`
100 0 :
101 0 : let server = Builder::new(hyper_util::rt::TokioExecutor::new());
102 :
103 0 : while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
104 0 : let (conn, peer_addr) = res.context("could not accept TCP stream")?;
105 0 : if let Err(e) = conn.set_nodelay(true) {
106 0 : tracing::error!("could not set nodelay: {e}");
107 0 : continue;
108 0 : }
109 0 : let conn_id = uuid::Uuid::new_v4();
110 0 : let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
111 :
112 0 : connections.spawn(
113 0 : connection_handler(
114 0 : config,
115 0 : backend.clone(),
116 0 : connections.clone(),
117 0 : cancellation_handler.clone(),
118 0 : cancellation_token.clone(),
119 0 : server.clone(),
120 0 : tls_acceptor.clone(),
121 0 : conn,
122 0 : peer_addr,
123 0 : )
124 0 : .instrument(http_conn_span),
125 0 : );
126 : }
127 :
128 0 : connections.wait().await;
129 :
130 0 : Ok(())
131 0 : }
132 :
133 : /// Handles the TCP lifecycle.
134 : ///
135 : /// 1. Parses PROXY protocol V2
136 : /// 2. Handles TLS handshake
137 : /// 3. Handles HTTP connection
138 : /// 1. With graceful shutdowns
139 : /// 2. With graceful request cancellation with connection failure
140 : /// 3. With websocket upgrade support.
141 : #[allow(clippy::too_many_arguments)]
142 0 : async fn connection_handler(
143 0 : config: &'static ProxyConfig,
144 0 : backend: Arc<PoolingBackend>,
145 0 : connections: TaskTracker,
146 0 : cancellation_handler: Arc<CancellationHandlerMain>,
147 0 : cancellation_token: CancellationToken,
148 0 : server: Builder<TokioExecutor>,
149 0 : tls_acceptor: TlsAcceptor,
150 0 : conn: TcpStream,
151 0 : peer_addr: SocketAddr,
152 0 : ) {
153 0 : let session_id = uuid::Uuid::new_v4();
154 0 :
155 0 : let _gauge = Metrics::get()
156 0 : .proxy
157 0 : .client_connections
158 0 : .guard(crate::metrics::Protocol::Http);
159 0 :
160 0 : // handle PROXY protocol
161 0 : let mut conn = WithClientIp::new(conn);
162 0 : let peer = match conn.wait_for_addr().await {
163 0 : Ok(peer) => peer,
164 0 : Err(e) => {
165 0 : tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
166 0 : return;
167 : }
168 : };
169 :
170 0 : let peer_addr = peer.unwrap_or(peer_addr).ip();
171 0 : let has_private_peer_addr = match peer_addr {
172 0 : IpAddr::V4(ip) => ip.is_private(),
173 0 : _ => false,
174 : };
175 0 : info!(?session_id, %peer_addr, "accepted new TCP connection");
176 :
177 : // try upgrade to TLS, but with a timeout.
178 0 : let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
179 0 : Ok(Ok(conn)) => {
180 0 : info!(?session_id, %peer_addr, "accepted new TLS connection");
181 0 : conn
182 : }
183 : // The handshake failed
184 0 : Ok(Err(e)) => {
185 0 : if !has_private_peer_addr {
186 0 : Metrics::get().proxy.tls_handshake_failures.inc();
187 0 : }
188 0 : warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
189 0 : return;
190 : }
191 : // The handshake timed out
192 0 : Err(e) => {
193 0 : if !has_private_peer_addr {
194 0 : Metrics::get().proxy.tls_handshake_failures.inc();
195 0 : }
196 0 : warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
197 0 : return;
198 : }
199 : };
200 :
201 0 : let session_id = AtomicTake::new(session_id);
202 0 :
203 0 : // Cancel all current inflight HTTP requests if the HTTP connection is closed.
204 0 : let http_cancellation_token = CancellationToken::new();
205 0 : let _cancel_connection = http_cancellation_token.clone().drop_guard();
206 0 :
207 0 : let conn = server.serve_connection_with_upgrades(
208 0 : hyper_util::rt::TokioIo::new(conn),
209 0 : hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
210 0 : // First HTTP request shares the same session ID
211 0 : let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
212 0 :
213 0 : // Cancel the current inflight HTTP request if the requets stream is closed.
214 0 : // This is slightly different to `_cancel_connection` in that
215 0 : // h2 can cancel individual requests with a `RST_STREAM`.
216 0 : let http_request_token = http_cancellation_token.child_token();
217 0 : let cancel_request = http_request_token.clone().drop_guard();
218 0 :
219 0 : // `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
220 0 : // By spawning the future, we ensure it never gets cancelled until it decides to.
221 0 : let handler = connections.spawn(
222 0 : request_handler(
223 0 : req,
224 0 : config,
225 0 : backend.clone(),
226 0 : connections.clone(),
227 0 : cancellation_handler.clone(),
228 0 : session_id,
229 0 : peer_addr,
230 0 : http_request_token,
231 0 : )
232 0 : .in_current_span()
233 0 : .map_ok_or_else(api_error_into_response, |r| r),
234 0 : );
235 :
236 0 : async move {
237 0 : let res = handler.await;
238 0 : cancel_request.disarm();
239 0 : res
240 0 : }
241 0 : }),
242 0 : );
243 :
244 : // On cancellation, trigger the HTTP connection handler to shut down.
245 0 : let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
246 0 : Either::Left((_cancelled, mut conn)) => {
247 0 : conn.as_mut().graceful_shutdown();
248 0 : conn.await
249 : }
250 0 : Either::Right((res, _)) => res,
251 : };
252 :
253 0 : match res {
254 0 : Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
255 0 : Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
256 : }
257 0 : }
258 :
259 : #[allow(clippy::too_many_arguments)]
260 0 : async fn request_handler(
261 0 : mut request: hyper1::Request<Incoming>,
262 0 : config: &'static ProxyConfig,
263 0 : backend: Arc<PoolingBackend>,
264 0 : ws_connections: TaskTracker,
265 0 : cancellation_handler: Arc<CancellationHandlerMain>,
266 0 : session_id: uuid::Uuid,
267 0 : peer_addr: IpAddr,
268 0 : // used to cancel in-flight HTTP requests. not used to cancel websockets
269 0 : http_cancellation_token: CancellationToken,
270 0 : ) -> Result<Response<Full<Bytes>>, ApiError> {
271 0 : let host = request
272 0 : .headers()
273 0 : .get("host")
274 0 : .and_then(|h| h.to_str().ok())
275 0 : .and_then(|h| h.split(':').next())
276 0 : .map(|s| s.to_string());
277 0 :
278 0 : // Check if the request is a websocket upgrade request.
279 0 : if hyper_tungstenite::is_upgrade_request(&request) {
280 0 : let ctx = RequestMonitoring::new(
281 0 : session_id,
282 0 : peer_addr,
283 0 : crate::metrics::Protocol::Ws,
284 0 : &config.region,
285 0 : );
286 0 :
287 0 : let span = ctx.span.clone();
288 0 : info!(parent: &span, "performing websocket upgrade");
289 :
290 0 : let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
291 0 : .map_err(|e| ApiError::BadRequest(e.into()))?;
292 :
293 0 : ws_connections.spawn(
294 0 : async move {
295 0 : if let Err(e) =
296 0 : websocket::serve_websocket(config, ctx, websocket, cancellation_handler, host)
297 0 : .await
298 : {
299 0 : error!("error in websocket connection: {e:#}");
300 0 : }
301 0 : }
302 0 : .instrument(span),
303 0 : );
304 0 :
305 0 : // Return the response so the spawned future can continue.
306 0 : Ok(response)
307 0 : } else if request.uri().path() == "/sql" && *request.method() == Method::POST {
308 0 : let ctx = RequestMonitoring::new(
309 0 : session_id,
310 0 : peer_addr,
311 0 : crate::metrics::Protocol::Http,
312 0 : &config.region,
313 0 : );
314 0 : let span = ctx.span.clone();
315 0 :
316 0 : sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
317 0 : .instrument(span)
318 0 : .await
319 0 : } else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
320 0 : Response::builder()
321 0 : .header("Allow", "OPTIONS, POST")
322 0 : .header("Access-Control-Allow-Origin", "*")
323 0 : .header(
324 0 : "Access-Control-Allow-Headers",
325 0 : "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
326 0 : )
327 0 : .header("Access-Control-Max-Age", "86400" /* 24 hours */)
328 0 : .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
329 0 : .body(Full::new(Bytes::new()))
330 0 : .map_err(|e| ApiError::InternalServerError(e.into()))
331 : } else {
332 0 : json_response(StatusCode::BAD_REQUEST, "query is not supported")
333 : }
334 0 : }
|