LCOV - code coverage report
Current view: top level - proxy/src/bin - pg_sni_router.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 91.9 % 161 148
Test Date: 2023-09-06 10:18:01 Functions: 90.0 % 20 18

            Line data    Source code
       1              : /// A stand-alone program that routes connections, e.g. from
       2              : /// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
       3              : ///
       4              : /// This allows connecting to pods/services running in the same Kubernetes cluster from
       5              : /// the outside. Similar to an ingress controller for HTTPS.
       6              : use std::{net::SocketAddr, sync::Arc};
       7              : 
       8              : use futures::future::Either;
       9              : use tokio::net::TcpListener;
      10              : 
      11              : use anyhow::{anyhow, bail, ensure, Context};
      12              : use clap::{self, Arg};
      13              : use futures::TryFutureExt;
      14              : use proxy::console::messages::MetricsAuxInfo;
      15              : use proxy::stream::{PqStream, Stream};
      16              : 
      17              : use tokio::io::{AsyncRead, AsyncWrite};
      18              : use tokio_util::sync::CancellationToken;
      19              : use utils::{project_git_version, sentry_init::init_sentry};
      20              : 
      21              : use tracing::{error, info, warn, Instrument};
      22              : 
      23              : project_git_version!(GIT_VERSION);
      24              : 
      25            1 : fn cli() -> clap::Command {
      26            1 :     clap::Command::new("Neon proxy/router")
      27            1 :         .version(GIT_VERSION)
      28            1 :         .arg(
      29            1 :             Arg::new("listen")
      30            1 :                 .short('l')
      31            1 :                 .long("listen")
      32            1 :                 .help("listen for incoming client connections on ip:port")
      33            1 :                 .default_value("127.0.0.1:4432"),
      34            1 :         )
      35            1 :         .arg(
      36            1 :             Arg::new("tls-key")
      37            1 :                 .short('k')
      38            1 :                 .long("tls-key")
      39            1 :                 .help("path to TLS key for client postgres connections")
      40            1 :                 .required(true),
      41            1 :         )
      42            1 :         .arg(
      43            1 :             Arg::new("tls-cert")
      44            1 :                 .short('c')
      45            1 :                 .long("tls-cert")
      46            1 :                 .help("path to TLS cert for client postgres connections")
      47            1 :                 .required(true),
      48            1 :         )
      49            1 :         .arg(
      50            1 :             Arg::new("dest")
      51            1 :                 .short('d')
      52            1 :                 .long("destination")
      53            1 :                 .help("append this domain zone to the SNI hostname to get the destination address")
      54            1 :                 .required(true),
      55            1 :         )
      56            1 : }
      57              : 
      58              : #[tokio::main]
      59            1 : async fn main() -> anyhow::Result<()> {
      60            1 :     let _logging_guard = proxy::logging::init().await?;
      61            1 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
      62            1 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
      63            1 : 
      64            1 :     let args = cli().get_matches();
      65            1 :     let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
      66              : 
      67              :     // Configure TLS
      68            1 :     let tls_config: Arc<rustls::ServerConfig> = match (
      69            1 :         args.get_one::<String>("tls-key"),
      70            1 :         args.get_one::<String>("tls-cert"),
      71              :     ) {
      72            1 :         (Some(key_path), Some(cert_path)) => {
      73            1 :             let key = {
      74            1 :                 let key_bytes = std::fs::read(key_path).context("TLS key file")?;
      75            1 :                 let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
      76            1 :                     .context(format!("Failed to read TLS keys at '{key_path}'"))?;
      77              : 
      78            1 :                 ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
      79            1 :                 keys.pop().map(rustls::PrivateKey).unwrap()
      80              :             };
      81              : 
      82            1 :             let cert_chain_bytes = std::fs::read(cert_path)
      83            1 :                 .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
      84              : 
      85            1 :             let cert_chain = {
      86            1 :                 rustls_pemfile::certs(&mut &cert_chain_bytes[..])
      87            1 :                     .context(format!(
      88            1 :                         "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
      89            1 :                     ))?
      90            1 :                     .into_iter()
      91            1 :                     .map(rustls::Certificate)
      92            1 :                     .collect()
      93            1 :             };
      94            1 : 
      95            1 :             rustls::ServerConfig::builder()
      96            1 :                 .with_safe_default_cipher_suites()
      97            1 :                 .with_safe_default_kx_groups()
      98            1 :                 .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
      99            1 :                 .with_no_client_auth()
     100            1 :                 .with_single_cert(cert_chain, key)?
     101            1 :                 .into()
     102              :         }
     103            0 :         _ => bail!("tls-key and tls-cert must be specified"),
     104              :     };
     105              : 
     106              :     // Start listening for incoming client connections
     107            1 :     let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
     108            1 :     info!("Starting sni router on {proxy_address}");
     109            1 :     let proxy_listener = TcpListener::bind(proxy_address).await?;
     110              : 
     111            1 :     let cancellation_token = CancellationToken::new();
     112            1 : 
     113            1 :     let main = tokio::spawn(task_main(
     114            1 :         Arc::new(destination),
     115            1 :         tls_config,
     116            1 :         proxy_listener,
     117            1 :         cancellation_token.clone(),
     118            1 :     ));
     119            1 :     let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token));
     120              : 
     121              :     // the signal task cant ever succeed.
     122              :     // the main task can error, or can succeed on cancellation.
     123              :     // we want to immediately exit on either of these cases
     124            1 :     let signal = match futures::future::select(signals_task, main).await {
     125            0 :         Either::Left((res, _)) => proxy::flatten_err(res)?,
     126            1 :         Either::Right((res, _)) => return proxy::flatten_err(res),
     127              :     };
     128              : 
     129              :     // maintenance tasks return `Infallible` success values, this is an impossible value
     130              :     // so this match statically ensures that there are no possibilities for that value
     131              :     match signal {}
     132              : }
     133              : 
     134            1 : async fn task_main(
     135            1 :     dest_suffix: Arc<String>,
     136            1 :     tls_config: Arc<rustls::ServerConfig>,
     137            1 :     listener: tokio::net::TcpListener,
     138            1 :     cancellation_token: CancellationToken,
     139            1 : ) -> anyhow::Result<()> {
     140              :     // When set for the server socket, the keepalive setting
     141              :     // will be inherited by all accepted client sockets.
     142            1 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
     143              : 
     144            1 :     let mut connections = tokio::task::JoinSet::new();
     145              : 
     146              :     loop {
     147            3 :         tokio::select! {
     148            2 :             accept_result = listener.accept() => {
     149              :                 let (socket, peer_addr) = accept_result?;
     150              : 
     151              :                 let session_id = uuid::Uuid::new_v4();
     152              :                 let tls_config = Arc::clone(&tls_config);
     153              :                 let dest_suffix = Arc::clone(&dest_suffix);
     154              : 
     155              :                 connections.spawn(
     156              :                     async move {
     157            2 :                         socket
     158            2 :                             .set_nodelay(true)
     159            2 :                             .context("failed to set socket option")?;
     160              : 
     161            2 :                         info!(%peer_addr, "serving");
     162           12 :                         handle_client(dest_suffix, tls_config, socket).await
     163            2 :                     }
     164            2 :                     .unwrap_or_else(|e| {
     165            2 :                         // Acknowledge that the task has finished with an error.
     166            2 :                         error!("per-client task finished with an error: {e:#}");
     167            2 :                     })
     168              :                     .instrument(tracing::info_span!("handle_client", ?session_id))
     169              :                 );
     170              :             }
     171              :             _ = cancellation_token.cancelled() => {
     172              :                 drop(listener);
     173              :                 break;
     174              :             }
     175              :         }
     176              :     }
     177              : 
     178              :     // Drain connections
     179            1 :     info!("waiting for all client connections to finish");
     180            3 :     while let Some(res) = connections.join_next().await {
     181            2 :         if let Err(e) = res {
     182            0 :             if !e.is_panic() && !e.is_cancelled() {
     183            0 :                 warn!("unexpected error from joined connection task: {e:?}");
     184            0 :             }
     185            2 :         }
     186              :     }
     187            1 :     info!("all client connections have finished");
     188            1 :     Ok(())
     189            1 : }
     190              : 
     191              : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
     192              : 
     193            2 : async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
     194            2 :     raw_stream: S,
     195            2 :     tls_config: Arc<rustls::ServerConfig>,
     196            2 : ) -> anyhow::Result<Stream<S>> {
     197            2 :     let mut stream = PqStream::new(Stream::from_raw(raw_stream));
     198              : 
     199            2 :     let msg = stream.read_startup_packet().await?;
     200              :     use pq_proto::FeStartupPacket::*;
     201              : 
     202            1 :     match msg {
     203              :         SslRequest => {
     204            1 :             stream
     205            1 :                 .write_message(&pq_proto::BeMessage::EncryptionResponse(true))
     206            0 :                 .await?;
     207              :             // Upgrade raw stream into a secure TLS-backed stream.
     208              :             // NOTE: We've consumed `tls`; this fact will be used later.
     209              : 
     210            1 :             let (raw, read_buf) = stream.into_inner();
     211            1 :             // TODO: Normally, client doesn't send any data before
     212            1 :             // server says TLS handshake is ok and read_buf is empy.
     213            1 :             // However, you could imagine pipelining of postgres
     214            1 :             // SSLRequest + TLS ClientHello in one hunk similar to
     215            1 :             // pipelining in our node js driver. We should probably
     216            1 :             // support that by chaining read_buf with the stream.
     217            1 :             if !read_buf.is_empty() {
     218            0 :                 bail!("data is sent before server replied with EncryptionResponse");
     219            1 :             }
     220            2 :             Ok(raw.upgrade(tls_config).await?)
     221              :         }
     222            0 :         unexpected => {
     223            0 :             info!(
     224            0 :                 ?unexpected,
     225            0 :                 "unexpected startup packet, rejecting connection"
     226            0 :             );
     227            0 :             stream.throw_error_str(ERR_INSECURE_CONNECTION).await?
     228              :         }
     229              :     }
     230            2 : }
     231              : 
     232            2 : async fn handle_client(
     233            2 :     dest_suffix: Arc<String>,
     234            2 :     tls_config: Arc<rustls::ServerConfig>,
     235            2 :     stream: impl AsyncRead + AsyncWrite + Unpin,
     236            2 : ) -> anyhow::Result<()> {
     237            4 :     let tls_stream = ssl_handshake(stream, tls_config).await?;
     238              : 
     239              :     // Cut off first part of the SNI domain
     240              :     // We receive required destination details in the format of
     241              :     //   `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
     242            1 :     let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
     243            1 :     let dest: Vec<&str> = sni
     244            1 :         .split_once('.')
     245            1 :         .context("invalid SNI")?
     246              :         .0
     247            1 :         .splitn(3, "--")
     248            1 :         .collect();
     249            1 :     let port = dest[2].parse::<u16>().context("invalid port")?;
     250            1 :     let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
     251              : 
     252            1 :     info!("destination: {}", destination);
     253              : 
     254            3 :     let client = tokio::net::TcpStream::connect(destination).await?;
     255              : 
     256            1 :     let metrics_aux: MetricsAuxInfo = Default::default();
     257            5 :     proxy::proxy::proxy_pass(tls_stream, client, &metrics_aux).await
     258            2 : }
        

Generated by: LCOV version 2.1-beta