LCOV - differential code coverage report
Current view: top level - proxy/src - proxy.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 88.2 % 346 305 41 305
Current Date: 2023-10-19 02:04:12 Functions: 49.5 % 210 104 106 104
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : #[cfg(test)]
       2                 : mod tests;
       3                 : 
       4                 : use crate::{
       5                 :     auth::{self, backend::AuthSuccess},
       6                 :     cancellation::{self, CancelMap},
       7                 :     compute::{self, PostgresConnection},
       8                 :     config::{ProxyConfig, TlsConfig},
       9                 :     console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api},
      10                 :     http::StatusCode,
      11                 :     metrics::{Ids, USAGE_METRICS},
      12                 :     protocol2::WithClientIp,
      13                 :     stream::{PqStream, Stream},
      14                 : };
      15                 : use anyhow::{bail, Context};
      16                 : use async_trait::async_trait;
      17                 : use futures::TryFutureExt;
      18                 : use metrics::{exponential_buckets, register_int_counter_vec, IntCounterVec};
      19                 : use once_cell::sync::Lazy;
      20                 : use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
      21                 : use prometheus::{register_histogram_vec, HistogramVec};
      22                 : use std::{error::Error, io, ops::ControlFlow, sync::Arc, time::Instant};
      23                 : use tokio::{
      24                 :     io::{AsyncRead, AsyncWrite, AsyncWriteExt},
      25                 :     time,
      26                 : };
      27                 : use tokio_util::sync::CancellationToken;
      28                 : use tracing::{error, info, info_span, warn, Instrument};
      29                 : use utils::measured_stream::MeasuredStream;
      30                 : 
      31                 : /// Number of times we should retry the `/proxy_wake_compute` http request.
      32                 : /// Retry duration is BASE_RETRY_WAIT_DURATION * RETRY_WAIT_EXPONENT_BASE ^ n, where n starts at 0
      33                 : pub const NUM_RETRIES_CONNECT: u32 = 16;
      34                 : const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
      35                 : const BASE_RETRY_WAIT_DURATION: time::Duration = time::Duration::from_millis(25);
      36                 : const RETRY_WAIT_EXPONENT_BASE: f64 = std::f64::consts::SQRT_2;
      37                 : 
      38                 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
      39                 : const ERR_PROTO_VIOLATION: &str = "protocol violation";
      40                 : 
      41 CBC          15 : pub static NUM_DB_CONNECTIONS_OPENED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
      42              15 :     register_int_counter_vec!(
      43              15 :         "proxy_opened_db_connections_total",
      44              15 :         "Number of opened connections to a database.",
      45              15 :         &["protocol"],
      46              15 :     )
      47              15 :     .unwrap()
      48              15 : });
      49                 : 
      50              15 : pub static NUM_DB_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
      51              15 :     register_int_counter_vec!(
      52              15 :         "proxy_closed_db_connections_total",
      53              15 :         "Number of closed connections to a database.",
      54              15 :         &["protocol"],
      55              15 :     )
      56              15 :     .unwrap()
      57              15 : });
      58                 : 
      59              15 : pub static NUM_CLIENT_CONNECTION_OPENED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
      60              15 :     register_int_counter_vec!(
      61              15 :         "proxy_opened_client_connections_total",
      62              15 :         "Number of opened connections from a client.",
      63              15 :         &["protocol"],
      64              15 :     )
      65              15 :     .unwrap()
      66              15 : });
      67                 : 
      68              15 : pub static NUM_CLIENT_CONNECTION_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
      69              15 :     register_int_counter_vec!(
      70              15 :         "proxy_closed_client_connections_total",
      71              15 :         "Number of closed connections from a client.",
      72              15 :         &["protocol"],
      73              15 :     )
      74              15 :     .unwrap()
      75              15 : });
      76                 : 
      77              15 : pub static NUM_CONNECTIONS_ACCEPTED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
      78              15 :     register_int_counter_vec!(
      79              15 :         "proxy_accepted_connections_total",
      80              15 :         "Number of client connections accepted.",
      81              15 :         &["protocol"],
      82              15 :     )
      83              15 :     .unwrap()
      84              15 : });
      85                 : 
      86              15 : pub static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
      87              15 :     register_int_counter_vec!(
      88              15 :         "proxy_closed_connections_total",
      89              15 :         "Number of client connections closed.",
      90              15 :         &["protocol"],
      91              15 :     )
      92              15 :     .unwrap()
      93              15 : });
      94                 : 
      95              16 : static COMPUTE_CONNECTION_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
      96              16 :     register_histogram_vec!(
      97              16 :         "proxy_compute_connection_latency_seconds",
      98              16 :         "Time it took for proxy to establish a connection to the compute endpoint",
      99              16 :         &["protocol", "cache_miss", "pool_miss"],
     100              16 :         // largest bucket = 2^16 * 0.5ms = 32s
     101              16 :         exponential_buckets(0.0005, 2.0, 16).unwrap(),
     102              16 :     )
     103              16 :     .unwrap()
     104              16 : });
     105                 : 
     106                 : pub struct LatencyTimer {
     107                 :     start: Instant,
     108                 :     pool_miss: bool,
     109                 :     cache_miss: bool,
     110                 :     protocol: &'static str,
     111                 : }
     112                 : 
     113                 : impl LatencyTimer {
     114              70 :     pub fn new(protocol: &'static str) -> Self {
     115              70 :         Self {
     116              70 :             start: Instant::now(),
     117              70 :             cache_miss: false,
     118              70 :             // by default we don't do pooling
     119              70 :             pool_miss: true,
     120              70 :             protocol,
     121              70 :         }
     122              70 :     }
     123                 : 
     124               8 :     pub fn cache_miss(&mut self) {
     125               8 :         self.cache_miss = true;
     126               8 :     }
     127                 : 
     128               3 :     pub fn pool_hit(&mut self) {
     129               3 :         self.pool_miss = false;
     130               3 :     }
     131                 : }
     132                 : 
     133                 : impl Drop for LatencyTimer {
     134              70 :     fn drop(&mut self) {
     135              70 :         let duration = self.start.elapsed().as_secs_f64();
     136              70 :         COMPUTE_CONNECTION_LATENCY
     137              70 :             .with_label_values(&[
     138              70 :                 self.protocol,
     139              70 :                 bool_to_str(self.cache_miss),
     140              70 :                 bool_to_str(self.pool_miss),
     141              70 :             ])
     142              70 :             .observe(duration)
     143              70 :     }
     144                 : }
     145                 : 
     146               2 : static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
     147               2 :     register_int_counter_vec!(
     148               2 :         "proxy_connection_failures_total",
     149               2 :         "Number of connection failures (per kind).",
     150               2 :         &["kind"],
     151               2 :     )
     152               2 :     .unwrap()
     153               2 : });
     154                 : 
     155               1 : static NUM_WAKEUP_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
     156               1 :     register_int_counter_vec!(
     157               1 :         "proxy_connection_failures_breakdown",
     158               1 :         "Number of wake-up failures (per kind).",
     159               1 :         &["retry", "kind"],
     160               1 :     )
     161               1 :     .unwrap()
     162               1 : });
     163                 : 
     164              16 : static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
     165              16 :     register_int_counter_vec!(
     166              16 :         "proxy_io_bytes_per_client",
     167              16 :         "Number of bytes sent/received between client and backend.",
     168              16 :         crate::console::messages::MetricsAuxInfo::TRAFFIC_LABELS,
     169              16 :     )
     170              16 :     .unwrap()
     171              16 : });
     172                 : 
     173              16 : pub async fn task_main(
     174              16 :     config: &'static ProxyConfig,
     175              16 :     listener: tokio::net::TcpListener,
     176              16 :     cancellation_token: CancellationToken,
     177              16 : ) -> anyhow::Result<()> {
     178              16 :     scopeguard::defer! {
     179              16 :         info!("proxy has shut down");
     180                 :     }
     181                 : 
     182                 :     // When set for the server socket, the keepalive setting
     183                 :     // will be inherited by all accepted client sockets.
     184              16 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
     185                 : 
     186              16 :     let mut connections = tokio::task::JoinSet::new();
     187              16 :     let cancel_map = Arc::new(CancelMap::default());
     188                 : 
     189                 :     loop {
     190              53 :         tokio::select! {
     191              37 :             accept_result = listener.accept() => {
     192                 :                 let (socket, _) = accept_result?;
     193                 : 
     194                 :                 let session_id = uuid::Uuid::new_v4();
     195                 :                 let cancel_map = Arc::clone(&cancel_map);
     196                 :                 connections.spawn(
     197              37 :                     async move {
     198              37 :                         info!("accepted postgres client connection");
     199                 : 
     200              37 :                         let mut socket = WithClientIp::new(socket);
     201              37 :                         if let Some(ip) = socket.wait_for_addr().await? {
     202 UBC           0 :                             tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
     203 CBC          37 :                         } else if config.require_client_ip {
     204 UBC           0 :                             bail!("missing required client IP");
     205 CBC          37 :                         }
     206                 : 
     207              37 :                         socket
     208              37 :                             .inner
     209              37 :                             .set_nodelay(true)
     210              37 :                             .context("failed to set socket option")?;
     211                 : 
     212             485 :                         handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp).await
     213              37 :                     }
     214                 :                     .instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty))
     215              37 :                     .unwrap_or_else(move |e| {
     216              37 :                         // Acknowledge that the task has finished with an error.
     217              37 :                         error!(?session_id, "per-client task finished with an error: {e:#}");
     218              37 :                     }),
     219                 :                 );
     220                 :             }
     221 UBC           0 :             Some(Err(e)) = connections.join_next(), if !connections.is_empty() => {
     222                 :                 if !e.is_panic() && !e.is_cancelled() {
     223               0 :                     warn!("unexpected error from joined connection task: {e:?}");
     224                 :                 }
     225                 :             }
     226                 :             _ = cancellation_token.cancelled() => {
     227                 :                 drop(listener);
     228                 :                 break;
     229                 :             }
     230                 :         }
     231                 :     }
     232                 :     // Drain connections
     233 CBC          21 :     while let Some(res) = connections.join_next().await {
     234               5 :         if let Err(e) = res {
     235 UBC           0 :             if !e.is_panic() && !e.is_cancelled() {
     236               0 :                 warn!("unexpected error from joined connection task: {e:?}");
     237               0 :             }
     238 CBC           5 :         }
     239                 :     }
     240              16 :     Ok(())
     241              16 : }
     242                 : 
     243                 : pub enum ClientMode {
     244                 :     Tcp,
     245                 :     Websockets { hostname: Option<String> },
     246                 : }
     247                 : 
     248                 : /// Abstracts the logic of handling TCP vs WS clients
     249                 : impl ClientMode {
     250             136 :     fn protocol_label(&self) -> &'static str {
     251             136 :         match self {
     252             136 :             ClientMode::Tcp => "tcp",
     253 UBC           0 :             ClientMode::Websockets { .. } => "ws",
     254                 :         }
     255 CBC         136 :     }
     256                 : 
     257              33 :     fn allow_cleartext(&self) -> bool {
     258              33 :         match self {
     259              33 :             ClientMode::Tcp => false,
     260 UBC           0 :             ClientMode::Websockets { .. } => true,
     261                 :         }
     262 CBC          33 :     }
     263                 : 
     264              33 :     fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
     265              33 :         match self {
     266              33 :             ClientMode::Tcp => config.allow_self_signed_compute,
     267 UBC           0 :             ClientMode::Websockets { .. } => false,
     268                 :         }
     269 CBC          33 :     }
     270                 : 
     271              33 :     fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
     272              33 :         match self {
     273              33 :             ClientMode::Tcp => s.sni_hostname(),
     274 UBC           0 :             ClientMode::Websockets { hostname } => hostname.as_deref(),
     275                 :         }
     276 CBC          33 :     }
     277                 : 
     278              37 :     fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
     279              37 :         match self {
     280              37 :             ClientMode::Tcp => tls,
     281                 :             // TLS is None here if using websockets, because the connection is already encrypted.
     282 UBC           0 :             ClientMode::Websockets { .. } => None,
     283                 :         }
     284 CBC          37 :     }
     285                 : }
     286                 : 
     287              37 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
     288              37 :     config: &'static ProxyConfig,
     289              37 :     cancel_map: &CancelMap,
     290              37 :     session_id: uuid::Uuid,
     291              37 :     stream: S,
     292              37 :     mode: ClientMode,
     293              37 : ) -> anyhow::Result<()> {
     294              37 :     info!(
     295              37 :         protocol = mode.protocol_label(),
     296              37 :         "handling interactive connection from client"
     297              37 :     );
     298                 : 
     299              37 :     let proto = mode.protocol_label();
     300              37 :     NUM_CLIENT_CONNECTION_OPENED_COUNTER
     301              37 :         .with_label_values(&[proto])
     302              37 :         .inc();
     303              37 :     NUM_CONNECTIONS_ACCEPTED_COUNTER
     304              37 :         .with_label_values(&[proto])
     305              37 :         .inc();
     306              37 :     scopeguard::defer! {
     307              37 :         NUM_CLIENT_CONNECTION_CLOSED_COUNTER.with_label_values(&[proto]).inc();
     308              37 :         NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[proto]).inc();
     309              37 :     }
     310                 : 
     311              37 :     let tls = config.tls_config.as_ref();
     312              37 : 
     313              37 :     let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
     314              88 :     let (mut stream, params) = match do_handshake.await? {
     315              33 :         Some(x) => x,
     316 UBC           0 :         None => return Ok(()), // it's a cancellation request
     317                 :     };
     318                 : 
     319                 :     // Extract credentials which we're going to use for auth.
     320 CBC          33 :     let creds = {
     321              33 :         let hostname = mode.hostname(stream.get_ref());
     322              33 :         let common_names = tls.and_then(|tls| tls.common_names.clone());
     323              33 :         let result = config
     324              33 :             .auth_backend
     325              33 :             .as_ref()
     326              33 :             .map(|_| auth::ClientCredentials::parse(&params, hostname, common_names))
     327              33 :             .transpose();
     328              33 : 
     329              33 :         match result {
     330              33 :             Ok(creds) => creds,
     331 UBC           0 :             Err(e) => stream.throw_error(e).await?,
     332                 :         }
     333                 :     };
     334                 : 
     335 CBC          33 :     let client = Client::new(
     336              33 :         stream,
     337              33 :         creds,
     338              33 :         &params,
     339              33 :         session_id,
     340              33 :         mode.allow_self_signed_compute(config),
     341              33 :     );
     342              33 :     cancel_map
     343              33 :         .with_session(|session| client.connect_to_db(session, mode))
     344             397 :         .await
     345              37 : }
     346                 : 
     347                 : /// Establish a (most probably, secure) connection with the client.
     348                 : /// For better testing experience, `stream` can be any object satisfying the traits.
     349                 : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
     350                 : /// we also take an extra care of propagating only the select handshake errors to client.
     351             139 : #[tracing::instrument(skip_all)]
     352                 : async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
     353                 :     stream: S,
     354                 :     mut tls: Option<&TlsConfig>,
     355                 :     cancel_map: &CancelMap,
     356                 : ) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
     357                 :     // Client may try upgrading to each protocol only once
     358                 :     let (mut tried_ssl, mut tried_gss) = (false, false);
     359                 : 
     360                 :     let mut stream = PqStream::new(Stream::from_raw(stream));
     361                 :     loop {
     362                 :         let msg = stream.read_startup_packet().await?;
     363              70 :         info!("received {msg:?}");
     364                 : 
     365                 :         use FeStartupPacket::*;
     366                 :         match msg {
     367                 :             SslRequest => match stream.get_ref() {
     368                 :                 Stream::Raw { .. } if !tried_ssl => {
     369                 :                     tried_ssl = true;
     370                 : 
     371                 :                     // We can't perform TLS handshake without a config
     372                 :                     let enc = tls.is_some();
     373                 :                     stream.write_message(&Be::EncryptionResponse(enc)).await?;
     374                 :                     if let Some(tls) = tls.take() {
     375                 :                         // Upgrade raw stream into a secure TLS-backed stream.
     376                 :                         // NOTE: We've consumed `tls`; this fact will be used later.
     377                 : 
     378                 :                         let (raw, read_buf) = stream.into_inner();
     379                 :                         // TODO: Normally, client doesn't send any data before
     380                 :                         // server says TLS handshake is ok and read_buf is empy.
     381                 :                         // However, you could imagine pipelining of postgres
     382                 :                         // SSLRequest + TLS ClientHello in one hunk similar to
     383                 :                         // pipelining in our node js driver. We should probably
     384                 :                         // support that by chaining read_buf with the stream.
     385                 :                         if !read_buf.is_empty() {
     386                 :                             bail!("data is sent before server replied with EncryptionResponse");
     387                 :                         }
     388                 :                         stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
     389                 :                     }
     390                 :                 }
     391                 :                 _ => bail!(ERR_PROTO_VIOLATION),
     392                 :             },
     393                 :             GssEncRequest => match stream.get_ref() {
     394                 :                 Stream::Raw { .. } if !tried_gss => {
     395                 :                     tried_gss = true;
     396                 : 
     397                 :                     // Currently, we don't support GSSAPI
     398                 :                     stream.write_message(&Be::EncryptionResponse(false)).await?;
     399                 :                 }
     400                 :                 _ => bail!(ERR_PROTO_VIOLATION),
     401                 :             },
     402                 :             StartupMessage { params, .. } => {
     403                 :                 // Check that the config has been consumed during upgrade
     404                 :                 // OR we didn't provide it at all (for dev purposes).
     405                 :                 if tls.is_some() {
     406                 :                     stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
     407                 :                 }
     408                 : 
     409              33 :                 info!(session_type = "normal", "successful handshake");
     410                 :                 break Ok(Some((stream, params)));
     411                 :             }
     412                 :             CancelRequest(cancel_key_data) => {
     413                 :                 cancel_map.cancel_session(cancel_key_data).await?;
     414                 : 
     415 UBC           0 :                 info!(session_type = "cancellation", "successful handshake");
     416                 :                 break Ok(None);
     417                 :             }
     418                 :         }
     419                 :     }
     420                 : }
     421                 : 
     422                 : /// If we couldn't connect, a cached connection info might be to blame
     423                 : /// (e.g. the compute node's address might've changed at the wrong time).
     424                 : /// Invalidate the cache entry (if any) to prevent subsequent errors.
     425 CBC          14 : #[tracing::instrument(name = "invalidate_cache", skip_all)]
     426                 : pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
     427                 :     let is_cached = node_info.cached();
     428                 :     if is_cached {
     429 UBC           0 :         warn!("invalidating stalled compute node info cache entry");
     430                 :     }
     431                 :     let label = match is_cached {
     432                 :         true => "compute_cached",
     433                 :         false => "compute_uncached",
     434                 :     };
     435                 :     NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
     436                 : 
     437                 :     node_info.invalidate().config
     438                 : }
     439                 : 
     440                 : /// Try to connect to the compute node once.
     441 CBC         116 : #[tracing::instrument(name = "connect_once", skip_all)]
     442                 : async fn connect_to_compute_once(
     443                 :     node_info: &console::CachedNodeInfo,
     444                 :     timeout: time::Duration,
     445                 : ) -> Result<PostgresConnection, compute::ConnectionError> {
     446                 :     let allow_self_signed_compute = node_info.allow_self_signed_compute;
     447                 : 
     448                 :     node_info
     449                 :         .config
     450                 :         .connect(allow_self_signed_compute, timeout)
     451                 :         .await
     452                 : }
     453                 : 
     454                 : #[async_trait]
     455                 : pub trait ConnectMechanism {
     456                 :     type Connection;
     457                 :     type ConnectError;
     458                 :     type Error: From<Self::ConnectError>;
     459                 :     async fn connect_once(
     460                 :         &self,
     461                 :         node_info: &console::CachedNodeInfo,
     462                 :         timeout: time::Duration,
     463                 :     ) -> Result<Self::Connection, Self::ConnectError>;
     464                 : 
     465                 :     fn update_connect_config(&self, conf: &mut compute::ConnCfg);
     466                 : }
     467                 : 
     468                 : pub struct TcpMechanism<'a> {
     469                 :     /// KV-dictionary with PostgreSQL connection params.
     470                 :     pub params: &'a StartupMessageParams,
     471                 : }
     472                 : 
     473                 : #[async_trait]
     474                 : impl ConnectMechanism for TcpMechanism<'_> {
     475                 :     type Connection = PostgresConnection;
     476                 :     type ConnectError = compute::ConnectionError;
     477                 :     type Error = compute::ConnectionError;
     478                 : 
     479              29 :     async fn connect_once(
     480              29 :         &self,
     481              29 :         node_info: &console::CachedNodeInfo,
     482              29 :         timeout: time::Duration,
     483              29 :     ) -> Result<PostgresConnection, Self::Error> {
     484              87 :         connect_to_compute_once(node_info, timeout).await
     485              58 :     }
     486                 : 
     487              29 :     fn update_connect_config(&self, config: &mut compute::ConnCfg) {
     488              29 :         config.set_startup_params(self.params);
     489              29 :     }
     490                 : }
     491                 : 
     492             142 : const fn bool_to_str(x: bool) -> &'static str {
     493             142 :     if x {
     494              76 :         "true"
     495                 :     } else {
     496              66 :         "false"
     497                 :     }
     498             142 : }
     499                 : 
     500               2 : fn report_error(e: &WakeComputeError, retry: bool) {
     501               2 :     use crate::console::errors::ApiError;
     502               2 :     let retry = bool_to_str(retry);
     503               2 :     let kind = match e {
     504 UBC           0 :         WakeComputeError::BadComputeAddress(_) => "bad_compute_address",
     505               0 :         WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error",
     506                 :         WakeComputeError::ApiError(ApiError::Console {
     507                 :             status: StatusCode::LOCKED,
     508               0 :             ref text,
     509               0 :         }) if text.contains("written data quota exceeded")
     510               0 :             || text.contains("the limit for current plan reached") =>
     511               0 :         {
     512               0 :             "quota_exceeded"
     513                 :         }
     514                 :         WakeComputeError::ApiError(ApiError::Console {
     515                 :             status: StatusCode::LOCKED,
     516                 :             ..
     517               0 :         }) => "api_console_locked",
     518                 :         WakeComputeError::ApiError(ApiError::Console {
     519                 :             status: StatusCode::BAD_REQUEST,
     520                 :             ..
     521               0 :         }) => "api_console_bad_request",
     522 CBC           2 :         WakeComputeError::ApiError(ApiError::Console { status, .. })
     523               2 :             if status.is_server_error() =>
     524               1 :         {
     525               1 :             "api_console_other_server_error"
     526                 :         }
     527               1 :         WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error",
     528                 :     };
     529               2 :     NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc();
     530               2 : }
     531                 : 
     532                 : /// Try to connect to the compute node, retrying if necessary.
     533                 : /// This function might update `node_info`, so we take it by `&mut`.
     534             196 : #[tracing::instrument(skip_all)]
     535                 : pub async fn connect_to_compute<M: ConnectMechanism>(
     536                 :     mechanism: &M,
     537                 :     mut node_info: console::CachedNodeInfo,
     538                 :     extra: &console::ConsoleReqExtra<'_>,
     539                 :     creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
     540                 :     mut latency_timer: LatencyTimer,
     541                 : ) -> Result<M::Connection, M::Error>
     542                 : where
     543                 :     M::ConnectError: ShouldRetry + std::fmt::Debug,
     544                 :     M::Error: From<WakeComputeError>,
     545                 : {
     546                 :     mechanism.update_connect_config(&mut node_info.config);
     547                 : 
     548                 :     // try once
     549                 :     let (config, err) = match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
     550                 :         Ok(res) => return Ok(res),
     551                 :         Err(e) => {
     552               2 :             error!(error = ?e, "could not connect to compute node");
     553                 :             (invalidate_cache(node_info), e)
     554                 :         }
     555                 :     };
     556                 : 
     557                 :     latency_timer.cache_miss();
     558                 : 
     559                 :     let mut num_retries = 1;
     560                 : 
     561                 :     // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
     562               2 :     info!("compute node's state has likely changed; requesting a wake-up");
     563                 :     let node_info = loop {
     564                 :         let wake_res = match creds {
     565                 :             auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
     566                 :             auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
     567                 :             // nothing to do?
     568                 :             auth::BackendType::Link(_) => return Err(err.into()),
     569                 :             // test backend
     570                 :             auth::BackendType::Test(x) => x.wake_compute(),
     571                 :         };
     572                 : 
     573                 :         match handle_try_wake(wake_res, num_retries) {
     574                 :             Err(e) => {
     575 UBC           0 :                 error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
     576                 :                 report_error(&e, false);
     577                 :                 return Err(e.into());
     578                 :             }
     579                 :             // failed to wake up but we can continue to retry
     580                 :             Ok(ControlFlow::Continue(e)) => {
     581                 :                 report_error(&e, true);
     582               0 :                 warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
     583                 :             }
     584                 :             // successfully woke up a compute node and can break the wakeup loop
     585                 :             Ok(ControlFlow::Break(mut node_info)) => {
     586                 :                 node_info.config.reuse_password(&config);
     587                 :                 mechanism.update_connect_config(&mut node_info.config);
     588                 :                 break node_info;
     589                 :             }
     590                 :         }
     591                 : 
     592                 :         let wait_duration = retry_after(num_retries);
     593                 :         num_retries += 1;
     594                 : 
     595                 :         time::sleep(wait_duration).await;
     596                 :     };
     597                 : 
     598                 :     // now that we have a new node, try connect to it repeatedly.
     599                 :     // this can error for a few reasons, for instance:
     600                 :     // * DNS connection settings haven't quite propagated yet
     601 CBC           2 :     info!("wake_compute success. attempting to connect");
     602                 :     loop {
     603                 :         match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
     604                 :             Ok(res) => return Ok(res),
     605                 :             Err(e) => {
     606                 :                 let retriable = e.should_retry(num_retries);
     607                 :                 if !retriable {
     608               2 :                     error!(error = ?e, num_retries, retriable, "couldn't connect to compute node");
     609                 :                     return Err(e.into());
     610                 :                 }
     611 UBC           0 :                 warn!(error = ?e, num_retries, retriable, "couldn't connect to compute node");
     612                 :             }
     613                 :         }
     614                 : 
     615                 :         let wait_duration = retry_after(num_retries);
     616                 :         num_retries += 1;
     617                 : 
     618                 :         time::sleep(wait_duration).await;
     619                 :     }
     620                 : }
     621                 : 
     622                 : /// Attempts to wake up the compute node.
     623                 : /// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable
     624                 : /// * Returns Ok(Break(node)) if the wakeup succeeded
     625                 : /// * Returns Err(e) if there was an error
     626 CBC          33 : pub fn handle_try_wake(
     627              33 :     result: Result<console::CachedNodeInfo, WakeComputeError>,
     628              33 :     num_retries: u32,
     629              33 : ) -> Result<ControlFlow<console::CachedNodeInfo, WakeComputeError>, WakeComputeError> {
     630              33 :     match result {
     631               2 :         Err(err) => match &err {
     632               2 :             WakeComputeError::ApiError(api) if api.should_retry(num_retries) => {
     633               1 :                 Ok(ControlFlow::Continue(err))
     634                 :             }
     635               1 :             _ => Err(err),
     636                 :         },
     637                 :         // Ready to try again.
     638              31 :         Ok(new) => Ok(ControlFlow::Break(new)),
     639                 :     }
     640              33 : }
     641                 : 
     642                 : pub trait ShouldRetry {
     643                 :     fn could_retry(&self) -> bool;
     644              24 :     fn should_retry(&self, num_retries: u32) -> bool {
     645              24 :         match self {
     646              24 :             _ if num_retries >= NUM_RETRIES_CONNECT => false,
     647              23 :             err => err.could_retry(),
     648                 :         }
     649              24 :     }
     650                 : }
     651                 : 
     652                 : impl ShouldRetry for io::Error {
     653 UBC           0 :     fn could_retry(&self) -> bool {
     654                 :         use std::io::ErrorKind;
     655               0 :         matches!(
     656               0 :             self.kind(),
     657                 :             ErrorKind::ConnectionRefused | ErrorKind::AddrNotAvailable | ErrorKind::TimedOut
     658                 :         )
     659               0 :     }
     660                 : }
     661                 : 
     662                 : impl ShouldRetry for tokio_postgres::error::DbError {
     663 CBC           2 :     fn could_retry(&self) -> bool {
     664                 :         use tokio_postgres::error::SqlState;
     665               2 :         matches!(
     666               2 :             self.code(),
     667                 :             &SqlState::CONNECTION_FAILURE
     668                 :                 | &SqlState::CONNECTION_EXCEPTION
     669                 :                 | &SqlState::CONNECTION_DOES_NOT_EXIST
     670                 :                 | &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
     671                 :         )
     672               2 :     }
     673                 : }
     674                 : 
     675                 : impl ShouldRetry for tokio_postgres::Error {
     676                 :     fn could_retry(&self) -> bool {
     677               2 :         if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
     678 UBC           0 :             io::Error::could_retry(io_err)
     679 CBC           2 :         } else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
     680               2 :             tokio_postgres::error::DbError::could_retry(db_err)
     681                 :         } else {
     682 UBC           0 :             false
     683                 :         }
     684 CBC           2 :     }
     685                 : }
     686                 : 
     687                 : impl ShouldRetry for compute::ConnectionError {
     688 UBC           0 :     fn could_retry(&self) -> bool {
     689               0 :         match self {
     690               0 :             compute::ConnectionError::Postgres(err) => err.could_retry(),
     691               0 :             compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
     692               0 :             _ => false,
     693                 :         }
     694               0 :     }
     695                 : }
     696                 : 
     697 CBC          34 : pub fn retry_after(num_retries: u32) -> time::Duration {
     698              34 :     BASE_RETRY_WAIT_DURATION.mul_f64(RETRY_WAIT_EXPONENT_BASE.powi((num_retries as i32) - 1))
     699              34 : }
     700                 : 
     701                 : /// Finish client connection initialization: confirm auth success, send params, etc.
     702             116 : #[tracing::instrument(skip_all)]
     703                 : async fn prepare_client_connection(
     704                 :     node: &compute::PostgresConnection,
     705                 :     reported_auth_ok: bool,
     706                 :     session: cancellation::Session<'_>,
     707                 :     stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     708                 : ) -> anyhow::Result<()> {
     709                 :     // Register compute's query cancellation token and produce a new, unique one.
     710                 :     // The new token (cancel_key_data) will be sent to the client.
     711                 :     let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
     712                 : 
     713                 :     // Report authentication success if we haven't done this already.
     714                 :     // Note that we do this only (for the most part) after we've connected
     715                 :     // to a compute (see above) which performs its own authentication.
     716                 :     if !reported_auth_ok {
     717                 :         stream.write_message_noflush(&Be::AuthenticationOk)?;
     718                 :     }
     719                 : 
     720                 :     // Forward all postgres connection params to the client.
     721                 :     // Right now the implementation is very hacky and inefficent (ideally,
     722                 :     // we don't need an intermediate hashmap), but at least it should be correct.
     723                 :     for (name, value) in &node.params {
     724                 :         // TODO: Theoretically, this could result in a big pile of params...
     725                 :         stream.write_message_noflush(&Be::ParameterStatus {
     726                 :             name: name.as_bytes(),
     727                 :             value: value.as_bytes(),
     728                 :         })?;
     729                 :     }
     730                 : 
     731                 :     stream
     732                 :         .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
     733                 :         .write_message(&Be::ReadyForQuery)
     734                 :         .await?;
     735                 : 
     736                 :     Ok(())
     737                 : }
     738                 : 
     739                 : /// Forward bytes in both directions (client <-> compute).
     740             120 : #[tracing::instrument(skip_all)]
     741                 : pub async fn proxy_pass(
     742                 :     client: impl AsyncRead + AsyncWrite + Unpin,
     743                 :     compute: impl AsyncRead + AsyncWrite + Unpin,
     744                 :     aux: &MetricsAuxInfo,
     745                 : ) -> anyhow::Result<()> {
     746                 :     let usage = USAGE_METRICS.register(Ids {
     747                 :         endpoint_id: aux.endpoint_id.to_string(),
     748                 :         branch_id: aux.branch_id.to_string(),
     749                 :     });
     750                 : 
     751                 :     let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx"));
     752                 :     let mut client = MeasuredStream::new(
     753                 :         client,
     754              90 :         |_| {},
     755              60 :         |cnt| {
     756              60 :             // Number of bytes we sent to the client (outbound).
     757              60 :             m_sent.inc_by(cnt as u64);
     758              60 :             usage.record_egress(cnt as u64);
     759              60 :         },
     760                 :     );
     761                 : 
     762                 :     let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx"));
     763                 :     let mut compute = MeasuredStream::new(
     764                 :         compute,
     765              60 :         |_| {},
     766              61 :         |cnt| {
     767              61 :             // Number of bytes the client sent to the compute node (inbound).
     768              61 :             m_recv.inc_by(cnt as u64);
     769              61 :         },
     770                 :     );
     771                 : 
     772                 :     // Starting from here we only proxy the client's traffic.
     773              30 :     info!("performing the proxy pass...");
     774                 :     let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
     775                 : 
     776                 :     Ok(())
     777                 : }
     778                 : 
     779                 : /// Thin connection context.
     780                 : struct Client<'a, S> {
     781                 :     /// The underlying libpq protocol stream.
     782                 :     stream: PqStream<S>,
     783                 :     /// Client credentials that we care about.
     784                 :     creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
     785                 :     /// KV-dictionary with PostgreSQL connection params.
     786                 :     params: &'a StartupMessageParams,
     787                 :     /// Unique connection ID.
     788                 :     session_id: uuid::Uuid,
     789                 :     /// Allow self-signed certificates (for testing).
     790                 :     allow_self_signed_compute: bool,
     791                 : }
     792                 : 
     793                 : impl<'a, S> Client<'a, S> {
     794                 :     /// Construct a new connection context.
     795              33 :     fn new(
     796              33 :         stream: PqStream<S>,
     797              33 :         creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
     798              33 :         params: &'a StartupMessageParams,
     799              33 :         session_id: uuid::Uuid,
     800              33 :         allow_self_signed_compute: bool,
     801              33 :     ) -> Self {
     802              33 :         Self {
     803              33 :             stream,
     804              33 :             creds,
     805              33 :             params,
     806              33 :             session_id,
     807              33 :             allow_self_signed_compute,
     808              33 :         }
     809              33 :     }
     810                 : }
     811                 : 
     812                 : impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
     813                 :     /// Let the client authenticate and connect to the designated compute node.
     814                 :     // Instrumentation logs endpoint name everywhere. Doesn't work for link
     815                 :     // auth; strictly speaking we don't know endpoint name in its case.
     816              66 :     #[tracing::instrument(name = "", fields(ep = self.creds.get_endpoint().unwrap_or("".to_owned())), skip_all)]
     817                 :     async fn connect_to_db(
     818                 :         self,
     819                 :         session: cancellation::Session<'_>,
     820                 :         mode: ClientMode,
     821                 :     ) -> anyhow::Result<()> {
     822                 :         let Self {
     823                 :             mut stream,
     824                 :             mut creds,
     825                 :             params,
     826                 :             session_id,
     827                 :             allow_self_signed_compute,
     828                 :         } = self;
     829                 : 
     830                 :         let extra = console::ConsoleReqExtra {
     831                 :             session_id, // aka this connection's id
     832                 :             application_name: params.get("application_name"),
     833                 :         };
     834                 : 
     835                 :         let latency_timer = LatencyTimer::new(mode.protocol_label());
     836                 : 
     837                 :         let auth_result = match creds
     838                 :             .authenticate(&extra, &mut stream, mode.allow_cleartext())
     839                 :             .await
     840                 :         {
     841                 :             Ok(auth_result) => auth_result,
     842                 :             Err(e) => {
     843                 :                 let user = creds.get_user();
     844                 :                 let db = params.get("database");
     845                 :                 let app = params.get("application_name");
     846                 :                 let params_span = tracing::info_span!("", ?user, ?db, ?app);
     847                 : 
     848                 :                 return stream.throw_error(e).instrument(params_span).await;
     849                 :             }
     850                 :         };
     851                 : 
     852                 :         let AuthSuccess {
     853                 :             reported_auth_ok,
     854                 :             value: mut node_info,
     855                 :         } = auth_result;
     856                 : 
     857                 :         node_info.allow_self_signed_compute = allow_self_signed_compute;
     858                 : 
     859                 :         let aux = node_info.aux.clone();
     860                 :         let mut node = connect_to_compute(
     861                 :             &TcpMechanism { params },
     862                 :             node_info,
     863                 :             &extra,
     864                 :             &creds,
     865                 :             latency_timer,
     866                 :         )
     867 UBC           0 :         .or_else(|e| stream.throw_error(e))
     868                 :         .await?;
     869                 : 
     870                 :         let proto = mode.protocol_label();
     871                 :         NUM_DB_CONNECTIONS_OPENED_COUNTER
     872                 :             .with_label_values(&[proto])
     873                 :             .inc();
     874 CBC          29 :         scopeguard::defer! {
     875              29 :             NUM_DB_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[proto]).inc();
     876              29 :         }
     877                 : 
     878                 :         prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
     879                 :         // Before proxy passing, forward to compute whatever data is left in the
     880                 :         // PqStream input buffer. Normally there is none, but our serverless npm
     881                 :         // driver in pipeline mode sends startup, password and first query
     882                 :         // immediately after opening the connection.
     883                 :         let (stream, read_buf) = stream.into_inner();
     884                 :         node.stream.write_all(&read_buf).await?;
     885                 :         proxy_pass(stream, node.stream, &aux).await
     886                 :     }
     887                 : }
        

Generated by: LCOV version 2.1-beta