LCOV - differential code coverage report
Current view: top level - proxy/src/bin - pg_sni_router.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 90.8 % 163 148 15 148
Current Date: 2023-10-19 02:04:12 Functions: 85.7 % 21 18 3 18
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : /// A stand-alone program that routes connections, e.g. from
       2                 : /// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
       3                 : ///
       4                 : /// This allows connecting to pods/services running in the same Kubernetes cluster from
       5                 : /// the outside. Similar to an ingress controller for HTTPS.
       6                 : use std::{net::SocketAddr, sync::Arc};
       7                 : 
       8                 : use futures::future::Either;
       9                 : use tokio::net::TcpListener;
      10                 : 
      11                 : use anyhow::{anyhow, bail, ensure, Context};
      12                 : use clap::{self, Arg};
      13                 : use futures::TryFutureExt;
      14                 : use proxy::console::messages::MetricsAuxInfo;
      15                 : use proxy::stream::{PqStream, Stream};
      16                 : 
      17                 : use tokio::io::{AsyncRead, AsyncWrite};
      18                 : use tokio_util::sync::CancellationToken;
      19                 : use utils::{project_git_version, sentry_init::init_sentry};
      20                 : 
      21                 : use tracing::{error, info, warn, Instrument};
      22                 : 
      23                 : project_git_version!(GIT_VERSION);
      24                 : 
      25 CBC           1 : fn cli() -> clap::Command {
      26               1 :     clap::Command::new("Neon proxy/router")
      27               1 :         .version(GIT_VERSION)
      28               1 :         .arg(
      29               1 :             Arg::new("listen")
      30               1 :                 .short('l')
      31               1 :                 .long("listen")
      32               1 :                 .help("listen for incoming client connections on ip:port")
      33               1 :                 .default_value("127.0.0.1:4432"),
      34               1 :         )
      35               1 :         .arg(
      36               1 :             Arg::new("tls-key")
      37               1 :                 .short('k')
      38               1 :                 .long("tls-key")
      39               1 :                 .help("path to TLS key for client postgres connections")
      40               1 :                 .required(true),
      41               1 :         )
      42               1 :         .arg(
      43               1 :             Arg::new("tls-cert")
      44               1 :                 .short('c')
      45               1 :                 .long("tls-cert")
      46               1 :                 .help("path to TLS cert for client postgres connections")
      47               1 :                 .required(true),
      48               1 :         )
      49               1 :         .arg(
      50               1 :             Arg::new("dest")
      51               1 :                 .short('d')
      52               1 :                 .long("destination")
      53               1 :                 .help("append this domain zone to the SNI hostname to get the destination address")
      54               1 :                 .required(true),
      55               1 :         )
      56               1 : }
      57                 : 
      58                 : #[tokio::main]
      59               1 : async fn main() -> anyhow::Result<()> {
      60               1 :     let _logging_guard = proxy::logging::init().await?;
      61               1 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
      62               1 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
      63               1 : 
      64               1 :     let args = cli().get_matches();
      65               1 :     let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
      66                 : 
      67                 :     // Configure TLS
      68               1 :     let tls_config: Arc<rustls::ServerConfig> = match (
      69               1 :         args.get_one::<String>("tls-key"),
      70               1 :         args.get_one::<String>("tls-cert"),
      71                 :     ) {
      72               1 :         (Some(key_path), Some(cert_path)) => {
      73               1 :             let key = {
      74               1 :                 let key_bytes = std::fs::read(key_path).context("TLS key file")?;
      75               1 :                 let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
      76               1 :                     .context(format!("Failed to read TLS keys at '{key_path}'"))?;
      77                 : 
      78               1 :                 ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
      79               1 :                 keys.pop().map(rustls::PrivateKey).unwrap()
      80                 :             };
      81                 : 
      82               1 :             let cert_chain_bytes = std::fs::read(cert_path)
      83               1 :                 .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
      84                 : 
      85               1 :             let cert_chain = {
      86               1 :                 rustls_pemfile::certs(&mut &cert_chain_bytes[..])
      87               1 :                     .context(format!(
      88               1 :                         "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
      89               1 :                     ))?
      90               1 :                     .into_iter()
      91               1 :                     .map(rustls::Certificate)
      92               1 :                     .collect()
      93               1 :             };
      94               1 : 
      95               1 :             rustls::ServerConfig::builder()
      96               1 :                 .with_safe_default_cipher_suites()
      97               1 :                 .with_safe_default_kx_groups()
      98               1 :                 .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
      99               1 :                 .with_no_client_auth()
     100               1 :                 .with_single_cert(cert_chain, key)?
     101               1 :                 .into()
     102                 :         }
     103 UBC           0 :         _ => bail!("tls-key and tls-cert must be specified"),
     104                 :     };
     105                 : 
     106                 :     // Start listening for incoming client connections
     107 CBC           1 :     let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
     108               1 :     info!("Starting sni router on {proxy_address}");
     109               1 :     let proxy_listener = TcpListener::bind(proxy_address).await?;
     110                 : 
     111               1 :     let cancellation_token = CancellationToken::new();
     112               1 : 
     113               1 :     let main = tokio::spawn(task_main(
     114               1 :         Arc::new(destination),
     115               1 :         tls_config,
     116               1 :         proxy_listener,
     117               1 :         cancellation_token.clone(),
     118               1 :     ));
     119               1 :     let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token));
     120                 : 
     121                 :     // the signal task cant ever succeed.
     122                 :     // the main task can error, or can succeed on cancellation.
     123                 :     // we want to immediately exit on either of these cases
     124               1 :     let signal = match futures::future::select(signals_task, main).await {
     125 UBC           0 :         Either::Left((res, _)) => proxy::flatten_err(res)?,
     126 CBC           1 :         Either::Right((res, _)) => return proxy::flatten_err(res),
     127                 :     };
     128                 : 
     129                 :     // maintenance tasks return `Infallible` success values, this is an impossible value
     130                 :     // so this match statically ensures that there are no possibilities for that value
     131                 :     match signal {}
     132                 : }
     133                 : 
     134               1 : async fn task_main(
     135               1 :     dest_suffix: Arc<String>,
     136               1 :     tls_config: Arc<rustls::ServerConfig>,
     137               1 :     listener: tokio::net::TcpListener,
     138               1 :     cancellation_token: CancellationToken,
     139               1 : ) -> anyhow::Result<()> {
     140                 :     // When set for the server socket, the keepalive setting
     141                 :     // will be inherited by all accepted client sockets.
     142               1 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
     143                 : 
     144               1 :     let mut connections = tokio::task::JoinSet::new();
     145                 : 
     146                 :     loop {
     147               3 :         tokio::select! {
     148               2 :             accept_result = listener.accept() => {
     149                 :                 let (socket, peer_addr) = accept_result?;
     150                 : 
     151                 :                 let session_id = uuid::Uuid::new_v4();
     152                 :                 let tls_config = Arc::clone(&tls_config);
     153                 :                 let dest_suffix = Arc::clone(&dest_suffix);
     154                 : 
     155                 :                 connections.spawn(
     156                 :                     async move {
     157               2 :                         socket
     158               2 :                             .set_nodelay(true)
     159               2 :                             .context("failed to set socket option")?;
     160                 : 
     161               2 :                         info!(%peer_addr, "serving");
     162              11 :                         handle_client(dest_suffix, tls_config, socket).await
     163               2 :                     }
     164               2 :                     .unwrap_or_else(|e| {
     165               2 :                         // Acknowledge that the task has finished with an error.
     166               2 :                         error!("per-client task finished with an error: {e:#}");
     167               2 :                     })
     168                 :                     .instrument(tracing::info_span!("handle_client", ?session_id))
     169                 :                 );
     170                 :             }
     171 UBC           0 :             Some(Err(e)) = connections.join_next(), if !connections.is_empty() => {
     172                 :                 if !e.is_panic() && !e.is_cancelled() {
     173               0 :                     warn!("unexpected error from joined connection task: {e:?}");
     174                 :                 }
     175                 :             }
     176                 :             _ = cancellation_token.cancelled() => {
     177                 :                 drop(listener);
     178                 :                 break;
     179                 :             }
     180                 :         }
     181                 :     }
     182                 : 
     183                 :     // Drain connections
     184 CBC           1 :     info!("waiting for all client connections to finish");
     185               2 :     while let Some(res) = connections.join_next().await {
     186               1 :         if let Err(e) = res {
     187 UBC           0 :             if !e.is_panic() && !e.is_cancelled() {
     188               0 :                 warn!("unexpected error from joined connection task: {e:?}");
     189               0 :             }
     190 CBC           1 :         }
     191                 :     }
     192               1 :     info!("all client connections have finished");
     193               1 :     Ok(())
     194               1 : }
     195                 : 
     196                 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
     197                 : 
     198               2 : async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
     199               2 :     raw_stream: S,
     200               2 :     tls_config: Arc<rustls::ServerConfig>,
     201               2 : ) -> anyhow::Result<Stream<S>> {
     202               2 :     let mut stream = PqStream::new(Stream::from_raw(raw_stream));
     203                 : 
     204               2 :     let msg = stream.read_startup_packet().await?;
     205                 :     use pq_proto::FeStartupPacket::*;
     206                 : 
     207               1 :     match msg {
     208                 :         SslRequest => {
     209               1 :             stream
     210               1 :                 .write_message(&pq_proto::BeMessage::EncryptionResponse(true))
     211 UBC           0 :                 .await?;
     212                 :             // Upgrade raw stream into a secure TLS-backed stream.
     213                 :             // NOTE: We've consumed `tls`; this fact will be used later.
     214                 : 
     215 CBC           1 :             let (raw, read_buf) = stream.into_inner();
     216               1 :             // TODO: Normally, client doesn't send any data before
     217               1 :             // server says TLS handshake is ok and read_buf is empy.
     218               1 :             // However, you could imagine pipelining of postgres
     219               1 :             // SSLRequest + TLS ClientHello in one hunk similar to
     220               1 :             // pipelining in our node js driver. We should probably
     221               1 :             // support that by chaining read_buf with the stream.
     222               1 :             if !read_buf.is_empty() {
     223 UBC           0 :                 bail!("data is sent before server replied with EncryptionResponse");
     224 CBC           1 :             }
     225               2 :             Ok(raw.upgrade(tls_config).await?)
     226                 :         }
     227 UBC           0 :         unexpected => {
     228               0 :             info!(
     229               0 :                 ?unexpected,
     230               0 :                 "unexpected startup packet, rejecting connection"
     231               0 :             );
     232               0 :             stream.throw_error_str(ERR_INSECURE_CONNECTION).await?
     233                 :         }
     234                 :     }
     235 CBC           2 : }
     236                 : 
     237               2 : async fn handle_client(
     238               2 :     dest_suffix: Arc<String>,
     239               2 :     tls_config: Arc<rustls::ServerConfig>,
     240               2 :     stream: impl AsyncRead + AsyncWrite + Unpin,
     241               2 : ) -> anyhow::Result<()> {
     242               4 :     let tls_stream = ssl_handshake(stream, tls_config).await?;
     243                 : 
     244                 :     // Cut off first part of the SNI domain
     245                 :     // We receive required destination details in the format of
     246                 :     //   `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
     247               1 :     let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
     248               1 :     let dest: Vec<&str> = sni
     249               1 :         .split_once('.')
     250               1 :         .context("invalid SNI")?
     251                 :         .0
     252               1 :         .splitn(3, "--")
     253               1 :         .collect();
     254               1 :     let port = dest[2].parse::<u16>().context("invalid port")?;
     255               1 :     let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
     256                 : 
     257               1 :     info!("destination: {}", destination);
     258                 : 
     259               2 :     let client = tokio::net::TcpStream::connect(destination).await?;
     260                 : 
     261               1 :     let metrics_aux: MetricsAuxInfo = Default::default();
     262               5 :     proxy::proxy::proxy_pass(tls_stream, client, &metrics_aux).await
     263               2 : }
        

Generated by: LCOV version 2.1-beta