LCOV - code coverage report
Current view: top level - proxy/src/compute - mod.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 11.5 % 209 24
Test Date: 2025-07-16 12:29:03 Functions: 12.5 % 24 3

            Line data    Source code
       1              : mod tls;
       2              : 
       3              : use std::fmt::Debug;
       4              : use std::io;
       5              : use std::net::{IpAddr, SocketAddr};
       6              : 
       7              : use futures::{FutureExt, TryFutureExt};
       8              : use itertools::Itertools;
       9              : use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
      10              : use postgres_client::maybe_tls_stream::MaybeTlsStream;
      11              : use postgres_client::tls::MakeTlsConnect;
      12              : use postgres_client::{NoTls, RawCancelToken, RawConnection};
      13              : use postgres_protocol::message::backend::NoticeResponseBody;
      14              : use thiserror::Error;
      15              : use tokio::net::{TcpStream, lookup_host};
      16              : use tracing::{debug, error, info, warn};
      17              : 
      18              : use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
      19              : use crate::auth::parse_endpoint_param;
      20              : use crate::cancellation::CancelClosure;
      21              : use crate::compute::tls::TlsError;
      22              : use crate::config::ComputeConfig;
      23              : use crate::context::RequestContext;
      24              : use crate::control_plane::client::ApiLockError;
      25              : use crate::control_plane::errors::WakeComputeError;
      26              : use crate::control_plane::messages::MetricsAuxInfo;
      27              : use crate::error::{ReportableError, UserFacingError};
      28              : use crate::metrics::{Metrics, NumDbConnectionsGuard};
      29              : use crate::pqproto::StartupMessageParams;
      30              : use crate::proxy::neon_option;
      31              : use crate::types::Host;
      32              : 
      33              : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      34              : 
      35              : #[derive(Debug, Error)]
      36              : pub(crate) enum PostgresError {
      37              :     /// This error doesn't seem to reveal any secrets; for instance,
      38              :     /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
      39              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      40              :     Postgres(#[from] postgres_client::Error),
      41              : }
      42              : 
      43              : impl UserFacingError for PostgresError {
      44            0 :     fn to_string_client(&self) -> String {
      45            0 :         match self {
      46              :             // This helps us drop irrelevant library-specific prefixes.
      47              :             // TODO: propagate severity level and other parameters.
      48            0 :             PostgresError::Postgres(err) => match err.as_db_error() {
      49            0 :                 Some(err) => {
      50            0 :                     let msg = err.message();
      51              : 
      52            0 :                     if msg.starts_with("unsupported startup parameter: ")
      53            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      54              :                     {
      55            0 :                         format!(
      56            0 :                             "{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter"
      57              :                         )
      58              :                     } else {
      59            0 :                         msg.to_owned()
      60              :                     }
      61              :                 }
      62            0 :                 None => err.to_string(),
      63              :             },
      64              :         }
      65            0 :     }
      66              : }
      67              : 
      68              : impl ReportableError for PostgresError {
      69            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      70            0 :         match self {
      71            0 :             PostgresError::Postgres(e) if e.as_db_error().is_some() => {
      72            0 :                 crate::error::ErrorKind::Postgres
      73              :             }
      74            0 :             PostgresError::Postgres(_) => crate::error::ErrorKind::Compute,
      75              :         }
      76            0 :     }
      77              : }
      78              : 
      79              : #[derive(Debug, Error)]
      80              : pub(crate) enum ConnectionError {
      81              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      82              :     TlsError(#[from] TlsError),
      83              : 
      84              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      85              :     WakeComputeError(#[from] WakeComputeError),
      86              : 
      87              :     #[error("error acquiring resource permit: {0}")]
      88              :     TooManyConnectionAttempts(#[from] ApiLockError),
      89              : }
      90              : 
      91              : impl UserFacingError for ConnectionError {
      92            0 :     fn to_string_client(&self) -> String {
      93            0 :         match self {
      94            0 :             ConnectionError::WakeComputeError(err) => err.to_string_client(),
      95              :             ConnectionError::TooManyConnectionAttempts(_) => {
      96            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
      97              :             }
      98            0 :             ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
      99              :         }
     100            0 :     }
     101              : }
     102              : 
     103              : impl ReportableError for ConnectionError {
     104            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     105            0 :         match self {
     106            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
     107            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
     108            0 :             ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
     109              :         }
     110            0 :     }
     111              : }
     112              : 
     113              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
     114              : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
     115              : 
     116              : #[derive(Clone)]
     117              : pub enum Auth {
     118              :     /// Only used during console-redirect.
     119              :     Password(Vec<u8>),
     120              :     /// Used by sql-over-http, ws, tcp.
     121              :     Scram(Box<ScramKeys>),
     122              : }
     123              : 
     124              : /// A config for authenticating to the compute node.
     125              : pub(crate) struct AuthInfo {
     126              :     /// None for local-proxy, as we use trust-based localhost auth.
     127              :     /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
     128              :     /// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
     129              :     auth: Option<Auth>,
     130              :     server_params: StartupMessageParams,
     131              : 
     132              :     channel_binding: ChannelBinding,
     133              : 
     134              :     /// Console redirect sets user and database, we shouldn't re-use those from the params.
     135              :     skip_db_user: bool,
     136              : }
     137              : 
     138              : /// Contains only the data needed to establish a secure connection to compute.
     139              : #[derive(Clone)]
     140              : pub struct ConnectInfo {
     141              :     pub host_addr: Option<IpAddr>,
     142              :     pub host: Host,
     143              :     pub port: u16,
     144              :     pub ssl_mode: SslMode,
     145              : }
     146              : 
     147              : /// Creation and initialization routines.
     148              : impl AuthInfo {
     149            0 :     pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
     150            0 :         let mut server_params = StartupMessageParams::default();
     151            0 :         server_params.insert("database", db);
     152            0 :         server_params.insert("user", user);
     153              :         Self {
     154            0 :             auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
     155            0 :             server_params,
     156              :             skip_db_user: true,
     157              :             // pg-sni-router is a mitm so this would fail.
     158            0 :             channel_binding: ChannelBinding::Disable,
     159              :         }
     160            0 :     }
     161              : 
     162            0 :     pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
     163              :         Self {
     164            0 :             auth: match keys {
     165            0 :                 ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
     166            0 :                     Some(Auth::Scram(Box::new(auth_keys)))
     167              :                 }
     168            0 :                 ComputeCredentialKeys::JwtPayload(_) => None,
     169              :             },
     170            0 :             server_params: StartupMessageParams::default(),
     171              :             skip_db_user: false,
     172            0 :             channel_binding: ChannelBinding::Prefer,
     173              :         }
     174            0 :     }
     175              : }
     176              : 
     177              : impl ConnectInfo {
     178            0 :     pub fn to_postgres_client_config(&self) -> postgres_client::Config {
     179            0 :         let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
     180            0 :         config.ssl_mode(self.ssl_mode);
     181            0 :         if let Some(host_addr) = self.host_addr {
     182            0 :             config.set_host_addr(host_addr);
     183            0 :         }
     184            0 :         config
     185            0 :     }
     186              : }
     187              : 
     188              : impl AuthInfo {
     189            0 :     fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
     190            0 :         match &self.auth {
     191            0 :             Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
     192            0 :             Some(Auth::Password(pw)) => config.password(pw),
     193            0 :             None => &mut config,
     194              :         };
     195            0 :         config.channel_binding(self.channel_binding);
     196            0 :         for (k, v) in self.server_params.iter() {
     197            0 :             config.set_param(k, v);
     198            0 :         }
     199            0 :         config
     200            0 :     }
     201              : 
     202              :     /// Apply startup message params to the connection config.
     203            0 :     pub(crate) fn set_startup_params(
     204            0 :         &mut self,
     205            0 :         params: &StartupMessageParams,
     206            0 :         arbitrary_params: bool,
     207            0 :     ) {
     208            0 :         if !arbitrary_params {
     209            0 :             self.server_params.insert("client_encoding", "UTF8");
     210            0 :         }
     211            0 :         for (k, v) in params.iter() {
     212            0 :             match k {
     213              :                 // Only set `user` if it's not present in the config.
     214              :                 // Console redirect auth flow takes username from the console's response.
     215            0 :                 "user" | "database" if self.skip_db_user => {}
     216            0 :                 "options" => {
     217            0 :                     if let Some(options) = filtered_options(v) {
     218            0 :                         self.server_params.insert(k, &options);
     219            0 :                     }
     220              :                 }
     221            0 :                 "user" | "database" | "application_name" | "replication" => {
     222            0 :                     self.server_params.insert(k, v);
     223            0 :                 }
     224              : 
     225              :                 // if we allow arbitrary params, then we forward them through.
     226              :                 // this is a flag for a period of backwards compatibility
     227            0 :                 k if arbitrary_params => {
     228            0 :                     self.server_params.insert(k, v);
     229            0 :                 }
     230            0 :                 _ => {}
     231              :             }
     232              :         }
     233            0 :     }
     234              : 
     235            0 :     pub async fn authenticate(
     236            0 :         &self,
     237            0 :         ctx: &RequestContext,
     238            0 :         compute: &mut ComputeConnection,
     239            0 :         user_info: &ComputeUserInfo,
     240            0 :     ) -> Result<PostgresSettings, PostgresError> {
     241              :         // client config with stubbed connect info.
     242              :         // TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
     243              :         // utilising pqproto.rs.
     244            0 :         let mut tmp_config = postgres_client::Config::new(String::new(), 0);
     245              :         // We have already established SSL if necessary.
     246            0 :         tmp_config.ssl_mode(SslMode::Disable);
     247            0 :         let tmp_config = self.enrich(tmp_config);
     248              : 
     249            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     250            0 :         let connection = tmp_config
     251            0 :             .tls_and_authenticate(&mut compute.stream, NoTls)
     252            0 :             .await?;
     253            0 :         drop(pause);
     254              : 
     255              :         let RawConnection {
     256              :             stream: _,
     257            0 :             parameters,
     258            0 :             delayed_notice,
     259            0 :             process_id,
     260            0 :             secret_key,
     261            0 :         } = connection;
     262              : 
     263            0 :         tracing::Span::current().record("pid", tracing::field::display(process_id));
     264              : 
     265              :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     266              :         // Yet another reason to rework the connection establishing code.
     267            0 :         let cancel_closure = CancelClosure::new(
     268            0 :             compute.socket_addr,
     269            0 :             RawCancelToken {
     270            0 :                 ssl_mode: compute.ssl_mode,
     271            0 :                 process_id,
     272            0 :                 secret_key,
     273            0 :             },
     274            0 :             compute.hostname.to_string(),
     275            0 :             user_info.clone(),
     276              :         );
     277              : 
     278            0 :         Ok(PostgresSettings {
     279            0 :             params: parameters,
     280            0 :             cancel_closure,
     281            0 :             delayed_notice,
     282            0 :         })
     283            0 :     }
     284              : }
     285              : 
     286              : impl ConnectInfo {
     287              :     /// Establish a raw TCP+TLS connection to the compute node.
     288            0 :     async fn connect_raw(
     289            0 :         &self,
     290            0 :         config: &ComputeConfig,
     291            0 :     ) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
     292            0 :         let timeout = config.timeout;
     293              : 
     294              :         // wrap TcpStream::connect with timeout
     295            0 :         let connect_with_timeout = |addrs| {
     296            0 :             tokio::time::timeout(timeout, TcpStream::connect(addrs)).map(move |res| match res {
     297            0 :                 Ok(tcpstream_connect_res) => tcpstream_connect_res,
     298            0 :                 Err(_) => Err(io::Error::new(
     299            0 :                     io::ErrorKind::TimedOut,
     300            0 :                     format!("exceeded connection timeout {timeout:?}"),
     301            0 :                 )),
     302            0 :             })
     303            0 :         };
     304              : 
     305            0 :         let connect_once = |addrs| {
     306            0 :             debug!("trying to connect to compute node at {addrs:?}");
     307            0 :             connect_with_timeout(addrs).and_then(|stream| async {
     308            0 :                 let socket_addr = stream.peer_addr()?;
     309            0 :                 let socket = socket2::SockRef::from(&stream);
     310              :                 // Disable Nagle's algorithm to not introduce latency between
     311              :                 // client and compute.
     312            0 :                 socket.set_nodelay(true)?;
     313              :                 // This prevents load balancer from severing the connection.
     314            0 :                 socket.set_keepalive(true)?;
     315            0 :                 Ok((socket_addr, stream))
     316            0 :             })
     317            0 :         };
     318              : 
     319              :         // We can't reuse connection establishing logic from `postgres_client` here,
     320              :         // because it has no means for extracting the underlying socket which we
     321              :         // require for our business.
     322            0 :         let port = self.port;
     323            0 :         let host = &*self.host;
     324              : 
     325            0 :         let addrs = match self.host_addr {
     326            0 :             Some(addr) => vec![SocketAddr::new(addr, port)],
     327            0 :             None => lookup_host((host, port)).await?.collect(),
     328              :         };
     329              : 
     330            0 :         match connect_once(&*addrs).await {
     331            0 :             Ok((sockaddr, stream)) => Ok((
     332            0 :                 sockaddr,
     333            0 :                 tls::connect_tls(stream, self.ssl_mode, config, host).await?,
     334              :             )),
     335            0 :             Err(err) => {
     336            0 :                 warn!("couldn't connect to compute node at {host}:{port}: {err}");
     337            0 :                 Err(TlsError::Connection(err))
     338              :             }
     339              :         }
     340            0 :     }
     341              : }
     342              : 
     343              : pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
     344              : pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
     345              : 
     346              : // TODO(conrad): we don't need to parse these.
     347              : // These are just immediately forwarded back to the client.
     348              : // We could instead stream them out instead of reading them into memory.
     349              : pub struct PostgresSettings {
     350              :     /// PostgreSQL connection parameters.
     351              :     pub params: std::collections::HashMap<String, String>,
     352              :     /// Query cancellation token.
     353              :     pub cancel_closure: CancelClosure,
     354              :     /// Notices received from compute after authenticating
     355              :     pub delayed_notice: Vec<NoticeResponseBody>,
     356              : }
     357              : 
     358              : pub struct ComputeConnection {
     359              :     /// Socket connected to a compute node.
     360              :     pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
     361              :     /// Labels for proxy's metrics.
     362              :     pub aux: MetricsAuxInfo,
     363              :     pub hostname: Host,
     364              :     pub ssl_mode: SslMode,
     365              :     pub socket_addr: SocketAddr,
     366              :     pub guage: NumDbConnectionsGuard<'static>,
     367              : }
     368              : 
     369              : impl ConnectInfo {
     370              :     /// Connect to a corresponding compute node.
     371            0 :     pub async fn connect(
     372            0 :         &self,
     373            0 :         ctx: &RequestContext,
     374            0 :         aux: &MetricsAuxInfo,
     375            0 :         config: &ComputeConfig,
     376            0 :     ) -> Result<ComputeConnection, ConnectionError> {
     377            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     378            0 :         let (socket_addr, stream) = self.connect_raw(config).await?;
     379            0 :         drop(pause);
     380              : 
     381            0 :         tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
     382              : 
     383              :         // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
     384            0 :         info!(
     385            0 :             cold_start_info = ctx.cold_start_info().as_str(),
     386            0 :             "connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
     387              :             self.host,
     388              :             self.ssl_mode,
     389            0 :             ctx.get_proxy_latency(),
     390            0 :             ctx.get_testodrome_id().unwrap_or_default(),
     391              :         );
     392              : 
     393            0 :         let connection = ComputeConnection {
     394            0 :             stream,
     395            0 :             socket_addr,
     396            0 :             hostname: self.host.clone(),
     397            0 :             ssl_mode: self.ssl_mode,
     398            0 :             aux: aux.clone(),
     399            0 :             guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
     400            0 :         };
     401              : 
     402            0 :         Ok(connection)
     403            0 :     }
     404              : }
     405              : 
     406              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     407            6 : fn filtered_options(options: &str) -> Option<String> {
     408              :     #[allow(unstable_name_collisions)]
     409            6 :     let options: String = StartupMessageParams::parse_options_raw(options)
     410           14 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     411            6 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     412            6 :         .collect();
     413              : 
     414              :     // Don't even bother with empty options.
     415            6 :     if options.is_empty() {
     416            3 :         return None;
     417            3 :     }
     418              : 
     419            3 :     Some(options)
     420            6 : }
     421              : 
     422              : #[cfg(test)]
     423              : mod tests {
     424              :     use super::*;
     425              : 
     426              :     #[test]
     427            1 :     fn test_filtered_options() {
     428              :         // Empty options is unlikely to be useful anyway.
     429            1 :         let params = "";
     430            1 :         assert_eq!(filtered_options(params), None);
     431              : 
     432              :         // It's likely that clients will only use options to specify endpoint/project.
     433            1 :         let params = "project=foo";
     434            1 :         assert_eq!(filtered_options(params), None);
     435              : 
     436              :         // Same, because unescaped whitespaces are no-op.
     437            1 :         let params = " project=foo ";
     438            1 :         assert_eq!(filtered_options(params).as_deref(), None);
     439              : 
     440            1 :         let params = r"\  project=foo \ ";
     441            1 :         assert_eq!(filtered_options(params).as_deref(), Some(r"\  \ "));
     442              : 
     443            1 :         let params = "project = foo";
     444            1 :         assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
     445              : 
     446            1 :         let params = "project = foo neon_endpoint_type:read_write   neon_lsn:0/2 neon_proxy_params_compat:true";
     447            1 :         assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
     448            1 :     }
     449              : }
        

Generated by: LCOV version 2.1-beta