Line data Source code
1 : //! Routers for our serverless APIs
2 : //!
3 : //! Handles both SQL over HTTP and SQL over Websockets.
4 :
5 : mod backend;
6 : pub mod cancel_set;
7 : mod conn_pool;
8 : mod http_util;
9 : mod json;
10 : mod sql_over_http;
11 : mod websocket;
12 :
13 : use async_trait::async_trait;
14 : use atomic_take::AtomicTake;
15 : use bytes::Bytes;
16 : pub use conn_pool::GlobalConnPoolOptions;
17 :
18 : use anyhow::Context;
19 : use futures::future::{select, Either};
20 : use futures::TryFutureExt;
21 : use http::{Method, Response, StatusCode};
22 : use http_body_util::Full;
23 : use hyper1::body::Incoming;
24 : use hyper_util::rt::TokioExecutor;
25 : use hyper_util::server::conn::auto::Builder;
26 : use rand::rngs::StdRng;
27 : use rand::SeedableRng;
28 : pub use reqwest_middleware::{ClientWithMiddleware, Error};
29 : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
30 : use tokio::io::{AsyncRead, AsyncWrite};
31 : use tokio::time::timeout;
32 : use tokio_rustls::TlsAcceptor;
33 : use tokio_util::task::TaskTracker;
34 :
35 : use crate::cancellation::CancellationHandlerMain;
36 : use crate::config::ProxyConfig;
37 : use crate::context::RequestMonitoring;
38 : use crate::metrics::Metrics;
39 : use crate::protocol2::{read_proxy_protocol, ChainRW};
40 : use crate::proxy::run_until_cancelled;
41 : use crate::rate_limiter::EndpointRateLimiter;
42 : use crate::serverless::backend::PoolingBackend;
43 : use crate::serverless::http_util::{api_error_into_response, json_response};
44 :
45 : use std::net::{IpAddr, SocketAddr};
46 : use std::pin::{pin, Pin};
47 : use std::sync::Arc;
48 : use tokio::net::{TcpListener, TcpStream};
49 : use tokio_util::sync::CancellationToken;
50 : use tracing::{error, info, warn, Instrument};
51 : use utils::http::error::ApiError;
52 :
53 : pub const SERVERLESS_DRIVER_SNI: &str = "api";
54 :
55 0 : pub async fn task_main(
56 0 : config: &'static ProxyConfig,
57 0 : ws_listener: TcpListener,
58 0 : cancellation_token: CancellationToken,
59 0 : cancellation_handler: Arc<CancellationHandlerMain>,
60 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
61 0 : ) -> anyhow::Result<()> {
62 : scopeguard::defer! {
63 : info!("websocket server has shut down");
64 : }
65 :
66 0 : let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
67 0 : {
68 0 : let conn_pool = Arc::clone(&conn_pool);
69 0 : tokio::spawn(async move {
70 0 : conn_pool.gc_worker(StdRng::from_entropy()).await;
71 0 : });
72 0 : }
73 0 :
74 0 : // shutdown the connection pool
75 0 : tokio::spawn({
76 0 : let cancellation_token = cancellation_token.clone();
77 0 : let conn_pool = conn_pool.clone();
78 0 : async move {
79 0 : cancellation_token.cancelled().await;
80 0 : tokio::task::spawn_blocking(move || conn_pool.shutdown())
81 0 : .await
82 0 : .unwrap();
83 0 : }
84 0 : });
85 0 :
86 0 : let backend = Arc::new(PoolingBackend {
87 0 : pool: Arc::clone(&conn_pool),
88 0 : config,
89 0 : endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
90 0 : });
91 0 : let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
92 0 : Some(config) => {
93 0 : let mut tls_server_config = rustls::ServerConfig::clone(&config.to_server_config());
94 0 : // prefer http2, but support http/1.1
95 0 : tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
96 0 : Arc::new(tls_server_config) as Arc<_>
97 : }
98 : None => {
99 0 : warn!("TLS config is missing");
100 0 : Arc::new(NoTls) as Arc<_>
101 : }
102 : };
103 :
104 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
105 0 : connections.close(); // allows `connections.wait to complete`
106 :
107 0 : while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
108 0 : let (conn, peer_addr) = res.context("could not accept TCP stream")?;
109 0 : if let Err(e) = conn.set_nodelay(true) {
110 0 : tracing::error!("could not set nodelay: {e}");
111 0 : continue;
112 0 : }
113 0 : let conn_id = uuid::Uuid::new_v4();
114 0 : let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
115 :
116 0 : let n_connections = Metrics::get()
117 0 : .proxy
118 0 : .client_connections
119 0 : .sample(crate::metrics::Protocol::Http);
120 0 : tracing::trace!(?n_connections, threshold = ?config.http_config.client_conn_threshold, "check");
121 0 : if n_connections > config.http_config.client_conn_threshold {
122 0 : tracing::trace!("attempting to cancel a random connection");
123 0 : if let Some(token) = config.http_config.cancel_set.take() {
124 0 : tracing::debug!("cancelling a random connection");
125 0 : token.cancel();
126 0 : }
127 0 : }
128 :
129 0 : let conn_token = cancellation_token.child_token();
130 0 : let tls_acceptor = tls_acceptor.clone();
131 0 : let backend = backend.clone();
132 0 : let connections2 = connections.clone();
133 0 : let cancellation_handler = cancellation_handler.clone();
134 0 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
135 0 : connections.spawn(
136 0 : async move {
137 0 : let conn_token2 = conn_token.clone();
138 0 : let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
139 0 :
140 0 : let session_id = uuid::Uuid::new_v4();
141 0 :
142 0 : let _gauge = Metrics::get()
143 0 : .proxy
144 0 : .client_connections
145 0 : .guard(crate::metrics::Protocol::Http);
146 :
147 0 : let startup_result = Box::pin(connection_startup(
148 0 : config,
149 0 : tls_acceptor,
150 0 : session_id,
151 0 : conn,
152 0 : peer_addr,
153 0 : ))
154 0 : .await;
155 0 : let Some((conn, peer_addr)) = startup_result else {
156 0 : return;
157 : };
158 :
159 0 : Box::pin(connection_handler(
160 0 : config,
161 0 : backend,
162 0 : connections2,
163 0 : cancellation_handler,
164 0 : endpoint_rate_limiter,
165 0 : conn_token,
166 0 : conn,
167 0 : peer_addr,
168 0 : session_id,
169 0 : ))
170 0 : .await;
171 0 : }
172 0 : .instrument(http_conn_span),
173 0 : );
174 : }
175 :
176 0 : connections.wait().await;
177 :
178 0 : Ok(())
179 0 : }
180 :
181 : pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
182 : impl<T: AsyncRead + AsyncWrite + Send + 'static> AsyncReadWrite for T {}
183 : pub type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
184 :
185 : #[async_trait]
186 : trait MaybeTlsAcceptor: Send + Sync + 'static {
187 : async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW>;
188 : }
189 :
190 : #[async_trait]
191 : impl MaybeTlsAcceptor for rustls::ServerConfig {
192 0 : async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
193 0 : Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?))
194 0 : }
195 : }
196 :
197 : struct NoTls;
198 :
199 : #[async_trait]
200 : impl MaybeTlsAcceptor for NoTls {
201 0 : async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
202 0 : Ok(Box::pin(conn))
203 0 : }
204 : }
205 :
206 : /// Handles the TCP startup lifecycle.
207 : /// 1. Parses PROXY protocol V2
208 : /// 2. Handles TLS handshake
209 0 : async fn connection_startup(
210 0 : config: &ProxyConfig,
211 0 : tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
212 0 : session_id: uuid::Uuid,
213 0 : conn: TcpStream,
214 0 : peer_addr: SocketAddr,
215 0 : ) -> Option<(AsyncRW, IpAddr)> {
216 : // handle PROXY protocol
217 0 : let (conn, peer) = match read_proxy_protocol(conn).await {
218 0 : Ok(c) => c,
219 0 : Err(e) => {
220 0 : tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
221 0 : return None;
222 : }
223 : };
224 :
225 0 : let peer_addr = peer.unwrap_or(peer_addr).ip();
226 0 : let has_private_peer_addr = match peer_addr {
227 0 : IpAddr::V4(ip) => ip.is_private(),
228 0 : IpAddr::V6(_) => false,
229 : };
230 0 : info!(?session_id, %peer_addr, "accepted new TCP connection");
231 :
232 : // try upgrade to TLS, but with a timeout.
233 0 : let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
234 0 : Ok(Ok(conn)) => {
235 0 : info!(?session_id, %peer_addr, "accepted new TLS connection");
236 0 : conn
237 : }
238 : // The handshake failed
239 0 : Ok(Err(e)) => {
240 0 : if !has_private_peer_addr {
241 0 : Metrics::get().proxy.tls_handshake_failures.inc();
242 0 : }
243 0 : warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
244 0 : return None;
245 : }
246 : // The handshake timed out
247 0 : Err(e) => {
248 0 : if !has_private_peer_addr {
249 0 : Metrics::get().proxy.tls_handshake_failures.inc();
250 0 : }
251 0 : warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
252 0 : return None;
253 : }
254 : };
255 :
256 0 : Some((conn, peer_addr))
257 0 : }
258 :
259 : /// Handles HTTP connection
260 : /// 1. With graceful shutdowns
261 : /// 2. With graceful request cancellation with connection failure
262 : /// 3. With websocket upgrade support.
263 : #[allow(clippy::too_many_arguments)]
264 0 : async fn connection_handler(
265 0 : config: &'static ProxyConfig,
266 0 : backend: Arc<PoolingBackend>,
267 0 : connections: TaskTracker,
268 0 : cancellation_handler: Arc<CancellationHandlerMain>,
269 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
270 0 : cancellation_token: CancellationToken,
271 0 : conn: AsyncRW,
272 0 : peer_addr: IpAddr,
273 0 : session_id: uuid::Uuid,
274 0 : ) {
275 0 : let session_id = AtomicTake::new(session_id);
276 0 :
277 0 : // Cancel all current inflight HTTP requests if the HTTP connection is closed.
278 0 : let http_cancellation_token = CancellationToken::new();
279 0 : let _cancel_connection = http_cancellation_token.clone().drop_guard();
280 0 :
281 0 : let server = Builder::new(TokioExecutor::new());
282 0 : let conn = server.serve_connection_with_upgrades(
283 0 : hyper_util::rt::TokioIo::new(conn),
284 0 : hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
285 0 : // First HTTP request shares the same session ID
286 0 : let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
287 0 :
288 0 : // Cancel the current inflight HTTP request if the requets stream is closed.
289 0 : // This is slightly different to `_cancel_connection` in that
290 0 : // h2 can cancel individual requests with a `RST_STREAM`.
291 0 : let http_request_token = http_cancellation_token.child_token();
292 0 : let cancel_request = http_request_token.clone().drop_guard();
293 0 :
294 0 : // `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
295 0 : // By spawning the future, we ensure it never gets cancelled until it decides to.
296 0 : let handler = connections.spawn(
297 0 : request_handler(
298 0 : req,
299 0 : config,
300 0 : backend.clone(),
301 0 : connections.clone(),
302 0 : cancellation_handler.clone(),
303 0 : session_id,
304 0 : peer_addr,
305 0 : http_request_token,
306 0 : endpoint_rate_limiter.clone(),
307 0 : )
308 0 : .in_current_span()
309 0 : .map_ok_or_else(api_error_into_response, |r| r),
310 0 : );
311 0 : async move {
312 0 : let res = handler.await;
313 0 : cancel_request.disarm();
314 0 : res
315 0 : }
316 0 : }),
317 0 : );
318 :
319 : // On cancellation, trigger the HTTP connection handler to shut down.
320 0 : let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
321 0 : Either::Left((_cancelled, mut conn)) => {
322 0 : tracing::debug!(%peer_addr, "cancelling connection");
323 0 : conn.as_mut().graceful_shutdown();
324 0 : conn.await
325 : }
326 0 : Either::Right((res, _)) => res,
327 : };
328 :
329 0 : match res {
330 0 : Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
331 0 : Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
332 : }
333 0 : }
334 :
335 : #[allow(clippy::too_many_arguments)]
336 0 : async fn request_handler(
337 0 : mut request: hyper1::Request<Incoming>,
338 0 : config: &'static ProxyConfig,
339 0 : backend: Arc<PoolingBackend>,
340 0 : ws_connections: TaskTracker,
341 0 : cancellation_handler: Arc<CancellationHandlerMain>,
342 0 : session_id: uuid::Uuid,
343 0 : peer_addr: IpAddr,
344 0 : // used to cancel in-flight HTTP requests. not used to cancel websockets
345 0 : http_cancellation_token: CancellationToken,
346 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
347 0 : ) -> Result<Response<Full<Bytes>>, ApiError> {
348 0 : let host = request
349 0 : .headers()
350 0 : .get("host")
351 0 : .and_then(|h| h.to_str().ok())
352 0 : .and_then(|h| h.split(':').next())
353 0 : .map(|s| s.to_string());
354 0 :
355 0 : // Check if the request is a websocket upgrade request.
356 0 : if config.http_config.accept_websockets
357 0 : && framed_websockets::upgrade::is_upgrade_request(&request)
358 : {
359 0 : let ctx = RequestMonitoring::new(
360 0 : session_id,
361 0 : peer_addr,
362 0 : crate::metrics::Protocol::Ws,
363 0 : &config.region,
364 0 : );
365 0 :
366 0 : let span = ctx.span();
367 : info!(parent: &span, "performing websocket upgrade");
368 :
369 0 : let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
370 0 : .map_err(|e| ApiError::BadRequest(e.into()))?;
371 :
372 0 : ws_connections.spawn(
373 0 : async move {
374 0 : if let Err(e) = websocket::serve_websocket(
375 0 : config,
376 0 : ctx,
377 0 : websocket,
378 0 : cancellation_handler,
379 0 : endpoint_rate_limiter,
380 0 : host,
381 0 : )
382 0 : .await
383 : {
384 0 : error!("error in websocket connection: {e:#}");
385 0 : }
386 0 : }
387 0 : .instrument(span),
388 0 : );
389 0 :
390 0 : // Return the response so the spawned future can continue.
391 0 : Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
392 0 : } else if request.uri().path() == "/sql" && *request.method() == Method::POST {
393 0 : let ctx = RequestMonitoring::new(
394 0 : session_id,
395 0 : peer_addr,
396 0 : crate::metrics::Protocol::Http,
397 0 : &config.region,
398 0 : );
399 0 : let span = ctx.span();
400 0 :
401 0 : sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
402 0 : .instrument(span)
403 0 : .await
404 0 : } else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
405 0 : Response::builder()
406 0 : .header("Allow", "OPTIONS, POST")
407 0 : .header("Access-Control-Allow-Origin", "*")
408 0 : .header(
409 0 : "Access-Control-Allow-Headers",
410 0 : "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
411 0 : )
412 0 : .header("Access-Control-Max-Age", "86400" /* 24 hours */)
413 0 : .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
414 0 : .body(Full::new(Bytes::new()))
415 0 : .map_err(|e| ApiError::InternalServerError(e.into()))
416 : } else {
417 0 : json_response(StatusCode::BAD_REQUEST, "query is not supported")
418 : }
419 0 : }
|