LCOV - code coverage report
Current view: top level - proxy/src/tls - server_config.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 31.7 % 139 44
Test Date: 2025-07-16 12:29:03 Functions: 28.6 % 14 4

            Line data    Source code
       1              : use std::collections::{HashMap, HashSet};
       2              : use std::path::Path;
       3              : use std::sync::Arc;
       4              : 
       5              : use anyhow::{Context, bail};
       6              : use itertools::Itertools;
       7              : use rustls::crypto::ring::{self, sign};
       8              : use rustls::pki_types::{CertificateDer, PrivateKeyDer};
       9              : use rustls::sign::CertifiedKey;
      10              : use x509_cert::der::{Reader, SliceReader};
      11              : 
      12              : use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
      13              : 
      14              : pub struct TlsConfig {
      15              :     // unfortunate split since we cannot change the ALPN on demand.
      16              :     // <https://github.com/rustls/rustls/issues/2260>
      17              :     pub http_config: Arc<rustls::ServerConfig>,
      18              :     pub pg_config: Arc<rustls::ServerConfig>,
      19              :     pub common_names: HashSet<String>,
      20              :     pub cert_resolver: Arc<CertResolver>,
      21              : }
      22              : 
      23              : /// Configure TLS for the main endpoint.
      24            0 : pub fn configure_tls(
      25            0 :     key_path: &Path,
      26            0 :     cert_path: &Path,
      27            0 :     certs_dir: Option<&Path>,
      28            0 :     allow_tls_keylogfile: bool,
      29            0 : ) -> anyhow::Result<TlsConfig> {
      30              :     // add default certificate
      31            0 :     let mut cert_resolver = CertResolver::parse_new(key_path, cert_path)?;
      32              : 
      33              :     // add extra certificates
      34            0 :     if let Some(certs_dir) = certs_dir {
      35            0 :         for entry in std::fs::read_dir(certs_dir)? {
      36            0 :             let entry = entry?;
      37            0 :             let path = entry.path();
      38            0 :             if path.is_dir() {
      39              :                 // file names aligned with default cert-manager names
      40            0 :                 let key_path = path.join("tls.key");
      41            0 :                 let cert_path = path.join("tls.crt");
      42            0 :                 if key_path.exists() && cert_path.exists() {
      43            0 :                     cert_resolver.add_cert_path(&key_path, &cert_path)?;
      44            0 :                 }
      45            0 :             }
      46              :         }
      47            0 :     }
      48              : 
      49            0 :     let common_names = cert_resolver.get_common_names();
      50              : 
      51            0 :     let cert_resolver = Arc::new(cert_resolver);
      52              : 
      53              :     // allow TLS 1.2 to be compatible with older client libraries
      54            0 :     let mut config =
      55            0 :         rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
      56            0 :             .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
      57            0 :             .context("ring should support TLS1.2 and TLS1.3")?
      58            0 :             .with_no_client_auth()
      59            0 :             .with_cert_resolver(cert_resolver.clone());
      60              : 
      61            0 :     config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
      62              : 
      63            0 :     if allow_tls_keylogfile {
      64            0 :         // KeyLogFile will check for the SSLKEYLOGFILE environment variable.
      65            0 :         config.key_log = Arc::new(rustls::KeyLogFile::new());
      66            0 :     }
      67              : 
      68            0 :     let mut http_config = config.clone();
      69            0 :     let mut pg_config = config;
      70              : 
      71            0 :     http_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
      72            0 :     pg_config.alpn_protocols = vec![b"postgresql".to_vec()];
      73              : 
      74            0 :     Ok(TlsConfig {
      75            0 :         http_config: Arc::new(http_config),
      76            0 :         pg_config: Arc::new(pg_config),
      77            0 :         common_names,
      78            0 :         cert_resolver,
      79            0 :     })
      80            0 : }
      81              : 
      82              : #[derive(Debug)]
      83              : pub struct CertResolver {
      84              :     certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
      85              :     default: (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint),
      86              : }
      87              : 
      88              : impl CertResolver {
      89            0 :     fn parse_new(key_path: &Path, cert_path: &Path) -> anyhow::Result<Self> {
      90            0 :         let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
      91            0 :         Self::new(priv_key, cert_chain)
      92            0 :     }
      93              : 
      94           21 :     pub fn new(
      95           21 :         priv_key: PrivateKeyDer<'static>,
      96           21 :         cert_chain: Vec<CertificateDer<'static>>,
      97           21 :     ) -> anyhow::Result<Self> {
      98           21 :         let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
      99              : 
     100           21 :         let mut certs = HashMap::new();
     101           21 :         let default = (cert.clone(), tls_server_end_point);
     102           21 :         certs.insert(common_name, (cert, tls_server_end_point));
     103           21 :         Ok(Self { certs, default })
     104           21 :     }
     105              : 
     106            0 :     fn add_cert_path(&mut self, key_path: &Path, cert_path: &Path) -> anyhow::Result<()> {
     107            0 :         let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
     108            0 :         self.add_cert(priv_key, cert_chain)
     109            0 :     }
     110              : 
     111            0 :     fn add_cert(
     112            0 :         &mut self,
     113            0 :         priv_key: PrivateKeyDer<'static>,
     114            0 :         cert_chain: Vec<CertificateDer<'static>>,
     115            0 :     ) -> anyhow::Result<()> {
     116            0 :         let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
     117            0 :         self.certs.insert(common_name, (cert, tls_server_end_point));
     118            0 :         Ok(())
     119            0 :     }
     120              : 
     121           21 :     pub fn get_common_names(&self) -> HashSet<String> {
     122           21 :         self.certs.keys().cloned().collect()
     123           21 :     }
     124              : }
     125              : 
     126            0 : fn parse_key_cert(
     127            0 :     key_path: &Path,
     128            0 :     cert_path: &Path,
     129            0 : ) -> anyhow::Result<(PrivateKeyDer<'static>, Vec<CertificateDer<'static>>)> {
     130            0 :     let priv_key = {
     131            0 :         let key_bytes = std::fs::read(key_path)
     132            0 :             .with_context(|| format!("Failed to read TLS keys at '{}'", key_path.display()))?;
     133            0 :         rustls_pemfile::private_key(&mut &key_bytes[..])
     134            0 :             .with_context(|| format!("Failed to parse TLS keys at '{}'", key_path.display()))?
     135            0 :             .with_context(|| format!("Failed to parse TLS keys at '{}'", key_path.display()))?
     136              :     };
     137              : 
     138            0 :     let cert_chain_bytes = std::fs::read(cert_path).context(format!(
     139            0 :         "Failed to read TLS cert file at '{}.'",
     140            0 :         cert_path.display()
     141            0 :     ))?;
     142              : 
     143            0 :     let cert_chain = {
     144            0 :         rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     145            0 :             .try_collect()
     146            0 :             .with_context(|| {
     147            0 :                 format!(
     148            0 :                     "Failed to read TLS certificate chain from bytes from file at '{}'.",
     149            0 :                     cert_path.display()
     150              :                 )
     151            0 :             })?
     152              :     };
     153              : 
     154            0 :     Ok((priv_key, cert_chain))
     155            0 : }
     156              : 
     157           21 : fn process_key_cert(
     158           21 :     priv_key: PrivateKeyDer<'static>,
     159           21 :     cert_chain: Vec<CertificateDer<'static>>,
     160           21 : ) -> anyhow::Result<(String, Arc<CertifiedKey>, TlsServerEndPoint)> {
     161           21 :     let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
     162              : 
     163           21 :     let first_cert = &cert_chain[0];
     164           21 :     let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
     165              : 
     166           21 :     let certificate = SliceReader::new(first_cert)
     167           21 :         .context("Failed to parse cerficiate")?
     168           21 :         .decode::<x509_cert::Certificate>()
     169           21 :         .context("Failed to parse cerficiate")?;
     170              : 
     171           21 :     let common_name = certificate.tbs_certificate.subject.to_string();
     172              : 
     173              :     // We need to get the canonical name for this certificate so we can match them against any domain names
     174              :     // seen within the proxy codebase.
     175              :     //
     176              :     // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
     177              :     // We need to remove the wildcard prefix for the purposes of certificate selection.
     178              :     //
     179              :     // auth-broker does not use SNI and instead uses the Neon-Connection-String header.
     180              :     // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
     181              :     //
     182              :     // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
     183              :     // validation, so let's we can continue with any common-name
     184           21 :     let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
     185            0 :         s.to_string()
     186           21 :     } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
     187            0 :         s.to_string()
     188           21 :     } else if let Some(s) = common_name.strip_prefix("CN=") {
     189           21 :         s.to_string()
     190              :     } else {
     191            0 :         bail!("Failed to parse common name from certificate")
     192              :     };
     193              : 
     194           21 :     let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
     195              : 
     196           21 :     Ok((common_name, cert, tls_server_end_point))
     197           21 : }
     198              : 
     199              : impl rustls::server::ResolvesServerCert for CertResolver {
     200            0 :     fn resolve(
     201            0 :         &self,
     202            0 :         client_hello: rustls::server::ClientHello<'_>,
     203            0 :     ) -> Option<Arc<rustls::sign::CertifiedKey>> {
     204            0 :         Some(self.resolve(client_hello.server_name()).0)
     205            0 :     }
     206              : }
     207              : 
     208              : impl CertResolver {
     209           20 :     pub fn resolve(
     210           20 :         &self,
     211           20 :         server_name: Option<&str>,
     212           20 :     ) -> (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint) {
     213              :         // loop here and cut off more and more subdomains until we find
     214              :         // a match to get a proper wildcard support. OTOH, we now do not
     215              :         // use nested domains, so keep this simple for now.
     216              :         //
     217              :         // With the current coding foo.com will match *.foo.com and that
     218              :         // repeats behavior of the old code.
     219           20 :         if let Some(mut sni_name) = server_name {
     220              :             loop {
     221           40 :                 if let Some(cert) = self.certs.get(sni_name) {
     222           20 :                     return cert.clone();
     223           20 :                 }
     224           20 :                 if let Some((_, rest)) = sni_name.split_once('.') {
     225           20 :                     sni_name = rest;
     226           20 :                 } else {
     227              :                     // The customer has some custom DNS mapping - just return
     228              :                     // a default certificate.
     229              :                     //
     230              :                     // This will error if the customer uses anything stronger
     231              :                     // than sslmode=require. That's a choice they can make.
     232            0 :                     return self.default.clone();
     233              :                 }
     234              :             }
     235              :         } else {
     236              :             // No SNI, use the default certificate, otherwise we can't get to
     237              :             // options parameter which can be used to set endpoint name too.
     238              :             // That means that non-SNI flow will not work for CNAME domains in
     239              :             // verify-full mode.
     240              :             //
     241              :             // If that will be a problem we can:
     242              :             //
     243              :             // a) Instead of multi-cert approach use single cert with extra
     244              :             //    domains listed in Subject Alternative Name (SAN).
     245              :             // b) Deploy separate proxy instances for extra domains.
     246            0 :             self.default.clone()
     247              :         }
     248           20 :     }
     249              : }
        

Generated by: LCOV version 2.1-beta