Line data Source code
1 : #[cfg(test)]
2 : mod tests;
3 :
4 : pub(crate) mod connect_compute;
5 : mod copy_bidirectional;
6 : pub(crate) mod handshake;
7 : pub(crate) mod passthrough;
8 : pub(crate) mod retry;
9 : pub(crate) mod wake_compute;
10 : use std::sync::Arc;
11 :
12 : pub use copy_bidirectional::{copy_bidirectional_client_compute, ErrorSource};
13 : use futures::TryFutureExt;
14 : use itertools::Itertools;
15 : use once_cell::sync::OnceCell;
16 : use pq_proto::{BeMessage as Be, StartupMessageParams};
17 : use regex::Regex;
18 : use smol_str::{format_smolstr, SmolStr};
19 : use thiserror::Error;
20 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
21 : use tokio_util::sync::CancellationToken;
22 : use tracing::{debug, error, info, warn, Instrument};
23 :
24 : use self::connect_compute::{connect_to_compute, TcpMechanism};
25 : use self::passthrough::ProxyPassthrough;
26 : use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
27 : use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
28 : use crate::context::RequestMonitoring;
29 : use crate::error::ReportableError;
30 : use crate::metrics::{Metrics, NumClientConnectionsGuard};
31 : use crate::protocol2::{read_proxy_protocol, ConnectHeader, ConnectionInfo};
32 : use crate::proxy::handshake::{handshake, HandshakeData};
33 : use crate::rate_limiter::EndpointRateLimiter;
34 : use crate::stream::{PqStream, Stream};
35 : use crate::types::EndpointCacheKey;
36 : use crate::{auth, compute};
37 :
38 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
39 :
40 0 : pub async fn run_until_cancelled<F: std::future::Future>(
41 0 : f: F,
42 0 : cancellation_token: &CancellationToken,
43 0 : ) -> Option<F::Output> {
44 0 : match futures::future::select(
45 0 : std::pin::pin!(f),
46 0 : std::pin::pin!(cancellation_token.cancelled()),
47 0 : )
48 0 : .await
49 : {
50 0 : futures::future::Either::Left((f, _)) => Some(f),
51 0 : futures::future::Either::Right(((), _)) => None,
52 : }
53 0 : }
54 :
55 0 : pub async fn task_main(
56 0 : config: &'static ProxyConfig,
57 0 : auth_backend: &'static auth::Backend<'static, ()>,
58 0 : listener: tokio::net::TcpListener,
59 0 : cancellation_token: CancellationToken,
60 0 : cancellation_handler: Arc<CancellationHandlerMain>,
61 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
62 0 : ) -> anyhow::Result<()> {
63 0 : scopeguard::defer! {
64 0 : info!("proxy has shut down");
65 0 : }
66 0 :
67 0 : // When set for the server socket, the keepalive setting
68 0 : // will be inherited by all accepted client sockets.
69 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
70 :
71 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
72 :
73 0 : while let Some(accept_result) =
74 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
75 : {
76 0 : let (socket, peer_addr) = accept_result?;
77 :
78 0 : let conn_gauge = Metrics::get()
79 0 : .proxy
80 0 : .client_connections
81 0 : .guard(crate::metrics::Protocol::Tcp);
82 0 :
83 0 : let session_id = uuid::Uuid::new_v4();
84 0 : let cancellation_handler = Arc::clone(&cancellation_handler);
85 0 :
86 0 : debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
87 0 : let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
88 0 :
89 0 : connections.spawn(async move {
90 0 : let (socket, conn_info) = match read_proxy_protocol(socket).await {
91 0 : Err(e) => {
92 0 : warn!("per-client task finished with an error: {e:#}");
93 0 : return;
94 : }
95 : // our load balancers will not send any more data. let's just exit immediately
96 0 : Ok((_socket, ConnectHeader::Local)) => {
97 0 : debug!("healthcheck received");
98 0 : return;
99 : }
100 0 : Ok((_socket, ConnectHeader::Missing)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
101 0 : warn!("missing required proxy protocol header");
102 0 : return;
103 : }
104 0 : Ok((_socket, ConnectHeader::Proxy(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
105 0 : warn!("proxy protocol header not supported");
106 0 : return;
107 : }
108 0 : Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
109 0 : Ok((socket, ConnectHeader::Missing)) => (socket, ConnectionInfo { addr: peer_addr, extra: None }),
110 : };
111 :
112 0 : match socket.inner.set_nodelay(true) {
113 0 : Ok(()) => {}
114 0 : Err(e) => {
115 0 : error!("per-client task finished with an error: failed to set socket option: {e:#}");
116 0 : return;
117 : }
118 : };
119 :
120 0 : let ctx = RequestMonitoring::new(
121 0 : session_id,
122 0 : conn_info,
123 0 : crate::metrics::Protocol::Tcp,
124 0 : &config.region,
125 0 : );
126 0 : let span = ctx.span();
127 0 :
128 0 : let startup = Box::pin(
129 0 : handle_client(
130 0 : config,
131 0 : auth_backend,
132 0 : &ctx,
133 0 : cancellation_handler,
134 0 : socket,
135 0 : ClientMode::Tcp,
136 0 : endpoint_rate_limiter2,
137 0 : conn_gauge,
138 0 : )
139 0 : .instrument(span.clone()),
140 0 : );
141 0 : let res = startup.await;
142 :
143 0 : match res {
144 0 : Err(e) => {
145 0 : // todo: log and push to ctx the error kind
146 0 : ctx.set_error_kind(e.get_error_kind());
147 0 : warn!(parent: &span, "per-client task finished with an error: {e:#}");
148 : }
149 0 : Ok(None) => {
150 0 : ctx.set_success();
151 0 : }
152 0 : Ok(Some(p)) => {
153 0 : ctx.set_success();
154 0 : ctx.log_connect();
155 0 : match p.proxy_pass().instrument(span.clone()).await {
156 0 : Ok(()) => {}
157 0 : Err(ErrorSource::Client(e)) => {
158 0 : warn!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
159 : }
160 0 : Err(ErrorSource::Compute(e)) => {
161 0 : error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
162 : }
163 : }
164 : }
165 : }
166 0 : });
167 : }
168 :
169 0 : connections.close();
170 0 : drop(listener);
171 0 :
172 0 : // Drain connections
173 0 : connections.wait().await;
174 :
175 0 : Ok(())
176 0 : }
177 :
178 : pub(crate) enum ClientMode {
179 : Tcp,
180 : Websockets { hostname: Option<String> },
181 : }
182 :
183 : /// Abstracts the logic of handling TCP vs WS clients
184 : impl ClientMode {
185 0 : pub(crate) fn allow_cleartext(&self) -> bool {
186 0 : match self {
187 0 : ClientMode::Tcp => false,
188 0 : ClientMode::Websockets { .. } => true,
189 : }
190 0 : }
191 :
192 0 : pub(crate) fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
193 0 : match self {
194 0 : ClientMode::Tcp => config.allow_self_signed_compute,
195 0 : ClientMode::Websockets { .. } => false,
196 : }
197 0 : }
198 :
199 0 : fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
200 0 : match self {
201 0 : ClientMode::Tcp => s.sni_hostname(),
202 0 : ClientMode::Websockets { hostname } => hostname.as_deref(),
203 : }
204 0 : }
205 :
206 0 : fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
207 0 : match self {
208 0 : ClientMode::Tcp => tls,
209 : // TLS is None here if using websockets, because the connection is already encrypted.
210 0 : ClientMode::Websockets { .. } => None,
211 : }
212 0 : }
213 : }
214 :
215 0 : #[derive(Debug, Error)]
216 : // almost all errors should be reported to the user, but there's a few cases where we cannot
217 : // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
218 : // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
219 : // we cannot be sure the client even understands our error message
220 : // 3. PrepareClient: The client disconnected, so we can't tell them anyway...
221 : pub(crate) enum ClientRequestError {
222 : #[error("{0}")]
223 : Cancellation(#[from] cancellation::CancelError),
224 : #[error("{0}")]
225 : Handshake(#[from] handshake::HandshakeError),
226 : #[error("{0}")]
227 : HandshakeTimeout(#[from] tokio::time::error::Elapsed),
228 : #[error("{0}")]
229 : PrepareClient(#[from] std::io::Error),
230 : #[error("{0}")]
231 : ReportedError(#[from] crate::stream::ReportedError),
232 : }
233 :
234 : impl ReportableError for ClientRequestError {
235 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
236 0 : match self {
237 0 : ClientRequestError::Cancellation(e) => e.get_error_kind(),
238 0 : ClientRequestError::Handshake(e) => e.get_error_kind(),
239 0 : ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
240 0 : ClientRequestError::ReportedError(e) => e.get_error_kind(),
241 0 : ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
242 : }
243 0 : }
244 : }
245 :
246 : #[allow(clippy::too_many_arguments)]
247 0 : pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
248 0 : config: &'static ProxyConfig,
249 0 : auth_backend: &'static auth::Backend<'static, ()>,
250 0 : ctx: &RequestMonitoring,
251 0 : cancellation_handler: Arc<CancellationHandlerMain>,
252 0 : stream: S,
253 0 : mode: ClientMode,
254 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
255 0 : conn_gauge: NumClientConnectionsGuard<'static>,
256 0 : ) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
257 0 : info!(
258 0 : protocol = %ctx.protocol(),
259 0 : "handling interactive connection from client"
260 : );
261 :
262 0 : let metrics = &Metrics::get().proxy;
263 0 : let proto = ctx.protocol();
264 0 : let request_gauge = metrics.connection_requests.guard(proto);
265 0 :
266 0 : let tls = config.tls_config.as_ref();
267 0 :
268 0 : let record_handshake_error = !ctx.has_private_peer_addr();
269 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
270 0 : let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
271 0 : let (mut stream, params) =
272 0 : match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
273 0 : HandshakeData::Startup(stream, params) => (stream, params),
274 0 : HandshakeData::Cancel(cancel_key_data) => {
275 0 : return Ok(cancellation_handler
276 0 : .cancel_session(cancel_key_data, ctx.session_id())
277 0 : .await
278 0 : .map(|()| None)?)
279 : }
280 : };
281 0 : drop(pause);
282 0 :
283 0 : ctx.set_db_options(params.clone());
284 0 :
285 0 : let hostname = mode.hostname(stream.get_ref());
286 0 :
287 0 : let common_names = tls.map(|tls| &tls.common_names);
288 0 :
289 0 : // Extract credentials which we're going to use for auth.
290 0 : let result = auth_backend
291 0 : .as_ref()
292 0 : .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
293 0 : .transpose();
294 :
295 0 : let user_info = match result {
296 0 : Ok(user_info) => user_info,
297 0 : Err(e) => stream.throw_error(e).await?,
298 : };
299 :
300 0 : let user = user_info.get_user().to_owned();
301 0 : let user_info = match user_info
302 0 : .authenticate(
303 0 : ctx,
304 0 : &mut stream,
305 0 : mode.allow_cleartext(),
306 0 : &config.authentication_config,
307 0 : endpoint_rate_limiter,
308 0 : )
309 0 : .await
310 : {
311 0 : Ok(auth_result) => auth_result,
312 0 : Err(e) => {
313 0 : let db = params.get("database");
314 0 : let app = params.get("application_name");
315 0 : let params_span = tracing::info_span!("", ?user, ?db, ?app);
316 :
317 0 : return stream.throw_error(e).instrument(params_span).await?;
318 : }
319 : };
320 :
321 0 : let mut node = connect_to_compute(
322 0 : ctx,
323 0 : &TcpMechanism {
324 0 : params: ¶ms,
325 0 : locks: &config.connect_compute_locks,
326 0 : },
327 0 : &user_info,
328 0 : mode.allow_self_signed_compute(config),
329 0 : config.wake_compute_retry_config,
330 0 : config.connect_to_compute_retry_config,
331 0 : )
332 0 : .or_else(|e| stream.throw_error(e))
333 0 : .await?;
334 :
335 0 : let session = cancellation_handler.get_session();
336 0 : prepare_client_connection(&node, &session, &mut stream).await?;
337 :
338 : // Before proxy passing, forward to compute whatever data is left in the
339 : // PqStream input buffer. Normally there is none, but our serverless npm
340 : // driver in pipeline mode sends startup, password and first query
341 : // immediately after opening the connection.
342 0 : let (stream, read_buf) = stream.into_inner();
343 0 : node.stream.write_all(&read_buf).await?;
344 :
345 0 : Ok(Some(ProxyPassthrough {
346 0 : client: stream,
347 0 : aux: node.aux.clone(),
348 0 : compute: node,
349 0 : _req: request_gauge,
350 0 : _conn: conn_gauge,
351 0 : _cancel: session,
352 0 : }))
353 0 : }
354 :
355 : /// Finish client connection initialization: confirm auth success, send params, etc.
356 0 : #[tracing::instrument(skip_all)]
357 : pub(crate) async fn prepare_client_connection<P>(
358 : node: &compute::PostgresConnection,
359 : session: &cancellation::Session<P>,
360 : stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
361 : ) -> Result<(), std::io::Error> {
362 : // Register compute's query cancellation token and produce a new, unique one.
363 : // The new token (cancel_key_data) will be sent to the client.
364 : let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
365 :
366 : // Forward all postgres connection params to the client.
367 : // Right now the implementation is very hacky and inefficent (ideally,
368 : // we don't need an intermediate hashmap), but at least it should be correct.
369 : for (name, value) in &node.params {
370 : // TODO: Theoretically, this could result in a big pile of params...
371 : stream.write_message_noflush(&Be::ParameterStatus {
372 : name: name.as_bytes(),
373 : value: value.as_bytes(),
374 : })?;
375 : }
376 :
377 : stream
378 : .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
379 : .write_message(&Be::ReadyForQuery)
380 : .await?;
381 :
382 : Ok(())
383 : }
384 :
385 : #[derive(Debug, Clone, PartialEq, Eq, Default)]
386 : pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
387 :
388 : impl NeonOptions {
389 11 : pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
390 11 : params
391 11 : .options_raw()
392 11 : .map(Self::parse_from_iter)
393 11 : .unwrap_or_default()
394 11 : }
395 7 : pub(crate) fn parse_options_raw(options: &str) -> Self {
396 7 : Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
397 7 : }
398 :
399 2 : pub(crate) fn is_ephemeral(&self) -> bool {
400 2 : // Currently, neon endpoint options are all reserved for ephemeral endpoints.
401 2 : !self.0.is_empty()
402 2 : }
403 :
404 13 : fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
405 13 : let mut options = options
406 13 : .filter_map(neon_option)
407 13 : .map(|(k, v)| (k.into(), v.into()))
408 13 : .collect_vec();
409 13 : options.sort();
410 13 : Self(options)
411 13 : }
412 :
413 4 : pub(crate) fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
414 4 : // prefix + format!(" {k}:{v}")
415 4 : // kinda jank because SmolStr is immutable
416 4 : std::iter::once(prefix)
417 4 : .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
418 4 : .collect::<SmolStr>()
419 4 : .into()
420 4 : }
421 :
422 : /// <https://swagger.io/docs/specification/serialization/> DeepObject format
423 : /// `paramName[prop1]=value1¶mName[prop2]=value2&...`
424 0 : pub(crate) fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
425 0 : self.0
426 0 : .iter()
427 0 : .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
428 0 : .collect()
429 0 : }
430 : }
431 :
432 32 : pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
433 : static RE: OnceCell<Regex> = OnceCell::new();
434 32 : let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
435 :
436 32 : let cap = re.captures(bytes)?;
437 4 : let (_, [k, v]) = cap.extract();
438 4 : Some((k, v))
439 32 : }
|