LCOV - code coverage report
Current view: top level - proxy/src - config.rs (source / functions) Coverage Total Hit
Test: b4ae4c4857f9ef3e144e982a35ee23bc84c71983.info Lines: 55.7 % 370 206
Test Date: 2024-10-22 22:13:45 Functions: 34.9 % 63 22

            Line data    Source code
       1              : use std::collections::{HashMap, HashSet};
       2              : use std::str::FromStr;
       3              : use std::sync::Arc;
       4              : use std::time::Duration;
       5              : 
       6              : use anyhow::{bail, ensure, Context, Ok};
       7              : use clap::ValueEnum;
       8              : use itertools::Itertools;
       9              : use remote_storage::RemoteStorageConfig;
      10              : use rustls::crypto::aws_lc_rs::{self, sign};
      11              : use rustls::pki_types::{CertificateDer, PrivateKeyDer};
      12              : use sha2::{Digest, Sha256};
      13              : use tracing::{error, info};
      14              : use x509_parser::oid_registry;
      15              : 
      16              : use crate::auth::backend::jwt::JwkCache;
      17              : use crate::auth::backend::AuthRateLimiter;
      18              : use crate::control_plane::locks::ApiLocks;
      19              : use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
      20              : use crate::scram::threadpool::ThreadPool;
      21              : use crate::serverless::cancel_set::CancelSet;
      22              : use crate::serverless::GlobalConnPoolOptions;
      23              : use crate::Host;
      24              : 
      25              : pub struct ProxyConfig {
      26              :     pub tls_config: Option<TlsConfig>,
      27              :     pub metric_collection: Option<MetricCollectionConfig>,
      28              :     pub allow_self_signed_compute: bool,
      29              :     pub http_config: HttpConfig,
      30              :     pub authentication_config: AuthenticationConfig,
      31              :     pub proxy_protocol_v2: ProxyProtocolV2,
      32              :     pub region: String,
      33              :     pub handshake_timeout: Duration,
      34              :     pub wake_compute_retry_config: RetryConfig,
      35              :     pub connect_compute_locks: ApiLocks<Host>,
      36              :     pub connect_to_compute_retry_config: RetryConfig,
      37              : }
      38              : 
      39            5 : #[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
      40              : pub enum ProxyProtocolV2 {
      41              :     /// Connection will error if PROXY protocol v2 header is missing
      42              :     Required,
      43              :     /// Connection will parse PROXY protocol v2 header, but accept the connection if it's missing.
      44              :     Supported,
      45              :     /// Connection will error if PROXY protocol v2 header is provided
      46              :     Rejected,
      47              : }
      48              : 
      49              : #[derive(Debug)]
      50              : pub struct MetricCollectionConfig {
      51              :     pub endpoint: reqwest::Url,
      52              :     pub interval: Duration,
      53              :     pub backup_metric_collection_config: MetricBackupCollectionConfig,
      54              : }
      55              : 
      56              : pub struct TlsConfig {
      57              :     pub config: Arc<rustls::ServerConfig>,
      58              :     pub common_names: HashSet<String>,
      59              :     pub cert_resolver: Arc<CertResolver>,
      60              : }
      61              : 
      62              : pub struct HttpConfig {
      63              :     pub accept_websockets: bool,
      64              :     pub pool_options: GlobalConnPoolOptions,
      65              :     pub cancel_set: CancelSet,
      66              :     pub client_conn_threshold: u64,
      67              :     pub max_request_size_bytes: u64,
      68              :     pub max_response_size_bytes: usize,
      69              : }
      70              : 
      71              : pub struct AuthenticationConfig {
      72              :     pub thread_pool: Arc<ThreadPool>,
      73              :     pub scram_protocol_timeout: tokio::time::Duration,
      74              :     pub rate_limiter_enabled: bool,
      75              :     pub rate_limiter: AuthRateLimiter,
      76              :     pub rate_limit_ip_subnet: u8,
      77              :     pub ip_allowlist_check_enabled: bool,
      78              :     pub jwks_cache: JwkCache,
      79              :     pub is_auth_broker: bool,
      80              :     pub accept_jwts: bool,
      81              :     pub webauth_confirmation_timeout: tokio::time::Duration,
      82              : }
      83              : 
      84              : impl TlsConfig {
      85           20 :     pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
      86           20 :         self.config.clone()
      87           20 :     }
      88              : }
      89              : 
      90              : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
      91              : pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
      92              : 
      93              : /// Configure TLS for the main endpoint.
      94            0 : pub fn configure_tls(
      95            0 :     key_path: &str,
      96            0 :     cert_path: &str,
      97            0 :     certs_dir: Option<&String>,
      98            0 : ) -> anyhow::Result<TlsConfig> {
      99            0 :     let mut cert_resolver = CertResolver::new();
     100            0 : 
     101            0 :     // add default certificate
     102            0 :     cert_resolver.add_cert_path(key_path, cert_path, true)?;
     103              : 
     104              :     // add extra certificates
     105            0 :     if let Some(certs_dir) = certs_dir {
     106            0 :         for entry in std::fs::read_dir(certs_dir)? {
     107            0 :             let entry = entry?;
     108            0 :             let path = entry.path();
     109            0 :             if path.is_dir() {
     110              :                 // file names aligned with default cert-manager names
     111            0 :                 let key_path = path.join("tls.key");
     112            0 :                 let cert_path = path.join("tls.crt");
     113            0 :                 if key_path.exists() && cert_path.exists() {
     114            0 :                     cert_resolver.add_cert_path(
     115            0 :                         &key_path.to_string_lossy(),
     116            0 :                         &cert_path.to_string_lossy(),
     117            0 :                         false,
     118            0 :                     )?;
     119            0 :                 }
     120            0 :             }
     121              :         }
     122            0 :     }
     123              : 
     124            0 :     let common_names = cert_resolver.get_common_names();
     125            0 : 
     126            0 :     let cert_resolver = Arc::new(cert_resolver);
     127              : 
     128              :     // allow TLS 1.2 to be compatible with older client libraries
     129            0 :     let mut config =
     130            0 :         rustls::ServerConfig::builder_with_provider(Arc::new(aws_lc_rs::default_provider()))
     131            0 :             .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
     132            0 :             .context("aws_lc_rs should support TLS1.2 and TLS1.3")?
     133            0 :             .with_no_client_auth()
     134            0 :             .with_cert_resolver(cert_resolver.clone());
     135            0 : 
     136            0 :     config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
     137            0 : 
     138            0 :     Ok(TlsConfig {
     139            0 :         config: Arc::new(config),
     140            0 :         common_names,
     141            0 :         cert_resolver,
     142            0 :     })
     143            0 : }
     144              : 
     145              : /// Channel binding parameter
     146              : ///
     147              : /// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
     148              : /// Description: The hash of the TLS server's certificate as it
     149              : /// appears, octet for octet, in the server's Certificate message.  Note
     150              : /// that the Certificate message contains a certificate_list, in which
     151              : /// the first element is the server's certificate.
     152              : ///
     153              : /// The hash function is to be selected as follows:
     154              : ///
     155              : /// * if the certificate's signatureAlgorithm uses a single hash
     156              : ///   function, and that hash function is either MD5 or SHA-1, then use SHA-256;
     157              : ///
     158              : /// * if the certificate's signatureAlgorithm uses a single hash
     159              : ///   function and that hash function neither MD5 nor SHA-1, then use
     160              : ///   the hash function associated with the certificate's
     161              : ///   signatureAlgorithm;
     162              : ///
     163              : /// * if the certificate's signatureAlgorithm uses no hash functions or
     164              : ///   uses multiple hash functions, then this channel binding type's
     165              : ///   channel bindings are undefined at this time (updates to is channel
     166              : ///   binding type may occur to address this issue if it ever arises).
     167              : #[derive(Debug, Clone, Copy)]
     168              : pub enum TlsServerEndPoint {
     169              :     Sha256([u8; 32]),
     170              :     Undefined,
     171              : }
     172              : 
     173              : impl TlsServerEndPoint {
     174           21 :     pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
     175           21 :         let sha256_oids = [
     176           21 :             // I'm explicitly not adding MD5 or SHA1 here... They're bad.
     177           21 :             oid_registry::OID_SIG_ECDSA_WITH_SHA256,
     178           21 :             oid_registry::OID_PKCS1_SHA256WITHRSA,
     179           21 :         ];
     180              : 
     181           21 :         let pem = x509_parser::parse_x509_certificate(cert)
     182           21 :             .context("Failed to parse PEM object from cerficiate")?
     183              :             .1;
     184              : 
     185           21 :         info!(subject = %pem.subject, "parsing TLS certificate");
     186              : 
     187           21 :         let reg = oid_registry::OidRegistry::default().with_all_crypto();
     188           21 :         let oid = pem.signature_algorithm.oid();
     189           21 :         let alg = reg.get(oid);
     190           21 :         if sha256_oids.contains(oid) {
     191           21 :             let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
     192           21 :             info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
     193           21 :             Ok(Self::Sha256(tls_server_end_point))
     194              :         } else {
     195            0 :             error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
     196            0 :             Ok(Self::Undefined)
     197              :         }
     198           21 :     }
     199              : 
     200           16 :     pub fn supported(&self) -> bool {
     201           16 :         !matches!(self, TlsServerEndPoint::Undefined)
     202           16 :     }
     203              : }
     204              : 
     205              : #[derive(Default, Debug)]
     206              : pub struct CertResolver {
     207              :     certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
     208              :     default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
     209              : }
     210              : 
     211              : impl CertResolver {
     212           21 :     pub fn new() -> Self {
     213           21 :         Self::default()
     214           21 :     }
     215              : 
     216            0 :     fn add_cert_path(
     217            0 :         &mut self,
     218            0 :         key_path: &str,
     219            0 :         cert_path: &str,
     220            0 :         is_default: bool,
     221            0 :     ) -> anyhow::Result<()> {
     222            0 :         let priv_key = {
     223            0 :             let key_bytes = std::fs::read(key_path)
     224            0 :                 .context(format!("Failed to read TLS keys at '{key_path}'"))?;
     225            0 :             let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
     226            0 : 
     227            0 :             ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
     228              :             PrivateKeyDer::Pkcs8(
     229            0 :                 keys.pop()
     230            0 :                     .unwrap()
     231            0 :                     .context(format!("Failed to parse TLS keys at '{key_path}'"))?,
     232              :             )
     233              :         };
     234              : 
     235            0 :         let cert_chain_bytes = std::fs::read(cert_path)
     236            0 :             .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
     237              : 
     238            0 :         let cert_chain = {
     239            0 :             rustls_pemfile::certs(&mut &cert_chain_bytes[..])
     240            0 :                 .try_collect()
     241            0 :                 .with_context(|| {
     242            0 :                     format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
     243            0 :                 })?
     244              :         };
     245              : 
     246            0 :         self.add_cert(priv_key, cert_chain, is_default)
     247            0 :     }
     248              : 
     249           21 :     pub fn add_cert(
     250           21 :         &mut self,
     251           21 :         priv_key: PrivateKeyDer<'static>,
     252           21 :         cert_chain: Vec<CertificateDer<'static>>,
     253           21 :         is_default: bool,
     254           21 :     ) -> anyhow::Result<()> {
     255           21 :         let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
     256              : 
     257           21 :         let first_cert = &cert_chain[0];
     258           21 :         let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
     259           21 :         let pem = x509_parser::parse_x509_certificate(first_cert)
     260           21 :             .context("Failed to parse PEM object from cerficiate")?
     261              :             .1;
     262              : 
     263           21 :         let common_name = pem.subject().to_string();
     264              : 
     265              :         // We need to get the canonical name for this certificate so we can match them against any domain names
     266              :         // seen within the proxy codebase.
     267              :         //
     268              :         // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
     269              :         // We need to remove the wildcard prefix for the purposes of certificate selection.
     270              :         //
     271              :         // auth-broker does not use SNI and instead uses the Neon-Connection-String header.
     272              :         // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
     273              :         //
     274              :         // Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
     275              :         // validation, so let's we can continue with any common-name
     276           21 :         let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
     277            0 :             s.to_string()
     278           21 :         } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
     279            0 :             s.to_string()
     280           21 :         } else if let Some(s) = common_name.strip_prefix("CN=") {
     281           21 :             s.to_string()
     282              :         } else {
     283            0 :             bail!("Failed to parse common name from certificate")
     284              :         };
     285              : 
     286           21 :         let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
     287           21 : 
     288           21 :         if is_default {
     289           21 :             self.default = Some((cert.clone(), tls_server_end_point));
     290           21 :         }
     291              : 
     292           21 :         self.certs.insert(common_name, (cert, tls_server_end_point));
     293           21 : 
     294           21 :         Ok(())
     295           21 :     }
     296              : 
     297           21 :     pub fn get_common_names(&self) -> HashSet<String> {
     298           21 :         self.certs.keys().map(|s| s.to_string()).collect()
     299           21 :     }
     300              : }
     301              : 
     302              : impl rustls::server::ResolvesServerCert for CertResolver {
     303            0 :     fn resolve(
     304            0 :         &self,
     305            0 :         client_hello: rustls::server::ClientHello<'_>,
     306            0 :     ) -> Option<Arc<rustls::sign::CertifiedKey>> {
     307            0 :         self.resolve(client_hello.server_name()).map(|x| x.0)
     308            0 :     }
     309              : }
     310              : 
     311              : impl CertResolver {
     312           20 :     pub fn resolve(
     313           20 :         &self,
     314           20 :         server_name: Option<&str>,
     315           20 :     ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
     316              :         // loop here and cut off more and more subdomains until we find
     317              :         // a match to get a proper wildcard support. OTOH, we now do not
     318              :         // use nested domains, so keep this simple for now.
     319              :         //
     320              :         // With the current coding foo.com will match *.foo.com and that
     321              :         // repeats behavior of the old code.
     322           20 :         if let Some(mut sni_name) = server_name {
     323              :             loop {
     324           40 :                 if let Some(cert) = self.certs.get(sni_name) {
     325           20 :                     return Some(cert.clone());
     326           20 :                 }
     327           20 :                 if let Some((_, rest)) = sni_name.split_once('.') {
     328           20 :                     sni_name = rest;
     329           20 :                 } else {
     330            0 :                     return None;
     331              :                 }
     332              :             }
     333              :         } else {
     334              :             // No SNI, use the default certificate, otherwise we can't get to
     335              :             // options parameter which can be used to set endpoint name too.
     336              :             // That means that non-SNI flow will not work for CNAME domains in
     337              :             // verify-full mode.
     338              :             //
     339              :             // If that will be a problem we can:
     340              :             //
     341              :             // a) Instead of multi-cert approach use single cert with extra
     342              :             //    domains listed in Subject Alternative Name (SAN).
     343              :             // b) Deploy separate proxy instances for extra domains.
     344            0 :             self.default.clone()
     345              :         }
     346           20 :     }
     347              : }
     348              : 
     349              : #[derive(Debug)]
     350              : pub struct EndpointCacheConfig {
     351              :     /// Batch size to receive all endpoints on the startup.
     352              :     pub initial_batch_size: usize,
     353              :     /// Batch size to receive endpoints.
     354              :     pub default_batch_size: usize,
     355              :     /// Timeouts for the stream read operation.
     356              :     pub xread_timeout: Duration,
     357              :     /// Stream name to read from.
     358              :     pub stream_name: String,
     359              :     /// Limiter info (to distinguish when to enable cache).
     360              :     pub limiter_info: Vec<RateBucketInfo>,
     361              :     /// Disable cache.
     362              :     /// If true, cache is ignored, but reports all statistics.
     363              :     pub disable_cache: bool,
     364              :     /// Retry interval for the stream read operation.
     365              :     pub retry_interval: Duration,
     366              : }
     367              : 
     368              : impl EndpointCacheConfig {
     369              :     /// Default options for [`crate::control_plane::provider::NodeInfoCache`].
     370              :     /// Notice that by default the limiter is empty, which means that cache is disabled.
     371              :     pub const CACHE_DEFAULT_OPTIONS: &'static str =
     372              :         "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
     373              : 
     374              :     /// Parse cache options passed via cmdline.
     375              :     /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
     376            0 :     fn parse(options: &str) -> anyhow::Result<Self> {
     377            0 :         let mut initial_batch_size = None;
     378            0 :         let mut default_batch_size = None;
     379            0 :         let mut xread_timeout = None;
     380            0 :         let mut stream_name = None;
     381            0 :         let mut limiter_info = vec![];
     382            0 :         let mut disable_cache = false;
     383            0 :         let mut retry_interval = None;
     384              : 
     385            0 :         for option in options.split(',') {
     386            0 :             let (key, value) = option
     387            0 :                 .split_once('=')
     388            0 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     389              : 
     390            0 :             match key {
     391            0 :                 "initial_batch_size" => initial_batch_size = Some(value.parse()?),
     392            0 :                 "default_batch_size" => default_batch_size = Some(value.parse()?),
     393            0 :                 "xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
     394            0 :                 "stream_name" => stream_name = Some(value.to_string()),
     395            0 :                 "limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
     396            0 :                 "disable_cache" => disable_cache = value.parse()?,
     397            0 :                 "retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
     398            0 :                 unknown => bail!("unknown key: {unknown}"),
     399              :             }
     400              :         }
     401            0 :         RateBucketInfo::validate(&mut limiter_info)?;
     402              : 
     403              :         Ok(Self {
     404            0 :             initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
     405            0 :             default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
     406            0 :             xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
     407            0 :             stream_name: stream_name.context("missing `stream_name`")?,
     408            0 :             disable_cache,
     409            0 :             limiter_info,
     410            0 :             retry_interval: retry_interval.context("missing `retry_interval`")?,
     411              :         })
     412            0 :     }
     413              : }
     414              : 
     415              : impl FromStr for EndpointCacheConfig {
     416              :     type Err = anyhow::Error;
     417              : 
     418            0 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     419            0 :         let error = || format!("failed to parse endpoint cache options '{options}'");
     420            0 :         Self::parse(options).with_context(error)
     421            0 :     }
     422              : }
     423              : #[derive(Debug)]
     424              : pub struct MetricBackupCollectionConfig {
     425              :     pub interval: Duration,
     426              :     pub remote_storage_config: Option<RemoteStorageConfig>,
     427              :     pub chunk_size: usize,
     428              : }
     429              : 
     430            1 : pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig> {
     431            1 :     RemoteStorageConfig::from_toml(&s.parse()?)
     432            1 : }
     433              : 
     434              : /// Helper for cmdline cache options parsing.
     435              : #[derive(Debug)]
     436              : pub struct CacheOptions {
     437              :     /// Max number of entries.
     438              :     pub size: usize,
     439              :     /// Entry's time-to-live.
     440              :     pub ttl: Duration,
     441              : }
     442              : 
     443              : impl CacheOptions {
     444              :     /// Default options for [`crate::control_plane::provider::NodeInfoCache`].
     445              :     pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
     446              : 
     447              :     /// Parse cache options passed via cmdline.
     448              :     /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
     449            4 :     fn parse(options: &str) -> anyhow::Result<Self> {
     450            4 :         let mut size = None;
     451            4 :         let mut ttl = None;
     452              : 
     453            7 :         for option in options.split(',') {
     454            7 :             let (key, value) = option
     455            7 :                 .split_once('=')
     456            7 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     457              : 
     458            7 :             match key {
     459            7 :                 "size" => size = Some(value.parse()?),
     460            3 :                 "ttl" => ttl = Some(humantime::parse_duration(value)?),
     461            0 :                 unknown => bail!("unknown key: {unknown}"),
     462              :             }
     463              :         }
     464              : 
     465              :         // TTL doesn't matter if cache is always empty.
     466            4 :         if let Some(0) = size {
     467            2 :             ttl.get_or_insert(Duration::default());
     468            2 :         }
     469              : 
     470              :         Ok(Self {
     471            4 :             size: size.context("missing `size`")?,
     472            4 :             ttl: ttl.context("missing `ttl`")?,
     473              :         })
     474            4 :     }
     475              : }
     476              : 
     477              : impl FromStr for CacheOptions {
     478              :     type Err = anyhow::Error;
     479              : 
     480            4 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     481            4 :         let error = || format!("failed to parse cache options '{options}'");
     482            4 :         Self::parse(options).with_context(error)
     483            4 :     }
     484              : }
     485              : 
     486              : /// Helper for cmdline cache options parsing.
     487              : #[derive(Debug)]
     488              : pub struct ProjectInfoCacheOptions {
     489              :     /// Max number of entries.
     490              :     pub size: usize,
     491              :     /// Entry's time-to-live.
     492              :     pub ttl: Duration,
     493              :     /// Max number of roles per endpoint.
     494              :     pub max_roles: usize,
     495              :     /// Gc interval.
     496              :     pub gc_interval: Duration,
     497              : }
     498              : 
     499              : impl ProjectInfoCacheOptions {
     500              :     /// Default options for [`crate::control_plane::provider::NodeInfoCache`].
     501              :     pub const CACHE_DEFAULT_OPTIONS: &'static str =
     502              :         "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
     503              : 
     504              :     /// Parse cache options passed via cmdline.
     505              :     /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
     506            0 :     fn parse(options: &str) -> anyhow::Result<Self> {
     507            0 :         let mut size = None;
     508            0 :         let mut ttl = None;
     509            0 :         let mut max_roles = None;
     510            0 :         let mut gc_interval = None;
     511              : 
     512            0 :         for option in options.split(',') {
     513            0 :             let (key, value) = option
     514            0 :                 .split_once('=')
     515            0 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     516              : 
     517            0 :             match key {
     518            0 :                 "size" => size = Some(value.parse()?),
     519            0 :                 "ttl" => ttl = Some(humantime::parse_duration(value)?),
     520            0 :                 "max_roles" => max_roles = Some(value.parse()?),
     521            0 :                 "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
     522            0 :                 unknown => bail!("unknown key: {unknown}"),
     523              :             }
     524              :         }
     525              : 
     526              :         // TTL doesn't matter if cache is always empty.
     527            0 :         if let Some(0) = size {
     528            0 :             ttl.get_or_insert(Duration::default());
     529            0 :         }
     530              : 
     531              :         Ok(Self {
     532            0 :             size: size.context("missing `size`")?,
     533            0 :             ttl: ttl.context("missing `ttl`")?,
     534            0 :             max_roles: max_roles.context("missing `max_roles`")?,
     535            0 :             gc_interval: gc_interval.context("missing `gc_interval`")?,
     536              :         })
     537            0 :     }
     538              : }
     539              : 
     540              : impl FromStr for ProjectInfoCacheOptions {
     541              :     type Err = anyhow::Error;
     542              : 
     543            0 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     544            0 :         let error = || format!("failed to parse cache options '{options}'");
     545            0 :         Self::parse(options).with_context(error)
     546            0 :     }
     547              : }
     548              : 
     549              : /// This is a config for connect to compute and wake compute.
     550              : #[derive(Clone, Copy, Debug)]
     551              : pub struct RetryConfig {
     552              :     /// Number of times we should retry.
     553              :     pub max_retries: u32,
     554              :     /// Retry duration is base_delay * backoff_factor ^ n, where n starts at 0
     555              :     pub base_delay: tokio::time::Duration,
     556              :     /// Exponential base for retry wait duration
     557              :     pub backoff_factor: f64,
     558              : }
     559              : 
     560              : impl RetryConfig {
     561              :     // Default options for RetryConfig.
     562              : 
     563              :     /// Total delay for 5 retries with 200ms base delay and 2 backoff factor is about 6s.
     564              :     pub const CONNECT_TO_COMPUTE_DEFAULT_VALUES: &'static str =
     565              :         "num_retries=5,base_retry_wait_duration=200ms,retry_wait_exponent_base=2";
     566              :     /// Total delay for 8 retries with 100ms base delay and 1.6 backoff factor is about 7s.
     567              :     /// Cplane has timeout of 60s on each request. 8m7s in total.
     568              :     pub const WAKE_COMPUTE_DEFAULT_VALUES: &'static str =
     569              :         "num_retries=8,base_retry_wait_duration=100ms,retry_wait_exponent_base=1.6";
     570              : 
     571              :     /// Parse retry options passed via cmdline.
     572              :     /// Example: [`Self::CONNECT_TO_COMPUTE_DEFAULT_VALUES`].
     573            0 :     pub fn parse(options: &str) -> anyhow::Result<Self> {
     574            0 :         let mut num_retries = None;
     575            0 :         let mut base_retry_wait_duration = None;
     576            0 :         let mut retry_wait_exponent_base = None;
     577              : 
     578            0 :         for option in options.split(',') {
     579            0 :             let (key, value) = option
     580            0 :                 .split_once('=')
     581            0 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     582              : 
     583            0 :             match key {
     584            0 :                 "num_retries" => num_retries = Some(value.parse()?),
     585            0 :                 "base_retry_wait_duration" => {
     586            0 :                     base_retry_wait_duration = Some(humantime::parse_duration(value)?);
     587              :                 }
     588            0 :                 "retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?),
     589            0 :                 unknown => bail!("unknown key: {unknown}"),
     590              :             }
     591              :         }
     592              : 
     593              :         Ok(Self {
     594            0 :             max_retries: num_retries.context("missing `num_retries`")?,
     595            0 :             base_delay: base_retry_wait_duration.context("missing `base_retry_wait_duration`")?,
     596            0 :             backoff_factor: retry_wait_exponent_base
     597            0 :                 .context("missing `retry_wait_exponent_base`")?,
     598              :         })
     599            0 :     }
     600              : }
     601              : 
     602              : /// Helper for cmdline cache options parsing.
     603            8 : #[derive(serde::Deserialize)]
     604              : pub struct ConcurrencyLockOptions {
     605              :     /// The number of shards the lock map should have
     606              :     pub shards: usize,
     607              :     /// The number of allowed concurrent requests for each endpoitn
     608              :     #[serde(flatten)]
     609              :     pub limiter: RateLimiterConfig,
     610              :     /// Garbage collection epoch
     611              :     #[serde(deserialize_with = "humantime_serde::deserialize")]
     612              :     pub epoch: Duration,
     613              :     /// Lock timeout
     614              :     #[serde(deserialize_with = "humantime_serde::deserialize")]
     615              :     pub timeout: Duration,
     616              : }
     617              : 
     618              : impl ConcurrencyLockOptions {
     619              :     /// Default options for [`crate::control_plane::provider::ApiLocks`].
     620              :     pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
     621              :     /// Default options for [`crate::control_plane::provider::ApiLocks`].
     622              :     pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
     623              :         "shards=64,permits=100,epoch=10m,timeout=10ms";
     624              : 
     625              :     // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
     626              : 
     627              :     /// Parse lock options passed via cmdline.
     628              :     /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
     629            4 :     fn parse(options: &str) -> anyhow::Result<Self> {
     630            4 :         let options = options.trim();
     631            4 :         if options.starts_with('{') && options.ends_with('}') {
     632            1 :             return Ok(serde_json::from_str(options)?);
     633            3 :         }
     634            3 : 
     635            3 :         let mut shards = None;
     636            3 :         let mut permits = None;
     637            3 :         let mut epoch = None;
     638            3 :         let mut timeout = None;
     639              : 
     640            9 :         for option in options.split(',') {
     641            9 :             let (key, value) = option
     642            9 :                 .split_once('=')
     643            9 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     644              : 
     645            9 :             match key {
     646            9 :                 "shards" => shards = Some(value.parse()?),
     647            7 :                 "permits" => permits = Some(value.parse()?),
     648            4 :                 "epoch" => epoch = Some(humantime::parse_duration(value)?),
     649            2 :                 "timeout" => timeout = Some(humantime::parse_duration(value)?),
     650            0 :                 unknown => bail!("unknown key: {unknown}"),
     651              :             }
     652              :         }
     653              : 
     654              :         // these dont matter if lock is disabled
     655            3 :         if let Some(0) = permits {
     656            1 :             timeout = Some(Duration::default());
     657            1 :             epoch = Some(Duration::default());
     658            1 :             shards = Some(2);
     659            2 :         }
     660              : 
     661            3 :         let permits = permits.context("missing `permits`")?;
     662            3 :         let out = Self {
     663            3 :             shards: shards.context("missing `shards`")?,
     664            3 :             limiter: RateLimiterConfig {
     665            3 :                 algorithm: RateLimitAlgorithm::Fixed,
     666            3 :                 initial_limit: permits,
     667            3 :             },
     668            3 :             epoch: epoch.context("missing `epoch`")?,
     669            3 :             timeout: timeout.context("missing `timeout`")?,
     670              :         };
     671              : 
     672            3 :         ensure!(out.shards > 1, "shard count must be > 1");
     673            3 :         ensure!(
     674            3 :             out.shards.is_power_of_two(),
     675            0 :             "shard count must be a power of two"
     676              :         );
     677              : 
     678            3 :         Ok(out)
     679            4 :     }
     680              : }
     681              : 
     682              : impl FromStr for ConcurrencyLockOptions {
     683              :     type Err = anyhow::Error;
     684              : 
     685            4 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     686            4 :         let error = || format!("failed to parse cache lock options '{options}'");
     687            4 :         Self::parse(options).with_context(error)
     688            4 :     }
     689              : }
     690              : 
     691              : #[cfg(test)]
     692              : mod tests {
     693              :     use super::*;
     694              :     use crate::rate_limiter::Aimd;
     695              : 
     696              :     #[test]
     697            1 :     fn test_parse_cache_options() -> anyhow::Result<()> {
     698            1 :         let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
     699            1 :         assert_eq!(size, 4096);
     700            1 :         assert_eq!(ttl, Duration::from_secs(5 * 60));
     701              : 
     702            1 :         let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
     703            1 :         assert_eq!(size, 2);
     704            1 :         assert_eq!(ttl, Duration::from_secs(4 * 60));
     705              : 
     706            1 :         let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
     707            1 :         assert_eq!(size, 0);
     708            1 :         assert_eq!(ttl, Duration::from_secs(1));
     709              : 
     710            1 :         let CacheOptions { size, ttl } = "size=0".parse()?;
     711            1 :         assert_eq!(size, 0);
     712            1 :         assert_eq!(ttl, Duration::default());
     713              : 
     714            1 :         Ok(())
     715            1 :     }
     716              : 
     717              :     #[test]
     718            1 :     fn test_parse_lock_options() -> anyhow::Result<()> {
     719              :         let ConcurrencyLockOptions {
     720            1 :             epoch,
     721            1 :             limiter,
     722            1 :             shards,
     723            1 :             timeout,
     724            1 :         } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
     725            1 :         assert_eq!(epoch, Duration::from_secs(10 * 60));
     726            1 :         assert_eq!(timeout, Duration::from_secs(1));
     727            1 :         assert_eq!(shards, 32);
     728            1 :         assert_eq!(limiter.initial_limit, 4);
     729            1 :         assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
     730              : 
     731              :         let ConcurrencyLockOptions {
     732            1 :             epoch,
     733            1 :             limiter,
     734            1 :             shards,
     735            1 :             timeout,
     736            1 :         } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
     737            1 :         assert_eq!(epoch, Duration::from_secs(60));
     738            1 :         assert_eq!(timeout, Duration::from_millis(100));
     739            1 :         assert_eq!(shards, 16);
     740            1 :         assert_eq!(limiter.initial_limit, 8);
     741            1 :         assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
     742              : 
     743              :         let ConcurrencyLockOptions {
     744            1 :             epoch,
     745            1 :             limiter,
     746            1 :             shards,
     747            1 :             timeout,
     748            1 :         } = "permits=0".parse()?;
     749            1 :         assert_eq!(epoch, Duration::ZERO);
     750            1 :         assert_eq!(timeout, Duration::ZERO);
     751            1 :         assert_eq!(shards, 2);
     752            1 :         assert_eq!(limiter.initial_limit, 0);
     753            1 :         assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
     754              : 
     755            1 :         Ok(())
     756            1 :     }
     757              : 
     758              :     #[test]
     759            1 :     fn test_parse_json_lock_options() -> anyhow::Result<()> {
     760              :         let ConcurrencyLockOptions {
     761            1 :             epoch,
     762            1 :             limiter,
     763            1 :             shards,
     764            1 :             timeout,
     765            1 :         } = r#"{"shards":32,"initial_limit":44,"aimd":{"min":5,"max":500,"inc":10,"dec":0.9,"utilisation":0.8},"epoch":"10m","timeout":"1s"}"#
     766            1 :             .parse()?;
     767            1 :         assert_eq!(epoch, Duration::from_secs(10 * 60));
     768            1 :         assert_eq!(timeout, Duration::from_secs(1));
     769            1 :         assert_eq!(shards, 32);
     770            1 :         assert_eq!(limiter.initial_limit, 44);
     771            1 :         assert_eq!(
     772            1 :             limiter.algorithm,
     773            1 :             RateLimitAlgorithm::Aimd {
     774            1 :                 conf: Aimd {
     775            1 :                     min: 5,
     776            1 :                     max: 500,
     777            1 :                     dec: 0.9,
     778            1 :                     inc: 10,
     779            1 :                     utilisation: 0.8
     780            1 :                 }
     781            1 :             },
     782            1 :         );
     783              : 
     784            1 :         Ok(())
     785            1 :     }
     786              : }
        

Generated by: LCOV version 2.1-beta