LCOV - code coverage report
Current view: top level - proxy/src/proxy - mod.rs (source / functions) Coverage Total Hit
Test: 6fa910d1c9aea142e54ede6987809ef55544c500.info Lines: 13.0 % 269 35
Test Date: 2024-11-19 23:07:42 Functions: 16.4 % 55 9

            Line data    Source code
       1              : #[cfg(test)]
       2              : mod tests;
       3              : 
       4              : pub(crate) mod connect_compute;
       5              : mod copy_bidirectional;
       6              : pub(crate) mod handshake;
       7              : pub(crate) mod passthrough;
       8              : pub(crate) mod retry;
       9              : pub(crate) mod wake_compute;
      10              : use std::sync::Arc;
      11              : 
      12              : pub use copy_bidirectional::{copy_bidirectional_client_compute, ErrorSource};
      13              : use futures::TryFutureExt;
      14              : use itertools::Itertools;
      15              : use once_cell::sync::OnceCell;
      16              : use pq_proto::{BeMessage as Be, StartupMessageParams};
      17              : use regex::Regex;
      18              : use smol_str::{format_smolstr, SmolStr};
      19              : use thiserror::Error;
      20              : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
      21              : use tokio_util::sync::CancellationToken;
      22              : use tracing::{debug, error, info, warn, Instrument};
      23              : 
      24              : use self::connect_compute::{connect_to_compute, TcpMechanism};
      25              : use self::passthrough::ProxyPassthrough;
      26              : use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
      27              : use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
      28              : use crate::context::RequestMonitoring;
      29              : use crate::error::ReportableError;
      30              : use crate::metrics::{Metrics, NumClientConnectionsGuard};
      31              : use crate::protocol2::{read_proxy_protocol, ConnectHeader, ConnectionInfo};
      32              : use crate::proxy::handshake::{handshake, HandshakeData};
      33              : use crate::rate_limiter::EndpointRateLimiter;
      34              : use crate::stream::{PqStream, Stream};
      35              : use crate::types::EndpointCacheKey;
      36              : use crate::{auth, compute};
      37              : 
      38              : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
      39              : 
      40            0 : pub async fn run_until_cancelled<F: std::future::Future>(
      41            0 :     f: F,
      42            0 :     cancellation_token: &CancellationToken,
      43            0 : ) -> Option<F::Output> {
      44            0 :     match futures::future::select(
      45            0 :         std::pin::pin!(f),
      46            0 :         std::pin::pin!(cancellation_token.cancelled()),
      47            0 :     )
      48            0 :     .await
      49              :     {
      50            0 :         futures::future::Either::Left((f, _)) => Some(f),
      51            0 :         futures::future::Either::Right(((), _)) => None,
      52              :     }
      53            0 : }
      54              : 
      55            0 : pub async fn task_main(
      56            0 :     config: &'static ProxyConfig,
      57            0 :     auth_backend: &'static auth::Backend<'static, ()>,
      58            0 :     listener: tokio::net::TcpListener,
      59            0 :     cancellation_token: CancellationToken,
      60            0 :     cancellation_handler: Arc<CancellationHandlerMain>,
      61            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      62            0 : ) -> anyhow::Result<()> {
      63            0 :     scopeguard::defer! {
      64            0 :         info!("proxy has shut down");
      65            0 :     }
      66            0 : 
      67            0 :     // When set for the server socket, the keepalive setting
      68            0 :     // will be inherited by all accepted client sockets.
      69            0 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
      70              : 
      71            0 :     let connections = tokio_util::task::task_tracker::TaskTracker::new();
      72              : 
      73            0 :     while let Some(accept_result) =
      74            0 :         run_until_cancelled(listener.accept(), &cancellation_token).await
      75              :     {
      76            0 :         let (socket, peer_addr) = accept_result?;
      77              : 
      78            0 :         let conn_gauge = Metrics::get()
      79            0 :             .proxy
      80            0 :             .client_connections
      81            0 :             .guard(crate::metrics::Protocol::Tcp);
      82            0 : 
      83            0 :         let session_id = uuid::Uuid::new_v4();
      84            0 :         let cancellation_handler = Arc::clone(&cancellation_handler);
      85            0 : 
      86            0 :         debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
      87            0 :         let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
      88            0 : 
      89            0 :         connections.spawn(async move {
      90            0 :             let (socket, conn_info) = match read_proxy_protocol(socket).await {
      91            0 :                 Err(e) => {
      92            0 :                     warn!("per-client task finished with an error: {e:#}");
      93            0 :                     return;
      94              :                 }
      95              :                 // our load balancers will not send any more data. let's just exit immediately
      96            0 :                 Ok((_socket, ConnectHeader::Local)) => {
      97            0 :                     debug!("healthcheck received");
      98            0 :                     return;
      99              :                 }
     100            0 :                 Ok((_socket, ConnectHeader::Missing)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
     101            0 :                     warn!("missing required proxy protocol header");
     102            0 :                     return;
     103              :                 }
     104            0 :                 Ok((_socket, ConnectHeader::Proxy(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
     105            0 :                     warn!("proxy protocol header not supported");
     106            0 :                     return;
     107              :                 }
     108            0 :                 Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
     109            0 :                 Ok((socket, ConnectHeader::Missing)) => (socket, ConnectionInfo { addr: peer_addr, extra: None }),
     110              :             };
     111              : 
     112            0 :             match socket.inner.set_nodelay(true) {
     113            0 :                 Ok(()) => {}
     114            0 :                 Err(e) => {
     115            0 :                     error!("per-client task finished with an error: failed to set socket option: {e:#}");
     116            0 :                     return;
     117              :                 }
     118              :             };
     119              : 
     120            0 :             let ctx = RequestMonitoring::new(
     121            0 :                 session_id,
     122            0 :                 conn_info,
     123            0 :                 crate::metrics::Protocol::Tcp,
     124            0 :                 &config.region,
     125            0 :             );
     126            0 :             let span = ctx.span();
     127            0 : 
     128            0 :             let startup = Box::pin(
     129            0 :                 handle_client(
     130            0 :                     config,
     131            0 :                     auth_backend,
     132            0 :                     &ctx,
     133            0 :                     cancellation_handler,
     134            0 :                     socket,
     135            0 :                     ClientMode::Tcp,
     136            0 :                     endpoint_rate_limiter2,
     137            0 :                     conn_gauge,
     138            0 :                 )
     139            0 :                 .instrument(span.clone()),
     140            0 :             );
     141            0 :             let res = startup.await;
     142              : 
     143            0 :             match res {
     144            0 :                 Err(e) => {
     145            0 :                     // todo: log and push to ctx the error kind
     146            0 :                     ctx.set_error_kind(e.get_error_kind());
     147            0 :                     warn!(parent: &span, "per-client task finished with an error: {e:#}");
     148              :                 }
     149            0 :                 Ok(None) => {
     150            0 :                     ctx.set_success();
     151            0 :                 }
     152            0 :                 Ok(Some(p)) => {
     153            0 :                     ctx.set_success();
     154            0 :                     ctx.log_connect();
     155            0 :                     match p.proxy_pass().instrument(span.clone()).await {
     156            0 :                         Ok(()) => {}
     157            0 :                         Err(ErrorSource::Client(e)) => {
     158            0 :                             warn!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
     159              :                         }
     160            0 :                         Err(ErrorSource::Compute(e)) => {
     161            0 :                             error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
     162              :                         }
     163              :                     }
     164              :                 }
     165              :             }
     166            0 :         });
     167              :     }
     168              : 
     169            0 :     connections.close();
     170            0 :     drop(listener);
     171            0 : 
     172            0 :     // Drain connections
     173            0 :     connections.wait().await;
     174              : 
     175            0 :     Ok(())
     176            0 : }
     177              : 
     178              : pub(crate) enum ClientMode {
     179              :     Tcp,
     180              :     Websockets { hostname: Option<String> },
     181              : }
     182              : 
     183              : /// Abstracts the logic of handling TCP vs WS clients
     184              : impl ClientMode {
     185            0 :     pub(crate) fn allow_cleartext(&self) -> bool {
     186            0 :         match self {
     187            0 :             ClientMode::Tcp => false,
     188            0 :             ClientMode::Websockets { .. } => true,
     189              :         }
     190            0 :     }
     191              : 
     192            0 :     pub(crate) fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
     193            0 :         match self {
     194            0 :             ClientMode::Tcp => config.allow_self_signed_compute,
     195            0 :             ClientMode::Websockets { .. } => false,
     196              :         }
     197            0 :     }
     198              : 
     199            0 :     fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
     200            0 :         match self {
     201            0 :             ClientMode::Tcp => s.sni_hostname(),
     202            0 :             ClientMode::Websockets { hostname } => hostname.as_deref(),
     203              :         }
     204            0 :     }
     205              : 
     206            0 :     fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
     207            0 :         match self {
     208            0 :             ClientMode::Tcp => tls,
     209              :             // TLS is None here if using websockets, because the connection is already encrypted.
     210            0 :             ClientMode::Websockets { .. } => None,
     211              :         }
     212            0 :     }
     213              : }
     214              : 
     215            0 : #[derive(Debug, Error)]
     216              : // almost all errors should be reported to the user, but there's a few cases where we cannot
     217              : // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
     218              : // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
     219              : //    we cannot be sure the client even understands our error message
     220              : // 3. PrepareClient: The client disconnected, so we can't tell them anyway...
     221              : pub(crate) enum ClientRequestError {
     222              :     #[error("{0}")]
     223              :     Cancellation(#[from] cancellation::CancelError),
     224              :     #[error("{0}")]
     225              :     Handshake(#[from] handshake::HandshakeError),
     226              :     #[error("{0}")]
     227              :     HandshakeTimeout(#[from] tokio::time::error::Elapsed),
     228              :     #[error("{0}")]
     229              :     PrepareClient(#[from] std::io::Error),
     230              :     #[error("{0}")]
     231              :     ReportedError(#[from] crate::stream::ReportedError),
     232              : }
     233              : 
     234              : impl ReportableError for ClientRequestError {
     235            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     236            0 :         match self {
     237            0 :             ClientRequestError::Cancellation(e) => e.get_error_kind(),
     238            0 :             ClientRequestError::Handshake(e) => e.get_error_kind(),
     239            0 :             ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
     240            0 :             ClientRequestError::ReportedError(e) => e.get_error_kind(),
     241            0 :             ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
     242              :         }
     243            0 :     }
     244              : }
     245              : 
     246              : #[allow(clippy::too_many_arguments)]
     247            0 : pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
     248            0 :     config: &'static ProxyConfig,
     249            0 :     auth_backend: &'static auth::Backend<'static, ()>,
     250            0 :     ctx: &RequestMonitoring,
     251            0 :     cancellation_handler: Arc<CancellationHandlerMain>,
     252            0 :     stream: S,
     253            0 :     mode: ClientMode,
     254            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     255            0 :     conn_gauge: NumClientConnectionsGuard<'static>,
     256            0 : ) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
     257            0 :     info!(
     258            0 :         protocol = %ctx.protocol(),
     259            0 :         "handling interactive connection from client"
     260              :     );
     261              : 
     262            0 :     let metrics = &Metrics::get().proxy;
     263            0 :     let proto = ctx.protocol();
     264            0 :     let request_gauge = metrics.connection_requests.guard(proto);
     265            0 : 
     266            0 :     let tls = config.tls_config.as_ref();
     267            0 : 
     268            0 :     let record_handshake_error = !ctx.has_private_peer_addr();
     269            0 :     let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
     270            0 :     let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
     271            0 :     let (mut stream, params) =
     272            0 :         match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
     273            0 :             HandshakeData::Startup(stream, params) => (stream, params),
     274            0 :             HandshakeData::Cancel(cancel_key_data) => {
     275            0 :                 return Ok(cancellation_handler
     276            0 :                     .cancel_session(cancel_key_data, ctx.session_id())
     277            0 :                     .await
     278            0 :                     .map(|()| None)?)
     279              :             }
     280              :         };
     281            0 :     drop(pause);
     282            0 : 
     283            0 :     ctx.set_db_options(params.clone());
     284            0 : 
     285            0 :     let hostname = mode.hostname(stream.get_ref());
     286            0 : 
     287            0 :     let common_names = tls.map(|tls| &tls.common_names);
     288            0 : 
     289            0 :     // Extract credentials which we're going to use for auth.
     290            0 :     let result = auth_backend
     291            0 :         .as_ref()
     292            0 :         .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
     293            0 :         .transpose();
     294              : 
     295            0 :     let user_info = match result {
     296            0 :         Ok(user_info) => user_info,
     297            0 :         Err(e) => stream.throw_error(e).await?,
     298              :     };
     299              : 
     300            0 :     let user = user_info.get_user().to_owned();
     301            0 :     let user_info = match user_info
     302            0 :         .authenticate(
     303            0 :             ctx,
     304            0 :             &mut stream,
     305            0 :             mode.allow_cleartext(),
     306            0 :             &config.authentication_config,
     307            0 :             endpoint_rate_limiter,
     308            0 :         )
     309            0 :         .await
     310              :     {
     311            0 :         Ok(auth_result) => auth_result,
     312            0 :         Err(e) => {
     313            0 :             let db = params.get("database");
     314            0 :             let app = params.get("application_name");
     315            0 :             let params_span = tracing::info_span!("", ?user, ?db, ?app);
     316              : 
     317            0 :             return stream.throw_error(e).instrument(params_span).await?;
     318              :         }
     319              :     };
     320              : 
     321            0 :     let mut node = connect_to_compute(
     322            0 :         ctx,
     323            0 :         &TcpMechanism {
     324            0 :             params: &params,
     325            0 :             locks: &config.connect_compute_locks,
     326            0 :         },
     327            0 :         &user_info,
     328            0 :         mode.allow_self_signed_compute(config),
     329            0 :         config.wake_compute_retry_config,
     330            0 :         config.connect_to_compute_retry_config,
     331            0 :     )
     332            0 :     .or_else(|e| stream.throw_error(e))
     333            0 :     .await?;
     334              : 
     335            0 :     let session = cancellation_handler.get_session();
     336            0 :     prepare_client_connection(&node, &session, &mut stream).await?;
     337              : 
     338              :     // Before proxy passing, forward to compute whatever data is left in the
     339              :     // PqStream input buffer. Normally there is none, but our serverless npm
     340              :     // driver in pipeline mode sends startup, password and first query
     341              :     // immediately after opening the connection.
     342            0 :     let (stream, read_buf) = stream.into_inner();
     343            0 :     node.stream.write_all(&read_buf).await?;
     344              : 
     345            0 :     Ok(Some(ProxyPassthrough {
     346            0 :         client: stream,
     347            0 :         aux: node.aux.clone(),
     348            0 :         compute: node,
     349            0 :         _req: request_gauge,
     350            0 :         _conn: conn_gauge,
     351            0 :         _cancel: session,
     352            0 :     }))
     353            0 : }
     354              : 
     355              : /// Finish client connection initialization: confirm auth success, send params, etc.
     356            0 : #[tracing::instrument(skip_all)]
     357              : pub(crate) async fn prepare_client_connection<P>(
     358              :     node: &compute::PostgresConnection,
     359              :     session: &cancellation::Session<P>,
     360              :     stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     361              : ) -> Result<(), std::io::Error> {
     362              :     // Register compute's query cancellation token and produce a new, unique one.
     363              :     // The new token (cancel_key_data) will be sent to the client.
     364              :     let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
     365              : 
     366              :     // Forward all postgres connection params to the client.
     367              :     // Right now the implementation is very hacky and inefficent (ideally,
     368              :     // we don't need an intermediate hashmap), but at least it should be correct.
     369              :     for (name, value) in &node.params {
     370              :         // TODO: Theoretically, this could result in a big pile of params...
     371              :         stream.write_message_noflush(&Be::ParameterStatus {
     372              :             name: name.as_bytes(),
     373              :             value: value.as_bytes(),
     374              :         })?;
     375              :     }
     376              : 
     377              :     stream
     378              :         .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
     379              :         .write_message(&Be::ReadyForQuery)
     380              :         .await?;
     381              : 
     382              :     Ok(())
     383              : }
     384              : 
     385              : #[derive(Debug, Clone, PartialEq, Eq, Default)]
     386              : pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
     387              : 
     388              : impl NeonOptions {
     389           11 :     pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
     390           11 :         params
     391           11 :             .options_raw()
     392           11 :             .map(Self::parse_from_iter)
     393           11 :             .unwrap_or_default()
     394           11 :     }
     395            7 :     pub(crate) fn parse_options_raw(options: &str) -> Self {
     396            7 :         Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
     397            7 :     }
     398              : 
     399            2 :     pub(crate) fn is_ephemeral(&self) -> bool {
     400            2 :         // Currently, neon endpoint options are all reserved for ephemeral endpoints.
     401            2 :         !self.0.is_empty()
     402            2 :     }
     403              : 
     404           13 :     fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
     405           13 :         let mut options = options
     406           13 :             .filter_map(neon_option)
     407           13 :             .map(|(k, v)| (k.into(), v.into()))
     408           13 :             .collect_vec();
     409           13 :         options.sort();
     410           13 :         Self(options)
     411           13 :     }
     412              : 
     413            4 :     pub(crate) fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
     414            4 :         // prefix + format!(" {k}:{v}")
     415            4 :         // kinda jank because SmolStr is immutable
     416            4 :         std::iter::once(prefix)
     417            4 :             .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
     418            4 :             .collect::<SmolStr>()
     419            4 :             .into()
     420            4 :     }
     421              : 
     422              :     /// <https://swagger.io/docs/specification/serialization/> DeepObject format
     423              :     /// `paramName[prop1]=value1&paramName[prop2]=value2&...`
     424            0 :     pub(crate) fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
     425            0 :         self.0
     426            0 :             .iter()
     427            0 :             .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
     428            0 :             .collect()
     429            0 :     }
     430              : }
     431              : 
     432           32 : pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
     433              :     static RE: OnceCell<Regex> = OnceCell::new();
     434           32 :     let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
     435              : 
     436           32 :     let cap = re.captures(bytes)?;
     437            4 :     let (_, [k, v]) = cap.extract();
     438            4 :     Some((k, v))
     439           32 : }
        

Generated by: LCOV version 2.1-beta