LCOV - code coverage report
Current view: top level - proxy/src - proxy.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 88.7 % 248 220
Test Date: 2024-02-12 20:26:03 Functions: 50.7 % 71 36

            Line data    Source code
       1              : #[cfg(test)]
       2              : mod tests;
       3              : 
       4              : pub mod connect_compute;
       5              : mod copy_bidirectional;
       6              : pub mod handshake;
       7              : pub mod passthrough;
       8              : pub mod retry;
       9              : pub mod wake_compute;
      10              : 
      11              : use crate::{
      12              :     auth,
      13              :     cancellation::{self, CancelMap},
      14              :     compute,
      15              :     config::{ProxyConfig, TlsConfig},
      16              :     context::RequestMonitoring,
      17              :     error::ReportableError,
      18              :     metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
      19              :     protocol2::WithClientIp,
      20              :     proxy::handshake::{handshake, HandshakeData},
      21              :     rate_limiter::EndpointRateLimiter,
      22              :     stream::{PqStream, Stream},
      23              :     EndpointCacheKey,
      24              : };
      25              : use anyhow::{bail, Context};
      26              : use futures::TryFutureExt;
      27              : use itertools::Itertools;
      28              : use once_cell::sync::OnceCell;
      29              : use pq_proto::{BeMessage as Be, StartupMessageParams};
      30              : use regex::Regex;
      31              : use smol_str::{format_smolstr, SmolStr};
      32              : use std::sync::Arc;
      33              : use thiserror::Error;
      34              : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
      35              : use tokio_util::sync::CancellationToken;
      36              : use tracing::{error, info, info_span, Instrument};
      37              : 
      38              : use self::{
      39              :     connect_compute::{connect_to_compute, TcpMechanism},
      40              :     passthrough::ProxyPassthrough,
      41              : };
      42              : 
      43              : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
      44              : 
      45           91 : pub async fn run_until_cancelled<F: std::future::Future>(
      46           91 :     f: F,
      47           91 :     cancellation_token: &CancellationToken,
      48           91 : ) -> Option<F::Output> {
      49           91 :     match futures::future::select(
      50           91 :         std::pin::pin!(f),
      51           91 :         std::pin::pin!(cancellation_token.cancelled()),
      52           91 :     )
      53           90 :     .await
      54              :     {
      55           65 :         futures::future::Either::Left((f, _)) => Some(f),
      56           26 :         futures::future::Either::Right(((), _)) => None,
      57              :     }
      58           91 : }
      59              : 
      60           25 : pub async fn task_main(
      61           25 :     config: &'static ProxyConfig,
      62           25 :     listener: tokio::net::TcpListener,
      63           25 :     cancellation_token: CancellationToken,
      64           25 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      65           25 : ) -> anyhow::Result<()> {
      66           25 :     scopeguard::defer! {
      67           25 :         info!("proxy has shut down");
      68              :     }
      69              : 
      70              :     // When set for the server socket, the keepalive setting
      71              :     // will be inherited by all accepted client sockets.
      72           25 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
      73              : 
      74           25 :     let connections = tokio_util::task::task_tracker::TaskTracker::new();
      75           25 :     let cancel_map = Arc::new(CancelMap::default());
      76              : 
      77           63 :     while let Some(accept_result) =
      78           88 :         run_until_cancelled(listener.accept(), &cancellation_token).await
      79              :     {
      80           63 :         let (socket, peer_addr) = accept_result?;
      81              : 
      82           63 :         let session_id = uuid::Uuid::new_v4();
      83           63 :         let cancel_map = Arc::clone(&cancel_map);
      84           63 :         let endpoint_rate_limiter = endpoint_rate_limiter.clone();
      85              : 
      86           63 :         let session_span = info_span!(
      87           63 :             "handle_client",
      88           63 :             ?session_id,
      89           63 :             peer_addr = tracing::field::Empty,
      90           63 :             ep = tracing::field::Empty,
      91           63 :         );
      92              : 
      93           63 :         connections.spawn(
      94           63 :             async move {
      95           63 :                 info!("accepted postgres client connection");
      96              : 
      97           63 :                 let mut socket = WithClientIp::new(socket);
      98           63 :                 let mut peer_addr = peer_addr.ip();
      99           63 :                 if let Some(addr) = socket.wait_for_addr().await? {
     100            0 :                     peer_addr = addr.ip();
     101            0 :                     tracing::Span::current().record("peer_addr", &tracing::field::display(addr));
     102           63 :                 } else if config.require_client_ip {
     103            0 :                     bail!("missing required client IP");
     104           63 :                 }
     105              : 
     106           63 :                 socket
     107           63 :                     .inner
     108           63 :                     .set_nodelay(true)
     109           63 :                     .context("failed to set socket option")?;
     110              : 
     111           63 :                 let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
     112              : 
     113           63 :                 let res = handle_client(
     114           63 :                     config,
     115           63 :                     &mut ctx,
     116           63 :                     cancel_map,
     117           63 :                     socket,
     118           63 :                     ClientMode::Tcp,
     119           63 :                     endpoint_rate_limiter,
     120           63 :                 )
     121          894 :                 .await;
     122              : 
     123           41 :                 match res {
     124           22 :                     Err(e) => {
     125           22 :                         // todo: log and push to ctx the error kind
     126           22 :                         ctx.set_error_kind(e.get_error_kind());
     127           22 :                         ctx.log();
     128           22 :                         Err(e.into())
     129              :                     }
     130              :                     Ok(None) => {
     131            0 :                         ctx.set_success();
     132            0 :                         ctx.log();
     133            0 :                         Ok(())
     134              :                     }
     135           41 :                     Ok(Some(p)) => {
     136           41 :                         ctx.set_success();
     137           41 :                         ctx.log();
     138          121 :                         p.proxy_pass().await
     139              :                     }
     140              :                 }
     141           63 :             }
     142           63 :             .unwrap_or_else(move |e| {
     143           63 :                 // Acknowledge that the task has finished with an error.
     144           63 :                 error!("per-client task finished with an error: {e:#}");
     145           63 :             })
     146           63 :             .instrument(session_span),
     147           63 :         );
     148              :     }
     149              : 
     150           25 :     connections.close();
     151           25 :     drop(listener);
     152           25 : 
     153           25 :     // Drain connections
     154           25 :     connections.wait().await;
     155              : 
     156           25 :     Ok(())
     157           25 : }
     158              : 
     159              : pub enum ClientMode {
     160              :     Tcp,
     161              :     Websockets { hostname: Option<String> },
     162              : }
     163              : 
     164              : /// Abstracts the logic of handling TCP vs WS clients
     165              : impl ClientMode {
     166           52 :     fn allow_cleartext(&self) -> bool {
     167           52 :         match self {
     168           52 :             ClientMode::Tcp => false,
     169            0 :             ClientMode::Websockets { .. } => true,
     170              :         }
     171           52 :     }
     172              : 
     173           41 :     fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
     174           41 :         match self {
     175           41 :             ClientMode::Tcp => config.allow_self_signed_compute,
     176            0 :             ClientMode::Websockets { .. } => false,
     177              :         }
     178           41 :     }
     179              : 
     180           52 :     fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
     181           52 :         match self {
     182           52 :             ClientMode::Tcp => s.sni_hostname(),
     183            0 :             ClientMode::Websockets { hostname } => hostname.as_deref(),
     184              :         }
     185           52 :     }
     186              : 
     187           63 :     fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
     188           63 :         match self {
     189           63 :             ClientMode::Tcp => tls,
     190              :             // TLS is None here if using websockets, because the connection is already encrypted.
     191            0 :             ClientMode::Websockets { .. } => None,
     192              :         }
     193           63 :     }
     194              : }
     195              : 
     196           88 : #[derive(Debug, Error)]
     197              : // almost all errors should be reported to the user, but there's a few cases where we cannot
     198              : // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
     199              : // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
     200              : //    we cannot be sure the client even understands our error message
     201              : // 3. PrepareClient: The client disconnected, so we can't tell them anyway...
     202              : pub enum ClientRequestError {
     203              :     #[error("{0}")]
     204              :     Cancellation(#[from] cancellation::CancelError),
     205              :     #[error("{0}")]
     206              :     Handshake(#[from] handshake::HandshakeError),
     207              :     #[error("{0}")]
     208              :     HandshakeTimeout(#[from] tokio::time::error::Elapsed),
     209              :     #[error("{0}")]
     210              :     PrepareClient(#[from] std::io::Error),
     211              :     #[error("{0}")]
     212              :     ReportedError(#[from] crate::stream::ReportedError),
     213              : }
     214              : 
     215              : impl ReportableError for ClientRequestError {
     216           22 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     217           22 :         match self {
     218            0 :             ClientRequestError::Cancellation(e) => e.get_error_kind(),
     219           11 :             ClientRequestError::Handshake(e) => e.get_error_kind(),
     220            0 :             ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
     221           11 :             ClientRequestError::ReportedError(e) => e.get_error_kind(),
     222            0 :             ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
     223              :         }
     224           22 :     }
     225              : }
     226              : 
     227           63 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
     228           63 :     config: &'static ProxyConfig,
     229           63 :     ctx: &mut RequestMonitoring,
     230           63 :     cancel_map: Arc<CancelMap>,
     231           63 :     stream: S,
     232           63 :     mode: ClientMode,
     233           63 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     234           63 : ) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
     235           63 :     info!(
     236           63 :         protocol = ctx.protocol,
     237           63 :         "handling interactive connection from client"
     238           63 :     );
     239              : 
     240           63 :     let proto = ctx.protocol;
     241           63 :     let _client_gauge = NUM_CLIENT_CONNECTION_GAUGE
     242           63 :         .with_label_values(&[proto])
     243           63 :         .guard();
     244           63 :     let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
     245           63 :         .with_label_values(&[proto])
     246           63 :         .guard();
     247           63 : 
     248           63 :     let tls = config.tls_config.as_ref();
     249           63 : 
     250           63 :     let pause = ctx.latency_timer.pause();
     251           63 :     let do_handshake = handshake(stream, mode.handshake_tls(tls));
     252           52 :     let (mut stream, params) =
     253          108 :         match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
     254           52 :             HandshakeData::Startup(stream, params) => (stream, params),
     255            0 :             HandshakeData::Cancel(cancel_key_data) => {
     256            0 :                 return Ok(cancel_map
     257            0 :                     .cancel_session(cancel_key_data)
     258            0 :                     .await
     259            0 :                     .map(|()| None)?)
     260              :             }
     261              :         };
     262           52 :     drop(pause);
     263           52 : 
     264           52 :     let hostname = mode.hostname(stream.get_ref());
     265           52 : 
     266           52 :     let common_names = tls.map(|tls| &tls.common_names);
     267           52 : 
     268           52 :     // Extract credentials which we're going to use for auth.
     269           52 :     let result = config
     270           52 :         .auth_backend
     271           52 :         .as_ref()
     272           52 :         .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
     273           52 :         .transpose();
     274              : 
     275           52 :     let user_info = match result {
     276           52 :         Ok(user_info) => user_info,
     277            0 :         Err(e) => stream.throw_error(e).await?,
     278              :     };
     279              : 
     280              :     // check rate limit
     281           52 :     if let Some(ep) = user_info.get_endpoint() {
     282           49 :         if !endpoint_rate_limiter.check(ep) {
     283            0 :             return stream
     284            0 :                 .throw_error(auth::AuthError::too_many_connections())
     285            0 :                 .await?;
     286           49 :         }
     287            3 :     }
     288              : 
     289           52 :     let user = user_info.get_user().to_owned();
     290           52 :     let (mut node_info, user_info) = match user_info
     291           52 :         .authenticate(
     292           52 :             ctx,
     293           52 :             &mut stream,
     294           52 :             mode.allow_cleartext(),
     295           52 :             &config.authentication_config,
     296           52 :         )
     297          666 :         .await
     298              :     {
     299           41 :         Ok(auth_result) => auth_result,
     300           11 :         Err(e) => {
     301           11 :             let db = params.get("database");
     302           11 :             let app = params.get("application_name");
     303           11 :             let params_span = tracing::info_span!("", ?user, ?db, ?app);
     304              : 
     305           11 :             return stream.throw_error(e).instrument(params_span).await?;
     306              :         }
     307              :     };
     308              : 
     309           41 :     node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
     310           41 : 
     311           41 :     let aux = node_info.aux.clone();
     312           41 :     let mut node = connect_to_compute(
     313           41 :         ctx,
     314           41 :         &TcpMechanism { params: &params },
     315           41 :         node_info,
     316           41 :         &user_info,
     317           41 :     )
     318           41 :     .or_else(|e| stream.throw_error(e))
     319          120 :     .await?;
     320              : 
     321           41 :     let session = cancel_map.get_session();
     322           41 :     prepare_client_connection(&node, &session, &mut stream).await?;
     323              : 
     324              :     // Before proxy passing, forward to compute whatever data is left in the
     325              :     // PqStream input buffer. Normally there is none, but our serverless npm
     326              :     // driver in pipeline mode sends startup, password and first query
     327              :     // immediately after opening the connection.
     328           41 :     let (stream, read_buf) = stream.into_inner();
     329           41 :     node.stream.write_all(&read_buf).await?;
     330              : 
     331           41 :     Ok(Some(ProxyPassthrough {
     332           41 :         client: stream,
     333           41 :         compute: node,
     334           41 :         aux,
     335           41 :         req: _request_gauge,
     336           41 :         conn: _client_gauge,
     337           41 :     }))
     338           63 : }
     339              : 
     340              : /// Finish client connection initialization: confirm auth success, send params, etc.
     341           82 : #[tracing::instrument(skip_all)]
     342              : async fn prepare_client_connection(
     343              :     node: &compute::PostgresConnection,
     344              :     session: &cancellation::Session,
     345              :     stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     346              : ) -> Result<(), std::io::Error> {
     347              :     // Register compute's query cancellation token and produce a new, unique one.
     348              :     // The new token (cancel_key_data) will be sent to the client.
     349              :     let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
     350              : 
     351              :     // Forward all postgres connection params to the client.
     352              :     // Right now the implementation is very hacky and inefficent (ideally,
     353              :     // we don't need an intermediate hashmap), but at least it should be correct.
     354              :     for (name, value) in &node.params {
     355              :         // TODO: Theoretically, this could result in a big pile of params...
     356              :         stream.write_message_noflush(&Be::ParameterStatus {
     357              :             name: name.as_bytes(),
     358              :             value: value.as_bytes(),
     359              :         })?;
     360              :     }
     361              : 
     362              :     stream
     363              :         .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
     364              :         .write_message(&Be::ReadyForQuery)
     365              :         .await?;
     366              : 
     367              :     Ok(())
     368              : }
     369              : 
     370          248 : #[derive(Debug, Clone, PartialEq, Eq, Default)]
     371              : pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
     372              : 
     373              : impl NeonOptions {
     374           71 :     pub fn parse_params(params: &StartupMessageParams) -> Self {
     375           71 :         params
     376           71 :             .options_raw()
     377           71 :             .map(Self::parse_from_iter)
     378           71 :             .unwrap_or_default()
     379           71 :     }
     380           14 :     pub fn parse_options_raw(options: &str) -> Self {
     381           14 :         Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
     382           14 :     }
     383              : 
     384           75 :     fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
     385           75 :         let mut options = options
     386           75 :             .filter_map(neon_option)
     387           75 :             .map(|(k, v)| (k.into(), v.into()))
     388           75 :             .collect_vec();
     389           75 :         options.sort();
     390           75 :         Self(options)
     391           75 :     }
     392              : 
     393          328 :     pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
     394          328 :         // prefix + format!(" {k}:{v}")
     395          328 :         // kinda jank because SmolStr is immutable
     396          328 :         std::iter::once(prefix)
     397          328 :             .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
     398          328 :             .collect::<SmolStr>()
     399          328 :             .into()
     400          328 :     }
     401              : 
     402              :     /// <https://swagger.io/docs/specification/serialization/> DeepObject format
     403              :     /// `paramName[prop1]=value1&paramName[prop2]=value2&...`
     404            0 :     pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
     405            0 :         self.0
     406            0 :             .iter()
     407            0 :             .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
     408            0 :             .collect()
     409            0 :     }
     410              : }
     411              : 
     412          185 : pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
     413          185 :     static RE: OnceCell<Regex> = OnceCell::new();
     414          185 :     let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
     415              : 
     416          185 :     let cap = re.captures(bytes)?;
     417            8 :     let (_, [k, v]) = cap.extract();
     418            8 :     Some((k, v))
     419          185 : }
        

Generated by: LCOV version 2.1-beta