LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 72.4 % 174 126
Test Date: 2023-09-06 10:18:01 Functions: 56.2 % 32 18

            Line data    Source code
       1              : use crate::{
       2              :     auth::parse_endpoint_param,
       3              :     cancellation::CancelClosure,
       4              :     console::errors::WakeComputeError,
       5              :     error::{io_error, UserFacingError},
       6              : };
       7              : use futures::{FutureExt, TryFutureExt};
       8              : use itertools::Itertools;
       9              : use pq_proto::StartupMessageParams;
      10              : use std::{io, net::SocketAddr, time::Duration};
      11              : use thiserror::Error;
      12              : use tokio::net::TcpStream;
      13              : use tokio_postgres::tls::MakeTlsConnect;
      14              : use tracing::{error, info, warn};
      15              : 
      16              : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      17              : 
      18            0 : #[derive(Debug, Error)]
      19              : pub enum ConnectionError {
      20              :     /// This error doesn't seem to reveal any secrets; for instance,
      21              :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      22              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      23              :     Postgres(#[from] tokio_postgres::Error),
      24              : 
      25              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      26              :     CouldNotConnect(#[from] io::Error),
      27              : 
      28              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      29              :     TlsError(#[from] native_tls::Error),
      30              : }
      31              : 
      32              : impl From<WakeComputeError> for ConnectionError {
      33            0 :     fn from(value: WakeComputeError) -> Self {
      34            0 :         io_error(value).into()
      35            0 :     }
      36              : }
      37              : 
      38              : impl UserFacingError for ConnectionError {
      39            0 :     fn to_string_client(&self) -> String {
      40            0 :         use ConnectionError::*;
      41            0 :         match self {
      42              :             // This helps us drop irrelevant library-specific prefixes.
      43              :             // TODO: propagate severity level and other parameters.
      44            0 :             Postgres(err) => match err.as_db_error() {
      45            0 :                 Some(err) => err.message().to_owned(),
      46            0 :                 None => err.to_string(),
      47              :             },
      48            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      49              :         }
      50            0 :     }
      51              : }
      52              : 
      53              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      54              : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      55              : 
      56              : /// A config for establishing a connection to compute node.
      57              : /// Eventually, `tokio_postgres` will be replaced with something better.
      58              : /// Newtype allows us to implement methods on top of it.
      59            0 : #[derive(Clone)]
      60              : #[repr(transparent)]
      61              : pub struct ConnCfg(Box<tokio_postgres::Config>);
      62              : 
      63              : /// Creation and initialization routines.
      64              : impl ConnCfg {
      65           61 :     pub fn new() -> Self {
      66           61 :         Self(Default::default())
      67           61 :     }
      68              : 
      69              :     /// Reuse password or auth keys from the other config.
      70              :     pub fn reuse_password(&mut self, other: &Self) {
      71            7 :         if let Some(password) = other.get_password() {
      72            0 :             self.password(password);
      73            7 :         }
      74              : 
      75            7 :         if let Some(keys) = other.get_auth_keys() {
      76            0 :             self.auth_keys(keys);
      77            7 :         }
      78            7 :     }
      79              : 
      80              :     /// Apply startup message params to the connection config.
      81              :     pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
      82              :         // Only set `user` if it's not present in the config.
      83              :         // Link auth flow takes username from the console's response.
      84           27 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
      85           24 :             self.user(user);
      86           24 :         }
      87              : 
      88              :         // Only set `dbname` if it's not present in the config.
      89              :         // Link auth flow takes dbname from the console's response.
      90           27 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
      91           24 :             self.dbname(dbname);
      92           24 :         }
      93              : 
      94              :         // Don't add `options` if they were only used for specifying a project.
      95              :         // Connection pools don't support `options`, because they affect backend startup.
      96           27 :         if let Some(options) = filtered_options(params) {
      97           24 :             self.options(&options);
      98           24 :         }
      99              : 
     100           27 :         if let Some(app_name) = params.get("application_name") {
     101            3 :             self.application_name(app_name);
     102           24 :         }
     103              : 
     104              :         // TODO: This is especially ugly...
     105           27 :         if let Some(replication) = params.get("replication") {
     106              :             use tokio_postgres::config::ReplicationMode;
     107            0 :             match replication {
     108            0 :                 "true" | "on" | "yes" | "1" => {
     109            0 :                     self.replication_mode(ReplicationMode::Physical);
     110            0 :                 }
     111            0 :                 "database" => {
     112            0 :                     self.replication_mode(ReplicationMode::Logical);
     113            0 :                 }
     114            0 :                 _other => {}
     115              :             }
     116           27 :         }
     117              : 
     118              :         // TODO: extend the list of the forwarded startup parameters.
     119              :         // Currently, tokio-postgres doesn't allow us to pass
     120              :         // arbitrary parameters, but the ones above are a good start.
     121              :         //
     122              :         // This and the reverse params problem can be better addressed
     123              :         // in a bespoke connection machinery (a new library for that sake).
     124           27 :     }
     125              : }
     126              : 
     127              : impl std::ops::Deref for ConnCfg {
     128              :     type Target = tokio_postgres::Config;
     129              : 
     130           90 :     fn deref(&self) -> &Self::Target {
     131           90 :         &self.0
     132           90 :     }
     133              : }
     134              : 
     135              : /// For now, let's make it easier to setup the config.
     136              : impl std::ops::DerefMut for ConnCfg {
     137          151 :     fn deref_mut(&mut self) -> &mut Self::Target {
     138          151 :         &mut self.0
     139          151 :     }
     140              : }
     141              : 
     142              : impl Default for ConnCfg {
     143            0 :     fn default() -> Self {
     144            0 :         Self::new()
     145            0 :     }
     146              : }
     147              : 
     148              : impl ConnCfg {
     149              :     /// Establish a raw TCP connection to the compute node.
     150           27 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     151           27 :         use tokio_postgres::config::Host;
     152           27 : 
     153           27 :         // wrap TcpStream::connect with timeout
     154           27 :         let connect_with_timeout = |host, port| {
     155           27 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     156           27 :                 move |res| match res {
     157           27 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     158            0 :                     Err(_) => Err(io::Error::new(
     159            0 :                         io::ErrorKind::TimedOut,
     160            0 :                         format!("exceeded connection timeout {timeout:?}"),
     161            0 :                     )),
     162           27 :                 },
     163           27 :             )
     164           27 :         };
     165           27 : 
     166           27 :         let connect_once = |host, port| {
     167           27 :             info!("trying to connect to compute node at {host}:{port}");
     168           27 :             connect_with_timeout(host, port).and_then(|socket| async {
     169           27 :                 let socket_addr = socket.peer_addr()?;
     170              :                 // This prevents load balancer from severing the connection.
     171           27 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     172           27 :                 Ok((socket_addr, socket))
     173           27 :             })
     174           27 :         };
     175           27 : 
     176           27 :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     177           27 :         // because it has no means for extracting the underlying socket which we
     178           27 :         // require for our business.
     179           27 :         let mut connection_error = None;
     180           27 :         let ports = self.0.get_ports();
     181           27 :         let hosts = self.0.get_hosts();
     182           27 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     183           27 :         if ports.len() > 1 && ports.len() != hosts.len() {
     184            0 :             return Err(io::Error::new(
     185            0 :                 io::ErrorKind::Other,
     186            0 :                 format!(
     187            0 :                     "bad compute config, \
     188            0 :                      ports and hosts entries' count does not match: {:?}",
     189            0 :                     self.0
     190            0 :                 ),
     191            0 :             ));
     192           27 :         }
     193              : 
     194           27 :         for (i, host) in hosts.iter().enumerate() {
     195           27 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     196           27 :             let host = match host {
     197           27 :                 Host::Tcp(host) => host.as_str(),
     198            0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     199              :             };
     200              : 
     201           53 :             match connect_once(host, *port).await {
     202           27 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     203            0 :                 Err(err) => {
     204              :                     // We can't throw an error here, as there might be more hosts to try.
     205            0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     206            0 :                     connection_error = Some(err);
     207              :                 }
     208              :             }
     209              :         }
     210              : 
     211            0 :         Err(connection_error.unwrap_or_else(|| {
     212            0 :             io::Error::new(
     213            0 :                 io::ErrorKind::Other,
     214            0 :                 format!("bad compute config: {:?}", self.0),
     215            0 :             )
     216            0 :         }))
     217           27 :     }
     218              : }
     219              : 
     220              : pub struct PostgresConnection {
     221              :     /// Socket connected to a compute node.
     222              :     pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     223              :         tokio::net::TcpStream,
     224              :         postgres_native_tls::TlsStream<tokio::net::TcpStream>,
     225              :     >,
     226              :     /// PostgreSQL connection parameters.
     227              :     pub params: std::collections::HashMap<String, String>,
     228              :     /// Query cancellation token.
     229              :     pub cancel_closure: CancelClosure,
     230              : }
     231              : 
     232              : impl ConnCfg {
     233              :     /// Connect to a corresponding compute node.
     234           27 :     pub async fn connect(
     235           27 :         &self,
     236           27 :         allow_self_signed_compute: bool,
     237           27 :         timeout: Duration,
     238           27 :     ) -> Result<PostgresConnection, ConnectionError> {
     239           53 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     240              : 
     241           27 :         let tls_connector = native_tls::TlsConnector::builder()
     242           27 :             .danger_accept_invalid_certs(allow_self_signed_compute)
     243           27 :             .build()
     244           27 :             .unwrap();
     245           27 :         let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
     246           27 :         let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
     247              : 
     248              :         // connect_raw() will not use TLS if sslmode is "disable"
     249           27 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     250           27 :         let stream = connection.stream.into_inner();
     251              : 
     252           27 :         info!(
     253           27 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     254           27 :             self.0.get_ssl_mode()
     255           27 :         );
     256              : 
     257              :         // This is very ugly but as of now there's no better way to
     258              :         // extract the connection parameters from tokio-postgres' connection.
     259              :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     260           27 :         let params = connection.parameters;
     261           27 : 
     262           27 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     263           27 :         // Yet another reason to rework the connection establishing code.
     264           27 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     265           27 : 
     266           27 :         let connection = PostgresConnection {
     267           27 :             stream,
     268           27 :             params,
     269           27 :             cancel_closure,
     270           27 :         };
     271           27 : 
     272           27 :         Ok(connection)
     273           27 :     }
     274              : }
     275              : 
     276              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     277           32 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     278              :     #[allow(unstable_name_collisions)]
     279           32 :     let options: String = params
     280           32 :         .options_raw()?
     281           53 :         .filter(|opt| parse_endpoint_param(opt).is_none())
     282           32 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     283           32 :         .collect();
     284           32 : 
     285           32 :     // Don't even bother with empty options.
     286           32 :     if options.is_empty() {
     287            6 :         return None;
     288           26 :     }
     289           26 : 
     290           26 :     Some(options)
     291           32 : }
     292              : 
     293              : #[cfg(test)]
     294              : mod tests {
     295              :     use super::*;
     296              : 
     297            1 :     #[test]
     298            1 :     fn test_filtered_options() {
     299            1 :         // Empty options is unlikely to be useful anyway.
     300            1 :         let params = StartupMessageParams::new([("options", "")]);
     301            1 :         assert_eq!(filtered_options(&params), None);
     302              : 
     303              :         // It's likely that clients will only use options to specify endpoint/project.
     304            1 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     305            1 :         assert_eq!(filtered_options(&params), None);
     306              : 
     307              :         // Same, because unescaped whitespaces are no-op.
     308            1 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     309            1 :         assert_eq!(filtered_options(&params).as_deref(), None);
     310              : 
     311            1 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     312            1 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     313              : 
     314            1 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     315            1 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     316            1 :     }
     317              : }
        

Generated by: LCOV version 2.1-beta