Line data Source code
1 : //! Routers for our serverless APIs
2 : //!
3 : //! Handles both SQL over HTTP and SQL over Websockets.
4 :
5 : mod backend;
6 : mod conn_pool;
7 : mod json;
8 : mod sql_over_http;
9 : pub mod tls_listener;
10 : mod websocket;
11 :
12 : pub use conn_pool::GlobalConnPoolOptions;
13 :
14 : use anyhow::bail;
15 : use hyper::StatusCode;
16 : use metrics::IntCounterPairGuard;
17 : use rand::rngs::StdRng;
18 : use rand::SeedableRng;
19 : pub use reqwest_middleware::{ClientWithMiddleware, Error};
20 : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
21 : use tokio_util::task::TaskTracker;
22 : use tracing::instrument::Instrumented;
23 :
24 : use crate::cancellation::CancellationHandlerMain;
25 : use crate::config::ProxyConfig;
26 : use crate::context::RequestMonitoring;
27 : use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
28 : use crate::rate_limiter::EndpointRateLimiter;
29 : use crate::serverless::backend::PoolingBackend;
30 : use hyper::{
31 : server::conn::{AddrIncoming, AddrStream},
32 : Body, Method, Request, Response,
33 : };
34 :
35 : use std::net::IpAddr;
36 : use std::sync::Arc;
37 : use std::task::Poll;
38 : use tls_listener::TlsListener;
39 : use tokio::net::TcpListener;
40 : use tokio_util::sync::{CancellationToken, DropGuard};
41 : use tracing::{error, info, warn, Instrument};
42 : use utils::http::{error::ApiError, json::json_response};
43 :
44 : pub const SERVERLESS_DRIVER_SNI: &str = "api";
45 :
46 0 : pub async fn task_main(
47 0 : config: &'static ProxyConfig,
48 0 : ws_listener: TcpListener,
49 0 : cancellation_token: CancellationToken,
50 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
51 0 : cancellation_handler: Arc<CancellationHandlerMain>,
52 0 : ) -> anyhow::Result<()> {
53 0 : scopeguard::defer! {
54 0 : info!("websocket server has shut down");
55 : }
56 :
57 0 : let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
58 0 : {
59 0 : let conn_pool = Arc::clone(&conn_pool);
60 0 : tokio::spawn(async move {
61 0 : conn_pool.gc_worker(StdRng::from_entropy()).await;
62 0 : });
63 0 : }
64 0 :
65 0 : // shutdown the connection pool
66 0 : tokio::spawn({
67 0 : let cancellation_token = cancellation_token.clone();
68 0 : let conn_pool = conn_pool.clone();
69 0 : async move {
70 0 : cancellation_token.cancelled().await;
71 0 : tokio::task::spawn_blocking(move || conn_pool.shutdown())
72 0 : .await
73 0 : .unwrap();
74 0 : }
75 0 : });
76 0 :
77 0 : let backend = Arc::new(PoolingBackend {
78 0 : pool: Arc::clone(&conn_pool),
79 0 : config,
80 0 : });
81 :
82 0 : let tls_config = match config.tls_config.as_ref() {
83 0 : Some(config) => config,
84 : None => {
85 0 : warn!("TLS config is missing, WebSocket Secure server will not be started");
86 0 : return Ok(());
87 : }
88 : };
89 0 : let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
90 0 : // prefer http2, but support http/1.1
91 0 : tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
92 0 : let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
93 :
94 0 : let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
95 0 : let _ = addr_incoming.set_nodelay(true);
96 0 : let addr_incoming = ProxyProtocolAccept {
97 0 : incoming: addr_incoming,
98 0 : protocol: "http",
99 0 : };
100 0 :
101 0 : let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
102 0 : ws_connections.close(); // allows `ws_connections.wait to complete`
103 0 :
104 0 : let tls_listener = TlsListener::new(tls_acceptor, addr_incoming, config.handshake_timeout);
105 0 :
106 0 : let make_svc = hyper::service::make_service_fn(
107 0 : |stream: &tokio_rustls::server::TlsStream<
108 : WithConnectionGuard<WithClientIp<AddrStream>>,
109 0 : >| {
110 0 : let (conn, _) = stream.get_ref();
111 0 :
112 0 : // this is jank. should dissapear with hyper 1.0 migration.
113 0 : let gauge = conn
114 0 : .gauge
115 0 : .lock()
116 0 : .expect("lock should not be poisoned")
117 0 : .take()
118 0 : .expect("gauge should be set on connection start");
119 0 :
120 0 : // Cancel all current inflight HTTP requests if the HTTP connection is closed.
121 0 : let http_cancellation_token = CancellationToken::new();
122 0 : let cancel_connection = http_cancellation_token.clone().drop_guard();
123 0 :
124 0 : let span = conn.span.clone();
125 0 : let client_addr = conn.inner.client_addr();
126 0 : let remote_addr = conn.inner.inner.remote_addr();
127 0 : let backend = backend.clone();
128 0 : let ws_connections = ws_connections.clone();
129 0 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
130 0 : let cancellation_handler = cancellation_handler.clone();
131 0 : async move {
132 0 : let peer_addr = match client_addr {
133 0 : Some(addr) => addr,
134 0 : None if config.require_client_ip => bail!("missing required client ip"),
135 0 : None => remote_addr,
136 : };
137 0 : Ok(MetricService::new(
138 0 : hyper::service::service_fn(move |req: Request<Body>| {
139 0 : let backend = backend.clone();
140 0 : let ws_connections2 = ws_connections.clone();
141 0 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
142 0 : let cancellation_handler = cancellation_handler.clone();
143 0 : let http_cancellation_token = http_cancellation_token.child_token();
144 0 :
145 0 : // `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
146 0 : // By spawning the future, we ensure it never gets cancelled until it decides to.
147 0 : ws_connections.spawn(
148 0 : async move {
149 0 : // Cancel the current inflight HTTP request if the requets stream is closed.
150 0 : // This is slightly different to `_cancel_connection` in that
151 0 : // h2 can cancel individual requests with a `RST_STREAM`.
152 0 : let _cancel_session = http_cancellation_token.clone().drop_guard();
153 :
154 0 : let res = request_handler(
155 0 : req,
156 0 : config,
157 0 : backend,
158 0 : ws_connections2,
159 0 : cancellation_handler,
160 0 : peer_addr.ip(),
161 0 : endpoint_rate_limiter,
162 0 : http_cancellation_token,
163 0 : )
164 0 : .await
165 0 : .map_or_else(|e| e.into_response(), |r| r);
166 0 :
167 0 : _cancel_session.disarm();
168 0 :
169 0 : res
170 0 : }
171 0 : .in_current_span(),
172 0 : )
173 0 : }),
174 0 : gauge,
175 0 : cancel_connection,
176 0 : span,
177 0 : ))
178 0 : }
179 0 : },
180 0 : );
181 0 :
182 0 : hyper::Server::builder(tls_listener)
183 0 : .serve(make_svc)
184 0 : .with_graceful_shutdown(cancellation_token.cancelled())
185 0 : .await?;
186 :
187 : // await websocket connections
188 0 : ws_connections.wait().await;
189 :
190 0 : Ok(())
191 0 : }
192 :
193 : struct MetricService<S> {
194 : inner: S,
195 : _gauge: IntCounterPairGuard,
196 : _cancel: DropGuard,
197 : span: tracing::Span,
198 : }
199 :
200 : impl<S> MetricService<S> {
201 0 : fn new(
202 0 : inner: S,
203 0 : _gauge: IntCounterPairGuard,
204 0 : _cancel: DropGuard,
205 0 : span: tracing::Span,
206 0 : ) -> MetricService<S> {
207 0 : MetricService {
208 0 : inner,
209 0 : _gauge,
210 0 : _cancel,
211 0 : span,
212 0 : }
213 0 : }
214 : }
215 :
216 : impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
217 : where
218 : S: hyper::service::Service<Request<ReqBody>>,
219 : {
220 : type Response = S::Response;
221 : type Error = S::Error;
222 : type Future = Instrumented<S::Future>;
223 :
224 0 : fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
225 0 : self.inner.poll_ready(cx)
226 0 : }
227 :
228 0 : fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
229 0 : self.span
230 0 : .in_scope(|| self.inner.call(req))
231 0 : .instrument(self.span.clone())
232 0 : }
233 : }
234 :
235 : #[allow(clippy::too_many_arguments)]
236 0 : async fn request_handler(
237 0 : mut request: Request<Body>,
238 0 : config: &'static ProxyConfig,
239 0 : backend: Arc<PoolingBackend>,
240 0 : ws_connections: TaskTracker,
241 0 : cancellation_handler: Arc<CancellationHandlerMain>,
242 0 : peer_addr: IpAddr,
243 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
244 0 : // used to cancel in-flight HTTP requests. not used to cancel websockets
245 0 : http_cancellation_token: CancellationToken,
246 0 : ) -> Result<Response<Body>, ApiError> {
247 0 : let session_id = uuid::Uuid::new_v4();
248 0 :
249 0 : let host = request
250 0 : .headers()
251 0 : .get("host")
252 0 : .and_then(|h| h.to_str().ok())
253 0 : .and_then(|h| h.split(':').next())
254 0 : .map(|s| s.to_string());
255 0 :
256 0 : // Check if the request is a websocket upgrade request.
257 0 : if hyper_tungstenite::is_upgrade_request(&request) {
258 0 : let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
259 0 : let span = ctx.span.clone();
260 0 : info!(parent: &span, "performing websocket upgrade");
261 :
262 0 : let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
263 0 : .map_err(|e| ApiError::BadRequest(e.into()))?;
264 :
265 0 : ws_connections.spawn(
266 0 : async move {
267 0 : if let Err(e) = websocket::serve_websocket(
268 0 : config,
269 0 : ctx,
270 0 : websocket,
271 0 : cancellation_handler,
272 0 : host,
273 0 : endpoint_rate_limiter,
274 0 : )
275 0 : .await
276 : {
277 0 : error!("error in websocket connection: {e:#}");
278 0 : }
279 0 : }
280 0 : .instrument(span),
281 0 : );
282 0 :
283 0 : // Return the response so the spawned future can continue.
284 0 : Ok(response)
285 0 : } else if request.uri().path() == "/sql" && request.method() == Method::POST {
286 0 : let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
287 0 : let span = ctx.span.clone();
288 0 :
289 0 : sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
290 0 : .instrument(span)
291 0 : .await
292 0 : } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
293 0 : Response::builder()
294 0 : .header("Allow", "OPTIONS, POST")
295 0 : .header("Access-Control-Allow-Origin", "*")
296 0 : .header(
297 0 : "Access-Control-Allow-Headers",
298 0 : "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
299 0 : )
300 0 : .header("Access-Control-Max-Age", "86400" /* 24 hours */)
301 0 : .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
302 0 : .body(Body::empty())
303 0 : .map_err(|e| ApiError::InternalServerError(e.into()))
304 : } else {
305 0 : json_response(StatusCode::BAD_REQUEST, "query is not supported")
306 : }
307 0 : }
|