Line data Source code
1 : //! Routers for our serverless APIs
2 : //!
3 : //! Handles both SQL over HTTP and SQL over Websockets.
4 :
5 : mod backend;
6 : pub mod cancel_set;
7 : mod conn_pool;
8 : mod conn_pool_lib;
9 : mod error;
10 : mod http_conn_pool;
11 : mod http_util;
12 : mod json;
13 : mod local_conn_pool;
14 : mod sql_over_http;
15 : mod websocket;
16 :
17 : use std::net::{IpAddr, SocketAddr};
18 : use std::pin::{Pin, pin};
19 : use std::sync::Arc;
20 :
21 : use anyhow::Context;
22 : use arc_swap::ArcSwapOption;
23 : use async_trait::async_trait;
24 : use atomic_take::AtomicTake;
25 : use bytes::Bytes;
26 : pub use conn_pool_lib::GlobalConnPoolOptions;
27 : use futures::TryFutureExt;
28 : use futures::future::{Either, select};
29 : use http::{Method, Response, StatusCode};
30 : use http_body_util::combinators::BoxBody;
31 : use http_body_util::{BodyExt, Empty};
32 : use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
33 : use http_utils::error::ApiError;
34 : use hyper::body::Incoming;
35 : use hyper_util::rt::TokioExecutor;
36 : use hyper_util::server::conn::auto::Builder;
37 : use rand::SeedableRng;
38 : use rand::rngs::StdRng;
39 : use tokio::io::{AsyncRead, AsyncWrite};
40 : use tokio::net::{TcpListener, TcpStream};
41 : use tokio::time::timeout;
42 : use tokio_rustls::TlsAcceptor;
43 : use tokio_util::sync::CancellationToken;
44 : use tokio_util::task::TaskTracker;
45 : use tracing::{Instrument, info, warn};
46 :
47 : use crate::cancellation::CancellationHandler;
48 : use crate::config::{ProxyConfig, ProxyProtocolV2};
49 : use crate::context::RequestContext;
50 : use crate::ext::TaskExt;
51 : use crate::metrics::Metrics;
52 : use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
53 : use crate::rate_limiter::EndpointRateLimiter;
54 : use crate::serverless::backend::PoolingBackend;
55 : use crate::serverless::http_util::{api_error_into_response, json_response};
56 : use crate::util::run_until_cancelled;
57 :
58 : pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
59 : pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";
60 :
61 0 : pub async fn task_main(
62 0 : config: &'static ProxyConfig,
63 0 : auth_backend: &'static crate::auth::Backend<'static, ()>,
64 0 : ws_listener: TcpListener,
65 0 : cancellation_token: CancellationToken,
66 0 : cancellation_handler: Arc<CancellationHandler>,
67 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
68 0 : ) -> anyhow::Result<()> {
69 0 : scopeguard::defer! {
70 : info!("websocket server has shut down");
71 : }
72 :
73 0 : let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config);
74 0 : let conn_pool = conn_pool_lib::GlobalConnPool::new(&config.http_config);
75 : {
76 0 : let conn_pool = Arc::clone(&conn_pool);
77 0 : tokio::spawn(async move {
78 0 : conn_pool.gc_worker(StdRng::from_entropy()).await;
79 0 : });
80 : }
81 :
82 : // shutdown the connection pool
83 0 : tokio::spawn({
84 0 : let cancellation_token = cancellation_token.clone();
85 0 : let conn_pool = conn_pool.clone();
86 0 : async move {
87 0 : cancellation_token.cancelled().await;
88 0 : tokio::task::spawn_blocking(move || conn_pool.shutdown())
89 0 : .await
90 0 : .propagate_task_panic();
91 0 : }
92 : });
93 :
94 0 : let http_conn_pool = conn_pool_lib::GlobalConnPool::new(&config.http_config);
95 : {
96 0 : let http_conn_pool = Arc::clone(&http_conn_pool);
97 0 : tokio::spawn(async move {
98 0 : http_conn_pool.gc_worker(StdRng::from_entropy()).await;
99 0 : });
100 : }
101 :
102 : // shutdown the connection pool
103 0 : tokio::spawn({
104 0 : let cancellation_token = cancellation_token.clone();
105 0 : let http_conn_pool = http_conn_pool.clone();
106 0 : async move {
107 0 : cancellation_token.cancelled().await;
108 0 : tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
109 0 : .await
110 0 : .propagate_task_panic();
111 0 : }
112 : });
113 :
114 0 : let backend = Arc::new(PoolingBackend {
115 0 : http_conn_pool: Arc::clone(&http_conn_pool),
116 0 : local_pool,
117 0 : pool: Arc::clone(&conn_pool),
118 0 : config,
119 0 : auth_backend,
120 0 : endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
121 0 : });
122 0 : let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = Arc::new(&config.tls_config);
123 :
124 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
125 0 : connections.close(); // allows `connections.wait to complete`
126 :
127 0 : let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
128 0 : while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
129 0 : let (conn, peer_addr) = res.context("could not accept TCP stream")?;
130 0 : if let Err(e) = conn.set_nodelay(true) {
131 0 : tracing::error!("could not set nodelay: {e}");
132 0 : continue;
133 0 : }
134 0 : let conn_id = uuid::Uuid::new_v4();
135 0 : let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
136 :
137 0 : let n_connections = Metrics::get()
138 0 : .proxy
139 0 : .client_connections
140 0 : .sample(crate::metrics::Protocol::Http);
141 0 : tracing::trace!(?n_connections, threshold = ?config.http_config.client_conn_threshold, "check");
142 0 : if n_connections > config.http_config.client_conn_threshold {
143 0 : tracing::trace!("attempting to cancel a random connection");
144 0 : if let Some(token) = config.http_config.cancel_set.take() {
145 0 : tracing::debug!("cancelling a random connection");
146 0 : token.cancel();
147 0 : }
148 0 : }
149 :
150 0 : let conn_token = cancellation_token.child_token();
151 0 : let tls_acceptor = tls_acceptor.clone();
152 0 : let backend = backend.clone();
153 0 : let connections2 = connections.clone();
154 0 : let cancellation_handler = cancellation_handler.clone();
155 0 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
156 0 : let cancellations = cancellations.clone();
157 0 : connections.spawn(
158 0 : async move {
159 0 : let conn_token2 = conn_token.clone();
160 0 : let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
161 :
162 0 : let session_id = uuid::Uuid::new_v4();
163 :
164 0 : let _gauge = Metrics::get()
165 0 : .proxy
166 0 : .client_connections
167 0 : .guard(crate::metrics::Protocol::Http);
168 :
169 0 : let startup_result = Box::pin(connection_startup(
170 0 : config,
171 0 : tls_acceptor,
172 0 : session_id,
173 0 : conn,
174 0 : peer_addr,
175 0 : ))
176 0 : .await;
177 0 : let Some((conn, conn_info)) = startup_result else {
178 0 : return;
179 : };
180 :
181 0 : Box::pin(connection_handler(
182 0 : config,
183 0 : backend,
184 0 : connections2,
185 0 : cancellations,
186 0 : cancellation_handler,
187 0 : endpoint_rate_limiter,
188 0 : conn_token,
189 0 : conn,
190 0 : conn_info,
191 0 : session_id,
192 0 : ))
193 0 : .await;
194 0 : }
195 0 : .instrument(http_conn_span),
196 : );
197 : }
198 :
199 0 : connections.wait().await;
200 :
201 0 : Ok(())
202 0 : }
203 :
204 : pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
205 : impl<T: AsyncRead + AsyncWrite + Send + 'static> AsyncReadWrite for T {}
206 : pub(crate) type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
207 :
208 : #[async_trait]
209 : trait MaybeTlsAcceptor: Send + Sync + 'static {
210 : async fn accept(&self, conn: TcpStream) -> std::io::Result<AsyncRW>;
211 : }
212 :
213 : #[async_trait]
214 : impl MaybeTlsAcceptor for &'static ArcSwapOption<crate::config::TlsConfig> {
215 0 : async fn accept(&self, conn: TcpStream) -> std::io::Result<AsyncRW> {
216 0 : match &*self.load() {
217 0 : Some(config) => Ok(Box::pin(
218 0 : TlsAcceptor::from(config.http_config.clone())
219 0 : .accept(conn)
220 0 : .await?,
221 : )),
222 0 : None => Ok(Box::pin(conn)),
223 : }
224 0 : }
225 : }
226 :
227 : /// Handles the TCP startup lifecycle.
228 : /// 1. Parses PROXY protocol V2
229 : /// 2. Handles TLS handshake
230 0 : async fn connection_startup(
231 0 : config: &ProxyConfig,
232 0 : tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
233 0 : session_id: uuid::Uuid,
234 0 : conn: TcpStream,
235 0 : peer_addr: SocketAddr,
236 0 : ) -> Option<(AsyncRW, ConnectionInfo)> {
237 : // handle PROXY protocol
238 0 : let (conn, conn_info) = match config.proxy_protocol_v2 {
239 : ProxyProtocolV2::Required => {
240 0 : match read_proxy_protocol(conn).await {
241 0 : Err(e) => {
242 0 : warn!("per-client task finished with an error: {e:#}");
243 0 : return None;
244 : }
245 : // our load balancers will not send any more data. let's just exit immediately
246 0 : Ok((_conn, ConnectHeader::Local)) => {
247 0 : tracing::debug!("healthcheck received");
248 0 : return None;
249 : }
250 0 : Ok((conn, ConnectHeader::Proxy(info))) => (conn, info),
251 : }
252 : }
253 : // ignore the header - it cannot be confused for a postgres or http connection so will
254 : // error later.
255 0 : ProxyProtocolV2::Rejected => (
256 0 : conn,
257 0 : ConnectionInfo {
258 0 : addr: peer_addr,
259 0 : extra: None,
260 0 : },
261 0 : ),
262 : };
263 :
264 0 : let has_private_peer_addr = match conn_info.addr.ip() {
265 0 : IpAddr::V4(ip) => ip.is_private(),
266 0 : IpAddr::V6(_) => false,
267 : };
268 0 : info!(?session_id, %conn_info, "accepted new TCP connection");
269 :
270 : // try upgrade to TLS, but with a timeout.
271 0 : let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
272 0 : Ok(Ok(conn)) => {
273 0 : info!(?session_id, %conn_info, "accepted new TLS connection");
274 0 : conn
275 : }
276 : // The handshake failed
277 0 : Ok(Err(e)) => {
278 0 : if !has_private_peer_addr {
279 0 : Metrics::get().proxy.tls_handshake_failures.inc();
280 0 : }
281 0 : warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
282 0 : return None;
283 : }
284 : // The handshake timed out
285 0 : Err(e) => {
286 0 : if !has_private_peer_addr {
287 0 : Metrics::get().proxy.tls_handshake_failures.inc();
288 0 : }
289 0 : warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
290 0 : return None;
291 : }
292 : };
293 :
294 0 : Some((conn, conn_info))
295 0 : }
296 :
297 : /// Handles HTTP connection
298 : /// 1. With graceful shutdowns
299 : /// 2. With graceful request cancellation with connection failure
300 : /// 3. With websocket upgrade support.
301 : #[allow(clippy::too_many_arguments)]
302 0 : async fn connection_handler(
303 0 : config: &'static ProxyConfig,
304 0 : backend: Arc<PoolingBackend>,
305 0 : connections: TaskTracker,
306 0 : cancellations: TaskTracker,
307 0 : cancellation_handler: Arc<CancellationHandler>,
308 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
309 0 : cancellation_token: CancellationToken,
310 0 : conn: AsyncRW,
311 0 : conn_info: ConnectionInfo,
312 0 : session_id: uuid::Uuid,
313 0 : ) {
314 0 : let session_id = AtomicTake::new(session_id);
315 :
316 : // Cancel all current inflight HTTP requests if the HTTP connection is closed.
317 0 : let http_cancellation_token = CancellationToken::new();
318 0 : let _cancel_connection = http_cancellation_token.clone().drop_guard();
319 :
320 0 : let conn_info2 = conn_info.clone();
321 0 : let server = Builder::new(TokioExecutor::new());
322 0 : let conn = server.serve_connection_with_upgrades(
323 0 : hyper_util::rt::TokioIo::new(conn),
324 0 : hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
325 : // First HTTP request shares the same session ID
326 0 : let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
327 :
328 0 : if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) {
329 : // take session_id from request, if given.
330 0 : if let Some(id) = req
331 0 : .headers()
332 0 : .get(&NEON_REQUEST_ID)
333 0 : .and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
334 0 : {
335 0 : session_id = id;
336 0 : }
337 0 : }
338 :
339 : // Cancel the current inflight HTTP request if the requets stream is closed.
340 : // This is slightly different to `_cancel_connection` in that
341 : // h2 can cancel individual requests with a `RST_STREAM`.
342 0 : let http_request_token = http_cancellation_token.child_token();
343 0 : let cancel_request = http_request_token.clone().drop_guard();
344 :
345 : // `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
346 : // By spawning the future, we ensure it never gets cancelled until it decides to.
347 0 : let cancellations = cancellations.clone();
348 0 : let handler = connections.spawn(
349 0 : request_handler(
350 0 : req,
351 0 : config,
352 0 : backend.clone(),
353 0 : connections.clone(),
354 0 : cancellation_handler.clone(),
355 0 : session_id,
356 0 : conn_info2.clone(),
357 0 : http_request_token,
358 0 : endpoint_rate_limiter.clone(),
359 0 : cancellations,
360 : )
361 0 : .in_current_span()
362 0 : .map_ok_or_else(api_error_into_response, |r| r),
363 : );
364 0 : async move {
365 0 : let mut res = handler.await;
366 0 : cancel_request.disarm();
367 :
368 : // add the session ID to the response
369 0 : if let Ok(resp) = &mut res {
370 0 : resp.headers_mut()
371 0 : .append(&NEON_REQUEST_ID, uuid_to_header_value(session_id));
372 0 : }
373 :
374 0 : res
375 0 : }
376 0 : }),
377 : );
378 :
379 : // On cancellation, trigger the HTTP connection handler to shut down.
380 0 : let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
381 0 : Either::Left((_cancelled, mut conn)) => {
382 0 : tracing::debug!(%conn_info, "cancelling connection");
383 0 : conn.as_mut().graceful_shutdown();
384 0 : conn.await
385 : }
386 0 : Either::Right((res, _)) => res,
387 : };
388 :
389 0 : match res {
390 0 : Ok(()) => tracing::info!(%conn_info, "HTTP connection closed"),
391 0 : Err(e) => tracing::warn!(%conn_info, "HTTP connection error {e}"),
392 : }
393 0 : }
394 :
395 : #[allow(clippy::too_many_arguments)]
396 0 : async fn request_handler(
397 0 : mut request: hyper::Request<Incoming>,
398 0 : config: &'static ProxyConfig,
399 0 : backend: Arc<PoolingBackend>,
400 0 : ws_connections: TaskTracker,
401 0 : cancellation_handler: Arc<CancellationHandler>,
402 0 : session_id: uuid::Uuid,
403 0 : conn_info: ConnectionInfo,
404 0 : // used to cancel in-flight HTTP requests. not used to cancel websockets
405 0 : http_cancellation_token: CancellationToken,
406 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
407 0 : cancellations: TaskTracker,
408 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
409 0 : let host = request
410 0 : .headers()
411 0 : .get("host")
412 0 : .and_then(|h| h.to_str().ok())
413 0 : .and_then(|h| h.split(':').next())
414 0 : .map(|s| s.to_string());
415 :
416 : // Check if the request is a websocket upgrade request.
417 0 : if config.http_config.accept_websockets
418 0 : && framed_websockets::upgrade::is_upgrade_request(&request)
419 : {
420 0 : let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
421 :
422 0 : ctx.set_user_agent(
423 0 : request
424 0 : .headers()
425 0 : .get(hyper::header::USER_AGENT)
426 0 : .and_then(|h| h.to_str().ok())
427 0 : .map(Into::into),
428 : );
429 :
430 0 : let span = ctx.span();
431 0 : info!(parent: &span, "performing websocket upgrade");
432 :
433 0 : let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
434 0 : .map_err(|e| ApiError::BadRequest(e.into()))?;
435 :
436 0 : let cancellations = cancellations.clone();
437 0 : ws_connections.spawn(
438 0 : async move {
439 0 : if let Err(e) = websocket::serve_websocket(
440 0 : config,
441 0 : backend.auth_backend,
442 0 : ctx,
443 0 : websocket,
444 0 : cancellation_handler,
445 0 : endpoint_rate_limiter,
446 0 : host,
447 0 : cancellations,
448 : )
449 0 : .await
450 : {
451 0 : warn!("error in websocket connection: {e:#}");
452 0 : }
453 0 : }
454 0 : .instrument(span),
455 : );
456 :
457 : // Return the response so the spawned future can continue.
458 0 : Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
459 0 : } else if request.uri().path() == "/sql" && *request.method() == Method::POST {
460 0 : let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
461 0 : let span = ctx.span();
462 :
463 0 : let testodrome_id = request
464 0 : .headers()
465 0 : .get("X-Neon-Query-ID")
466 0 : .and_then(|value| value.to_str().ok())
467 0 : .map(|s| s.to_string());
468 :
469 0 : if let Some(query_id) = testodrome_id {
470 0 : info!(parent: &ctx.span(), "testodrome query ID: {query_id}");
471 0 : ctx.set_testodrome_id(query_id.into());
472 0 : }
473 :
474 0 : sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
475 0 : .instrument(span)
476 0 : .await
477 0 : } else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
478 0 : Response::builder()
479 0 : .header("Allow", "OPTIONS, POST")
480 0 : .header("Access-Control-Allow-Origin", "*")
481 0 : .header(
482 : "Access-Control-Allow-Headers",
483 : "Authorization, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
484 : )
485 0 : .header("Access-Control-Max-Age", "86400" /* 24 hours */)
486 0 : .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
487 0 : .body(Empty::new().map_err(|x| match x {}).boxed())
488 0 : .map_err(|e| ApiError::InternalServerError(e.into()))
489 : } else {
490 0 : json_response(StatusCode::BAD_REQUEST, "query is not supported")
491 : }
492 0 : }
|