Line data Source code
1 : use crate::{
2 : auth::{self, backend::AuthRateLimiter},
3 : console::locks::ApiLocks,
4 : rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
5 : scram::threadpool::ThreadPool,
6 : serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
7 : Host,
8 : };
9 : use anyhow::{bail, ensure, Context, Ok};
10 : use itertools::Itertools;
11 : use remote_storage::RemoteStorageConfig;
12 : use rustls::{
13 : crypto::ring::sign,
14 : pki_types::{CertificateDer, PrivateKeyDer},
15 : };
16 : use sha2::{Digest, Sha256};
17 : use std::{
18 : collections::{HashMap, HashSet},
19 : str::FromStr,
20 : sync::Arc,
21 : time::Duration,
22 : };
23 : use tracing::{error, info};
24 : use x509_parser::oid_registry;
25 :
26 : pub struct ProxyConfig {
27 : pub tls_config: Option<TlsConfig>,
28 : pub auth_backend: auth::Backend<'static, (), ()>,
29 : pub metric_collection: Option<MetricCollectionConfig>,
30 : pub allow_self_signed_compute: bool,
31 : pub http_config: HttpConfig,
32 : pub authentication_config: AuthenticationConfig,
33 : pub require_client_ip: bool,
34 : pub region: String,
35 : pub handshake_timeout: Duration,
36 : pub wake_compute_retry_config: RetryConfig,
37 : pub connect_compute_locks: ApiLocks<Host>,
38 : pub connect_to_compute_retry_config: RetryConfig,
39 : }
40 :
41 : #[derive(Debug)]
42 : pub struct MetricCollectionConfig {
43 : pub endpoint: reqwest::Url,
44 : pub interval: Duration,
45 : pub backup_metric_collection_config: MetricBackupCollectionConfig,
46 : }
47 :
48 : pub struct TlsConfig {
49 : pub config: Arc<rustls::ServerConfig>,
50 : pub common_names: HashSet<String>,
51 : pub cert_resolver: Arc<CertResolver>,
52 : }
53 :
54 : pub struct HttpConfig {
55 : pub accept_websockets: bool,
56 : pub pool_options: GlobalConnPoolOptions,
57 : pub cancel_set: CancelSet,
58 : pub client_conn_threshold: u64,
59 : pub max_request_size_bytes: u64,
60 : pub max_response_size_bytes: usize,
61 : }
62 :
63 : pub struct AuthenticationConfig {
64 : pub thread_pool: Arc<ThreadPool>,
65 : pub scram_protocol_timeout: tokio::time::Duration,
66 : pub rate_limiter_enabled: bool,
67 : pub rate_limiter: AuthRateLimiter,
68 : pub rate_limit_ip_subnet: u8,
69 : pub ip_allowlist_check_enabled: bool,
70 : }
71 :
72 : impl TlsConfig {
73 20 : pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
74 20 : self.config.clone()
75 20 : }
76 : }
77 :
78 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
79 : pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
80 :
81 : /// Configure TLS for the main endpoint.
82 0 : pub fn configure_tls(
83 0 : key_path: &str,
84 0 : cert_path: &str,
85 0 : certs_dir: Option<&String>,
86 0 : ) -> anyhow::Result<TlsConfig> {
87 0 : let mut cert_resolver = CertResolver::new();
88 0 :
89 0 : // add default certificate
90 0 : cert_resolver.add_cert_path(key_path, cert_path, true)?;
91 :
92 : // add extra certificates
93 0 : if let Some(certs_dir) = certs_dir {
94 0 : for entry in std::fs::read_dir(certs_dir)? {
95 0 : let entry = entry?;
96 0 : let path = entry.path();
97 0 : if path.is_dir() {
98 : // file names aligned with default cert-manager names
99 0 : let key_path = path.join("tls.key");
100 0 : let cert_path = path.join("tls.crt");
101 0 : if key_path.exists() && cert_path.exists() {
102 0 : cert_resolver.add_cert_path(
103 0 : &key_path.to_string_lossy(),
104 0 : &cert_path.to_string_lossy(),
105 0 : false,
106 0 : )?;
107 0 : }
108 0 : }
109 : }
110 0 : }
111 :
112 0 : let common_names = cert_resolver.get_common_names();
113 0 :
114 0 : let cert_resolver = Arc::new(cert_resolver);
115 0 :
116 0 : // allow TLS 1.2 to be compatible with older client libraries
117 0 : let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[
118 0 : &rustls::version::TLS13,
119 0 : &rustls::version::TLS12,
120 0 : ])
121 0 : .with_no_client_auth()
122 0 : .with_cert_resolver(cert_resolver.clone());
123 0 :
124 0 : config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
125 0 :
126 0 : Ok(TlsConfig {
127 0 : config: Arc::new(config),
128 0 : common_names,
129 0 : cert_resolver,
130 0 : })
131 0 : }
132 :
133 : /// Channel binding parameter
134 : ///
135 : /// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
136 : /// Description: The hash of the TLS server's certificate as it
137 : /// appears, octet for octet, in the server's Certificate message. Note
138 : /// that the Certificate message contains a certificate_list, in which
139 : /// the first element is the server's certificate.
140 : ///
141 : /// The hash function is to be selected as follows:
142 : ///
143 : /// * if the certificate's signatureAlgorithm uses a single hash
144 : /// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
145 : ///
146 : /// * if the certificate's signatureAlgorithm uses a single hash
147 : /// function and that hash function neither MD5 nor SHA-1, then use
148 : /// the hash function associated with the certificate's
149 : /// signatureAlgorithm;
150 : ///
151 : /// * if the certificate's signatureAlgorithm uses no hash functions or
152 : /// uses multiple hash functions, then this channel binding type's
153 : /// channel bindings are undefined at this time (updates to is channel
154 : /// binding type may occur to address this issue if it ever arises).
155 : #[derive(Debug, Clone, Copy)]
156 : pub enum TlsServerEndPoint {
157 : Sha256([u8; 32]),
158 : Undefined,
159 : }
160 :
161 : impl TlsServerEndPoint {
162 21 : pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
163 21 : let sha256_oids = [
164 21 : // I'm explicitly not adding MD5 or SHA1 here... They're bad.
165 21 : oid_registry::OID_SIG_ECDSA_WITH_SHA256,
166 21 : oid_registry::OID_PKCS1_SHA256WITHRSA,
167 21 : ];
168 :
169 21 : let pem = x509_parser::parse_x509_certificate(cert)
170 21 : .context("Failed to parse PEM object from cerficiate")?
171 : .1;
172 :
173 21 : info!(subject = %pem.subject, "parsing TLS certificate");
174 :
175 21 : let reg = oid_registry::OidRegistry::default().with_all_crypto();
176 21 : let oid = pem.signature_algorithm.oid();
177 21 : let alg = reg.get(oid);
178 21 : if sha256_oids.contains(oid) {
179 21 : let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
180 21 : info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
181 21 : Ok(Self::Sha256(tls_server_end_point))
182 : } else {
183 0 : error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
184 0 : Ok(Self::Undefined)
185 : }
186 21 : }
187 :
188 16 : pub fn supported(&self) -> bool {
189 16 : !matches!(self, TlsServerEndPoint::Undefined)
190 16 : }
191 : }
192 :
193 : #[derive(Default, Debug)]
194 : pub struct CertResolver {
195 : certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
196 : default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
197 : }
198 :
199 : impl CertResolver {
200 21 : pub fn new() -> Self {
201 21 : Self::default()
202 21 : }
203 :
204 0 : fn add_cert_path(
205 0 : &mut self,
206 0 : key_path: &str,
207 0 : cert_path: &str,
208 0 : is_default: bool,
209 0 : ) -> anyhow::Result<()> {
210 0 : let priv_key = {
211 0 : let key_bytes = std::fs::read(key_path)
212 0 : .context(format!("Failed to read TLS keys at '{key_path}'"))?;
213 0 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
214 0 :
215 0 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
216 : PrivateKeyDer::Pkcs8(
217 0 : keys.pop()
218 0 : .unwrap()
219 0 : .context(format!("Failed to parse TLS keys at '{key_path}'"))?,
220 : )
221 : };
222 :
223 0 : let cert_chain_bytes = std::fs::read(cert_path)
224 0 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
225 :
226 0 : let cert_chain = {
227 0 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
228 0 : .try_collect()
229 0 : .with_context(|| {
230 0 : format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
231 0 : })?
232 : };
233 :
234 0 : self.add_cert(priv_key, cert_chain, is_default)
235 0 : }
236 :
237 21 : pub fn add_cert(
238 21 : &mut self,
239 21 : priv_key: PrivateKeyDer<'static>,
240 21 : cert_chain: Vec<CertificateDer<'static>>,
241 21 : is_default: bool,
242 21 : ) -> anyhow::Result<()> {
243 21 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
244 :
245 21 : let first_cert = &cert_chain[0];
246 21 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
247 21 : let pem = x509_parser::parse_x509_certificate(first_cert)
248 21 : .context("Failed to parse PEM object from cerficiate")?
249 : .1;
250 :
251 21 : let common_name = pem.subject().to_string();
252 :
253 : // We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as
254 : // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
255 : // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
256 : // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
257 : // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
258 : // of cutting off '*.' parts.
259 21 : let common_name = if common_name.starts_with("CN=*.") {
260 0 : common_name.strip_prefix("CN=*.").map(|s| s.to_string())
261 : } else {
262 21 : common_name.strip_prefix("CN=").map(|s| s.to_string())
263 : }
264 21 : .context("Failed to parse common name from certificate")?;
265 :
266 21 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
267 21 :
268 21 : if is_default {
269 21 : self.default = Some((cert.clone(), tls_server_end_point));
270 21 : }
271 :
272 21 : self.certs.insert(common_name, (cert, tls_server_end_point));
273 21 :
274 21 : Ok(())
275 21 : }
276 :
277 21 : pub fn get_common_names(&self) -> HashSet<String> {
278 21 : self.certs.keys().map(|s| s.to_string()).collect()
279 21 : }
280 : }
281 :
282 : impl rustls::server::ResolvesServerCert for CertResolver {
283 0 : fn resolve(
284 0 : &self,
285 0 : client_hello: rustls::server::ClientHello<'_>,
286 0 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
287 0 : self.resolve(client_hello.server_name()).map(|x| x.0)
288 0 : }
289 : }
290 :
291 : impl CertResolver {
292 20 : pub fn resolve(
293 20 : &self,
294 20 : server_name: Option<&str>,
295 20 : ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
296 : // loop here and cut off more and more subdomains until we find
297 : // a match to get a proper wildcard support. OTOH, we now do not
298 : // use nested domains, so keep this simple for now.
299 : //
300 : // With the current coding foo.com will match *.foo.com and that
301 : // repeats behavior of the old code.
302 20 : if let Some(mut sni_name) = server_name {
303 : loop {
304 40 : if let Some(cert) = self.certs.get(sni_name) {
305 20 : return Some(cert.clone());
306 20 : }
307 20 : if let Some((_, rest)) = sni_name.split_once('.') {
308 20 : sni_name = rest;
309 20 : } else {
310 0 : return None;
311 : }
312 : }
313 : } else {
314 : // No SNI, use the default certificate, otherwise we can't get to
315 : // options parameter which can be used to set endpoint name too.
316 : // That means that non-SNI flow will not work for CNAME domains in
317 : // verify-full mode.
318 : //
319 : // If that will be a problem we can:
320 : //
321 : // a) Instead of multi-cert approach use single cert with extra
322 : // domains listed in Subject Alternative Name (SAN).
323 : // b) Deploy separate proxy instances for extra domains.
324 0 : self.default.clone()
325 : }
326 20 : }
327 : }
328 :
329 : #[derive(Debug)]
330 : pub struct EndpointCacheConfig {
331 : /// Batch size to receive all endpoints on the startup.
332 : pub initial_batch_size: usize,
333 : /// Batch size to receive endpoints.
334 : pub default_batch_size: usize,
335 : /// Timeouts for the stream read operation.
336 : pub xread_timeout: Duration,
337 : /// Stream name to read from.
338 : pub stream_name: String,
339 : /// Limiter info (to distinguish when to enable cache).
340 : pub limiter_info: Vec<RateBucketInfo>,
341 : /// Disable cache.
342 : /// If true, cache is ignored, but reports all statistics.
343 : pub disable_cache: bool,
344 : /// Retry interval for the stream read operation.
345 : pub retry_interval: Duration,
346 : }
347 :
348 : impl EndpointCacheConfig {
349 : /// Default options for [`crate::console::provider::NodeInfoCache`].
350 : /// Notice that by default the limiter is empty, which means that cache is disabled.
351 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
352 : "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
353 :
354 : /// Parse cache options passed via cmdline.
355 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
356 0 : fn parse(options: &str) -> anyhow::Result<Self> {
357 0 : let mut initial_batch_size = None;
358 0 : let mut default_batch_size = None;
359 0 : let mut xread_timeout = None;
360 0 : let mut stream_name = None;
361 0 : let mut limiter_info = vec![];
362 0 : let mut disable_cache = false;
363 0 : let mut retry_interval = None;
364 :
365 0 : for option in options.split(',') {
366 0 : let (key, value) = option
367 0 : .split_once('=')
368 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
369 :
370 0 : match key {
371 0 : "initial_batch_size" => initial_batch_size = Some(value.parse()?),
372 0 : "default_batch_size" => default_batch_size = Some(value.parse()?),
373 0 : "xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
374 0 : "stream_name" => stream_name = Some(value.to_string()),
375 0 : "limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
376 0 : "disable_cache" => disable_cache = value.parse()?,
377 0 : "retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
378 0 : unknown => bail!("unknown key: {unknown}"),
379 : }
380 : }
381 0 : RateBucketInfo::validate(&mut limiter_info)?;
382 :
383 : Ok(Self {
384 0 : initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
385 0 : default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
386 0 : xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
387 0 : stream_name: stream_name.context("missing `stream_name`")?,
388 0 : disable_cache,
389 0 : limiter_info,
390 0 : retry_interval: retry_interval.context("missing `retry_interval`")?,
391 : })
392 0 : }
393 : }
394 :
395 : impl FromStr for EndpointCacheConfig {
396 : type Err = anyhow::Error;
397 :
398 0 : fn from_str(options: &str) -> Result<Self, Self::Err> {
399 0 : let error = || format!("failed to parse endpoint cache options '{options}'");
400 0 : Self::parse(options).with_context(error)
401 0 : }
402 : }
403 : #[derive(Debug)]
404 : pub struct MetricBackupCollectionConfig {
405 : pub interval: Duration,
406 : pub remote_storage_config: Option<RemoteStorageConfig>,
407 : pub chunk_size: usize,
408 : }
409 :
410 1 : pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig> {
411 1 : RemoteStorageConfig::from_toml(&s.parse()?)
412 1 : }
413 :
414 : /// Helper for cmdline cache options parsing.
415 : #[derive(Debug)]
416 : pub struct CacheOptions {
417 : /// Max number of entries.
418 : pub size: usize,
419 : /// Entry's time-to-live.
420 : pub ttl: Duration,
421 : }
422 :
423 : impl CacheOptions {
424 : /// Default options for [`crate::console::provider::NodeInfoCache`].
425 : pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
426 :
427 : /// Parse cache options passed via cmdline.
428 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
429 4 : fn parse(options: &str) -> anyhow::Result<Self> {
430 4 : let mut size = None;
431 4 : let mut ttl = None;
432 :
433 7 : for option in options.split(',') {
434 7 : let (key, value) = option
435 7 : .split_once('=')
436 7 : .with_context(|| format!("bad key-value pair: {option}"))?;
437 :
438 7 : match key {
439 7 : "size" => size = Some(value.parse()?),
440 3 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
441 0 : unknown => bail!("unknown key: {unknown}"),
442 : }
443 : }
444 :
445 : // TTL doesn't matter if cache is always empty.
446 4 : if let Some(0) = size {
447 2 : ttl.get_or_insert(Duration::default());
448 2 : }
449 :
450 : Ok(Self {
451 4 : size: size.context("missing `size`")?,
452 4 : ttl: ttl.context("missing `ttl`")?,
453 : })
454 4 : }
455 : }
456 :
457 : impl FromStr for CacheOptions {
458 : type Err = anyhow::Error;
459 :
460 4 : fn from_str(options: &str) -> Result<Self, Self::Err> {
461 4 : let error = || format!("failed to parse cache options '{options}'");
462 4 : Self::parse(options).with_context(error)
463 4 : }
464 : }
465 :
466 : /// Helper for cmdline cache options parsing.
467 : #[derive(Debug)]
468 : pub struct ProjectInfoCacheOptions {
469 : /// Max number of entries.
470 : pub size: usize,
471 : /// Entry's time-to-live.
472 : pub ttl: Duration,
473 : /// Max number of roles per endpoint.
474 : pub max_roles: usize,
475 : /// Gc interval.
476 : pub gc_interval: Duration,
477 : }
478 :
479 : impl ProjectInfoCacheOptions {
480 : /// Default options for [`crate::console::provider::NodeInfoCache`].
481 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
482 : "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
483 :
484 : /// Parse cache options passed via cmdline.
485 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
486 0 : fn parse(options: &str) -> anyhow::Result<Self> {
487 0 : let mut size = None;
488 0 : let mut ttl = None;
489 0 : let mut max_roles = None;
490 0 : let mut gc_interval = None;
491 :
492 0 : for option in options.split(',') {
493 0 : let (key, value) = option
494 0 : .split_once('=')
495 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
496 :
497 0 : match key {
498 0 : "size" => size = Some(value.parse()?),
499 0 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
500 0 : "max_roles" => max_roles = Some(value.parse()?),
501 0 : "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
502 0 : unknown => bail!("unknown key: {unknown}"),
503 : }
504 : }
505 :
506 : // TTL doesn't matter if cache is always empty.
507 0 : if let Some(0) = size {
508 0 : ttl.get_or_insert(Duration::default());
509 0 : }
510 :
511 : Ok(Self {
512 0 : size: size.context("missing `size`")?,
513 0 : ttl: ttl.context("missing `ttl`")?,
514 0 : max_roles: max_roles.context("missing `max_roles`")?,
515 0 : gc_interval: gc_interval.context("missing `gc_interval`")?,
516 : })
517 0 : }
518 : }
519 :
520 : impl FromStr for ProjectInfoCacheOptions {
521 : type Err = anyhow::Error;
522 :
523 0 : fn from_str(options: &str) -> Result<Self, Self::Err> {
524 0 : let error = || format!("failed to parse cache options '{options}'");
525 0 : Self::parse(options).with_context(error)
526 0 : }
527 : }
528 :
529 : /// This is a config for connect to compute and wake compute.
530 : #[derive(Clone, Copy, Debug)]
531 : pub struct RetryConfig {
532 : /// Number of times we should retry.
533 : pub max_retries: u32,
534 : /// Retry duration is base_delay * backoff_factor ^ n, where n starts at 0
535 : pub base_delay: tokio::time::Duration,
536 : /// Exponential base for retry wait duration
537 : pub backoff_factor: f64,
538 : }
539 :
540 : impl RetryConfig {
541 : /// Default options for RetryConfig.
542 :
543 : /// Total delay for 5 retries with 200ms base delay and 2 backoff factor is about 6s.
544 : pub const CONNECT_TO_COMPUTE_DEFAULT_VALUES: &'static str =
545 : "num_retries=5,base_retry_wait_duration=200ms,retry_wait_exponent_base=2";
546 : /// Total delay for 8 retries with 100ms base delay and 1.6 backoff factor is about 7s.
547 : /// Cplane has timeout of 60s on each request. 8m7s in total.
548 : pub const WAKE_COMPUTE_DEFAULT_VALUES: &'static str =
549 : "num_retries=8,base_retry_wait_duration=100ms,retry_wait_exponent_base=1.6";
550 :
551 : /// Parse retry options passed via cmdline.
552 : /// Example: [`Self::CONNECT_TO_COMPUTE_DEFAULT_VALUES`].
553 0 : pub fn parse(options: &str) -> anyhow::Result<Self> {
554 0 : let mut num_retries = None;
555 0 : let mut base_retry_wait_duration = None;
556 0 : let mut retry_wait_exponent_base = None;
557 :
558 0 : for option in options.split(',') {
559 0 : let (key, value) = option
560 0 : .split_once('=')
561 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
562 :
563 0 : match key {
564 0 : "num_retries" => num_retries = Some(value.parse()?),
565 0 : "base_retry_wait_duration" => {
566 0 : base_retry_wait_duration = Some(humantime::parse_duration(value)?);
567 : }
568 0 : "retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?),
569 0 : unknown => bail!("unknown key: {unknown}"),
570 : }
571 : }
572 :
573 : Ok(Self {
574 0 : max_retries: num_retries.context("missing `num_retries`")?,
575 0 : base_delay: base_retry_wait_duration.context("missing `base_retry_wait_duration`")?,
576 0 : backoff_factor: retry_wait_exponent_base
577 0 : .context("missing `retry_wait_exponent_base`")?,
578 : })
579 0 : }
580 : }
581 :
582 : /// Helper for cmdline cache options parsing.
583 8 : #[derive(serde::Deserialize)]
584 : pub struct ConcurrencyLockOptions {
585 : /// The number of shards the lock map should have
586 : pub shards: usize,
587 : /// The number of allowed concurrent requests for each endpoitn
588 : #[serde(flatten)]
589 : pub limiter: RateLimiterConfig,
590 : /// Garbage collection epoch
591 : #[serde(deserialize_with = "humantime_serde::deserialize")]
592 : pub epoch: Duration,
593 : /// Lock timeout
594 : #[serde(deserialize_with = "humantime_serde::deserialize")]
595 : pub timeout: Duration,
596 : }
597 :
598 : impl ConcurrencyLockOptions {
599 : /// Default options for [`crate::console::provider::ApiLocks`].
600 : pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
601 : /// Default options for [`crate::console::provider::ApiLocks`].
602 : pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
603 : "shards=64,permits=100,epoch=10m,timeout=10ms";
604 :
605 : // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
606 :
607 : /// Parse lock options passed via cmdline.
608 : /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
609 4 : fn parse(options: &str) -> anyhow::Result<Self> {
610 4 : let options = options.trim();
611 4 : if options.starts_with('{') && options.ends_with('}') {
612 1 : return Ok(serde_json::from_str(options)?);
613 3 : }
614 3 :
615 3 : let mut shards = None;
616 3 : let mut permits = None;
617 3 : let mut epoch = None;
618 3 : let mut timeout = None;
619 :
620 9 : for option in options.split(',') {
621 9 : let (key, value) = option
622 9 : .split_once('=')
623 9 : .with_context(|| format!("bad key-value pair: {option}"))?;
624 :
625 9 : match key {
626 9 : "shards" => shards = Some(value.parse()?),
627 7 : "permits" => permits = Some(value.parse()?),
628 4 : "epoch" => epoch = Some(humantime::parse_duration(value)?),
629 2 : "timeout" => timeout = Some(humantime::parse_duration(value)?),
630 0 : unknown => bail!("unknown key: {unknown}"),
631 : }
632 : }
633 :
634 : // these dont matter if lock is disabled
635 3 : if let Some(0) = permits {
636 1 : timeout = Some(Duration::default());
637 1 : epoch = Some(Duration::default());
638 1 : shards = Some(2);
639 2 : }
640 :
641 3 : let permits = permits.context("missing `permits`")?;
642 3 : let out = Self {
643 3 : shards: shards.context("missing `shards`")?,
644 3 : limiter: RateLimiterConfig {
645 3 : algorithm: RateLimitAlgorithm::Fixed,
646 3 : initial_limit: permits,
647 3 : },
648 3 : epoch: epoch.context("missing `epoch`")?,
649 3 : timeout: timeout.context("missing `timeout`")?,
650 : };
651 :
652 3 : ensure!(out.shards > 1, "shard count must be > 1");
653 3 : ensure!(
654 3 : out.shards.is_power_of_two(),
655 0 : "shard count must be a power of two"
656 : );
657 :
658 3 : Ok(out)
659 4 : }
660 : }
661 :
662 : impl FromStr for ConcurrencyLockOptions {
663 : type Err = anyhow::Error;
664 :
665 4 : fn from_str(options: &str) -> Result<Self, Self::Err> {
666 4 : let error = || format!("failed to parse cache lock options '{options}'");
667 4 : Self::parse(options).with_context(error)
668 4 : }
669 : }
670 :
671 : #[cfg(test)]
672 : mod tests {
673 : use crate::rate_limiter::Aimd;
674 :
675 : use super::*;
676 :
677 : #[test]
678 1 : fn test_parse_cache_options() -> anyhow::Result<()> {
679 1 : let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
680 1 : assert_eq!(size, 4096);
681 1 : assert_eq!(ttl, Duration::from_secs(5 * 60));
682 :
683 1 : let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
684 1 : assert_eq!(size, 2);
685 1 : assert_eq!(ttl, Duration::from_secs(4 * 60));
686 :
687 1 : let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
688 1 : assert_eq!(size, 0);
689 1 : assert_eq!(ttl, Duration::from_secs(1));
690 :
691 1 : let CacheOptions { size, ttl } = "size=0".parse()?;
692 1 : assert_eq!(size, 0);
693 1 : assert_eq!(ttl, Duration::default());
694 :
695 1 : Ok(())
696 1 : }
697 :
698 : #[test]
699 1 : fn test_parse_lock_options() -> anyhow::Result<()> {
700 : let ConcurrencyLockOptions {
701 1 : epoch,
702 1 : limiter,
703 1 : shards,
704 1 : timeout,
705 1 : } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
706 1 : assert_eq!(epoch, Duration::from_secs(10 * 60));
707 1 : assert_eq!(timeout, Duration::from_secs(1));
708 1 : assert_eq!(shards, 32);
709 1 : assert_eq!(limiter.initial_limit, 4);
710 1 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
711 :
712 : let ConcurrencyLockOptions {
713 1 : epoch,
714 1 : limiter,
715 1 : shards,
716 1 : timeout,
717 1 : } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
718 1 : assert_eq!(epoch, Duration::from_secs(60));
719 1 : assert_eq!(timeout, Duration::from_millis(100));
720 1 : assert_eq!(shards, 16);
721 1 : assert_eq!(limiter.initial_limit, 8);
722 1 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
723 :
724 : let ConcurrencyLockOptions {
725 1 : epoch,
726 1 : limiter,
727 1 : shards,
728 1 : timeout,
729 1 : } = "permits=0".parse()?;
730 1 : assert_eq!(epoch, Duration::ZERO);
731 1 : assert_eq!(timeout, Duration::ZERO);
732 1 : assert_eq!(shards, 2);
733 1 : assert_eq!(limiter.initial_limit, 0);
734 1 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
735 :
736 1 : Ok(())
737 1 : }
738 :
739 : #[test]
740 1 : fn test_parse_json_lock_options() -> anyhow::Result<()> {
741 : let ConcurrencyLockOptions {
742 1 : epoch,
743 1 : limiter,
744 1 : shards,
745 1 : timeout,
746 1 : } = r#"{"shards":32,"initial_limit":44,"aimd":{"min":5,"max":500,"inc":10,"dec":0.9,"utilisation":0.8},"epoch":"10m","timeout":"1s"}"#
747 1 : .parse()?;
748 1 : assert_eq!(epoch, Duration::from_secs(10 * 60));
749 1 : assert_eq!(timeout, Duration::from_secs(1));
750 1 : assert_eq!(shards, 32);
751 1 : assert_eq!(limiter.initial_limit, 44);
752 1 : assert_eq!(
753 1 : limiter.algorithm,
754 1 : RateLimitAlgorithm::Aimd {
755 1 : conf: Aimd {
756 1 : min: 5,
757 1 : max: 500,
758 1 : dec: 0.9,
759 1 : inc: 10,
760 1 : utilisation: 0.8
761 1 : }
762 1 : },
763 1 : );
764 :
765 1 : Ok(())
766 1 : }
767 : }
|