LCOV - differential code coverage report
Current view: top level - proxy/src/bin - pg_sni_router.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 94.7 % 187 177 10 177
Current Date: 2024-01-09 02:06:09 Functions: 94.1 % 17 16 1 16
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : /// A stand-alone program that routes connections, e.g. from
       2                 : /// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
       3                 : ///
       4                 : /// This allows connecting to pods/services running in the same Kubernetes cluster from
       5                 : /// the outside. Similar to an ingress controller for HTTPS.
       6                 : use std::{net::SocketAddr, sync::Arc};
       7                 : 
       8                 : use futures::future::Either;
       9                 : use itertools::Itertools;
      10                 : use proxy::config::TlsServerEndPoint;
      11                 : use proxy::context::RequestMonitoring;
      12                 : use proxy::proxy::run_until_cancelled;
      13                 : use tokio::net::TcpListener;
      14                 : 
      15                 : use anyhow::{anyhow, bail, ensure, Context};
      16                 : use clap::{self, Arg};
      17                 : use futures::TryFutureExt;
      18                 : use proxy::console::messages::MetricsAuxInfo;
      19                 : use proxy::stream::{PqStream, Stream};
      20                 : 
      21                 : use tokio::io::{AsyncRead, AsyncWrite};
      22                 : use tokio_util::sync::CancellationToken;
      23                 : use utils::{project_git_version, sentry_init::init_sentry};
      24                 : 
      25                 : use tracing::{error, info, Instrument};
      26                 : 
      27                 : project_git_version!(GIT_VERSION);
      28                 : 
      29 CBC           1 : fn cli() -> clap::Command {
      30               1 :     clap::Command::new("Neon proxy/router")
      31               1 :         .version(GIT_VERSION)
      32               1 :         .arg(
      33               1 :             Arg::new("listen")
      34               1 :                 .short('l')
      35               1 :                 .long("listen")
      36               1 :                 .help("listen for incoming client connections on ip:port")
      37               1 :                 .default_value("127.0.0.1:4432"),
      38               1 :         )
      39               1 :         .arg(
      40               1 :             Arg::new("tls-key")
      41               1 :                 .short('k')
      42               1 :                 .long("tls-key")
      43               1 :                 .help("path to TLS key for client postgres connections")
      44               1 :                 .required(true),
      45               1 :         )
      46               1 :         .arg(
      47               1 :             Arg::new("tls-cert")
      48               1 :                 .short('c')
      49               1 :                 .long("tls-cert")
      50               1 :                 .help("path to TLS cert for client postgres connections")
      51               1 :                 .required(true),
      52               1 :         )
      53               1 :         .arg(
      54               1 :             Arg::new("dest")
      55               1 :                 .short('d')
      56               1 :                 .long("destination")
      57               1 :                 .help("append this domain zone to the SNI hostname to get the destination address")
      58               1 :                 .required(true),
      59               1 :         )
      60               1 : }
      61                 : 
      62                 : #[tokio::main]
      63               1 : async fn main() -> anyhow::Result<()> {
      64               1 :     let _logging_guard = proxy::logging::init().await?;
      65               1 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
      66               1 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
      67               1 : 
      68               1 :     let args = cli().get_matches();
      69               1 :     let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
      70                 : 
      71                 :     // Configure TLS
      72               1 :     let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
      73               1 :         args.get_one::<String>("tls-key"),
      74               1 :         args.get_one::<String>("tls-cert"),
      75                 :     ) {
      76               1 :         (Some(key_path), Some(cert_path)) => {
      77               1 :             let key = {
      78               1 :                 let key_bytes = std::fs::read(key_path).context("TLS key file")?;
      79               1 :                 let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
      80               1 :                     .context(format!("Failed to read TLS keys at '{key_path}'"))?;
      81                 : 
      82               1 :                 ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
      83               1 :                 keys.pop().map(rustls::PrivateKey).unwrap()
      84                 :             };
      85                 : 
      86               1 :             let cert_chain_bytes = std::fs::read(cert_path)
      87               1 :                 .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
      88                 : 
      89               1 :             let cert_chain = {
      90               1 :                 rustls_pemfile::certs(&mut &cert_chain_bytes[..])
      91               1 :                     .context(format!(
      92               1 :                         "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
      93               1 :                     ))?
      94               1 :                     .into_iter()
      95               1 :                     .map(rustls::Certificate)
      96               1 :                     .collect_vec()
      97                 :             };
      98                 : 
      99                 :             // needed for channel bindings
     100               1 :             let first_cert = cert_chain.first().context("missing certificate")?;
     101               1 :             let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
     102                 : 
     103               1 :             let tls_config = rustls::ServerConfig::builder()
     104               1 :                 .with_safe_default_cipher_suites()
     105               1 :                 .with_safe_default_kx_groups()
     106               1 :                 .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
     107               1 :                 .with_no_client_auth()
     108               1 :                 .with_single_cert(cert_chain, key)?
     109               1 :                 .into();
     110               1 : 
     111               1 :             (tls_config, tls_server_end_point)
     112                 :         }
     113 UBC           0 :         _ => bail!("tls-key and tls-cert must be specified"),
     114                 :     };
     115                 : 
     116                 :     // Start listening for incoming client connections
     117 CBC           1 :     let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
     118               1 :     info!("Starting sni router on {proxy_address}");
     119               1 :     let proxy_listener = TcpListener::bind(proxy_address).await?;
     120                 : 
     121               1 :     let cancellation_token = CancellationToken::new();
     122               1 : 
     123               1 :     let main = tokio::spawn(task_main(
     124               1 :         Arc::new(destination),
     125               1 :         tls_config,
     126               1 :         tls_server_end_point,
     127               1 :         proxy_listener,
     128               1 :         cancellation_token.clone(),
     129               1 :     ));
     130               1 :     let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token));
     131                 : 
     132                 :     // the signal task cant ever succeed.
     133                 :     // the main task can error, or can succeed on cancellation.
     134                 :     // we want to immediately exit on either of these cases
     135               1 :     let signal = match futures::future::select(signals_task, main).await {
     136 UBC           0 :         Either::Left((res, _)) => proxy::flatten_err(res)?,
     137 CBC           1 :         Either::Right((res, _)) => return proxy::flatten_err(res),
     138                 :     };
     139                 : 
     140                 :     // maintenance tasks return `Infallible` success values, this is an impossible value
     141                 :     // so this match statically ensures that there are no possibilities for that value
     142                 :     match signal {}
     143                 : }
     144                 : 
     145               1 : async fn task_main(
     146               1 :     dest_suffix: Arc<String>,
     147               1 :     tls_config: Arc<rustls::ServerConfig>,
     148               1 :     tls_server_end_point: TlsServerEndPoint,
     149               1 :     listener: tokio::net::TcpListener,
     150               1 :     cancellation_token: CancellationToken,
     151               1 : ) -> anyhow::Result<()> {
     152               1 :     // When set for the server socket, the keepalive setting
     153               1 :     // will be inherited by all accepted client sockets.
     154               1 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
     155                 : 
     156               1 :     let connections = tokio_util::task::task_tracker::TaskTracker::new();
     157                 : 
     158               2 :     while let Some(accept_result) =
     159               3 :         run_until_cancelled(listener.accept(), &cancellation_token).await
     160                 :     {
     161               2 :         let (socket, peer_addr) = accept_result?;
     162                 : 
     163               2 :         let session_id = uuid::Uuid::new_v4();
     164               2 :         let tls_config = Arc::clone(&tls_config);
     165               2 :         let dest_suffix = Arc::clone(&dest_suffix);
     166               2 : 
     167               2 :         connections.spawn(
     168               2 :             async move {
     169               2 :                 socket
     170               2 :                     .set_nodelay(true)
     171               2 :                     .context("failed to set socket option")?;
     172                 : 
     173               2 :                 info!(%peer_addr, "serving");
     174               2 :                 let mut ctx =
     175               2 :                     RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
     176               2 :                 handle_client(
     177               2 :                     &mut ctx,
     178               2 :                     dest_suffix,
     179               2 :                     tls_config,
     180               2 :                     tls_server_end_point,
     181               2 :                     socket,
     182               2 :                 )
     183               8 :                 .await
     184               2 :             }
     185               2 :             .unwrap_or_else(|e| {
     186               2 :                 // Acknowledge that the task has finished with an error.
     187               2 :                 error!("per-client task finished with an error: {e:#}");
     188               2 :             })
     189               2 :             .instrument(tracing::info_span!("handle_client", ?session_id)),
     190                 :         );
     191                 :     }
     192                 : 
     193               1 :     connections.close();
     194               1 :     drop(listener);
     195               1 : 
     196               1 :     connections.wait().await;
     197                 : 
     198               1 :     info!("all client connections have finished");
     199               1 :     Ok(())
     200               1 : }
     201                 : 
     202                 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
     203                 : 
     204               2 : async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
     205               2 :     raw_stream: S,
     206               2 :     tls_config: Arc<rustls::ServerConfig>,
     207               2 :     tls_server_end_point: TlsServerEndPoint,
     208               2 : ) -> anyhow::Result<Stream<S>> {
     209               2 :     let mut stream = PqStream::new(Stream::from_raw(raw_stream));
     210                 : 
     211               2 :     let msg = stream.read_startup_packet().await?;
     212                 :     use pq_proto::FeStartupPacket::*;
     213                 : 
     214               1 :     match msg {
     215                 :         SslRequest => {
     216               1 :             stream
     217               1 :                 .write_message(&pq_proto::BeMessage::EncryptionResponse(true))
     218 UBC           0 :                 .await?;
     219                 :             // Upgrade raw stream into a secure TLS-backed stream.
     220                 :             // NOTE: We've consumed `tls`; this fact will be used later.
     221                 : 
     222 CBC           1 :             let (raw, read_buf) = stream.into_inner();
     223               1 :             // TODO: Normally, client doesn't send any data before
     224               1 :             // server says TLS handshake is ok and read_buf is empy.
     225               1 :             // However, you could imagine pipelining of postgres
     226               1 :             // SSLRequest + TLS ClientHello in one hunk similar to
     227               1 :             // pipelining in our node js driver. We should probably
     228               1 :             // support that by chaining read_buf with the stream.
     229               1 :             if !read_buf.is_empty() {
     230 UBC           0 :                 bail!("data is sent before server replied with EncryptionResponse");
     231 CBC           1 :             }
     232               1 : 
     233               1 :             Ok(Stream::Tls {
     234               2 :                 tls: Box::new(raw.upgrade(tls_config).await?),
     235               1 :                 tls_server_end_point,
     236                 :             })
     237                 :         }
     238 UBC           0 :         unexpected => {
     239               0 :             info!(
     240               0 :                 ?unexpected,
     241               0 :                 "unexpected startup packet, rejecting connection"
     242               0 :             );
     243               0 :             stream.throw_error_str(ERR_INSECURE_CONNECTION).await?
     244                 :         }
     245                 :     }
     246 CBC           2 : }
     247                 : 
     248               2 : async fn handle_client(
     249               2 :     ctx: &mut RequestMonitoring,
     250               2 :     dest_suffix: Arc<String>,
     251               2 :     tls_config: Arc<rustls::ServerConfig>,
     252               2 :     tls_server_end_point: TlsServerEndPoint,
     253               2 :     stream: impl AsyncRead + AsyncWrite + Unpin,
     254               2 : ) -> anyhow::Result<()> {
     255               2 :     let tls_stream = ssl_handshake(stream, tls_config, tls_server_end_point).await?;
     256                 : 
     257                 :     // Cut off first part of the SNI domain
     258                 :     // We receive required destination details in the format of
     259                 :     //   `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
     260               1 :     let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
     261               1 :     let dest: Vec<&str> = sni
     262               1 :         .split_once('.')
     263               1 :         .context("invalid SNI")?
     264                 :         .0
     265               1 :         .splitn(3, "--")
     266               1 :         .collect();
     267               1 :     let port = dest[2].parse::<u16>().context("invalid port")?;
     268               1 :     let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
     269                 : 
     270               1 :     info!("destination: {}", destination);
     271                 : 
     272               1 :     let client = tokio::net::TcpStream::connect(destination).await?;
     273                 : 
     274               1 :     let metrics_aux: MetricsAuxInfo = Default::default();
     275               5 :     proxy::proxy::proxy_pass(ctx, tls_stream, client, metrics_aux).await
     276               2 : }
        

Generated by: LCOV version 2.1-beta