Line data Source code
1 : //! Routers for our serverless APIs
2 : //!
3 : //! Handles both SQL over HTTP and SQL over Websockets.
4 :
5 : mod conn_pool;
6 : mod json;
7 : mod sql_over_http;
8 : mod websocket;
9 :
10 : pub use conn_pool::GlobalConnPoolOptions;
11 :
12 : use anyhow::bail;
13 : use hyper::StatusCode;
14 : use metrics::IntCounterPairGuard;
15 : use rand::rngs::StdRng;
16 : use rand::SeedableRng;
17 : pub use reqwest_middleware::{ClientWithMiddleware, Error};
18 : pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
19 : use tokio_util::task::TaskTracker;
20 :
21 : use crate::config::TlsConfig;
22 : use crate::context::RequestMonitoring;
23 : use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
24 : use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
25 : use crate::rate_limiter::EndpointRateLimiter;
26 : use crate::{cancellation::CancelMap, config::ProxyConfig};
27 : use futures::StreamExt;
28 : use hyper::{
29 : server::{
30 : accept,
31 : conn::{AddrIncoming, AddrStream},
32 : },
33 : Body, Method, Request, Response,
34 : };
35 :
36 : use std::net::IpAddr;
37 : use std::task::Poll;
38 : use std::{future::ready, sync::Arc};
39 : use tls_listener::TlsListener;
40 : use tokio::net::TcpListener;
41 : use tokio_util::sync::CancellationToken;
42 : use tracing::{error, info, info_span, warn, Instrument};
43 : use utils::http::{error::ApiError, json::json_response};
44 :
45 : pub const SERVERLESS_DRIVER_SNI: &str = "api";
46 :
47 23 : pub async fn task_main(
48 23 : config: &'static ProxyConfig,
49 23 : ws_listener: TcpListener,
50 23 : cancellation_token: CancellationToken,
51 23 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
52 23 : ) -> anyhow::Result<()> {
53 23 : scopeguard::defer! {
54 23 : info!("websocket server has shut down");
55 23 : }
56 23 :
57 23 : let conn_pool = conn_pool::GlobalConnPool::new(config);
58 23 :
59 23 : let conn_pool2 = Arc::clone(&conn_pool);
60 23 : tokio::spawn(async move {
61 27 : conn_pool2.gc_worker(StdRng::from_entropy()).await;
62 23 : });
63 23 :
64 23 : // shutdown the connection pool
65 23 : tokio::spawn({
66 23 : let cancellation_token = cancellation_token.clone();
67 23 : let conn_pool = conn_pool.clone();
68 23 : async move {
69 23 : cancellation_token.cancelled().await;
70 23 : tokio::task::spawn_blocking(move || conn_pool.shutdown())
71 22 : .await
72 23 : .unwrap();
73 23 : }
74 23 : });
75 :
76 23 : let tls_config = match config.tls_config.as_ref() {
77 23 : Some(config) => config,
78 : None => {
79 0 : warn!("TLS config is missing, WebSocket Secure server will not be started");
80 0 : return Ok(());
81 : }
82 : };
83 23 : let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into();
84 :
85 23 : let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
86 23 : let _ = addr_incoming.set_nodelay(true);
87 23 : let addr_incoming = ProxyProtocolAccept {
88 23 : incoming: addr_incoming,
89 23 : };
90 23 :
91 23 : let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
92 23 : ws_connections.close(); // allows `ws_connections.wait to complete`
93 23 :
94 46 : let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
95 46 : if let Err(err) = conn {
96 0 : error!("failed to accept TLS connection for websockets: {err:?}");
97 0 : ready(false)
98 : } else {
99 46 : ready(true)
100 : }
101 46 : });
102 23 :
103 23 : let make_svc = hyper::service::make_service_fn(
104 46 : |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
105 46 : let (io, tls) = stream.get_ref();
106 46 : let client_addr = io.client_addr();
107 46 : let remote_addr = io.inner.remote_addr();
108 46 : let sni_name = tls.server_name().map(|s| s.to_string());
109 46 : let conn_pool = conn_pool.clone();
110 46 : let ws_connections = ws_connections.clone();
111 46 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
112 :
113 46 : async move {
114 46 : let peer_addr = match client_addr {
115 0 : Some(addr) => addr,
116 0 : None if config.require_client_ip => bail!("missing required client ip"),
117 46 : None => remote_addr,
118 : };
119 46 : Ok(MetricService::new(hyper::service::service_fn(
120 46 : move |req: Request<Body>| {
121 46 : let sni_name = sni_name.clone();
122 46 : let conn_pool = conn_pool.clone();
123 46 : let ws_connections = ws_connections.clone();
124 46 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
125 :
126 46 : async move {
127 46 : let cancel_map = Arc::new(CancelMap::default());
128 46 : let session_id = uuid::Uuid::new_v4();
129 46 :
130 46 : request_handler(
131 46 : req,
132 46 : config,
133 46 : tls_config,
134 46 : conn_pool,
135 46 : ws_connections,
136 46 : cancel_map,
137 46 : session_id,
138 46 : sni_name,
139 46 : peer_addr.ip(),
140 46 : endpoint_rate_limiter,
141 46 : )
142 46 : .instrument(info_span!(
143 46 : "serverless",
144 46 : session = %session_id,
145 46 : %peer_addr,
146 46 : ))
147 832 : .await
148 46 : }
149 46 : },
150 46 : )))
151 46 : }
152 46 : },
153 23 : );
154 23 :
155 23 : hyper::Server::builder(accept::from_stream(tls_listener))
156 23 : .serve(make_svc)
157 23 : .with_graceful_shutdown(cancellation_token.cancelled())
158 214 : .await?;
159 :
160 : // await websocket connections
161 23 : ws_connections.wait().await;
162 :
163 23 : Ok(())
164 23 : }
165 :
166 : struct MetricService<S> {
167 : inner: S,
168 : _gauge: IntCounterPairGuard,
169 : }
170 :
171 : impl<S> MetricService<S> {
172 46 : fn new(inner: S) -> MetricService<S> {
173 46 : MetricService {
174 46 : inner,
175 46 : _gauge: NUM_CLIENT_CONNECTION_GAUGE
176 46 : .with_label_values(&["http"])
177 46 : .guard(),
178 46 : }
179 46 : }
180 : }
181 :
182 : impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
183 : where
184 : S: hyper::service::Service<Request<ReqBody>>,
185 : {
186 : type Response = S::Response;
187 : type Error = S::Error;
188 : type Future = S::Future;
189 :
190 125 : fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
191 125 : self.inner.poll_ready(cx)
192 125 : }
193 :
194 46 : fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
195 46 : self.inner.call(req)
196 46 : }
197 : }
198 :
199 : #[allow(clippy::too_many_arguments)]
200 46 : async fn request_handler(
201 46 : mut request: Request<Body>,
202 46 : config: &'static ProxyConfig,
203 46 : tls: &'static TlsConfig,
204 46 : conn_pool: Arc<conn_pool::GlobalConnPool>,
205 46 : ws_connections: TaskTracker,
206 46 : cancel_map: Arc<CancelMap>,
207 46 : session_id: uuid::Uuid,
208 46 : sni_hostname: Option<String>,
209 46 : peer_addr: IpAddr,
210 46 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
211 46 : ) -> Result<Response<Body>, ApiError> {
212 46 : let host = request
213 46 : .headers()
214 46 : .get("host")
215 46 : .and_then(|h| h.to_str().ok())
216 46 : .and_then(|h| h.split(':').next())
217 46 : .map(|s| s.to_string());
218 46 :
219 46 : // Check if the request is a websocket upgrade request.
220 46 : if hyper_tungstenite::is_upgrade_request(&request) {
221 0 : info!(session_id = ?session_id, "performing websocket upgrade");
222 :
223 0 : let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
224 0 : .map_err(|e| ApiError::BadRequest(e.into()))?;
225 :
226 0 : ws_connections.spawn(
227 0 : async move {
228 0 : let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
229 :
230 0 : if let Err(e) = websocket::serve_websocket(
231 0 : config,
232 0 : &mut ctx,
233 0 : websocket,
234 0 : cancel_map,
235 0 : host,
236 0 : endpoint_rate_limiter,
237 0 : )
238 0 : .await
239 : {
240 0 : error!(session_id = ?session_id, "error in websocket connection: {e:#}");
241 0 : }
242 0 : }
243 0 : .in_current_span(),
244 0 : );
245 0 :
246 0 : // Return the response so the spawned future can continue.
247 0 : Ok(response)
248 46 : } else if request.uri().path() == "/sql" && request.method() == Method::POST {
249 46 : let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
250 46 :
251 46 : sql_over_http::handle(
252 46 : tls,
253 46 : &config.http_config,
254 46 : &mut ctx,
255 46 : request,
256 46 : sni_hostname,
257 46 : conn_pool,
258 46 : )
259 832 : .await
260 0 : } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
261 0 : Response::builder()
262 0 : .header("Allow", "OPTIONS, POST")
263 0 : .header("Access-Control-Allow-Origin", "*")
264 0 : .header(
265 0 : "Access-Control-Allow-Headers",
266 0 : "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
267 0 : )
268 0 : .header("Access-Control-Max-Age", "86400" /* 24 hours */)
269 0 : .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
270 0 : .body(Body::empty())
271 0 : .map_err(|e| ApiError::InternalServerError(e.into()))
272 : } else {
273 0 : json_response(StatusCode::BAD_REQUEST, "query is not supported")
274 : }
275 46 : }
|