Line data Source code
1 : use crate::{
2 : auth::{self, backend::AuthRateLimiter},
3 : rate_limiter::RateBucketInfo,
4 : serverless::GlobalConnPoolOptions,
5 : };
6 : use anyhow::{bail, ensure, Context, Ok};
7 : use itertools::Itertools;
8 : use remote_storage::RemoteStorageConfig;
9 : use rustls::{
10 : crypto::ring::sign,
11 : pki_types::{CertificateDer, PrivateKeyDer},
12 : };
13 : use sha2::{Digest, Sha256};
14 : use std::{
15 : collections::{HashMap, HashSet},
16 : str::FromStr,
17 : sync::Arc,
18 : time::Duration,
19 : };
20 : use tracing::{error, info};
21 : use x509_parser::oid_registry;
22 :
23 : pub struct ProxyConfig {
24 : pub tls_config: Option<TlsConfig>,
25 : pub auth_backend: auth::BackendType<'static, (), ()>,
26 : pub metric_collection: Option<MetricCollectionConfig>,
27 : pub allow_self_signed_compute: bool,
28 : pub http_config: HttpConfig,
29 : pub authentication_config: AuthenticationConfig,
30 : pub require_client_ip: bool,
31 : pub disable_ip_check_for_http: bool,
32 : pub redis_rps_limit: Vec<RateBucketInfo>,
33 : pub region: String,
34 : pub handshake_timeout: Duration,
35 : pub aws_region: String,
36 : }
37 :
38 : #[derive(Debug)]
39 : pub struct MetricCollectionConfig {
40 : pub endpoint: reqwest::Url,
41 : pub interval: Duration,
42 : pub backup_metric_collection_config: MetricBackupCollectionConfig,
43 : }
44 :
45 : pub struct TlsConfig {
46 : pub config: Arc<rustls::ServerConfig>,
47 : pub common_names: HashSet<String>,
48 : pub cert_resolver: Arc<CertResolver>,
49 : }
50 :
51 : pub struct HttpConfig {
52 : pub request_timeout: tokio::time::Duration,
53 : pub pool_options: GlobalConnPoolOptions,
54 : }
55 :
56 : pub struct AuthenticationConfig {
57 : pub scram_protocol_timeout: tokio::time::Duration,
58 : pub rate_limiter_enabled: bool,
59 : pub rate_limiter: AuthRateLimiter,
60 : pub rate_limit_ip_subnet: u8,
61 : }
62 :
63 : impl TlsConfig {
64 40 : pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
65 40 : self.config.clone()
66 40 : }
67 : }
68 :
69 : /// Configure TLS for the main endpoint.
70 0 : pub fn configure_tls(
71 0 : key_path: &str,
72 0 : cert_path: &str,
73 0 : certs_dir: Option<&String>,
74 0 : ) -> anyhow::Result<TlsConfig> {
75 0 : let mut cert_resolver = CertResolver::new();
76 0 :
77 0 : // add default certificate
78 0 : cert_resolver.add_cert_path(key_path, cert_path, true)?;
79 :
80 : // add extra certificates
81 0 : if let Some(certs_dir) = certs_dir {
82 0 : for entry in std::fs::read_dir(certs_dir)? {
83 0 : let entry = entry?;
84 0 : let path = entry.path();
85 0 : if path.is_dir() {
86 : // file names aligned with default cert-manager names
87 0 : let key_path = path.join("tls.key");
88 0 : let cert_path = path.join("tls.crt");
89 0 : if key_path.exists() && cert_path.exists() {
90 0 : cert_resolver.add_cert_path(
91 0 : &key_path.to_string_lossy(),
92 0 : &cert_path.to_string_lossy(),
93 0 : false,
94 0 : )?;
95 0 : }
96 0 : }
97 : }
98 0 : }
99 :
100 0 : let common_names = cert_resolver.get_common_names();
101 0 :
102 0 : let cert_resolver = Arc::new(cert_resolver);
103 0 :
104 0 : // allow TLS 1.2 to be compatible with older client libraries
105 0 : let config = rustls::ServerConfig::builder_with_protocol_versions(&[
106 0 : &rustls::version::TLS13,
107 0 : &rustls::version::TLS12,
108 0 : ])
109 0 : .with_no_client_auth()
110 0 : .with_cert_resolver(cert_resolver.clone())
111 0 : .into();
112 0 :
113 0 : Ok(TlsConfig {
114 0 : config,
115 0 : common_names,
116 0 : cert_resolver,
117 0 : })
118 0 : }
119 :
120 : /// Channel binding parameter
121 : ///
122 : /// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
123 : /// Description: The hash of the TLS server's certificate as it
124 : /// appears, octet for octet, in the server's Certificate message. Note
125 : /// that the Certificate message contains a certificate_list, in which
126 : /// the first element is the server's certificate.
127 : ///
128 : /// The hash function is to be selected as follows:
129 : ///
130 : /// * if the certificate's signatureAlgorithm uses a single hash
131 : /// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
132 : ///
133 : /// * if the certificate's signatureAlgorithm uses a single hash
134 : /// function and that hash function neither MD5 nor SHA-1, then use
135 : /// the hash function associated with the certificate's
136 : /// signatureAlgorithm;
137 : ///
138 : /// * if the certificate's signatureAlgorithm uses no hash functions or
139 : /// uses multiple hash functions, then this channel binding type's
140 : /// channel bindings are undefined at this time (updates to is channel
141 : /// binding type may occur to address this issue if it ever arises).
142 : #[derive(Debug, Clone, Copy)]
143 : pub enum TlsServerEndPoint {
144 : Sha256([u8; 32]),
145 : Undefined,
146 : }
147 :
148 : impl TlsServerEndPoint {
149 42 : pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
150 42 : let sha256_oids = [
151 42 : // I'm explicitly not adding MD5 or SHA1 here... They're bad.
152 42 : oid_registry::OID_SIG_ECDSA_WITH_SHA256,
153 42 : oid_registry::OID_PKCS1_SHA256WITHRSA,
154 42 : ];
155 :
156 42 : let pem = x509_parser::parse_x509_certificate(cert)
157 42 : .context("Failed to parse PEM object from cerficiate")?
158 : .1;
159 :
160 42 : info!(subject = %pem.subject, "parsing TLS certificate");
161 :
162 42 : let reg = oid_registry::OidRegistry::default().with_all_crypto();
163 42 : let oid = pem.signature_algorithm.oid();
164 42 : let alg = reg.get(oid);
165 42 : if sha256_oids.contains(oid) {
166 42 : let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
167 42 : info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
168 42 : Ok(Self::Sha256(tls_server_end_point))
169 : } else {
170 0 : error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
171 0 : Ok(Self::Undefined)
172 : }
173 42 : }
174 :
175 32 : pub fn supported(&self) -> bool {
176 32 : !matches!(self, TlsServerEndPoint::Undefined)
177 32 : }
178 : }
179 :
180 : #[derive(Default, Debug)]
181 : pub struct CertResolver {
182 : certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
183 : default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
184 : }
185 :
186 : impl CertResolver {
187 42 : pub fn new() -> Self {
188 42 : Self::default()
189 42 : }
190 :
191 0 : fn add_cert_path(
192 0 : &mut self,
193 0 : key_path: &str,
194 0 : cert_path: &str,
195 0 : is_default: bool,
196 0 : ) -> anyhow::Result<()> {
197 0 : let priv_key = {
198 0 : let key_bytes = std::fs::read(key_path)
199 0 : .context(format!("Failed to read TLS keys at '{key_path}'"))?;
200 0 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
201 0 :
202 0 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
203 : PrivateKeyDer::Pkcs8(
204 0 : keys.pop()
205 0 : .unwrap()
206 0 : .context(format!("Failed to parse TLS keys at '{key_path}'"))?,
207 : )
208 : };
209 :
210 0 : let cert_chain_bytes = std::fs::read(cert_path)
211 0 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
212 :
213 0 : let cert_chain = {
214 0 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
215 0 : .try_collect()
216 0 : .with_context(|| {
217 0 : format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
218 0 : })?
219 : };
220 :
221 0 : self.add_cert(priv_key, cert_chain, is_default)
222 0 : }
223 :
224 42 : pub fn add_cert(
225 42 : &mut self,
226 42 : priv_key: PrivateKeyDer<'static>,
227 42 : cert_chain: Vec<CertificateDer<'static>>,
228 42 : is_default: bool,
229 42 : ) -> anyhow::Result<()> {
230 42 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
231 :
232 42 : let first_cert = &cert_chain[0];
233 42 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
234 42 : let pem = x509_parser::parse_x509_certificate(first_cert)
235 42 : .context("Failed to parse PEM object from cerficiate")?
236 : .1;
237 :
238 42 : let common_name = pem.subject().to_string();
239 :
240 : // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
241 : // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
242 : // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
243 : // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
244 : // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
245 : // of cutting off '*.' parts.
246 42 : let common_name = if common_name.starts_with("CN=*.") {
247 0 : common_name.strip_prefix("CN=*.").map(|s| s.to_string())
248 : } else {
249 42 : common_name.strip_prefix("CN=").map(|s| s.to_string())
250 : }
251 42 : .context("Failed to parse common name from certificate")?;
252 :
253 42 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
254 42 :
255 42 : if is_default {
256 42 : self.default = Some((cert.clone(), tls_server_end_point));
257 42 : }
258 :
259 42 : self.certs.insert(common_name, (cert, tls_server_end_point));
260 42 :
261 42 : Ok(())
262 42 : }
263 :
264 42 : pub fn get_common_names(&self) -> HashSet<String> {
265 42 : self.certs.keys().map(|s| s.to_string()).collect()
266 42 : }
267 : }
268 :
269 : impl rustls::server::ResolvesServerCert for CertResolver {
270 0 : fn resolve(
271 0 : &self,
272 0 : client_hello: rustls::server::ClientHello,
273 0 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
274 0 : self.resolve(client_hello.server_name()).map(|x| x.0)
275 0 : }
276 : }
277 :
278 : impl CertResolver {
279 40 : pub fn resolve(
280 40 : &self,
281 40 : server_name: Option<&str>,
282 40 : ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
283 : // loop here and cut off more and more subdomains until we find
284 : // a match to get a proper wildcard support. OTOH, we now do not
285 : // use nested domains, so keep this simple for now.
286 : //
287 : // With the current coding foo.com will match *.foo.com and that
288 : // repeats behavior of the old code.
289 40 : if let Some(mut sni_name) = server_name {
290 : loop {
291 80 : if let Some(cert) = self.certs.get(sni_name) {
292 40 : return Some(cert.clone());
293 40 : }
294 40 : if let Some((_, rest)) = sni_name.split_once('.') {
295 40 : sni_name = rest;
296 40 : } else {
297 0 : return None;
298 : }
299 : }
300 : } else {
301 : // No SNI, use the default certificate, otherwise we can't get to
302 : // options parameter which can be used to set endpoint name too.
303 : // That means that non-SNI flow will not work for CNAME domains in
304 : // verify-full mode.
305 : //
306 : // If that will be a problem we can:
307 : //
308 : // a) Instead of multi-cert approach use single cert with extra
309 : // domains listed in Subject Alternative Name (SAN).
310 : // b) Deploy separate proxy instances for extra domains.
311 0 : self.default.as_ref().cloned()
312 : }
313 40 : }
314 : }
315 :
316 : #[derive(Debug)]
317 : pub struct EndpointCacheConfig {
318 : /// Batch size to receive all endpoints on the startup.
319 : pub initial_batch_size: usize,
320 : /// Batch size to receive endpoints.
321 : pub default_batch_size: usize,
322 : /// Timeouts for the stream read operation.
323 : pub xread_timeout: Duration,
324 : /// Stream name to read from.
325 : pub stream_name: String,
326 : /// Limiter info (to distinguish when to enable cache).
327 : pub limiter_info: Vec<RateBucketInfo>,
328 : /// Disable cache.
329 : /// If true, cache is ignored, but reports all statistics.
330 : pub disable_cache: bool,
331 : /// Retry interval for the stream read operation.
332 : pub retry_interval: Duration,
333 : }
334 :
335 : impl EndpointCacheConfig {
336 : /// Default options for [`crate::console::provider::NodeInfoCache`].
337 : /// Notice that by default the limiter is empty, which means that cache is disabled.
338 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
339 : "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
340 :
341 : /// Parse cache options passed via cmdline.
342 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
343 0 : fn parse(options: &str) -> anyhow::Result<Self> {
344 0 : let mut initial_batch_size = None;
345 0 : let mut default_batch_size = None;
346 0 : let mut xread_timeout = None;
347 0 : let mut stream_name = None;
348 0 : let mut limiter_info = vec![];
349 0 : let mut disable_cache = false;
350 0 : let mut retry_interval = None;
351 :
352 0 : for option in options.split(',') {
353 0 : let (key, value) = option
354 0 : .split_once('=')
355 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
356 :
357 0 : match key {
358 0 : "initial_batch_size" => initial_batch_size = Some(value.parse()?),
359 0 : "default_batch_size" => default_batch_size = Some(value.parse()?),
360 0 : "xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
361 0 : "stream_name" => stream_name = Some(value.to_string()),
362 0 : "limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
363 0 : "disable_cache" => disable_cache = value.parse()?,
364 0 : "retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
365 0 : unknown => bail!("unknown key: {unknown}"),
366 : }
367 : }
368 0 : RateBucketInfo::validate(&mut limiter_info)?;
369 :
370 : Ok(Self {
371 0 : initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
372 0 : default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
373 0 : xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
374 0 : stream_name: stream_name.context("missing `stream_name`")?,
375 0 : disable_cache,
376 0 : limiter_info,
377 0 : retry_interval: retry_interval.context("missing `retry_interval`")?,
378 : })
379 0 : }
380 : }
381 :
382 : impl FromStr for EndpointCacheConfig {
383 : type Err = anyhow::Error;
384 :
385 0 : fn from_str(options: &str) -> Result<Self, Self::Err> {
386 0 : let error = || format!("failed to parse endpoint cache options '{options}'");
387 0 : Self::parse(options).with_context(error)
388 0 : }
389 : }
390 : #[derive(Debug)]
391 : pub struct MetricBackupCollectionConfig {
392 : pub interval: Duration,
393 : pub remote_storage_config: OptRemoteStorageConfig,
394 : pub chunk_size: usize,
395 : }
396 :
397 : /// Hack to avoid clap being smarter. If you don't use this type alias, clap assumes more about the optional state and you get
398 : /// runtime type errors from the value parser we use.
399 : pub type OptRemoteStorageConfig = Option<RemoteStorageConfig>;
400 :
401 12 : pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<OptRemoteStorageConfig> {
402 12 : RemoteStorageConfig::from_toml(&s.parse()?)
403 12 : }
404 :
405 : /// Helper for cmdline cache options parsing.
406 : #[derive(Debug)]
407 : pub struct CacheOptions {
408 : /// Max number of entries.
409 : pub size: usize,
410 : /// Entry's time-to-live.
411 : pub ttl: Duration,
412 : }
413 :
414 : impl CacheOptions {
415 : /// Default options for [`crate::console::provider::NodeInfoCache`].
416 : pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
417 :
418 : /// Parse cache options passed via cmdline.
419 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
420 8 : fn parse(options: &str) -> anyhow::Result<Self> {
421 8 : let mut size = None;
422 8 : let mut ttl = None;
423 :
424 14 : for option in options.split(',') {
425 14 : let (key, value) = option
426 14 : .split_once('=')
427 14 : .with_context(|| format!("bad key-value pair: {option}"))?;
428 :
429 14 : match key {
430 14 : "size" => size = Some(value.parse()?),
431 6 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
432 0 : unknown => bail!("unknown key: {unknown}"),
433 : }
434 : }
435 :
436 : // TTL doesn't matter if cache is always empty.
437 8 : if let Some(0) = size {
438 4 : ttl.get_or_insert(Duration::default());
439 4 : }
440 :
441 : Ok(Self {
442 8 : size: size.context("missing `size`")?,
443 8 : ttl: ttl.context("missing `ttl`")?,
444 : })
445 8 : }
446 : }
447 :
448 : impl FromStr for CacheOptions {
449 : type Err = anyhow::Error;
450 :
451 8 : fn from_str(options: &str) -> Result<Self, Self::Err> {
452 8 : let error = || format!("failed to parse cache options '{options}'");
453 8 : Self::parse(options).with_context(error)
454 8 : }
455 : }
456 :
457 : /// Helper for cmdline cache options parsing.
458 : #[derive(Debug)]
459 : pub struct ProjectInfoCacheOptions {
460 : /// Max number of entries.
461 : pub size: usize,
462 : /// Entry's time-to-live.
463 : pub ttl: Duration,
464 : /// Max number of roles per endpoint.
465 : pub max_roles: usize,
466 : /// Gc interval.
467 : pub gc_interval: Duration,
468 : }
469 :
470 : impl ProjectInfoCacheOptions {
471 : /// Default options for [`crate::console::provider::NodeInfoCache`].
472 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
473 : "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
474 :
475 : /// Parse cache options passed via cmdline.
476 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
477 0 : fn parse(options: &str) -> anyhow::Result<Self> {
478 0 : let mut size = None;
479 0 : let mut ttl = None;
480 0 : let mut max_roles = None;
481 0 : let mut gc_interval = None;
482 :
483 0 : for option in options.split(',') {
484 0 : let (key, value) = option
485 0 : .split_once('=')
486 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
487 :
488 0 : match key {
489 0 : "size" => size = Some(value.parse()?),
490 0 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
491 0 : "max_roles" => max_roles = Some(value.parse()?),
492 0 : "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
493 0 : unknown => bail!("unknown key: {unknown}"),
494 : }
495 : }
496 :
497 : // TTL doesn't matter if cache is always empty.
498 0 : if let Some(0) = size {
499 0 : ttl.get_or_insert(Duration::default());
500 0 : }
501 :
502 : Ok(Self {
503 0 : size: size.context("missing `size`")?,
504 0 : ttl: ttl.context("missing `ttl`")?,
505 0 : max_roles: max_roles.context("missing `max_roles`")?,
506 0 : gc_interval: gc_interval.context("missing `gc_interval`")?,
507 : })
508 0 : }
509 : }
510 :
511 : impl FromStr for ProjectInfoCacheOptions {
512 : type Err = anyhow::Error;
513 :
514 0 : fn from_str(options: &str) -> Result<Self, Self::Err> {
515 0 : let error = || format!("failed to parse cache options '{options}'");
516 0 : Self::parse(options).with_context(error)
517 0 : }
518 : }
519 :
520 : /// Helper for cmdline cache options parsing.
521 : pub struct WakeComputeLockOptions {
522 : /// The number of shards the lock map should have
523 : pub shards: usize,
524 : /// The number of allowed concurrent requests for each endpoitn
525 : pub permits: usize,
526 : /// Garbage collection epoch
527 : pub epoch: Duration,
528 : /// Lock timeout
529 : pub timeout: Duration,
530 : }
531 :
532 : impl WakeComputeLockOptions {
533 : /// Default options for [`crate::console::provider::ApiLocks`].
534 : pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
535 :
536 : // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
537 :
538 : /// Parse lock options passed via cmdline.
539 : /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
540 6 : fn parse(options: &str) -> anyhow::Result<Self> {
541 6 : let mut shards = None;
542 6 : let mut permits = None;
543 6 : let mut epoch = None;
544 6 : let mut timeout = None;
545 :
546 18 : for option in options.split(',') {
547 18 : let (key, value) = option
548 18 : .split_once('=')
549 18 : .with_context(|| format!("bad key-value pair: {option}"))?;
550 :
551 18 : match key {
552 18 : "shards" => shards = Some(value.parse()?),
553 14 : "permits" => permits = Some(value.parse()?),
554 8 : "epoch" => epoch = Some(humantime::parse_duration(value)?),
555 4 : "timeout" => timeout = Some(humantime::parse_duration(value)?),
556 0 : unknown => bail!("unknown key: {unknown}"),
557 : }
558 : }
559 :
560 : // these dont matter if lock is disabled
561 6 : if let Some(0) = permits {
562 2 : timeout = Some(Duration::default());
563 2 : epoch = Some(Duration::default());
564 2 : shards = Some(2);
565 4 : }
566 :
567 6 : let out = Self {
568 6 : shards: shards.context("missing `shards`")?,
569 6 : permits: permits.context("missing `permits`")?,
570 6 : epoch: epoch.context("missing `epoch`")?,
571 6 : timeout: timeout.context("missing `timeout`")?,
572 : };
573 :
574 6 : ensure!(out.shards > 1, "shard count must be > 1");
575 6 : ensure!(
576 6 : out.shards.is_power_of_two(),
577 0 : "shard count must be a power of two"
578 : );
579 :
580 6 : Ok(out)
581 6 : }
582 : }
583 :
584 : impl FromStr for WakeComputeLockOptions {
585 : type Err = anyhow::Error;
586 :
587 6 : fn from_str(options: &str) -> Result<Self, Self::Err> {
588 6 : let error = || format!("failed to parse cache lock options '{options}'");
589 6 : Self::parse(options).with_context(error)
590 6 : }
591 : }
592 :
593 : #[cfg(test)]
594 : mod tests {
595 : use super::*;
596 :
597 : #[test]
598 2 : fn test_parse_cache_options() -> anyhow::Result<()> {
599 2 : let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
600 2 : assert_eq!(size, 4096);
601 2 : assert_eq!(ttl, Duration::from_secs(5 * 60));
602 :
603 2 : let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
604 2 : assert_eq!(size, 2);
605 2 : assert_eq!(ttl, Duration::from_secs(4 * 60));
606 :
607 2 : let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
608 2 : assert_eq!(size, 0);
609 2 : assert_eq!(ttl, Duration::from_secs(1));
610 :
611 2 : let CacheOptions { size, ttl } = "size=0".parse()?;
612 2 : assert_eq!(size, 0);
613 2 : assert_eq!(ttl, Duration::default());
614 :
615 2 : Ok(())
616 2 : }
617 :
618 : #[test]
619 2 : fn test_parse_lock_options() -> anyhow::Result<()> {
620 : let WakeComputeLockOptions {
621 2 : epoch,
622 2 : permits,
623 2 : shards,
624 2 : timeout,
625 2 : } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
626 2 : assert_eq!(epoch, Duration::from_secs(10 * 60));
627 2 : assert_eq!(timeout, Duration::from_secs(1));
628 2 : assert_eq!(shards, 32);
629 2 : assert_eq!(permits, 4);
630 :
631 : let WakeComputeLockOptions {
632 2 : epoch,
633 2 : permits,
634 2 : shards,
635 2 : timeout,
636 2 : } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
637 2 : assert_eq!(epoch, Duration::from_secs(60));
638 2 : assert_eq!(timeout, Duration::from_millis(100));
639 2 : assert_eq!(shards, 16);
640 2 : assert_eq!(permits, 8);
641 :
642 : let WakeComputeLockOptions {
643 2 : epoch,
644 2 : permits,
645 2 : shards,
646 2 : timeout,
647 2 : } = "permits=0".parse()?;
648 2 : assert_eq!(epoch, Duration::ZERO);
649 2 : assert_eq!(timeout, Duration::ZERO);
650 2 : assert_eq!(shards, 2);
651 2 : assert_eq!(permits, 0);
652 :
653 2 : Ok(())
654 2 : }
655 : }
|