LCOV - differential code coverage report
Current view: top level - proxy/src - config.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 91.0 % 255 232 23 232
Current Date: 2024-01-09 02:06:09 Functions: 74.4 % 39 29 10 29
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use crate::{auth, rate_limiter::RateBucketInfo, serverless::GlobalConnPoolOptions};
       2                 : use anyhow::{bail, ensure, Context, Ok};
       3                 : use rustls::{sign, Certificate, PrivateKey};
       4                 : use sha2::{Digest, Sha256};
       5                 : use std::{
       6                 :     collections::{HashMap, HashSet},
       7                 :     str::FromStr,
       8                 :     sync::Arc,
       9                 :     time::Duration,
      10                 : };
      11                 : use tracing::{error, info};
      12                 : use x509_parser::oid_registry;
      13                 : 
      14                 : pub struct ProxyConfig {
      15                 :     pub tls_config: Option<TlsConfig>,
      16                 :     pub auth_backend: auth::BackendType<'static, ()>,
      17                 :     pub metric_collection: Option<MetricCollectionConfig>,
      18                 :     pub allow_self_signed_compute: bool,
      19                 :     pub http_config: HttpConfig,
      20                 :     pub authentication_config: AuthenticationConfig,
      21                 :     pub require_client_ip: bool,
      22                 :     pub disable_ip_check_for_http: bool,
      23                 :     pub endpoint_rps_limit: Vec<RateBucketInfo>,
      24                 :     pub region: String,
      25                 : }
      26                 : 
      27 CBC           1 : #[derive(Debug)]
      28                 : pub struct MetricCollectionConfig {
      29                 :     pub endpoint: reqwest::Url,
      30                 :     pub interval: Duration,
      31                 : }
      32                 : 
      33                 : pub struct TlsConfig {
      34                 :     pub config: Arc<rustls::ServerConfig>,
      35                 :     pub common_names: Option<HashSet<String>>,
      36                 :     pub cert_resolver: Arc<CertResolver>,
      37                 : }
      38                 : 
      39                 : pub struct HttpConfig {
      40                 :     pub request_timeout: tokio::time::Duration,
      41                 :     pub pool_options: GlobalConnPoolOptions,
      42                 : }
      43                 : 
      44                 : pub struct AuthenticationConfig {
      45                 :     pub scram_protocol_timeout: tokio::time::Duration,
      46                 : }
      47                 : 
      48                 : impl TlsConfig {
      49              91 :     pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
      50              91 :         self.config.clone()
      51              91 :     }
      52                 : }
      53                 : 
      54                 : /// Configure TLS for the main endpoint.
      55              22 : pub fn configure_tls(
      56              22 :     key_path: &str,
      57              22 :     cert_path: &str,
      58              22 :     certs_dir: Option<&String>,
      59              22 : ) -> anyhow::Result<TlsConfig> {
      60              22 :     let mut cert_resolver = CertResolver::new();
      61              22 : 
      62              22 :     // add default certificate
      63              22 :     cert_resolver.add_cert_path(key_path, cert_path, true)?;
      64                 : 
      65                 :     // add extra certificates
      66              22 :     if let Some(certs_dir) = certs_dir {
      67 UBC           0 :         for entry in std::fs::read_dir(certs_dir)? {
      68               0 :             let entry = entry?;
      69               0 :             let path = entry.path();
      70               0 :             if path.is_dir() {
      71                 :                 // file names aligned with default cert-manager names
      72               0 :                 let key_path = path.join("tls.key");
      73               0 :                 let cert_path = path.join("tls.crt");
      74               0 :                 if key_path.exists() && cert_path.exists() {
      75               0 :                     cert_resolver.add_cert_path(
      76               0 :                         &key_path.to_string_lossy(),
      77               0 :                         &cert_path.to_string_lossy(),
      78               0 :                         false,
      79               0 :                     )?;
      80               0 :                 }
      81               0 :             }
      82                 :         }
      83 CBC          22 :     }
      84                 : 
      85              22 :     let common_names = cert_resolver.get_common_names();
      86              22 : 
      87              22 :     let cert_resolver = Arc::new(cert_resolver);
      88                 : 
      89              22 :     let config = rustls::ServerConfig::builder()
      90              22 :         .with_safe_default_cipher_suites()
      91              22 :         .with_safe_default_kx_groups()
      92              22 :         // allow TLS 1.2 to be compatible with older client libraries
      93              22 :         .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
      94              22 :         .with_no_client_auth()
      95              22 :         .with_cert_resolver(cert_resolver.clone())
      96              22 :         .into();
      97              22 : 
      98              22 :     Ok(TlsConfig {
      99              22 :         config,
     100              22 :         common_names: Some(common_names),
     101              22 :         cert_resolver,
     102              22 :     })
     103              22 : }
     104                 : 
     105                 : /// Channel binding parameter
     106                 : ///
     107                 : /// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
     108                 : /// Description: The hash of the TLS server's certificate as it
     109                 : /// appears, octet for octet, in the server's Certificate message.  Note
     110                 : /// that the Certificate message contains a certificate_list, in which
     111                 : /// the first element is the server's certificate.
     112                 : ///
     113                 : /// The hash function is to be selected as follows:
     114                 : ///
     115                 : /// * if the certificate's signatureAlgorithm uses a single hash
     116                 : ///   function, and that hash function is either MD5 or SHA-1, then use SHA-256;
     117                 : ///
     118                 : /// * if the certificate's signatureAlgorithm uses a single hash
     119                 : ///   function and that hash function neither MD5 nor SHA-1, then use
     120                 : ///   the hash function associated with the certificate's
     121                 : ///   signatureAlgorithm;
     122                 : ///
     123                 : /// * if the certificate's signatureAlgorithm uses no hash functions or
     124                 : ///   uses multiple hash functions, then this channel binding type's
     125                 : ///   channel bindings are undefined at this time (updates to is channel
     126                 : ///   binding type may occur to address this issue if it ever arises).
     127             163 : #[derive(Debug, Clone, Copy)]
     128                 : pub enum TlsServerEndPoint {
     129                 :     Sha256([u8; 32]),
     130                 :     Undefined,
     131                 : }
     132                 : 
     133                 : impl TlsServerEndPoint {
     134              44 :     pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
     135              44 :         let sha256_oids = [
     136              44 :             // I'm explicitly not adding MD5 or SHA1 here... They're bad.
     137              44 :             oid_registry::OID_SIG_ECDSA_WITH_SHA256,
     138              44 :             oid_registry::OID_PKCS1_SHA256WITHRSA,
     139              44 :         ];
     140                 : 
     141              44 :         let pem = x509_parser::parse_x509_certificate(&cert.0)
     142              44 :             .context("Failed to parse PEM object from cerficiate")?
     143                 :             .1;
     144                 : 
     145              44 :         info!(subject = %pem.subject, "parsing TLS certificate");
     146                 : 
     147              44 :         let reg = oid_registry::OidRegistry::default().with_all_crypto();
     148              44 :         let oid = pem.signature_algorithm.oid();
     149              44 :         let alg = reg.get(oid);
     150              44 :         if sha256_oids.contains(oid) {
     151              44 :             let tls_server_end_point: [u8; 32] =
     152              44 :                 Sha256::new().chain_update(&cert.0).finalize().into();
     153              44 :             info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
     154              44 :             Ok(Self::Sha256(tls_server_end_point))
     155                 :         } else {
     156 UBC           0 :             error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
     157               0 :             Ok(Self::Undefined)
     158                 :         }
     159 CBC          44 :     }
     160                 : 
     161              52 :     pub fn supported(&self) -> bool {
     162              52 :         !matches!(self, TlsServerEndPoint::Undefined)
     163              52 :     }
     164                 : }
     165                 : 
     166              43 : #[derive(Default)]
     167                 : pub struct CertResolver {
     168                 :     certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
     169                 :     default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
     170                 : }
     171                 : 
     172                 : impl CertResolver {
     173              43 :     pub fn new() -> Self {
     174              43 :         Self::default()
     175              43 :     }
     176                 : 
     177              22 :     fn add_cert_path(
     178              22 :         &mut self,
     179              22 :         key_path: &str,
     180              22 :         cert_path: &str,
     181              22 :         is_default: bool,
     182              22 :     ) -> anyhow::Result<()> {
     183              22 :         let priv_key = {
     184              22 :             let key_bytes = std::fs::read(key_path)
     185              22 :                 .context(format!("Failed to read TLS keys at '{key_path}'"))?;
     186              22 :             let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
     187              22 :                 .context(format!("Failed to parse TLS keys at '{key_path}'"))?;
     188                 : 
     189              22 :             ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
     190              22 :             keys.pop().map(rustls::PrivateKey).unwrap()
     191                 :         };
     192                 : 
     193              22 :         let cert_chain_bytes = std::fs::read(cert_path)
     194              22 :             .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
     195                 : 
     196              22 :         let cert_chain = {
     197              22 :             rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     198              22 :                 .with_context(|| {
     199 UBC           0 :                     format!(
     200               0 :                     "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
     201               0 :                 )
     202 CBC          22 :                 })?
     203              22 :                 .into_iter()
     204              22 :                 .map(rustls::Certificate)
     205              22 :                 .collect()
     206              22 :         };
     207              22 : 
     208              22 :         self.add_cert(priv_key, cert_chain, is_default)
     209              22 :     }
     210                 : 
     211              43 :     pub fn add_cert(
     212              43 :         &mut self,
     213              43 :         priv_key: PrivateKey,
     214              43 :         cert_chain: Vec<Certificate>,
     215              43 :         is_default: bool,
     216              43 :     ) -> anyhow::Result<()> {
     217              43 :         let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
     218                 : 
     219              43 :         let first_cert = &cert_chain[0];
     220              43 :         let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
     221              43 :         let pem = x509_parser::parse_x509_certificate(&first_cert.0)
     222              43 :             .context("Failed to parse PEM object from cerficiate")?
     223                 :             .1;
     224                 : 
     225              43 :         let common_name = pem.subject().to_string();
     226                 : 
     227                 :         // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
     228                 :         // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
     229                 :         // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
     230                 :         // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
     231                 :         // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
     232                 :         // of cutting off '*.' parts.
     233              43 :         let common_name = if common_name.starts_with("CN=*.") {
     234              22 :             common_name.strip_prefix("CN=*.").map(|s| s.to_string())
     235                 :         } else {
     236              21 :             common_name.strip_prefix("CN=").map(|s| s.to_string())
     237                 :         }
     238              43 :         .context("Failed to parse common name from certificate")?;
     239                 : 
     240              43 :         let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
     241              43 : 
     242              43 :         if is_default {
     243              43 :             self.default = Some((cert.clone(), tls_server_end_point));
     244              43 :         }
     245                 : 
     246              43 :         self.certs.insert(common_name, (cert, tls_server_end_point));
     247              43 : 
     248              43 :         Ok(())
     249              43 :     }
     250                 : 
     251              43 :     pub fn get_common_names(&self) -> HashSet<String> {
     252              43 :         self.certs.keys().map(|s| s.to_string()).collect()
     253              43 :     }
     254                 : }
     255                 : 
     256                 : impl rustls::server::ResolvesServerCert for CertResolver {
     257              94 :     fn resolve(
     258              94 :         &self,
     259              94 :         client_hello: rustls::server::ClientHello,
     260              94 :     ) -> Option<Arc<rustls::sign::CertifiedKey>> {
     261              94 :         self.resolve(client_hello.server_name()).map(|x| x.0)
     262              94 :     }
     263                 : }
     264                 : 
     265                 : impl CertResolver {
     266             163 :     pub fn resolve(
     267             163 :         &self,
     268             163 :         server_name: Option<&str>,
     269             163 :     ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
     270                 :         // loop here and cut off more and more subdomains until we find
     271                 :         // a match to get a proper wildcard support. OTOH, we now do not
     272                 :         // use nested domains, so keep this simple for now.
     273                 :         //
     274                 :         // With the current coding foo.com will match *.foo.com and that
     275                 :         // repeats behavior of the old code.
     276             163 :         if let Some(mut sni_name) = server_name {
     277                 :             loop {
     278             246 :                 if let Some(cert) = self.certs.get(sni_name) {
     279             123 :                     return Some(cert.clone());
     280             123 :                 }
     281             123 :                 if let Some((_, rest)) = sni_name.split_once('.') {
     282             123 :                     sni_name = rest;
     283             123 :                 } else {
     284 UBC           0 :                     return None;
     285                 :                 }
     286                 :             }
     287                 :         } else {
     288                 :             // No SNI, use the default certificate, otherwise we can't get to
     289                 :             // options parameter which can be used to set endpoint name too.
     290                 :             // That means that non-SNI flow will not work for CNAME domains in
     291                 :             // verify-full mode.
     292                 :             //
     293                 :             // If that will be a problem we can:
     294                 :             //
     295                 :             // a) Instead of multi-cert approach use single cert with extra
     296                 :             //    domains listed in Subject Alternative Name (SAN).
     297                 :             // b) Deploy separate proxy instances for extra domains.
     298 CBC          40 :             self.default.as_ref().cloned()
     299                 :         }
     300             163 :     }
     301                 : }
     302                 : 
     303                 : /// Helper for cmdline cache options parsing.
     304               3 : #[derive(Debug)]
     305                 : pub struct CacheOptions {
     306                 :     /// Max number of entries.
     307                 :     pub size: usize,
     308                 :     /// Entry's time-to-live.
     309                 :     pub ttl: Duration,
     310                 : }
     311                 : 
     312                 : impl CacheOptions {
     313                 :     /// Default options for [`crate::console::provider::NodeInfoCache`].
     314                 :     pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
     315                 : 
     316                 :     /// Parse cache options passed via cmdline.
     317                 :     /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
     318               7 :     fn parse(options: &str) -> anyhow::Result<Self> {
     319               7 :         let mut size = None;
     320               7 :         let mut ttl = None;
     321                 : 
     322              12 :         for option in options.split(',') {
     323              12 :             let (key, value) = option
     324              12 :                 .split_once('=')
     325              12 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     326                 : 
     327              12 :             match key {
     328              12 :                 "size" => size = Some(value.parse()?),
     329               5 :                 "ttl" => ttl = Some(humantime::parse_duration(value)?),
     330 UBC           0 :                 unknown => bail!("unknown key: {unknown}"),
     331                 :             }
     332                 :         }
     333                 : 
     334                 :         // TTL doesn't matter if cache is always empty.
     335 CBC           7 :         if let Some(0) = size {
     336               3 :             ttl.get_or_insert(Duration::default());
     337               4 :         }
     338                 : 
     339                 :         Ok(Self {
     340               7 :             size: size.context("missing `size`")?,
     341               7 :             ttl: ttl.context("missing `ttl`")?,
     342                 :         })
     343               7 :     }
     344                 : }
     345                 : 
     346                 : impl FromStr for CacheOptions {
     347                 :     type Err = anyhow::Error;
     348                 : 
     349               7 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     350               7 :         let error = || format!("failed to parse cache options '{options}'");
     351               7 :         Self::parse(options).with_context(error)
     352               7 :     }
     353                 : }
     354                 : 
     355                 : /// Helper for cmdline cache options parsing.
     356                 : pub struct WakeComputeLockOptions {
     357                 :     /// The number of shards the lock map should have
     358                 :     pub shards: usize,
     359                 :     /// The number of allowed concurrent requests for each endpoitn
     360                 :     pub permits: usize,
     361                 :     /// Garbage collection epoch
     362                 :     pub epoch: Duration,
     363                 :     /// Lock timeout
     364                 :     pub timeout: Duration,
     365                 : }
     366                 : 
     367                 : impl WakeComputeLockOptions {
     368                 :     /// Default options for [`crate::console::provider::ApiLocks`].
     369                 :     pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
     370                 : 
     371                 :     // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
     372                 : 
     373                 :     /// Parse lock options passed via cmdline.
     374                 :     /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
     375               4 :     fn parse(options: &str) -> anyhow::Result<Self> {
     376               4 :         let mut shards = None;
     377               4 :         let mut permits = None;
     378               4 :         let mut epoch = None;
     379               4 :         let mut timeout = None;
     380                 : 
     381              10 :         for option in options.split(',') {
     382              10 :             let (key, value) = option
     383              10 :                 .split_once('=')
     384              10 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     385                 : 
     386              10 :             match key {
     387              10 :                 "shards" => shards = Some(value.parse()?),
     388               8 :                 "permits" => permits = Some(value.parse()?),
     389               4 :                 "epoch" => epoch = Some(humantime::parse_duration(value)?),
     390               2 :                 "timeout" => timeout = Some(humantime::parse_duration(value)?),
     391 UBC           0 :                 unknown => bail!("unknown key: {unknown}"),
     392                 :             }
     393                 :         }
     394                 : 
     395                 :         // these dont matter if lock is disabled
     396 CBC           4 :         if let Some(0) = permits {
     397               2 :             timeout = Some(Duration::default());
     398               2 :             epoch = Some(Duration::default());
     399               2 :             shards = Some(2);
     400               2 :         }
     401                 : 
     402               4 :         let out = Self {
     403               4 :             shards: shards.context("missing `shards`")?,
     404               4 :             permits: permits.context("missing `permits`")?,
     405               4 :             epoch: epoch.context("missing `epoch`")?,
     406               4 :             timeout: timeout.context("missing `timeout`")?,
     407                 :         };
     408                 : 
     409               4 :         ensure!(out.shards > 1, "shard count must be > 1");
     410                 :         ensure!(
     411               4 :             out.shards.is_power_of_two(),
     412 UBC           0 :             "shard count must be a power of two"
     413                 :         );
     414                 : 
     415 CBC           4 :         Ok(out)
     416               4 :     }
     417                 : }
     418                 : 
     419                 : impl FromStr for WakeComputeLockOptions {
     420                 :     type Err = anyhow::Error;
     421                 : 
     422               4 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     423               4 :         let error = || format!("failed to parse cache lock options '{options}'");
     424               4 :         Self::parse(options).with_context(error)
     425               4 :     }
     426                 : }
     427                 : 
     428                 : #[cfg(test)]
     429                 : mod tests {
     430                 :     use super::*;
     431                 : 
     432               1 :     #[test]
     433               1 :     fn test_parse_cache_options() -> anyhow::Result<()> {
     434               1 :         let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
     435               1 :         assert_eq!(size, 4096);
     436               1 :         assert_eq!(ttl, Duration::from_secs(5 * 60));
     437                 : 
     438               1 :         let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
     439               1 :         assert_eq!(size, 2);
     440               1 :         assert_eq!(ttl, Duration::from_secs(4 * 60));
     441                 : 
     442               1 :         let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
     443               1 :         assert_eq!(size, 0);
     444               1 :         assert_eq!(ttl, Duration::from_secs(1));
     445                 : 
     446               1 :         let CacheOptions { size, ttl } = "size=0".parse()?;
     447               1 :         assert_eq!(size, 0);
     448               1 :         assert_eq!(ttl, Duration::default());
     449                 : 
     450               1 :         Ok(())
     451               1 :     }
     452                 : 
     453               1 :     #[test]
     454               1 :     fn test_parse_lock_options() -> anyhow::Result<()> {
     455                 :         let WakeComputeLockOptions {
     456               1 :             epoch,
     457               1 :             permits,
     458               1 :             shards,
     459               1 :             timeout,
     460               1 :         } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
     461               1 :         assert_eq!(epoch, Duration::from_secs(10 * 60));
     462               1 :         assert_eq!(timeout, Duration::from_secs(1));
     463               1 :         assert_eq!(shards, 32);
     464               1 :         assert_eq!(permits, 4);
     465                 : 
     466                 :         let WakeComputeLockOptions {
     467               1 :             epoch,
     468               1 :             permits,
     469               1 :             shards,
     470               1 :             timeout,
     471               1 :         } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
     472               1 :         assert_eq!(epoch, Duration::from_secs(60));
     473               1 :         assert_eq!(timeout, Duration::from_millis(100));
     474               1 :         assert_eq!(shards, 16);
     475               1 :         assert_eq!(permits, 8);
     476                 : 
     477                 :         let WakeComputeLockOptions {
     478               1 :             epoch,
     479               1 :             permits,
     480               1 :             shards,
     481               1 :             timeout,
     482               1 :         } = "permits=0".parse()?;
     483               1 :         assert_eq!(epoch, Duration::ZERO);
     484               1 :         assert_eq!(timeout, Duration::ZERO);
     485               1 :         assert_eq!(shards, 2);
     486               1 :         assert_eq!(permits, 0);
     487                 : 
     488               1 :         Ok(())
     489               1 :     }
     490                 : }
        

Generated by: LCOV version 2.1-beta