LCOV - code coverage report
Current view: top level - proxy/src/tls - server_config.rs (source / functions) Coverage Total Hit
Test: ae4948feae6a1d420c855050eb8c189119446a71.info Lines: 33.8 % 133 45
Test Date: 2025-03-18 18:33:46 Functions: 38.5 % 13 5

            Line data    Source code
       1              : use std::collections::{HashMap, HashSet};
       2              : use std::sync::Arc;
       3              : 
       4              : use anyhow::{Context, bail};
       5              : use itertools::Itertools;
       6              : use rustls::crypto::ring::{self, sign};
       7              : use rustls::pki_types::{CertificateDer, PrivateKeyDer};
       8              : use x509_cert::der::{Reader, SliceReader};
       9              : 
      10              : use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
      11              : 
      12              : pub struct TlsConfig {
      13              :     // unfortunate split since we cannot change the ALPN on demand.
      14              :     // <https://github.com/rustls/rustls/issues/2260>
      15              :     pub http_config: Arc<rustls::ServerConfig>,
      16              :     pub pg_config: Arc<rustls::ServerConfig>,
      17              :     pub common_names: HashSet<String>,
      18              :     pub cert_resolver: Arc<CertResolver>,
      19              : }
      20              : 
      21              : /// Configure TLS for the main endpoint.
      22            0 : pub fn configure_tls(
      23            0 :     key_path: &str,
      24            0 :     cert_path: &str,
      25            0 :     certs_dir: Option<&String>,
      26            0 :     allow_tls_keylogfile: bool,
      27            0 : ) -> anyhow::Result<TlsConfig> {
      28            0 :     let mut cert_resolver = CertResolver::new();
      29            0 : 
      30            0 :     // add default certificate
      31            0 :     cert_resolver.add_cert_path(key_path, cert_path, true)?;
      32              : 
      33              :     // add extra certificates
      34            0 :     if let Some(certs_dir) = certs_dir {
      35            0 :         for entry in std::fs::read_dir(certs_dir)? {
      36            0 :             let entry = entry?;
      37            0 :             let path = entry.path();
      38            0 :             if path.is_dir() {
      39              :                 // file names aligned with default cert-manager names
      40            0 :                 let key_path = path.join("tls.key");
      41            0 :                 let cert_path = path.join("tls.crt");
      42            0 :                 if key_path.exists() && cert_path.exists() {
      43            0 :                     cert_resolver.add_cert_path(
      44            0 :                         &key_path.to_string_lossy(),
      45            0 :                         &cert_path.to_string_lossy(),
      46            0 :                         false,
      47            0 :                     )?;
      48            0 :                 }
      49            0 :             }
      50              :         }
      51            0 :     }
      52              : 
      53            0 :     let common_names = cert_resolver.get_common_names();
      54            0 : 
      55            0 :     let cert_resolver = Arc::new(cert_resolver);
      56              : 
      57              :     // allow TLS 1.2 to be compatible with older client libraries
      58            0 :     let mut config =
      59            0 :         rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
      60            0 :             .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
      61            0 :             .context("ring should support TLS1.2 and TLS1.3")?
      62            0 :             .with_no_client_auth()
      63            0 :             .with_cert_resolver(cert_resolver.clone());
      64            0 : 
      65            0 :     config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
      66            0 : 
      67            0 :     if allow_tls_keylogfile {
      68            0 :         // KeyLogFile will check for the SSLKEYLOGFILE environment variable.
      69            0 :         config.key_log = Arc::new(rustls::KeyLogFile::new());
      70            0 :     }
      71              : 
      72            0 :     let mut http_config = config.clone();
      73            0 :     let mut pg_config = config;
      74            0 : 
      75            0 :     http_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
      76            0 :     pg_config.alpn_protocols = vec![b"postgresql".to_vec()];
      77            0 : 
      78            0 :     Ok(TlsConfig {
      79            0 :         http_config: Arc::new(http_config),
      80            0 :         pg_config: Arc::new(pg_config),
      81            0 :         common_names,
      82            0 :         cert_resolver,
      83            0 :     })
      84            0 : }
      85              : 
      86              : #[derive(Default, Debug)]
      87              : pub struct CertResolver {
      88              :     certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
      89              :     default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
      90              : }
      91              : 
      92              : impl CertResolver {
      93           21 :     pub fn new() -> Self {
      94           21 :         Self::default()
      95           21 :     }
      96              : 
      97            0 :     fn add_cert_path(
      98            0 :         &mut self,
      99            0 :         key_path: &str,
     100            0 :         cert_path: &str,
     101            0 :         is_default: bool,
     102            0 :     ) -> anyhow::Result<()> {
     103            0 :         let priv_key = {
     104            0 :             let key_bytes = std::fs::read(key_path)
     105            0 :                 .with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
     106            0 :             rustls_pemfile::private_key(&mut &key_bytes[..])
     107            0 :                 .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
     108            0 :                 .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
     109              :         };
     110              : 
     111            0 :         let cert_chain_bytes = std::fs::read(cert_path)
     112            0 :             .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
     113              : 
     114            0 :         let cert_chain = {
     115            0 :             rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     116            0 :                 .try_collect()
     117            0 :                 .with_context(|| {
     118            0 :                     format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
     119            0 :                 })?
     120              :         };
     121              : 
     122            0 :         self.add_cert(priv_key, cert_chain, is_default)
     123            0 :     }
     124              : 
     125           21 :     pub fn add_cert(
     126           21 :         &mut self,
     127           21 :         priv_key: PrivateKeyDer<'static>,
     128           21 :         cert_chain: Vec<CertificateDer<'static>>,
     129           21 :         is_default: bool,
     130           21 :     ) -> anyhow::Result<()> {
     131           21 :         let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
     132              : 
     133           21 :         let first_cert = &cert_chain[0];
     134           21 :         let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
     135              : 
     136           21 :         let certificate = SliceReader::new(first_cert)
     137           21 :             .context("Failed to parse cerficiate")?
     138           21 :             .decode::<x509_cert::Certificate>()
     139           21 :             .context("Failed to parse cerficiate")?;
     140              : 
     141           21 :         let common_name = certificate.tbs_certificate.subject.to_string();
     142              : 
     143              :         // We need to get the canonical name for this certificate so we can match them against any domain names
     144              :         // seen within the proxy codebase.
     145              :         //
     146              :         // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
     147              :         // We need to remove the wildcard prefix for the purposes of certificate selection.
     148              :         //
     149              :         // auth-broker does not use SNI and instead uses the Neon-Connection-String header.
     150              :         // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
     151              :         //
     152              :         // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
     153              :         // validation, so let's we can continue with any common-name
     154           21 :         let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
     155            0 :             s.to_string()
     156           21 :         } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
     157            0 :             s.to_string()
     158           21 :         } else if let Some(s) = common_name.strip_prefix("CN=") {
     159           21 :             s.to_string()
     160              :         } else {
     161            0 :             bail!("Failed to parse common name from certificate")
     162              :         };
     163              : 
     164           21 :         let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
     165           21 : 
     166           21 :         if is_default {
     167           21 :             self.default = Some((cert.clone(), tls_server_end_point));
     168           21 :         }
     169              : 
     170           21 :         self.certs.insert(common_name, (cert, tls_server_end_point));
     171           21 : 
     172           21 :         Ok(())
     173           21 :     }
     174              : 
     175           21 :     pub fn get_common_names(&self) -> HashSet<String> {
     176           21 :         self.certs.keys().map(|s| s.to_string()).collect()
     177           21 :     }
     178              : }
     179              : 
     180              : impl rustls::server::ResolvesServerCert for CertResolver {
     181            0 :     fn resolve(
     182            0 :         &self,
     183            0 :         client_hello: rustls::server::ClientHello<'_>,
     184            0 :     ) -> Option<Arc<rustls::sign::CertifiedKey>> {
     185            0 :         self.resolve(client_hello.server_name()).map(|x| x.0)
     186            0 :     }
     187              : }
     188              : 
     189              : impl CertResolver {
     190           20 :     pub fn resolve(
     191           20 :         &self,
     192           20 :         server_name: Option<&str>,
     193           20 :     ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
     194              :         // loop here and cut off more and more subdomains until we find
     195              :         // a match to get a proper wildcard support. OTOH, we now do not
     196              :         // use nested domains, so keep this simple for now.
     197              :         //
     198              :         // With the current coding foo.com will match *.foo.com and that
     199              :         // repeats behavior of the old code.
     200           20 :         if let Some(mut sni_name) = server_name {
     201              :             loop {
     202           40 :                 if let Some(cert) = self.certs.get(sni_name) {
     203           20 :                     return Some(cert.clone());
     204           20 :                 }
     205           20 :                 if let Some((_, rest)) = sni_name.split_once('.') {
     206           20 :                     sni_name = rest;
     207           20 :                 } else {
     208            0 :                     return None;
     209              :                 }
     210              :             }
     211              :         } else {
     212              :             // No SNI, use the default certificate, otherwise we can't get to
     213              :             // options parameter which can be used to set endpoint name too.
     214              :             // That means that non-SNI flow will not work for CNAME domains in
     215              :             // verify-full mode.
     216              :             //
     217              :             // If that will be a problem we can:
     218              :             //
     219              :             // a) Instead of multi-cert approach use single cert with extra
     220              :             //    domains listed in Subject Alternative Name (SAN).
     221              :             // b) Deploy separate proxy instances for extra domains.
     222            0 :             self.default.clone()
     223              :         }
     224           20 :     }
     225              : }
        

Generated by: LCOV version 2.1-beta