LCOV - code coverage report
Current view: top level - proxy/src - config.rs (source / functions) Coverage Total Hit
Test: 90b23405d17e36048d3bb64e314067f397803f1b.info Lines: 55.4 % 368 204
Test Date: 2024-09-20 13:14:58 Functions: 33.3 % 63 21

            Line data    Source code
       1              : use crate::{
       2              :     auth::{self, backend::AuthRateLimiter},
       3              :     console::locks::ApiLocks,
       4              :     rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
       5              :     scram::threadpool::ThreadPool,
       6              :     serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
       7              :     Host,
       8              : };
       9              : use anyhow::{bail, ensure, Context, Ok};
      10              : use itertools::Itertools;
      11              : use remote_storage::RemoteStorageConfig;
      12              : use rustls::{
      13              :     crypto::ring::sign,
      14              :     pki_types::{CertificateDer, PrivateKeyDer},
      15              : };
      16              : use sha2::{Digest, Sha256};
      17              : use std::{
      18              :     collections::{HashMap, HashSet},
      19              :     str::FromStr,
      20              :     sync::Arc,
      21              :     time::Duration,
      22              : };
      23              : use tracing::{error, info};
      24              : use x509_parser::oid_registry;
      25              : 
      26              : pub struct ProxyConfig {
      27              :     pub tls_config: Option<TlsConfig>,
      28              :     pub auth_backend: auth::Backend<'static, (), ()>,
      29              :     pub metric_collection: Option<MetricCollectionConfig>,
      30              :     pub allow_self_signed_compute: bool,
      31              :     pub http_config: HttpConfig,
      32              :     pub authentication_config: AuthenticationConfig,
      33              :     pub require_client_ip: bool,
      34              :     pub region: String,
      35              :     pub handshake_timeout: Duration,
      36              :     pub wake_compute_retry_config: RetryConfig,
      37              :     pub connect_compute_locks: ApiLocks<Host>,
      38              :     pub connect_to_compute_retry_config: RetryConfig,
      39              : }
      40              : 
      41              : #[derive(Debug)]
      42              : pub struct MetricCollectionConfig {
      43              :     pub endpoint: reqwest::Url,
      44              :     pub interval: Duration,
      45              :     pub backup_metric_collection_config: MetricBackupCollectionConfig,
      46              : }
      47              : 
      48              : pub struct TlsConfig {
      49              :     pub config: Arc<rustls::ServerConfig>,
      50              :     pub common_names: HashSet<String>,
      51              :     pub cert_resolver: Arc<CertResolver>,
      52              : }
      53              : 
      54              : pub struct HttpConfig {
      55              :     pub accept_websockets: bool,
      56              :     pub pool_options: GlobalConnPoolOptions,
      57              :     pub cancel_set: CancelSet,
      58              :     pub client_conn_threshold: u64,
      59              :     pub max_request_size_bytes: u64,
      60              :     pub max_response_size_bytes: usize,
      61              : }
      62              : 
      63              : pub struct AuthenticationConfig {
      64              :     pub thread_pool: Arc<ThreadPool>,
      65              :     pub scram_protocol_timeout: tokio::time::Duration,
      66              :     pub rate_limiter_enabled: bool,
      67              :     pub rate_limiter: AuthRateLimiter,
      68              :     pub rate_limit_ip_subnet: u8,
      69              :     pub ip_allowlist_check_enabled: bool,
      70              : }
      71              : 
      72              : impl TlsConfig {
      73           20 :     pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
      74           20 :         self.config.clone()
      75           20 :     }
      76              : }
      77              : 
      78              : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
      79              : pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
      80              : 
      81              : /// Configure TLS for the main endpoint.
      82            0 : pub fn configure_tls(
      83            0 :     key_path: &str,
      84            0 :     cert_path: &str,
      85            0 :     certs_dir: Option<&String>,
      86            0 : ) -> anyhow::Result<TlsConfig> {
      87            0 :     let mut cert_resolver = CertResolver::new();
      88            0 : 
      89            0 :     // add default certificate
      90            0 :     cert_resolver.add_cert_path(key_path, cert_path, true)?;
      91              : 
      92              :     // add extra certificates
      93            0 :     if let Some(certs_dir) = certs_dir {
      94            0 :         for entry in std::fs::read_dir(certs_dir)? {
      95            0 :             let entry = entry?;
      96            0 :             let path = entry.path();
      97            0 :             if path.is_dir() {
      98              :                 // file names aligned with default cert-manager names
      99            0 :                 let key_path = path.join("tls.key");
     100            0 :                 let cert_path = path.join("tls.crt");
     101            0 :                 if key_path.exists() && cert_path.exists() {
     102            0 :                     cert_resolver.add_cert_path(
     103            0 :                         &key_path.to_string_lossy(),
     104            0 :                         &cert_path.to_string_lossy(),
     105            0 :                         false,
     106            0 :                     )?;
     107            0 :                 }
     108            0 :             }
     109              :         }
     110            0 :     }
     111              : 
     112            0 :     let common_names = cert_resolver.get_common_names();
     113            0 : 
     114            0 :     let cert_resolver = Arc::new(cert_resolver);
     115            0 : 
     116            0 :     // allow TLS 1.2 to be compatible with older client libraries
     117            0 :     let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[
     118            0 :         &rustls::version::TLS13,
     119            0 :         &rustls::version::TLS12,
     120            0 :     ])
     121            0 :     .with_no_client_auth()
     122            0 :     .with_cert_resolver(cert_resolver.clone());
     123            0 : 
     124            0 :     config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
     125            0 : 
     126            0 :     Ok(TlsConfig {
     127            0 :         config: Arc::new(config),
     128            0 :         common_names,
     129            0 :         cert_resolver,
     130            0 :     })
     131            0 : }
     132              : 
     133              : /// Channel binding parameter
     134              : ///
     135              : /// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
     136              : /// Description: The hash of the TLS server's certificate as it
     137              : /// appears, octet for octet, in the server's Certificate message.  Note
     138              : /// that the Certificate message contains a certificate_list, in which
     139              : /// the first element is the server's certificate.
     140              : ///
     141              : /// The hash function is to be selected as follows:
     142              : ///
     143              : /// * if the certificate's signatureAlgorithm uses a single hash
     144              : ///   function, and that hash function is either MD5 or SHA-1, then use SHA-256;
     145              : ///
     146              : /// * if the certificate's signatureAlgorithm uses a single hash
     147              : ///   function and that hash function neither MD5 nor SHA-1, then use
     148              : ///   the hash function associated with the certificate's
     149              : ///   signatureAlgorithm;
     150              : ///
     151              : /// * if the certificate's signatureAlgorithm uses no hash functions or
     152              : ///   uses multiple hash functions, then this channel binding type's
     153              : ///   channel bindings are undefined at this time (updates to is channel
     154              : ///   binding type may occur to address this issue if it ever arises).
     155              : #[derive(Debug, Clone, Copy)]
     156              : pub enum TlsServerEndPoint {
     157              :     Sha256([u8; 32]),
     158              :     Undefined,
     159              : }
     160              : 
     161              : impl TlsServerEndPoint {
     162           21 :     pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
     163           21 :         let sha256_oids = [
     164           21 :             // I'm explicitly not adding MD5 or SHA1 here... They're bad.
     165           21 :             oid_registry::OID_SIG_ECDSA_WITH_SHA256,
     166           21 :             oid_registry::OID_PKCS1_SHA256WITHRSA,
     167           21 :         ];
     168              : 
     169           21 :         let pem = x509_parser::parse_x509_certificate(cert)
     170           21 :             .context("Failed to parse PEM object from cerficiate")?
     171              :             .1;
     172              : 
     173           21 :         info!(subject = %pem.subject, "parsing TLS certificate");
     174              : 
     175           21 :         let reg = oid_registry::OidRegistry::default().with_all_crypto();
     176           21 :         let oid = pem.signature_algorithm.oid();
     177           21 :         let alg = reg.get(oid);
     178           21 :         if sha256_oids.contains(oid) {
     179           21 :             let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
     180           21 :             info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
     181           21 :             Ok(Self::Sha256(tls_server_end_point))
     182              :         } else {
     183            0 :             error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
     184            0 :             Ok(Self::Undefined)
     185              :         }
     186           21 :     }
     187              : 
     188           16 :     pub fn supported(&self) -> bool {
     189           16 :         !matches!(self, TlsServerEndPoint::Undefined)
     190           16 :     }
     191              : }
     192              : 
     193              : #[derive(Default, Debug)]
     194              : pub struct CertResolver {
     195              :     certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
     196              :     default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
     197              : }
     198              : 
     199              : impl CertResolver {
     200           21 :     pub fn new() -> Self {
     201           21 :         Self::default()
     202           21 :     }
     203              : 
     204            0 :     fn add_cert_path(
     205            0 :         &mut self,
     206            0 :         key_path: &str,
     207            0 :         cert_path: &str,
     208            0 :         is_default: bool,
     209            0 :     ) -> anyhow::Result<()> {
     210            0 :         let priv_key = {
     211            0 :             let key_bytes = std::fs::read(key_path)
     212            0 :                 .context(format!("Failed to read TLS keys at '{key_path}'"))?;
     213            0 :             let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
     214            0 : 
     215            0 :             ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
     216              :             PrivateKeyDer::Pkcs8(
     217            0 :                 keys.pop()
     218            0 :                     .unwrap()
     219            0 :                     .context(format!("Failed to parse TLS keys at '{key_path}'"))?,
     220              :             )
     221              :         };
     222              : 
     223            0 :         let cert_chain_bytes = std::fs::read(cert_path)
     224            0 :             .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
     225              : 
     226            0 :         let cert_chain = {
     227            0 :             rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     228            0 :                 .try_collect()
     229            0 :                 .with_context(|| {
     230            0 :                     format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
     231            0 :                 })?
     232              :         };
     233              : 
     234            0 :         self.add_cert(priv_key, cert_chain, is_default)
     235            0 :     }
     236              : 
     237           21 :     pub fn add_cert(
     238           21 :         &mut self,
     239           21 :         priv_key: PrivateKeyDer<'static>,
     240           21 :         cert_chain: Vec<CertificateDer<'static>>,
     241           21 :         is_default: bool,
     242           21 :     ) -> anyhow::Result<()> {
     243           21 :         let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
     244              : 
     245           21 :         let first_cert = &cert_chain[0];
     246           21 :         let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
     247           21 :         let pem = x509_parser::parse_x509_certificate(first_cert)
     248           21 :             .context("Failed to parse PEM object from cerficiate")?
     249              :             .1;
     250              : 
     251           21 :         let common_name = pem.subject().to_string();
     252              : 
     253              :         // We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as
     254              :         // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
     255              :         // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
     256              :         // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
     257              :         // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
     258              :         // of cutting off '*.' parts.
     259           21 :         let common_name = if common_name.starts_with("CN=*.") {
     260            0 :             common_name.strip_prefix("CN=*.").map(|s| s.to_string())
     261              :         } else {
     262           21 :             common_name.strip_prefix("CN=").map(|s| s.to_string())
     263              :         }
     264           21 :         .context("Failed to parse common name from certificate")?;
     265              : 
     266           21 :         let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
     267           21 : 
     268           21 :         if is_default {
     269           21 :             self.default = Some((cert.clone(), tls_server_end_point));
     270           21 :         }
     271              : 
     272           21 :         self.certs.insert(common_name, (cert, tls_server_end_point));
     273           21 : 
     274           21 :         Ok(())
     275           21 :     }
     276              : 
     277           21 :     pub fn get_common_names(&self) -> HashSet<String> {
     278           21 :         self.certs.keys().map(|s| s.to_string()).collect()
     279           21 :     }
     280              : }
     281              : 
     282              : impl rustls::server::ResolvesServerCert for CertResolver {
     283            0 :     fn resolve(
     284            0 :         &self,
     285            0 :         client_hello: rustls::server::ClientHello<'_>,
     286            0 :     ) -> Option<Arc<rustls::sign::CertifiedKey>> {
     287            0 :         self.resolve(client_hello.server_name()).map(|x| x.0)
     288            0 :     }
     289              : }
     290              : 
     291              : impl CertResolver {
     292           20 :     pub fn resolve(
     293           20 :         &self,
     294           20 :         server_name: Option<&str>,
     295           20 :     ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
     296              :         // loop here and cut off more and more subdomains until we find
     297              :         // a match to get a proper wildcard support. OTOH, we now do not
     298              :         // use nested domains, so keep this simple for now.
     299              :         //
     300              :         // With the current coding foo.com will match *.foo.com and that
     301              :         // repeats behavior of the old code.
     302           20 :         if let Some(mut sni_name) = server_name {
     303              :             loop {
     304           40 :                 if let Some(cert) = self.certs.get(sni_name) {
     305           20 :                     return Some(cert.clone());
     306           20 :                 }
     307           20 :                 if let Some((_, rest)) = sni_name.split_once('.') {
     308           20 :                     sni_name = rest;
     309           20 :                 } else {
     310            0 :                     return None;
     311              :                 }
     312              :             }
     313              :         } else {
     314              :             // No SNI, use the default certificate, otherwise we can't get to
     315              :             // options parameter which can be used to set endpoint name too.
     316              :             // That means that non-SNI flow will not work for CNAME domains in
     317              :             // verify-full mode.
     318              :             //
     319              :             // If that will be a problem we can:
     320              :             //
     321              :             // a) Instead of multi-cert approach use single cert with extra
     322              :             //    domains listed in Subject Alternative Name (SAN).
     323              :             // b) Deploy separate proxy instances for extra domains.
     324            0 :             self.default.clone()
     325              :         }
     326           20 :     }
     327              : }
     328              : 
     329              : #[derive(Debug)]
     330              : pub struct EndpointCacheConfig {
     331              :     /// Batch size to receive all endpoints on the startup.
     332              :     pub initial_batch_size: usize,
     333              :     /// Batch size to receive endpoints.
     334              :     pub default_batch_size: usize,
     335              :     /// Timeouts for the stream read operation.
     336              :     pub xread_timeout: Duration,
     337              :     /// Stream name to read from.
     338              :     pub stream_name: String,
     339              :     /// Limiter info (to distinguish when to enable cache).
     340              :     pub limiter_info: Vec<RateBucketInfo>,
     341              :     /// Disable cache.
     342              :     /// If true, cache is ignored, but reports all statistics.
     343              :     pub disable_cache: bool,
     344              :     /// Retry interval for the stream read operation.
     345              :     pub retry_interval: Duration,
     346              : }
     347              : 
     348              : impl EndpointCacheConfig {
     349              :     /// Default options for [`crate::console::provider::NodeInfoCache`].
     350              :     /// Notice that by default the limiter is empty, which means that cache is disabled.
     351              :     pub const CACHE_DEFAULT_OPTIONS: &'static str =
     352              :         "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
     353              : 
     354              :     /// Parse cache options passed via cmdline.
     355              :     /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
     356            0 :     fn parse(options: &str) -> anyhow::Result<Self> {
     357            0 :         let mut initial_batch_size = None;
     358            0 :         let mut default_batch_size = None;
     359            0 :         let mut xread_timeout = None;
     360            0 :         let mut stream_name = None;
     361            0 :         let mut limiter_info = vec![];
     362            0 :         let mut disable_cache = false;
     363            0 :         let mut retry_interval = None;
     364              : 
     365            0 :         for option in options.split(',') {
     366            0 :             let (key, value) = option
     367            0 :                 .split_once('=')
     368            0 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     369              : 
     370            0 :             match key {
     371            0 :                 "initial_batch_size" => initial_batch_size = Some(value.parse()?),
     372            0 :                 "default_batch_size" => default_batch_size = Some(value.parse()?),
     373            0 :                 "xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
     374            0 :                 "stream_name" => stream_name = Some(value.to_string()),
     375            0 :                 "limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
     376            0 :                 "disable_cache" => disable_cache = value.parse()?,
     377            0 :                 "retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
     378            0 :                 unknown => bail!("unknown key: {unknown}"),
     379              :             }
     380              :         }
     381            0 :         RateBucketInfo::validate(&mut limiter_info)?;
     382              : 
     383              :         Ok(Self {
     384            0 :             initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
     385            0 :             default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
     386            0 :             xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
     387            0 :             stream_name: stream_name.context("missing `stream_name`")?,
     388            0 :             disable_cache,
     389            0 :             limiter_info,
     390            0 :             retry_interval: retry_interval.context("missing `retry_interval`")?,
     391              :         })
     392            0 :     }
     393              : }
     394              : 
     395              : impl FromStr for EndpointCacheConfig {
     396              :     type Err = anyhow::Error;
     397              : 
     398            0 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     399            0 :         let error = || format!("failed to parse endpoint cache options '{options}'");
     400            0 :         Self::parse(options).with_context(error)
     401            0 :     }
     402              : }
     403              : #[derive(Debug)]
     404              : pub struct MetricBackupCollectionConfig {
     405              :     pub interval: Duration,
     406              :     pub remote_storage_config: Option<RemoteStorageConfig>,
     407              :     pub chunk_size: usize,
     408              : }
     409              : 
     410            1 : pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig> {
     411            1 :     RemoteStorageConfig::from_toml(&s.parse()?)
     412            1 : }
     413              : 
     414              : /// Helper for cmdline cache options parsing.
     415              : #[derive(Debug)]
     416              : pub struct CacheOptions {
     417              :     /// Max number of entries.
     418              :     pub size: usize,
     419              :     /// Entry's time-to-live.
     420              :     pub ttl: Duration,
     421              : }
     422              : 
     423              : impl CacheOptions {
     424              :     /// Default options for [`crate::console::provider::NodeInfoCache`].
     425              :     pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
     426              : 
     427              :     /// Parse cache options passed via cmdline.
     428              :     /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
     429            4 :     fn parse(options: &str) -> anyhow::Result<Self> {
     430            4 :         let mut size = None;
     431            4 :         let mut ttl = None;
     432              : 
     433            7 :         for option in options.split(',') {
     434            7 :             let (key, value) = option
     435            7 :                 .split_once('=')
     436            7 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     437              : 
     438            7 :             match key {
     439            7 :                 "size" => size = Some(value.parse()?),
     440            3 :                 "ttl" => ttl = Some(humantime::parse_duration(value)?),
     441            0 :                 unknown => bail!("unknown key: {unknown}"),
     442              :             }
     443              :         }
     444              : 
     445              :         // TTL doesn't matter if cache is always empty.
     446            4 :         if let Some(0) = size {
     447            2 :             ttl.get_or_insert(Duration::default());
     448            2 :         }
     449              : 
     450              :         Ok(Self {
     451            4 :             size: size.context("missing `size`")?,
     452            4 :             ttl: ttl.context("missing `ttl`")?,
     453              :         })
     454            4 :     }
     455              : }
     456              : 
     457              : impl FromStr for CacheOptions {
     458              :     type Err = anyhow::Error;
     459              : 
     460            4 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     461            4 :         let error = || format!("failed to parse cache options '{options}'");
     462            4 :         Self::parse(options).with_context(error)
     463            4 :     }
     464              : }
     465              : 
     466              : /// Helper for cmdline cache options parsing.
     467              : #[derive(Debug)]
     468              : pub struct ProjectInfoCacheOptions {
     469              :     /// Max number of entries.
     470              :     pub size: usize,
     471              :     /// Entry's time-to-live.
     472              :     pub ttl: Duration,
     473              :     /// Max number of roles per endpoint.
     474              :     pub max_roles: usize,
     475              :     /// Gc interval.
     476              :     pub gc_interval: Duration,
     477              : }
     478              : 
     479              : impl ProjectInfoCacheOptions {
     480              :     /// Default options for [`crate::console::provider::NodeInfoCache`].
     481              :     pub const CACHE_DEFAULT_OPTIONS: &'static str =
     482              :         "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
     483              : 
     484              :     /// Parse cache options passed via cmdline.
     485              :     /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
     486            0 :     fn parse(options: &str) -> anyhow::Result<Self> {
     487            0 :         let mut size = None;
     488            0 :         let mut ttl = None;
     489            0 :         let mut max_roles = None;
     490            0 :         let mut gc_interval = None;
     491              : 
     492            0 :         for option in options.split(',') {
     493            0 :             let (key, value) = option
     494            0 :                 .split_once('=')
     495            0 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     496              : 
     497            0 :             match key {
     498            0 :                 "size" => size = Some(value.parse()?),
     499            0 :                 "ttl" => ttl = Some(humantime::parse_duration(value)?),
     500            0 :                 "max_roles" => max_roles = Some(value.parse()?),
     501            0 :                 "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
     502            0 :                 unknown => bail!("unknown key: {unknown}"),
     503              :             }
     504              :         }
     505              : 
     506              :         // TTL doesn't matter if cache is always empty.
     507            0 :         if let Some(0) = size {
     508            0 :             ttl.get_or_insert(Duration::default());
     509            0 :         }
     510              : 
     511              :         Ok(Self {
     512            0 :             size: size.context("missing `size`")?,
     513            0 :             ttl: ttl.context("missing `ttl`")?,
     514            0 :             max_roles: max_roles.context("missing `max_roles`")?,
     515            0 :             gc_interval: gc_interval.context("missing `gc_interval`")?,
     516              :         })
     517            0 :     }
     518              : }
     519              : 
     520              : impl FromStr for ProjectInfoCacheOptions {
     521              :     type Err = anyhow::Error;
     522              : 
     523            0 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     524            0 :         let error = || format!("failed to parse cache options '{options}'");
     525            0 :         Self::parse(options).with_context(error)
     526            0 :     }
     527              : }
     528              : 
     529              : /// This is a config for connect to compute and wake compute.
     530              : #[derive(Clone, Copy, Debug)]
     531              : pub struct RetryConfig {
     532              :     /// Number of times we should retry.
     533              :     pub max_retries: u32,
     534              :     /// Retry duration is base_delay * backoff_factor ^ n, where n starts at 0
     535              :     pub base_delay: tokio::time::Duration,
     536              :     /// Exponential base for retry wait duration
     537              :     pub backoff_factor: f64,
     538              : }
     539              : 
     540              : impl RetryConfig {
     541              :     /// Default options for RetryConfig.
     542              : 
     543              :     /// Total delay for 5 retries with 200ms base delay and 2 backoff factor is about 6s.
     544              :     pub const CONNECT_TO_COMPUTE_DEFAULT_VALUES: &'static str =
     545              :         "num_retries=5,base_retry_wait_duration=200ms,retry_wait_exponent_base=2";
     546              :     /// Total delay for 8 retries with 100ms base delay and 1.6 backoff factor is about 7s.
     547              :     /// Cplane has timeout of 60s on each request. 8m7s in total.
     548              :     pub const WAKE_COMPUTE_DEFAULT_VALUES: &'static str =
     549              :         "num_retries=8,base_retry_wait_duration=100ms,retry_wait_exponent_base=1.6";
     550              : 
     551              :     /// Parse retry options passed via cmdline.
     552              :     /// Example: [`Self::CONNECT_TO_COMPUTE_DEFAULT_VALUES`].
     553            0 :     pub fn parse(options: &str) -> anyhow::Result<Self> {
     554            0 :         let mut num_retries = None;
     555            0 :         let mut base_retry_wait_duration = None;
     556            0 :         let mut retry_wait_exponent_base = None;
     557              : 
     558            0 :         for option in options.split(',') {
     559            0 :             let (key, value) = option
     560            0 :                 .split_once('=')
     561            0 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     562              : 
     563            0 :             match key {
     564            0 :                 "num_retries" => num_retries = Some(value.parse()?),
     565            0 :                 "base_retry_wait_duration" => {
     566            0 :                     base_retry_wait_duration = Some(humantime::parse_duration(value)?);
     567              :                 }
     568            0 :                 "retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?),
     569            0 :                 unknown => bail!("unknown key: {unknown}"),
     570              :             }
     571              :         }
     572              : 
     573              :         Ok(Self {
     574            0 :             max_retries: num_retries.context("missing `num_retries`")?,
     575            0 :             base_delay: base_retry_wait_duration.context("missing `base_retry_wait_duration`")?,
     576            0 :             backoff_factor: retry_wait_exponent_base
     577            0 :                 .context("missing `retry_wait_exponent_base`")?,
     578              :         })
     579            0 :     }
     580              : }
     581              : 
     582              : /// Helper for cmdline cache options parsing.
     583            8 : #[derive(serde::Deserialize)]
     584              : pub struct ConcurrencyLockOptions {
     585              :     /// The number of shards the lock map should have
     586              :     pub shards: usize,
     587              :     /// The number of allowed concurrent requests for each endpoitn
     588              :     #[serde(flatten)]
     589              :     pub limiter: RateLimiterConfig,
     590              :     /// Garbage collection epoch
     591              :     #[serde(deserialize_with = "humantime_serde::deserialize")]
     592              :     pub epoch: Duration,
     593              :     /// Lock timeout
     594              :     #[serde(deserialize_with = "humantime_serde::deserialize")]
     595              :     pub timeout: Duration,
     596              : }
     597              : 
     598              : impl ConcurrencyLockOptions {
     599              :     /// Default options for [`crate::console::provider::ApiLocks`].
     600              :     pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
     601              :     /// Default options for [`crate::console::provider::ApiLocks`].
     602              :     pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
     603              :         "shards=64,permits=100,epoch=10m,timeout=10ms";
     604              : 
     605              :     // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
     606              : 
     607              :     /// Parse lock options passed via cmdline.
     608              :     /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
     609            4 :     fn parse(options: &str) -> anyhow::Result<Self> {
     610            4 :         let options = options.trim();
     611            4 :         if options.starts_with('{') && options.ends_with('}') {
     612            1 :             return Ok(serde_json::from_str(options)?);
     613            3 :         }
     614            3 : 
     615            3 :         let mut shards = None;
     616            3 :         let mut permits = None;
     617            3 :         let mut epoch = None;
     618            3 :         let mut timeout = None;
     619              : 
     620            9 :         for option in options.split(',') {
     621            9 :             let (key, value) = option
     622            9 :                 .split_once('=')
     623            9 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     624              : 
     625            9 :             match key {
     626            9 :                 "shards" => shards = Some(value.parse()?),
     627            7 :                 "permits" => permits = Some(value.parse()?),
     628            4 :                 "epoch" => epoch = Some(humantime::parse_duration(value)?),
     629            2 :                 "timeout" => timeout = Some(humantime::parse_duration(value)?),
     630            0 :                 unknown => bail!("unknown key: {unknown}"),
     631              :             }
     632              :         }
     633              : 
     634              :         // these dont matter if lock is disabled
     635            3 :         if let Some(0) = permits {
     636            1 :             timeout = Some(Duration::default());
     637            1 :             epoch = Some(Duration::default());
     638            1 :             shards = Some(2);
     639            2 :         }
     640              : 
     641            3 :         let permits = permits.context("missing `permits`")?;
     642            3 :         let out = Self {
     643            3 :             shards: shards.context("missing `shards`")?,
     644            3 :             limiter: RateLimiterConfig {
     645            3 :                 algorithm: RateLimitAlgorithm::Fixed,
     646            3 :                 initial_limit: permits,
     647            3 :             },
     648            3 :             epoch: epoch.context("missing `epoch`")?,
     649            3 :             timeout: timeout.context("missing `timeout`")?,
     650              :         };
     651              : 
     652            3 :         ensure!(out.shards > 1, "shard count must be > 1");
     653            3 :         ensure!(
     654            3 :             out.shards.is_power_of_two(),
     655            0 :             "shard count must be a power of two"
     656              :         );
     657              : 
     658            3 :         Ok(out)
     659            4 :     }
     660              : }
     661              : 
     662              : impl FromStr for ConcurrencyLockOptions {
     663              :     type Err = anyhow::Error;
     664              : 
     665            4 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     666            4 :         let error = || format!("failed to parse cache lock options '{options}'");
     667            4 :         Self::parse(options).with_context(error)
     668            4 :     }
     669              : }
     670              : 
     671              : #[cfg(test)]
     672              : mod tests {
     673              :     use crate::rate_limiter::Aimd;
     674              : 
     675              :     use super::*;
     676              : 
     677              :     #[test]
     678            1 :     fn test_parse_cache_options() -> anyhow::Result<()> {
     679            1 :         let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
     680            1 :         assert_eq!(size, 4096);
     681            1 :         assert_eq!(ttl, Duration::from_secs(5 * 60));
     682              : 
     683            1 :         let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
     684            1 :         assert_eq!(size, 2);
     685            1 :         assert_eq!(ttl, Duration::from_secs(4 * 60));
     686              : 
     687            1 :         let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
     688            1 :         assert_eq!(size, 0);
     689            1 :         assert_eq!(ttl, Duration::from_secs(1));
     690              : 
     691            1 :         let CacheOptions { size, ttl } = "size=0".parse()?;
     692            1 :         assert_eq!(size, 0);
     693            1 :         assert_eq!(ttl, Duration::default());
     694              : 
     695            1 :         Ok(())
     696            1 :     }
     697              : 
     698              :     #[test]
     699            1 :     fn test_parse_lock_options() -> anyhow::Result<()> {
     700              :         let ConcurrencyLockOptions {
     701            1 :             epoch,
     702            1 :             limiter,
     703            1 :             shards,
     704            1 :             timeout,
     705            1 :         } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
     706            1 :         assert_eq!(epoch, Duration::from_secs(10 * 60));
     707            1 :         assert_eq!(timeout, Duration::from_secs(1));
     708            1 :         assert_eq!(shards, 32);
     709            1 :         assert_eq!(limiter.initial_limit, 4);
     710            1 :         assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
     711              : 
     712              :         let ConcurrencyLockOptions {
     713            1 :             epoch,
     714            1 :             limiter,
     715            1 :             shards,
     716            1 :             timeout,
     717            1 :         } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
     718            1 :         assert_eq!(epoch, Duration::from_secs(60));
     719            1 :         assert_eq!(timeout, Duration::from_millis(100));
     720            1 :         assert_eq!(shards, 16);
     721            1 :         assert_eq!(limiter.initial_limit, 8);
     722            1 :         assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
     723              : 
     724              :         let ConcurrencyLockOptions {
     725            1 :             epoch,
     726            1 :             limiter,
     727            1 :             shards,
     728            1 :             timeout,
     729            1 :         } = "permits=0".parse()?;
     730            1 :         assert_eq!(epoch, Duration::ZERO);
     731            1 :         assert_eq!(timeout, Duration::ZERO);
     732            1 :         assert_eq!(shards, 2);
     733            1 :         assert_eq!(limiter.initial_limit, 0);
     734            1 :         assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
     735              : 
     736            1 :         Ok(())
     737            1 :     }
     738              : 
     739              :     #[test]
     740            1 :     fn test_parse_json_lock_options() -> anyhow::Result<()> {
     741              :         let ConcurrencyLockOptions {
     742            1 :             epoch,
     743            1 :             limiter,
     744            1 :             shards,
     745            1 :             timeout,
     746            1 :         } = r#"{"shards":32,"initial_limit":44,"aimd":{"min":5,"max":500,"inc":10,"dec":0.9,"utilisation":0.8},"epoch":"10m","timeout":"1s"}"#
     747            1 :             .parse()?;
     748            1 :         assert_eq!(epoch, Duration::from_secs(10 * 60));
     749            1 :         assert_eq!(timeout, Duration::from_secs(1));
     750            1 :         assert_eq!(shards, 32);
     751            1 :         assert_eq!(limiter.initial_limit, 44);
     752            1 :         assert_eq!(
     753            1 :             limiter.algorithm,
     754            1 :             RateLimitAlgorithm::Aimd {
     755            1 :                 conf: Aimd {
     756            1 :                     min: 5,
     757            1 :                     max: 500,
     758            1 :                     dec: 0.9,
     759            1 :                     inc: 10,
     760            1 :                     utilisation: 0.8
     761            1 :                 }
     762            1 :             },
     763            1 :         );
     764              : 
     765            1 :         Ok(())
     766            1 :     }
     767              : }
        

Generated by: LCOV version 2.1-beta