Line data Source code
1 : #[cfg(test)]
2 : mod tests;
3 :
4 : use crate::{
5 : auth::{self, backend::AuthSuccess},
6 : cancellation::{self, CancelMap},
7 : compute::{self, PostgresConnection},
8 : config::{ProxyConfig, TlsConfig},
9 : console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api},
10 : protocol2::WithClientIp,
11 : stream::{PqStream, Stream},
12 : };
13 : use anyhow::{bail, Context};
14 : use async_trait::async_trait;
15 : use futures::TryFutureExt;
16 : use metrics::{
17 : exponential_buckets, register_histogram, register_int_counter_vec, Histogram, IntCounterVec,
18 : };
19 : use once_cell::sync::Lazy;
20 : use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
21 : use std::{error::Error, io, ops::ControlFlow, sync::Arc};
22 : use tokio::{
23 : io::{AsyncRead, AsyncWrite, AsyncWriteExt},
24 : time,
25 : };
26 : use tokio_util::sync::CancellationToken;
27 : use tracing::{error, info, info_span, warn, Instrument};
28 : use utils::measured_stream::MeasuredStream;
29 :
30 : /// Number of times we should retry the `/proxy_wake_compute` http request.
31 : /// Retry duration is BASE_RETRY_WAIT_DURATION * 1.5^n
32 : pub const NUM_RETRIES_CONNECT: u32 = 10;
33 : const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
34 : const BASE_RETRY_WAIT_DURATION: time::Duration = time::Duration::from_millis(100);
35 :
36 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
37 : const ERR_PROTO_VIOLATION: &str = "protocol violation";
38 :
39 13 : static NUM_CONNECTIONS_ACCEPTED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
40 13 : register_int_counter_vec!(
41 13 : "proxy_accepted_connections_total",
42 13 : "Number of TCP client connections accepted.",
43 13 : &["protocol"],
44 13 : )
45 13 : .unwrap()
46 13 : });
47 :
48 13 : static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
49 13 : register_int_counter_vec!(
50 13 : "proxy_closed_connections_total",
51 13 : "Number of TCP client connections closed.",
52 13 : &["protocol"],
53 13 : )
54 13 : .unwrap()
55 13 : });
56 :
57 14 : static COMPUTE_CONNECTION_LATENCY: Lazy<Histogram> = Lazy::new(|| {
58 14 : register_histogram!(
59 14 : "proxy_compute_connection_latency_seconds",
60 14 : "Time it took for proxy to establish a connection to the compute endpoint",
61 14 : // largest bucket = 2^16 * 0.5ms = 32s
62 14 : exponential_buckets(0.0005, 2.0, 16).unwrap(),
63 14 : )
64 14 : .unwrap()
65 14 : });
66 :
67 2 : static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
68 2 : register_int_counter_vec!(
69 2 : "proxy_connection_failures_total",
70 2 : "Number of connection failures (per kind).",
71 2 : &["kind"],
72 2 : )
73 2 : .unwrap()
74 2 : });
75 :
76 14 : static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
77 14 : register_int_counter_vec!(
78 14 : "proxy_io_bytes_per_client",
79 14 : "Number of bytes sent/received between client and backend.",
80 14 : crate::console::messages::MetricsAuxInfo::TRAFFIC_LABELS,
81 14 : )
82 14 : .unwrap()
83 14 : });
84 :
85 14 : pub async fn task_main(
86 14 : config: &'static ProxyConfig,
87 14 : listener: tokio::net::TcpListener,
88 14 : cancellation_token: CancellationToken,
89 14 : ) -> anyhow::Result<()> {
90 14 : scopeguard::defer! {
91 14 : info!("proxy has shut down");
92 : }
93 :
94 : // When set for the server socket, the keepalive setting
95 : // will be inherited by all accepted client sockets.
96 14 : socket2::SockRef::from(&listener).set_keepalive(true)?;
97 :
98 14 : let mut connections = tokio::task::JoinSet::new();
99 14 : let cancel_map = Arc::new(CancelMap::default());
100 :
101 : loop {
102 49 : tokio::select! {
103 35 : accept_result = listener.accept() => {
104 : let (socket, _) = accept_result?;
105 :
106 : let session_id = uuid::Uuid::new_v4();
107 : let cancel_map = Arc::clone(&cancel_map);
108 : connections.spawn(
109 35 : async move {
110 35 : info!("accepted postgres client connection");
111 :
112 35 : let mut socket = WithClientIp::new(socket);
113 35 : if let Some(ip) = socket.wait_for_addr().await? {
114 0 : tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
115 35 : }
116 :
117 35 : socket
118 35 : .inner
119 35 : .set_nodelay(true)
120 35 : .context("failed to set socket option")?;
121 :
122 439 : handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp).await
123 35 : }
124 : .instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty))
125 35 : .unwrap_or_else(move |e| {
126 35 : // Acknowledge that the task has finished with an error.
127 35 : error!(?session_id, "per-client task finished with an error: {e:#}");
128 35 : }),
129 : );
130 : }
131 : _ = cancellation_token.cancelled() => {
132 : drop(listener);
133 : break;
134 : }
135 : }
136 : }
137 : // Drain connections
138 49 : while let Some(res) = connections.join_next().await {
139 35 : if let Err(e) = res {
140 0 : if !e.is_panic() && !e.is_cancelled() {
141 0 : warn!("unexpected error from joined connection task: {e:?}");
142 0 : }
143 35 : }
144 : }
145 14 : Ok(())
146 14 : }
147 :
148 : pub enum ClientMode {
149 : Tcp,
150 : Websockets { hostname: Option<String> },
151 : }
152 :
153 : /// Abstracts the logic of handling TCP vs WS clients
154 : impl ClientMode {
155 105 : fn protocol_label(&self) -> &'static str {
156 105 : match self {
157 105 : ClientMode::Tcp => "tcp",
158 0 : ClientMode::Websockets { .. } => "ws",
159 : }
160 105 : }
161 :
162 31 : fn allow_cleartext(&self) -> bool {
163 31 : match self {
164 31 : ClientMode::Tcp => false,
165 0 : ClientMode::Websockets { .. } => true,
166 : }
167 31 : }
168 :
169 31 : fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
170 31 : match self {
171 31 : ClientMode::Tcp => config.allow_self_signed_compute,
172 0 : ClientMode::Websockets { .. } => false,
173 : }
174 31 : }
175 :
176 31 : fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
177 31 : match self {
178 31 : ClientMode::Tcp => s.sni_hostname(),
179 0 : ClientMode::Websockets { hostname } => hostname.as_deref(),
180 : }
181 31 : }
182 :
183 35 : fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
184 35 : match self {
185 35 : ClientMode::Tcp => tls,
186 : // TLS is None here if using websockets, because the connection is already encrypted.
187 0 : ClientMode::Websockets { .. } => None,
188 : }
189 35 : }
190 : }
191 :
192 35 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
193 35 : config: &'static ProxyConfig,
194 35 : cancel_map: &CancelMap,
195 35 : session_id: uuid::Uuid,
196 35 : stream: S,
197 35 : mode: ClientMode,
198 35 : ) -> anyhow::Result<()> {
199 35 : info!(
200 35 : protocol = mode.protocol_label(),
201 35 : "handling interactive connection from client"
202 35 : );
203 :
204 : // The `closed` counter will increase when this future is destroyed.
205 35 : NUM_CONNECTIONS_ACCEPTED_COUNTER
206 35 : .with_label_values(&[mode.protocol_label()])
207 35 : .inc();
208 35 : scopeguard::defer! {
209 35 : NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[mode.protocol_label()]).inc();
210 35 : }
211 :
212 35 : let tls = config.tls_config.as_ref();
213 35 :
214 35 : let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
215 81 : let (mut stream, params) = match do_handshake.await? {
216 31 : Some(x) => x,
217 0 : None => return Ok(()), // it's a cancellation request
218 : };
219 :
220 : // Extract credentials which we're going to use for auth.
221 31 : let creds = {
222 31 : let hostname = mode.hostname(stream.get_ref());
223 31 : let common_names = tls.and_then(|tls| tls.common_names.clone());
224 31 : let result = config
225 31 : .auth_backend
226 31 : .as_ref()
227 31 : .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names))
228 31 : .transpose();
229 31 :
230 31 : match result {
231 31 : Ok(creds) => creds,
232 0 : Err(e) => stream.throw_error(e).await?,
233 : }
234 : };
235 :
236 31 : let client = Client::new(
237 31 : stream,
238 31 : creds,
239 31 : ¶ms,
240 31 : session_id,
241 31 : mode.allow_self_signed_compute(config),
242 31 : );
243 31 : cancel_map
244 31 : .with_session(|session| client.connect_to_db(session, mode.allow_cleartext()))
245 358 : .await
246 35 : }
247 :
248 : /// Establish a (most probably, secure) connection with the client.
249 : /// For better testing experience, `stream` can be any object satisfying the traits.
250 : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
251 : /// we also take an extra care of propagating only the select handshake errors to client.
252 133 : #[tracing::instrument(skip_all)]
253 : async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
254 : stream: S,
255 : mut tls: Option<&TlsConfig>,
256 : cancel_map: &CancelMap,
257 : ) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
258 : // Client may try upgrading to each protocol only once
259 : let (mut tried_ssl, mut tried_gss) = (false, false);
260 :
261 : let mut stream = PqStream::new(Stream::from_raw(stream));
262 : loop {
263 : let msg = stream.read_startup_packet().await?;
264 66 : info!("received {msg:?}");
265 :
266 : use FeStartupPacket::*;
267 : match msg {
268 : SslRequest => match stream.get_ref() {
269 : Stream::Raw { .. } if !tried_ssl => {
270 : tried_ssl = true;
271 :
272 : // We can't perform TLS handshake without a config
273 : let enc = tls.is_some();
274 : stream.write_message(&Be::EncryptionResponse(enc)).await?;
275 : if let Some(tls) = tls.take() {
276 : // Upgrade raw stream into a secure TLS-backed stream.
277 : // NOTE: We've consumed `tls`; this fact will be used later.
278 :
279 : let (raw, read_buf) = stream.into_inner();
280 : // TODO: Normally, client doesn't send any data before
281 : // server says TLS handshake is ok and read_buf is empy.
282 : // However, you could imagine pipelining of postgres
283 : // SSLRequest + TLS ClientHello in one hunk similar to
284 : // pipelining in our node js driver. We should probably
285 : // support that by chaining read_buf with the stream.
286 : if !read_buf.is_empty() {
287 : bail!("data is sent before server replied with EncryptionResponse");
288 : }
289 : stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
290 : }
291 : }
292 : _ => bail!(ERR_PROTO_VIOLATION),
293 : },
294 : GssEncRequest => match stream.get_ref() {
295 : Stream::Raw { .. } if !tried_gss => {
296 : tried_gss = true;
297 :
298 : // Currently, we don't support GSSAPI
299 : stream.write_message(&Be::EncryptionResponse(false)).await?;
300 : }
301 : _ => bail!(ERR_PROTO_VIOLATION),
302 : },
303 : StartupMessage { params, .. } => {
304 : // Check that the config has been consumed during upgrade
305 : // OR we didn't provide it at all (for dev purposes).
306 : if tls.is_some() {
307 : stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
308 : }
309 :
310 31 : info!(session_type = "normal", "successful handshake");
311 : break Ok(Some((stream, params)));
312 : }
313 : CancelRequest(cancel_key_data) => {
314 : cancel_map.cancel_session(cancel_key_data).await?;
315 :
316 0 : info!(session_type = "cancellation", "successful handshake");
317 : break Ok(None);
318 : }
319 : }
320 : }
321 : }
322 :
323 : /// If we couldn't connect, a cached connection info might be to blame
324 : /// (e.g. the compute node's address might've changed at the wrong time).
325 : /// Invalidate the cache entry (if any) to prevent subsequent errors.
326 14 : #[tracing::instrument(name = "invalidate_cache", skip_all)]
327 : pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
328 : let is_cached = node_info.cached();
329 : if is_cached {
330 0 : warn!("invalidating stalled compute node info cache entry");
331 : }
332 : let label = match is_cached {
333 : true => "compute_cached",
334 : false => "compute_uncached",
335 : };
336 : NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
337 :
338 : node_info.invalidate().config
339 : }
340 :
341 : /// Try to connect to the compute node once.
342 108 : #[tracing::instrument(name = "connect_once", skip_all)]
343 : async fn connect_to_compute_once(
344 : node_info: &console::CachedNodeInfo,
345 : timeout: time::Duration,
346 : ) -> Result<PostgresConnection, compute::ConnectionError> {
347 : let allow_self_signed_compute = node_info.allow_self_signed_compute;
348 :
349 : node_info
350 : .config
351 : .connect(allow_self_signed_compute, timeout)
352 : .await
353 : }
354 :
355 : #[async_trait]
356 : pub trait ConnectMechanism {
357 : type Connection;
358 : type ConnectError;
359 : type Error: From<Self::ConnectError>;
360 : async fn connect_once(
361 : &self,
362 : node_info: &console::CachedNodeInfo,
363 : timeout: time::Duration,
364 : ) -> Result<Self::Connection, Self::ConnectError>;
365 :
366 : fn update_connect_config(&self, conf: &mut compute::ConnCfg);
367 : }
368 :
369 : pub struct TcpMechanism<'a> {
370 : /// KV-dictionary with PostgreSQL connection params.
371 : pub params: &'a StartupMessageParams,
372 : }
373 :
374 : #[async_trait]
375 : impl ConnectMechanism for TcpMechanism<'_> {
376 : type Connection = PostgresConnection;
377 : type ConnectError = compute::ConnectionError;
378 : type Error = compute::ConnectionError;
379 :
380 27 : async fn connect_once(
381 27 : &self,
382 27 : node_info: &console::CachedNodeInfo,
383 27 : timeout: time::Duration,
384 27 : ) -> Result<PostgresConnection, Self::Error> {
385 80 : connect_to_compute_once(node_info, timeout).await
386 54 : }
387 :
388 27 : fn update_connect_config(&self, config: &mut compute::ConnCfg) {
389 27 : config.set_startup_params(self.params);
390 27 : }
391 : }
392 :
393 : /// Try to connect to the compute node, retrying if necessary.
394 : /// This function might update `node_info`, so we take it by `&mut`.
395 115 : #[tracing::instrument(skip_all)]
396 : pub async fn connect_to_compute<M: ConnectMechanism>(
397 : mechanism: &M,
398 : mut node_info: console::CachedNodeInfo,
399 : extra: &console::ConsoleReqExtra<'_>,
400 : creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
401 : ) -> Result<M::Connection, M::Error>
402 : where
403 : M::ConnectError: ShouldRetry + std::fmt::Debug,
404 : M::Error: From<WakeComputeError>,
405 : {
406 : let _timer = COMPUTE_CONNECTION_LATENCY.start_timer();
407 :
408 : mechanism.update_connect_config(&mut node_info.config);
409 :
410 : // try once
411 : let (config, err) = match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
412 : Ok(res) => return Ok(res),
413 : Err(e) => {
414 2 : error!(error = ?e, "could not connect to compute node");
415 : (invalidate_cache(node_info), e)
416 : }
417 : };
418 :
419 : let mut num_retries = 1;
420 :
421 : // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
422 2 : info!("compute node's state has likely changed; requesting a wake-up");
423 : let node_info = loop {
424 : let wake_res = match creds {
425 : auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
426 : auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
427 : // nothing to do?
428 : auth::BackendType::Link(_) => return Err(err.into()),
429 : // test backend
430 : auth::BackendType::Test(x) => x.wake_compute(),
431 : };
432 :
433 : match handle_try_wake(wake_res, num_retries) {
434 : Err(e) => {
435 0 : error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
436 : return Err(e.into());
437 : }
438 : // failed to wake up but we can continue to retry
439 : Ok(ControlFlow::Continue(e)) => {
440 0 : warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
441 : }
442 : // successfully woke up a compute node and can break the wakeup loop
443 : Ok(ControlFlow::Break(mut node_info)) => {
444 : node_info.config.reuse_password(&config);
445 : mechanism.update_connect_config(&mut node_info.config);
446 : break node_info;
447 : }
448 : }
449 :
450 : let wait_duration = retry_after(num_retries);
451 : num_retries += 1;
452 :
453 : time::sleep(wait_duration).await;
454 : };
455 :
456 : // now that we have a new node, try connect to it repeatedly.
457 : // this can error for a few reasons, for instance:
458 : // * DNS connection settings haven't quite propagated yet
459 2 : info!("wake_compute success. attempting to connect");
460 : loop {
461 : match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
462 : Ok(res) => return Ok(res),
463 : Err(e) => {
464 : let retriable = e.should_retry(num_retries);
465 : if !retriable {
466 2 : error!(error = ?e, num_retries, retriable, "couldn't connect to compute node");
467 : return Err(e.into());
468 : }
469 0 : warn!(error = ?e, num_retries, retriable, "couldn't connect to compute node");
470 : }
471 : }
472 :
473 : let wait_duration = retry_after(num_retries);
474 : num_retries += 1;
475 :
476 : time::sleep(wait_duration).await;
477 : }
478 : }
479 :
480 : /// Attempts to wake up the compute node.
481 : /// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable
482 : /// * Returns Ok(Break(node)) if the wakeup succeeded
483 : /// * Returns Err(e) if there was an error
484 31 : pub fn handle_try_wake(
485 31 : result: Result<console::CachedNodeInfo, WakeComputeError>,
486 31 : num_retries: u32,
487 31 : ) -> Result<ControlFlow<console::CachedNodeInfo, WakeComputeError>, WakeComputeError> {
488 31 : match result {
489 2 : Err(err) => match &err {
490 2 : WakeComputeError::ApiError(api) if api.should_retry(num_retries) => {
491 1 : Ok(ControlFlow::Continue(err))
492 : }
493 1 : _ => Err(err),
494 : },
495 : // Ready to try again.
496 29 : Ok(new) => Ok(ControlFlow::Break(new)),
497 : }
498 31 : }
499 :
500 : pub trait ShouldRetry {
501 : fn could_retry(&self) -> bool;
502 18 : fn should_retry(&self, num_retries: u32) -> bool {
503 18 : match self {
504 18 : _ if num_retries >= NUM_RETRIES_CONNECT => false,
505 17 : err => err.could_retry(),
506 : }
507 18 : }
508 : }
509 :
510 : impl ShouldRetry for io::Error {
511 0 : fn could_retry(&self) -> bool {
512 : use std::io::ErrorKind;
513 0 : matches!(
514 0 : self.kind(),
515 : ErrorKind::ConnectionRefused | ErrorKind::AddrNotAvailable | ErrorKind::TimedOut
516 : )
517 0 : }
518 : }
519 :
520 : impl ShouldRetry for tokio_postgres::error::DbError {
521 2 : fn could_retry(&self) -> bool {
522 : use tokio_postgres::error::SqlState;
523 2 : matches!(
524 2 : self.code(),
525 : &SqlState::CONNECTION_FAILURE
526 : | &SqlState::CONNECTION_EXCEPTION
527 : | &SqlState::CONNECTION_DOES_NOT_EXIST
528 : | &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
529 : )
530 2 : }
531 : }
532 :
533 : impl ShouldRetry for tokio_postgres::Error {
534 : fn could_retry(&self) -> bool {
535 2 : if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
536 0 : io::Error::could_retry(io_err)
537 2 : } else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
538 2 : tokio_postgres::error::DbError::could_retry(db_err)
539 : } else {
540 0 : false
541 : }
542 2 : }
543 : }
544 :
545 : impl ShouldRetry for compute::ConnectionError {
546 0 : fn could_retry(&self) -> bool {
547 0 : match self {
548 0 : compute::ConnectionError::Postgres(err) => err.could_retry(),
549 0 : compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
550 0 : _ => false,
551 : }
552 0 : }
553 : }
554 :
555 22 : pub fn retry_after(num_retries: u32) -> time::Duration {
556 22 : // 1.5 seems to be an ok growth factor heuristic
557 22 : BASE_RETRY_WAIT_DURATION.mul_f64(1.5_f64.powi(num_retries as i32))
558 22 : }
559 :
560 : /// Finish client connection initialization: confirm auth success, send params, etc.
561 108 : #[tracing::instrument(skip_all)]
562 : async fn prepare_client_connection(
563 : node: &compute::PostgresConnection,
564 : reported_auth_ok: bool,
565 : session: cancellation::Session<'_>,
566 : stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
567 : ) -> anyhow::Result<()> {
568 : // Register compute's query cancellation token and produce a new, unique one.
569 : // The new token (cancel_key_data) will be sent to the client.
570 : let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
571 :
572 : // Report authentication success if we haven't done this already.
573 : // Note that we do this only (for the most part) after we've connected
574 : // to a compute (see above) which performs its own authentication.
575 : if !reported_auth_ok {
576 : stream.write_message_noflush(&Be::AuthenticationOk)?;
577 : }
578 :
579 : // Forward all postgres connection params to the client.
580 : // Right now the implementation is very hacky and inefficent (ideally,
581 : // we don't need an intermediate hashmap), but at least it should be correct.
582 : for (name, value) in &node.params {
583 : // TODO: Theoretically, this could result in a big pile of params...
584 : stream.write_message_noflush(&Be::ParameterStatus {
585 : name: name.as_bytes(),
586 : value: value.as_bytes(),
587 : })?;
588 : }
589 :
590 : stream
591 : .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
592 : .write_message(&Be::ReadyForQuery)
593 : .await?;
594 :
595 : Ok(())
596 : }
597 :
598 : /// Forward bytes in both directions (client <-> compute).
599 112 : #[tracing::instrument(skip_all)]
600 : pub async fn proxy_pass(
601 : client: impl AsyncRead + AsyncWrite + Unpin,
602 : compute: impl AsyncRead + AsyncWrite + Unpin,
603 : aux: &MetricsAuxInfo,
604 : ) -> anyhow::Result<()> {
605 : let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx"));
606 : let mut client = MeasuredStream::new(
607 : client,
608 84 : |_| {},
609 56 : |cnt| {
610 56 : // Number of bytes we sent to the client (outbound).
611 56 : m_sent.inc_by(cnt as u64);
612 56 : },
613 : );
614 :
615 : let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx"));
616 : let mut compute = MeasuredStream::new(
617 : compute,
618 56 : |_| {},
619 56 : |cnt| {
620 56 : // Number of bytes the client sent to the compute node (inbound).
621 56 : m_recv.inc_by(cnt as u64);
622 56 : },
623 : );
624 :
625 : // Starting from here we only proxy the client's traffic.
626 28 : info!("performing the proxy pass...");
627 : let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
628 :
629 : Ok(())
630 : }
631 :
632 : /// Thin connection context.
633 : struct Client<'a, S> {
634 : /// The underlying libpq protocol stream.
635 : stream: PqStream<S>,
636 : /// Client credentials that we care about.
637 : creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
638 : /// KV-dictionary with PostgreSQL connection params.
639 : params: &'a StartupMessageParams,
640 : /// Unique connection ID.
641 : session_id: uuid::Uuid,
642 : /// Allow self-signed certificates (for testing).
643 : allow_self_signed_compute: bool,
644 : }
645 :
646 : impl<'a, S> Client<'a, S> {
647 : /// Construct a new connection context.
648 31 : fn new(
649 31 : stream: PqStream<S>,
650 31 : creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
651 31 : params: &'a StartupMessageParams,
652 31 : session_id: uuid::Uuid,
653 31 : allow_self_signed_compute: bool,
654 31 : ) -> Self {
655 31 : Self {
656 31 : stream,
657 31 : creds,
658 31 : params,
659 31 : session_id,
660 31 : allow_self_signed_compute,
661 31 : }
662 31 : }
663 : }
664 :
665 : impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
666 : /// Let the client authenticate and connect to the designated compute node.
667 : // Instrumentation logs endpoint name everywhere. Doesn't work for link
668 : // auth; strictly speaking we don't know endpoint name in its case.
669 93 : #[tracing::instrument(name = "", fields(ep = self.creds.get_endpoint().unwrap_or("".to_owned())), skip_all)]
670 : async fn connect_to_db(
671 : self,
672 : session: cancellation::Session<'_>,
673 : allow_cleartext: bool,
674 : ) -> anyhow::Result<()> {
675 : let Self {
676 : mut stream,
677 : mut creds,
678 : params,
679 : session_id,
680 : allow_self_signed_compute,
681 : } = self;
682 :
683 : let extra = console::ConsoleReqExtra {
684 : session_id, // aka this connection's id
685 : application_name: params.get("application_name"),
686 : };
687 :
688 : let auth_result = match creds
689 : .authenticate(&extra, &mut stream, allow_cleartext)
690 : .await
691 : {
692 : Ok(auth_result) => auth_result,
693 : Err(e) => return stream.throw_error(e).await,
694 : };
695 :
696 : let AuthSuccess {
697 : reported_auth_ok,
698 : value: mut node_info,
699 : } = auth_result;
700 :
701 : node_info.allow_self_signed_compute = allow_self_signed_compute;
702 :
703 : let aux = node_info.aux.clone();
704 : let mut node = connect_to_compute(&TcpMechanism { params }, node_info, &extra, &creds)
705 0 : .or_else(|e| stream.throw_error(e))
706 : .await?;
707 :
708 : prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
709 : // Before proxy passing, forward to compute whatever data is left in the
710 : // PqStream input buffer. Normally there is none, but our serverless npm
711 : // driver in pipeline mode sends startup, password and first query
712 : // immediately after opening the connection.
713 : let (stream, read_buf) = stream.into_inner();
714 : node.stream.write_all(&read_buf).await?;
715 : proxy_pass(stream, node.stream, &aux).await
716 : }
717 : }
|