LCOV - code coverage report
Current view: top level - proxy/src - proxy.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 88.1 % 243 214
Test Date: 2023-09-06 10:18:01 Functions: 45.8 % 190 87

            Line data    Source code
       1              : #[cfg(test)]
       2              : mod tests;
       3              : 
       4              : use crate::{
       5              :     auth::{self, backend::AuthSuccess},
       6              :     cancellation::{self, CancelMap},
       7              :     compute::{self, PostgresConnection},
       8              :     config::{ProxyConfig, TlsConfig},
       9              :     console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api},
      10              :     protocol2::WithClientIp,
      11              :     stream::{PqStream, Stream},
      12              : };
      13              : use anyhow::{bail, Context};
      14              : use async_trait::async_trait;
      15              : use futures::TryFutureExt;
      16              : use metrics::{
      17              :     exponential_buckets, register_histogram, register_int_counter_vec, Histogram, IntCounterVec,
      18              : };
      19              : use once_cell::sync::Lazy;
      20              : use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
      21              : use std::{error::Error, io, ops::ControlFlow, sync::Arc};
      22              : use tokio::{
      23              :     io::{AsyncRead, AsyncWrite, AsyncWriteExt},
      24              :     time,
      25              : };
      26              : use tokio_util::sync::CancellationToken;
      27              : use tracing::{error, info, info_span, warn, Instrument};
      28              : use utils::measured_stream::MeasuredStream;
      29              : 
      30              : /// Number of times we should retry the `/proxy_wake_compute` http request.
      31              : /// Retry duration is BASE_RETRY_WAIT_DURATION * 1.5^n
      32              : pub const NUM_RETRIES_CONNECT: u32 = 10;
      33              : const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
      34              : const BASE_RETRY_WAIT_DURATION: time::Duration = time::Duration::from_millis(100);
      35              : 
      36              : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
      37              : const ERR_PROTO_VIOLATION: &str = "protocol violation";
      38              : 
      39           13 : static NUM_CONNECTIONS_ACCEPTED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
      40           13 :     register_int_counter_vec!(
      41           13 :         "proxy_accepted_connections_total",
      42           13 :         "Number of TCP client connections accepted.",
      43           13 :         &["protocol"],
      44           13 :     )
      45           13 :     .unwrap()
      46           13 : });
      47              : 
      48           13 : static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
      49           13 :     register_int_counter_vec!(
      50           13 :         "proxy_closed_connections_total",
      51           13 :         "Number of TCP client connections closed.",
      52           13 :         &["protocol"],
      53           13 :     )
      54           13 :     .unwrap()
      55           13 : });
      56              : 
      57           14 : static COMPUTE_CONNECTION_LATENCY: Lazy<Histogram> = Lazy::new(|| {
      58           14 :     register_histogram!(
      59           14 :         "proxy_compute_connection_latency_seconds",
      60           14 :         "Time it took for proxy to establish a connection to the compute endpoint",
      61           14 :         // largest bucket = 2^16 * 0.5ms = 32s
      62           14 :         exponential_buckets(0.0005, 2.0, 16).unwrap(),
      63           14 :     )
      64           14 :     .unwrap()
      65           14 : });
      66              : 
      67            2 : static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
      68            2 :     register_int_counter_vec!(
      69            2 :         "proxy_connection_failures_total",
      70            2 :         "Number of connection failures (per kind).",
      71            2 :         &["kind"],
      72            2 :     )
      73            2 :     .unwrap()
      74            2 : });
      75              : 
      76           14 : static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
      77           14 :     register_int_counter_vec!(
      78           14 :         "proxy_io_bytes_per_client",
      79           14 :         "Number of bytes sent/received between client and backend.",
      80           14 :         crate::console::messages::MetricsAuxInfo::TRAFFIC_LABELS,
      81           14 :     )
      82           14 :     .unwrap()
      83           14 : });
      84              : 
      85           14 : pub async fn task_main(
      86           14 :     config: &'static ProxyConfig,
      87           14 :     listener: tokio::net::TcpListener,
      88           14 :     cancellation_token: CancellationToken,
      89           14 : ) -> anyhow::Result<()> {
      90           14 :     scopeguard::defer! {
      91           14 :         info!("proxy has shut down");
      92              :     }
      93              : 
      94              :     // When set for the server socket, the keepalive setting
      95              :     // will be inherited by all accepted client sockets.
      96           14 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
      97              : 
      98           14 :     let mut connections = tokio::task::JoinSet::new();
      99           14 :     let cancel_map = Arc::new(CancelMap::default());
     100              : 
     101              :     loop {
     102           49 :         tokio::select! {
     103           35 :             accept_result = listener.accept() => {
     104              :                 let (socket, _) = accept_result?;
     105              : 
     106              :                 let session_id = uuid::Uuid::new_v4();
     107              :                 let cancel_map = Arc::clone(&cancel_map);
     108              :                 connections.spawn(
     109           35 :                     async move {
     110           35 :                         info!("accepted postgres client connection");
     111              : 
     112           35 :                         let mut socket = WithClientIp::new(socket);
     113           35 :                         if let Some(ip) = socket.wait_for_addr().await? {
     114            0 :                             tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
     115           35 :                         }
     116              : 
     117           35 :                         socket
     118           35 :                             .inner
     119           35 :                             .set_nodelay(true)
     120           35 :                             .context("failed to set socket option")?;
     121              : 
     122          439 :                         handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp).await
     123           35 :                     }
     124              :                     .instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty))
     125           35 :                     .unwrap_or_else(move |e| {
     126           35 :                         // Acknowledge that the task has finished with an error.
     127           35 :                         error!(?session_id, "per-client task finished with an error: {e:#}");
     128           35 :                     }),
     129              :                 );
     130              :             }
     131              :             _ = cancellation_token.cancelled() => {
     132              :                 drop(listener);
     133              :                 break;
     134              :             }
     135              :         }
     136              :     }
     137              :     // Drain connections
     138           49 :     while let Some(res) = connections.join_next().await {
     139           35 :         if let Err(e) = res {
     140            0 :             if !e.is_panic() && !e.is_cancelled() {
     141            0 :                 warn!("unexpected error from joined connection task: {e:?}");
     142            0 :             }
     143           35 :         }
     144              :     }
     145           14 :     Ok(())
     146           14 : }
     147              : 
     148              : pub enum ClientMode {
     149              :     Tcp,
     150              :     Websockets { hostname: Option<String> },
     151              : }
     152              : 
     153              : /// Abstracts the logic of handling TCP vs WS clients
     154              : impl ClientMode {
     155          105 :     fn protocol_label(&self) -> &'static str {
     156          105 :         match self {
     157          105 :             ClientMode::Tcp => "tcp",
     158            0 :             ClientMode::Websockets { .. } => "ws",
     159              :         }
     160          105 :     }
     161              : 
     162           31 :     fn allow_cleartext(&self) -> bool {
     163           31 :         match self {
     164           31 :             ClientMode::Tcp => false,
     165            0 :             ClientMode::Websockets { .. } => true,
     166              :         }
     167           31 :     }
     168              : 
     169           31 :     fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
     170           31 :         match self {
     171           31 :             ClientMode::Tcp => config.allow_self_signed_compute,
     172            0 :             ClientMode::Websockets { .. } => false,
     173              :         }
     174           31 :     }
     175              : 
     176           31 :     fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
     177           31 :         match self {
     178           31 :             ClientMode::Tcp => s.sni_hostname(),
     179            0 :             ClientMode::Websockets { hostname } => hostname.as_deref(),
     180              :         }
     181           31 :     }
     182              : 
     183           35 :     fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
     184           35 :         match self {
     185           35 :             ClientMode::Tcp => tls,
     186              :             // TLS is None here if using websockets, because the connection is already encrypted.
     187            0 :             ClientMode::Websockets { .. } => None,
     188              :         }
     189           35 :     }
     190              : }
     191              : 
     192           35 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
     193           35 :     config: &'static ProxyConfig,
     194           35 :     cancel_map: &CancelMap,
     195           35 :     session_id: uuid::Uuid,
     196           35 :     stream: S,
     197           35 :     mode: ClientMode,
     198           35 : ) -> anyhow::Result<()> {
     199           35 :     info!(
     200           35 :         protocol = mode.protocol_label(),
     201           35 :         "handling interactive connection from client"
     202           35 :     );
     203              : 
     204              :     // The `closed` counter will increase when this future is destroyed.
     205           35 :     NUM_CONNECTIONS_ACCEPTED_COUNTER
     206           35 :         .with_label_values(&[mode.protocol_label()])
     207           35 :         .inc();
     208           35 :     scopeguard::defer! {
     209           35 :         NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[mode.protocol_label()]).inc();
     210           35 :     }
     211              : 
     212           35 :     let tls = config.tls_config.as_ref();
     213           35 : 
     214           35 :     let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
     215           81 :     let (mut stream, params) = match do_handshake.await? {
     216           31 :         Some(x) => x,
     217            0 :         None => return Ok(()), // it's a cancellation request
     218              :     };
     219              : 
     220              :     // Extract credentials which we're going to use for auth.
     221           31 :     let creds = {
     222           31 :         let hostname = mode.hostname(stream.get_ref());
     223           31 :         let common_names = tls.and_then(|tls| tls.common_names.clone());
     224           31 :         let result = config
     225           31 :             .auth_backend
     226           31 :             .as_ref()
     227           31 :             .map(|_| auth::ClientCredentials::parse(&params, hostname, common_names))
     228           31 :             .transpose();
     229           31 : 
     230           31 :         match result {
     231           31 :             Ok(creds) => creds,
     232            0 :             Err(e) => stream.throw_error(e).await?,
     233              :         }
     234              :     };
     235              : 
     236           31 :     let client = Client::new(
     237           31 :         stream,
     238           31 :         creds,
     239           31 :         &params,
     240           31 :         session_id,
     241           31 :         mode.allow_self_signed_compute(config),
     242           31 :     );
     243           31 :     cancel_map
     244           31 :         .with_session(|session| client.connect_to_db(session, mode.allow_cleartext()))
     245          358 :         .await
     246           35 : }
     247              : 
     248              : /// Establish a (most probably, secure) connection with the client.
     249              : /// For better testing experience, `stream` can be any object satisfying the traits.
     250              : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
     251              : /// we also take an extra care of propagating only the select handshake errors to client.
     252          133 : #[tracing::instrument(skip_all)]
     253              : async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
     254              :     stream: S,
     255              :     mut tls: Option<&TlsConfig>,
     256              :     cancel_map: &CancelMap,
     257              : ) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
     258              :     // Client may try upgrading to each protocol only once
     259              :     let (mut tried_ssl, mut tried_gss) = (false, false);
     260              : 
     261              :     let mut stream = PqStream::new(Stream::from_raw(stream));
     262              :     loop {
     263              :         let msg = stream.read_startup_packet().await?;
     264           66 :         info!("received {msg:?}");
     265              : 
     266              :         use FeStartupPacket::*;
     267              :         match msg {
     268              :             SslRequest => match stream.get_ref() {
     269              :                 Stream::Raw { .. } if !tried_ssl => {
     270              :                     tried_ssl = true;
     271              : 
     272              :                     // We can't perform TLS handshake without a config
     273              :                     let enc = tls.is_some();
     274              :                     stream.write_message(&Be::EncryptionResponse(enc)).await?;
     275              :                     if let Some(tls) = tls.take() {
     276              :                         // Upgrade raw stream into a secure TLS-backed stream.
     277              :                         // NOTE: We've consumed `tls`; this fact will be used later.
     278              : 
     279              :                         let (raw, read_buf) = stream.into_inner();
     280              :                         // TODO: Normally, client doesn't send any data before
     281              :                         // server says TLS handshake is ok and read_buf is empy.
     282              :                         // However, you could imagine pipelining of postgres
     283              :                         // SSLRequest + TLS ClientHello in one hunk similar to
     284              :                         // pipelining in our node js driver. We should probably
     285              :                         // support that by chaining read_buf with the stream.
     286              :                         if !read_buf.is_empty() {
     287              :                             bail!("data is sent before server replied with EncryptionResponse");
     288              :                         }
     289              :                         stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
     290              :                     }
     291              :                 }
     292              :                 _ => bail!(ERR_PROTO_VIOLATION),
     293              :             },
     294              :             GssEncRequest => match stream.get_ref() {
     295              :                 Stream::Raw { .. } if !tried_gss => {
     296              :                     tried_gss = true;
     297              : 
     298              :                     // Currently, we don't support GSSAPI
     299              :                     stream.write_message(&Be::EncryptionResponse(false)).await?;
     300              :                 }
     301              :                 _ => bail!(ERR_PROTO_VIOLATION),
     302              :             },
     303              :             StartupMessage { params, .. } => {
     304              :                 // Check that the config has been consumed during upgrade
     305              :                 // OR we didn't provide it at all (for dev purposes).
     306              :                 if tls.is_some() {
     307              :                     stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
     308              :                 }
     309              : 
     310           31 :                 info!(session_type = "normal", "successful handshake");
     311              :                 break Ok(Some((stream, params)));
     312              :             }
     313              :             CancelRequest(cancel_key_data) => {
     314              :                 cancel_map.cancel_session(cancel_key_data).await?;
     315              : 
     316            0 :                 info!(session_type = "cancellation", "successful handshake");
     317              :                 break Ok(None);
     318              :             }
     319              :         }
     320              :     }
     321              : }
     322              : 
     323              : /// If we couldn't connect, a cached connection info might be to blame
     324              : /// (e.g. the compute node's address might've changed at the wrong time).
     325              : /// Invalidate the cache entry (if any) to prevent subsequent errors.
     326           14 : #[tracing::instrument(name = "invalidate_cache", skip_all)]
     327              : pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
     328              :     let is_cached = node_info.cached();
     329              :     if is_cached {
     330            0 :         warn!("invalidating stalled compute node info cache entry");
     331              :     }
     332              :     let label = match is_cached {
     333              :         true => "compute_cached",
     334              :         false => "compute_uncached",
     335              :     };
     336              :     NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
     337              : 
     338              :     node_info.invalidate().config
     339              : }
     340              : 
     341              : /// Try to connect to the compute node once.
     342          108 : #[tracing::instrument(name = "connect_once", skip_all)]
     343              : async fn connect_to_compute_once(
     344              :     node_info: &console::CachedNodeInfo,
     345              :     timeout: time::Duration,
     346              : ) -> Result<PostgresConnection, compute::ConnectionError> {
     347              :     let allow_self_signed_compute = node_info.allow_self_signed_compute;
     348              : 
     349              :     node_info
     350              :         .config
     351              :         .connect(allow_self_signed_compute, timeout)
     352              :         .await
     353              : }
     354              : 
     355              : #[async_trait]
     356              : pub trait ConnectMechanism {
     357              :     type Connection;
     358              :     type ConnectError;
     359              :     type Error: From<Self::ConnectError>;
     360              :     async fn connect_once(
     361              :         &self,
     362              :         node_info: &console::CachedNodeInfo,
     363              :         timeout: time::Duration,
     364              :     ) -> Result<Self::Connection, Self::ConnectError>;
     365              : 
     366              :     fn update_connect_config(&self, conf: &mut compute::ConnCfg);
     367              : }
     368              : 
     369              : pub struct TcpMechanism<'a> {
     370              :     /// KV-dictionary with PostgreSQL connection params.
     371              :     pub params: &'a StartupMessageParams,
     372              : }
     373              : 
     374              : #[async_trait]
     375              : impl ConnectMechanism for TcpMechanism<'_> {
     376              :     type Connection = PostgresConnection;
     377              :     type ConnectError = compute::ConnectionError;
     378              :     type Error = compute::ConnectionError;
     379              : 
     380           27 :     async fn connect_once(
     381           27 :         &self,
     382           27 :         node_info: &console::CachedNodeInfo,
     383           27 :         timeout: time::Duration,
     384           27 :     ) -> Result<PostgresConnection, Self::Error> {
     385           80 :         connect_to_compute_once(node_info, timeout).await
     386           54 :     }
     387              : 
     388           27 :     fn update_connect_config(&self, config: &mut compute::ConnCfg) {
     389           27 :         config.set_startup_params(self.params);
     390           27 :     }
     391              : }
     392              : 
     393              : /// Try to connect to the compute node, retrying if necessary.
     394              : /// This function might update `node_info`, so we take it by `&mut`.
     395          115 : #[tracing::instrument(skip_all)]
     396              : pub async fn connect_to_compute<M: ConnectMechanism>(
     397              :     mechanism: &M,
     398              :     mut node_info: console::CachedNodeInfo,
     399              :     extra: &console::ConsoleReqExtra<'_>,
     400              :     creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
     401              : ) -> Result<M::Connection, M::Error>
     402              : where
     403              :     M::ConnectError: ShouldRetry + std::fmt::Debug,
     404              :     M::Error: From<WakeComputeError>,
     405              : {
     406              :     let _timer = COMPUTE_CONNECTION_LATENCY.start_timer();
     407              : 
     408              :     mechanism.update_connect_config(&mut node_info.config);
     409              : 
     410              :     // try once
     411              :     let (config, err) = match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
     412              :         Ok(res) => return Ok(res),
     413              :         Err(e) => {
     414            2 :             error!(error = ?e, "could not connect to compute node");
     415              :             (invalidate_cache(node_info), e)
     416              :         }
     417              :     };
     418              : 
     419              :     let mut num_retries = 1;
     420              : 
     421              :     // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
     422            2 :     info!("compute node's state has likely changed; requesting a wake-up");
     423              :     let node_info = loop {
     424              :         let wake_res = match creds {
     425              :             auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
     426              :             auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
     427              :             // nothing to do?
     428              :             auth::BackendType::Link(_) => return Err(err.into()),
     429              :             // test backend
     430              :             auth::BackendType::Test(x) => x.wake_compute(),
     431              :         };
     432              : 
     433              :         match handle_try_wake(wake_res, num_retries) {
     434              :             Err(e) => {
     435            0 :                 error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
     436              :                 return Err(e.into());
     437              :             }
     438              :             // failed to wake up but we can continue to retry
     439              :             Ok(ControlFlow::Continue(e)) => {
     440            0 :                 warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
     441              :             }
     442              :             // successfully woke up a compute node and can break the wakeup loop
     443              :             Ok(ControlFlow::Break(mut node_info)) => {
     444              :                 node_info.config.reuse_password(&config);
     445              :                 mechanism.update_connect_config(&mut node_info.config);
     446              :                 break node_info;
     447              :             }
     448              :         }
     449              : 
     450              :         let wait_duration = retry_after(num_retries);
     451              :         num_retries += 1;
     452              : 
     453              :         time::sleep(wait_duration).await;
     454              :     };
     455              : 
     456              :     // now that we have a new node, try connect to it repeatedly.
     457              :     // this can error for a few reasons, for instance:
     458              :     // * DNS connection settings haven't quite propagated yet
     459            2 :     info!("wake_compute success. attempting to connect");
     460              :     loop {
     461              :         match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
     462              :             Ok(res) => return Ok(res),
     463              :             Err(e) => {
     464              :                 let retriable = e.should_retry(num_retries);
     465              :                 if !retriable {
     466            2 :                     error!(error = ?e, num_retries, retriable, "couldn't connect to compute node");
     467              :                     return Err(e.into());
     468              :                 }
     469            0 :                 warn!(error = ?e, num_retries, retriable, "couldn't connect to compute node");
     470              :             }
     471              :         }
     472              : 
     473              :         let wait_duration = retry_after(num_retries);
     474              :         num_retries += 1;
     475              : 
     476              :         time::sleep(wait_duration).await;
     477              :     }
     478              : }
     479              : 
     480              : /// Attempts to wake up the compute node.
     481              : /// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable
     482              : /// * Returns Ok(Break(node)) if the wakeup succeeded
     483              : /// * Returns Err(e) if there was an error
     484           31 : pub fn handle_try_wake(
     485           31 :     result: Result<console::CachedNodeInfo, WakeComputeError>,
     486           31 :     num_retries: u32,
     487           31 : ) -> Result<ControlFlow<console::CachedNodeInfo, WakeComputeError>, WakeComputeError> {
     488           31 :     match result {
     489            2 :         Err(err) => match &err {
     490            2 :             WakeComputeError::ApiError(api) if api.should_retry(num_retries) => {
     491            1 :                 Ok(ControlFlow::Continue(err))
     492              :             }
     493            1 :             _ => Err(err),
     494              :         },
     495              :         // Ready to try again.
     496           29 :         Ok(new) => Ok(ControlFlow::Break(new)),
     497              :     }
     498           31 : }
     499              : 
     500              : pub trait ShouldRetry {
     501              :     fn could_retry(&self) -> bool;
     502           18 :     fn should_retry(&self, num_retries: u32) -> bool {
     503           18 :         match self {
     504           18 :             _ if num_retries >= NUM_RETRIES_CONNECT => false,
     505           17 :             err => err.could_retry(),
     506              :         }
     507           18 :     }
     508              : }
     509              : 
     510              : impl ShouldRetry for io::Error {
     511            0 :     fn could_retry(&self) -> bool {
     512              :         use std::io::ErrorKind;
     513            0 :         matches!(
     514            0 :             self.kind(),
     515              :             ErrorKind::ConnectionRefused | ErrorKind::AddrNotAvailable | ErrorKind::TimedOut
     516              :         )
     517            0 :     }
     518              : }
     519              : 
     520              : impl ShouldRetry for tokio_postgres::error::DbError {
     521            2 :     fn could_retry(&self) -> bool {
     522              :         use tokio_postgres::error::SqlState;
     523            2 :         matches!(
     524            2 :             self.code(),
     525              :             &SqlState::CONNECTION_FAILURE
     526              :                 | &SqlState::CONNECTION_EXCEPTION
     527              :                 | &SqlState::CONNECTION_DOES_NOT_EXIST
     528              :                 | &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
     529              :         )
     530            2 :     }
     531              : }
     532              : 
     533              : impl ShouldRetry for tokio_postgres::Error {
     534              :     fn could_retry(&self) -> bool {
     535            2 :         if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
     536            0 :             io::Error::could_retry(io_err)
     537            2 :         } else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
     538            2 :             tokio_postgres::error::DbError::could_retry(db_err)
     539              :         } else {
     540            0 :             false
     541              :         }
     542            2 :     }
     543              : }
     544              : 
     545              : impl ShouldRetry for compute::ConnectionError {
     546            0 :     fn could_retry(&self) -> bool {
     547            0 :         match self {
     548            0 :             compute::ConnectionError::Postgres(err) => err.could_retry(),
     549            0 :             compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
     550            0 :             _ => false,
     551              :         }
     552            0 :     }
     553              : }
     554              : 
     555           22 : pub fn retry_after(num_retries: u32) -> time::Duration {
     556           22 :     // 1.5 seems to be an ok growth factor heuristic
     557           22 :     BASE_RETRY_WAIT_DURATION.mul_f64(1.5_f64.powi(num_retries as i32))
     558           22 : }
     559              : 
     560              : /// Finish client connection initialization: confirm auth success, send params, etc.
     561          108 : #[tracing::instrument(skip_all)]
     562              : async fn prepare_client_connection(
     563              :     node: &compute::PostgresConnection,
     564              :     reported_auth_ok: bool,
     565              :     session: cancellation::Session<'_>,
     566              :     stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     567              : ) -> anyhow::Result<()> {
     568              :     // Register compute's query cancellation token and produce a new, unique one.
     569              :     // The new token (cancel_key_data) will be sent to the client.
     570              :     let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
     571              : 
     572              :     // Report authentication success if we haven't done this already.
     573              :     // Note that we do this only (for the most part) after we've connected
     574              :     // to a compute (see above) which performs its own authentication.
     575              :     if !reported_auth_ok {
     576              :         stream.write_message_noflush(&Be::AuthenticationOk)?;
     577              :     }
     578              : 
     579              :     // Forward all postgres connection params to the client.
     580              :     // Right now the implementation is very hacky and inefficent (ideally,
     581              :     // we don't need an intermediate hashmap), but at least it should be correct.
     582              :     for (name, value) in &node.params {
     583              :         // TODO: Theoretically, this could result in a big pile of params...
     584              :         stream.write_message_noflush(&Be::ParameterStatus {
     585              :             name: name.as_bytes(),
     586              :             value: value.as_bytes(),
     587              :         })?;
     588              :     }
     589              : 
     590              :     stream
     591              :         .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
     592              :         .write_message(&Be::ReadyForQuery)
     593              :         .await?;
     594              : 
     595              :     Ok(())
     596              : }
     597              : 
     598              : /// Forward bytes in both directions (client <-> compute).
     599          112 : #[tracing::instrument(skip_all)]
     600              : pub async fn proxy_pass(
     601              :     client: impl AsyncRead + AsyncWrite + Unpin,
     602              :     compute: impl AsyncRead + AsyncWrite + Unpin,
     603              :     aux: &MetricsAuxInfo,
     604              : ) -> anyhow::Result<()> {
     605              :     let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx"));
     606              :     let mut client = MeasuredStream::new(
     607              :         client,
     608           84 :         |_| {},
     609           56 :         |cnt| {
     610           56 :             // Number of bytes we sent to the client (outbound).
     611           56 :             m_sent.inc_by(cnt as u64);
     612           56 :         },
     613              :     );
     614              : 
     615              :     let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx"));
     616              :     let mut compute = MeasuredStream::new(
     617              :         compute,
     618           56 :         |_| {},
     619           56 :         |cnt| {
     620           56 :             // Number of bytes the client sent to the compute node (inbound).
     621           56 :             m_recv.inc_by(cnt as u64);
     622           56 :         },
     623              :     );
     624              : 
     625              :     // Starting from here we only proxy the client's traffic.
     626           28 :     info!("performing the proxy pass...");
     627              :     let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
     628              : 
     629              :     Ok(())
     630              : }
     631              : 
     632              : /// Thin connection context.
     633              : struct Client<'a, S> {
     634              :     /// The underlying libpq protocol stream.
     635              :     stream: PqStream<S>,
     636              :     /// Client credentials that we care about.
     637              :     creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
     638              :     /// KV-dictionary with PostgreSQL connection params.
     639              :     params: &'a StartupMessageParams,
     640              :     /// Unique connection ID.
     641              :     session_id: uuid::Uuid,
     642              :     /// Allow self-signed certificates (for testing).
     643              :     allow_self_signed_compute: bool,
     644              : }
     645              : 
     646              : impl<'a, S> Client<'a, S> {
     647              :     /// Construct a new connection context.
     648           31 :     fn new(
     649           31 :         stream: PqStream<S>,
     650           31 :         creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
     651           31 :         params: &'a StartupMessageParams,
     652           31 :         session_id: uuid::Uuid,
     653           31 :         allow_self_signed_compute: bool,
     654           31 :     ) -> Self {
     655           31 :         Self {
     656           31 :             stream,
     657           31 :             creds,
     658           31 :             params,
     659           31 :             session_id,
     660           31 :             allow_self_signed_compute,
     661           31 :         }
     662           31 :     }
     663              : }
     664              : 
     665              : impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
     666              :     /// Let the client authenticate and connect to the designated compute node.
     667              :     // Instrumentation logs endpoint name everywhere. Doesn't work for link
     668              :     // auth; strictly speaking we don't know endpoint name in its case.
     669           93 :     #[tracing::instrument(name = "", fields(ep = self.creds.get_endpoint().unwrap_or("".to_owned())), skip_all)]
     670              :     async fn connect_to_db(
     671              :         self,
     672              :         session: cancellation::Session<'_>,
     673              :         allow_cleartext: bool,
     674              :     ) -> anyhow::Result<()> {
     675              :         let Self {
     676              :             mut stream,
     677              :             mut creds,
     678              :             params,
     679              :             session_id,
     680              :             allow_self_signed_compute,
     681              :         } = self;
     682              : 
     683              :         let extra = console::ConsoleReqExtra {
     684              :             session_id, // aka this connection's id
     685              :             application_name: params.get("application_name"),
     686              :         };
     687              : 
     688              :         let auth_result = match creds
     689              :             .authenticate(&extra, &mut stream, allow_cleartext)
     690              :             .await
     691              :         {
     692              :             Ok(auth_result) => auth_result,
     693              :             Err(e) => return stream.throw_error(e).await,
     694              :         };
     695              : 
     696              :         let AuthSuccess {
     697              :             reported_auth_ok,
     698              :             value: mut node_info,
     699              :         } = auth_result;
     700              : 
     701              :         node_info.allow_self_signed_compute = allow_self_signed_compute;
     702              : 
     703              :         let aux = node_info.aux.clone();
     704              :         let mut node = connect_to_compute(&TcpMechanism { params }, node_info, &extra, &creds)
     705            0 :             .or_else(|e| stream.throw_error(e))
     706              :             .await?;
     707              : 
     708              :         prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
     709              :         // Before proxy passing, forward to compute whatever data is left in the
     710              :         // PqStream input buffer. Normally there is none, but our serverless npm
     711              :         // driver in pipeline mode sends startup, password and first query
     712              :         // immediately after opening the connection.
     713              :         let (stream, read_buf) = stream.into_inner();
     714              :         node.stream.write_all(&read_buf).await?;
     715              :         proxy_pass(stream, node.stream, &aux).await
     716              :     }
     717              : }
        

Generated by: LCOV version 2.1-beta