LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: 36bb8dd7c7efcb53483d1a7d9f7cb33e8406dcf0.info Lines: 24.7 % 194 48
Test Date: 2024-04-08 10:22:05 Functions: 23.3 % 30 7

            Line data    Source code
       1              : use crate::{
       2              :     auth::parse_endpoint_param,
       3              :     cancellation::CancelClosure,
       4              :     console::{errors::WakeComputeError, messages::MetricsAuxInfo},
       5              :     context::RequestMonitoring,
       6              :     error::{ReportableError, UserFacingError},
       7              :     metrics::NUM_DB_CONNECTIONS_GAUGE,
       8              :     proxy::neon_option,
       9              : };
      10              : use futures::{FutureExt, TryFutureExt};
      11              : use itertools::Itertools;
      12              : use metrics::IntCounterPairGuard;
      13              : use pq_proto::StartupMessageParams;
      14              : use std::{io, net::SocketAddr, time::Duration};
      15              : use thiserror::Error;
      16              : use tokio::net::TcpStream;
      17              : use tokio_postgres::tls::MakeTlsConnect;
      18              : use tracing::{error, info, warn};
      19              : 
      20              : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      21              : 
      22            0 : #[derive(Debug, Error)]
      23              : pub enum ConnectionError {
      24              :     /// This error doesn't seem to reveal any secrets; for instance,
      25              :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      26              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      27              :     Postgres(#[from] tokio_postgres::Error),
      28              : 
      29              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      30              :     CouldNotConnect(#[from] io::Error),
      31              : 
      32              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      33              :     TlsError(#[from] native_tls::Error),
      34              : 
      35              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      36              :     WakeComputeError(#[from] WakeComputeError),
      37              : }
      38              : 
      39              : impl UserFacingError for ConnectionError {
      40            0 :     fn to_string_client(&self) -> String {
      41            0 :         use ConnectionError::*;
      42            0 :         match self {
      43              :             // This helps us drop irrelevant library-specific prefixes.
      44              :             // TODO: propagate severity level and other parameters.
      45            0 :             Postgres(err) => match err.as_db_error() {
      46            0 :                 Some(err) => {
      47            0 :                     let msg = err.message();
      48            0 : 
      49            0 :                     if msg.starts_with("unsupported startup parameter: ")
      50            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      51              :                     {
      52            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      53              :                     } else {
      54            0 :                         msg.to_owned()
      55              :                     }
      56              :                 }
      57            0 :                 None => err.to_string(),
      58              :             },
      59            0 :             WakeComputeError(err) => err.to_string_client(),
      60            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      61              :         }
      62            0 :     }
      63              : }
      64              : 
      65              : impl ReportableError for ConnectionError {
      66            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      67            0 :         match self {
      68            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      69            0 :                 crate::error::ErrorKind::Postgres
      70              :             }
      71            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      72            0 :             ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
      73            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      74            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      75              :         }
      76            0 :     }
      77              : }
      78              : 
      79              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      80              : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      81              : 
      82              : /// A config for establishing a connection to compute node.
      83              : /// Eventually, `tokio_postgres` will be replaced with something better.
      84              : /// Newtype allows us to implement methods on top of it.
      85              : #[derive(Clone, Default)]
      86              : pub struct ConnCfg(Box<tokio_postgres::Config>);
      87              : 
      88              : /// Creation and initialization routines.
      89              : impl ConnCfg {
      90           20 :     pub fn new() -> Self {
      91           20 :         Self::default()
      92           20 :     }
      93              : 
      94              :     /// Reuse password or auth keys from the other config.
      95            8 :     pub fn reuse_password(&mut self, other: Self) {
      96            8 :         if let Some(password) = other.get_password() {
      97            8 :             self.password(password);
      98            8 :         }
      99              : 
     100            8 :         if let Some(keys) = other.get_auth_keys() {
     101            0 :             self.auth_keys(keys);
     102            8 :         }
     103            8 :     }
     104              : 
     105              :     /// Apply startup message params to the connection config.
     106            0 :     pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
     107              :         // Only set `user` if it's not present in the config.
     108              :         // Link auth flow takes username from the console's response.
     109            0 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
     110            0 :             self.user(user);
     111            0 :         }
     112              : 
     113              :         // Only set `dbname` if it's not present in the config.
     114              :         // Link auth flow takes dbname from the console's response.
     115            0 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
     116            0 :             self.dbname(dbname);
     117            0 :         }
     118              : 
     119              :         // Don't add `options` if they were only used for specifying a project.
     120              :         // Connection pools don't support `options`, because they affect backend startup.
     121            0 :         if let Some(options) = filtered_options(params) {
     122            0 :             self.options(&options);
     123            0 :         }
     124              : 
     125            0 :         if let Some(app_name) = params.get("application_name") {
     126            0 :             self.application_name(app_name);
     127            0 :         }
     128              : 
     129              :         // TODO: This is especially ugly...
     130            0 :         if let Some(replication) = params.get("replication") {
     131              :             use tokio_postgres::config::ReplicationMode;
     132            0 :             match replication {
     133            0 :                 "true" | "on" | "yes" | "1" => {
     134            0 :                     self.replication_mode(ReplicationMode::Physical);
     135            0 :                 }
     136            0 :                 "database" => {
     137            0 :                     self.replication_mode(ReplicationMode::Logical);
     138            0 :                 }
     139            0 :                 _other => {}
     140              :             }
     141            0 :         }
     142              : 
     143              :         // TODO: extend the list of the forwarded startup parameters.
     144              :         // Currently, tokio-postgres doesn't allow us to pass
     145              :         // arbitrary parameters, but the ones above are a good start.
     146              :         //
     147              :         // This and the reverse params problem can be better addressed
     148              :         // in a bespoke connection machinery (a new library for that sake).
     149            0 :     }
     150              : }
     151              : 
     152              : impl std::ops::Deref for ConnCfg {
     153              :     type Target = tokio_postgres::Config;
     154              : 
     155           16 :     fn deref(&self) -> &Self::Target {
     156           16 :         &self.0
     157           16 :     }
     158              : }
     159              : 
     160              : /// For now, let's make it easier to setup the config.
     161              : impl std::ops::DerefMut for ConnCfg {
     162           20 :     fn deref_mut(&mut self) -> &mut Self::Target {
     163           20 :         &mut self.0
     164           20 :     }
     165              : }
     166              : 
     167              : impl ConnCfg {
     168              :     /// Establish a raw TCP connection to the compute node.
     169            0 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     170            0 :         use tokio_postgres::config::Host;
     171            0 : 
     172            0 :         // wrap TcpStream::connect with timeout
     173            0 :         let connect_with_timeout = |host, port| {
     174            0 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     175            0 :                 move |res| match res {
     176            0 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     177            0 :                     Err(_) => Err(io::Error::new(
     178            0 :                         io::ErrorKind::TimedOut,
     179            0 :                         format!("exceeded connection timeout {timeout:?}"),
     180            0 :                     )),
     181            0 :                 },
     182            0 :             )
     183            0 :         };
     184              : 
     185            0 :         let connect_once = |host, port| {
     186            0 :             info!("trying to connect to compute node at {host}:{port}");
     187            0 :             connect_with_timeout(host, port).and_then(|socket| async {
     188            0 :                 let socket_addr = socket.peer_addr()?;
     189              :                 // This prevents load balancer from severing the connection.
     190            0 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     191            0 :                 Ok((socket_addr, socket))
     192            0 :             })
     193            0 :         };
     194              : 
     195              :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     196              :         // because it has no means for extracting the underlying socket which we
     197              :         // require for our business.
     198            0 :         let mut connection_error = None;
     199            0 :         let ports = self.0.get_ports();
     200            0 :         let hosts = self.0.get_hosts();
     201            0 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     202            0 :         if ports.len() > 1 && ports.len() != hosts.len() {
     203            0 :             return Err(io::Error::new(
     204            0 :                 io::ErrorKind::Other,
     205            0 :                 format!(
     206            0 :                     "bad compute config, \
     207            0 :                      ports and hosts entries' count does not match: {:?}",
     208            0 :                     self.0
     209            0 :                 ),
     210            0 :             ));
     211            0 :         }
     212              : 
     213            0 :         for (i, host) in hosts.iter().enumerate() {
     214            0 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     215            0 :             let host = match host {
     216            0 :                 Host::Tcp(host) => host.as_str(),
     217            0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     218              :             };
     219              : 
     220            0 :             match connect_once(host, *port).await {
     221            0 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     222            0 :                 Err(err) => {
     223            0 :                     // We can't throw an error here, as there might be more hosts to try.
     224            0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     225            0 :                     connection_error = Some(err);
     226              :                 }
     227              :             }
     228              :         }
     229              : 
     230            0 :         Err(connection_error.unwrap_or_else(|| {
     231            0 :             io::Error::new(
     232            0 :                 io::ErrorKind::Other,
     233            0 :                 format!("bad compute config: {:?}", self.0),
     234            0 :             )
     235            0 :         }))
     236            0 :     }
     237              : }
     238              : 
     239              : pub struct PostgresConnection {
     240              :     /// Socket connected to a compute node.
     241              :     pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     242              :         tokio::net::TcpStream,
     243              :         postgres_native_tls::TlsStream<tokio::net::TcpStream>,
     244              :     >,
     245              :     /// PostgreSQL connection parameters.
     246              :     pub params: std::collections::HashMap<String, String>,
     247              :     /// Query cancellation token.
     248              :     pub cancel_closure: CancelClosure,
     249              :     /// Labels for proxy's metrics.
     250              :     pub aux: MetricsAuxInfo,
     251              : 
     252              :     _guage: IntCounterPairGuard,
     253              : }
     254              : 
     255              : impl ConnCfg {
     256              :     /// Connect to a corresponding compute node.
     257            0 :     pub async fn connect(
     258            0 :         &self,
     259            0 :         ctx: &mut RequestMonitoring,
     260            0 :         allow_self_signed_compute: bool,
     261            0 :         aux: MetricsAuxInfo,
     262            0 :         timeout: Duration,
     263            0 :     ) -> Result<PostgresConnection, ConnectionError> {
     264            0 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     265              : 
     266            0 :         let tls_connector = native_tls::TlsConnector::builder()
     267            0 :             .danger_accept_invalid_certs(allow_self_signed_compute)
     268            0 :             .build()
     269            0 :             .unwrap();
     270            0 :         let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
     271            0 :         let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
     272              : 
     273              :         // connect_raw() will not use TLS if sslmode is "disable"
     274            0 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     275            0 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     276            0 :         let stream = connection.stream.into_inner();
     277            0 : 
     278            0 :         info!(
     279            0 :             cold_start_info = ctx.cold_start_info.as_str(),
     280            0 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     281            0 :             self.0.get_ssl_mode()
     282            0 :         );
     283              : 
     284              :         // This is very ugly but as of now there's no better way to
     285              :         // extract the connection parameters from tokio-postgres' connection.
     286              :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     287            0 :         let params = connection.parameters;
     288            0 : 
     289            0 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     290            0 :         // Yet another reason to rework the connection establishing code.
     291            0 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     292            0 : 
     293            0 :         let connection = PostgresConnection {
     294            0 :             stream,
     295            0 :             params,
     296            0 :             cancel_closure,
     297            0 :             aux,
     298            0 :             _guage: NUM_DB_CONNECTIONS_GAUGE
     299            0 :                 .with_label_values(&[ctx.protocol])
     300            0 :                 .guard(),
     301            0 :         };
     302            0 : 
     303            0 :         Ok(connection)
     304            0 :     }
     305              : }
     306              : 
     307              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     308           12 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     309              :     #[allow(unstable_name_collisions)]
     310           12 :     let options: String = params
     311           12 :         .options_raw()?
     312           26 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     313           12 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     314           12 :         .collect();
     315           12 : 
     316           12 :     // Don't even bother with empty options.
     317           12 :     if options.is_empty() {
     318            6 :         return None;
     319            6 :     }
     320            6 : 
     321            6 :     Some(options)
     322           12 : }
     323              : 
     324              : #[cfg(test)]
     325              : mod tests {
     326              :     use super::*;
     327              : 
     328              :     #[test]
     329            2 :     fn test_filtered_options() {
     330            2 :         // Empty options is unlikely to be useful anyway.
     331            2 :         let params = StartupMessageParams::new([("options", "")]);
     332            2 :         assert_eq!(filtered_options(&params), None);
     333              : 
     334              :         // It's likely that clients will only use options to specify endpoint/project.
     335            2 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     336            2 :         assert_eq!(filtered_options(&params), None);
     337              : 
     338              :         // Same, because unescaped whitespaces are no-op.
     339            2 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     340            2 :         assert_eq!(filtered_options(&params).as_deref(), None);
     341              : 
     342            2 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     343            2 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     344              : 
     345            2 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     346            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     347              : 
     348            2 :         let params = StartupMessageParams::new([(
     349            2 :             "options",
     350            2 :             "project = foo neon_endpoint_type:read_write   neon_lsn:0/2",
     351            2 :         )]);
     352            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     353            2 :     }
     354              : }
        

Generated by: LCOV version 2.1-beta