Line data Source code
1 : //! Routers for our serverless APIs
2 : //!
3 : //! Handles both SQL over HTTP and SQL over Websockets.
4 :
5 : mod backend;
6 : mod conn_pool;
7 : mod json;
8 : mod sql_over_http;
9 : mod websocket;
10 :
11 : pub use conn_pool::GlobalConnPoolOptions;
12 :
13 : use anyhow::bail;
14 : use hyper::StatusCode;
15 : use metrics::IntCounterPairGuard;
16 : use rand::rngs::StdRng;
17 : use rand::SeedableRng;
18 : pub use reqwest_middleware::{ClientWithMiddleware, Error};
19 : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
20 : use tokio_util::task::TaskTracker;
21 :
22 : use crate::context::RequestMonitoring;
23 : use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
24 : use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
25 : use crate::rate_limiter::EndpointRateLimiter;
26 : use crate::serverless::backend::PoolingBackend;
27 : use crate::{cancellation::CancellationHandler, config::ProxyConfig};
28 : use futures::StreamExt;
29 : use hyper::{
30 : server::{
31 : accept,
32 : conn::{AddrIncoming, AddrStream},
33 : },
34 : Body, Method, Request, Response,
35 : };
36 :
37 : use std::convert::Infallible;
38 : use std::net::IpAddr;
39 : use std::task::Poll;
40 : use std::{future::ready, sync::Arc};
41 : use tls_listener::TlsListener;
42 : use tokio::net::TcpListener;
43 : use tokio_util::sync::CancellationToken;
44 : use tracing::{error, info, warn, Instrument};
45 : use utils::http::{error::ApiError, json::json_response};
46 :
47 : pub const SERVERLESS_DRIVER_SNI: &str = "api";
48 :
49 0 : pub async fn task_main(
50 0 : config: &'static ProxyConfig,
51 0 : ws_listener: TcpListener,
52 0 : cancellation_token: CancellationToken,
53 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
54 0 : cancellation_handler: Arc<CancellationHandler>,
55 0 : ) -> anyhow::Result<()> {
56 0 : scopeguard::defer! {
57 0 : info!("websocket server has shut down");
58 : }
59 :
60 0 : let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
61 0 : {
62 0 : let conn_pool = Arc::clone(&conn_pool);
63 0 : tokio::spawn(async move {
64 0 : conn_pool.gc_worker(StdRng::from_entropy()).await;
65 0 : });
66 0 : }
67 0 :
68 0 : // shutdown the connection pool
69 0 : tokio::spawn({
70 0 : let cancellation_token = cancellation_token.clone();
71 0 : let conn_pool = conn_pool.clone();
72 0 : async move {
73 0 : cancellation_token.cancelled().await;
74 0 : tokio::task::spawn_blocking(move || conn_pool.shutdown())
75 0 : .await
76 0 : .unwrap();
77 0 : }
78 0 : });
79 0 :
80 0 : let backend = Arc::new(PoolingBackend {
81 0 : pool: Arc::clone(&conn_pool),
82 0 : config,
83 0 : });
84 :
85 0 : let tls_config = match config.tls_config.as_ref() {
86 0 : Some(config) => config,
87 : None => {
88 0 : warn!("TLS config is missing, WebSocket Secure server will not be started");
89 0 : return Ok(());
90 : }
91 : };
92 0 : let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
93 0 : // prefer http2, but support http/1.1
94 0 : tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
95 0 : let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
96 :
97 0 : let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
98 0 : let _ = addr_incoming.set_nodelay(true);
99 0 : let addr_incoming = ProxyProtocolAccept {
100 0 : incoming: addr_incoming,
101 0 : };
102 0 :
103 0 : let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
104 0 : ws_connections.close(); // allows `ws_connections.wait to complete`
105 0 :
106 0 : let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
107 0 : if let Err(err) = conn {
108 0 : error!("failed to accept TLS connection for websockets: {err:?}");
109 0 : ready(false)
110 : } else {
111 0 : ready(true)
112 : }
113 0 : });
114 0 :
115 0 : let make_svc = hyper::service::make_service_fn(
116 0 : |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
117 0 : let (io, _) = stream.get_ref();
118 0 : let client_addr = io.client_addr();
119 0 : let remote_addr = io.inner.remote_addr();
120 0 : let backend = backend.clone();
121 0 : let ws_connections = ws_connections.clone();
122 0 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
123 0 : let cancellation_handler = cancellation_handler.clone();
124 0 : async move {
125 0 : let peer_addr = match client_addr {
126 0 : Some(addr) => addr,
127 0 : None if config.require_client_ip => bail!("missing required client ip"),
128 0 : None => remote_addr,
129 : };
130 0 : Ok(MetricService::new(hyper::service::service_fn(
131 0 : move |req: Request<Body>| {
132 0 : let backend = backend.clone();
133 0 : let ws_connections = ws_connections.clone();
134 0 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
135 0 : let cancellation_handler = cancellation_handler.clone();
136 :
137 0 : async move {
138 0 : Ok::<_, Infallible>(
139 0 : request_handler(
140 0 : req,
141 0 : config,
142 0 : backend,
143 0 : ws_connections,
144 0 : cancellation_handler,
145 0 : peer_addr.ip(),
146 0 : endpoint_rate_limiter,
147 0 : )
148 0 : .await
149 0 : .map_or_else(|e| e.into_response(), |r| r),
150 0 : )
151 0 : }
152 0 : },
153 0 : )))
154 0 : }
155 0 : },
156 0 : );
157 0 :
158 0 : hyper::Server::builder(accept::from_stream(tls_listener))
159 0 : .serve(make_svc)
160 0 : .with_graceful_shutdown(cancellation_token.cancelled())
161 0 : .await?;
162 :
163 : // await websocket connections
164 0 : ws_connections.wait().await;
165 :
166 0 : Ok(())
167 0 : }
168 :
169 : struct MetricService<S> {
170 : inner: S,
171 : _gauge: IntCounterPairGuard,
172 : }
173 :
174 : impl<S> MetricService<S> {
175 0 : fn new(inner: S) -> MetricService<S> {
176 0 : MetricService {
177 0 : inner,
178 0 : _gauge: NUM_CLIENT_CONNECTION_GAUGE
179 0 : .with_label_values(&["http"])
180 0 : .guard(),
181 0 : }
182 0 : }
183 : }
184 :
185 : impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
186 : where
187 : S: hyper::service::Service<Request<ReqBody>>,
188 : {
189 : type Response = S::Response;
190 : type Error = S::Error;
191 : type Future = S::Future;
192 :
193 0 : fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
194 0 : self.inner.poll_ready(cx)
195 0 : }
196 :
197 0 : fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
198 0 : self.inner.call(req)
199 0 : }
200 : }
201 :
202 : #[allow(clippy::too_many_arguments)]
203 0 : async fn request_handler(
204 0 : mut request: Request<Body>,
205 0 : config: &'static ProxyConfig,
206 0 : backend: Arc<PoolingBackend>,
207 0 : ws_connections: TaskTracker,
208 0 : cancellation_handler: Arc<CancellationHandler>,
209 0 : peer_addr: IpAddr,
210 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
211 0 : ) -> Result<Response<Body>, ApiError> {
212 0 : let session_id = uuid::Uuid::new_v4();
213 0 :
214 0 : let host = request
215 0 : .headers()
216 0 : .get("host")
217 0 : .and_then(|h| h.to_str().ok())
218 0 : .and_then(|h| h.split(':').next())
219 0 : .map(|s| s.to_string());
220 0 :
221 0 : // Check if the request is a websocket upgrade request.
222 0 : if hyper_tungstenite::is_upgrade_request(&request) {
223 0 : let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
224 0 : let span = ctx.span.clone();
225 0 : info!(parent: &span, "performing websocket upgrade");
226 :
227 0 : let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
228 0 : .map_err(|e| ApiError::BadRequest(e.into()))?;
229 :
230 0 : ws_connections.spawn(
231 0 : async move {
232 0 : if let Err(e) = websocket::serve_websocket(
233 0 : config,
234 0 : ctx,
235 0 : websocket,
236 0 : cancellation_handler,
237 0 : host,
238 0 : endpoint_rate_limiter,
239 0 : )
240 0 : .await
241 : {
242 0 : error!("error in websocket connection: {e:#}");
243 0 : }
244 0 : }
245 0 : .instrument(span),
246 0 : );
247 0 :
248 0 : // Return the response so the spawned future can continue.
249 0 : Ok(response)
250 0 : } else if request.uri().path() == "/sql" && request.method() == Method::POST {
251 0 : let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
252 0 : let span = ctx.span.clone();
253 0 :
254 0 : sql_over_http::handle(config, ctx, request, backend)
255 0 : .instrument(span)
256 0 : .await
257 0 : } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
258 0 : Response::builder()
259 0 : .header("Allow", "OPTIONS, POST")
260 0 : .header("Access-Control-Allow-Origin", "*")
261 0 : .header(
262 0 : "Access-Control-Allow-Headers",
263 0 : "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
264 0 : )
265 0 : .header("Access-Control-Max-Age", "86400" /* 24 hours */)
266 0 : .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
267 0 : .body(Body::empty())
268 0 : .map_err(|e| ApiError::InternalServerError(e.into()))
269 : } else {
270 0 : json_response(StatusCode::BAD_REQUEST, "query is not supported")
271 : }
272 0 : }
|