TLA Line data Source code
1 : //! Routers for our serverless APIs
2 : //!
3 : //! Handles both SQL over HTTP and SQL over Websockets.
4 :
5 : mod conn_pool;
6 : mod sql_over_http;
7 : mod websocket;
8 :
9 : pub use conn_pool::GlobalConnPoolOptions;
10 :
11 : use anyhow::bail;
12 : use hyper::StatusCode;
13 : use metrics::IntCounterPairGuard;
14 : use rand::rngs::StdRng;
15 : use rand::SeedableRng;
16 : pub use reqwest_middleware::{ClientWithMiddleware, Error};
17 : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
18 : use tokio_util::task::TaskTracker;
19 :
20 : use crate::context::RequestMonitoring;
21 : use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
22 : use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
23 : use crate::rate_limiter::EndpointRateLimiter;
24 : use crate::{cancellation::CancelMap, config::ProxyConfig};
25 : use futures::StreamExt;
26 : use hyper::{
27 : server::{
28 : accept,
29 : conn::{AddrIncoming, AddrStream},
30 : },
31 : Body, Method, Request, Response,
32 : };
33 :
34 : use std::net::IpAddr;
35 : use std::task::Poll;
36 : use std::{future::ready, sync::Arc};
37 : use tls_listener::TlsListener;
38 : use tokio::net::TcpListener;
39 : use tokio_util::sync::CancellationToken;
40 : use tracing::{error, info, info_span, warn, Instrument};
41 : use utils::http::{error::ApiError, json::json_response};
42 :
43 CBC 22 : pub async fn task_main(
44 22 : config: &'static ProxyConfig,
45 22 : ws_listener: TcpListener,
46 22 : cancellation_token: CancellationToken,
47 22 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
48 22 : ) -> anyhow::Result<()> {
49 22 : scopeguard::defer! {
50 22 : info!("websocket server has shut down");
51 22 : }
52 22 :
53 22 : let conn_pool = conn_pool::GlobalConnPool::new(config);
54 22 :
55 22 : let conn_pool2 = Arc::clone(&conn_pool);
56 22 : tokio::spawn(async move {
57 30 : conn_pool2.gc_worker(StdRng::from_entropy()).await;
58 22 : });
59 22 :
60 22 : // shutdown the connection pool
61 22 : tokio::spawn({
62 22 : let cancellation_token = cancellation_token.clone();
63 22 : let conn_pool = conn_pool.clone();
64 22 : async move {
65 22 : cancellation_token.cancelled().await;
66 22 : tokio::task::spawn_blocking(move || conn_pool.shutdown())
67 20 : .await
68 21 : .unwrap();
69 22 : }
70 22 : });
71 22 :
72 22 : let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
73 22 : let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
74 22 : Some(config) => config.into(),
75 : None => {
76 UBC 0 : warn!("TLS config is missing, WebSocket Secure server will not be started");
77 0 : return Ok(());
78 : }
79 : };
80 :
81 CBC 22 : let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
82 22 : let _ = addr_incoming.set_nodelay(true);
83 22 : let addr_incoming = ProxyProtocolAccept {
84 22 : incoming: addr_incoming,
85 22 : };
86 22 :
87 22 : let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
88 22 : ws_connections.close(); // allows `ws_connections.wait to complete`
89 22 :
90 45 : let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
91 45 : if let Err(err) = conn {
92 UBC 0 : error!("failed to accept TLS connection for websockets: {err:?}");
93 0 : ready(false)
94 : } else {
95 CBC 45 : ready(true)
96 : }
97 45 : });
98 22 :
99 22 : let make_svc = hyper::service::make_service_fn(
100 45 : |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
101 45 : let (io, tls) = stream.get_ref();
102 45 : let client_addr = io.client_addr();
103 45 : let remote_addr = io.inner.remote_addr();
104 45 : let sni_name = tls.server_name().map(|s| s.to_string());
105 45 : let conn_pool = conn_pool.clone();
106 45 : let ws_connections = ws_connections.clone();
107 45 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
108 :
109 45 : async move {
110 45 : let peer_addr = match client_addr {
111 UBC 0 : Some(addr) => addr,
112 0 : None if config.require_client_ip => bail!("missing required client ip"),
113 CBC 45 : None => remote_addr,
114 : };
115 45 : Ok(MetricService::new(hyper::service::service_fn(
116 45 : move |req: Request<Body>| {
117 45 : let sni_name = sni_name.clone();
118 45 : let conn_pool = conn_pool.clone();
119 45 : let ws_connections = ws_connections.clone();
120 45 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
121 :
122 45 : async move {
123 45 : let cancel_map = Arc::new(CancelMap::default());
124 45 : let session_id = uuid::Uuid::new_v4();
125 45 :
126 45 : request_handler(
127 45 : req,
128 45 : config,
129 45 : conn_pool,
130 45 : ws_connections,
131 45 : cancel_map,
132 45 : session_id,
133 45 : sni_name,
134 45 : peer_addr.ip(),
135 45 : endpoint_rate_limiter,
136 45 : )
137 45 : .instrument(info_span!(
138 45 : "serverless",
139 45 : session = %session_id,
140 45 : %peer_addr,
141 45 : ))
142 727 : .await
143 45 : }
144 45 : },
145 45 : )))
146 45 : }
147 45 : },
148 22 : );
149 22 :
150 22 : hyper::Server::builder(accept::from_stream(tls_listener))
151 22 : .serve(make_svc)
152 22 : .with_graceful_shutdown(cancellation_token.cancelled())
153 207 : .await?;
154 :
155 : // await websocket connections
156 22 : ws_connections.wait().await;
157 :
158 22 : Ok(())
159 22 : }
160 :
161 : struct MetricService<S> {
162 : inner: S,
163 : _gauge: IntCounterPairGuard,
164 : }
165 :
166 : impl<S> MetricService<S> {
167 45 : fn new(inner: S) -> MetricService<S> {
168 45 : MetricService {
169 45 : inner,
170 45 : _gauge: NUM_CLIENT_CONNECTION_GAUGE
171 45 : .with_label_values(&["http"])
172 45 : .guard(),
173 45 : }
174 45 : }
175 : }
176 :
177 : impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
178 : where
179 : S: hyper::service::Service<Request<ReqBody>>,
180 : {
181 : type Response = S::Response;
182 : type Error = S::Error;
183 : type Future = S::Future;
184 :
185 124 : fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
186 124 : self.inner.poll_ready(cx)
187 124 : }
188 :
189 45 : fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
190 45 : self.inner.call(req)
191 45 : }
192 : }
193 :
194 : #[allow(clippy::too_many_arguments)]
195 45 : async fn request_handler(
196 45 : mut request: Request<Body>,
197 45 : config: &'static ProxyConfig,
198 45 : conn_pool: Arc<conn_pool::GlobalConnPool>,
199 45 : ws_connections: TaskTracker,
200 45 : cancel_map: Arc<CancelMap>,
201 45 : session_id: uuid::Uuid,
202 45 : sni_hostname: Option<String>,
203 45 : peer_addr: IpAddr,
204 45 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
205 45 : ) -> Result<Response<Body>, ApiError> {
206 45 : let host = request
207 45 : .headers()
208 45 : .get("host")
209 45 : .and_then(|h| h.to_str().ok())
210 45 : .and_then(|h| h.split(':').next())
211 45 : .map(|s| s.to_string());
212 45 :
213 45 : // Check if the request is a websocket upgrade request.
214 45 : if hyper_tungstenite::is_upgrade_request(&request) {
215 UBC 0 : info!(session_id = ?session_id, "performing websocket upgrade");
216 :
217 0 : let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
218 0 : .map_err(|e| ApiError::BadRequest(e.into()))?;
219 :
220 0 : ws_connections.spawn(
221 0 : async move {
222 0 : let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
223 :
224 0 : if let Err(e) = websocket::serve_websocket(
225 0 : config,
226 0 : &mut ctx,
227 0 : websocket,
228 0 : &cancel_map,
229 0 : host,
230 0 : endpoint_rate_limiter,
231 0 : )
232 0 : .await
233 : {
234 0 : error!(session_id = ?session_id, "error in websocket connection: {e:#}");
235 0 : }
236 0 : }
237 0 : .in_current_span(),
238 0 : );
239 0 :
240 0 : // Return the response so the spawned future can continue.
241 0 : Ok(response)
242 CBC 45 : } else if request.uri().path() == "/sql" && request.method() == Method::POST {
243 45 : let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
244 45 :
245 45 : sql_over_http::handle(
246 45 : &config.http_config,
247 45 : &mut ctx,
248 45 : request,
249 45 : sni_hostname,
250 45 : conn_pool,
251 45 : )
252 727 : .await
253 UBC 0 : } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
254 0 : Response::builder()
255 0 : .header("Allow", "OPTIONS, POST")
256 0 : .header("Access-Control-Allow-Origin", "*")
257 0 : .header(
258 0 : "Access-Control-Allow-Headers",
259 0 : "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
260 0 : )
261 0 : .header("Access-Control-Max-Age", "86400" /* 24 hours */)
262 0 : .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
263 0 : .body(Body::empty())
264 0 : .map_err(|e| ApiError::InternalServerError(e.into()))
265 : } else {
266 0 : json_response(StatusCode::BAD_REQUEST, "query is not supported")
267 : }
268 CBC 45 : }
|