LCOV - code coverage report
Current view: top level - proxy/src - config.rs (source / functions) Coverage Total Hit
Test: 553e39c2773e5840c720c90d86e56f89a4330d43.info Lines: 62.6 % 222 139
Test Date: 2025-06-13 20:01:21 Functions: 23.7 % 38 9

            Line data    Source code
       1              : use std::str::FromStr;
       2              : use std::sync::Arc;
       3              : use std::time::Duration;
       4              : 
       5              : use anyhow::{Context, Ok, bail, ensure};
       6              : use arc_swap::ArcSwapOption;
       7              : use clap::ValueEnum;
       8              : use remote_storage::RemoteStorageConfig;
       9              : 
      10              : use crate::auth::backend::jwt::JwkCache;
      11              : use crate::control_plane::locks::ApiLocks;
      12              : use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
      13              : use crate::scram::threadpool::ThreadPool;
      14              : use crate::serverless::GlobalConnPoolOptions;
      15              : use crate::serverless::cancel_set::CancelSet;
      16              : pub use crate::tls::server_config::{TlsConfig, configure_tls};
      17              : use crate::types::Host;
      18              : 
      19              : pub struct ProxyConfig {
      20              :     pub tls_config: ArcSwapOption<TlsConfig>,
      21              :     pub metric_collection: Option<MetricCollectionConfig>,
      22              :     pub http_config: HttpConfig,
      23              :     pub authentication_config: AuthenticationConfig,
      24              :     pub proxy_protocol_v2: ProxyProtocolV2,
      25              :     pub region: String,
      26              :     pub handshake_timeout: Duration,
      27              :     pub wake_compute_retry_config: RetryConfig,
      28              :     pub connect_compute_locks: ApiLocks<Host>,
      29              :     pub connect_to_compute: ComputeConfig,
      30              : }
      31              : 
      32              : pub struct ComputeConfig {
      33              :     pub retry: RetryConfig,
      34              :     pub tls: Arc<rustls::ClientConfig>,
      35              :     pub timeout: Duration,
      36              : }
      37              : 
      38              : #[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
      39              : pub enum ProxyProtocolV2 {
      40              :     /// Connection will error if PROXY protocol v2 header is missing
      41              :     Required,
      42              :     /// Connection will error if PROXY protocol v2 header is provided
      43              :     Rejected,
      44              : }
      45              : 
      46              : #[derive(Debug)]
      47              : pub struct MetricCollectionConfig {
      48              :     pub endpoint: reqwest::Url,
      49              :     pub interval: Duration,
      50              :     pub backup_metric_collection_config: MetricBackupCollectionConfig,
      51              : }
      52              : 
      53              : pub struct HttpConfig {
      54              :     pub accept_websockets: bool,
      55              :     pub pool_options: GlobalConnPoolOptions,
      56              :     pub cancel_set: CancelSet,
      57              :     pub client_conn_threshold: u64,
      58              :     pub max_request_size_bytes: usize,
      59              :     pub max_response_size_bytes: usize,
      60              : }
      61              : 
      62              : pub struct AuthenticationConfig {
      63              :     pub thread_pool: Arc<ThreadPool>,
      64              :     pub scram_protocol_timeout: tokio::time::Duration,
      65              :     pub ip_allowlist_check_enabled: bool,
      66              :     pub is_vpc_acccess_proxy: bool,
      67              :     pub jwks_cache: JwkCache,
      68              :     pub is_auth_broker: bool,
      69              :     pub accept_jwts: bool,
      70              :     pub console_redirect_confirmation_timeout: tokio::time::Duration,
      71              : }
      72              : 
      73              : #[derive(Debug)]
      74              : pub struct EndpointCacheConfig {
      75              :     /// Batch size to receive all endpoints on the startup.
      76              :     pub initial_batch_size: usize,
      77              :     /// Batch size to receive endpoints.
      78              :     pub default_batch_size: usize,
      79              :     /// Timeouts for the stream read operation.
      80              :     pub xread_timeout: Duration,
      81              :     /// Stream name to read from.
      82              :     pub stream_name: String,
      83              :     /// Limiter info (to distinguish when to enable cache).
      84              :     pub limiter_info: Vec<RateBucketInfo>,
      85              :     /// Disable cache.
      86              :     /// If true, cache is ignored, but reports all statistics.
      87              :     pub disable_cache: bool,
      88              :     /// Retry interval for the stream read operation.
      89              :     pub retry_interval: Duration,
      90              : }
      91              : 
      92              : impl EndpointCacheConfig {
      93              :     /// Default options for [`crate::control_plane::NodeInfoCache`].
      94              :     /// Notice that by default the limiter is empty, which means that cache is disabled.
      95              :     pub const CACHE_DEFAULT_OPTIONS: &'static str = "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
      96              : 
      97              :     /// Parse cache options passed via cmdline.
      98              :     /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
      99            0 :     fn parse(options: &str) -> anyhow::Result<Self> {
     100            0 :         let mut initial_batch_size = None;
     101            0 :         let mut default_batch_size = None;
     102            0 :         let mut xread_timeout = None;
     103            0 :         let mut stream_name = None;
     104            0 :         let mut limiter_info = vec![];
     105            0 :         let mut disable_cache = false;
     106            0 :         let mut retry_interval = None;
     107              : 
     108            0 :         for option in options.split(',') {
     109            0 :             let (key, value) = option
     110            0 :                 .split_once('=')
     111            0 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     112              : 
     113            0 :             match key {
     114            0 :                 "initial_batch_size" => initial_batch_size = Some(value.parse()?),
     115            0 :                 "default_batch_size" => default_batch_size = Some(value.parse()?),
     116            0 :                 "xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
     117            0 :                 "stream_name" => stream_name = Some(value.to_string()),
     118            0 :                 "limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
     119            0 :                 "disable_cache" => disable_cache = value.parse()?,
     120            0 :                 "retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
     121            0 :                 unknown => bail!("unknown key: {unknown}"),
     122              :             }
     123              :         }
     124            0 :         RateBucketInfo::validate(&mut limiter_info)?;
     125              : 
     126              :         Ok(Self {
     127            0 :             initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
     128            0 :             default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
     129            0 :             xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
     130            0 :             stream_name: stream_name.context("missing `stream_name`")?,
     131            0 :             disable_cache,
     132            0 :             limiter_info,
     133            0 :             retry_interval: retry_interval.context("missing `retry_interval`")?,
     134              :         })
     135            0 :     }
     136              : }
     137              : 
     138              : impl FromStr for EndpointCacheConfig {
     139              :     type Err = anyhow::Error;
     140              : 
     141            0 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     142            0 :         let error = || format!("failed to parse endpoint cache options '{options}'");
     143            0 :         Self::parse(options).with_context(error)
     144            0 :     }
     145              : }
     146              : #[derive(Debug)]
     147              : pub struct MetricBackupCollectionConfig {
     148              :     pub remote_storage_config: Option<RemoteStorageConfig>,
     149              :     pub chunk_size: usize,
     150              : }
     151              : 
     152            1 : pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig> {
     153            1 :     RemoteStorageConfig::from_toml(&s.parse()?)
     154            1 : }
     155              : 
     156              : /// Helper for cmdline cache options parsing.
     157              : #[derive(Debug)]
     158              : pub struct CacheOptions {
     159              :     /// Max number of entries.
     160              :     pub size: usize,
     161              :     /// Entry's time-to-live.
     162              :     pub ttl: Duration,
     163              : }
     164              : 
     165              : impl CacheOptions {
     166              :     /// Default options for [`crate::control_plane::NodeInfoCache`].
     167              :     pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
     168              : 
     169              :     /// Parse cache options passed via cmdline.
     170              :     /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
     171            4 :     fn parse(options: &str) -> anyhow::Result<Self> {
     172            4 :         let mut size = None;
     173            4 :         let mut ttl = None;
     174              : 
     175            7 :         for option in options.split(',') {
     176            7 :             let (key, value) = option
     177            7 :                 .split_once('=')
     178            7 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     179              : 
     180            7 :             match key {
     181            7 :                 "size" => size = Some(value.parse()?),
     182            3 :                 "ttl" => ttl = Some(humantime::parse_duration(value)?),
     183            0 :                 unknown => bail!("unknown key: {unknown}"),
     184              :             }
     185              :         }
     186              : 
     187              :         // TTL doesn't matter if cache is always empty.
     188            4 :         if let Some(0) = size {
     189            2 :             ttl.get_or_insert(Duration::default());
     190            2 :         }
     191              : 
     192              :         Ok(Self {
     193            4 :             size: size.context("missing `size`")?,
     194            4 :             ttl: ttl.context("missing `ttl`")?,
     195              :         })
     196            4 :     }
     197              : }
     198              : 
     199              : impl FromStr for CacheOptions {
     200              :     type Err = anyhow::Error;
     201              : 
     202            4 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     203            4 :         let error = || format!("failed to parse cache options '{options}'");
     204            4 :         Self::parse(options).with_context(error)
     205            4 :     }
     206              : }
     207              : 
     208              : /// Helper for cmdline cache options parsing.
     209              : #[derive(Debug)]
     210              : pub struct ProjectInfoCacheOptions {
     211              :     /// Max number of entries.
     212              :     pub size: usize,
     213              :     /// Entry's time-to-live.
     214              :     pub ttl: Duration,
     215              :     /// Max number of roles per endpoint.
     216              :     pub max_roles: usize,
     217              :     /// Gc interval.
     218              :     pub gc_interval: Duration,
     219              : }
     220              : 
     221              : impl ProjectInfoCacheOptions {
     222              :     /// Default options for [`crate::control_plane::NodeInfoCache`].
     223              :     pub const CACHE_DEFAULT_OPTIONS: &'static str =
     224              :         "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
     225              : 
     226              :     /// Parse cache options passed via cmdline.
     227              :     /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
     228            0 :     fn parse(options: &str) -> anyhow::Result<Self> {
     229            0 :         let mut size = None;
     230            0 :         let mut ttl = None;
     231            0 :         let mut max_roles = None;
     232            0 :         let mut gc_interval = None;
     233              : 
     234            0 :         for option in options.split(',') {
     235            0 :             let (key, value) = option
     236            0 :                 .split_once('=')
     237            0 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     238              : 
     239            0 :             match key {
     240            0 :                 "size" => size = Some(value.parse()?),
     241            0 :                 "ttl" => ttl = Some(humantime::parse_duration(value)?),
     242            0 :                 "max_roles" => max_roles = Some(value.parse()?),
     243            0 :                 "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
     244            0 :                 unknown => bail!("unknown key: {unknown}"),
     245              :             }
     246              :         }
     247              : 
     248              :         // TTL doesn't matter if cache is always empty.
     249            0 :         if let Some(0) = size {
     250            0 :             ttl.get_or_insert(Duration::default());
     251            0 :         }
     252              : 
     253              :         Ok(Self {
     254            0 :             size: size.context("missing `size`")?,
     255            0 :             ttl: ttl.context("missing `ttl`")?,
     256            0 :             max_roles: max_roles.context("missing `max_roles`")?,
     257            0 :             gc_interval: gc_interval.context("missing `gc_interval`")?,
     258              :         })
     259            0 :     }
     260              : }
     261              : 
     262              : impl FromStr for ProjectInfoCacheOptions {
     263              :     type Err = anyhow::Error;
     264              : 
     265            0 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     266            0 :         let error = || format!("failed to parse cache options '{options}'");
     267            0 :         Self::parse(options).with_context(error)
     268            0 :     }
     269              : }
     270              : 
     271              : /// This is a config for connect to compute and wake compute.
     272              : #[derive(Clone, Copy, Debug)]
     273              : pub struct RetryConfig {
     274              :     /// Number of times we should retry.
     275              :     pub max_retries: u32,
     276              :     /// Retry duration is base_delay * backoff_factor ^ n, where n starts at 0
     277              :     pub base_delay: tokio::time::Duration,
     278              :     /// Exponential base for retry wait duration
     279              :     pub backoff_factor: f64,
     280              : }
     281              : 
     282              : impl RetryConfig {
     283              :     // Default options for RetryConfig.
     284              : 
     285              :     /// Total delay for 5 retries with 200ms base delay and 2 backoff factor is about 6s.
     286              :     pub const CONNECT_TO_COMPUTE_DEFAULT_VALUES: &'static str =
     287              :         "num_retries=5,base_retry_wait_duration=200ms,retry_wait_exponent_base=2";
     288              :     /// Total delay for 8 retries with 100ms base delay and 1.6 backoff factor is about 7s.
     289              :     /// Cplane has timeout of 60s on each request. 8m7s in total.
     290              :     pub const WAKE_COMPUTE_DEFAULT_VALUES: &'static str =
     291              :         "num_retries=8,base_retry_wait_duration=100ms,retry_wait_exponent_base=1.6";
     292              : 
     293              :     /// Parse retry options passed via cmdline.
     294              :     /// Example: [`Self::CONNECT_TO_COMPUTE_DEFAULT_VALUES`].
     295            0 :     pub fn parse(options: &str) -> anyhow::Result<Self> {
     296            0 :         let mut num_retries = None;
     297            0 :         let mut base_retry_wait_duration = None;
     298            0 :         let mut retry_wait_exponent_base = None;
     299              : 
     300            0 :         for option in options.split(',') {
     301            0 :             let (key, value) = option
     302            0 :                 .split_once('=')
     303            0 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     304              : 
     305            0 :             match key {
     306            0 :                 "num_retries" => num_retries = Some(value.parse()?),
     307            0 :                 "base_retry_wait_duration" => {
     308            0 :                     base_retry_wait_duration = Some(humantime::parse_duration(value)?);
     309              :                 }
     310            0 :                 "retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?),
     311            0 :                 unknown => bail!("unknown key: {unknown}"),
     312              :             }
     313              :         }
     314              : 
     315              :         Ok(Self {
     316            0 :             max_retries: num_retries.context("missing `num_retries`")?,
     317            0 :             base_delay: base_retry_wait_duration.context("missing `base_retry_wait_duration`")?,
     318            0 :             backoff_factor: retry_wait_exponent_base
     319            0 :                 .context("missing `retry_wait_exponent_base`")?,
     320              :         })
     321            0 :     }
     322              : }
     323              : 
     324              : /// Helper for cmdline cache options parsing.
     325            5 : #[derive(serde::Deserialize)]
     326              : pub struct ConcurrencyLockOptions {
     327              :     /// The number of shards the lock map should have
     328              :     pub shards: usize,
     329              :     /// The number of allowed concurrent requests for each endpoitn
     330              :     #[serde(flatten)]
     331              :     pub limiter: RateLimiterConfig,
     332              :     /// Garbage collection epoch
     333              :     #[serde(deserialize_with = "humantime_serde::deserialize")]
     334              :     pub epoch: Duration,
     335              :     /// Lock timeout
     336              :     #[serde(deserialize_with = "humantime_serde::deserialize")]
     337              :     pub timeout: Duration,
     338              : }
     339              : 
     340              : impl ConcurrencyLockOptions {
     341              :     /// Default options for [`crate::control_plane::client::ApiLocks`].
     342              :     pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
     343              :     /// Default options for [`crate::control_plane::client::ApiLocks`].
     344              :     pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
     345              :         "shards=64,permits=100,epoch=10m,timeout=10ms";
     346              : 
     347              :     // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
     348              : 
     349              :     /// Parse lock options passed via cmdline.
     350              :     /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
     351            4 :     fn parse(options: &str) -> anyhow::Result<Self> {
     352            4 :         let options = options.trim();
     353            4 :         if options.starts_with('{') && options.ends_with('}') {
     354            1 :             return Ok(serde_json::from_str(options)?);
     355            3 :         }
     356            3 : 
     357            3 :         let mut shards = None;
     358            3 :         let mut permits = None;
     359            3 :         let mut epoch = None;
     360            3 :         let mut timeout = None;
     361              : 
     362            9 :         for option in options.split(',') {
     363            9 :             let (key, value) = option
     364            9 :                 .split_once('=')
     365            9 :                 .with_context(|| format!("bad key-value pair: {option}"))?;
     366              : 
     367            9 :             match key {
     368            9 :                 "shards" => shards = Some(value.parse()?),
     369            7 :                 "permits" => permits = Some(value.parse()?),
     370            4 :                 "epoch" => epoch = Some(humantime::parse_duration(value)?),
     371            2 :                 "timeout" => timeout = Some(humantime::parse_duration(value)?),
     372            0 :                 unknown => bail!("unknown key: {unknown}"),
     373              :             }
     374              :         }
     375              : 
     376              :         // these dont matter if lock is disabled
     377            3 :         if let Some(0) = permits {
     378            1 :             timeout = Some(Duration::default());
     379            1 :             epoch = Some(Duration::default());
     380            1 :             shards = Some(2);
     381            2 :         }
     382              : 
     383            3 :         let permits = permits.context("missing `permits`")?;
     384            3 :         let out = Self {
     385            3 :             shards: shards.context("missing `shards`")?,
     386            3 :             limiter: RateLimiterConfig {
     387            3 :                 algorithm: RateLimitAlgorithm::Fixed,
     388            3 :                 initial_limit: permits,
     389            3 :             },
     390            3 :             epoch: epoch.context("missing `epoch`")?,
     391            3 :             timeout: timeout.context("missing `timeout`")?,
     392              :         };
     393              : 
     394            3 :         ensure!(out.shards > 1, "shard count must be > 1");
     395            3 :         ensure!(
     396            3 :             out.shards.is_power_of_two(),
     397            0 :             "shard count must be a power of two"
     398              :         );
     399              : 
     400            3 :         Ok(out)
     401            4 :     }
     402              : }
     403              : 
     404              : impl FromStr for ConcurrencyLockOptions {
     405              :     type Err = anyhow::Error;
     406              : 
     407            4 :     fn from_str(options: &str) -> Result<Self, Self::Err> {
     408            4 :         let error = || format!("failed to parse cache lock options '{options}'");
     409            4 :         Self::parse(options).with_context(error)
     410            4 :     }
     411              : }
     412              : 
     413              : #[cfg(test)]
     414              : mod tests {
     415              :     use super::*;
     416              :     use crate::rate_limiter::Aimd;
     417              : 
     418              :     #[test]
     419            1 :     fn test_parse_cache_options() -> anyhow::Result<()> {
     420            1 :         let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
     421            1 :         assert_eq!(size, 4096);
     422            1 :         assert_eq!(ttl, Duration::from_secs(5 * 60));
     423              : 
     424            1 :         let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
     425            1 :         assert_eq!(size, 2);
     426            1 :         assert_eq!(ttl, Duration::from_secs(4 * 60));
     427              : 
     428            1 :         let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
     429            1 :         assert_eq!(size, 0);
     430            1 :         assert_eq!(ttl, Duration::from_secs(1));
     431              : 
     432            1 :         let CacheOptions { size, ttl } = "size=0".parse()?;
     433            1 :         assert_eq!(size, 0);
     434            1 :         assert_eq!(ttl, Duration::default());
     435              : 
     436            1 :         Ok(())
     437            1 :     }
     438              : 
     439              :     #[test]
     440            1 :     fn test_parse_lock_options() -> anyhow::Result<()> {
     441              :         let ConcurrencyLockOptions {
     442            1 :             epoch,
     443            1 :             limiter,
     444            1 :             shards,
     445            1 :             timeout,
     446            1 :         } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
     447            1 :         assert_eq!(epoch, Duration::from_secs(10 * 60));
     448            1 :         assert_eq!(timeout, Duration::from_secs(1));
     449            1 :         assert_eq!(shards, 32);
     450            1 :         assert_eq!(limiter.initial_limit, 4);
     451            1 :         assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
     452              : 
     453              :         let ConcurrencyLockOptions {
     454            1 :             epoch,
     455            1 :             limiter,
     456            1 :             shards,
     457            1 :             timeout,
     458            1 :         } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
     459            1 :         assert_eq!(epoch, Duration::from_secs(60));
     460            1 :         assert_eq!(timeout, Duration::from_millis(100));
     461            1 :         assert_eq!(shards, 16);
     462            1 :         assert_eq!(limiter.initial_limit, 8);
     463            1 :         assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
     464              : 
     465              :         let ConcurrencyLockOptions {
     466            1 :             epoch,
     467            1 :             limiter,
     468            1 :             shards,
     469            1 :             timeout,
     470            1 :         } = "permits=0".parse()?;
     471            1 :         assert_eq!(epoch, Duration::ZERO);
     472            1 :         assert_eq!(timeout, Duration::ZERO);
     473            1 :         assert_eq!(shards, 2);
     474            1 :         assert_eq!(limiter.initial_limit, 0);
     475            1 :         assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
     476              : 
     477            1 :         Ok(())
     478            1 :     }
     479              : 
     480              :     #[test]
     481            1 :     fn test_parse_json_lock_options() -> anyhow::Result<()> {
     482              :         let ConcurrencyLockOptions {
     483            1 :             epoch,
     484            1 :             limiter,
     485            1 :             shards,
     486            1 :             timeout,
     487            1 :         } = r#"{"shards":32,"initial_limit":44,"aimd":{"min":5,"max":500,"inc":10,"dec":0.9,"utilisation":0.8},"epoch":"10m","timeout":"1s"}"#
     488            1 :             .parse()?;
     489            1 :         assert_eq!(epoch, Duration::from_secs(10 * 60));
     490            1 :         assert_eq!(timeout, Duration::from_secs(1));
     491            1 :         assert_eq!(shards, 32);
     492            1 :         assert_eq!(limiter.initial_limit, 44);
     493            1 :         assert_eq!(
     494            1 :             limiter.algorithm,
     495            1 :             RateLimitAlgorithm::Aimd {
     496            1 :                 conf: Aimd {
     497            1 :                     min: 5,
     498            1 :                     max: 500,
     499            1 :                     dec: 0.9,
     500            1 :                     inc: 10,
     501            1 :                     utilisation: 0.8
     502            1 :                 }
     503            1 :             },
     504            1 :         );
     505              : 
     506            1 :         Ok(())
     507            1 :     }
     508              : }
        

Generated by: LCOV version 2.1-beta