LCOV - code coverage report
Current view: top level - proxy/src/tls - server_config.rs (source / functions) Coverage Total Hit
Test: d5cbf29d97fdd23f6bda3d96cc57e517a06c5084.info Lines: 32.8 % 131 43
Test Date: 2025-03-18 14:05:41 Functions: 38.5 % 13 5

            Line data    Source code
       1              : use std::collections::{HashMap, HashSet};
       2              : use std::sync::Arc;
       3              : 
       4              : use anyhow::{Context, bail};
       5              : use itertools::Itertools;
       6              : use rustls::crypto::ring::{self, sign};
       7              : use rustls::pki_types::{CertificateDer, PrivateKeyDer};
       8              : 
       9              : use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
      10              : 
      11              : pub struct TlsConfig {
      12              :     // unfortunate split since we cannot change the ALPN on demand.
      13              :     // <https://github.com/rustls/rustls/issues/2260>
      14              :     pub http_config: Arc<rustls::ServerConfig>,
      15              :     pub pg_config: Arc<rustls::ServerConfig>,
      16              :     pub common_names: HashSet<String>,
      17              :     pub cert_resolver: Arc<CertResolver>,
      18              : }
      19              : 
      20              : /// Configure TLS for the main endpoint.
      21            0 : pub fn configure_tls(
      22            0 :     key_path: &str,
      23            0 :     cert_path: &str,
      24            0 :     certs_dir: Option<&String>,
      25            0 :     allow_tls_keylogfile: bool,
      26            0 : ) -> anyhow::Result<TlsConfig> {
      27            0 :     let mut cert_resolver = CertResolver::new();
      28            0 : 
      29            0 :     // add default certificate
      30            0 :     cert_resolver.add_cert_path(key_path, cert_path, true)?;
      31              : 
      32              :     // add extra certificates
      33            0 :     if let Some(certs_dir) = certs_dir {
      34            0 :         for entry in std::fs::read_dir(certs_dir)? {
      35            0 :             let entry = entry?;
      36            0 :             let path = entry.path();
      37            0 :             if path.is_dir() {
      38              :                 // file names aligned with default cert-manager names
      39            0 :                 let key_path = path.join("tls.key");
      40            0 :                 let cert_path = path.join("tls.crt");
      41            0 :                 if key_path.exists() && cert_path.exists() {
      42            0 :                     cert_resolver.add_cert_path(
      43            0 :                         &key_path.to_string_lossy(),
      44            0 :                         &cert_path.to_string_lossy(),
      45            0 :                         false,
      46            0 :                     )?;
      47            0 :                 }
      48            0 :             }
      49              :         }
      50            0 :     }
      51              : 
      52            0 :     let common_names = cert_resolver.get_common_names();
      53            0 : 
      54            0 :     let cert_resolver = Arc::new(cert_resolver);
      55              : 
      56              :     // allow TLS 1.2 to be compatible with older client libraries
      57            0 :     let mut config =
      58            0 :         rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
      59            0 :             .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
      60            0 :             .context("ring should support TLS1.2 and TLS1.3")?
      61            0 :             .with_no_client_auth()
      62            0 :             .with_cert_resolver(cert_resolver.clone());
      63            0 : 
      64            0 :     config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
      65            0 : 
      66            0 :     if allow_tls_keylogfile {
      67            0 :         // KeyLogFile will check for the SSLKEYLOGFILE environment variable.
      68            0 :         config.key_log = Arc::new(rustls::KeyLogFile::new());
      69            0 :     }
      70              : 
      71            0 :     let mut http_config = config.clone();
      72            0 :     let mut pg_config = config;
      73            0 : 
      74            0 :     http_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
      75            0 :     pg_config.alpn_protocols = vec![b"postgresql".to_vec()];
      76            0 : 
      77            0 :     Ok(TlsConfig {
      78            0 :         http_config: Arc::new(http_config),
      79            0 :         pg_config: Arc::new(pg_config),
      80            0 :         common_names,
      81            0 :         cert_resolver,
      82            0 :     })
      83            0 : }
      84              : 
      85              : #[derive(Default, Debug)]
      86              : pub struct CertResolver {
      87              :     certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
      88              :     default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
      89              : }
      90              : 
      91              : impl CertResolver {
      92           21 :     pub fn new() -> Self {
      93           21 :         Self::default()
      94           21 :     }
      95              : 
      96            0 :     fn add_cert_path(
      97            0 :         &mut self,
      98            0 :         key_path: &str,
      99            0 :         cert_path: &str,
     100            0 :         is_default: bool,
     101            0 :     ) -> anyhow::Result<()> {
     102            0 :         let priv_key = {
     103            0 :             let key_bytes = std::fs::read(key_path)
     104            0 :                 .with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
     105            0 :             rustls_pemfile::private_key(&mut &key_bytes[..])
     106            0 :                 .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
     107            0 :                 .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
     108              :         };
     109              : 
     110            0 :         let cert_chain_bytes = std::fs::read(cert_path)
     111            0 :             .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
     112              : 
     113            0 :         let cert_chain = {
     114            0 :             rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     115            0 :                 .try_collect()
     116            0 :                 .with_context(|| {
     117            0 :                     format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
     118            0 :                 })?
     119              :         };
     120              : 
     121            0 :         self.add_cert(priv_key, cert_chain, is_default)
     122            0 :     }
     123              : 
     124           21 :     pub fn add_cert(
     125           21 :         &mut self,
     126           21 :         priv_key: PrivateKeyDer<'static>,
     127           21 :         cert_chain: Vec<CertificateDer<'static>>,
     128           21 :         is_default: bool,
     129           21 :     ) -> anyhow::Result<()> {
     130           21 :         let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
     131              : 
     132           21 :         let first_cert = &cert_chain[0];
     133           21 :         let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
     134           21 :         let pem = x509_parser::parse_x509_certificate(first_cert)
     135           21 :             .context("Failed to parse PEM object from cerficiate")?
     136              :             .1;
     137              : 
     138           21 :         let common_name = pem.subject().to_string();
     139              : 
     140              :         // We need to get the canonical name for this certificate so we can match them against any domain names
     141              :         // seen within the proxy codebase.
     142              :         //
     143              :         // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
     144              :         // We need to remove the wildcard prefix for the purposes of certificate selection.
     145              :         //
     146              :         // auth-broker does not use SNI and instead uses the Neon-Connection-String header.
     147              :         // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
     148              :         //
     149              :         // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
     150              :         // validation, so let's we can continue with any common-name
     151           21 :         let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
     152            0 :             s.to_string()
     153           21 :         } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
     154            0 :             s.to_string()
     155           21 :         } else if let Some(s) = common_name.strip_prefix("CN=") {
     156           21 :             s.to_string()
     157              :         } else {
     158            0 :             bail!("Failed to parse common name from certificate")
     159              :         };
     160              : 
     161           21 :         let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
     162           21 : 
     163           21 :         if is_default {
     164           21 :             self.default = Some((cert.clone(), tls_server_end_point));
     165           21 :         }
     166              : 
     167           21 :         self.certs.insert(common_name, (cert, tls_server_end_point));
     168           21 : 
     169           21 :         Ok(())
     170           21 :     }
     171              : 
     172           21 :     pub fn get_common_names(&self) -> HashSet<String> {
     173           21 :         self.certs.keys().map(|s| s.to_string()).collect()
     174           21 :     }
     175              : }
     176              : 
     177              : impl rustls::server::ResolvesServerCert for CertResolver {
     178            0 :     fn resolve(
     179            0 :         &self,
     180            0 :         client_hello: rustls::server::ClientHello<'_>,
     181            0 :     ) -> Option<Arc<rustls::sign::CertifiedKey>> {
     182            0 :         self.resolve(client_hello.server_name()).map(|x| x.0)
     183            0 :     }
     184              : }
     185              : 
     186              : impl CertResolver {
     187           20 :     pub fn resolve(
     188           20 :         &self,
     189           20 :         server_name: Option<&str>,
     190           20 :     ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
     191              :         // loop here and cut off more and more subdomains until we find
     192              :         // a match to get a proper wildcard support. OTOH, we now do not
     193              :         // use nested domains, so keep this simple for now.
     194              :         //
     195              :         // With the current coding foo.com will match *.foo.com and that
     196              :         // repeats behavior of the old code.
     197           20 :         if let Some(mut sni_name) = server_name {
     198              :             loop {
     199           40 :                 if let Some(cert) = self.certs.get(sni_name) {
     200           20 :                     return Some(cert.clone());
     201           20 :                 }
     202           20 :                 if let Some((_, rest)) = sni_name.split_once('.') {
     203           20 :                     sni_name = rest;
     204           20 :                 } else {
     205            0 :                     return None;
     206              :                 }
     207              :             }
     208              :         } else {
     209              :             // No SNI, use the default certificate, otherwise we can't get to
     210              :             // options parameter which can be used to set endpoint name too.
     211              :             // That means that non-SNI flow will not work for CNAME domains in
     212              :             // verify-full mode.
     213              :             //
     214              :             // If that will be a problem we can:
     215              :             //
     216              :             // a) Instead of multi-cert approach use single cert with extra
     217              :             //    domains listed in Subject Alternative Name (SAN).
     218              :             // b) Deploy separate proxy instances for extra domains.
     219            0 :             self.default.clone()
     220              :         }
     221           20 :     }
     222              : }
        

Generated by: LCOV version 2.1-beta