LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 68.6 % 194 133
Test Date: 2024-02-12 20:26:03 Functions: 58.8 % 34 20

            Line data    Source code
       1              : use crate::{
       2              :     auth::parse_endpoint_param,
       3              :     cancellation::CancelClosure,
       4              :     console::errors::WakeComputeError,
       5              :     context::RequestMonitoring,
       6              :     error::{ReportableError, UserFacingError},
       7              :     metrics::NUM_DB_CONNECTIONS_GAUGE,
       8              :     proxy::neon_option,
       9              : };
      10              : use futures::{FutureExt, TryFutureExt};
      11              : use itertools::Itertools;
      12              : use metrics::IntCounterPairGuard;
      13              : use pq_proto::StartupMessageParams;
      14              : use std::{io, net::SocketAddr, time::Duration};
      15              : use thiserror::Error;
      16              : use tokio::net::TcpStream;
      17              : use tokio_postgres::tls::MakeTlsConnect;
      18              : use tracing::{error, info, warn};
      19              : 
      20              : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      21              : 
      22            0 : #[derive(Debug, Error)]
      23              : pub enum ConnectionError {
      24              :     /// This error doesn't seem to reveal any secrets; for instance,
      25              :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      26              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      27              :     Postgres(#[from] tokio_postgres::Error),
      28              : 
      29              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      30              :     CouldNotConnect(#[from] io::Error),
      31              : 
      32              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      33              :     TlsError(#[from] native_tls::Error),
      34              : 
      35              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      36              :     WakeComputeError(#[from] WakeComputeError),
      37              : }
      38              : 
      39              : impl UserFacingError for ConnectionError {
      40            0 :     fn to_string_client(&self) -> String {
      41            0 :         use ConnectionError::*;
      42            0 :         match self {
      43              :             // This helps us drop irrelevant library-specific prefixes.
      44              :             // TODO: propagate severity level and other parameters.
      45            0 :             Postgres(err) => match err.as_db_error() {
      46            0 :                 Some(err) => {
      47            0 :                     let msg = err.message();
      48            0 : 
      49            0 :                     if msg.starts_with("unsupported startup parameter: ")
      50            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      51              :                     {
      52            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      53              :                     } else {
      54            0 :                         msg.to_owned()
      55              :                     }
      56              :                 }
      57            0 :                 None => err.to_string(),
      58              :             },
      59            0 :             WakeComputeError(err) => err.to_string_client(),
      60            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      61              :         }
      62            0 :     }
      63              : }
      64              : 
      65              : impl ReportableError for ConnectionError {
      66            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      67            0 :         match self {
      68            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      69            0 :                 crate::error::ErrorKind::Postgres
      70              :             }
      71            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      72            0 :             ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
      73            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      74            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      75              :         }
      76            0 :     }
      77              : }
      78              : 
      79              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      80              : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      81              : 
      82              : /// A config for establishing a connection to compute node.
      83              : /// Eventually, `tokio_postgres` will be replaced with something better.
      84              : /// Newtype allows us to implement methods on top of it.
      85            0 : #[derive(Clone)]
      86              : #[repr(transparent)]
      87              : pub struct ConnCfg(Box<tokio_postgres::Config>);
      88              : 
      89              : /// Creation and initialization routines.
      90              : impl ConnCfg {
      91          105 :     pub fn new() -> Self {
      92          105 :         Self(Default::default())
      93          105 :     }
      94              : 
      95              :     /// Reuse password or auth keys from the other config.
      96           10 :     pub fn reuse_password(&mut self, other: &Self) {
      97           10 :         if let Some(password) = other.get_password() {
      98            0 :             self.password(password);
      99           10 :         }
     100              : 
     101           10 :         if let Some(keys) = other.get_auth_keys() {
     102            0 :             self.auth_keys(keys);
     103           10 :         }
     104           10 :     }
     105              : 
     106              :     /// Apply startup message params to the connection config.
     107           41 :     pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
     108              :         // Only set `user` if it's not present in the config.
     109              :         // Link auth flow takes username from the console's response.
     110           41 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
     111           38 :             self.user(user);
     112           38 :         }
     113              : 
     114              :         // Only set `dbname` if it's not present in the config.
     115              :         // Link auth flow takes dbname from the console's response.
     116           41 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
     117           38 :             self.dbname(dbname);
     118           38 :         }
     119              : 
     120              :         // Don't add `options` if they were only used for specifying a project.
     121              :         // Connection pools don't support `options`, because they affect backend startup.
     122           41 :         if let Some(options) = filtered_options(params) {
     123           38 :             self.options(&options);
     124           38 :         }
     125              : 
     126           41 :         if let Some(app_name) = params.get("application_name") {
     127            3 :             self.application_name(app_name);
     128           38 :         }
     129              : 
     130              :         // TODO: This is especially ugly...
     131           41 :         if let Some(replication) = params.get("replication") {
     132              :             use tokio_postgres::config::ReplicationMode;
     133            0 :             match replication {
     134            0 :                 "true" | "on" | "yes" | "1" => {
     135            0 :                     self.replication_mode(ReplicationMode::Physical);
     136            0 :                 }
     137            0 :                 "database" => {
     138            0 :                     self.replication_mode(ReplicationMode::Logical);
     139            0 :                 }
     140            0 :                 _other => {}
     141              :             }
     142           41 :         }
     143              : 
     144              :         // TODO: extend the list of the forwarded startup parameters.
     145              :         // Currently, tokio-postgres doesn't allow us to pass
     146              :         // arbitrary parameters, but the ones above are a good start.
     147              :         //
     148              :         // This and the reverse params problem can be better addressed
     149              :         // in a bespoke connection machinery (a new library for that sake).
     150           41 :     }
     151              : }
     152              : 
     153              : impl std::ops::Deref for ConnCfg {
     154              :     type Target = tokio_postgres::Config;
     155              : 
     156          142 :     fn deref(&self) -> &Self::Target {
     157          142 :         &self.0
     158          142 :     }
     159              : }
     160              : 
     161              : /// For now, let's make it easier to setup the config.
     162              : impl std::ops::DerefMut for ConnCfg {
     163          279 :     fn deref_mut(&mut self) -> &mut Self::Target {
     164          279 :         &mut self.0
     165          279 :     }
     166              : }
     167              : 
     168              : impl Default for ConnCfg {
     169            0 :     fn default() -> Self {
     170            0 :         Self::new()
     171            0 :     }
     172              : }
     173              : 
     174              : impl ConnCfg {
     175              :     /// Establish a raw TCP connection to the compute node.
     176           41 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     177           41 :         use tokio_postgres::config::Host;
     178           41 : 
     179           41 :         // wrap TcpStream::connect with timeout
     180           41 :         let connect_with_timeout = |host, port| {
     181           41 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     182           41 :                 move |res| match res {
     183           41 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     184            0 :                     Err(_) => Err(io::Error::new(
     185            0 :                         io::ErrorKind::TimedOut,
     186            0 :                         format!("exceeded connection timeout {timeout:?}"),
     187            0 :                     )),
     188           41 :                 },
     189           41 :             )
     190           41 :         };
     191              : 
     192           41 :         let connect_once = |host, port| {
     193           41 :             info!("trying to connect to compute node at {host}:{port}");
     194           41 :             connect_with_timeout(host, port).and_then(|socket| async {
     195           41 :                 let socket_addr = socket.peer_addr()?;
     196              :                 // This prevents load balancer from severing the connection.
     197           41 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     198           41 :                 Ok((socket_addr, socket))
     199           82 :             })
     200           41 :         };
     201              : 
     202              :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     203              :         // because it has no means for extracting the underlying socket which we
     204              :         // require for our business.
     205           41 :         let mut connection_error = None;
     206           41 :         let ports = self.0.get_ports();
     207           41 :         let hosts = self.0.get_hosts();
     208           41 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     209           41 :         if ports.len() > 1 && ports.len() != hosts.len() {
     210            0 :             return Err(io::Error::new(
     211            0 :                 io::ErrorKind::Other,
     212            0 :                 format!(
     213            0 :                     "bad compute config, \
     214            0 :                      ports and hosts entries' count does not match: {:?}",
     215            0 :                     self.0
     216            0 :                 ),
     217            0 :             ));
     218           41 :         }
     219              : 
     220           41 :         for (i, host) in hosts.iter().enumerate() {
     221           41 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     222           41 :             let host = match host {
     223           41 :                 Host::Tcp(host) => host.as_str(),
     224            0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     225              :             };
     226              : 
     227           77 :             match connect_once(host, *port).await {
     228           41 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     229            0 :                 Err(err) => {
     230              :                     // We can't throw an error here, as there might be more hosts to try.
     231            0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     232            0 :                     connection_error = Some(err);
     233              :                 }
     234              :             }
     235              :         }
     236              : 
     237            0 :         Err(connection_error.unwrap_or_else(|| {
     238            0 :             io::Error::new(
     239            0 :                 io::ErrorKind::Other,
     240            0 :                 format!("bad compute config: {:?}", self.0),
     241            0 :             )
     242            0 :         }))
     243           41 :     }
     244              : }
     245              : 
     246              : pub struct PostgresConnection {
     247              :     /// Socket connected to a compute node.
     248              :     pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     249              :         tokio::net::TcpStream,
     250              :         postgres_native_tls::TlsStream<tokio::net::TcpStream>,
     251              :     >,
     252              :     /// PostgreSQL connection parameters.
     253              :     pub params: std::collections::HashMap<String, String>,
     254              :     /// Query cancellation token.
     255              :     pub cancel_closure: CancelClosure,
     256              : 
     257              :     _guage: IntCounterPairGuard,
     258              : }
     259              : 
     260              : impl ConnCfg {
     261              :     /// Connect to a corresponding compute node.
     262           41 :     pub async fn connect(
     263           41 :         &self,
     264           41 :         ctx: &mut RequestMonitoring,
     265           41 :         allow_self_signed_compute: bool,
     266           41 :         timeout: Duration,
     267           41 :     ) -> Result<PostgresConnection, ConnectionError> {
     268           77 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     269              : 
     270           41 :         let tls_connector = native_tls::TlsConnector::builder()
     271           41 :             .danger_accept_invalid_certs(allow_self_signed_compute)
     272           41 :             .build()
     273           41 :             .unwrap();
     274           41 :         let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
     275           41 :         let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
     276              : 
     277              :         // connect_raw() will not use TLS if sslmode is "disable"
     278           43 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     279           41 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     280           41 :         let stream = connection.stream.into_inner();
     281              : 
     282           41 :         info!(
     283           41 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     284           41 :             self.0.get_ssl_mode()
     285           41 :         );
     286              : 
     287              :         // This is very ugly but as of now there's no better way to
     288              :         // extract the connection parameters from tokio-postgres' connection.
     289              :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     290           41 :         let params = connection.parameters;
     291           41 : 
     292           41 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     293           41 :         // Yet another reason to rework the connection establishing code.
     294           41 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     295           41 : 
     296           41 :         let connection = PostgresConnection {
     297           41 :             stream,
     298           41 :             params,
     299           41 :             cancel_closure,
     300           41 :             _guage: NUM_DB_CONNECTIONS_GAUGE
     301           41 :                 .with_label_values(&[ctx.protocol])
     302           41 :                 .guard(),
     303           41 :         };
     304           41 : 
     305           41 :         Ok(connection)
     306           41 :     }
     307              : }
     308              : 
     309              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     310           53 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     311              :     #[allow(unstable_name_collisions)]
     312           53 :     let options: String = params
     313           53 :         .options_raw()?
     314           87 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     315           53 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     316           53 :         .collect();
     317           53 : 
     318           53 :     // Don't even bother with empty options.
     319           53 :     if options.is_empty() {
     320            9 :         return None;
     321           44 :     }
     322           44 : 
     323           44 :     Some(options)
     324           53 : }
     325              : 
     326              : #[cfg(test)]
     327              : mod tests {
     328              :     use super::*;
     329              : 
     330            2 :     #[test]
     331            2 :     fn test_filtered_options() {
     332            2 :         // Empty options is unlikely to be useful anyway.
     333            2 :         let params = StartupMessageParams::new([("options", "")]);
     334            2 :         assert_eq!(filtered_options(&params), None);
     335              : 
     336              :         // It's likely that clients will only use options to specify endpoint/project.
     337            2 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     338            2 :         assert_eq!(filtered_options(&params), None);
     339              : 
     340              :         // Same, because unescaped whitespaces are no-op.
     341            2 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     342            2 :         assert_eq!(filtered_options(&params).as_deref(), None);
     343              : 
     344            2 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     345            2 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     346              : 
     347            2 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     348            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     349              : 
     350            2 :         let params = StartupMessageParams::new([(
     351            2 :             "options",
     352            2 :             "project = foo neon_endpoint_type:read_write   neon_lsn:0/2",
     353            2 :         )]);
     354            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     355            2 :     }
     356              : }
        

Generated by: LCOV version 2.1-beta