LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: b837401fb09d2d9818b70e630fdb67e9799b7b0d.info Lines: 25.0 % 192 48
Test Date: 2024-04-18 15:32:49 Functions: 23.3 % 30 7

            Line data    Source code
       1              : use crate::{
       2              :     auth::parse_endpoint_param,
       3              :     cancellation::CancelClosure,
       4              :     console::{errors::WakeComputeError, messages::MetricsAuxInfo},
       5              :     context::RequestMonitoring,
       6              :     error::{ReportableError, UserFacingError},
       7              :     metrics::{Metrics, NumDbConnectionsGuard},
       8              :     proxy::neon_option,
       9              : };
      10              : use futures::{FutureExt, TryFutureExt};
      11              : use itertools::Itertools;
      12              : use pq_proto::StartupMessageParams;
      13              : use std::{io, net::SocketAddr, time::Duration};
      14              : use thiserror::Error;
      15              : use tokio::net::TcpStream;
      16              : use tokio_postgres::tls::MakeTlsConnect;
      17              : use tracing::{error, info, warn};
      18              : 
      19              : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      20              : 
      21            0 : #[derive(Debug, Error)]
      22              : pub enum ConnectionError {
      23              :     /// This error doesn't seem to reveal any secrets; for instance,
      24              :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      25              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      26              :     Postgres(#[from] tokio_postgres::Error),
      27              : 
      28              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      29              :     CouldNotConnect(#[from] io::Error),
      30              : 
      31              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      32              :     TlsError(#[from] native_tls::Error),
      33              : 
      34              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      35              :     WakeComputeError(#[from] WakeComputeError),
      36              : }
      37              : 
      38              : impl UserFacingError for ConnectionError {
      39            0 :     fn to_string_client(&self) -> String {
      40            0 :         use ConnectionError::*;
      41            0 :         match self {
      42              :             // This helps us drop irrelevant library-specific prefixes.
      43              :             // TODO: propagate severity level and other parameters.
      44            0 :             Postgres(err) => match err.as_db_error() {
      45            0 :                 Some(err) => {
      46            0 :                     let msg = err.message();
      47            0 : 
      48            0 :                     if msg.starts_with("unsupported startup parameter: ")
      49            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      50              :                     {
      51            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      52              :                     } else {
      53            0 :                         msg.to_owned()
      54              :                     }
      55              :                 }
      56            0 :                 None => err.to_string(),
      57              :             },
      58            0 :             WakeComputeError(err) => err.to_string_client(),
      59            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      60              :         }
      61            0 :     }
      62              : }
      63              : 
      64              : impl ReportableError for ConnectionError {
      65            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      66            0 :         match self {
      67            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      68            0 :                 crate::error::ErrorKind::Postgres
      69              :             }
      70            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      71            0 :             ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
      72            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      73            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      74              :         }
      75            0 :     }
      76              : }
      77              : 
      78              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      79              : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      80              : 
      81              : /// A config for establishing a connection to compute node.
      82              : /// Eventually, `tokio_postgres` will be replaced with something better.
      83              : /// Newtype allows us to implement methods on top of it.
      84              : #[derive(Clone, Default)]
      85              : pub struct ConnCfg(Box<tokio_postgres::Config>);
      86              : 
      87              : /// Creation and initialization routines.
      88              : impl ConnCfg {
      89           20 :     pub fn new() -> Self {
      90           20 :         Self::default()
      91           20 :     }
      92              : 
      93              :     /// Reuse password or auth keys from the other config.
      94            8 :     pub fn reuse_password(&mut self, other: Self) {
      95            8 :         if let Some(password) = other.get_password() {
      96            8 :             self.password(password);
      97            8 :         }
      98              : 
      99            8 :         if let Some(keys) = other.get_auth_keys() {
     100            0 :             self.auth_keys(keys);
     101            8 :         }
     102            8 :     }
     103              : 
     104              :     /// Apply startup message params to the connection config.
     105            0 :     pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
     106              :         // Only set `user` if it's not present in the config.
     107              :         // Link auth flow takes username from the console's response.
     108            0 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
     109            0 :             self.user(user);
     110            0 :         }
     111              : 
     112              :         // Only set `dbname` if it's not present in the config.
     113              :         // Link auth flow takes dbname from the console's response.
     114            0 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
     115            0 :             self.dbname(dbname);
     116            0 :         }
     117              : 
     118              :         // Don't add `options` if they were only used for specifying a project.
     119              :         // Connection pools don't support `options`, because they affect backend startup.
     120            0 :         if let Some(options) = filtered_options(params) {
     121            0 :             self.options(&options);
     122            0 :         }
     123              : 
     124            0 :         if let Some(app_name) = params.get("application_name") {
     125            0 :             self.application_name(app_name);
     126            0 :         }
     127              : 
     128              :         // TODO: This is especially ugly...
     129            0 :         if let Some(replication) = params.get("replication") {
     130              :             use tokio_postgres::config::ReplicationMode;
     131            0 :             match replication {
     132            0 :                 "true" | "on" | "yes" | "1" => {
     133            0 :                     self.replication_mode(ReplicationMode::Physical);
     134            0 :                 }
     135            0 :                 "database" => {
     136            0 :                     self.replication_mode(ReplicationMode::Logical);
     137            0 :                 }
     138            0 :                 _other => {}
     139              :             }
     140            0 :         }
     141              : 
     142              :         // TODO: extend the list of the forwarded startup parameters.
     143              :         // Currently, tokio-postgres doesn't allow us to pass
     144              :         // arbitrary parameters, but the ones above are a good start.
     145              :         //
     146              :         // This and the reverse params problem can be better addressed
     147              :         // in a bespoke connection machinery (a new library for that sake).
     148            0 :     }
     149              : }
     150              : 
     151              : impl std::ops::Deref for ConnCfg {
     152              :     type Target = tokio_postgres::Config;
     153              : 
     154           16 :     fn deref(&self) -> &Self::Target {
     155           16 :         &self.0
     156           16 :     }
     157              : }
     158              : 
     159              : /// For now, let's make it easier to setup the config.
     160              : impl std::ops::DerefMut for ConnCfg {
     161           20 :     fn deref_mut(&mut self) -> &mut Self::Target {
     162           20 :         &mut self.0
     163           20 :     }
     164              : }
     165              : 
     166              : impl ConnCfg {
     167              :     /// Establish a raw TCP connection to the compute node.
     168            0 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     169            0 :         use tokio_postgres::config::Host;
     170            0 : 
     171            0 :         // wrap TcpStream::connect with timeout
     172            0 :         let connect_with_timeout = |host, port| {
     173            0 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     174            0 :                 move |res| match res {
     175            0 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     176            0 :                     Err(_) => Err(io::Error::new(
     177            0 :                         io::ErrorKind::TimedOut,
     178            0 :                         format!("exceeded connection timeout {timeout:?}"),
     179            0 :                     )),
     180            0 :                 },
     181            0 :             )
     182            0 :         };
     183              : 
     184            0 :         let connect_once = |host, port| {
     185            0 :             info!("trying to connect to compute node at {host}:{port}");
     186            0 :             connect_with_timeout(host, port).and_then(|socket| async {
     187            0 :                 let socket_addr = socket.peer_addr()?;
     188              :                 // This prevents load balancer from severing the connection.
     189            0 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     190            0 :                 Ok((socket_addr, socket))
     191            0 :             })
     192            0 :         };
     193              : 
     194              :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     195              :         // because it has no means for extracting the underlying socket which we
     196              :         // require for our business.
     197            0 :         let mut connection_error = None;
     198            0 :         let ports = self.0.get_ports();
     199            0 :         let hosts = self.0.get_hosts();
     200            0 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     201            0 :         if ports.len() > 1 && ports.len() != hosts.len() {
     202            0 :             return Err(io::Error::new(
     203            0 :                 io::ErrorKind::Other,
     204            0 :                 format!(
     205            0 :                     "bad compute config, \
     206            0 :                      ports and hosts entries' count does not match: {:?}",
     207            0 :                     self.0
     208            0 :                 ),
     209            0 :             ));
     210            0 :         }
     211              : 
     212            0 :         for (i, host) in hosts.iter().enumerate() {
     213            0 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     214            0 :             let host = match host {
     215            0 :                 Host::Tcp(host) => host.as_str(),
     216            0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     217              :             };
     218              : 
     219            0 :             match connect_once(host, *port).await {
     220            0 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     221            0 :                 Err(err) => {
     222            0 :                     // We can't throw an error here, as there might be more hosts to try.
     223            0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     224            0 :                     connection_error = Some(err);
     225              :                 }
     226              :             }
     227              :         }
     228              : 
     229            0 :         Err(connection_error.unwrap_or_else(|| {
     230            0 :             io::Error::new(
     231            0 :                 io::ErrorKind::Other,
     232            0 :                 format!("bad compute config: {:?}", self.0),
     233            0 :             )
     234            0 :         }))
     235            0 :     }
     236              : }
     237              : 
     238              : pub struct PostgresConnection {
     239              :     /// Socket connected to a compute node.
     240              :     pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     241              :         tokio::net::TcpStream,
     242              :         postgres_native_tls::TlsStream<tokio::net::TcpStream>,
     243              :     >,
     244              :     /// PostgreSQL connection parameters.
     245              :     pub params: std::collections::HashMap<String, String>,
     246              :     /// Query cancellation token.
     247              :     pub cancel_closure: CancelClosure,
     248              :     /// Labels for proxy's metrics.
     249              :     pub aux: MetricsAuxInfo,
     250              : 
     251              :     _guage: NumDbConnectionsGuard<'static>,
     252              : }
     253              : 
     254              : impl ConnCfg {
     255              :     /// Connect to a corresponding compute node.
     256            0 :     pub async fn connect(
     257            0 :         &self,
     258            0 :         ctx: &mut RequestMonitoring,
     259            0 :         allow_self_signed_compute: bool,
     260            0 :         aux: MetricsAuxInfo,
     261            0 :         timeout: Duration,
     262            0 :     ) -> Result<PostgresConnection, ConnectionError> {
     263            0 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     264              : 
     265            0 :         let tls_connector = native_tls::TlsConnector::builder()
     266            0 :             .danger_accept_invalid_certs(allow_self_signed_compute)
     267            0 :             .build()
     268            0 :             .unwrap();
     269            0 :         let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
     270            0 :         let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
     271              : 
     272              :         // connect_raw() will not use TLS if sslmode is "disable"
     273            0 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     274            0 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     275            0 :         let stream = connection.stream.into_inner();
     276            0 : 
     277            0 :         info!(
     278            0 :             cold_start_info = ctx.cold_start_info.as_str(),
     279            0 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     280            0 :             self.0.get_ssl_mode()
     281            0 :         );
     282              : 
     283              :         // This is very ugly but as of now there's no better way to
     284              :         // extract the connection parameters from tokio-postgres' connection.
     285              :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     286            0 :         let params = connection.parameters;
     287            0 : 
     288            0 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     289            0 :         // Yet another reason to rework the connection establishing code.
     290            0 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     291            0 : 
     292            0 :         let connection = PostgresConnection {
     293            0 :             stream,
     294            0 :             params,
     295            0 :             cancel_closure,
     296            0 :             aux,
     297            0 :             _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol),
     298            0 :         };
     299            0 : 
     300            0 :         Ok(connection)
     301            0 :     }
     302              : }
     303              : 
     304              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     305           12 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     306              :     #[allow(unstable_name_collisions)]
     307           12 :     let options: String = params
     308           12 :         .options_raw()?
     309           26 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     310           12 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     311           12 :         .collect();
     312           12 : 
     313           12 :     // Don't even bother with empty options.
     314           12 :     if options.is_empty() {
     315            6 :         return None;
     316            6 :     }
     317            6 : 
     318            6 :     Some(options)
     319           12 : }
     320              : 
     321              : #[cfg(test)]
     322              : mod tests {
     323              :     use super::*;
     324              : 
     325              :     #[test]
     326            2 :     fn test_filtered_options() {
     327            2 :         // Empty options is unlikely to be useful anyway.
     328            2 :         let params = StartupMessageParams::new([("options", "")]);
     329            2 :         assert_eq!(filtered_options(&params), None);
     330              : 
     331              :         // It's likely that clients will only use options to specify endpoint/project.
     332            2 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     333            2 :         assert_eq!(filtered_options(&params), None);
     334              : 
     335              :         // Same, because unescaped whitespaces are no-op.
     336            2 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     337            2 :         assert_eq!(filtered_options(&params).as_deref(), None);
     338              : 
     339            2 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     340            2 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     341              : 
     342            2 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     343            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     344              : 
     345            2 :         let params = StartupMessageParams::new([(
     346            2 :             "options",
     347            2 :             "project = foo neon_endpoint_type:read_write   neon_lsn:0/2",
     348            2 :         )]);
     349            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     350            2 :     }
     351              : }
        

Generated by: LCOV version 2.1-beta