Line data Source code
1 : use std::collections::{HashMap, HashSet};
2 : use std::str::FromStr;
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use anyhow::{bail, ensure, Context, Ok};
7 : use clap::ValueEnum;
8 : use itertools::Itertools;
9 : use remote_storage::RemoteStorageConfig;
10 : use rustls::crypto::ring::{self, sign};
11 : use rustls::pki_types::{CertificateDer, PrivateKeyDer};
12 : use sha2::{Digest, Sha256};
13 : use tracing::{error, info};
14 : use x509_parser::oid_registry;
15 :
16 : use crate::auth::backend::jwt::JwkCache;
17 : use crate::auth::backend::AuthRateLimiter;
18 : use crate::control_plane::locks::ApiLocks;
19 : use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
20 : use crate::scram::threadpool::ThreadPool;
21 : use crate::serverless::cancel_set::CancelSet;
22 : use crate::serverless::GlobalConnPoolOptions;
23 : use crate::types::Host;
24 :
25 : pub struct ProxyConfig {
26 : pub tls_config: Option<TlsConfig>,
27 : pub metric_collection: Option<MetricCollectionConfig>,
28 : pub allow_self_signed_compute: bool,
29 : pub http_config: HttpConfig,
30 : pub authentication_config: AuthenticationConfig,
31 : pub proxy_protocol_v2: ProxyProtocolV2,
32 : pub region: String,
33 : pub handshake_timeout: Duration,
34 : pub wake_compute_retry_config: RetryConfig,
35 : pub connect_compute_locks: ApiLocks<Host>,
36 : pub connect_to_compute_retry_config: RetryConfig,
37 : }
38 :
39 5 : #[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
40 : pub enum ProxyProtocolV2 {
41 : /// Connection will error if PROXY protocol v2 header is missing
42 : Required,
43 : /// Connection will parse PROXY protocol v2 header, but accept the connection if it's missing.
44 : Supported,
45 : /// Connection will error if PROXY protocol v2 header is provided
46 : Rejected,
47 : }
48 :
49 : #[derive(Debug)]
50 : pub struct MetricCollectionConfig {
51 : pub endpoint: reqwest::Url,
52 : pub interval: Duration,
53 : pub backup_metric_collection_config: MetricBackupCollectionConfig,
54 : }
55 :
56 : pub struct TlsConfig {
57 : pub config: Arc<rustls::ServerConfig>,
58 : pub common_names: HashSet<String>,
59 : pub cert_resolver: Arc<CertResolver>,
60 : }
61 :
62 : pub struct HttpConfig {
63 : pub accept_websockets: bool,
64 : pub pool_options: GlobalConnPoolOptions,
65 : pub cancel_set: CancelSet,
66 : pub client_conn_threshold: u64,
67 : pub max_request_size_bytes: usize,
68 : pub max_response_size_bytes: usize,
69 : }
70 :
71 : pub struct AuthenticationConfig {
72 : pub thread_pool: Arc<ThreadPool>,
73 : pub scram_protocol_timeout: tokio::time::Duration,
74 : pub rate_limiter_enabled: bool,
75 : pub rate_limiter: AuthRateLimiter,
76 : pub rate_limit_ip_subnet: u8,
77 : pub ip_allowlist_check_enabled: bool,
78 : pub jwks_cache: JwkCache,
79 : pub is_auth_broker: bool,
80 : pub accept_jwts: bool,
81 : pub console_redirect_confirmation_timeout: tokio::time::Duration,
82 : }
83 :
84 : impl TlsConfig {
85 20 : pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
86 20 : self.config.clone()
87 20 : }
88 : }
89 :
90 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
91 : pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
92 :
93 : /// Configure TLS for the main endpoint.
94 0 : pub fn configure_tls(
95 0 : key_path: &str,
96 0 : cert_path: &str,
97 0 : certs_dir: Option<&String>,
98 0 : ) -> anyhow::Result<TlsConfig> {
99 0 : let mut cert_resolver = CertResolver::new();
100 0 :
101 0 : // add default certificate
102 0 : cert_resolver.add_cert_path(key_path, cert_path, true)?;
103 :
104 : // add extra certificates
105 0 : if let Some(certs_dir) = certs_dir {
106 0 : for entry in std::fs::read_dir(certs_dir)? {
107 0 : let entry = entry?;
108 0 : let path = entry.path();
109 0 : if path.is_dir() {
110 : // file names aligned with default cert-manager names
111 0 : let key_path = path.join("tls.key");
112 0 : let cert_path = path.join("tls.crt");
113 0 : if key_path.exists() && cert_path.exists() {
114 0 : cert_resolver.add_cert_path(
115 0 : &key_path.to_string_lossy(),
116 0 : &cert_path.to_string_lossy(),
117 0 : false,
118 0 : )?;
119 0 : }
120 0 : }
121 : }
122 0 : }
123 :
124 0 : let common_names = cert_resolver.get_common_names();
125 0 :
126 0 : let cert_resolver = Arc::new(cert_resolver);
127 :
128 : // allow TLS 1.2 to be compatible with older client libraries
129 0 : let mut config =
130 0 : rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
131 0 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
132 0 : .context("ring should support TLS1.2 and TLS1.3")?
133 0 : .with_no_client_auth()
134 0 : .with_cert_resolver(cert_resolver.clone());
135 0 :
136 0 : config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
137 0 :
138 0 : Ok(TlsConfig {
139 0 : config: Arc::new(config),
140 0 : common_names,
141 0 : cert_resolver,
142 0 : })
143 0 : }
144 :
145 : /// Channel binding parameter
146 : ///
147 : /// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
148 : /// Description: The hash of the TLS server's certificate as it
149 : /// appears, octet for octet, in the server's Certificate message. Note
150 : /// that the Certificate message contains a certificate_list, in which
151 : /// the first element is the server's certificate.
152 : ///
153 : /// The hash function is to be selected as follows:
154 : ///
155 : /// * if the certificate's signatureAlgorithm uses a single hash
156 : /// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
157 : ///
158 : /// * if the certificate's signatureAlgorithm uses a single hash
159 : /// function and that hash function neither MD5 nor SHA-1, then use
160 : /// the hash function associated with the certificate's
161 : /// signatureAlgorithm;
162 : ///
163 : /// * if the certificate's signatureAlgorithm uses no hash functions or
164 : /// uses multiple hash functions, then this channel binding type's
165 : /// channel bindings are undefined at this time (updates to is channel
166 : /// binding type may occur to address this issue if it ever arises).
167 : #[derive(Debug, Clone, Copy)]
168 : pub enum TlsServerEndPoint {
169 : Sha256([u8; 32]),
170 : Undefined,
171 : }
172 :
173 : impl TlsServerEndPoint {
174 21 : pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
175 21 : let sha256_oids = [
176 21 : // I'm explicitly not adding MD5 or SHA1 here... They're bad.
177 21 : oid_registry::OID_SIG_ECDSA_WITH_SHA256,
178 21 : oid_registry::OID_PKCS1_SHA256WITHRSA,
179 21 : ];
180 :
181 21 : let pem = x509_parser::parse_x509_certificate(cert)
182 21 : .context("Failed to parse PEM object from cerficiate")?
183 : .1;
184 :
185 21 : info!(subject = %pem.subject, "parsing TLS certificate");
186 :
187 21 : let reg = oid_registry::OidRegistry::default().with_all_crypto();
188 21 : let oid = pem.signature_algorithm.oid();
189 21 : let alg = reg.get(oid);
190 21 : if sha256_oids.contains(oid) {
191 21 : let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
192 21 : info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
193 21 : Ok(Self::Sha256(tls_server_end_point))
194 : } else {
195 0 : error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
196 0 : Ok(Self::Undefined)
197 : }
198 21 : }
199 :
200 16 : pub fn supported(&self) -> bool {
201 16 : !matches!(self, TlsServerEndPoint::Undefined)
202 16 : }
203 : }
204 :
205 : #[derive(Default, Debug)]
206 : pub struct CertResolver {
207 : certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
208 : default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
209 : }
210 :
211 : impl CertResolver {
212 21 : pub fn new() -> Self {
213 21 : Self::default()
214 21 : }
215 :
216 0 : fn add_cert_path(
217 0 : &mut self,
218 0 : key_path: &str,
219 0 : cert_path: &str,
220 0 : is_default: bool,
221 0 : ) -> anyhow::Result<()> {
222 0 : let priv_key = {
223 0 : let key_bytes = std::fs::read(key_path)
224 0 : .context(format!("Failed to read TLS keys at '{key_path}'"))?;
225 0 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
226 0 :
227 0 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
228 : PrivateKeyDer::Pkcs8(
229 0 : keys.pop()
230 0 : .unwrap()
231 0 : .context(format!("Failed to parse TLS keys at '{key_path}'"))?,
232 : )
233 : };
234 :
235 0 : let cert_chain_bytes = std::fs::read(cert_path)
236 0 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
237 :
238 0 : let cert_chain = {
239 0 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
240 0 : .try_collect()
241 0 : .with_context(|| {
242 0 : format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
243 0 : })?
244 : };
245 :
246 0 : self.add_cert(priv_key, cert_chain, is_default)
247 0 : }
248 :
249 21 : pub fn add_cert(
250 21 : &mut self,
251 21 : priv_key: PrivateKeyDer<'static>,
252 21 : cert_chain: Vec<CertificateDer<'static>>,
253 21 : is_default: bool,
254 21 : ) -> anyhow::Result<()> {
255 21 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
256 :
257 21 : let first_cert = &cert_chain[0];
258 21 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
259 21 : let pem = x509_parser::parse_x509_certificate(first_cert)
260 21 : .context("Failed to parse PEM object from cerficiate")?
261 : .1;
262 :
263 21 : let common_name = pem.subject().to_string();
264 :
265 : // We need to get the canonical name for this certificate so we can match them against any domain names
266 : // seen within the proxy codebase.
267 : //
268 : // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
269 : // We need to remove the wildcard prefix for the purposes of certificate selection.
270 : //
271 : // auth-broker does not use SNI and instead uses the Neon-Connection-String header.
272 : // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
273 : //
274 : // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
275 : // validation, so let's we can continue with any common-name
276 21 : let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
277 0 : s.to_string()
278 21 : } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
279 0 : s.to_string()
280 21 : } else if let Some(s) = common_name.strip_prefix("CN=") {
281 21 : s.to_string()
282 : } else {
283 0 : bail!("Failed to parse common name from certificate")
284 : };
285 :
286 21 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
287 21 :
288 21 : if is_default {
289 21 : self.default = Some((cert.clone(), tls_server_end_point));
290 21 : }
291 :
292 21 : self.certs.insert(common_name, (cert, tls_server_end_point));
293 21 :
294 21 : Ok(())
295 21 : }
296 :
297 21 : pub fn get_common_names(&self) -> HashSet<String> {
298 21 : self.certs.keys().map(|s| s.to_string()).collect()
299 21 : }
300 : }
301 :
302 : impl rustls::server::ResolvesServerCert for CertResolver {
303 0 : fn resolve(
304 0 : &self,
305 0 : client_hello: rustls::server::ClientHello<'_>,
306 0 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
307 0 : self.resolve(client_hello.server_name()).map(|x| x.0)
308 0 : }
309 : }
310 :
311 : impl CertResolver {
312 20 : pub fn resolve(
313 20 : &self,
314 20 : server_name: Option<&str>,
315 20 : ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
316 : // loop here and cut off more and more subdomains until we find
317 : // a match to get a proper wildcard support. OTOH, we now do not
318 : // use nested domains, so keep this simple for now.
319 : //
320 : // With the current coding foo.com will match *.foo.com and that
321 : // repeats behavior of the old code.
322 20 : if let Some(mut sni_name) = server_name {
323 : loop {
324 40 : if let Some(cert) = self.certs.get(sni_name) {
325 20 : return Some(cert.clone());
326 20 : }
327 20 : if let Some((_, rest)) = sni_name.split_once('.') {
328 20 : sni_name = rest;
329 20 : } else {
330 0 : return None;
331 : }
332 : }
333 : } else {
334 : // No SNI, use the default certificate, otherwise we can't get to
335 : // options parameter which can be used to set endpoint name too.
336 : // That means that non-SNI flow will not work for CNAME domains in
337 : // verify-full mode.
338 : //
339 : // If that will be a problem we can:
340 : //
341 : // a) Instead of multi-cert approach use single cert with extra
342 : // domains listed in Subject Alternative Name (SAN).
343 : // b) Deploy separate proxy instances for extra domains.
344 0 : self.default.clone()
345 : }
346 20 : }
347 : }
348 :
349 : #[derive(Debug)]
350 : pub struct EndpointCacheConfig {
351 : /// Batch size to receive all endpoints on the startup.
352 : pub initial_batch_size: usize,
353 : /// Batch size to receive endpoints.
354 : pub default_batch_size: usize,
355 : /// Timeouts for the stream read operation.
356 : pub xread_timeout: Duration,
357 : /// Stream name to read from.
358 : pub stream_name: String,
359 : /// Limiter info (to distinguish when to enable cache).
360 : pub limiter_info: Vec<RateBucketInfo>,
361 : /// Disable cache.
362 : /// If true, cache is ignored, but reports all statistics.
363 : pub disable_cache: bool,
364 : /// Retry interval for the stream read operation.
365 : pub retry_interval: Duration,
366 : }
367 :
368 : impl EndpointCacheConfig {
369 : /// Default options for [`crate::control_plane::NodeInfoCache`].
370 : /// Notice that by default the limiter is empty, which means that cache is disabled.
371 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
372 : "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
373 :
374 : /// Parse cache options passed via cmdline.
375 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
376 0 : fn parse(options: &str) -> anyhow::Result<Self> {
377 0 : let mut initial_batch_size = None;
378 0 : let mut default_batch_size = None;
379 0 : let mut xread_timeout = None;
380 0 : let mut stream_name = None;
381 0 : let mut limiter_info = vec![];
382 0 : let mut disable_cache = false;
383 0 : let mut retry_interval = None;
384 :
385 0 : for option in options.split(',') {
386 0 : let (key, value) = option
387 0 : .split_once('=')
388 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
389 :
390 0 : match key {
391 0 : "initial_batch_size" => initial_batch_size = Some(value.parse()?),
392 0 : "default_batch_size" => default_batch_size = Some(value.parse()?),
393 0 : "xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
394 0 : "stream_name" => stream_name = Some(value.to_string()),
395 0 : "limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
396 0 : "disable_cache" => disable_cache = value.parse()?,
397 0 : "retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
398 0 : unknown => bail!("unknown key: {unknown}"),
399 : }
400 : }
401 0 : RateBucketInfo::validate(&mut limiter_info)?;
402 :
403 : Ok(Self {
404 0 : initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
405 0 : default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
406 0 : xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
407 0 : stream_name: stream_name.context("missing `stream_name`")?,
408 0 : disable_cache,
409 0 : limiter_info,
410 0 : retry_interval: retry_interval.context("missing `retry_interval`")?,
411 : })
412 0 : }
413 : }
414 :
415 : impl FromStr for EndpointCacheConfig {
416 : type Err = anyhow::Error;
417 :
418 0 : fn from_str(options: &str) -> Result<Self, Self::Err> {
419 0 : let error = || format!("failed to parse endpoint cache options '{options}'");
420 0 : Self::parse(options).with_context(error)
421 0 : }
422 : }
423 : #[derive(Debug)]
424 : pub struct MetricBackupCollectionConfig {
425 : pub interval: Duration,
426 : pub remote_storage_config: Option<RemoteStorageConfig>,
427 : pub chunk_size: usize,
428 : }
429 :
430 1 : pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig> {
431 1 : RemoteStorageConfig::from_toml(&s.parse()?)
432 1 : }
433 :
434 : /// Helper for cmdline cache options parsing.
435 : #[derive(Debug)]
436 : pub struct CacheOptions {
437 : /// Max number of entries.
438 : pub size: usize,
439 : /// Entry's time-to-live.
440 : pub ttl: Duration,
441 : }
442 :
443 : impl CacheOptions {
444 : /// Default options for [`crate::control_plane::NodeInfoCache`].
445 : pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
446 :
447 : /// Parse cache options passed via cmdline.
448 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
449 4 : fn parse(options: &str) -> anyhow::Result<Self> {
450 4 : let mut size = None;
451 4 : let mut ttl = None;
452 :
453 7 : for option in options.split(',') {
454 7 : let (key, value) = option
455 7 : .split_once('=')
456 7 : .with_context(|| format!("bad key-value pair: {option}"))?;
457 :
458 7 : match key {
459 7 : "size" => size = Some(value.parse()?),
460 3 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
461 0 : unknown => bail!("unknown key: {unknown}"),
462 : }
463 : }
464 :
465 : // TTL doesn't matter if cache is always empty.
466 4 : if let Some(0) = size {
467 2 : ttl.get_or_insert(Duration::default());
468 2 : }
469 :
470 : Ok(Self {
471 4 : size: size.context("missing `size`")?,
472 4 : ttl: ttl.context("missing `ttl`")?,
473 : })
474 4 : }
475 : }
476 :
477 : impl FromStr for CacheOptions {
478 : type Err = anyhow::Error;
479 :
480 4 : fn from_str(options: &str) -> Result<Self, Self::Err> {
481 4 : let error = || format!("failed to parse cache options '{options}'");
482 4 : Self::parse(options).with_context(error)
483 4 : }
484 : }
485 :
486 : /// Helper for cmdline cache options parsing.
487 : #[derive(Debug)]
488 : pub struct ProjectInfoCacheOptions {
489 : /// Max number of entries.
490 : pub size: usize,
491 : /// Entry's time-to-live.
492 : pub ttl: Duration,
493 : /// Max number of roles per endpoint.
494 : pub max_roles: usize,
495 : /// Gc interval.
496 : pub gc_interval: Duration,
497 : }
498 :
499 : impl ProjectInfoCacheOptions {
500 : /// Default options for [`crate::control_plane::NodeInfoCache`].
501 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
502 : "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
503 :
504 : /// Parse cache options passed via cmdline.
505 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
506 0 : fn parse(options: &str) -> anyhow::Result<Self> {
507 0 : let mut size = None;
508 0 : let mut ttl = None;
509 0 : let mut max_roles = None;
510 0 : let mut gc_interval = None;
511 :
512 0 : for option in options.split(',') {
513 0 : let (key, value) = option
514 0 : .split_once('=')
515 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
516 :
517 0 : match key {
518 0 : "size" => size = Some(value.parse()?),
519 0 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
520 0 : "max_roles" => max_roles = Some(value.parse()?),
521 0 : "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
522 0 : unknown => bail!("unknown key: {unknown}"),
523 : }
524 : }
525 :
526 : // TTL doesn't matter if cache is always empty.
527 0 : if let Some(0) = size {
528 0 : ttl.get_or_insert(Duration::default());
529 0 : }
530 :
531 : Ok(Self {
532 0 : size: size.context("missing `size`")?,
533 0 : ttl: ttl.context("missing `ttl`")?,
534 0 : max_roles: max_roles.context("missing `max_roles`")?,
535 0 : gc_interval: gc_interval.context("missing `gc_interval`")?,
536 : })
537 0 : }
538 : }
539 :
540 : impl FromStr for ProjectInfoCacheOptions {
541 : type Err = anyhow::Error;
542 :
543 0 : fn from_str(options: &str) -> Result<Self, Self::Err> {
544 0 : let error = || format!("failed to parse cache options '{options}'");
545 0 : Self::parse(options).with_context(error)
546 0 : }
547 : }
548 :
549 : /// This is a config for connect to compute and wake compute.
550 : #[derive(Clone, Copy, Debug)]
551 : pub struct RetryConfig {
552 : /// Number of times we should retry.
553 : pub max_retries: u32,
554 : /// Retry duration is base_delay * backoff_factor ^ n, where n starts at 0
555 : pub base_delay: tokio::time::Duration,
556 : /// Exponential base for retry wait duration
557 : pub backoff_factor: f64,
558 : }
559 :
560 : impl RetryConfig {
561 : // Default options for RetryConfig.
562 :
563 : /// Total delay for 5 retries with 200ms base delay and 2 backoff factor is about 6s.
564 : pub const CONNECT_TO_COMPUTE_DEFAULT_VALUES: &'static str =
565 : "num_retries=5,base_retry_wait_duration=200ms,retry_wait_exponent_base=2";
566 : /// Total delay for 8 retries with 100ms base delay and 1.6 backoff factor is about 7s.
567 : /// Cplane has timeout of 60s on each request. 8m7s in total.
568 : pub const WAKE_COMPUTE_DEFAULT_VALUES: &'static str =
569 : "num_retries=8,base_retry_wait_duration=100ms,retry_wait_exponent_base=1.6";
570 :
571 : /// Parse retry options passed via cmdline.
572 : /// Example: [`Self::CONNECT_TO_COMPUTE_DEFAULT_VALUES`].
573 0 : pub fn parse(options: &str) -> anyhow::Result<Self> {
574 0 : let mut num_retries = None;
575 0 : let mut base_retry_wait_duration = None;
576 0 : let mut retry_wait_exponent_base = None;
577 :
578 0 : for option in options.split(',') {
579 0 : let (key, value) = option
580 0 : .split_once('=')
581 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
582 :
583 0 : match key {
584 0 : "num_retries" => num_retries = Some(value.parse()?),
585 0 : "base_retry_wait_duration" => {
586 0 : base_retry_wait_duration = Some(humantime::parse_duration(value)?);
587 : }
588 0 : "retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?),
589 0 : unknown => bail!("unknown key: {unknown}"),
590 : }
591 : }
592 :
593 : Ok(Self {
594 0 : max_retries: num_retries.context("missing `num_retries`")?,
595 0 : base_delay: base_retry_wait_duration.context("missing `base_retry_wait_duration`")?,
596 0 : backoff_factor: retry_wait_exponent_base
597 0 : .context("missing `retry_wait_exponent_base`")?,
598 : })
599 0 : }
600 : }
601 :
602 : /// Helper for cmdline cache options parsing.
603 8 : #[derive(serde::Deserialize)]
604 : pub struct ConcurrencyLockOptions {
605 : /// The number of shards the lock map should have
606 : pub shards: usize,
607 : /// The number of allowed concurrent requests for each endpoitn
608 : #[serde(flatten)]
609 : pub limiter: RateLimiterConfig,
610 : /// Garbage collection epoch
611 : #[serde(deserialize_with = "humantime_serde::deserialize")]
612 : pub epoch: Duration,
613 : /// Lock timeout
614 : #[serde(deserialize_with = "humantime_serde::deserialize")]
615 : pub timeout: Duration,
616 : }
617 :
618 : impl ConcurrencyLockOptions {
619 : /// Default options for [`crate::control_plane::client::ApiLocks`].
620 : pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
621 : /// Default options for [`crate::control_plane::client::ApiLocks`].
622 : pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
623 : "shards=64,permits=100,epoch=10m,timeout=10ms";
624 :
625 : // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
626 :
627 : /// Parse lock options passed via cmdline.
628 : /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
629 4 : fn parse(options: &str) -> anyhow::Result<Self> {
630 4 : let options = options.trim();
631 4 : if options.starts_with('{') && options.ends_with('}') {
632 1 : return Ok(serde_json::from_str(options)?);
633 3 : }
634 3 :
635 3 : let mut shards = None;
636 3 : let mut permits = None;
637 3 : let mut epoch = None;
638 3 : let mut timeout = None;
639 :
640 9 : for option in options.split(',') {
641 9 : let (key, value) = option
642 9 : .split_once('=')
643 9 : .with_context(|| format!("bad key-value pair: {option}"))?;
644 :
645 9 : match key {
646 9 : "shards" => shards = Some(value.parse()?),
647 7 : "permits" => permits = Some(value.parse()?),
648 4 : "epoch" => epoch = Some(humantime::parse_duration(value)?),
649 2 : "timeout" => timeout = Some(humantime::parse_duration(value)?),
650 0 : unknown => bail!("unknown key: {unknown}"),
651 : }
652 : }
653 :
654 : // these dont matter if lock is disabled
655 3 : if let Some(0) = permits {
656 1 : timeout = Some(Duration::default());
657 1 : epoch = Some(Duration::default());
658 1 : shards = Some(2);
659 2 : }
660 :
661 3 : let permits = permits.context("missing `permits`")?;
662 3 : let out = Self {
663 3 : shards: shards.context("missing `shards`")?,
664 3 : limiter: RateLimiterConfig {
665 3 : algorithm: RateLimitAlgorithm::Fixed,
666 3 : initial_limit: permits,
667 3 : },
668 3 : epoch: epoch.context("missing `epoch`")?,
669 3 : timeout: timeout.context("missing `timeout`")?,
670 : };
671 :
672 3 : ensure!(out.shards > 1, "shard count must be > 1");
673 3 : ensure!(
674 3 : out.shards.is_power_of_two(),
675 0 : "shard count must be a power of two"
676 : );
677 :
678 3 : Ok(out)
679 4 : }
680 : }
681 :
682 : impl FromStr for ConcurrencyLockOptions {
683 : type Err = anyhow::Error;
684 :
685 4 : fn from_str(options: &str) -> Result<Self, Self::Err> {
686 4 : let error = || format!("failed to parse cache lock options '{options}'");
687 4 : Self::parse(options).with_context(error)
688 4 : }
689 : }
690 :
691 : #[cfg(test)]
692 : mod tests {
693 : use super::*;
694 : use crate::rate_limiter::Aimd;
695 :
696 : #[test]
697 1 : fn test_parse_cache_options() -> anyhow::Result<()> {
698 1 : let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
699 1 : assert_eq!(size, 4096);
700 1 : assert_eq!(ttl, Duration::from_secs(5 * 60));
701 :
702 1 : let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
703 1 : assert_eq!(size, 2);
704 1 : assert_eq!(ttl, Duration::from_secs(4 * 60));
705 :
706 1 : let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
707 1 : assert_eq!(size, 0);
708 1 : assert_eq!(ttl, Duration::from_secs(1));
709 :
710 1 : let CacheOptions { size, ttl } = "size=0".parse()?;
711 1 : assert_eq!(size, 0);
712 1 : assert_eq!(ttl, Duration::default());
713 :
714 1 : Ok(())
715 1 : }
716 :
717 : #[test]
718 1 : fn test_parse_lock_options() -> anyhow::Result<()> {
719 : let ConcurrencyLockOptions {
720 1 : epoch,
721 1 : limiter,
722 1 : shards,
723 1 : timeout,
724 1 : } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
725 1 : assert_eq!(epoch, Duration::from_secs(10 * 60));
726 1 : assert_eq!(timeout, Duration::from_secs(1));
727 1 : assert_eq!(shards, 32);
728 1 : assert_eq!(limiter.initial_limit, 4);
729 1 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
730 :
731 : let ConcurrencyLockOptions {
732 1 : epoch,
733 1 : limiter,
734 1 : shards,
735 1 : timeout,
736 1 : } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
737 1 : assert_eq!(epoch, Duration::from_secs(60));
738 1 : assert_eq!(timeout, Duration::from_millis(100));
739 1 : assert_eq!(shards, 16);
740 1 : assert_eq!(limiter.initial_limit, 8);
741 1 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
742 :
743 : let ConcurrencyLockOptions {
744 1 : epoch,
745 1 : limiter,
746 1 : shards,
747 1 : timeout,
748 1 : } = "permits=0".parse()?;
749 1 : assert_eq!(epoch, Duration::ZERO);
750 1 : assert_eq!(timeout, Duration::ZERO);
751 1 : assert_eq!(shards, 2);
752 1 : assert_eq!(limiter.initial_limit, 0);
753 1 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
754 :
755 1 : Ok(())
756 1 : }
757 :
758 : #[test]
759 1 : fn test_parse_json_lock_options() -> anyhow::Result<()> {
760 : let ConcurrencyLockOptions {
761 1 : epoch,
762 1 : limiter,
763 1 : shards,
764 1 : timeout,
765 1 : } = r#"{"shards":32,"initial_limit":44,"aimd":{"min":5,"max":500,"inc":10,"dec":0.9,"utilisation":0.8},"epoch":"10m","timeout":"1s"}"#
766 1 : .parse()?;
767 1 : assert_eq!(epoch, Duration::from_secs(10 * 60));
768 1 : assert_eq!(timeout, Duration::from_secs(1));
769 1 : assert_eq!(shards, 32);
770 1 : assert_eq!(limiter.initial_limit, 44);
771 1 : assert_eq!(
772 1 : limiter.algorithm,
773 1 : RateLimitAlgorithm::Aimd {
774 1 : conf: Aimd {
775 1 : min: 5,
776 1 : max: 500,
777 1 : dec: 0.9,
778 1 : inc: 10,
779 1 : utilisation: 0.8
780 1 : }
781 1 : },
782 1 : );
783 :
784 1 : Ok(())
785 1 : }
786 : }
|