Line data Source code
1 : use crate::{
2 : auth::{self, backend::AuthRateLimiter},
3 : console::locks::ApiLocks,
4 : rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
5 : scram::threadpool::ThreadPool,
6 : serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
7 : Host,
8 : };
9 : use anyhow::{bail, ensure, Context, Ok};
10 : use itertools::Itertools;
11 : use remote_storage::RemoteStorageConfig;
12 : use rustls::{
13 : crypto::ring::sign,
14 : pki_types::{CertificateDer, PrivateKeyDer},
15 : };
16 : use sha2::{Digest, Sha256};
17 : use std::{
18 : collections::{HashMap, HashSet},
19 : str::FromStr,
20 : sync::Arc,
21 : time::Duration,
22 : };
23 : use tracing::{error, info};
24 : use x509_parser::oid_registry;
25 :
26 : pub struct ProxyConfig {
27 : pub tls_config: Option<TlsConfig>,
28 : pub auth_backend: auth::BackendType<'static, (), ()>,
29 : pub metric_collection: Option<MetricCollectionConfig>,
30 : pub allow_self_signed_compute: bool,
31 : pub http_config: HttpConfig,
32 : pub authentication_config: AuthenticationConfig,
33 : pub require_client_ip: bool,
34 : pub disable_ip_check_for_http: bool,
35 : pub redis_rps_limit: Vec<RateBucketInfo>,
36 : pub region: String,
37 : pub handshake_timeout: Duration,
38 : pub aws_region: String,
39 : pub wake_compute_retry_config: RetryConfig,
40 : pub connect_compute_locks: ApiLocks<Host>,
41 : pub connect_to_compute_retry_config: RetryConfig,
42 : }
43 :
44 : #[derive(Debug)]
45 : pub struct MetricCollectionConfig {
46 : pub endpoint: reqwest::Url,
47 : pub interval: Duration,
48 : pub backup_metric_collection_config: MetricBackupCollectionConfig,
49 : }
50 :
51 : pub struct TlsConfig {
52 : pub config: Arc<rustls::ServerConfig>,
53 : pub common_names: HashSet<String>,
54 : pub cert_resolver: Arc<CertResolver>,
55 : }
56 :
57 : pub struct HttpConfig {
58 : pub request_timeout: tokio::time::Duration,
59 : pub pool_options: GlobalConnPoolOptions,
60 : pub cancel_set: CancelSet,
61 : pub client_conn_threshold: u64,
62 : }
63 :
64 : pub struct AuthenticationConfig {
65 : pub thread_pool: Arc<ThreadPool>,
66 : pub scram_protocol_timeout: tokio::time::Duration,
67 : pub rate_limiter_enabled: bool,
68 : pub rate_limiter: AuthRateLimiter,
69 : pub rate_limit_ip_subnet: u8,
70 : }
71 :
72 : impl TlsConfig {
73 40 : pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
74 40 : self.config.clone()
75 40 : }
76 : }
77 :
78 : /// Configure TLS for the main endpoint.
79 0 : pub fn configure_tls(
80 0 : key_path: &str,
81 0 : cert_path: &str,
82 0 : certs_dir: Option<&String>,
83 0 : ) -> anyhow::Result<TlsConfig> {
84 0 : let mut cert_resolver = CertResolver::new();
85 0 :
86 0 : // add default certificate
87 0 : cert_resolver.add_cert_path(key_path, cert_path, true)?;
88 :
89 : // add extra certificates
90 0 : if let Some(certs_dir) = certs_dir {
91 0 : for entry in std::fs::read_dir(certs_dir)? {
92 0 : let entry = entry?;
93 0 : let path = entry.path();
94 0 : if path.is_dir() {
95 : // file names aligned with default cert-manager names
96 0 : let key_path = path.join("tls.key");
97 0 : let cert_path = path.join("tls.crt");
98 0 : if key_path.exists() && cert_path.exists() {
99 0 : cert_resolver.add_cert_path(
100 0 : &key_path.to_string_lossy(),
101 0 : &cert_path.to_string_lossy(),
102 0 : false,
103 0 : )?;
104 0 : }
105 0 : }
106 : }
107 0 : }
108 :
109 0 : let common_names = cert_resolver.get_common_names();
110 0 :
111 0 : let cert_resolver = Arc::new(cert_resolver);
112 0 :
113 0 : // allow TLS 1.2 to be compatible with older client libraries
114 0 : let config = rustls::ServerConfig::builder_with_protocol_versions(&[
115 0 : &rustls::version::TLS13,
116 0 : &rustls::version::TLS12,
117 0 : ])
118 0 : .with_no_client_auth()
119 0 : .with_cert_resolver(cert_resolver.clone())
120 0 : .into();
121 0 :
122 0 : Ok(TlsConfig {
123 0 : config,
124 0 : common_names,
125 0 : cert_resolver,
126 0 : })
127 0 : }
128 :
129 : /// Channel binding parameter
130 : ///
131 : /// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
132 : /// Description: The hash of the TLS server's certificate as it
133 : /// appears, octet for octet, in the server's Certificate message. Note
134 : /// that the Certificate message contains a certificate_list, in which
135 : /// the first element is the server's certificate.
136 : ///
137 : /// The hash function is to be selected as follows:
138 : ///
139 : /// * if the certificate's signatureAlgorithm uses a single hash
140 : /// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
141 : ///
142 : /// * if the certificate's signatureAlgorithm uses a single hash
143 : /// function and that hash function neither MD5 nor SHA-1, then use
144 : /// the hash function associated with the certificate's
145 : /// signatureAlgorithm;
146 : ///
147 : /// * if the certificate's signatureAlgorithm uses no hash functions or
148 : /// uses multiple hash functions, then this channel binding type's
149 : /// channel bindings are undefined at this time (updates to is channel
150 : /// binding type may occur to address this issue if it ever arises).
151 : #[derive(Debug, Clone, Copy)]
152 : pub enum TlsServerEndPoint {
153 : Sha256([u8; 32]),
154 : Undefined,
155 : }
156 :
157 : impl TlsServerEndPoint {
158 42 : pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
159 42 : let sha256_oids = [
160 42 : // I'm explicitly not adding MD5 or SHA1 here... They're bad.
161 42 : oid_registry::OID_SIG_ECDSA_WITH_SHA256,
162 42 : oid_registry::OID_PKCS1_SHA256WITHRSA,
163 42 : ];
164 :
165 42 : let pem = x509_parser::parse_x509_certificate(cert)
166 42 : .context("Failed to parse PEM object from cerficiate")?
167 : .1;
168 :
169 42 : info!(subject = %pem.subject, "parsing TLS certificate");
170 :
171 42 : let reg = oid_registry::OidRegistry::default().with_all_crypto();
172 42 : let oid = pem.signature_algorithm.oid();
173 42 : let alg = reg.get(oid);
174 42 : if sha256_oids.contains(oid) {
175 42 : let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
176 42 : info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
177 42 : Ok(Self::Sha256(tls_server_end_point))
178 : } else {
179 0 : error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
180 0 : Ok(Self::Undefined)
181 : }
182 42 : }
183 :
184 32 : pub fn supported(&self) -> bool {
185 32 : !matches!(self, TlsServerEndPoint::Undefined)
186 32 : }
187 : }
188 :
189 : #[derive(Default, Debug)]
190 : pub struct CertResolver {
191 : certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
192 : default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
193 : }
194 :
195 : impl CertResolver {
196 42 : pub fn new() -> Self {
197 42 : Self::default()
198 42 : }
199 :
200 0 : fn add_cert_path(
201 0 : &mut self,
202 0 : key_path: &str,
203 0 : cert_path: &str,
204 0 : is_default: bool,
205 0 : ) -> anyhow::Result<()> {
206 0 : let priv_key = {
207 0 : let key_bytes = std::fs::read(key_path)
208 0 : .context(format!("Failed to read TLS keys at '{key_path}'"))?;
209 0 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
210 0 :
211 0 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
212 : PrivateKeyDer::Pkcs8(
213 0 : keys.pop()
214 0 : .unwrap()
215 0 : .context(format!("Failed to parse TLS keys at '{key_path}'"))?,
216 : )
217 : };
218 :
219 0 : let cert_chain_bytes = std::fs::read(cert_path)
220 0 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
221 :
222 0 : let cert_chain = {
223 0 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
224 0 : .try_collect()
225 0 : .with_context(|| {
226 0 : format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
227 0 : })?
228 : };
229 :
230 0 : self.add_cert(priv_key, cert_chain, is_default)
231 0 : }
232 :
233 42 : pub fn add_cert(
234 42 : &mut self,
235 42 : priv_key: PrivateKeyDer<'static>,
236 42 : cert_chain: Vec<CertificateDer<'static>>,
237 42 : is_default: bool,
238 42 : ) -> anyhow::Result<()> {
239 42 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
240 :
241 42 : let first_cert = &cert_chain[0];
242 42 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
243 42 : let pem = x509_parser::parse_x509_certificate(first_cert)
244 42 : .context("Failed to parse PEM object from cerficiate")?
245 : .1;
246 :
247 42 : let common_name = pem.subject().to_string();
248 :
249 : // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
250 : // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
251 : // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
252 : // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
253 : // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
254 : // of cutting off '*.' parts.
255 42 : let common_name = if common_name.starts_with("CN=*.") {
256 0 : common_name.strip_prefix("CN=*.").map(|s| s.to_string())
257 : } else {
258 42 : common_name.strip_prefix("CN=").map(|s| s.to_string())
259 : }
260 42 : .context("Failed to parse common name from certificate")?;
261 :
262 42 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
263 42 :
264 42 : if is_default {
265 42 : self.default = Some((cert.clone(), tls_server_end_point));
266 42 : }
267 :
268 42 : self.certs.insert(common_name, (cert, tls_server_end_point));
269 42 :
270 42 : Ok(())
271 42 : }
272 :
273 42 : pub fn get_common_names(&self) -> HashSet<String> {
274 42 : self.certs.keys().map(|s| s.to_string()).collect()
275 42 : }
276 : }
277 :
278 : impl rustls::server::ResolvesServerCert for CertResolver {
279 0 : fn resolve(
280 0 : &self,
281 0 : client_hello: rustls::server::ClientHello,
282 0 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
283 0 : self.resolve(client_hello.server_name()).map(|x| x.0)
284 0 : }
285 : }
286 :
287 : impl CertResolver {
288 40 : pub fn resolve(
289 40 : &self,
290 40 : server_name: Option<&str>,
291 40 : ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
292 : // loop here and cut off more and more subdomains until we find
293 : // a match to get a proper wildcard support. OTOH, we now do not
294 : // use nested domains, so keep this simple for now.
295 : //
296 : // With the current coding foo.com will match *.foo.com and that
297 : // repeats behavior of the old code.
298 40 : if let Some(mut sni_name) = server_name {
299 : loop {
300 80 : if let Some(cert) = self.certs.get(sni_name) {
301 40 : return Some(cert.clone());
302 40 : }
303 40 : if let Some((_, rest)) = sni_name.split_once('.') {
304 40 : sni_name = rest;
305 40 : } else {
306 0 : return None;
307 : }
308 : }
309 : } else {
310 : // No SNI, use the default certificate, otherwise we can't get to
311 : // options parameter which can be used to set endpoint name too.
312 : // That means that non-SNI flow will not work for CNAME domains in
313 : // verify-full mode.
314 : //
315 : // If that will be a problem we can:
316 : //
317 : // a) Instead of multi-cert approach use single cert with extra
318 : // domains listed in Subject Alternative Name (SAN).
319 : // b) Deploy separate proxy instances for extra domains.
320 0 : self.default.as_ref().cloned()
321 : }
322 40 : }
323 : }
324 :
325 : #[derive(Debug)]
326 : pub struct EndpointCacheConfig {
327 : /// Batch size to receive all endpoints on the startup.
328 : pub initial_batch_size: usize,
329 : /// Batch size to receive endpoints.
330 : pub default_batch_size: usize,
331 : /// Timeouts for the stream read operation.
332 : pub xread_timeout: Duration,
333 : /// Stream name to read from.
334 : pub stream_name: String,
335 : /// Limiter info (to distinguish when to enable cache).
336 : pub limiter_info: Vec<RateBucketInfo>,
337 : /// Disable cache.
338 : /// If true, cache is ignored, but reports all statistics.
339 : pub disable_cache: bool,
340 : /// Retry interval for the stream read operation.
341 : pub retry_interval: Duration,
342 : }
343 :
344 : impl EndpointCacheConfig {
345 : /// Default options for [`crate::console::provider::NodeInfoCache`].
346 : /// Notice that by default the limiter is empty, which means that cache is disabled.
347 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
348 : "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
349 :
350 : /// Parse cache options passed via cmdline.
351 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
352 0 : fn parse(options: &str) -> anyhow::Result<Self> {
353 0 : let mut initial_batch_size = None;
354 0 : let mut default_batch_size = None;
355 0 : let mut xread_timeout = None;
356 0 : let mut stream_name = None;
357 0 : let mut limiter_info = vec![];
358 0 : let mut disable_cache = false;
359 0 : let mut retry_interval = None;
360 :
361 0 : for option in options.split(',') {
362 0 : let (key, value) = option
363 0 : .split_once('=')
364 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
365 :
366 0 : match key {
367 0 : "initial_batch_size" => initial_batch_size = Some(value.parse()?),
368 0 : "default_batch_size" => default_batch_size = Some(value.parse()?),
369 0 : "xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
370 0 : "stream_name" => stream_name = Some(value.to_string()),
371 0 : "limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
372 0 : "disable_cache" => disable_cache = value.parse()?,
373 0 : "retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
374 0 : unknown => bail!("unknown key: {unknown}"),
375 : }
376 : }
377 0 : RateBucketInfo::validate(&mut limiter_info)?;
378 :
379 : Ok(Self {
380 0 : initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
381 0 : default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
382 0 : xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
383 0 : stream_name: stream_name.context("missing `stream_name`")?,
384 0 : disable_cache,
385 0 : limiter_info,
386 0 : retry_interval: retry_interval.context("missing `retry_interval`")?,
387 : })
388 0 : }
389 : }
390 :
391 : impl FromStr for EndpointCacheConfig {
392 : type Err = anyhow::Error;
393 :
394 0 : fn from_str(options: &str) -> Result<Self, Self::Err> {
395 0 : let error = || format!("failed to parse endpoint cache options '{options}'");
396 0 : Self::parse(options).with_context(error)
397 0 : }
398 : }
399 : #[derive(Debug)]
400 : pub struct MetricBackupCollectionConfig {
401 : pub interval: Duration,
402 : pub remote_storage_config: Option<RemoteStorageConfig>,
403 : pub chunk_size: usize,
404 : }
405 :
406 2 : pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig> {
407 2 : RemoteStorageConfig::from_toml(&s.parse()?)
408 2 : }
409 :
410 : /// Helper for cmdline cache options parsing.
411 : #[derive(Debug)]
412 : pub struct CacheOptions {
413 : /// Max number of entries.
414 : pub size: usize,
415 : /// Entry's time-to-live.
416 : pub ttl: Duration,
417 : }
418 :
419 : impl CacheOptions {
420 : /// Default options for [`crate::console::provider::NodeInfoCache`].
421 : pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
422 :
423 : /// Parse cache options passed via cmdline.
424 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
425 8 : fn parse(options: &str) -> anyhow::Result<Self> {
426 8 : let mut size = None;
427 8 : let mut ttl = None;
428 :
429 14 : for option in options.split(',') {
430 14 : let (key, value) = option
431 14 : .split_once('=')
432 14 : .with_context(|| format!("bad key-value pair: {option}"))?;
433 :
434 14 : match key {
435 14 : "size" => size = Some(value.parse()?),
436 6 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
437 0 : unknown => bail!("unknown key: {unknown}"),
438 : }
439 : }
440 :
441 : // TTL doesn't matter if cache is always empty.
442 8 : if let Some(0) = size {
443 4 : ttl.get_or_insert(Duration::default());
444 4 : }
445 :
446 : Ok(Self {
447 8 : size: size.context("missing `size`")?,
448 8 : ttl: ttl.context("missing `ttl`")?,
449 : })
450 8 : }
451 : }
452 :
453 : impl FromStr for CacheOptions {
454 : type Err = anyhow::Error;
455 :
456 8 : fn from_str(options: &str) -> Result<Self, Self::Err> {
457 8 : let error = || format!("failed to parse cache options '{options}'");
458 8 : Self::parse(options).with_context(error)
459 8 : }
460 : }
461 :
462 : /// Helper for cmdline cache options parsing.
463 : #[derive(Debug)]
464 : pub struct ProjectInfoCacheOptions {
465 : /// Max number of entries.
466 : pub size: usize,
467 : /// Entry's time-to-live.
468 : pub ttl: Duration,
469 : /// Max number of roles per endpoint.
470 : pub max_roles: usize,
471 : /// Gc interval.
472 : pub gc_interval: Duration,
473 : }
474 :
475 : impl ProjectInfoCacheOptions {
476 : /// Default options for [`crate::console::provider::NodeInfoCache`].
477 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
478 : "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
479 :
480 : /// Parse cache options passed via cmdline.
481 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
482 0 : fn parse(options: &str) -> anyhow::Result<Self> {
483 0 : let mut size = None;
484 0 : let mut ttl = None;
485 0 : let mut max_roles = None;
486 0 : let mut gc_interval = None;
487 :
488 0 : for option in options.split(',') {
489 0 : let (key, value) = option
490 0 : .split_once('=')
491 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
492 :
493 0 : match key {
494 0 : "size" => size = Some(value.parse()?),
495 0 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
496 0 : "max_roles" => max_roles = Some(value.parse()?),
497 0 : "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
498 0 : unknown => bail!("unknown key: {unknown}"),
499 : }
500 : }
501 :
502 : // TTL doesn't matter if cache is always empty.
503 0 : if let Some(0) = size {
504 0 : ttl.get_or_insert(Duration::default());
505 0 : }
506 :
507 : Ok(Self {
508 0 : size: size.context("missing `size`")?,
509 0 : ttl: ttl.context("missing `ttl`")?,
510 0 : max_roles: max_roles.context("missing `max_roles`")?,
511 0 : gc_interval: gc_interval.context("missing `gc_interval`")?,
512 : })
513 0 : }
514 : }
515 :
516 : impl FromStr for ProjectInfoCacheOptions {
517 : type Err = anyhow::Error;
518 :
519 0 : fn from_str(options: &str) -> Result<Self, Self::Err> {
520 0 : let error = || format!("failed to parse cache options '{options}'");
521 0 : Self::parse(options).with_context(error)
522 0 : }
523 : }
524 :
525 : /// This is a config for connect to compute and wake compute.
526 : #[derive(Clone, Copy, Debug)]
527 : pub struct RetryConfig {
528 : /// Number of times we should retry.
529 : pub max_retries: u32,
530 : /// Retry duration is base_delay * backoff_factor ^ n, where n starts at 0
531 : pub base_delay: tokio::time::Duration,
532 : /// Exponential base for retry wait duration
533 : pub backoff_factor: f64,
534 : }
535 :
536 : impl RetryConfig {
537 : /// Default options for RetryConfig.
538 :
539 : /// Total delay for 5 retries with 200ms base delay and 2 backoff factor is about 6s.
540 : pub const CONNECT_TO_COMPUTE_DEFAULT_VALUES: &'static str =
541 : "num_retries=5,base_retry_wait_duration=200ms,retry_wait_exponent_base=2";
542 : /// Total delay for 8 retries with 100ms base delay and 1.6 backoff factor is about 7s.
543 : /// Cplane has timeout of 60s on each request. 8m7s in total.
544 : pub const WAKE_COMPUTE_DEFAULT_VALUES: &'static str =
545 : "num_retries=8,base_retry_wait_duration=100ms,retry_wait_exponent_base=1.6";
546 :
547 : /// Parse retry options passed via cmdline.
548 : /// Example: [`Self::CONNECT_TO_COMPUTE_DEFAULT_VALUES`].
549 0 : pub fn parse(options: &str) -> anyhow::Result<Self> {
550 0 : let mut num_retries = None;
551 0 : let mut base_retry_wait_duration = None;
552 0 : let mut retry_wait_exponent_base = None;
553 :
554 0 : for option in options.split(',') {
555 0 : let (key, value) = option
556 0 : .split_once('=')
557 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
558 :
559 0 : match key {
560 0 : "num_retries" => num_retries = Some(value.parse()?),
561 0 : "base_retry_wait_duration" => {
562 0 : base_retry_wait_duration = Some(humantime::parse_duration(value)?)
563 : }
564 0 : "retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?),
565 0 : unknown => bail!("unknown key: {unknown}"),
566 : }
567 : }
568 :
569 : Ok(Self {
570 0 : max_retries: num_retries.context("missing `num_retries`")?,
571 0 : base_delay: base_retry_wait_duration.context("missing `base_retry_wait_duration`")?,
572 0 : backoff_factor: retry_wait_exponent_base
573 0 : .context("missing `retry_wait_exponent_base`")?,
574 : })
575 0 : }
576 : }
577 :
578 : /// Helper for cmdline cache options parsing.
579 16 : #[derive(serde::Deserialize)]
580 : pub struct ConcurrencyLockOptions {
581 : /// The number of shards the lock map should have
582 : pub shards: usize,
583 : /// The number of allowed concurrent requests for each endpoitn
584 : #[serde(flatten)]
585 : pub limiter: RateLimiterConfig,
586 : /// Garbage collection epoch
587 : #[serde(deserialize_with = "humantime_serde::deserialize")]
588 : pub epoch: Duration,
589 : /// Lock timeout
590 : #[serde(deserialize_with = "humantime_serde::deserialize")]
591 : pub timeout: Duration,
592 : }
593 :
594 : impl ConcurrencyLockOptions {
595 : /// Default options for [`crate::console::provider::ApiLocks`].
596 : pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
597 : /// Default options for [`crate::console::provider::ApiLocks`].
598 : pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
599 : "shards=64,permits=100,epoch=10m,timeout=10ms";
600 :
601 : // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
602 :
603 : /// Parse lock options passed via cmdline.
604 : /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
605 8 : fn parse(options: &str) -> anyhow::Result<Self> {
606 8 : let options = options.trim();
607 8 : if options.starts_with('{') && options.ends_with('}') {
608 2 : return Ok(serde_json::from_str(options)?);
609 6 : }
610 6 :
611 6 : let mut shards = None;
612 6 : let mut permits = None;
613 6 : let mut epoch = None;
614 6 : let mut timeout = None;
615 :
616 18 : for option in options.split(',') {
617 18 : let (key, value) = option
618 18 : .split_once('=')
619 18 : .with_context(|| format!("bad key-value pair: {option}"))?;
620 :
621 18 : match key {
622 18 : "shards" => shards = Some(value.parse()?),
623 14 : "permits" => permits = Some(value.parse()?),
624 8 : "epoch" => epoch = Some(humantime::parse_duration(value)?),
625 4 : "timeout" => timeout = Some(humantime::parse_duration(value)?),
626 0 : unknown => bail!("unknown key: {unknown}"),
627 : }
628 : }
629 :
630 : // these dont matter if lock is disabled
631 6 : if let Some(0) = permits {
632 2 : timeout = Some(Duration::default());
633 2 : epoch = Some(Duration::default());
634 2 : shards = Some(2);
635 4 : }
636 :
637 6 : let permits = permits.context("missing `permits`")?;
638 6 : let out = Self {
639 6 : shards: shards.context("missing `shards`")?,
640 6 : limiter: RateLimiterConfig {
641 6 : algorithm: RateLimitAlgorithm::Fixed,
642 6 : initial_limit: permits,
643 6 : },
644 6 : epoch: epoch.context("missing `epoch`")?,
645 6 : timeout: timeout.context("missing `timeout`")?,
646 : };
647 :
648 6 : ensure!(out.shards > 1, "shard count must be > 1");
649 6 : ensure!(
650 6 : out.shards.is_power_of_two(),
651 0 : "shard count must be a power of two"
652 : );
653 :
654 6 : Ok(out)
655 8 : }
656 : }
657 :
658 : impl FromStr for ConcurrencyLockOptions {
659 : type Err = anyhow::Error;
660 :
661 8 : fn from_str(options: &str) -> Result<Self, Self::Err> {
662 8 : let error = || format!("failed to parse cache lock options '{options}'");
663 8 : Self::parse(options).with_context(error)
664 8 : }
665 : }
666 :
667 : #[cfg(test)]
668 : mod tests {
669 : use crate::rate_limiter::Aimd;
670 :
671 : use super::*;
672 :
673 : #[test]
674 2 : fn test_parse_cache_options() -> anyhow::Result<()> {
675 2 : let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
676 2 : assert_eq!(size, 4096);
677 2 : assert_eq!(ttl, Duration::from_secs(5 * 60));
678 :
679 2 : let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
680 2 : assert_eq!(size, 2);
681 2 : assert_eq!(ttl, Duration::from_secs(4 * 60));
682 :
683 2 : let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
684 2 : assert_eq!(size, 0);
685 2 : assert_eq!(ttl, Duration::from_secs(1));
686 :
687 2 : let CacheOptions { size, ttl } = "size=0".parse()?;
688 2 : assert_eq!(size, 0);
689 2 : assert_eq!(ttl, Duration::default());
690 :
691 2 : Ok(())
692 2 : }
693 :
694 : #[test]
695 2 : fn test_parse_lock_options() -> anyhow::Result<()> {
696 : let ConcurrencyLockOptions {
697 2 : epoch,
698 2 : limiter,
699 2 : shards,
700 2 : timeout,
701 2 : } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
702 2 : assert_eq!(epoch, Duration::from_secs(10 * 60));
703 2 : assert_eq!(timeout, Duration::from_secs(1));
704 2 : assert_eq!(shards, 32);
705 2 : assert_eq!(limiter.initial_limit, 4);
706 2 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
707 :
708 : let ConcurrencyLockOptions {
709 2 : epoch,
710 2 : limiter,
711 2 : shards,
712 2 : timeout,
713 2 : } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
714 2 : assert_eq!(epoch, Duration::from_secs(60));
715 2 : assert_eq!(timeout, Duration::from_millis(100));
716 2 : assert_eq!(shards, 16);
717 2 : assert_eq!(limiter.initial_limit, 8);
718 2 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
719 :
720 : let ConcurrencyLockOptions {
721 2 : epoch,
722 2 : limiter,
723 2 : shards,
724 2 : timeout,
725 2 : } = "permits=0".parse()?;
726 2 : assert_eq!(epoch, Duration::ZERO);
727 2 : assert_eq!(timeout, Duration::ZERO);
728 2 : assert_eq!(shards, 2);
729 2 : assert_eq!(limiter.initial_limit, 0);
730 2 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
731 :
732 2 : Ok(())
733 2 : }
734 :
735 : #[test]
736 2 : fn test_parse_json_lock_options() -> anyhow::Result<()> {
737 : let ConcurrencyLockOptions {
738 2 : epoch,
739 2 : limiter,
740 2 : shards,
741 2 : timeout,
742 2 : } = r#"{"shards":32,"initial_limit":44,"aimd":{"min":5,"max":500,"inc":10,"dec":0.9,"utilisation":0.8},"epoch":"10m","timeout":"1s"}"#
743 2 : .parse()?;
744 2 : assert_eq!(epoch, Duration::from_secs(10 * 60));
745 2 : assert_eq!(timeout, Duration::from_secs(1));
746 2 : assert_eq!(shards, 32);
747 2 : assert_eq!(limiter.initial_limit, 44);
748 2 : assert_eq!(
749 2 : limiter.algorithm,
750 2 : RateLimitAlgorithm::Aimd {
751 2 : conf: Aimd {
752 2 : min: 5,
753 2 : max: 500,
754 2 : dec: 0.9,
755 2 : inc: 10,
756 2 : utilisation: 0.8
757 2 : }
758 2 : },
759 2 : );
760 :
761 2 : Ok(())
762 2 : }
763 : }
|