TLA Line data Source code
1 : #[cfg(test)]
2 : mod tests;
3 :
4 : pub mod connect_compute;
5 : pub mod retry;
6 :
7 : use crate::{
8 : auth,
9 : cancellation::{self, CancelMap},
10 : compute,
11 : config::{AuthenticationConfig, ProxyConfig, TlsConfig},
12 : console::{self, messages::MetricsAuxInfo},
13 : context::RequestMonitoring,
14 : metrics::{
15 : NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
16 : NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE,
17 : },
18 : protocol2::WithClientIp,
19 : rate_limiter::EndpointRateLimiter,
20 : stream::{PqStream, Stream},
21 : usage_metrics::{Ids, USAGE_METRICS},
22 : };
23 : use anyhow::{bail, Context};
24 : use futures::TryFutureExt;
25 : use itertools::Itertools;
26 : use once_cell::sync::OnceCell;
27 : use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
28 : use regex::Regex;
29 : use std::sync::Arc;
30 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
31 : use tokio_util::sync::CancellationToken;
32 : use tracing::{error, info, info_span, Instrument};
33 : use utils::measured_stream::MeasuredStream;
34 :
35 : use self::connect_compute::{connect_to_compute, TcpMechanism};
36 :
37 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
38 : const ERR_PROTO_VIOLATION: &str = "protocol violation";
39 :
40 CBC 85 : pub async fn run_until_cancelled<F: std::future::Future>(
41 85 : f: F,
42 85 : cancellation_token: &CancellationToken,
43 85 : ) -> Option<F::Output> {
44 85 : match futures::future::select(
45 85 : std::pin::pin!(f),
46 85 : std::pin::pin!(cancellation_token.cancelled()),
47 85 : )
48 84 : .await
49 : {
50 62 : futures::future::Either::Left((f, _)) => Some(f),
51 23 : futures::future::Either::Right(((), _)) => None,
52 : }
53 85 : }
54 :
55 22 : pub async fn task_main(
56 22 : config: &'static ProxyConfig,
57 22 : listener: tokio::net::TcpListener,
58 22 : cancellation_token: CancellationToken,
59 22 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
60 22 : ) -> anyhow::Result<()> {
61 22 : scopeguard::defer! {
62 22 : info!("proxy has shut down");
63 22 : }
64 22 :
65 22 : // When set for the server socket, the keepalive setting
66 22 : // will be inherited by all accepted client sockets.
67 22 : socket2::SockRef::from(&listener).set_keepalive(true)?;
68 :
69 22 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
70 22 : let cancel_map = Arc::new(CancelMap::default());
71 :
72 60 : while let Some(accept_result) =
73 82 : run_until_cancelled(listener.accept(), &cancellation_token).await
74 : {
75 60 : let (socket, peer_addr) = accept_result?;
76 :
77 60 : let session_id = uuid::Uuid::new_v4();
78 60 : let cancel_map = Arc::clone(&cancel_map);
79 60 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
80 60 :
81 60 : connections.spawn(
82 60 : async move {
83 60 : info!("accepted postgres client connection");
84 :
85 60 : let mut socket = WithClientIp::new(socket);
86 60 : let mut peer_addr = peer_addr.ip();
87 60 : if let Some(addr) = socket.wait_for_addr().await? {
88 UBC 0 : peer_addr = addr.ip();
89 0 : tracing::Span::current().record("peer_addr", &tracing::field::display(addr));
90 CBC 60 : } else if config.require_client_ip {
91 UBC 0 : bail!("missing required client IP");
92 CBC 60 : }
93 :
94 60 : let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
95 60 :
96 60 : socket
97 60 : .inner
98 60 : .set_nodelay(true)
99 60 : .context("failed to set socket option")?;
100 :
101 60 : handle_client(
102 60 : config,
103 60 : &mut ctx,
104 60 : &cancel_map,
105 60 : socket,
106 60 : ClientMode::Tcp,
107 60 : endpoint_rate_limiter,
108 60 : )
109 977 : .await
110 60 : }
111 60 : .instrument(info_span!(
112 60 : "handle_client",
113 60 : ?session_id,
114 60 : peer_addr = tracing::field::Empty
115 60 : ))
116 60 : .unwrap_or_else(move |e| {
117 60 : // Acknowledge that the task has finished with an error.
118 60 : error!(?session_id, "per-client task finished with an error: {e:#}");
119 60 : }),
120 : );
121 : }
122 :
123 22 : connections.close();
124 22 : drop(listener);
125 22 :
126 22 : // Drain connections
127 22 : connections.wait().await;
128 :
129 22 : Ok(())
130 22 : }
131 :
132 : pub enum ClientMode {
133 : Tcp,
134 : Websockets { hostname: Option<String> },
135 : }
136 :
137 : /// Abstracts the logic of handling TCP vs WS clients
138 : impl ClientMode {
139 49 : fn allow_cleartext(&self) -> bool {
140 49 : match self {
141 49 : ClientMode::Tcp => false,
142 UBC 0 : ClientMode::Websockets { .. } => true,
143 : }
144 CBC 49 : }
145 :
146 49 : fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
147 49 : match self {
148 49 : ClientMode::Tcp => config.allow_self_signed_compute,
149 UBC 0 : ClientMode::Websockets { .. } => false,
150 : }
151 CBC 49 : }
152 :
153 49 : fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
154 49 : match self {
155 49 : ClientMode::Tcp => s.sni_hostname(),
156 UBC 0 : ClientMode::Websockets { hostname } => hostname.as_deref(),
157 : }
158 CBC 49 : }
159 :
160 60 : fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
161 60 : match self {
162 60 : ClientMode::Tcp => tls,
163 : // TLS is None here if using websockets, because the connection is already encrypted.
164 UBC 0 : ClientMode::Websockets { .. } => None,
165 : }
166 CBC 60 : }
167 : }
168 :
169 60 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
170 60 : config: &'static ProxyConfig,
171 60 : ctx: &mut RequestMonitoring,
172 60 : cancel_map: &CancelMap,
173 60 : stream: S,
174 60 : mode: ClientMode,
175 60 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
176 60 : ) -> anyhow::Result<()> {
177 60 : info!(
178 60 : protocol = ctx.protocol,
179 60 : "handling interactive connection from client"
180 60 : );
181 :
182 60 : let proto = ctx.protocol;
183 60 : let _client_gauge = NUM_CLIENT_CONNECTION_GAUGE
184 60 : .with_label_values(&[proto])
185 60 : .guard();
186 60 : let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
187 60 : .with_label_values(&[proto])
188 60 : .guard();
189 60 :
190 60 : let tls = config.tls_config.as_ref();
191 60 :
192 60 : let pause = ctx.latency_timer.pause();
193 60 : let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
194 100 : let (mut stream, params) = match do_handshake.await? {
195 49 : Some(x) => x,
196 UBC 0 : None => return Ok(()), // it's a cancellation request
197 : };
198 CBC 49 : drop(pause);
199 :
200 : // Extract credentials which we're going to use for auth.
201 49 : let creds = {
202 49 : let hostname = mode.hostname(stream.get_ref());
203 49 :
204 49 : let common_names = tls.and_then(|tls| tls.common_names.clone());
205 49 : let result = config
206 49 : .auth_backend
207 49 : .as_ref()
208 49 : .map(|_| auth::ClientCredentials::parse(ctx, ¶ms, hostname, common_names))
209 49 : .transpose();
210 49 :
211 49 : match result {
212 49 : Ok(creds) => creds,
213 UBC 0 : Err(e) => stream.throw_error(e).await?,
214 : }
215 : };
216 :
217 CBC 49 : ctx.set_endpoint_id(creds.get_endpoint());
218 49 :
219 49 : let client = Client::new(
220 49 : stream,
221 49 : creds,
222 49 : ¶ms,
223 49 : mode.allow_self_signed_compute(config),
224 49 : endpoint_rate_limiter,
225 49 : );
226 49 : cancel_map
227 49 : .with_session(|session| {
228 49 : client.connect_to_db(ctx, session, mode, &config.authentication_config)
229 49 : })
230 877 : .await
231 60 : }
232 :
233 : /// Establish a (most probably, secure) connection with the client.
234 : /// For better testing experience, `stream` can be any object satisfying the traits.
235 : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
236 : /// we also take an extra care of propagating only the select handshake errors to client.
237 151 : #[tracing::instrument(skip_all)]
238 : async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
239 : stream: S,
240 : mut tls: Option<&TlsConfig>,
241 : cancel_map: &CancelMap,
242 : ) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
243 : // Client may try upgrading to each protocol only once
244 : let (mut tried_ssl, mut tried_gss) = (false, false);
245 :
246 : let mut stream = PqStream::new(Stream::from_raw(stream));
247 : loop {
248 : let msg = stream.read_startup_packet().await?;
249 109 : info!("received {msg:?}");
250 :
251 : use FeStartupPacket::*;
252 : match msg {
253 : SslRequest => match stream.get_ref() {
254 : Stream::Raw { .. } if !tried_ssl => {
255 : tried_ssl = true;
256 :
257 : // We can't perform TLS handshake without a config
258 : let enc = tls.is_some();
259 : stream.write_message(&Be::EncryptionResponse(enc)).await?;
260 : if let Some(tls) = tls.take() {
261 : // Upgrade raw stream into a secure TLS-backed stream.
262 : // NOTE: We've consumed `tls`; this fact will be used later.
263 :
264 : let (raw, read_buf) = stream.into_inner();
265 : // TODO: Normally, client doesn't send any data before
266 : // server says TLS handshake is ok and read_buf is empy.
267 : // However, you could imagine pipelining of postgres
268 : // SSLRequest + TLS ClientHello in one hunk similar to
269 : // pipelining in our node js driver. We should probably
270 : // support that by chaining read_buf with the stream.
271 : if !read_buf.is_empty() {
272 : bail!("data is sent before server replied with EncryptionResponse");
273 : }
274 : let tls_stream = raw.upgrade(tls.to_server_config()).await?;
275 :
276 : let (_, tls_server_end_point) = tls
277 : .cert_resolver
278 : .resolve(tls_stream.get_ref().1.server_name())
279 : .context("missing certificate")?;
280 :
281 : stream = PqStream::new(Stream::Tls {
282 : tls: Box::new(tls_stream),
283 : tls_server_end_point,
284 : });
285 : }
286 : }
287 : _ => bail!(ERR_PROTO_VIOLATION),
288 : },
289 : GssEncRequest => match stream.get_ref() {
290 : Stream::Raw { .. } if !tried_gss => {
291 : tried_gss = true;
292 :
293 : // Currently, we don't support GSSAPI
294 : stream.write_message(&Be::EncryptionResponse(false)).await?;
295 : }
296 : _ => bail!(ERR_PROTO_VIOLATION),
297 : },
298 : StartupMessage { params, .. } => {
299 : // Check that the config has been consumed during upgrade
300 : // OR we didn't provide it at all (for dev purposes).
301 : if tls.is_some() {
302 : stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
303 : }
304 :
305 49 : info!(session_type = "normal", "successful handshake");
306 : break Ok(Some((stream, params)));
307 : }
308 : CancelRequest(cancel_key_data) => {
309 : cancel_map.cancel_session(cancel_key_data).await?;
310 :
311 UBC 0 : info!(session_type = "cancellation", "successful handshake");
312 : break Ok(None);
313 : }
314 : }
315 : }
316 : }
317 :
318 : /// Finish client connection initialization: confirm auth success, send params, etc.
319 CBC 38 : #[tracing::instrument(skip_all)]
320 : async fn prepare_client_connection(
321 : node: &compute::PostgresConnection,
322 : session: cancellation::Session<'_>,
323 : stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
324 : ) -> anyhow::Result<()> {
325 : // Register compute's query cancellation token and produce a new, unique one.
326 : // The new token (cancel_key_data) will be sent to the client.
327 : let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
328 :
329 : // Forward all postgres connection params to the client.
330 : // Right now the implementation is very hacky and inefficent (ideally,
331 : // we don't need an intermediate hashmap), but at least it should be correct.
332 : for (name, value) in &node.params {
333 : // TODO: Theoretically, this could result in a big pile of params...
334 : stream.write_message_noflush(&Be::ParameterStatus {
335 : name: name.as_bytes(),
336 : value: value.as_bytes(),
337 : })?;
338 : }
339 :
340 : stream
341 : .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
342 : .write_message(&Be::ReadyForQuery)
343 : .await?;
344 :
345 : Ok(())
346 : }
347 :
348 : /// Forward bytes in both directions (client <-> compute).
349 UBC 0 : #[tracing::instrument(skip_all)]
350 : pub async fn proxy_pass(
351 : ctx: &mut RequestMonitoring,
352 : client: impl AsyncRead + AsyncWrite + Unpin,
353 : compute: impl AsyncRead + AsyncWrite + Unpin,
354 : aux: MetricsAuxInfo,
355 : ) -> anyhow::Result<()> {
356 : ctx.log();
357 :
358 : let usage = USAGE_METRICS.register(Ids {
359 : endpoint_id: aux.endpoint_id.clone(),
360 : branch_id: aux.branch_id.clone(),
361 : });
362 :
363 : let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
364 : let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
365 : let mut client = MeasuredStream::new(
366 : client,
367 CBC 117 : |_| {},
368 78 : |cnt| {
369 78 : // Number of bytes we sent to the client (outbound).
370 78 : m_sent.inc_by(cnt as u64);
371 78 : m_sent2.inc_by(cnt as u64);
372 78 : usage.record_egress(cnt as u64);
373 78 : },
374 : );
375 :
376 : let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
377 : let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
378 : let mut compute = MeasuredStream::new(
379 : compute,
380 78 : |_| {},
381 78 : |cnt| {
382 78 : // Number of bytes the client sent to the compute node (inbound).
383 78 : m_recv.inc_by(cnt as u64);
384 78 : m_recv2.inc_by(cnt as u64);
385 78 : },
386 : );
387 :
388 : // Starting from here we only proxy the client's traffic.
389 39 : info!("performing the proxy pass...");
390 : let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
391 :
392 : Ok(())
393 : }
394 :
395 : /// Thin connection context.
396 : struct Client<'a, S> {
397 : /// The underlying libpq protocol stream.
398 : stream: PqStream<Stream<S>>,
399 : /// Client credentials that we care about.
400 : creds: auth::BackendType<'a, auth::ClientCredentials>,
401 : /// KV-dictionary with PostgreSQL connection params.
402 : params: &'a StartupMessageParams,
403 : /// Allow self-signed certificates (for testing).
404 : allow_self_signed_compute: bool,
405 : /// Rate limiter for endpoints
406 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
407 : }
408 :
409 : impl<'a, S> Client<'a, S> {
410 : /// Construct a new connection context.
411 49 : fn new(
412 49 : stream: PqStream<Stream<S>>,
413 49 : creds: auth::BackendType<'a, auth::ClientCredentials>,
414 49 : params: &'a StartupMessageParams,
415 49 : allow_self_signed_compute: bool,
416 49 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
417 49 : ) -> Self {
418 49 : Self {
419 49 : stream,
420 49 : creds,
421 49 : params,
422 49 : allow_self_signed_compute,
423 49 : endpoint_rate_limiter,
424 49 : }
425 49 : }
426 : }
427 :
428 : impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
429 : /// Let the client authenticate and connect to the designated compute node.
430 : // Instrumentation logs endpoint name everywhere. Doesn't work for link
431 : // auth; strictly speaking we don't know endpoint name in its case.
432 UBC 0 : #[tracing::instrument(name = "", fields(ep = %self.creds.get_endpoint().unwrap_or_default()), skip_all)]
433 : async fn connect_to_db(
434 : self,
435 : ctx: &mut RequestMonitoring,
436 : session: cancellation::Session<'_>,
437 : mode: ClientMode,
438 : config: &'static AuthenticationConfig,
439 : ) -> anyhow::Result<()> {
440 : let Self {
441 : mut stream,
442 : creds,
443 : params,
444 : allow_self_signed_compute,
445 : endpoint_rate_limiter,
446 : } = self;
447 :
448 : // check rate limit
449 : if let Some(ep) = creds.get_endpoint() {
450 : if !endpoint_rate_limiter.check(ep) {
451 : return stream
452 : .throw_error(auth::AuthError::too_many_connections())
453 : .await;
454 : }
455 : }
456 :
457 : let extra = console::ConsoleReqExtra {
458 : options: neon_options(params),
459 : };
460 :
461 : let user = creds.get_user().to_owned();
462 : let auth_result = match creds
463 : .authenticate(ctx, &extra, &mut stream, mode.allow_cleartext(), config)
464 : .await
465 : {
466 : Ok(auth_result) => auth_result,
467 : Err(e) => {
468 : let db = params.get("database");
469 : let app = params.get("application_name");
470 : let params_span = tracing::info_span!("", ?user, ?db, ?app);
471 :
472 : return stream.throw_error(e).instrument(params_span).await;
473 : }
474 : };
475 :
476 : let (mut node_info, creds) = auth_result;
477 :
478 : node_info.allow_self_signed_compute = allow_self_signed_compute;
479 :
480 : let aux = node_info.aux.clone();
481 : let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &extra, &creds)
482 0 : .or_else(|e| stream.throw_error(e))
483 : .await?;
484 :
485 : prepare_client_connection(&node, session, &mut stream).await?;
486 : // Before proxy passing, forward to compute whatever data is left in the
487 : // PqStream input buffer. Normally there is none, but our serverless npm
488 : // driver in pipeline mode sends startup, password and first query
489 : // immediately after opening the connection.
490 : let (stream, read_buf) = stream.into_inner();
491 : node.stream.write_all(&read_buf).await?;
492 : proxy_pass(ctx, stream, node.stream, aux).await
493 : }
494 : }
495 :
496 CBC 188 : pub fn neon_options(params: &StartupMessageParams) -> Vec<(String, String)> {
497 188 : #[allow(unstable_name_collisions)]
498 188 : match params.options_raw() {
499 183 : Some(options) => options.filter_map(neon_option).collect(),
500 5 : None => vec![],
501 : }
502 188 : }
503 :
504 98 : pub fn neon_options_str(params: &StartupMessageParams) -> String {
505 98 : #[allow(unstable_name_collisions)]
506 98 : neon_options(params)
507 98 : .iter()
508 98 : .map(|(k, v)| format!("{}:{}", k, v))
509 98 : .sorted() // we sort it to use as cache key
510 98 : .intersperse(" ".to_owned())
511 98 : .collect()
512 98 : }
513 :
514 218 : pub fn neon_option(bytes: &str) -> Option<(String, String)> {
515 218 : static RE: OnceCell<Regex> = OnceCell::new();
516 218 : let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
517 :
518 218 : let cap = re.captures(bytes)?;
519 4 : let (_, [k, v]) = cap.extract();
520 4 : Some((k.to_owned(), v.to_owned()))
521 218 : }
|