LCOV - code coverage report
Current view: top level - proxy/src - config.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 88.8 % 152 135
Test Date: 2023-09-06 10:18:01 Functions: 81.2 % 16 13

            Line data    Source code
       1              : use crate::auth;
       2              : use anyhow::{bail, ensure, Context, Ok};
       3              : use rustls::sign;
       4              : use std::{
       5              :     collections::{HashMap, HashSet},
       6              :     str::FromStr,
       7              :     sync::Arc,
       8              :     time::Duration,
       9              : };
      10              : 
      11              : pub struct ProxyConfig {
      12              :     pub tls_config: Option<TlsConfig>,
      13              :     pub auth_backend: auth::BackendType<'static, ()>,
      14              :     pub metric_collection: Option<MetricCollectionConfig>,
      15              :     pub allow_self_signed_compute: bool,
      16              : }
      17              : 
      18            1 : #[derive(Debug)]
      19              : pub struct MetricCollectionConfig {
      20              :     pub endpoint: reqwest::Url,
      21              :     pub interval: Duration,
      22              : }
      23              : 
      24              : pub struct TlsConfig {
      25              :     pub config: Arc<rustls::ServerConfig>,
      26              :     pub common_names: Option<HashSet<String>>,
      27              : }
      28              : 
      29              : impl TlsConfig {
      30           50 :     pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
      31           50 :         self.config.clone()
      32           50 :     }
      33              : }
      34              : 
      35              : /// Configure TLS for the main endpoint.
      36           14 : pub fn configure_tls(
      37           14 :     key_path: &str,
      38           14 :     cert_path: &str,
      39           14 :     certs_dir: Option<&String>,
      40           14 : ) -> anyhow::Result<TlsConfig> {
      41           14 :     let mut cert_resolver = CertResolver::new();
      42           14 : 
      43           14 :     // add default certificate
      44           14 :     cert_resolver.add_cert(key_path, cert_path, true)?;
      45              : 
      46              :     // add extra certificates
      47           14 :     if let Some(certs_dir) = certs_dir {
      48            0 :         for entry in std::fs::read_dir(certs_dir)? {
      49            0 :             let entry = entry?;
      50            0 :             let path = entry.path();
      51            0 :             if path.is_dir() {
      52              :                 // file names aligned with default cert-manager names
      53            0 :                 let key_path = path.join("tls.key");
      54            0 :                 let cert_path = path.join("tls.crt");
      55            0 :                 if key_path.exists() && cert_path.exists() {
      56            0 :                     cert_resolver.add_cert(
      57            0 :                         &key_path.to_string_lossy(),
      58            0 :                         &cert_path.to_string_lossy(),
      59            0 :                         false,
      60            0 :                     )?;
      61            0 :                 }
      62            0 :             }
      63              :         }
      64           14 :     }
      65              : 
      66           14 :     let common_names = cert_resolver.get_common_names();
      67              : 
      68           14 :     let config = rustls::ServerConfig::builder()
      69           14 :         .with_safe_default_cipher_suites()
      70           14 :         .with_safe_default_kx_groups()
      71           14 :         // allow TLS 1.2 to be compatible with older client libraries
      72           14 :         .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
      73           14 :         .with_no_client_auth()
      74           14 :         .with_cert_resolver(Arc::new(cert_resolver))
      75           14 :         .into();
      76           14 : 
      77           14 :     Ok(TlsConfig {
      78           14 :         config,
      79           14 :         common_names: Some(common_names),
      80           14 :     })
      81           14 : }
      82              : 
      83              : struct CertResolver {
      84              :     certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
      85              :     default: Option<Arc<rustls::sign::CertifiedKey>>,
      86              : }
      87              : 
      88              : impl CertResolver {
      89           14 :     fn new() -> Self {
      90           14 :         Self {
      91           14 :             certs: HashMap::new(),
      92           14 :             default: None,
      93           14 :         }
      94           14 :     }
      95              : 
      96           14 :     fn add_cert(
      97           14 :         &mut self,
      98           14 :         key_path: &str,
      99           14 :         cert_path: &str,
     100           14 :         is_default: bool,
     101           14 :     ) -> anyhow::Result<()> {
     102           14 :         let priv_key = {
     103           14 :             let key_bytes = std::fs::read(key_path)
     104           14 :                 .context(format!("Failed to read TLS keys at '{key_path}'"))?;
     105           14 :             let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
     106           14 :                 .context(format!("Failed to parse TLS keys at '{key_path}'"))?;
     107              : 
     108           14 :             ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
     109           14 :             keys.pop().map(rustls::PrivateKey).unwrap()
     110              :         };
     111              : 
     112           14 :         let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
     113              : 
     114           14 :         let cert_chain_bytes = std::fs::read(cert_path)
     115           14 :             .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
     116              : 
     117           14 :         let cert_chain = {
     118           14 :             rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     119           14 :                 .context(format!(
     120           14 :                     "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
     121           14 :                 ))?
     122           14 :                 .into_iter()
     123           14 :                 .map(rustls::Certificate)
     124           14 :                 .collect()
     125              :         };
     126              : 
     127           14 :         let common_name = {
     128           14 :             let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
     129           14 :                 .context(format!(
     130           14 :                     "Failed to parse PEM object from bytes from file at '{cert_path}'."
     131           14 :                 ))?
     132              :                 .1;
     133           14 :             let common_name = pem.parse_x509()?.subject().to_string();
     134           14 : 
     135           14 :             // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
     136           14 :             // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
     137           14 :             // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
     138           14 :             // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
     139           14 :             // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
     140           14 :             // of cutting off '*.' parts.
     141           14 :             if common_name.starts_with("CN=*.") {
     142           14 :                 common_name.strip_prefix("CN=*.").map(|s| s.to_string())
     143              :             } else {
     144            0 :                 common_name.strip_prefix("CN=").map(|s| s.to_string())
     145              :             }
     146              :         }
     147           14 :         .context(format!(
     148           14 :             "Failed to parse common name from certificate at '{cert_path}'."
     149           14 :         ))?;
     150              : 
     151           14 :         let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
     152           14 : 
     153           14 :         if is_default {
     154           14 :             self.default = Some(cert.clone());
     155           14 :         }
     156              : 
     157           14 :         self.certs.insert(common_name, cert);
     158           14 : 
     159           14 :         Ok(())
     160           14 :     }
     161              : 
     162           14 :     fn get_common_names(&self) -> HashSet<String> {
     163           14 :         self.certs.keys().map(|s| s.to_string()).collect()
     164           14 :     }
     165              : }
     166              : 
     167              : impl rustls::server::ResolvesServerCert for CertResolver {
     168              :     fn resolve(
     169              :         &self,
     170              :         _client_hello: rustls::server::ClientHello,
     171              :     ) -> Option<Arc<rustls::sign::CertifiedKey>> {
     172              :         // loop here and cut off more and more subdomains until we find
     173              :         // a match to get a proper wildcard support. OTOH, we now do not
     174              :         // use nested domains, so keep this simple for now.
     175              :         //
     176              :         // With the current coding foo.com will match *.foo.com and that
     177              :         // repeats behavior of the old code.
     178           53 :         if let Some(mut sni_name) = _client_hello.server_name() {
     179              :             loop {
     180           82 :                 if let Some(cert) = self.certs.get(sni_name) {
     181           41 :                     return Some(cert.clone());
     182           41 :                 }
     183           41 :                 if let Some((_, rest)) = sni_name.split_once('.') {
     184           41 :                     sni_name = rest;
     185           41 :                 } else {
     186            0 :                     return None;
     187              :                 }
     188              :             }
     189              :         } else {
     190              :             // No SNI, use the default certificate, otherwise we can't get to
     191              :             // options parameter which can be used to set endpoint name too.
     192              :             // That means that non-SNI flow will not work for CNAME domains in
     193              :             // verify-full mode.
     194              :             //
     195              :             // If that will be a problem we can:
     196              :             //
     197              :             // a) Instead of multi-cert approach use single cert with extra
     198              :             //    domains listed in Subject Alternative Name (SAN).
     199              :             // b) Deploy separate proxy instances for extra domains.
     200           12 :             self.default.as_ref().cloned()
     201              :         }
     202           53 :     }
     203              : }
     204              : 
     205              : /// Helper for cmdline cache options parsing.
     206              : pub struct CacheOptions {
     207              :     /// Max number of entries.
     208              :     pub size: usize,
     209              :     /// Entry's time-to-live.
     210              :     pub ttl: Duration,
     211              : }
     212              : 
     213              : impl CacheOptions {
     214              :     /// Default options for [`crate::console::provider::NodeInfoCache`].
     215              :     pub const DEFAULT_OPTIONS_NODE_INFO: &str = "size=4000,ttl=4m";
     216              : 
     217              :     /// Parse cache options passed via cmdline.
     218              :     /// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`].
     219            4 :     fn parse(options: &str) -> anyhow::Result<Self> {
     220            4 :         let mut size = None;
     221            4 :         let mut ttl = None;
     222              : 
     223            7 :         for option in options.split(',') {
     224            7 :             let (key, value) = option
     225            7 :                 .split_once('=')
     226            7 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     227              : 
     228            7 :             match key {
     229            7 :                 "size" => size = Some(value.parse()?),
     230            3 :                 "ttl" => ttl = Some(humantime::parse_duration(value)?),
     231            0 :                 unknown => bail!("unknown key: {unknown}"),
     232              :             }
     233              :         }
     234              : 
     235              :         // TTL doesn't matter if cache is always empty.
     236            4 :         if let Some(0) = size {
     237            2 :             ttl.get_or_insert(Duration::default());
     238            2 :         }
     239              : 
     240              :         Ok(Self {
     241            4 :             size: size.context("missing `size`")?,
     242            4 :             ttl: ttl.context("missing `ttl`")?,
     243              :         })
     244            4 :     }
     245              : }
     246              : 
     247              : impl FromStr for CacheOptions {
     248              :     type Err = anyhow::Error;
     249              : 
     250            4 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     251            4 :         let error = || format!("failed to parse cache options '{options}'");
     252            4 :         Self::parse(options).with_context(error)
     253            4 :     }
     254              : }
     255              : 
     256              : #[cfg(test)]
     257              : mod tests {
     258              :     use super::*;
     259              : 
     260            1 :     #[test]
     261            1 :     fn test_parse_cache_options() -> anyhow::Result<()> {
     262            1 :         let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
     263            1 :         assert_eq!(size, 4096);
     264            1 :         assert_eq!(ttl, Duration::from_secs(5 * 60));
     265              : 
     266            1 :         let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
     267            1 :         assert_eq!(size, 2);
     268            1 :         assert_eq!(ttl, Duration::from_secs(4 * 60));
     269              : 
     270            1 :         let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
     271            1 :         assert_eq!(size, 0);
     272            1 :         assert_eq!(ttl, Duration::from_secs(1));
     273              : 
     274            1 :         let CacheOptions { size, ttl } = "size=0".parse()?;
     275            1 :         assert_eq!(size, 0);
     276            1 :         assert_eq!(ttl, Duration::default());
     277              : 
     278            1 :         Ok(())
     279            1 :     }
     280              : }
        

Generated by: LCOV version 2.1-beta