LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: 691a4c28fe7169edd60b367c52d448a0a6605f1f.info Lines: 23.3 % 206 48
Test Date: 2024-05-10 13:18:37 Functions: 24.1 % 29 7

            Line data    Source code
       1              : use crate::{
       2              :     auth::parse_endpoint_param,
       3              :     cancellation::CancelClosure,
       4              :     console::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError},
       5              :     context::RequestMonitoring,
       6              :     error::{ReportableError, UserFacingError},
       7              :     metrics::{Metrics, NumDbConnectionsGuard},
       8              :     proxy::neon_option,
       9              :     Host,
      10              : };
      11              : use futures::{FutureExt, TryFutureExt};
      12              : use itertools::Itertools;
      13              : use pq_proto::StartupMessageParams;
      14              : use std::{io, net::SocketAddr, time::Duration};
      15              : use thiserror::Error;
      16              : use tokio::net::TcpStream;
      17              : use tokio_postgres::tls::MakeTlsConnect;
      18              : use tracing::{error, info, warn};
      19              : 
      20              : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      21              : 
      22            0 : #[derive(Debug, Error)]
      23              : pub enum ConnectionError {
      24              :     /// This error doesn't seem to reveal any secrets; for instance,
      25              :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      26              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      27              :     Postgres(#[from] tokio_postgres::Error),
      28              : 
      29              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      30              :     CouldNotConnect(#[from] io::Error),
      31              : 
      32              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      33              :     TlsError(#[from] native_tls::Error),
      34              : 
      35              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      36              :     WakeComputeError(#[from] WakeComputeError),
      37              : 
      38              :     #[error("error acquiring resource permit: {0}")]
      39              :     TooManyConnectionAttempts(#[from] ApiLockError),
      40              : }
      41              : 
      42              : impl UserFacingError for ConnectionError {
      43            0 :     fn to_string_client(&self) -> String {
      44            0 :         use ConnectionError::*;
      45            0 :         match self {
      46              :             // This helps us drop irrelevant library-specific prefixes.
      47              :             // TODO: propagate severity level and other parameters.
      48            0 :             Postgres(err) => match err.as_db_error() {
      49            0 :                 Some(err) => {
      50            0 :                     let msg = err.message();
      51            0 : 
      52            0 :                     if msg.starts_with("unsupported startup parameter: ")
      53            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      54              :                     {
      55            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      56              :                     } else {
      57            0 :                         msg.to_owned()
      58              :                     }
      59              :                 }
      60            0 :                 None => err.to_string(),
      61              :             },
      62            0 :             WakeComputeError(err) => err.to_string_client(),
      63              :             TooManyConnectionAttempts(_) => {
      64            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
      65              :             }
      66            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      67              :         }
      68            0 :     }
      69              : }
      70              : 
      71              : impl ReportableError for ConnectionError {
      72            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      73            0 :         match self {
      74            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      75            0 :                 crate::error::ErrorKind::Postgres
      76              :             }
      77            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      78            0 :             ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
      79            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      80            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      81            0 :             ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
      82              :         }
      83            0 :     }
      84              : }
      85              : 
      86              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      87              : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      88              : 
      89              : /// A config for establishing a connection to compute node.
      90              : /// Eventually, `tokio_postgres` will be replaced with something better.
      91              : /// Newtype allows us to implement methods on top of it.
      92              : #[derive(Clone, Default)]
      93              : pub struct ConnCfg(Box<tokio_postgres::Config>);
      94              : 
      95              : /// Creation and initialization routines.
      96              : impl ConnCfg {
      97           20 :     pub fn new() -> Self {
      98           20 :         Self::default()
      99           20 :     }
     100              : 
     101              :     /// Reuse password or auth keys from the other config.
     102            8 :     pub fn reuse_password(&mut self, other: Self) {
     103            8 :         if let Some(password) = other.get_password() {
     104            8 :             self.password(password);
     105            8 :         }
     106              : 
     107            8 :         if let Some(keys) = other.get_auth_keys() {
     108            0 :             self.auth_keys(keys);
     109            8 :         }
     110            8 :     }
     111              : 
     112            0 :     pub fn get_host(&self) -> Result<Host, WakeComputeError> {
     113            0 :         match self.0.get_hosts() {
     114            0 :             [tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
     115              :             // we should not have multiple address or unix addresses.
     116            0 :             _ => Err(WakeComputeError::BadComputeAddress(
     117            0 :                 "invalid compute address".into(),
     118            0 :             )),
     119              :         }
     120            0 :     }
     121              : 
     122              :     /// Apply startup message params to the connection config.
     123            0 :     pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
     124              :         // Only set `user` if it's not present in the config.
     125              :         // Link auth flow takes username from the console's response.
     126            0 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
     127            0 :             self.user(user);
     128            0 :         }
     129              : 
     130              :         // Only set `dbname` if it's not present in the config.
     131              :         // Link auth flow takes dbname from the console's response.
     132            0 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
     133            0 :             self.dbname(dbname);
     134            0 :         }
     135              : 
     136              :         // Don't add `options` if they were only used for specifying a project.
     137              :         // Connection pools don't support `options`, because they affect backend startup.
     138            0 :         if let Some(options) = filtered_options(params) {
     139            0 :             self.options(&options);
     140            0 :         }
     141              : 
     142            0 :         if let Some(app_name) = params.get("application_name") {
     143            0 :             self.application_name(app_name);
     144            0 :         }
     145              : 
     146              :         // TODO: This is especially ugly...
     147            0 :         if let Some(replication) = params.get("replication") {
     148              :             use tokio_postgres::config::ReplicationMode;
     149            0 :             match replication {
     150            0 :                 "true" | "on" | "yes" | "1" => {
     151            0 :                     self.replication_mode(ReplicationMode::Physical);
     152            0 :                 }
     153            0 :                 "database" => {
     154            0 :                     self.replication_mode(ReplicationMode::Logical);
     155            0 :                 }
     156            0 :                 _other => {}
     157              :             }
     158            0 :         }
     159              : 
     160              :         // TODO: extend the list of the forwarded startup parameters.
     161              :         // Currently, tokio-postgres doesn't allow us to pass
     162              :         // arbitrary parameters, but the ones above are a good start.
     163              :         //
     164              :         // This and the reverse params problem can be better addressed
     165              :         // in a bespoke connection machinery (a new library for that sake).
     166            0 :     }
     167              : }
     168              : 
     169              : impl std::ops::Deref for ConnCfg {
     170              :     type Target = tokio_postgres::Config;
     171              : 
     172           16 :     fn deref(&self) -> &Self::Target {
     173           16 :         &self.0
     174           16 :     }
     175              : }
     176              : 
     177              : /// For now, let's make it easier to setup the config.
     178              : impl std::ops::DerefMut for ConnCfg {
     179           20 :     fn deref_mut(&mut self) -> &mut Self::Target {
     180           20 :         &mut self.0
     181           20 :     }
     182              : }
     183              : 
     184              : impl ConnCfg {
     185              :     /// Establish a raw TCP connection to the compute node.
     186            0 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     187            0 :         use tokio_postgres::config::Host;
     188            0 : 
     189            0 :         // wrap TcpStream::connect with timeout
     190            0 :         let connect_with_timeout = |host, port| {
     191            0 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     192            0 :                 move |res| match res {
     193            0 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     194            0 :                     Err(_) => Err(io::Error::new(
     195            0 :                         io::ErrorKind::TimedOut,
     196            0 :                         format!("exceeded connection timeout {timeout:?}"),
     197            0 :                     )),
     198            0 :                 },
     199            0 :             )
     200            0 :         };
     201              : 
     202            0 :         let connect_once = |host, port| {
     203            0 :             info!("trying to connect to compute node at {host}:{port}");
     204            0 :             connect_with_timeout(host, port).and_then(|socket| async {
     205            0 :                 let socket_addr = socket.peer_addr()?;
     206            0 :                 // This prevents load balancer from severing the connection.
     207            0 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     208            0 :                 Ok((socket_addr, socket))
     209            0 :             })
     210            0 :         };
     211              : 
     212              :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     213              :         // because it has no means for extracting the underlying socket which we
     214              :         // require for our business.
     215            0 :         let mut connection_error = None;
     216            0 :         let ports = self.0.get_ports();
     217            0 :         let hosts = self.0.get_hosts();
     218            0 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     219            0 :         if ports.len() > 1 && ports.len() != hosts.len() {
     220            0 :             return Err(io::Error::new(
     221            0 :                 io::ErrorKind::Other,
     222            0 :                 format!(
     223            0 :                     "bad compute config, \
     224            0 :                      ports and hosts entries' count does not match: {:?}",
     225            0 :                     self.0
     226            0 :                 ),
     227            0 :             ));
     228            0 :         }
     229              : 
     230            0 :         for (i, host) in hosts.iter().enumerate() {
     231            0 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     232            0 :             let host = match host {
     233            0 :                 Host::Tcp(host) => host.as_str(),
     234            0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     235              :             };
     236              : 
     237            0 :             match connect_once(host, *port).await {
     238            0 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     239            0 :                 Err(err) => {
     240            0 :                     // We can't throw an error here, as there might be more hosts to try.
     241            0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     242            0 :                     connection_error = Some(err);
     243              :                 }
     244              :             }
     245              :         }
     246              : 
     247            0 :         Err(connection_error.unwrap_or_else(|| {
     248            0 :             io::Error::new(
     249            0 :                 io::ErrorKind::Other,
     250            0 :                 format!("bad compute config: {:?}", self.0),
     251            0 :             )
     252            0 :         }))
     253            0 :     }
     254              : }
     255              : 
     256              : pub struct PostgresConnection {
     257              :     /// Socket connected to a compute node.
     258              :     pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     259              :         tokio::net::TcpStream,
     260              :         postgres_native_tls::TlsStream<tokio::net::TcpStream>,
     261              :     >,
     262              :     /// PostgreSQL connection parameters.
     263              :     pub params: std::collections::HashMap<String, String>,
     264              :     /// Query cancellation token.
     265              :     pub cancel_closure: CancelClosure,
     266              :     /// Labels for proxy's metrics.
     267              :     pub aux: MetricsAuxInfo,
     268              : 
     269              :     _guage: NumDbConnectionsGuard<'static>,
     270              : }
     271              : 
     272              : impl ConnCfg {
     273              :     /// Connect to a corresponding compute node.
     274            0 :     pub async fn connect(
     275            0 :         &self,
     276            0 :         ctx: &mut RequestMonitoring,
     277            0 :         allow_self_signed_compute: bool,
     278            0 :         aux: MetricsAuxInfo,
     279            0 :         timeout: Duration,
     280            0 :     ) -> Result<PostgresConnection, ConnectionError> {
     281            0 :         let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
     282            0 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     283            0 :         drop(pause);
     284            0 : 
     285            0 :         let tls_connector = native_tls::TlsConnector::builder()
     286            0 :             .danger_accept_invalid_certs(allow_self_signed_compute)
     287            0 :             .build()
     288            0 :             .unwrap();
     289            0 :         let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
     290            0 :         let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
     291              : 
     292              :         // connect_raw() will not use TLS if sslmode is "disable"
     293            0 :         let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
     294            0 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     295            0 :         drop(pause);
     296            0 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     297            0 :         let stream = connection.stream.into_inner();
     298            0 : 
     299            0 :         info!(
     300            0 :             cold_start_info = ctx.cold_start_info.as_str(),
     301            0 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     302            0 :             self.0.get_ssl_mode()
     303              :         );
     304              : 
     305              :         // This is very ugly but as of now there's no better way to
     306              :         // extract the connection parameters from tokio-postgres' connection.
     307              :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     308            0 :         let params = connection.parameters;
     309            0 : 
     310            0 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     311            0 :         // Yet another reason to rework the connection establishing code.
     312            0 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     313            0 : 
     314            0 :         let connection = PostgresConnection {
     315            0 :             stream,
     316            0 :             params,
     317            0 :             cancel_closure,
     318            0 :             aux,
     319            0 :             _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol),
     320            0 :         };
     321            0 : 
     322            0 :         Ok(connection)
     323            0 :     }
     324              : }
     325              : 
     326              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     327           12 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     328              :     #[allow(unstable_name_collisions)]
     329           12 :     let options: String = params
     330           12 :         .options_raw()?
     331           26 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     332           12 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     333           12 :         .collect();
     334           12 : 
     335           12 :     // Don't even bother with empty options.
     336           12 :     if options.is_empty() {
     337            6 :         return None;
     338            6 :     }
     339            6 : 
     340            6 :     Some(options)
     341           12 : }
     342              : 
     343              : #[cfg(test)]
     344              : mod tests {
     345              :     use super::*;
     346              : 
     347              :     #[test]
     348            2 :     fn test_filtered_options() {
     349            2 :         // Empty options is unlikely to be useful anyway.
     350            2 :         let params = StartupMessageParams::new([("options", "")]);
     351            2 :         assert_eq!(filtered_options(&params), None);
     352              : 
     353              :         // It's likely that clients will only use options to specify endpoint/project.
     354            2 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     355            2 :         assert_eq!(filtered_options(&params), None);
     356              : 
     357              :         // Same, because unescaped whitespaces are no-op.
     358            2 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     359            2 :         assert_eq!(filtered_options(&params).as_deref(), None);
     360              : 
     361            2 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     362            2 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     363              : 
     364            2 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     365            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     366              : 
     367            2 :         let params = StartupMessageParams::new([(
     368            2 :             "options",
     369            2 :             "project = foo neon_endpoint_type:read_write   neon_lsn:0/2",
     370            2 :         )]);
     371            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     372            2 :     }
     373              : }
        

Generated by: LCOV version 2.1-beta