LCOV - code coverage report
Current view: top level - proxy/src/binary - pg_sni_router.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 0.0 % 236 0
Test Date: 2025-07-26 17:20:05 Functions: 0.0 % 14 0

            Line data    Source code
       1              : //! A stand-alone program that routes connections, e.g. from
       2              : //! `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
       3              : //!
       4              : //! This allows connecting to pods/services running in the same Kubernetes cluster from
       5              : //! the outside. Similar to an ingress controller for HTTPS.
       6              : 
       7              : use std::io;
       8              : use std::net::SocketAddr;
       9              : use std::path::Path;
      10              : use std::sync::Arc;
      11              : 
      12              : use anyhow::{Context, anyhow, bail, ensure};
      13              : use clap::Arg;
      14              : use futures::future::Either;
      15              : use futures::{FutureExt, TryFutureExt};
      16              : use itertools::Itertools;
      17              : use rustls::crypto::ring;
      18              : use rustls::pki_types::{DnsName, PrivateKeyDer};
      19              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
      20              : use tokio::net::TcpListener;
      21              : use tokio_rustls::TlsConnector;
      22              : use tokio_rustls::server::TlsStream;
      23              : use tokio_util::sync::CancellationToken;
      24              : use tracing::{Instrument, error, info};
      25              : use utils::project_git_version;
      26              : use utils::sentry_init::init_sentry;
      27              : 
      28              : use crate::context::RequestContext;
      29              : use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
      30              : use crate::pglb::TlsRequired;
      31              : use crate::pqproto::FeStartupPacket;
      32              : use crate::protocol2::ConnectionInfo;
      33              : use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
      34              : use crate::stream::{PqStream, Stream};
      35              : use crate::util::run_until_cancelled;
      36              : 
      37              : project_git_version!(GIT_VERSION);
      38              : 
      39            0 : fn cli() -> clap::Command {
      40            0 :     clap::Command::new("Neon proxy/router")
      41            0 :         .version(GIT_VERSION)
      42            0 :         .arg(
      43            0 :             Arg::new("listen")
      44            0 :                 .short('l')
      45            0 :                 .long("listen")
      46            0 :                 .help("listen for incoming client connections on ip:port")
      47            0 :                 .default_value("127.0.0.1:4432"),
      48              :         )
      49            0 :         .arg(
      50            0 :             Arg::new("listen-tls")
      51            0 :                 .long("listen-tls")
      52            0 :                 .help("listen for incoming client connections on ip:port, requiring TLS to compute")
      53            0 :                 .default_value("127.0.0.1:4433"),
      54              :         )
      55            0 :         .arg(
      56            0 :             Arg::new("tls-key")
      57            0 :                 .short('k')
      58            0 :                 .long("tls-key")
      59            0 :                 .help("path to TLS key for client postgres connections")
      60            0 :                 .required(true),
      61              :         )
      62            0 :         .arg(
      63            0 :             Arg::new("tls-cert")
      64            0 :                 .short('c')
      65            0 :                 .long("tls-cert")
      66            0 :                 .help("path to TLS cert for client postgres connections")
      67            0 :                 .required(true),
      68              :         )
      69            0 :         .arg(
      70            0 :             Arg::new("dest")
      71            0 :                 .short('d')
      72            0 :                 .long("destination")
      73            0 :                 .help("append this domain zone to the SNI hostname to get the destination address")
      74            0 :                 .required(true),
      75              :         )
      76            0 : }
      77              : 
      78            0 : pub async fn run() -> anyhow::Result<()> {
      79            0 :     let _logging_guard = crate::logging::init()?;
      80            0 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
      81            0 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
      82              : 
      83            0 :     Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
      84              : 
      85            0 :     let args = cli().get_matches();
      86            0 :     let destination: String = args
      87            0 :         .get_one::<String>("dest")
      88            0 :         .expect("string argument defined")
      89            0 :         .parse()?;
      90              : 
      91              :     // Configure TLS
      92            0 :     let tls_config = match (
      93            0 :         args.get_one::<String>("tls-key"),
      94            0 :         args.get_one::<String>("tls-cert"),
      95              :     ) {
      96            0 :         (Some(key_path), Some(cert_path)) => parse_tls(key_path.as_ref(), cert_path.as_ref())?,
      97            0 :         _ => bail!("tls-key and tls-cert must be specified"),
      98              :     };
      99              : 
     100            0 :     let compute_tls_config =
     101            0 :         Arc::new(crate::tls::client_config::compute_client_config_with_root_certs()?);
     102              : 
     103              :     // Start listening for incoming client connections
     104            0 :     let proxy_address: SocketAddr = args
     105            0 :         .get_one::<String>("listen")
     106            0 :         .expect("listen argument defined")
     107            0 :         .parse()?;
     108            0 :     let proxy_address_compute_tls: SocketAddr = args
     109            0 :         .get_one::<String>("listen-tls")
     110            0 :         .expect("listen-tls argument defined")
     111            0 :         .parse()?;
     112              : 
     113            0 :     info!("Starting sni router on {proxy_address}");
     114            0 :     info!("Starting sni router on {proxy_address_compute_tls}");
     115            0 :     let proxy_listener = TcpListener::bind(proxy_address).await?;
     116            0 :     let proxy_listener_compute_tls = TcpListener::bind(proxy_address_compute_tls).await?;
     117              : 
     118            0 :     let cancellation_token = CancellationToken::new();
     119            0 :     let dest = Arc::new(destination);
     120              : 
     121            0 :     let main = tokio::spawn(task_main(
     122            0 :         dest.clone(),
     123            0 :         tls_config.clone(),
     124            0 :         None,
     125            0 :         proxy_listener,
     126            0 :         cancellation_token.clone(),
     127              :     ))
     128            0 :     .map(crate::error::flatten_err);
     129              : 
     130            0 :     let main_tls = tokio::spawn(task_main(
     131            0 :         dest,
     132            0 :         tls_config,
     133            0 :         Some(compute_tls_config),
     134            0 :         proxy_listener_compute_tls,
     135            0 :         cancellation_token.clone(),
     136              :     ))
     137            0 :     .map(crate::error::flatten_err);
     138              : 
     139            0 :     Metrics::get()
     140            0 :         .service
     141            0 :         .info
     142            0 :         .set_label(ServiceInfo::running());
     143              : 
     144            0 :     let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {}));
     145              : 
     146              :     // the signal task cant ever succeed.
     147              :     // the main task can error, or can succeed on cancellation.
     148              :     // we want to immediately exit on either of these cases
     149            0 :     let main = futures::future::try_join(main, main_tls);
     150            0 :     let signal = match futures::future::select(signals_task, main).await {
     151            0 :         Either::Left((res, _)) => crate::error::flatten_err(res)?,
     152            0 :         Either::Right((res, _)) => {
     153            0 :             res?;
     154            0 :             return Ok(());
     155              :         }
     156              :     };
     157              : 
     158              :     // maintenance tasks return `Infallible` success values, this is an impossible value
     159              :     // so this match statically ensures that there are no possibilities for that value
     160              :     match signal {}
     161            0 : }
     162              : 
     163            0 : pub(super) fn parse_tls(
     164            0 :     key_path: &Path,
     165            0 :     cert_path: &Path,
     166            0 : ) -> anyhow::Result<Arc<rustls::ServerConfig>> {
     167            0 :     let key = {
     168            0 :         let key_bytes = std::fs::read(key_path).context("TLS key file")?;
     169              : 
     170            0 :         let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
     171              : 
     172            0 :         ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
     173              :         PrivateKeyDer::Pkcs8(
     174            0 :             keys.pop()
     175            0 :                 .expect("keys should not be empty")
     176            0 :                 .context(format!(
     177            0 :                     "Failed to read TLS keys at '{}'",
     178            0 :                     key_path.display()
     179            0 :                 ))?,
     180              :         )
     181              :     };
     182              : 
     183            0 :     let cert_chain_bytes = std::fs::read(cert_path).context(format!(
     184            0 :         "Failed to read TLS cert file at '{}.'",
     185            0 :         cert_path.display()
     186            0 :     ))?;
     187              : 
     188            0 :     let cert_chain: Vec<_> = {
     189            0 :         rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     190            0 :             .try_collect()
     191            0 :             .with_context(|| {
     192            0 :                 format!(
     193            0 :                     "Failed to read TLS certificate chain from bytes from file at '{}'.",
     194            0 :                     cert_path.display()
     195              :                 )
     196            0 :             })?
     197              :     };
     198              : 
     199            0 :     let tls_config =
     200            0 :         rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
     201            0 :             .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
     202            0 :             .context("ring should support TLS1.2 and TLS1.3")?
     203            0 :             .with_no_client_auth()
     204            0 :             .with_single_cert(cert_chain, key)?
     205            0 :             .into();
     206              : 
     207            0 :     Ok(tls_config)
     208            0 : }
     209              : 
     210            0 : pub(super) async fn task_main(
     211            0 :     dest_suffix: Arc<String>,
     212            0 :     tls_config: Arc<rustls::ServerConfig>,
     213            0 :     compute_tls_config: Option<Arc<rustls::ClientConfig>>,
     214            0 :     listener: tokio::net::TcpListener,
     215            0 :     cancellation_token: CancellationToken,
     216            0 : ) -> anyhow::Result<()> {
     217              :     // When set for the server socket, the keepalive setting
     218              :     // will be inherited by all accepted client sockets.
     219            0 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
     220              : 
     221            0 :     let connections = tokio_util::task::task_tracker::TaskTracker::new();
     222              : 
     223            0 :     while let Some(accept_result) =
     224            0 :         run_until_cancelled(listener.accept(), &cancellation_token).await
     225              :     {
     226            0 :         let (socket, peer_addr) = accept_result?;
     227              : 
     228            0 :         let session_id = uuid::Uuid::new_v4();
     229            0 :         let tls_config = Arc::clone(&tls_config);
     230            0 :         let dest_suffix = Arc::clone(&dest_suffix);
     231            0 :         let compute_tls_config = compute_tls_config.clone();
     232              : 
     233            0 :         connections.spawn(
     234            0 :             async move {
     235            0 :                 socket
     236            0 :                     .set_nodelay(true)
     237            0 :                     .context("failed to set socket option")?;
     238              : 
     239            0 :                 let ctx = RequestContext::new(
     240            0 :                     session_id,
     241            0 :                     ConnectionInfo {
     242            0 :                         addr: peer_addr,
     243            0 :                         extra: None,
     244            0 :                     },
     245            0 :                     crate::metrics::Protocol::SniRouter,
     246              :                 );
     247            0 :                 handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
     248            0 :             }
     249            0 :             .unwrap_or_else(|e| {
     250            0 :                 if let Some(FirstMessage(io_error)) = e.downcast_ref() {
     251              :                     // this is noisy. if we get EOF on the very first message that's likely
     252              :                     // just NLB doing a healthcheck.
     253            0 :                     if io_error.kind() == io::ErrorKind::UnexpectedEof {
     254            0 :                         return;
     255            0 :                     }
     256            0 :                 }
     257              : 
     258              :                 // Acknowledge that the task has finished with an error.
     259            0 :                 error!("per-client task finished with an error: {e:#}");
     260            0 :             })
     261            0 :             .instrument(tracing::info_span!("handle_client", ?session_id)),
     262              :         );
     263              :     }
     264              : 
     265            0 :     connections.close();
     266            0 :     drop(listener);
     267              : 
     268            0 :     connections.wait().await;
     269              : 
     270            0 :     info!("all client connections have finished");
     271            0 :     Ok(())
     272            0 : }
     273              : 
     274              : #[derive(Debug, thiserror::Error)]
     275              : #[error(transparent)]
     276              : struct FirstMessage(io::Error);
     277              : 
     278            0 : async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
     279            0 :     ctx: &RequestContext,
     280            0 :     raw_stream: S,
     281            0 :     tls_config: Arc<rustls::ServerConfig>,
     282            0 : ) -> anyhow::Result<TlsStream<S>> {
     283            0 :     let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream))
     284            0 :         .await
     285            0 :         .map_err(FirstMessage)?;
     286              : 
     287            0 :     match msg {
     288              :         FeStartupPacket::SslRequest { direct: None } => {
     289            0 :             let raw = stream.accept_tls().await?;
     290              : 
     291            0 :             Ok(raw
     292            0 :                 .upgrade(tls_config, !ctx.has_private_peer_addr())
     293            0 :                 .await?)
     294              :         }
     295            0 :         unexpected => {
     296            0 :             info!(
     297              :                 ?unexpected,
     298            0 :                 "unexpected startup packet, rejecting connection"
     299              :             );
     300            0 :             Err(stream.throw_error(TlsRequired, None).await)?
     301              :         }
     302              :     }
     303            0 : }
     304              : 
     305            0 : async fn handle_client(
     306            0 :     ctx: RequestContext,
     307            0 :     dest_suffix: Arc<String>,
     308            0 :     tls_config: Arc<rustls::ServerConfig>,
     309            0 :     compute_tls_config: Option<Arc<rustls::ClientConfig>>,
     310            0 :     stream: impl AsyncRead + AsyncWrite + Unpin,
     311            0 : ) -> anyhow::Result<()> {
     312            0 :     let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
     313              : 
     314              :     // Cut off first part of the SNI domain
     315              :     // We receive required destination details in the format of
     316              :     //   `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
     317            0 :     let sni = tls_stream
     318            0 :         .get_ref()
     319            0 :         .1
     320            0 :         .server_name()
     321            0 :         .ok_or(anyhow!("SNI missing"))?;
     322            0 :     let dest: Vec<&str> = sni
     323            0 :         .split_once('.')
     324            0 :         .context("invalid SNI")?
     325              :         .0
     326            0 :         .splitn(3, "--")
     327            0 :         .collect();
     328            0 :     let port = dest[2].parse::<u16>().context("invalid port")?;
     329            0 :     let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
     330              : 
     331            0 :     info!("destination: {}", destination);
     332              : 
     333            0 :     let mut client = tokio::net::TcpStream::connect(&destination).await?;
     334              : 
     335            0 :     let client = if let Some(compute_tls_config) = compute_tls_config {
     336            0 :         info!("upgrading TLS");
     337              : 
     338              :         // send SslRequest
     339            0 :         client
     340            0 :             .write_all(b"\x00\x00\x00\x08\x04\xd2\x16\x2f")
     341            0 :             .await?;
     342              : 
     343              :         // wait for S/N respons
     344            0 :         let mut resp = b'N';
     345            0 :         client.read_exact(std::slice::from_mut(&mut resp)).await?;
     346              : 
     347              :         // error if not S
     348            0 :         ensure!(resp == b'S', "compute refused TLS");
     349              : 
     350              :         // upgrade to TLS.
     351            0 :         let domain = DnsName::try_from(destination)?;
     352            0 :         let domain = rustls::pki_types::ServerName::DnsName(domain);
     353            0 :         let client = TlsConnector::from(compute_tls_config)
     354            0 :             .connect(domain, client)
     355            0 :             .await?;
     356            0 :         Connection::Tls(client)
     357              :     } else {
     358            0 :         Connection::Raw(client)
     359              :     };
     360              : 
     361              :     // doesn't yet matter as pg-sni-router doesn't report analytics logs
     362            0 :     ctx.set_success();
     363            0 :     ctx.log_connect();
     364              : 
     365              :     // Starting from here we only proxy the client's traffic.
     366            0 :     info!("performing the proxy pass...");
     367              : 
     368            0 :     let res = match client {
     369            0 :         Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
     370            0 :         Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
     371              :     };
     372              : 
     373            0 :     match res {
     374            0 :         Ok(_) => Ok(()),
     375            0 :         Err(ErrorSource::Client(err)) => Err(err).context("client"),
     376            0 :         Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
     377              :     }
     378            0 : }
     379              : 
     380              : #[allow(clippy::large_enum_variant)]
     381              : enum Connection {
     382              :     Raw(tokio::net::TcpStream),
     383              :     Tls(tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
     384              : }
        

Generated by: LCOV version 2.1-beta