LCOV - differential code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 72.4 % 174 126 48 126
Current Date: 2023-10-19 02:04:12 Functions: 56.2 % 32 18 14 18
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : use crate::{
       2                 :     auth::parse_endpoint_param,
       3                 :     cancellation::CancelClosure,
       4                 :     console::errors::WakeComputeError,
       5                 :     error::{io_error, UserFacingError},
       6                 : };
       7                 : use futures::{FutureExt, TryFutureExt};
       8                 : use itertools::Itertools;
       9                 : use pq_proto::StartupMessageParams;
      10                 : use std::{io, net::SocketAddr, time::Duration};
      11                 : use thiserror::Error;
      12                 : use tokio::net::TcpStream;
      13                 : use tokio_postgres::tls::MakeTlsConnect;
      14                 : use tracing::{error, info, warn};
      15                 : 
      16                 : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      17                 : 
      18 UBC           0 : #[derive(Debug, Error)]
      19                 : pub enum ConnectionError {
      20                 :     /// This error doesn't seem to reveal any secrets; for instance,
      21                 :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      22                 :     #[error("{COULD_NOT_CONNECT}: {0}")]
      23                 :     Postgres(#[from] tokio_postgres::Error),
      24                 : 
      25                 :     #[error("{COULD_NOT_CONNECT}: {0}")]
      26                 :     CouldNotConnect(#[from] io::Error),
      27                 : 
      28                 :     #[error("{COULD_NOT_CONNECT}: {0}")]
      29                 :     TlsError(#[from] native_tls::Error),
      30                 : }
      31                 : 
      32                 : impl From<WakeComputeError> for ConnectionError {
      33               0 :     fn from(value: WakeComputeError) -> Self {
      34               0 :         io_error(value).into()
      35               0 :     }
      36                 : }
      37                 : 
      38                 : impl UserFacingError for ConnectionError {
      39               0 :     fn to_string_client(&self) -> String {
      40               0 :         use ConnectionError::*;
      41               0 :         match self {
      42                 :             // This helps us drop irrelevant library-specific prefixes.
      43                 :             // TODO: propagate severity level and other parameters.
      44               0 :             Postgres(err) => match err.as_db_error() {
      45               0 :                 Some(err) => err.message().to_owned(),
      46               0 :                 None => err.to_string(),
      47                 :             },
      48               0 :             _ => COULD_NOT_CONNECT.to_owned(),
      49                 :         }
      50               0 :     }
      51                 : }
      52                 : 
      53                 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      54                 : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      55                 : 
      56                 : /// A config for establishing a connection to compute node.
      57                 : /// Eventually, `tokio_postgres` will be replaced with something better.
      58                 : /// Newtype allows us to implement methods on top of it.
      59               0 : #[derive(Clone)]
      60                 : #[repr(transparent)]
      61                 : pub struct ConnCfg(Box<tokio_postgres::Config>);
      62                 : 
      63                 : /// Creation and initialization routines.
      64                 : impl ConnCfg {
      65 CBC          70 :     pub fn new() -> Self {
      66              70 :         Self(Default::default())
      67              70 :     }
      68                 : 
      69                 :     /// Reuse password or auth keys from the other config.
      70                 :     pub fn reuse_password(&mut self, other: &Self) {
      71               7 :         if let Some(password) = other.get_password() {
      72 UBC           0 :             self.password(password);
      73 CBC           7 :         }
      74                 : 
      75               7 :         if let Some(keys) = other.get_auth_keys() {
      76 UBC           0 :             self.auth_keys(keys);
      77 CBC           7 :         }
      78               7 :     }
      79                 : 
      80                 :     /// Apply startup message params to the connection config.
      81                 :     pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
      82                 :         // Only set `user` if it's not present in the config.
      83                 :         // Link auth flow takes username from the console's response.
      84              29 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
      85              26 :             self.user(user);
      86              26 :         }
      87                 : 
      88                 :         // Only set `dbname` if it's not present in the config.
      89                 :         // Link auth flow takes dbname from the console's response.
      90              29 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
      91              26 :             self.dbname(dbname);
      92              26 :         }
      93                 : 
      94                 :         // Don't add `options` if they were only used for specifying a project.
      95                 :         // Connection pools don't support `options`, because they affect backend startup.
      96              29 :         if let Some(options) = filtered_options(params) {
      97              26 :             self.options(&options);
      98              26 :         }
      99                 : 
     100              29 :         if let Some(app_name) = params.get("application_name") {
     101               3 :             self.application_name(app_name);
     102              26 :         }
     103                 : 
     104                 :         // TODO: This is especially ugly...
     105              29 :         if let Some(replication) = params.get("replication") {
     106                 :             use tokio_postgres::config::ReplicationMode;
     107 UBC           0 :             match replication {
     108               0 :                 "true" | "on" | "yes" | "1" => {
     109               0 :                     self.replication_mode(ReplicationMode::Physical);
     110               0 :                 }
     111               0 :                 "database" => {
     112               0 :                     self.replication_mode(ReplicationMode::Logical);
     113               0 :                 }
     114               0 :                 _other => {}
     115                 :             }
     116 CBC          29 :         }
     117                 : 
     118                 :         // TODO: extend the list of the forwarded startup parameters.
     119                 :         // Currently, tokio-postgres doesn't allow us to pass
     120                 :         // arbitrary parameters, but the ones above are a good start.
     121                 :         //
     122                 :         // This and the reverse params problem can be better addressed
     123                 :         // in a bespoke connection machinery (a new library for that sake).
     124              29 :     }
     125                 : }
     126                 : 
     127                 : impl std::ops::Deref for ConnCfg {
     128                 :     type Target = tokio_postgres::Config;
     129                 : 
     130             101 :     fn deref(&self) -> &Self::Target {
     131             101 :         &self.0
     132             101 :     }
     133                 : }
     134                 : 
     135                 : /// For now, let's make it easier to setup the config.
     136                 : impl std::ops::DerefMut for ConnCfg {
     137             168 :     fn deref_mut(&mut self) -> &mut Self::Target {
     138             168 :         &mut self.0
     139             168 :     }
     140                 : }
     141                 : 
     142                 : impl Default for ConnCfg {
     143 UBC           0 :     fn default() -> Self {
     144               0 :         Self::new()
     145               0 :     }
     146                 : }
     147                 : 
     148                 : impl ConnCfg {
     149                 :     /// Establish a raw TCP connection to the compute node.
     150 CBC          29 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     151              29 :         use tokio_postgres::config::Host;
     152              29 : 
     153              29 :         // wrap TcpStream::connect with timeout
     154              29 :         let connect_with_timeout = |host, port| {
     155              29 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     156              29 :                 move |res| match res {
     157              29 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     158 UBC           0 :                     Err(_) => Err(io::Error::new(
     159               0 :                         io::ErrorKind::TimedOut,
     160               0 :                         format!("exceeded connection timeout {timeout:?}"),
     161               0 :                     )),
     162 CBC          29 :                 },
     163              29 :             )
     164              29 :         };
     165              29 : 
     166              29 :         let connect_once = |host, port| {
     167              29 :             info!("trying to connect to compute node at {host}:{port}");
     168              29 :             connect_with_timeout(host, port).and_then(|socket| async {
     169              29 :                 let socket_addr = socket.peer_addr()?;
     170                 :                 // This prevents load balancer from severing the connection.
     171              29 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     172              29 :                 Ok((socket_addr, socket))
     173              29 :             })
     174              29 :         };
     175              29 : 
     176              29 :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     177              29 :         // because it has no means for extracting the underlying socket which we
     178              29 :         // require for our business.
     179              29 :         let mut connection_error = None;
     180              29 :         let ports = self.0.get_ports();
     181              29 :         let hosts = self.0.get_hosts();
     182              29 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     183              29 :         if ports.len() > 1 && ports.len() != hosts.len() {
     184 UBC           0 :             return Err(io::Error::new(
     185               0 :                 io::ErrorKind::Other,
     186               0 :                 format!(
     187               0 :                     "bad compute config, \
     188               0 :                      ports and hosts entries' count does not match: {:?}",
     189               0 :                     self.0
     190               0 :                 ),
     191               0 :             ));
     192 CBC          29 :         }
     193                 : 
     194              29 :         for (i, host) in hosts.iter().enumerate() {
     195              29 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     196              29 :             let host = match host {
     197              29 :                 Host::Tcp(host) => host.as_str(),
     198 UBC           0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     199                 :             };
     200                 : 
     201 CBC          58 :             match connect_once(host, *port).await {
     202              29 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     203 UBC           0 :                 Err(err) => {
     204                 :                     // We can't throw an error here, as there might be more hosts to try.
     205               0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     206               0 :                     connection_error = Some(err);
     207                 :                 }
     208                 :             }
     209                 :         }
     210                 : 
     211               0 :         Err(connection_error.unwrap_or_else(|| {
     212               0 :             io::Error::new(
     213               0 :                 io::ErrorKind::Other,
     214               0 :                 format!("bad compute config: {:?}", self.0),
     215               0 :             )
     216               0 :         }))
     217 CBC          29 :     }
     218                 : }
     219                 : 
     220                 : pub struct PostgresConnection {
     221                 :     /// Socket connected to a compute node.
     222                 :     pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     223                 :         tokio::net::TcpStream,
     224                 :         postgres_native_tls::TlsStream<tokio::net::TcpStream>,
     225                 :     >,
     226                 :     /// PostgreSQL connection parameters.
     227                 :     pub params: std::collections::HashMap<String, String>,
     228                 :     /// Query cancellation token.
     229                 :     pub cancel_closure: CancelClosure,
     230                 : }
     231                 : 
     232                 : impl ConnCfg {
     233                 :     /// Connect to a corresponding compute node.
     234              29 :     pub async fn connect(
     235              29 :         &self,
     236              29 :         allow_self_signed_compute: bool,
     237              29 :         timeout: Duration,
     238              29 :     ) -> Result<PostgresConnection, ConnectionError> {
     239              58 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     240                 : 
     241              29 :         let tls_connector = native_tls::TlsConnector::builder()
     242              29 :             .danger_accept_invalid_certs(allow_self_signed_compute)
     243              29 :             .build()
     244              29 :             .unwrap();
     245              29 :         let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
     246              29 :         let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
     247                 : 
     248                 :         // connect_raw() will not use TLS if sslmode is "disable"
     249              29 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     250              29 :         let stream = connection.stream.into_inner();
     251                 : 
     252              29 :         info!(
     253              29 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     254              29 :             self.0.get_ssl_mode()
     255              29 :         );
     256                 : 
     257                 :         // This is very ugly but as of now there's no better way to
     258                 :         // extract the connection parameters from tokio-postgres' connection.
     259                 :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     260              29 :         let params = connection.parameters;
     261              29 : 
     262              29 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     263              29 :         // Yet another reason to rework the connection establishing code.
     264              29 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     265              29 : 
     266              29 :         let connection = PostgresConnection {
     267              29 :             stream,
     268              29 :             params,
     269              29 :             cancel_closure,
     270              29 :         };
     271              29 : 
     272              29 :         Ok(connection)
     273              29 :     }
     274                 : }
     275                 : 
     276                 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     277              34 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     278                 :     #[allow(unstable_name_collisions)]
     279              34 :     let options: String = params
     280              34 :         .options_raw()?
     281              55 :         .filter(|opt| parse_endpoint_param(opt).is_none())
     282              34 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     283              34 :         .collect();
     284              34 : 
     285              34 :     // Don't even bother with empty options.
     286              34 :     if options.is_empty() {
     287               6 :         return None;
     288              28 :     }
     289              28 : 
     290              28 :     Some(options)
     291              34 : }
     292                 : 
     293                 : #[cfg(test)]
     294                 : mod tests {
     295                 :     use super::*;
     296                 : 
     297               1 :     #[test]
     298               1 :     fn test_filtered_options() {
     299               1 :         // Empty options is unlikely to be useful anyway.
     300               1 :         let params = StartupMessageParams::new([("options", "")]);
     301               1 :         assert_eq!(filtered_options(&params), None);
     302                 : 
     303                 :         // It's likely that clients will only use options to specify endpoint/project.
     304               1 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     305               1 :         assert_eq!(filtered_options(&params), None);
     306                 : 
     307                 :         // Same, because unescaped whitespaces are no-op.
     308               1 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     309               1 :         assert_eq!(filtered_options(&params).as_deref(), None);
     310                 : 
     311               1 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     312               1 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     313                 : 
     314               1 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     315               1 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     316               1 :     }
     317                 : }
        

Generated by: LCOV version 2.1-beta