LCOV - code coverage report
Current view: top level - proxy/src - proxy.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 91.8 % 220 202
Test Date: 2024-02-07 07:37:29 Functions: 53.6 % 56 30

            Line data    Source code
       1              : #[cfg(test)]
       2              : mod tests;
       3              : 
       4              : pub mod connect_compute;
       5              : pub mod handshake;
       6              : pub mod passthrough;
       7              : pub mod retry;
       8              : pub mod wake_compute;
       9              : 
      10              : use crate::{
      11              :     auth,
      12              :     cancellation::{self, CancelMap},
      13              :     compute,
      14              :     config::{ProxyConfig, TlsConfig},
      15              :     context::RequestMonitoring,
      16              :     metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
      17              :     protocol2::WithClientIp,
      18              :     proxy::{handshake::handshake, passthrough::proxy_pass},
      19              :     rate_limiter::EndpointRateLimiter,
      20              :     stream::{PqStream, Stream},
      21              :     EndpointCacheKey,
      22              : };
      23              : use anyhow::{bail, Context};
      24              : use futures::TryFutureExt;
      25              : use itertools::Itertools;
      26              : use once_cell::sync::OnceCell;
      27              : use pq_proto::{BeMessage as Be, StartupMessageParams};
      28              : use regex::Regex;
      29              : use smol_str::{format_smolstr, SmolStr};
      30              : use std::sync::Arc;
      31              : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
      32              : use tokio_util::sync::CancellationToken;
      33              : use tracing::{error, info, info_span, Instrument};
      34              : 
      35              : use self::connect_compute::{connect_to_compute, TcpMechanism};
      36              : 
      37              : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
      38              : const ERR_PROTO_VIOLATION: &str = "protocol violation";
      39              : 
      40           87 : pub async fn run_until_cancelled<F: std::future::Future>(
      41           87 :     f: F,
      42           87 :     cancellation_token: &CancellationToken,
      43           87 : ) -> Option<F::Output> {
      44           87 :     match futures::future::select(
      45           87 :         std::pin::pin!(f),
      46           87 :         std::pin::pin!(cancellation_token.cancelled()),
      47           87 :     )
      48           86 :     .await
      49              :     {
      50           63 :         futures::future::Either::Left((f, _)) => Some(f),
      51           24 :         futures::future::Either::Right(((), _)) => None,
      52              :     }
      53           87 : }
      54              : 
      55           23 : pub async fn task_main(
      56           23 :     config: &'static ProxyConfig,
      57           23 :     listener: tokio::net::TcpListener,
      58           23 :     cancellation_token: CancellationToken,
      59           23 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      60           23 : ) -> anyhow::Result<()> {
      61           23 :     scopeguard::defer! {
      62           23 :         info!("proxy has shut down");
      63           23 :     }
      64           23 : 
      65           23 :     // When set for the server socket, the keepalive setting
      66           23 :     // will be inherited by all accepted client sockets.
      67           23 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
      68              : 
      69           23 :     let connections = tokio_util::task::task_tracker::TaskTracker::new();
      70           23 :     let cancel_map = Arc::new(CancelMap::default());
      71              : 
      72           61 :     while let Some(accept_result) =
      73           84 :         run_until_cancelled(listener.accept(), &cancellation_token).await
      74              :     {
      75           61 :         let (socket, peer_addr) = accept_result?;
      76              : 
      77           61 :         let session_id = uuid::Uuid::new_v4();
      78           61 :         let cancel_map = Arc::clone(&cancel_map);
      79           61 :         let endpoint_rate_limiter = endpoint_rate_limiter.clone();
      80              : 
      81           61 :         let session_span = info_span!(
      82           61 :             "handle_client",
      83           61 :             ?session_id,
      84           61 :             peer_addr = tracing::field::Empty,
      85           61 :             ep = tracing::field::Empty,
      86           61 :         );
      87              : 
      88           61 :         connections.spawn(
      89           61 :             async move {
      90           61 :                 info!("accepted postgres client connection");
      91              : 
      92           61 :                 let mut socket = WithClientIp::new(socket);
      93           61 :                 let mut peer_addr = peer_addr.ip();
      94           61 :                 if let Some(addr) = socket.wait_for_addr().await? {
      95            0 :                     peer_addr = addr.ip();
      96            0 :                     tracing::Span::current().record("peer_addr", &tracing::field::display(addr));
      97           61 :                 } else if config.require_client_ip {
      98            0 :                     bail!("missing required client IP");
      99           61 :                 }
     100              : 
     101           61 :                 let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
     102           61 : 
     103           61 :                 socket
     104           61 :                     .inner
     105           61 :                     .set_nodelay(true)
     106           61 :                     .context("failed to set socket option")?;
     107              : 
     108           61 :                 handle_client(
     109           61 :                     config,
     110           61 :                     &mut ctx,
     111           61 :                     cancel_map,
     112           61 :                     socket,
     113           61 :                     ClientMode::Tcp,
     114           61 :                     endpoint_rate_limiter,
     115           61 :                 )
     116         1010 :                 .await
     117           61 :             }
     118           61 :             .unwrap_or_else(move |e| {
     119           61 :                 // Acknowledge that the task has finished with an error.
     120           61 :                 error!("per-client task finished with an error: {e:#}");
     121           61 :             })
     122           61 :             .instrument(session_span),
     123           61 :         );
     124              :     }
     125              : 
     126           23 :     connections.close();
     127           23 :     drop(listener);
     128           23 : 
     129           23 :     // Drain connections
     130           23 :     connections.wait().await;
     131              : 
     132           23 :     Ok(())
     133           23 : }
     134              : 
     135              : pub enum ClientMode {
     136              :     Tcp,
     137              :     Websockets { hostname: Option<String> },
     138              : }
     139              : 
     140              : /// Abstracts the logic of handling TCP vs WS clients
     141              : impl ClientMode {
     142           50 :     fn allow_cleartext(&self) -> bool {
     143           50 :         match self {
     144           50 :             ClientMode::Tcp => false,
     145            0 :             ClientMode::Websockets { .. } => true,
     146              :         }
     147           50 :     }
     148              : 
     149           39 :     fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
     150           39 :         match self {
     151           39 :             ClientMode::Tcp => config.allow_self_signed_compute,
     152            0 :             ClientMode::Websockets { .. } => false,
     153              :         }
     154           39 :     }
     155              : 
     156           50 :     fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
     157           50 :         match self {
     158           50 :             ClientMode::Tcp => s.sni_hostname(),
     159            0 :             ClientMode::Websockets { hostname } => hostname.as_deref(),
     160              :         }
     161           50 :     }
     162              : 
     163           61 :     fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
     164           61 :         match self {
     165           61 :             ClientMode::Tcp => tls,
     166              :             // TLS is None here if using websockets, because the connection is already encrypted.
     167            0 :             ClientMode::Websockets { .. } => None,
     168              :         }
     169           61 :     }
     170              : }
     171              : 
     172           61 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
     173           61 :     config: &'static ProxyConfig,
     174           61 :     ctx: &mut RequestMonitoring,
     175           61 :     cancel_map: Arc<CancelMap>,
     176           61 :     stream: S,
     177           61 :     mode: ClientMode,
     178           61 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     179           61 : ) -> anyhow::Result<()> {
     180           61 :     info!(
     181           61 :         protocol = ctx.protocol,
     182           61 :         "handling interactive connection from client"
     183           61 :     );
     184              : 
     185           61 :     let proto = ctx.protocol;
     186           61 :     let _client_gauge = NUM_CLIENT_CONNECTION_GAUGE
     187           61 :         .with_label_values(&[proto])
     188           61 :         .guard();
     189           61 :     let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
     190           61 :         .with_label_values(&[proto])
     191           61 :         .guard();
     192           61 : 
     193           61 :     let tls = config.tls_config.as_ref();
     194           61 : 
     195           61 :     let pause = ctx.latency_timer.pause();
     196           61 :     let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map);
     197          105 :     let (mut stream, params) = match do_handshake.await? {
     198           50 :         Some(x) => x,
     199            0 :         None => return Ok(()), // it's a cancellation request
     200              :     };
     201           50 :     drop(pause);
     202           50 : 
     203           50 :     let hostname = mode.hostname(stream.get_ref());
     204           50 : 
     205           50 :     let common_names = tls.map(|tls| &tls.common_names);
     206           50 : 
     207           50 :     // Extract credentials which we're going to use for auth.
     208           50 :     let result = config
     209           50 :         .auth_backend
     210           50 :         .as_ref()
     211           50 :         .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
     212           50 :         .transpose();
     213              : 
     214           50 :     let user_info = match result {
     215           50 :         Ok(user_info) => user_info,
     216            0 :         Err(e) => stream.throw_error(e).await?,
     217              :     };
     218              : 
     219              :     // check rate limit
     220           50 :     if let Some(ep) = user_info.get_endpoint() {
     221           47 :         if !endpoint_rate_limiter.check(ep) {
     222            0 :             return stream
     223            0 :                 .throw_error(auth::AuthError::too_many_connections())
     224            0 :                 .await;
     225           47 :         }
     226            3 :     }
     227              : 
     228           50 :     let user = user_info.get_user().to_owned();
     229           50 :     let (mut node_info, user_info) = match user_info
     230           50 :         .authenticate(
     231           50 :             ctx,
     232           50 :             &mut stream,
     233           50 :             mode.allow_cleartext(),
     234           50 :             &config.authentication_config,
     235           50 :         )
     236          637 :         .await
     237              :     {
     238           39 :         Ok(auth_result) => auth_result,
     239           11 :         Err(e) => {
     240           11 :             let db = params.get("database");
     241           11 :             let app = params.get("application_name");
     242           11 :             let params_span = tracing::info_span!("", ?user, ?db, ?app);
     243              : 
     244           11 :             return stream.throw_error(e).instrument(params_span).await;
     245              :         }
     246              :     };
     247              : 
     248           39 :     node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
     249           39 : 
     250           39 :     let aux = node_info.aux.clone();
     251           39 :     let mut node = connect_to_compute(
     252           39 :         ctx,
     253           39 :         &TcpMechanism { params: &params },
     254           39 :         node_info,
     255           39 :         &user_info,
     256           39 :     )
     257           39 :     .or_else(|e| stream.throw_error(e))
     258          114 :     .await?;
     259              : 
     260           39 :     let session = cancel_map.get_session();
     261           39 :     prepare_client_connection(&node, &session, &mut stream).await?;
     262              : 
     263              :     // Before proxy passing, forward to compute whatever data is left in the
     264              :     // PqStream input buffer. Normally there is none, but our serverless npm
     265              :     // driver in pipeline mode sends startup, password and first query
     266              :     // immediately after opening the connection.
     267           39 :     let (stream, read_buf) = stream.into_inner();
     268           39 :     node.stream.write_all(&read_buf).await?;
     269              : 
     270          154 :     proxy_pass(ctx, stream, node.stream, aux).await
     271           61 : }
     272              : 
     273              : /// Finish client connection initialization: confirm auth success, send params, etc.
     274           39 : #[tracing::instrument(skip_all)]
     275              : async fn prepare_client_connection(
     276              :     node: &compute::PostgresConnection,
     277              :     session: &cancellation::Session,
     278              :     stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     279              : ) -> anyhow::Result<()> {
     280              :     // Register compute's query cancellation token and produce a new, unique one.
     281              :     // The new token (cancel_key_data) will be sent to the client.
     282              :     let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
     283              : 
     284              :     // Forward all postgres connection params to the client.
     285              :     // Right now the implementation is very hacky and inefficent (ideally,
     286              :     // we don't need an intermediate hashmap), but at least it should be correct.
     287              :     for (name, value) in &node.params {
     288              :         // TODO: Theoretically, this could result in a big pile of params...
     289              :         stream.write_message_noflush(&Be::ParameterStatus {
     290              :             name: name.as_bytes(),
     291              :             value: value.as_bytes(),
     292              :         })?;
     293              :     }
     294              : 
     295              :     stream
     296              :         .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
     297              :         .write_message(&Be::ReadyForQuery)
     298              :         .await?;
     299              : 
     300              :     Ok(())
     301              : }
     302              : 
     303           85 : #[derive(Debug, Clone, PartialEq, Eq, Default)]
     304              : pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
     305              : 
     306              : impl NeonOptions {
     307           69 :     pub fn parse_params(params: &StartupMessageParams) -> Self {
     308           69 :         params
     309           69 :             .options_raw()
     310           69 :             .map(Self::parse_from_iter)
     311           69 :             .unwrap_or_default()
     312           69 :     }
     313           14 :     pub fn parse_options_raw(options: &str) -> Self {
     314           14 :         Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
     315           14 :     }
     316              : 
     317           73 :     fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
     318           73 :         let mut options = options
     319           73 :             .filter_map(neon_option)
     320           73 :             .map(|(k, v)| (k.into(), v.into()))
     321           73 :             .collect_vec();
     322           73 :         options.sort();
     323           73 :         Self(options)
     324           73 :     }
     325              : 
     326          260 :     pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
     327          260 :         // prefix + format!(" {k}:{v}")
     328          260 :         // kinda jank because SmolStr is immutable
     329          260 :         std::iter::once(prefix)
     330          260 :             .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
     331          260 :             .collect::<SmolStr>()
     332          260 :             .into()
     333          260 :     }
     334              : 
     335              :     /// <https://swagger.io/docs/specification/serialization/> DeepObject format
     336              :     /// `paramName[prop1]=value1&paramName[prop2]=value2&...`
     337            0 :     pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
     338            0 :         self.0
     339            0 :             .iter()
     340            0 :             .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
     341            0 :             .collect()
     342            0 :     }
     343              : }
     344              : 
     345          181 : pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
     346          181 :     static RE: OnceCell<Regex> = OnceCell::new();
     347          181 :     let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
     348              : 
     349          181 :     let cap = re.captures(bytes)?;
     350            8 :     let (_, [k, v]) = cap.extract();
     351            8 :     Some((k, v))
     352          181 : }
        

Generated by: LCOV version 2.1-beta