LCOV - differential code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 75.0 % 184 138 46 138
Current Date: 2024-01-09 02:06:09 Functions: 59.4 % 32 19 13 19
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use crate::{
       2                 :     auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
       3                 :     context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
       4                 :     proxy::neon_option,
       5                 : };
       6                 : use futures::{FutureExt, TryFutureExt};
       7                 : use itertools::Itertools;
       8                 : use metrics::IntCounterPairGuard;
       9                 : use pq_proto::StartupMessageParams;
      10                 : use std::{io, net::SocketAddr, time::Duration};
      11                 : use thiserror::Error;
      12                 : use tokio::net::TcpStream;
      13                 : use tokio_postgres::tls::MakeTlsConnect;
      14                 : use tracing::{error, info, warn};
      15                 : 
      16                 : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      17                 : 
      18 UBC           0 : #[derive(Debug, Error)]
      19                 : pub enum ConnectionError {
      20                 :     /// This error doesn't seem to reveal any secrets; for instance,
      21                 :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      22                 :     #[error("{COULD_NOT_CONNECT}: {0}")]
      23                 :     Postgres(#[from] tokio_postgres::Error),
      24                 : 
      25                 :     #[error("{COULD_NOT_CONNECT}: {0}")]
      26                 :     CouldNotConnect(#[from] io::Error),
      27                 : 
      28                 :     #[error("{COULD_NOT_CONNECT}: {0}")]
      29                 :     TlsError(#[from] native_tls::Error),
      30                 : 
      31                 :     #[error("{COULD_NOT_CONNECT}: {0}")]
      32                 :     WakeComputeError(#[from] WakeComputeError),
      33                 : }
      34                 : 
      35                 : impl UserFacingError for ConnectionError {
      36               0 :     fn to_string_client(&self) -> String {
      37               0 :         use ConnectionError::*;
      38               0 :         match self {
      39                 :             // This helps us drop irrelevant library-specific prefixes.
      40                 :             // TODO: propagate severity level and other parameters.
      41               0 :             Postgres(err) => match err.as_db_error() {
      42               0 :                 Some(err) => err.message().to_owned(),
      43               0 :                 None => err.to_string(),
      44                 :             },
      45               0 :             WakeComputeError(err) => err.to_string_client(),
      46               0 :             _ => COULD_NOT_CONNECT.to_owned(),
      47                 :         }
      48               0 :     }
      49                 : }
      50                 : 
      51                 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      52                 : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      53                 : 
      54                 : /// A config for establishing a connection to compute node.
      55                 : /// Eventually, `tokio_postgres` will be replaced with something better.
      56                 : /// Newtype allows us to implement methods on top of it.
      57               0 : #[derive(Clone)]
      58                 : #[repr(transparent)]
      59                 : pub struct ConnCfg(Box<tokio_postgres::Config>);
      60                 : 
      61                 : /// Creation and initialization routines.
      62                 : impl ConnCfg {
      63 CBC          92 :     pub fn new() -> Self {
      64              92 :         Self(Default::default())
      65              92 :     }
      66                 : 
      67                 :     /// Reuse password or auth keys from the other config.
      68               7 :     pub fn reuse_password(&mut self, other: &Self) {
      69               7 :         if let Some(password) = other.get_password() {
      70 UBC           0 :             self.password(password);
      71 CBC           7 :         }
      72                 : 
      73               7 :         if let Some(keys) = other.get_auth_keys() {
      74 UBC           0 :             self.auth_keys(keys);
      75 CBC           7 :         }
      76               7 :     }
      77                 : 
      78                 :     /// Apply startup message params to the connection config.
      79              38 :     pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
      80                 :         // Only set `user` if it's not present in the config.
      81                 :         // Link auth flow takes username from the console's response.
      82              38 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
      83              35 :             self.user(user);
      84              35 :         }
      85                 : 
      86                 :         // Only set `dbname` if it's not present in the config.
      87                 :         // Link auth flow takes dbname from the console's response.
      88              38 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
      89              35 :             self.dbname(dbname);
      90              35 :         }
      91                 : 
      92                 :         // Don't add `options` if they were only used for specifying a project.
      93                 :         // Connection pools don't support `options`, because they affect backend startup.
      94              38 :         if let Some(options) = filtered_options(params) {
      95              35 :             self.options(&options);
      96              35 :         }
      97                 : 
      98              38 :         if let Some(app_name) = params.get("application_name") {
      99               3 :             self.application_name(app_name);
     100              35 :         }
     101                 : 
     102                 :         // TODO: This is especially ugly...
     103              38 :         if let Some(replication) = params.get("replication") {
     104                 :             use tokio_postgres::config::ReplicationMode;
     105 UBC           0 :             match replication {
     106               0 :                 "true" | "on" | "yes" | "1" => {
     107               0 :                     self.replication_mode(ReplicationMode::Physical);
     108               0 :                 }
     109               0 :                 "database" => {
     110               0 :                     self.replication_mode(ReplicationMode::Logical);
     111               0 :                 }
     112               0 :                 _other => {}
     113                 :             }
     114 CBC          38 :         }
     115                 : 
     116                 :         // TODO: extend the list of the forwarded startup parameters.
     117                 :         // Currently, tokio-postgres doesn't allow us to pass
     118                 :         // arbitrary parameters, but the ones above are a good start.
     119                 :         //
     120                 :         // This and the reverse params problem can be better addressed
     121                 :         // in a bespoke connection machinery (a new library for that sake).
     122              38 :     }
     123                 : }
     124                 : 
     125                 : impl std::ops::Deref for ConnCfg {
     126                 :     type Target = tokio_postgres::Config;
     127                 : 
     128             132 :     fn deref(&self) -> &Self::Target {
     129             132 :         &self.0
     130             132 :     }
     131                 : }
     132                 : 
     133                 : /// For now, let's make it easier to setup the config.
     134                 : impl std::ops::DerefMut for ConnCfg {
     135             226 :     fn deref_mut(&mut self) -> &mut Self::Target {
     136             226 :         &mut self.0
     137             226 :     }
     138                 : }
     139                 : 
     140                 : impl Default for ConnCfg {
     141 UBC           0 :     fn default() -> Self {
     142               0 :         Self::new()
     143               0 :     }
     144                 : }
     145                 : 
     146                 : impl ConnCfg {
     147                 :     /// Establish a raw TCP connection to the compute node.
     148 CBC          38 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     149              38 :         use tokio_postgres::config::Host;
     150              38 : 
     151              38 :         // wrap TcpStream::connect with timeout
     152              38 :         let connect_with_timeout = |host, port| {
     153              38 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     154              38 :                 move |res| match res {
     155              38 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     156 UBC           0 :                     Err(_) => Err(io::Error::new(
     157               0 :                         io::ErrorKind::TimedOut,
     158               0 :                         format!("exceeded connection timeout {timeout:?}"),
     159               0 :                     )),
     160 CBC          38 :                 },
     161              38 :             )
     162              38 :         };
     163              38 : 
     164              38 :         let connect_once = |host, port| {
     165              38 :             info!("trying to connect to compute node at {host}:{port}");
     166              38 :             connect_with_timeout(host, port).and_then(|socket| async {
     167              38 :                 let socket_addr = socket.peer_addr()?;
     168                 :                 // This prevents load balancer from severing the connection.
     169              38 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     170              38 :                 Ok((socket_addr, socket))
     171              38 :             })
     172              38 :         };
     173              38 : 
     174              38 :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     175              38 :         // because it has no means for extracting the underlying socket which we
     176              38 :         // require for our business.
     177              38 :         let mut connection_error = None;
     178              38 :         let ports = self.0.get_ports();
     179              38 :         let hosts = self.0.get_hosts();
     180              38 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     181              38 :         if ports.len() > 1 && ports.len() != hosts.len() {
     182 UBC           0 :             return Err(io::Error::new(
     183               0 :                 io::ErrorKind::Other,
     184               0 :                 format!(
     185               0 :                     "bad compute config, \
     186               0 :                      ports and hosts entries' count does not match: {:?}",
     187               0 :                     self.0
     188               0 :                 ),
     189               0 :             ));
     190 CBC          38 :         }
     191                 : 
     192              38 :         for (i, host) in hosts.iter().enumerate() {
     193              38 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     194              38 :             let host = match host {
     195              38 :                 Host::Tcp(host) => host.as_str(),
     196 UBC           0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     197                 :             };
     198                 : 
     199 CBC          68 :             match connect_once(host, *port).await {
     200              38 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     201 UBC           0 :                 Err(err) => {
     202                 :                     // We can't throw an error here, as there might be more hosts to try.
     203               0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     204               0 :                     connection_error = Some(err);
     205                 :                 }
     206                 :             }
     207                 :         }
     208                 : 
     209               0 :         Err(connection_error.unwrap_or_else(|| {
     210               0 :             io::Error::new(
     211               0 :                 io::ErrorKind::Other,
     212               0 :                 format!("bad compute config: {:?}", self.0),
     213               0 :             )
     214               0 :         }))
     215 CBC          38 :     }
     216                 : }
     217                 : 
     218                 : pub struct PostgresConnection {
     219                 :     /// Socket connected to a compute node.
     220                 :     pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     221                 :         tokio::net::TcpStream,
     222                 :         postgres_native_tls::TlsStream<tokio::net::TcpStream>,
     223                 :     >,
     224                 :     /// PostgreSQL connection parameters.
     225                 :     pub params: std::collections::HashMap<String, String>,
     226                 :     /// Query cancellation token.
     227                 :     pub cancel_closure: CancelClosure,
     228                 : 
     229                 :     _guage: IntCounterPairGuard,
     230                 : }
     231                 : 
     232                 : impl ConnCfg {
     233                 :     /// Connect to a corresponding compute node.
     234              38 :     pub async fn connect(
     235              38 :         &self,
     236              38 :         ctx: &mut RequestMonitoring,
     237              38 :         allow_self_signed_compute: bool,
     238              38 :         timeout: Duration,
     239              38 :     ) -> Result<PostgresConnection, ConnectionError> {
     240              68 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     241                 : 
     242              38 :         let tls_connector = native_tls::TlsConnector::builder()
     243              38 :             .danger_accept_invalid_certs(allow_self_signed_compute)
     244              38 :             .build()
     245              38 :             .unwrap();
     246              38 :         let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
     247              38 :         let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
     248                 : 
     249                 :         // connect_raw() will not use TLS if sslmode is "disable"
     250              42 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     251              38 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     252              38 :         let stream = connection.stream.into_inner();
     253                 : 
     254              38 :         info!(
     255              38 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     256              38 :             self.0.get_ssl_mode()
     257              38 :         );
     258                 : 
     259                 :         // This is very ugly but as of now there's no better way to
     260                 :         // extract the connection parameters from tokio-postgres' connection.
     261                 :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     262              38 :         let params = connection.parameters;
     263              38 : 
     264              38 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     265              38 :         // Yet another reason to rework the connection establishing code.
     266              38 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     267              38 : 
     268              38 :         let connection = PostgresConnection {
     269              38 :             stream,
     270              38 :             params,
     271              38 :             cancel_closure,
     272              38 :             _guage: NUM_DB_CONNECTIONS_GAUGE
     273              38 :                 .with_label_values(&[ctx.protocol])
     274              38 :                 .guard(),
     275              38 :         };
     276              38 : 
     277              38 :         Ok(connection)
     278              38 :     }
     279                 : }
     280                 : 
     281                 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     282              44 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     283                 :     #[allow(unstable_name_collisions)]
     284              44 :     let options: String = params
     285              44 :         .options_raw()?
     286              71 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     287              44 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     288              44 :         .collect();
     289              44 : 
     290              44 :     // Don't even bother with empty options.
     291              44 :     if options.is_empty() {
     292               6 :         return None;
     293              38 :     }
     294              38 : 
     295              38 :     Some(options)
     296              44 : }
     297                 : 
     298                 : #[cfg(test)]
     299                 : mod tests {
     300                 :     use super::*;
     301                 : 
     302               1 :     #[test]
     303               1 :     fn test_filtered_options() {
     304               1 :         // Empty options is unlikely to be useful anyway.
     305               1 :         let params = StartupMessageParams::new([("options", "")]);
     306               1 :         assert_eq!(filtered_options(&params), None);
     307                 : 
     308                 :         // It's likely that clients will only use options to specify endpoint/project.
     309               1 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     310               1 :         assert_eq!(filtered_options(&params), None);
     311                 : 
     312                 :         // Same, because unescaped whitespaces are no-op.
     313               1 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     314               1 :         assert_eq!(filtered_options(&params).as_deref(), None);
     315                 : 
     316               1 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     317               1 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     318                 : 
     319               1 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     320               1 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     321                 : 
     322               1 :         let params = StartupMessageParams::new([(
     323               1 :             "options",
     324               1 :             "project = foo neon_endpoint_type:read_write   neon_lsn:0/2",
     325               1 :         )]);
     326               1 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     327               1 :     }
     328                 : }
        

Generated by: LCOV version 2.1-beta