LCOV - differential code coverage report
Current view: top level - proxy/src - config.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 88.8 % 152 135 17 135
Current Date: 2023-10-19 02:04:12 Functions: 81.2 % 16 13 3 13
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : use crate::auth;
       2                 : use anyhow::{bail, ensure, Context, Ok};
       3                 : use rustls::sign;
       4                 : use std::{
       5                 :     collections::{HashMap, HashSet},
       6                 :     str::FromStr,
       7                 :     sync::Arc,
       8                 :     time::Duration,
       9                 : };
      10                 : 
      11                 : pub struct ProxyConfig {
      12                 :     pub tls_config: Option<TlsConfig>,
      13                 :     pub auth_backend: auth::BackendType<'static, ()>,
      14                 :     pub metric_collection: Option<MetricCollectionConfig>,
      15                 :     pub allow_self_signed_compute: bool,
      16                 :     pub http_config: HttpConfig,
      17                 :     pub require_client_ip: bool,
      18                 : }
      19                 : 
      20 CBC           1 : #[derive(Debug)]
      21                 : pub struct MetricCollectionConfig {
      22                 :     pub endpoint: reqwest::Url,
      23                 :     pub interval: Duration,
      24                 : }
      25                 : 
      26                 : pub struct TlsConfig {
      27                 :     pub config: Arc<rustls::ServerConfig>,
      28                 :     pub common_names: Option<HashSet<String>>,
      29                 : }
      30                 : 
      31                 : pub struct HttpConfig {
      32                 :     pub sql_over_http_timeout: tokio::time::Duration,
      33                 : }
      34                 : 
      35                 : impl TlsConfig {
      36              54 :     pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
      37              54 :         self.config.clone()
      38              54 :     }
      39                 : }
      40                 : 
      41                 : /// Configure TLS for the main endpoint.
      42              16 : pub fn configure_tls(
      43              16 :     key_path: &str,
      44              16 :     cert_path: &str,
      45              16 :     certs_dir: Option<&String>,
      46              16 : ) -> anyhow::Result<TlsConfig> {
      47              16 :     let mut cert_resolver = CertResolver::new();
      48              16 : 
      49              16 :     // add default certificate
      50              16 :     cert_resolver.add_cert(key_path, cert_path, true)?;
      51                 : 
      52                 :     // add extra certificates
      53              16 :     if let Some(certs_dir) = certs_dir {
      54 UBC           0 :         for entry in std::fs::read_dir(certs_dir)? {
      55               0 :             let entry = entry?;
      56               0 :             let path = entry.path();
      57               0 :             if path.is_dir() {
      58                 :                 // file names aligned with default cert-manager names
      59               0 :                 let key_path = path.join("tls.key");
      60               0 :                 let cert_path = path.join("tls.crt");
      61               0 :                 if key_path.exists() && cert_path.exists() {
      62               0 :                     cert_resolver.add_cert(
      63               0 :                         &key_path.to_string_lossy(),
      64               0 :                         &cert_path.to_string_lossy(),
      65               0 :                         false,
      66               0 :                     )?;
      67               0 :                 }
      68               0 :             }
      69                 :         }
      70 CBC          16 :     }
      71                 : 
      72              16 :     let common_names = cert_resolver.get_common_names();
      73                 : 
      74              16 :     let config = rustls::ServerConfig::builder()
      75              16 :         .with_safe_default_cipher_suites()
      76              16 :         .with_safe_default_kx_groups()
      77              16 :         // allow TLS 1.2 to be compatible with older client libraries
      78              16 :         .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
      79              16 :         .with_no_client_auth()
      80              16 :         .with_cert_resolver(Arc::new(cert_resolver))
      81              16 :         .into();
      82              16 : 
      83              16 :     Ok(TlsConfig {
      84              16 :         config,
      85              16 :         common_names: Some(common_names),
      86              16 :     })
      87              16 : }
      88                 : 
      89                 : struct CertResolver {
      90                 :     certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
      91                 :     default: Option<Arc<rustls::sign::CertifiedKey>>,
      92                 : }
      93                 : 
      94                 : impl CertResolver {
      95              16 :     fn new() -> Self {
      96              16 :         Self {
      97              16 :             certs: HashMap::new(),
      98              16 :             default: None,
      99              16 :         }
     100              16 :     }
     101                 : 
     102              16 :     fn add_cert(
     103              16 :         &mut self,
     104              16 :         key_path: &str,
     105              16 :         cert_path: &str,
     106              16 :         is_default: bool,
     107              16 :     ) -> anyhow::Result<()> {
     108              16 :         let priv_key = {
     109              16 :             let key_bytes = std::fs::read(key_path)
     110              16 :                 .context(format!("Failed to read TLS keys at '{key_path}'"))?;
     111              16 :             let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
     112              16 :                 .context(format!("Failed to parse TLS keys at '{key_path}'"))?;
     113                 : 
     114              16 :             ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
     115              16 :             keys.pop().map(rustls::PrivateKey).unwrap()
     116                 :         };
     117                 : 
     118              16 :         let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
     119                 : 
     120              16 :         let cert_chain_bytes = std::fs::read(cert_path)
     121              16 :             .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
     122                 : 
     123              16 :         let cert_chain = {
     124              16 :             rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     125              16 :                 .context(format!(
     126              16 :                     "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
     127              16 :                 ))?
     128              16 :                 .into_iter()
     129              16 :                 .map(rustls::Certificate)
     130              16 :                 .collect()
     131                 :         };
     132                 : 
     133              16 :         let common_name = {
     134              16 :             let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
     135              16 :                 .context(format!(
     136              16 :                     "Failed to parse PEM object from bytes from file at '{cert_path}'."
     137              16 :                 ))?
     138                 :                 .1;
     139              16 :             let common_name = pem.parse_x509()?.subject().to_string();
     140              16 : 
     141              16 :             // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
     142              16 :             // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
     143              16 :             // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
     144              16 :             // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
     145              16 :             // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
     146              16 :             // of cutting off '*.' parts.
     147              16 :             if common_name.starts_with("CN=*.") {
     148              16 :                 common_name.strip_prefix("CN=*.").map(|s| s.to_string())
     149                 :             } else {
     150 UBC           0 :                 common_name.strip_prefix("CN=").map(|s| s.to_string())
     151                 :             }
     152                 :         }
     153 CBC          16 :         .context(format!(
     154              16 :             "Failed to parse common name from certificate at '{cert_path}'."
     155              16 :         ))?;
     156                 : 
     157              16 :         let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
     158              16 : 
     159              16 :         if is_default {
     160              16 :             self.default = Some(cert.clone());
     161              16 :         }
     162                 : 
     163              16 :         self.certs.insert(common_name, cert);
     164              16 : 
     165              16 :         Ok(())
     166              16 :     }
     167                 : 
     168              16 :     fn get_common_names(&self) -> HashSet<String> {
     169              16 :         self.certs.keys().map(|s| s.to_string()).collect()
     170              16 :     }
     171                 : }
     172                 : 
     173                 : impl rustls::server::ResolvesServerCert for CertResolver {
     174                 :     fn resolve(
     175                 :         &self,
     176                 :         _client_hello: rustls::server::ClientHello,
     177                 :     ) -> Option<Arc<rustls::sign::CertifiedKey>> {
     178                 :         // loop here and cut off more and more subdomains until we find
     179                 :         // a match to get a proper wildcard support. OTOH, we now do not
     180                 :         // use nested domains, so keep this simple for now.
     181                 :         //
     182                 :         // With the current coding foo.com will match *.foo.com and that
     183                 :         // repeats behavior of the old code.
     184              63 :         if let Some(mut sni_name) = _client_hello.server_name() {
     185                 :             loop {
     186             102 :                 if let Some(cert) = self.certs.get(sni_name) {
     187              51 :                     return Some(cert.clone());
     188              51 :                 }
     189              51 :                 if let Some((_, rest)) = sni_name.split_once('.') {
     190              51 :                     sni_name = rest;
     191              51 :                 } else {
     192 UBC           0 :                     return None;
     193                 :                 }
     194                 :             }
     195                 :         } else {
     196                 :             // No SNI, use the default certificate, otherwise we can't get to
     197                 :             // options parameter which can be used to set endpoint name too.
     198                 :             // That means that non-SNI flow will not work for CNAME domains in
     199                 :             // verify-full mode.
     200                 :             //
     201                 :             // If that will be a problem we can:
     202                 :             //
     203                 :             // a) Instead of multi-cert approach use single cert with extra
     204                 :             //    domains listed in Subject Alternative Name (SAN).
     205                 :             // b) Deploy separate proxy instances for extra domains.
     206 CBC          12 :             self.default.as_ref().cloned()
     207                 :         }
     208              63 :     }
     209                 : }
     210                 : 
     211                 : /// Helper for cmdline cache options parsing.
     212                 : pub struct CacheOptions {
     213                 :     /// Max number of entries.
     214                 :     pub size: usize,
     215                 :     /// Entry's time-to-live.
     216                 :     pub ttl: Duration,
     217                 : }
     218                 : 
     219                 : impl CacheOptions {
     220                 :     /// Default options for [`crate::console::provider::NodeInfoCache`].
     221                 :     pub const DEFAULT_OPTIONS_NODE_INFO: &str = "size=4000,ttl=4m";
     222                 : 
     223                 :     /// Parse cache options passed via cmdline.
     224                 :     /// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`].
     225               4 :     fn parse(options: &str) -> anyhow::Result<Self> {
     226               4 :         let mut size = None;
     227               4 :         let mut ttl = None;
     228                 : 
     229               7 :         for option in options.split(',') {
     230               7 :             let (key, value) = option
     231               7 :                 .split_once('=')
     232               7 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     233                 : 
     234               7 :             match key {
     235               7 :                 "size" => size = Some(value.parse()?),
     236               3 :                 "ttl" => ttl = Some(humantime::parse_duration(value)?),
     237 UBC           0 :                 unknown => bail!("unknown key: {unknown}"),
     238                 :             }
     239                 :         }
     240                 : 
     241                 :         // TTL doesn't matter if cache is always empty.
     242 CBC           4 :         if let Some(0) = size {
     243               2 :             ttl.get_or_insert(Duration::default());
     244               2 :         }
     245                 : 
     246                 :         Ok(Self {
     247               4 :             size: size.context("missing `size`")?,
     248               4 :             ttl: ttl.context("missing `ttl`")?,
     249                 :         })
     250               4 :     }
     251                 : }
     252                 : 
     253                 : impl FromStr for CacheOptions {
     254                 :     type Err = anyhow::Error;
     255                 : 
     256               4 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     257               4 :         let error = || format!("failed to parse cache options '{options}'");
     258               4 :         Self::parse(options).with_context(error)
     259               4 :     }
     260                 : }
     261                 : 
     262                 : #[cfg(test)]
     263                 : mod tests {
     264                 :     use super::*;
     265                 : 
     266               1 :     #[test]
     267               1 :     fn test_parse_cache_options() -> anyhow::Result<()> {
     268               1 :         let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
     269               1 :         assert_eq!(size, 4096);
     270               1 :         assert_eq!(ttl, Duration::from_secs(5 * 60));
     271                 : 
     272               1 :         let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
     273               1 :         assert_eq!(size, 2);
     274               1 :         assert_eq!(ttl, Duration::from_secs(4 * 60));
     275                 : 
     276               1 :         let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
     277               1 :         assert_eq!(size, 0);
     278               1 :         assert_eq!(ttl, Duration::from_secs(1));
     279                 : 
     280               1 :         let CacheOptions { size, ttl } = "size=0".parse()?;
     281               1 :         assert_eq!(size, 0);
     282               1 :         assert_eq!(ttl, Duration::default());
     283                 : 
     284               1 :         Ok(())
     285               1 :     }
     286                 : }
        

Generated by: LCOV version 2.1-beta