LCOV - code coverage report
Current view: top level - proxy/src/binary - pg_sni_router.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 0.0 % 212 0
Test Date: 2025-02-20 13:11:02 Functions: 0.0 % 17 0

            Line data    Source code
       1              : /// A stand-alone program that routes connections, e.g. from
       2              : /// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
       3              : ///
       4              : /// This allows connecting to pods/services running in the same Kubernetes cluster from
       5              : /// the outside. Similar to an ingress controller for HTTPS.
       6              : use std::{net::SocketAddr, sync::Arc};
       7              : 
       8              : use anyhow::{anyhow, bail, ensure, Context};
       9              : use clap::Arg;
      10              : use futures::future::Either;
      11              : use futures::TryFutureExt;
      12              : use itertools::Itertools;
      13              : use rustls::crypto::ring;
      14              : use rustls::pki_types::PrivateKeyDer;
      15              : use tokio::io::{AsyncRead, AsyncWrite};
      16              : use tokio::net::TcpListener;
      17              : use tokio_util::sync::CancellationToken;
      18              : use tracing::{error, info, Instrument};
      19              : use utils::project_git_version;
      20              : use utils::sentry_init::init_sentry;
      21              : 
      22              : use crate::context::RequestContext;
      23              : use crate::metrics::{Metrics, ThreadPoolMetrics};
      24              : use crate::protocol2::ConnectionInfo;
      25              : use crate::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
      26              : use crate::stream::{PqStream, Stream};
      27              : use crate::tls::TlsServerEndPoint;
      28              : 
      29              : project_git_version!(GIT_VERSION);
      30              : 
      31            0 : fn cli() -> clap::Command {
      32            0 :     clap::Command::new("Neon proxy/router")
      33            0 :         .version(GIT_VERSION)
      34            0 :         .arg(
      35            0 :             Arg::new("listen")
      36            0 :                 .short('l')
      37            0 :                 .long("listen")
      38            0 :                 .help("listen for incoming client connections on ip:port")
      39            0 :                 .default_value("127.0.0.1:4432"),
      40            0 :         )
      41            0 :         .arg(
      42            0 :             Arg::new("tls-key")
      43            0 :                 .short('k')
      44            0 :                 .long("tls-key")
      45            0 :                 .help("path to TLS key for client postgres connections")
      46            0 :                 .required(true),
      47            0 :         )
      48            0 :         .arg(
      49            0 :             Arg::new("tls-cert")
      50            0 :                 .short('c')
      51            0 :                 .long("tls-cert")
      52            0 :                 .help("path to TLS cert for client postgres connections")
      53            0 :                 .required(true),
      54            0 :         )
      55            0 :         .arg(
      56            0 :             Arg::new("dest")
      57            0 :                 .short('d')
      58            0 :                 .long("destination")
      59            0 :                 .help("append this domain zone to the SNI hostname to get the destination address")
      60            0 :                 .required(true),
      61            0 :         )
      62            0 : }
      63              : 
      64            0 : pub async fn run() -> anyhow::Result<()> {
      65            0 :     let _logging_guard = crate::logging::init().await?;
      66            0 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
      67            0 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
      68            0 : 
      69            0 :     Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
      70            0 : 
      71            0 :     let args = cli().get_matches();
      72            0 :     let destination: String = args
      73            0 :         .get_one::<String>("dest")
      74            0 :         .expect("string argument defined")
      75            0 :         .parse()?;
      76              : 
      77              :     // Configure TLS
      78            0 :     let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
      79            0 :         args.get_one::<String>("tls-key"),
      80            0 :         args.get_one::<String>("tls-cert"),
      81              :     ) {
      82            0 :         (Some(key_path), Some(cert_path)) => {
      83            0 :             let key = {
      84            0 :                 let key_bytes = std::fs::read(key_path).context("TLS key file")?;
      85              : 
      86            0 :                 let mut keys =
      87            0 :                     rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
      88            0 : 
      89            0 :                 ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
      90              :                 PrivateKeyDer::Pkcs8(
      91            0 :                     keys.pop()
      92            0 :                         .expect("keys should not be empty")
      93            0 :                         .context(format!("Failed to read TLS keys at '{key_path}'"))?,
      94              :                 )
      95              :             };
      96              : 
      97            0 :             let cert_chain_bytes = std::fs::read(cert_path)
      98            0 :                 .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
      99              : 
     100            0 :             let cert_chain: Vec<_> = {
     101            0 :                 rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     102            0 :                 .try_collect()
     103            0 :                 .with_context(|| {
     104            0 :                     format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
     105            0 :                 })?
     106              :             };
     107              : 
     108              :             // needed for channel bindings
     109            0 :             let first_cert = cert_chain.first().context("missing certificate")?;
     110            0 :             let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
     111              : 
     112            0 :             let tls_config =
     113            0 :                 rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
     114            0 :                     .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
     115            0 :                     .context("ring should support TLS1.2 and TLS1.3")?
     116            0 :                     .with_no_client_auth()
     117            0 :                     .with_single_cert(cert_chain, key)?
     118            0 :                     .into();
     119            0 : 
     120            0 :             (tls_config, tls_server_end_point)
     121              :         }
     122            0 :         _ => bail!("tls-key and tls-cert must be specified"),
     123              :     };
     124              : 
     125              :     // Start listening for incoming client connections
     126            0 :     let proxy_address: SocketAddr = args
     127            0 :         .get_one::<String>("listen")
     128            0 :         .expect("string argument defined")
     129            0 :         .parse()?;
     130            0 :     info!("Starting sni router on {proxy_address}");
     131            0 :     let proxy_listener = TcpListener::bind(proxy_address).await?;
     132              : 
     133            0 :     let cancellation_token = CancellationToken::new();
     134            0 : 
     135            0 :     let main = tokio::spawn(task_main(
     136            0 :         Arc::new(destination),
     137            0 :         tls_config,
     138            0 :         tls_server_end_point,
     139            0 :         proxy_listener,
     140            0 :         cancellation_token.clone(),
     141            0 :     ));
     142            0 :     let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {}));
     143            0 : 
     144            0 :     // the signal task cant ever succeed.
     145            0 :     // the main task can error, or can succeed on cancellation.
     146            0 :     // we want to immediately exit on either of these cases
     147            0 :     let signal = match futures::future::select(signals_task, main).await {
     148            0 :         Either::Left((res, _)) => crate::error::flatten_err(res)?,
     149            0 :         Either::Right((res, _)) => return crate::error::flatten_err(res),
     150              :     };
     151              : 
     152              :     // maintenance tasks return `Infallible` success values, this is an impossible value
     153              :     // so this match statically ensures that there are no possibilities for that value
     154              :     match signal {}
     155            0 : }
     156              : 
     157            0 : async fn task_main(
     158            0 :     dest_suffix: Arc<String>,
     159            0 :     tls_config: Arc<rustls::ServerConfig>,
     160            0 :     tls_server_end_point: TlsServerEndPoint,
     161            0 :     listener: tokio::net::TcpListener,
     162            0 :     cancellation_token: CancellationToken,
     163            0 : ) -> anyhow::Result<()> {
     164            0 :     // When set for the server socket, the keepalive setting
     165            0 :     // will be inherited by all accepted client sockets.
     166            0 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
     167              : 
     168            0 :     let connections = tokio_util::task::task_tracker::TaskTracker::new();
     169              : 
     170            0 :     while let Some(accept_result) =
     171            0 :         run_until_cancelled(listener.accept(), &cancellation_token).await
     172              :     {
     173            0 :         let (socket, peer_addr) = accept_result?;
     174              : 
     175            0 :         let session_id = uuid::Uuid::new_v4();
     176            0 :         let tls_config = Arc::clone(&tls_config);
     177            0 :         let dest_suffix = Arc::clone(&dest_suffix);
     178            0 : 
     179            0 :         connections.spawn(
     180            0 :             async move {
     181            0 :                 socket
     182            0 :                     .set_nodelay(true)
     183            0 :                     .context("failed to set socket option")?;
     184              : 
     185            0 :                 info!(%peer_addr, "serving");
     186            0 :                 let ctx = RequestContext::new(
     187            0 :                     session_id,
     188            0 :                     ConnectionInfo {
     189            0 :                         addr: peer_addr,
     190            0 :                         extra: None,
     191            0 :                     },
     192            0 :                     crate::metrics::Protocol::SniRouter,
     193            0 :                     "sni",
     194            0 :                 );
     195            0 :                 handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
     196            0 :             }
     197            0 :             .unwrap_or_else(|e| {
     198            0 :                 // Acknowledge that the task has finished with an error.
     199            0 :                 error!("per-client task finished with an error: {e:#}");
     200            0 :             })
     201            0 :             .instrument(tracing::info_span!("handle_client", ?session_id)),
     202              :         );
     203              :     }
     204              : 
     205            0 :     connections.close();
     206            0 :     drop(listener);
     207            0 : 
     208            0 :     connections.wait().await;
     209              : 
     210            0 :     info!("all client connections have finished");
     211            0 :     Ok(())
     212            0 : }
     213              : 
     214              : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
     215              : 
     216            0 : async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
     217            0 :     ctx: &RequestContext,
     218            0 :     raw_stream: S,
     219            0 :     tls_config: Arc<rustls::ServerConfig>,
     220            0 :     tls_server_end_point: TlsServerEndPoint,
     221            0 : ) -> anyhow::Result<Stream<S>> {
     222            0 :     let mut stream = PqStream::new(Stream::from_raw(raw_stream));
     223              : 
     224            0 :     let msg = stream.read_startup_packet().await?;
     225              :     use pq_proto::FeStartupPacket::SslRequest;
     226              : 
     227            0 :     match msg {
     228              :         SslRequest { direct: false } => {
     229            0 :             stream
     230            0 :                 .write_message(&pq_proto::BeMessage::EncryptionResponse(true))
     231            0 :                 .await?;
     232              : 
     233              :             // Upgrade raw stream into a secure TLS-backed stream.
     234              :             // NOTE: We've consumed `tls`; this fact will be used later.
     235              : 
     236            0 :             let (raw, read_buf) = stream.into_inner();
     237            0 :             // TODO: Normally, client doesn't send any data before
     238            0 :             // server says TLS handshake is ok and read_buf is empty.
     239            0 :             // However, you could imagine pipelining of postgres
     240            0 :             // SSLRequest + TLS ClientHello in one hunk similar to
     241            0 :             // pipelining in our node js driver. We should probably
     242            0 :             // support that by chaining read_buf with the stream.
     243            0 :             if !read_buf.is_empty() {
     244            0 :                 bail!("data is sent before server replied with EncryptionResponse");
     245            0 :             }
     246            0 : 
     247            0 :             Ok(Stream::Tls {
     248            0 :                 tls: Box::new(
     249            0 :                     raw.upgrade(tls_config, !ctx.has_private_peer_addr())
     250            0 :                         .await?,
     251              :                 ),
     252            0 :                 tls_server_end_point,
     253              :             })
     254              :         }
     255            0 :         unexpected => {
     256            0 :             info!(
     257              :                 ?unexpected,
     258            0 :                 "unexpected startup packet, rejecting connection"
     259              :             );
     260            0 :             stream
     261            0 :                 .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
     262            0 :                 .await?
     263              :         }
     264              :     }
     265            0 : }
     266              : 
     267            0 : async fn handle_client(
     268            0 :     ctx: RequestContext,
     269            0 :     dest_suffix: Arc<String>,
     270            0 :     tls_config: Arc<rustls::ServerConfig>,
     271            0 :     tls_server_end_point: TlsServerEndPoint,
     272            0 :     stream: impl AsyncRead + AsyncWrite + Unpin,
     273            0 : ) -> anyhow::Result<()> {
     274            0 :     let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
     275              : 
     276              :     // Cut off first part of the SNI domain
     277              :     // We receive required destination details in the format of
     278              :     //   `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
     279            0 :     let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
     280            0 :     let dest: Vec<&str> = sni
     281            0 :         .split_once('.')
     282            0 :         .context("invalid SNI")?
     283              :         .0
     284            0 :         .splitn(3, "--")
     285            0 :         .collect();
     286            0 :     let port = dest[2].parse::<u16>().context("invalid port")?;
     287            0 :     let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
     288            0 : 
     289            0 :     info!("destination: {}", destination);
     290              : 
     291            0 :     let mut client = tokio::net::TcpStream::connect(destination).await?;
     292              : 
     293              :     // doesn't yet matter as pg-sni-router doesn't report analytics logs
     294            0 :     ctx.set_success();
     295            0 :     ctx.log_connect();
     296            0 : 
     297            0 :     // Starting from here we only proxy the client's traffic.
     298            0 :     info!("performing the proxy pass...");
     299              : 
     300            0 :     match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await {
     301            0 :         Ok(_) => Ok(()),
     302            0 :         Err(ErrorSource::Client(err)) => Err(err).context("client"),
     303            0 :         Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
     304              :     }
     305            0 : }
        

Generated by: LCOV version 2.1-beta