LCOV - code coverage report
Current view: top level - proxy/src - proxy.rs (source / functions) Coverage Total Hit
Test: 322b88762cba8ea666f63cda880cccab6936bf37.info Lines: 13.7 % 241 33
Test Date: 2024-02-29 11:57:12 Functions: 13.7 % 73 10

            Line data    Source code
       1              : #[cfg(test)]
       2              : mod tests;
       3              : 
       4              : pub mod connect_compute;
       5              : mod copy_bidirectional;
       6              : pub mod handshake;
       7              : pub mod passthrough;
       8              : pub mod retry;
       9              : pub mod wake_compute;
      10              : 
      11              : use crate::{
      12              :     auth,
      13              :     cancellation::{self, CancellationHandler},
      14              :     compute,
      15              :     config::{ProxyConfig, TlsConfig},
      16              :     context::RequestMonitoring,
      17              :     error::ReportableError,
      18              :     metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
      19              :     protocol2::WithClientIp,
      20              :     proxy::handshake::{handshake, HandshakeData},
      21              :     rate_limiter::EndpointRateLimiter,
      22              :     stream::{PqStream, Stream},
      23              :     EndpointCacheKey,
      24              : };
      25              : use futures::TryFutureExt;
      26              : use itertools::Itertools;
      27              : use once_cell::sync::OnceCell;
      28              : use pq_proto::{BeMessage as Be, StartupMessageParams};
      29              : use regex::Regex;
      30              : use smol_str::{format_smolstr, SmolStr};
      31              : use std::sync::Arc;
      32              : use thiserror::Error;
      33              : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
      34              : use tokio_util::sync::CancellationToken;
      35              : use tracing::{error, info, Instrument};
      36              : 
      37              : use self::{
      38              :     connect_compute::{connect_to_compute, TcpMechanism},
      39              :     passthrough::ProxyPassthrough,
      40              : };
      41              : 
      42              : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
      43              : 
      44            0 : pub async fn run_until_cancelled<F: std::future::Future>(
      45            0 :     f: F,
      46            0 :     cancellation_token: &CancellationToken,
      47            0 : ) -> Option<F::Output> {
      48            0 :     match futures::future::select(
      49            0 :         std::pin::pin!(f),
      50            0 :         std::pin::pin!(cancellation_token.cancelled()),
      51            0 :     )
      52            0 :     .await
      53              :     {
      54            0 :         futures::future::Either::Left((f, _)) => Some(f),
      55            0 :         futures::future::Either::Right(((), _)) => None,
      56              :     }
      57            0 : }
      58              : 
      59            0 : pub async fn task_main(
      60            0 :     config: &'static ProxyConfig,
      61            0 :     listener: tokio::net::TcpListener,
      62            0 :     cancellation_token: CancellationToken,
      63            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      64            0 :     cancellation_handler: Arc<CancellationHandler>,
      65            0 : ) -> anyhow::Result<()> {
      66            0 :     scopeguard::defer! {
      67            0 :         info!("proxy has shut down");
      68              :     }
      69              : 
      70              :     // When set for the server socket, the keepalive setting
      71              :     // will be inherited by all accepted client sockets.
      72            0 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
      73              : 
      74            0 :     let connections = tokio_util::task::task_tracker::TaskTracker::new();
      75              : 
      76            0 :     while let Some(accept_result) =
      77            0 :         run_until_cancelled(listener.accept(), &cancellation_token).await
      78            0 :     {
      79            0 :         let (socket, peer_addr) = accept_result?;
      80              : 
      81            0 :         let session_id = uuid::Uuid::new_v4();
      82            0 :         let cancellation_handler = Arc::clone(&cancellation_handler);
      83            0 :         let endpoint_rate_limiter = endpoint_rate_limiter.clone();
      84            0 : 
      85            0 :         connections.spawn(async move {
      86            0 :             let mut socket = WithClientIp::new(socket);
      87            0 :             let mut peer_addr = peer_addr.ip();
      88            0 :             match socket.wait_for_addr().await {
      89            0 :                 Ok(Some(addr)) => peer_addr = addr.ip(),
      90            0 :                 Err(e) => {
      91            0 :                     error!("per-client task finished with an error: {e:#}");
      92            0 :                     return;
      93              :                 }
      94            0 :                 Ok(None) if config.require_client_ip => {
      95            0 :                     error!("missing required client IP");
      96            0 :                     return;
      97              :                 }
      98            0 :                 Ok(None) => {}
      99              :             }
     100              : 
     101            0 :             match socket.inner.set_nodelay(true) {
     102            0 :                 Ok(()) => {},
     103            0 :                 Err(e) => {
     104            0 :                     error!("per-client task finished with an error: failed to set socket option: {e:#}");
     105            0 :                     return;
     106              :                 },
     107              :             };
     108              : 
     109            0 :             let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
     110            0 :             let span = ctx.span.clone();
     111              : 
     112            0 :             let res = handle_client(
     113            0 :                 config,
     114            0 :                 &mut ctx,
     115            0 :                 cancellation_handler,
     116            0 :                 socket,
     117            0 :                 ClientMode::Tcp,
     118            0 :                 endpoint_rate_limiter,
     119            0 :             )
     120            0 :             .instrument(span.clone())
     121            0 :             .await;
     122              : 
     123            0 :             match res {
     124            0 :                 Err(e) => {
     125            0 :                     // todo: log and push to ctx the error kind
     126            0 :                     ctx.set_error_kind(e.get_error_kind());
     127            0 :                     ctx.log();
     128            0 :                     error!(parent: &span, "per-client task finished with an error: {e:#}");
     129              :                 }
     130            0 :                 Ok(None) => {
     131            0 :                     ctx.set_success();
     132            0 :                     ctx.log();
     133            0 :                 }
     134            0 :                 Ok(Some(p)) => {
     135            0 :                     ctx.set_success();
     136            0 :                     ctx.log();
     137            0 :                     match p.proxy_pass().instrument(span.clone()).await {
     138            0 :                         Ok(()) => {}
     139            0 :                         Err(e) => {
     140            0 :                             error!(parent: &span, "per-client task finished with an error: {e:#}");
     141              :                         }
     142              :                     }
     143              :                 }
     144              :             }
     145            0 :         });
     146              :     }
     147              : 
     148            0 :     connections.close();
     149            0 :     drop(listener);
     150            0 : 
     151            0 :     // Drain connections
     152            0 :     connections.wait().await;
     153              : 
     154            0 :     Ok(())
     155            0 : }
     156              : 
     157              : pub enum ClientMode {
     158              :     Tcp,
     159              :     Websockets { hostname: Option<String> },
     160              : }
     161              : 
     162              : /// Abstracts the logic of handling TCP vs WS clients
     163              : impl ClientMode {
     164            0 :     pub fn allow_cleartext(&self) -> bool {
     165            0 :         match self {
     166            0 :             ClientMode::Tcp => false,
     167            0 :             ClientMode::Websockets { .. } => true,
     168              :         }
     169            0 :     }
     170              : 
     171            0 :     pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
     172            0 :         match self {
     173            0 :             ClientMode::Tcp => config.allow_self_signed_compute,
     174            0 :             ClientMode::Websockets { .. } => false,
     175              :         }
     176            0 :     }
     177              : 
     178            0 :     fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
     179            0 :         match self {
     180            0 :             ClientMode::Tcp => s.sni_hostname(),
     181            0 :             ClientMode::Websockets { hostname } => hostname.as_deref(),
     182              :         }
     183            0 :     }
     184              : 
     185            0 :     fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
     186            0 :         match self {
     187            0 :             ClientMode::Tcp => tls,
     188              :             // TLS is None here if using websockets, because the connection is already encrypted.
     189            0 :             ClientMode::Websockets { .. } => None,
     190              :         }
     191            0 :     }
     192              : }
     193              : 
     194            0 : #[derive(Debug, Error)]
     195              : // almost all errors should be reported to the user, but there's a few cases where we cannot
     196              : // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
     197              : // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
     198              : //    we cannot be sure the client even understands our error message
     199              : // 3. PrepareClient: The client disconnected, so we can't tell them anyway...
     200              : pub enum ClientRequestError {
     201              :     #[error("{0}")]
     202              :     Cancellation(#[from] cancellation::CancelError),
     203              :     #[error("{0}")]
     204              :     Handshake(#[from] handshake::HandshakeError),
     205              :     #[error("{0}")]
     206              :     HandshakeTimeout(#[from] tokio::time::error::Elapsed),
     207              :     #[error("{0}")]
     208              :     PrepareClient(#[from] std::io::Error),
     209              :     #[error("{0}")]
     210              :     ReportedError(#[from] crate::stream::ReportedError),
     211              : }
     212              : 
     213              : impl ReportableError for ClientRequestError {
     214            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     215            0 :         match self {
     216            0 :             ClientRequestError::Cancellation(e) => e.get_error_kind(),
     217            0 :             ClientRequestError::Handshake(e) => e.get_error_kind(),
     218            0 :             ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
     219            0 :             ClientRequestError::ReportedError(e) => e.get_error_kind(),
     220            0 :             ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
     221              :         }
     222            0 :     }
     223              : }
     224              : 
     225            0 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
     226            0 :     config: &'static ProxyConfig,
     227            0 :     ctx: &mut RequestMonitoring,
     228            0 :     cancellation_handler: Arc<CancellationHandler>,
     229            0 :     stream: S,
     230            0 :     mode: ClientMode,
     231            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     232            0 : ) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
     233            0 :     info!("handling interactive connection from client");
     234              : 
     235            0 :     let proto = ctx.protocol;
     236            0 :     let _client_gauge = NUM_CLIENT_CONNECTION_GAUGE
     237            0 :         .with_label_values(&[proto])
     238            0 :         .guard();
     239            0 :     let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
     240            0 :         .with_label_values(&[proto])
     241            0 :         .guard();
     242            0 : 
     243            0 :     let tls = config.tls_config.as_ref();
     244            0 : 
     245            0 :     let pause = ctx.latency_timer.pause();
     246            0 :     let do_handshake = handshake(stream, mode.handshake_tls(tls));
     247            0 :     let (mut stream, params) =
     248            0 :         match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
     249            0 :             HandshakeData::Startup(stream, params) => (stream, params),
     250            0 :             HandshakeData::Cancel(cancel_key_data) => {
     251            0 :                 return Ok(cancellation_handler
     252            0 :                     .cancel_session(cancel_key_data, ctx.session_id)
     253            0 :                     .await
     254            0 :                     .map(|()| None)?)
     255              :             }
     256              :         };
     257            0 :     drop(pause);
     258            0 : 
     259            0 :     let hostname = mode.hostname(stream.get_ref());
     260            0 : 
     261            0 :     let common_names = tls.map(|tls| &tls.common_names);
     262            0 : 
     263            0 :     // Extract credentials which we're going to use for auth.
     264            0 :     let result = config
     265            0 :         .auth_backend
     266            0 :         .as_ref()
     267            0 :         .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
     268            0 :         .transpose();
     269              : 
     270            0 :     let user_info = match result {
     271            0 :         Ok(user_info) => user_info,
     272            0 :         Err(e) => stream.throw_error(e).await?,
     273              :     };
     274              : 
     275              :     // check rate limit
     276            0 :     if let Some(ep) = user_info.get_endpoint() {
     277            0 :         if !endpoint_rate_limiter.check(ep) {
     278            0 :             return stream
     279            0 :                 .throw_error(auth::AuthError::too_many_connections())
     280            0 :                 .await?;
     281            0 :         }
     282            0 :     }
     283              : 
     284            0 :     let user = user_info.get_user().to_owned();
     285            0 :     let user_info = match user_info
     286            0 :         .authenticate(
     287            0 :             ctx,
     288            0 :             &mut stream,
     289            0 :             mode.allow_cleartext(),
     290            0 :             &config.authentication_config,
     291            0 :         )
     292            0 :         .await
     293              :     {
     294            0 :         Ok(auth_result) => auth_result,
     295            0 :         Err(e) => {
     296            0 :             let db = params.get("database");
     297            0 :             let app = params.get("application_name");
     298            0 :             let params_span = tracing::info_span!("", ?user, ?db, ?app);
     299              : 
     300            0 :             return stream.throw_error(e).instrument(params_span).await?;
     301              :         }
     302              :     };
     303              : 
     304            0 :     let mut node = connect_to_compute(
     305            0 :         ctx,
     306            0 :         &TcpMechanism { params: &params },
     307            0 :         &user_info,
     308            0 :         mode.allow_self_signed_compute(config),
     309            0 :     )
     310            0 :     .or_else(|e| stream.throw_error(e))
     311            0 :     .await?;
     312              : 
     313            0 :     let session = cancellation_handler.get_session();
     314            0 :     prepare_client_connection(&node, &session, &mut stream).await?;
     315              : 
     316              :     // Before proxy passing, forward to compute whatever data is left in the
     317              :     // PqStream input buffer. Normally there is none, but our serverless npm
     318              :     // driver in pipeline mode sends startup, password and first query
     319              :     // immediately after opening the connection.
     320            0 :     let (stream, read_buf) = stream.into_inner();
     321            0 :     node.stream.write_all(&read_buf).await?;
     322              : 
     323            0 :     Ok(Some(ProxyPassthrough {
     324            0 :         client: stream,
     325            0 :         aux: node.aux.clone(),
     326            0 :         compute: node,
     327            0 :         req: _request_gauge,
     328            0 :         conn: _client_gauge,
     329            0 :         cancel: session,
     330            0 :     }))
     331            0 : }
     332              : 
     333              : /// Finish client connection initialization: confirm auth success, send params, etc.
     334            0 : #[tracing::instrument(skip_all)]
     335              : async fn prepare_client_connection(
     336              :     node: &compute::PostgresConnection,
     337              :     session: &cancellation::Session,
     338              :     stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     339              : ) -> Result<(), std::io::Error> {
     340              :     // Register compute's query cancellation token and produce a new, unique one.
     341              :     // The new token (cancel_key_data) will be sent to the client.
     342              :     let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
     343              : 
     344              :     // Forward all postgres connection params to the client.
     345              :     // Right now the implementation is very hacky and inefficent (ideally,
     346              :     // we don't need an intermediate hashmap), but at least it should be correct.
     347              :     for (name, value) in &node.params {
     348              :         // TODO: Theoretically, this could result in a big pile of params...
     349              :         stream.write_message_noflush(&Be::ParameterStatus {
     350              :             name: name.as_bytes(),
     351              :             value: value.as_bytes(),
     352              :         })?;
     353              :     }
     354              : 
     355              :     stream
     356              :         .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
     357              :         .write_message(&Be::ReadyForQuery)
     358              :         .await?;
     359              : 
     360              :     Ok(())
     361              : }
     362              : 
     363           26 : #[derive(Debug, Clone, PartialEq, Eq, Default)]
     364              : pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
     365              : 
     366              : impl NeonOptions {
     367           22 :     pub fn parse_params(params: &StartupMessageParams) -> Self {
     368           22 :         params
     369           22 :             .options_raw()
     370           22 :             .map(Self::parse_from_iter)
     371           22 :             .unwrap_or_default()
     372           22 :     }
     373           14 :     pub fn parse_options_raw(options: &str) -> Self {
     374           14 :         Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
     375           14 :     }
     376              : 
     377           26 :     fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
     378           26 :         let mut options = options
     379           26 :             .filter_map(neon_option)
     380           26 :             .map(|(k, v)| (k.into(), v.into()))
     381           26 :             .collect_vec();
     382           26 :         options.sort();
     383           26 :         Self(options)
     384           26 :     }
     385              : 
     386            8 :     pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
     387            8 :         // prefix + format!(" {k}:{v}")
     388            8 :         // kinda jank because SmolStr is immutable
     389            8 :         std::iter::once(prefix)
     390            8 :             .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
     391            8 :             .collect::<SmolStr>()
     392            8 :             .into()
     393            8 :     }
     394              : 
     395              :     /// <https://swagger.io/docs/specification/serialization/> DeepObject format
     396              :     /// `paramName[prop1]=value1&paramName[prop2]=value2&...`
     397            0 :     pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
     398            0 :         self.0
     399            0 :             .iter()
     400            0 :             .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
     401            0 :             .collect()
     402            0 :     }
     403              : }
     404              : 
     405           64 : pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
     406           64 :     static RE: OnceCell<Regex> = OnceCell::new();
     407           64 :     let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
     408              : 
     409           64 :     let cap = re.captures(bytes)?;
     410            8 :     let (_, [k, v]) = cap.extract();
     411            8 :     Some((k, v))
     412           64 : }
        

Generated by: LCOV version 2.1-beta