LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 72.6 % 190 138
Test Date: 2024-02-07 07:37:29 Functions: 59.4 % 32 19

            Line data    Source code
       1              : use crate::{
       2              :     auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
       3              :     context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
       4              :     proxy::neon_option,
       5              : };
       6              : use futures::{FutureExt, TryFutureExt};
       7              : use itertools::Itertools;
       8              : use metrics::IntCounterPairGuard;
       9              : use pq_proto::StartupMessageParams;
      10              : use std::{io, net::SocketAddr, time::Duration};
      11              : use thiserror::Error;
      12              : use tokio::net::TcpStream;
      13              : use tokio_postgres::tls::MakeTlsConnect;
      14              : use tracing::{error, info, warn};
      15              : 
      16              : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      17              : 
      18            0 : #[derive(Debug, Error)]
      19              : pub enum ConnectionError {
      20              :     /// This error doesn't seem to reveal any secrets; for instance,
      21              :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      22              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      23              :     Postgres(#[from] tokio_postgres::Error),
      24              : 
      25              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      26              :     CouldNotConnect(#[from] io::Error),
      27              : 
      28              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      29              :     TlsError(#[from] native_tls::Error),
      30              : 
      31              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      32              :     WakeComputeError(#[from] WakeComputeError),
      33              : }
      34              : 
      35              : impl UserFacingError for ConnectionError {
      36            0 :     fn to_string_client(&self) -> String {
      37            0 :         use ConnectionError::*;
      38            0 :         match self {
      39              :             // This helps us drop irrelevant library-specific prefixes.
      40              :             // TODO: propagate severity level and other parameters.
      41            0 :             Postgres(err) => match err.as_db_error() {
      42            0 :                 Some(err) => {
      43            0 :                     let msg = err.message();
      44            0 : 
      45            0 :                     if msg.starts_with("unsupported startup parameter: ")
      46            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      47              :                     {
      48            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      49              :                     } else {
      50            0 :                         msg.to_owned()
      51              :                     }
      52              :                 }
      53            0 :                 None => err.to_string(),
      54              :             },
      55            0 :             WakeComputeError(err) => err.to_string_client(),
      56            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      57              :         }
      58            0 :     }
      59              : }
      60              : 
      61              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      62              : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      63              : 
      64              : /// A config for establishing a connection to compute node.
      65              : /// Eventually, `tokio_postgres` will be replaced with something better.
      66              : /// Newtype allows us to implement methods on top of it.
      67            0 : #[derive(Clone)]
      68              : #[repr(transparent)]
      69              : pub struct ConnCfg(Box<tokio_postgres::Config>);
      70              : 
      71              : /// Creation and initialization routines.
      72              : impl ConnCfg {
      73          106 :     pub fn new() -> Self {
      74          106 :         Self(Default::default())
      75          106 :     }
      76              : 
      77              :     /// Reuse password or auth keys from the other config.
      78           12 :     pub fn reuse_password(&mut self, other: &Self) {
      79           12 :         if let Some(password) = other.get_password() {
      80            0 :             self.password(password);
      81           12 :         }
      82              : 
      83           12 :         if let Some(keys) = other.get_auth_keys() {
      84            0 :             self.auth_keys(keys);
      85           12 :         }
      86           12 :     }
      87              : 
      88              :     /// Apply startup message params to the connection config.
      89           39 :     pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
      90              :         // Only set `user` if it's not present in the config.
      91              :         // Link auth flow takes username from the console's response.
      92           39 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
      93           36 :             self.user(user);
      94           36 :         }
      95              : 
      96              :         // Only set `dbname` if it's not present in the config.
      97              :         // Link auth flow takes dbname from the console's response.
      98           39 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
      99           36 :             self.dbname(dbname);
     100           36 :         }
     101              : 
     102              :         // Don't add `options` if they were only used for specifying a project.
     103              :         // Connection pools don't support `options`, because they affect backend startup.
     104           39 :         if let Some(options) = filtered_options(params) {
     105           36 :             self.options(&options);
     106           36 :         }
     107              : 
     108           39 :         if let Some(app_name) = params.get("application_name") {
     109            3 :             self.application_name(app_name);
     110           36 :         }
     111              : 
     112              :         // TODO: This is especially ugly...
     113           39 :         if let Some(replication) = params.get("replication") {
     114              :             use tokio_postgres::config::ReplicationMode;
     115            0 :             match replication {
     116            0 :                 "true" | "on" | "yes" | "1" => {
     117            0 :                     self.replication_mode(ReplicationMode::Physical);
     118            0 :                 }
     119            0 :                 "database" => {
     120            0 :                     self.replication_mode(ReplicationMode::Logical);
     121            0 :                 }
     122            0 :                 _other => {}
     123              :             }
     124           39 :         }
     125              : 
     126              :         // TODO: extend the list of the forwarded startup parameters.
     127              :         // Currently, tokio-postgres doesn't allow us to pass
     128              :         // arbitrary parameters, but the ones above are a good start.
     129              :         //
     130              :         // This and the reverse params problem can be better addressed
     131              :         // in a bespoke connection machinery (a new library for that sake).
     132           39 :     }
     133              : }
     134              : 
     135              : impl std::ops::Deref for ConnCfg {
     136              :     type Target = tokio_postgres::Config;
     137              : 
     138          145 :     fn deref(&self) -> &Self::Target {
     139          145 :         &self.0
     140          145 :     }
     141              : }
     142              : 
     143              : /// For now, let's make it easier to setup the config.
     144              : impl std::ops::DerefMut for ConnCfg {
     145          232 :     fn deref_mut(&mut self) -> &mut Self::Target {
     146          232 :         &mut self.0
     147          232 :     }
     148              : }
     149              : 
     150              : impl Default for ConnCfg {
     151            0 :     fn default() -> Self {
     152            0 :         Self::new()
     153            0 :     }
     154              : }
     155              : 
     156              : impl ConnCfg {
     157              :     /// Establish a raw TCP connection to the compute node.
     158           39 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     159           39 :         use tokio_postgres::config::Host;
     160           39 : 
     161           39 :         // wrap TcpStream::connect with timeout
     162           39 :         let connect_with_timeout = |host, port| {
     163           39 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     164           39 :                 move |res| match res {
     165           39 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     166            0 :                     Err(_) => Err(io::Error::new(
     167            0 :                         io::ErrorKind::TimedOut,
     168            0 :                         format!("exceeded connection timeout {timeout:?}"),
     169            0 :                     )),
     170           39 :                 },
     171           39 :             )
     172           39 :         };
     173           39 : 
     174           39 :         let connect_once = |host, port| {
     175           39 :             info!("trying to connect to compute node at {host}:{port}");
     176           39 :             connect_with_timeout(host, port).and_then(|socket| async {
     177           39 :                 let socket_addr = socket.peer_addr()?;
     178              :                 // This prevents load balancer from severing the connection.
     179           39 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     180           39 :                 Ok((socket_addr, socket))
     181           39 :             })
     182           39 :         };
     183           39 : 
     184           39 :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     185           39 :         // because it has no means for extracting the underlying socket which we
     186           39 :         // require for our business.
     187           39 :         let mut connection_error = None;
     188           39 :         let ports = self.0.get_ports();
     189           39 :         let hosts = self.0.get_hosts();
     190           39 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     191           39 :         if ports.len() > 1 && ports.len() != hosts.len() {
     192            0 :             return Err(io::Error::new(
     193            0 :                 io::ErrorKind::Other,
     194            0 :                 format!(
     195            0 :                     "bad compute config, \
     196            0 :                      ports and hosts entries' count does not match: {:?}",
     197            0 :                     self.0
     198            0 :                 ),
     199            0 :             ));
     200           39 :         }
     201              : 
     202           39 :         for (i, host) in hosts.iter().enumerate() {
     203           39 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     204           39 :             let host = match host {
     205           39 :                 Host::Tcp(host) => host.as_str(),
     206            0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     207              :             };
     208              : 
     209           72 :             match connect_once(host, *port).await {
     210           39 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     211            0 :                 Err(err) => {
     212              :                     // We can't throw an error here, as there might be more hosts to try.
     213            0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     214            0 :                     connection_error = Some(err);
     215              :                 }
     216              :             }
     217              :         }
     218              : 
     219            0 :         Err(connection_error.unwrap_or_else(|| {
     220            0 :             io::Error::new(
     221            0 :                 io::ErrorKind::Other,
     222            0 :                 format!("bad compute config: {:?}", self.0),
     223            0 :             )
     224            0 :         }))
     225           39 :     }
     226              : }
     227              : 
     228              : pub struct PostgresConnection {
     229              :     /// Socket connected to a compute node.
     230              :     pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     231              :         tokio::net::TcpStream,
     232              :         postgres_native_tls::TlsStream<tokio::net::TcpStream>,
     233              :     >,
     234              :     /// PostgreSQL connection parameters.
     235              :     pub params: std::collections::HashMap<String, String>,
     236              :     /// Query cancellation token.
     237              :     pub cancel_closure: CancelClosure,
     238              : 
     239              :     _guage: IntCounterPairGuard,
     240              : }
     241              : 
     242              : impl ConnCfg {
     243              :     /// Connect to a corresponding compute node.
     244           39 :     pub async fn connect(
     245           39 :         &self,
     246           39 :         ctx: &mut RequestMonitoring,
     247           39 :         allow_self_signed_compute: bool,
     248           39 :         timeout: Duration,
     249           39 :     ) -> Result<PostgresConnection, ConnectionError> {
     250           72 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     251              : 
     252           39 :         let tls_connector = native_tls::TlsConnector::builder()
     253           39 :             .danger_accept_invalid_certs(allow_self_signed_compute)
     254           39 :             .build()
     255           39 :             .unwrap();
     256           39 :         let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
     257           39 :         let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
     258              : 
     259              :         // connect_raw() will not use TLS if sslmode is "disable"
     260           42 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     261           39 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     262           39 :         let stream = connection.stream.into_inner();
     263              : 
     264           39 :         info!(
     265           39 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     266           39 :             self.0.get_ssl_mode()
     267           39 :         );
     268              : 
     269              :         // This is very ugly but as of now there's no better way to
     270              :         // extract the connection parameters from tokio-postgres' connection.
     271              :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     272           39 :         let params = connection.parameters;
     273           39 : 
     274           39 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     275           39 :         // Yet another reason to rework the connection establishing code.
     276           39 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     277           39 : 
     278           39 :         let connection = PostgresConnection {
     279           39 :             stream,
     280           39 :             params,
     281           39 :             cancel_closure,
     282           39 :             _guage: NUM_DB_CONNECTIONS_GAUGE
     283           39 :                 .with_label_values(&[ctx.protocol])
     284           39 :                 .guard(),
     285           39 :         };
     286           39 : 
     287           39 :         Ok(connection)
     288           39 :     }
     289              : }
     290              : 
     291              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     292           51 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     293              :     #[allow(unstable_name_collisions)]
     294           51 :     let options: String = params
     295           51 :         .options_raw()?
     296           85 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     297           51 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     298           51 :         .collect();
     299           51 : 
     300           51 :     // Don't even bother with empty options.
     301           51 :     if options.is_empty() {
     302            9 :         return None;
     303           42 :     }
     304           42 : 
     305           42 :     Some(options)
     306           51 : }
     307              : 
     308              : #[cfg(test)]
     309              : mod tests {
     310              :     use super::*;
     311              : 
     312            2 :     #[test]
     313            2 :     fn test_filtered_options() {
     314            2 :         // Empty options is unlikely to be useful anyway.
     315            2 :         let params = StartupMessageParams::new([("options", "")]);
     316            2 :         assert_eq!(filtered_options(&params), None);
     317              : 
     318              :         // It's likely that clients will only use options to specify endpoint/project.
     319            2 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     320            2 :         assert_eq!(filtered_options(&params), None);
     321              : 
     322              :         // Same, because unescaped whitespaces are no-op.
     323            2 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     324            2 :         assert_eq!(filtered_options(&params).as_deref(), None);
     325              : 
     326            2 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     327            2 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     328              : 
     329            2 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     330            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     331              : 
     332            2 :         let params = StartupMessageParams::new([(
     333            2 :             "options",
     334            2 :             "project = foo neon_endpoint_type:read_write   neon_lsn:0/2",
     335            2 :         )]);
     336            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     337            2 :     }
     338              : }
        

Generated by: LCOV version 2.1-beta