LCOV - code coverage report
Current view: top level - proxy/src/tls - server_config.rs (source / functions) Coverage Total Hit
Test: 5445d246133daeceb0507e6cc0797ab7c1c70cb8.info Lines: 36.2 % 127 46
Test Date: 2025-03-12 18:05:02 Functions: 42.9 % 14 6

            Line data    Source code
       1              : use std::collections::{HashMap, HashSet};
       2              : use std::sync::Arc;
       3              : 
       4              : use anyhow::{Context, bail};
       5              : use itertools::Itertools;
       6              : use rustls::crypto::ring::{self, sign};
       7              : use rustls::pki_types::{CertificateDer, PrivateKeyDer};
       8              : 
       9              : use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
      10              : 
      11              : pub struct TlsConfig {
      12              :     pub config: Arc<rustls::ServerConfig>,
      13              :     pub common_names: HashSet<String>,
      14              :     pub cert_resolver: Arc<CertResolver>,
      15              : }
      16              : 
      17              : impl TlsConfig {
      18           20 :     pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
      19           20 :         self.config.clone()
      20           20 :     }
      21              : }
      22              : 
      23              : /// Configure TLS for the main endpoint.
      24            0 : pub fn configure_tls(
      25            0 :     key_path: &str,
      26            0 :     cert_path: &str,
      27            0 :     certs_dir: Option<&String>,
      28            0 :     allow_tls_keylogfile: bool,
      29            0 : ) -> anyhow::Result<TlsConfig> {
      30            0 :     let mut cert_resolver = CertResolver::new();
      31            0 : 
      32            0 :     // add default certificate
      33            0 :     cert_resolver.add_cert_path(key_path, cert_path, true)?;
      34              : 
      35              :     // add extra certificates
      36            0 :     if let Some(certs_dir) = certs_dir {
      37            0 :         for entry in std::fs::read_dir(certs_dir)? {
      38            0 :             let entry = entry?;
      39            0 :             let path = entry.path();
      40            0 :             if path.is_dir() {
      41              :                 // file names aligned with default cert-manager names
      42            0 :                 let key_path = path.join("tls.key");
      43            0 :                 let cert_path = path.join("tls.crt");
      44            0 :                 if key_path.exists() && cert_path.exists() {
      45            0 :                     cert_resolver.add_cert_path(
      46            0 :                         &key_path.to_string_lossy(),
      47            0 :                         &cert_path.to_string_lossy(),
      48            0 :                         false,
      49            0 :                     )?;
      50            0 :                 }
      51            0 :             }
      52              :         }
      53            0 :     }
      54              : 
      55            0 :     let common_names = cert_resolver.get_common_names();
      56            0 : 
      57            0 :     let cert_resolver = Arc::new(cert_resolver);
      58              : 
      59              :     // allow TLS 1.2 to be compatible with older client libraries
      60            0 :     let mut config =
      61            0 :         rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
      62            0 :             .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
      63            0 :             .context("ring should support TLS1.2 and TLS1.3")?
      64            0 :             .with_no_client_auth()
      65            0 :             .with_cert_resolver(cert_resolver.clone());
      66            0 : 
      67            0 :     config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
      68            0 : 
      69            0 :     if allow_tls_keylogfile {
      70            0 :         // KeyLogFile will check for the SSLKEYLOGFILE environment variable.
      71            0 :         config.key_log = Arc::new(rustls::KeyLogFile::new());
      72            0 :     }
      73              : 
      74            0 :     Ok(TlsConfig {
      75            0 :         config: Arc::new(config),
      76            0 :         common_names,
      77            0 :         cert_resolver,
      78            0 :     })
      79            0 : }
      80              : 
      81              : #[derive(Default, Debug)]
      82              : pub struct CertResolver {
      83              :     certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
      84              :     default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
      85              : }
      86              : 
      87              : impl CertResolver {
      88           21 :     pub fn new() -> Self {
      89           21 :         Self::default()
      90           21 :     }
      91              : 
      92            0 :     fn add_cert_path(
      93            0 :         &mut self,
      94            0 :         key_path: &str,
      95            0 :         cert_path: &str,
      96            0 :         is_default: bool,
      97            0 :     ) -> anyhow::Result<()> {
      98            0 :         let priv_key = {
      99            0 :             let key_bytes = std::fs::read(key_path)
     100            0 :                 .with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
     101            0 :             rustls_pemfile::private_key(&mut &key_bytes[..])
     102            0 :                 .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
     103            0 :                 .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
     104              :         };
     105              : 
     106            0 :         let cert_chain_bytes = std::fs::read(cert_path)
     107            0 :             .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
     108              : 
     109            0 :         let cert_chain = {
     110            0 :             rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     111            0 :                 .try_collect()
     112            0 :                 .with_context(|| {
     113            0 :                     format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
     114            0 :                 })?
     115              :         };
     116              : 
     117            0 :         self.add_cert(priv_key, cert_chain, is_default)
     118            0 :     }
     119              : 
     120           21 :     pub fn add_cert(
     121           21 :         &mut self,
     122           21 :         priv_key: PrivateKeyDer<'static>,
     123           21 :         cert_chain: Vec<CertificateDer<'static>>,
     124           21 :         is_default: bool,
     125           21 :     ) -> anyhow::Result<()> {
     126           21 :         let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
     127              : 
     128           21 :         let first_cert = &cert_chain[0];
     129           21 :         let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
     130           21 :         let pem = x509_parser::parse_x509_certificate(first_cert)
     131           21 :             .context("Failed to parse PEM object from cerficiate")?
     132              :             .1;
     133              : 
     134           21 :         let common_name = pem.subject().to_string();
     135              : 
     136              :         // We need to get the canonical name for this certificate so we can match them against any domain names
     137              :         // seen within the proxy codebase.
     138              :         //
     139              :         // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
     140              :         // We need to remove the wildcard prefix for the purposes of certificate selection.
     141              :         //
     142              :         // auth-broker does not use SNI and instead uses the Neon-Connection-String header.
     143              :         // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
     144              :         //
     145              :         // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
     146              :         // validation, so let's we can continue with any common-name
     147           21 :         let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
     148            0 :             s.to_string()
     149           21 :         } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
     150            0 :             s.to_string()
     151           21 :         } else if let Some(s) = common_name.strip_prefix("CN=") {
     152           21 :             s.to_string()
     153              :         } else {
     154            0 :             bail!("Failed to parse common name from certificate")
     155              :         };
     156              : 
     157           21 :         let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
     158           21 : 
     159           21 :         if is_default {
     160           21 :             self.default = Some((cert.clone(), tls_server_end_point));
     161           21 :         }
     162              : 
     163           21 :         self.certs.insert(common_name, (cert, tls_server_end_point));
     164           21 : 
     165           21 :         Ok(())
     166           21 :     }
     167              : 
     168           21 :     pub fn get_common_names(&self) -> HashSet<String> {
     169           21 :         self.certs.keys().map(|s| s.to_string()).collect()
     170           21 :     }
     171              : }
     172              : 
     173              : impl rustls::server::ResolvesServerCert for CertResolver {
     174            0 :     fn resolve(
     175            0 :         &self,
     176            0 :         client_hello: rustls::server::ClientHello<'_>,
     177            0 :     ) -> Option<Arc<rustls::sign::CertifiedKey>> {
     178            0 :         self.resolve(client_hello.server_name()).map(|x| x.0)
     179            0 :     }
     180              : }
     181              : 
     182              : impl CertResolver {
     183           20 :     pub fn resolve(
     184           20 :         &self,
     185           20 :         server_name: Option<&str>,
     186           20 :     ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
     187              :         // loop here and cut off more and more subdomains until we find
     188              :         // a match to get a proper wildcard support. OTOH, we now do not
     189              :         // use nested domains, so keep this simple for now.
     190              :         //
     191              :         // With the current coding foo.com will match *.foo.com and that
     192              :         // repeats behavior of the old code.
     193           20 :         if let Some(mut sni_name) = server_name {
     194              :             loop {
     195           40 :                 if let Some(cert) = self.certs.get(sni_name) {
     196           20 :                     return Some(cert.clone());
     197           20 :                 }
     198           20 :                 if let Some((_, rest)) = sni_name.split_once('.') {
     199           20 :                     sni_name = rest;
     200           20 :                 } else {
     201            0 :                     return None;
     202              :                 }
     203              :             }
     204              :         } else {
     205              :             // No SNI, use the default certificate, otherwise we can't get to
     206              :             // options parameter which can be used to set endpoint name too.
     207              :             // That means that non-SNI flow will not work for CNAME domains in
     208              :             // verify-full mode.
     209              :             //
     210              :             // If that will be a problem we can:
     211              :             //
     212              :             // a) Instead of multi-cert approach use single cert with extra
     213              :             //    domains listed in Subject Alternative Name (SAN).
     214              :             // b) Deploy separate proxy instances for extra domains.
     215            0 :             self.default.clone()
     216              :         }
     217           20 :     }
     218              : }
        

Generated by: LCOV version 2.1-beta