LCOV - code coverage report
Current view: top level - proxy/src/proxy - mod.rs (source / functions) Coverage Total Hit
Test: 89231e3f993a79261f961ab69a97777ade006195.info Lines: 14.9 % 215 32
Test Date: 2025-07-31 13:52:20 Functions: 28.1 % 32 9

            Line data    Source code
       1              : #[cfg(test)]
       2              : mod tests;
       3              : 
       4              : pub(crate) mod connect_auth;
       5              : pub(crate) mod connect_compute;
       6              : pub(crate) mod retry;
       7              : pub(crate) mod wake_compute;
       8              : 
       9              : use std::collections::HashSet;
      10              : use std::convert::Infallible;
      11              : use std::sync::Arc;
      12              : 
      13              : use futures::TryStreamExt;
      14              : use itertools::Itertools;
      15              : use once_cell::sync::OnceCell;
      16              : use postgres_client::RawCancelToken;
      17              : use postgres_client::connect_raw::StartupStream;
      18              : use postgres_protocol::message::backend::Message;
      19              : use regex::Regex;
      20              : use serde::{Deserialize, Serialize};
      21              : use smol_str::{SmolStr, format_smolstr};
      22              : use tokio::io::{AsyncRead, AsyncWrite};
      23              : use tokio::net::TcpStream;
      24              : use tokio::sync::oneshot;
      25              : use tracing::Instrument;
      26              : 
      27              : use crate::cancellation::{CancelClosure, CancellationHandler};
      28              : use crate::compute::{ComputeConnection, PostgresError, RustlsStream};
      29              : use crate::config::ProxyConfig;
      30              : use crate::context::RequestContext;
      31              : pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
      32              : use crate::pglb::{ClientMode, ClientRequestError};
      33              : use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
      34              : use crate::rate_limiter::EndpointRateLimiter;
      35              : use crate::stream::{PqStream, Stream};
      36              : use crate::types::EndpointCacheKey;
      37              : use crate::{auth, compute};
      38              : 
      39              : #[allow(clippy::too_many_arguments)]
      40            0 : pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
      41            0 :     config: &'static ProxyConfig,
      42            0 :     auth_backend: &'static auth::Backend<'static, ()>,
      43            0 :     ctx: &RequestContext,
      44            0 :     cancellation_handler: Arc<CancellationHandler>,
      45            0 :     client: &mut PqStream<Stream<S>>,
      46            0 :     mode: &ClientMode,
      47            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      48            0 :     common_names: Option<&HashSet<String>>,
      49            0 :     params: &StartupMessageParams,
      50            0 : ) -> Result<(ComputeConnection, oneshot::Sender<Infallible>), ClientRequestError> {
      51            0 :     let hostname = mode.hostname(client.get_ref());
      52              :     // Extract credentials which we're going to use for auth.
      53            0 :     let result = auth_backend
      54            0 :         .as_ref()
      55            0 :         .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, params, hostname, common_names))
      56            0 :         .transpose();
      57              : 
      58            0 :     let user_info = match result {
      59            0 :         Ok(user_info) => user_info,
      60            0 :         Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
      61              :     };
      62              : 
      63            0 :     let user = user_info.get_user().to_owned();
      64            0 :     let user_info = match user_info
      65            0 :         .authenticate(
      66            0 :             ctx,
      67            0 :             client,
      68            0 :             mode.allow_cleartext(),
      69            0 :             &config.authentication_config,
      70            0 :             endpoint_rate_limiter,
      71              :         )
      72            0 :         .await
      73              :     {
      74            0 :         Ok(auth_result) => auth_result,
      75            0 :         Err(e) => {
      76            0 :             let db = params.get("database");
      77            0 :             let app = params.get("application_name");
      78            0 :             let params_span = tracing::info_span!("", ?user, ?db, ?app);
      79              : 
      80            0 :             return Err(client
      81            0 :                 .throw_error(e, Some(ctx))
      82            0 :                 .instrument(params_span)
      83            0 :                 .await)?;
      84              :         }
      85              :     };
      86              : 
      87            0 :     let (cplane, creds) = match user_info {
      88            0 :         auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
      89            0 :         auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
      90              :     };
      91            0 :     let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
      92            0 :     let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
      93            0 :     auth_info.set_startup_params(params, params_compat);
      94              : 
      95            0 :     let backend = auth::Backend::ControlPlane(cplane, creds.info);
      96              : 
      97              :     // TODO: callback to pglb
      98            0 :     let res = connect_auth::connect_to_compute_and_auth(
      99            0 :         ctx,
     100            0 :         config,
     101            0 :         &backend,
     102            0 :         auth_info,
     103            0 :         connect_compute::TlsNegotiation::Postgres,
     104            0 :     )
     105            0 :     .await;
     106              : 
     107            0 :     let mut node = match res {
     108            0 :         Ok(node) => node,
     109            0 :         Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
     110              :     };
     111              : 
     112            0 :     send_client_greeting(ctx, &config.greetings, client);
     113              : 
     114            0 :     let auth::Backend::ControlPlane(_, user_info) = backend else {
     115            0 :         unreachable!("ensured above");
     116              :     };
     117              : 
     118            0 :     let session = cancellation_handler.get_key();
     119              : 
     120            0 :     let (process_id, secret_key) =
     121            0 :         forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?;
     122            0 :     let hostname = node.hostname.to_string();
     123              : 
     124            0 :     let session_id = ctx.session_id();
     125            0 :     let (cancel_on_shutdown, cancel) = oneshot::channel();
     126            0 :     tokio::spawn(async move {
     127            0 :         session
     128            0 :             .maintain_cancel_key(
     129            0 :                 session_id,
     130            0 :                 cancel,
     131            0 :                 &CancelClosure {
     132            0 :                     socket_addr: node.socket_addr,
     133            0 :                     cancel_token: RawCancelToken {
     134            0 :                         ssl_mode: node.ssl_mode,
     135            0 :                         process_id,
     136            0 :                         secret_key,
     137            0 :                     },
     138            0 :                     hostname,
     139            0 :                     user_info,
     140            0 :                 },
     141            0 :                 &config.connect_to_compute,
     142            0 :             )
     143            0 :             .await;
     144            0 :     });
     145              : 
     146            0 :     Ok((node, cancel_on_shutdown))
     147            0 : }
     148              : 
     149              : /// Greet the client with any useful information.
     150            0 : pub(crate) fn send_client_greeting(
     151            0 :     ctx: &RequestContext,
     152            0 :     greetings: &String,
     153            0 :     client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     154            0 : ) {
     155              :     // Expose session_id to clients if we have a greeting message.
     156            0 :     if !greetings.is_empty() {
     157            0 :         let session_msg = format!("{}, session_id: {}", greetings, ctx.session_id());
     158            0 :         client.write_message(BeMessage::NoticeResponse(session_msg.as_str()));
     159            0 :     }
     160              : 
     161              :     // Forward recorded latencies for probing requests
     162            0 :     if let Some(testodrome_id) = ctx.get_testodrome_id() {
     163            0 :         client.write_message(BeMessage::ParameterStatus {
     164            0 :             name: "neon.testodrome_id".as_bytes(),
     165            0 :             value: testodrome_id.as_bytes(),
     166            0 :         });
     167            0 : 
     168            0 :         let latency_measured = ctx.get_proxy_latency();
     169            0 : 
     170            0 :         client.write_message(BeMessage::ParameterStatus {
     171            0 :             name: "neon.cplane_latency".as_bytes(),
     172            0 :             value: latency_measured.cplane.as_micros().to_string().as_bytes(),
     173            0 :         });
     174            0 : 
     175            0 :         client.write_message(BeMessage::ParameterStatus {
     176            0 :             name: "neon.client_latency".as_bytes(),
     177            0 :             value: latency_measured.client.as_micros().to_string().as_bytes(),
     178            0 :         });
     179            0 : 
     180            0 :         client.write_message(BeMessage::ParameterStatus {
     181            0 :             name: "neon.compute_latency".as_bytes(),
     182            0 :             value: latency_measured.compute.as_micros().to_string().as_bytes(),
     183            0 :         });
     184            0 : 
     185            0 :         client.write_message(BeMessage::ParameterStatus {
     186            0 :             name: "neon.retry_latency".as_bytes(),
     187            0 :             value: latency_measured.retry.as_micros().to_string().as_bytes(),
     188            0 :         });
     189            0 :     }
     190            0 : }
     191              : 
     192            0 : pub(crate) async fn forward_compute_params_to_client(
     193            0 :     ctx: &RequestContext,
     194            0 :     cancel_key_data: CancelKeyData,
     195            0 :     client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     196            0 :     compute: &mut StartupStream<TcpStream, RustlsStream>,
     197            0 : ) -> Result<(i32, i32), ClientRequestError> {
     198            0 :     let mut process_id = 0;
     199            0 :     let mut secret_key = 0;
     200              : 
     201            0 :     let err = loop {
     202              :         // if the client buffer is too large, let's write out some bytes now to save some space
     203            0 :         client.write_if_full().await?;
     204              : 
     205            0 :         let msg = match compute.try_next().await {
     206            0 :             Ok(msg) => msg,
     207            0 :             Err(e) => break postgres_client::Error::io(e),
     208              :         };
     209              : 
     210            0 :         match msg {
     211              :             // Send our cancellation key data instead.
     212            0 :             Some(Message::BackendKeyData(body)) => {
     213            0 :                 client.write_message(BeMessage::BackendKeyData(cancel_key_data));
     214            0 :                 process_id = body.process_id();
     215            0 :                 secret_key = body.secret_key();
     216            0 :             }
     217              :             // Forward all postgres connection params to the client.
     218            0 :             Some(Message::ParameterStatus(body)) => {
     219            0 :                 if let Ok(name) = body.name()
     220            0 :                     && let Ok(value) = body.value()
     221            0 :                 {
     222            0 :                     client.write_message(BeMessage::ParameterStatus {
     223            0 :                         name: name.as_bytes(),
     224            0 :                         value: value.as_bytes(),
     225            0 :                     });
     226            0 :                 }
     227              :             }
     228              :             // Forward all notices to the client.
     229            0 :             Some(Message::NoticeResponse(notice)) => {
     230            0 :                 client.write_raw(notice.as_bytes().len(), b'N', |buf| {
     231            0 :                     buf.extend_from_slice(notice.as_bytes());
     232            0 :                 });
     233              :             }
     234              :             Some(Message::ReadyForQuery(_)) => {
     235            0 :                 client.write_message(BeMessage::ReadyForQuery);
     236            0 :                 return Ok((process_id, secret_key));
     237              :             }
     238            0 :             Some(Message::ErrorResponse(body)) => break postgres_client::Error::db(body),
     239            0 :             Some(_) => break postgres_client::Error::unexpected_message(),
     240            0 :             None => break postgres_client::Error::closed(),
     241              :         }
     242              :     };
     243              : 
     244            0 :     Err(client
     245            0 :         .throw_error(PostgresError::Postgres(err), Some(ctx))
     246            0 :         .await)?
     247            0 : }
     248              : 
     249            0 : #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
     250              : pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
     251              : 
     252              : impl NeonOptions {
     253              :     // proxy options:
     254              : 
     255              :     /// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
     256              :     pub const PARAMS_COMPAT: &'static str = "proxy_params_compat";
     257              : 
     258              :     // cplane options:
     259              : 
     260              :     /// `LSN` allows provisioning an ephemeral compute with time-travel to the provided LSN.
     261              :     const LSN: &'static str = "lsn";
     262              : 
     263              :     /// `TIMESTAMP` allows provisioning an ephemeral compute with time-travel to the provided timestamp.
     264              :     const TIMESTAMP: &'static str = "timestamp";
     265              : 
     266              :     /// `ENDPOINT_TYPE` allows configuring an ephemeral compute to be read_only or read_write.
     267              :     const ENDPOINT_TYPE: &'static str = "endpoint_type";
     268              : 
     269           13 :     pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
     270           13 :         params
     271           13 :             .options_raw()
     272           13 :             .map(Self::parse_from_iter)
     273           13 :             .unwrap_or_default()
     274           13 :     }
     275              : 
     276           13 :     pub(crate) fn parse_options_raw(options: &str) -> Self {
     277           13 :         Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
     278           13 :     }
     279              : 
     280            0 :     pub(crate) fn get(&self, key: &str) -> Option<SmolStr> {
     281            0 :         self.0
     282            0 :             .iter()
     283            0 :             .find_map(|(k, v)| (k == key).then_some(v))
     284            0 :             .cloned()
     285            0 :     }
     286              : 
     287            2 :     pub(crate) fn is_ephemeral(&self) -> bool {
     288            2 :         self.0.iter().any(|(k, _)| match &**k {
     289              :             // This is not a cplane option, we know it does not create ephemeral computes.
     290            0 :             Self::PARAMS_COMPAT => false,
     291            0 :             Self::LSN => true,
     292            0 :             Self::TIMESTAMP => true,
     293            0 :             Self::ENDPOINT_TYPE => true,
     294              :             // err on the side of caution. any cplane options we don't know about
     295              :             // might lead to ephemeral computes.
     296            0 :             _ => true,
     297            0 :         })
     298            2 :     }
     299              : 
     300           20 :     fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
     301           20 :         let mut options = options
     302           20 :             .filter_map(neon_option)
     303           20 :             .map(|(k, v)| (k.into(), v.into()))
     304           20 :             .collect_vec();
     305           20 :         options.sort();
     306           20 :         Self(options)
     307           20 :     }
     308              : 
     309            4 :     pub(crate) fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
     310              :         // prefix + format!(" {k}:{v}")
     311              :         // kinda jank because SmolStr is immutable
     312            4 :         std::iter::once(prefix)
     313            4 :             .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
     314            4 :             .collect::<SmolStr>()
     315            4 :             .into()
     316            4 :     }
     317              : 
     318              :     /// <https://swagger.io/docs/specification/serialization/> DeepObject format
     319              :     /// `paramName[prop1]=value1&paramName[prop2]=value2&...`
     320            0 :     pub(crate) fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
     321            0 :         self.0
     322            0 :             .iter()
     323            0 :             .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
     324            0 :             .collect()
     325            0 :     }
     326              : }
     327              : 
     328           34 : pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
     329              :     static RE: OnceCell<Regex> = OnceCell::new();
     330           34 :     let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").expect("regex should be correct"));
     331              : 
     332           34 :     let cap = re.captures(bytes)?;
     333            5 :     let (_, [k, v]) = cap.extract();
     334            5 :     Some((k, v))
     335           34 : }
        

Generated by: LCOV version 2.1-beta