Line data Source code
1 : use std::str::FromStr;
2 : use std::sync::Arc;
3 : use std::time::Duration;
4 :
5 : use anyhow::{Context, Ok, bail, ensure};
6 : use arc_swap::ArcSwapOption;
7 : use camino::{Utf8Path, Utf8PathBuf};
8 : use clap::ValueEnum;
9 : use compute_api::spec::LocalProxySpec;
10 : use remote_storage::RemoteStorageConfig;
11 : use thiserror::Error;
12 : use tokio::sync::Notify;
13 : use tracing::{debug, error, info, warn};
14 :
15 : use crate::auth::backend::jwt::JwkCache;
16 : use crate::auth::backend::local::JWKS_ROLE_MAP;
17 : use crate::control_plane::locks::ApiLocks;
18 : use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
19 : use crate::ext::TaskExt;
20 : use crate::intern::RoleNameInt;
21 : use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
22 : use crate::scram::threadpool::ThreadPool;
23 : use crate::serverless::GlobalConnPoolOptions;
24 : use crate::serverless::cancel_set::CancelSet;
25 : pub use crate::tls::server_config::{TlsConfig, configure_tls};
26 : use crate::types::{Host, RoleName};
27 :
28 : pub struct ProxyConfig {
29 : pub tls_config: ArcSwapOption<TlsConfig>,
30 : pub metric_collection: Option<MetricCollectionConfig>,
31 : pub http_config: HttpConfig,
32 : pub authentication_config: AuthenticationConfig,
33 : pub proxy_protocol_v2: ProxyProtocolV2,
34 : pub handshake_timeout: Duration,
35 : pub wake_compute_retry_config: RetryConfig,
36 : pub connect_compute_locks: ApiLocks<Host>,
37 : pub connect_to_compute: ComputeConfig,
38 : #[cfg(feature = "testing")]
39 : pub disable_pg_session_jwt: bool,
40 : }
41 :
42 : pub struct ComputeConfig {
43 : pub retry: RetryConfig,
44 : pub tls: Arc<rustls::ClientConfig>,
45 : pub timeout: Duration,
46 : }
47 :
48 : #[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
49 : pub enum ProxyProtocolV2 {
50 : /// Connection will error if PROXY protocol v2 header is missing
51 : Required,
52 : /// Connection will error if PROXY protocol v2 header is provided
53 : Rejected,
54 : }
55 :
56 : #[derive(Debug)]
57 : pub struct MetricCollectionConfig {
58 : pub endpoint: reqwest::Url,
59 : pub interval: Duration,
60 : pub backup_metric_collection_config: MetricBackupCollectionConfig,
61 : }
62 :
63 : pub struct HttpConfig {
64 : pub accept_websockets: bool,
65 : pub pool_options: GlobalConnPoolOptions,
66 : pub cancel_set: CancelSet,
67 : pub client_conn_threshold: u64,
68 : pub max_request_size_bytes: usize,
69 : pub max_response_size_bytes: usize,
70 : }
71 :
72 : pub struct AuthenticationConfig {
73 : pub thread_pool: Arc<ThreadPool>,
74 : pub scram_protocol_timeout: tokio::time::Duration,
75 : pub ip_allowlist_check_enabled: bool,
76 : pub is_vpc_acccess_proxy: bool,
77 : pub jwks_cache: JwkCache,
78 : pub is_auth_broker: bool,
79 : pub accept_jwts: bool,
80 : pub console_redirect_confirmation_timeout: tokio::time::Duration,
81 : }
82 :
83 : #[derive(Debug)]
84 : pub struct MetricBackupCollectionConfig {
85 : pub remote_storage_config: Option<RemoteStorageConfig>,
86 : pub chunk_size: usize,
87 : }
88 :
89 1 : pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig> {
90 1 : RemoteStorageConfig::from_toml(&s.parse()?)
91 1 : }
92 :
93 : /// Helper for cmdline cache options parsing.
94 : #[derive(Debug)]
95 : pub struct CacheOptions {
96 : /// Max number of entries.
97 : pub size: usize,
98 : /// Entry's time-to-live.
99 : pub ttl: Duration,
100 : }
101 :
102 : impl CacheOptions {
103 : /// Default options for [`crate::control_plane::NodeInfoCache`].
104 : pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
105 :
106 : /// Parse cache options passed via cmdline.
107 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
108 4 : fn parse(options: &str) -> anyhow::Result<Self> {
109 4 : let mut size = None;
110 4 : let mut ttl = None;
111 :
112 7 : for option in options.split(',') {
113 7 : let (key, value) = option
114 7 : .split_once('=')
115 7 : .with_context(|| format!("bad key-value pair: {option}"))?;
116 :
117 7 : match key {
118 7 : "size" => size = Some(value.parse()?),
119 3 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
120 0 : unknown => bail!("unknown key: {unknown}"),
121 : }
122 : }
123 :
124 : // TTL doesn't matter if cache is always empty.
125 4 : if let Some(0) = size {
126 2 : ttl.get_or_insert(Duration::default());
127 2 : }
128 :
129 4 : Ok(Self {
130 4 : size: size.context("missing `size`")?,
131 4 : ttl: ttl.context("missing `ttl`")?,
132 : })
133 4 : }
134 : }
135 :
136 : impl FromStr for CacheOptions {
137 : type Err = anyhow::Error;
138 :
139 4 : fn from_str(options: &str) -> Result<Self, Self::Err> {
140 4 : let error = || format!("failed to parse cache options '{options}'");
141 4 : Self::parse(options).with_context(error)
142 4 : }
143 : }
144 :
145 : /// Helper for cmdline cache options parsing.
146 : #[derive(Debug)]
147 : pub struct ProjectInfoCacheOptions {
148 : /// Max number of entries.
149 : pub size: usize,
150 : /// Entry's time-to-live.
151 : pub ttl: Duration,
152 : /// Max number of roles per endpoint.
153 : pub max_roles: usize,
154 : /// Gc interval.
155 : pub gc_interval: Duration,
156 : }
157 :
158 : impl ProjectInfoCacheOptions {
159 : /// Default options for [`crate::control_plane::NodeInfoCache`].
160 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
161 : "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
162 :
163 : /// Parse cache options passed via cmdline.
164 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
165 0 : fn parse(options: &str) -> anyhow::Result<Self> {
166 0 : let mut size = None;
167 0 : let mut ttl = None;
168 0 : let mut max_roles = None;
169 0 : let mut gc_interval = None;
170 :
171 0 : for option in options.split(',') {
172 0 : let (key, value) = option
173 0 : .split_once('=')
174 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
175 :
176 0 : match key {
177 0 : "size" => size = Some(value.parse()?),
178 0 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
179 0 : "max_roles" => max_roles = Some(value.parse()?),
180 0 : "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
181 0 : unknown => bail!("unknown key: {unknown}"),
182 : }
183 : }
184 :
185 : // TTL doesn't matter if cache is always empty.
186 0 : if let Some(0) = size {
187 0 : ttl.get_or_insert(Duration::default());
188 0 : }
189 :
190 0 : Ok(Self {
191 0 : size: size.context("missing `size`")?,
192 0 : ttl: ttl.context("missing `ttl`")?,
193 0 : max_roles: max_roles.context("missing `max_roles`")?,
194 0 : gc_interval: gc_interval.context("missing `gc_interval`")?,
195 : })
196 0 : }
197 : }
198 :
199 : impl FromStr for ProjectInfoCacheOptions {
200 : type Err = anyhow::Error;
201 :
202 0 : fn from_str(options: &str) -> Result<Self, Self::Err> {
203 0 : let error = || format!("failed to parse cache options '{options}'");
204 0 : Self::parse(options).with_context(error)
205 0 : }
206 : }
207 :
208 : /// This is a config for connect to compute and wake compute.
209 : #[derive(Clone, Copy, Debug)]
210 : pub struct RetryConfig {
211 : /// Number of times we should retry.
212 : pub max_retries: u32,
213 : /// Retry duration is base_delay * backoff_factor ^ n, where n starts at 0
214 : pub base_delay: tokio::time::Duration,
215 : /// Exponential base for retry wait duration
216 : pub backoff_factor: f64,
217 : }
218 :
219 : impl RetryConfig {
220 : // Default options for RetryConfig.
221 :
222 : /// Total delay for 5 retries with 200ms base delay and 2 backoff factor is about 6s.
223 : pub const CONNECT_TO_COMPUTE_DEFAULT_VALUES: &'static str =
224 : "num_retries=5,base_retry_wait_duration=200ms,retry_wait_exponent_base=2";
225 : /// Total delay for 8 retries with 100ms base delay and 1.6 backoff factor is about 7s.
226 : /// Cplane has timeout of 60s on each request. 8m7s in total.
227 : pub const WAKE_COMPUTE_DEFAULT_VALUES: &'static str =
228 : "num_retries=8,base_retry_wait_duration=100ms,retry_wait_exponent_base=1.6";
229 :
230 : /// Parse retry options passed via cmdline.
231 : /// Example: [`Self::CONNECT_TO_COMPUTE_DEFAULT_VALUES`].
232 0 : pub fn parse(options: &str) -> anyhow::Result<Self> {
233 0 : let mut num_retries = None;
234 0 : let mut base_retry_wait_duration = None;
235 0 : let mut retry_wait_exponent_base = None;
236 :
237 0 : for option in options.split(',') {
238 0 : let (key, value) = option
239 0 : .split_once('=')
240 0 : .with_context(|| format!("bad key-value pair: {option}"))?;
241 :
242 0 : match key {
243 0 : "num_retries" => num_retries = Some(value.parse()?),
244 0 : "base_retry_wait_duration" => {
245 0 : base_retry_wait_duration = Some(humantime::parse_duration(value)?);
246 : }
247 0 : "retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?),
248 0 : unknown => bail!("unknown key: {unknown}"),
249 : }
250 : }
251 :
252 0 : Ok(Self {
253 0 : max_retries: num_retries.context("missing `num_retries`")?,
254 0 : base_delay: base_retry_wait_duration.context("missing `base_retry_wait_duration`")?,
255 0 : backoff_factor: retry_wait_exponent_base
256 0 : .context("missing `retry_wait_exponent_base`")?,
257 : })
258 0 : }
259 : }
260 :
261 : /// Helper for cmdline cache options parsing.
262 : #[derive(serde::Deserialize)]
263 : pub struct ConcurrencyLockOptions {
264 : /// The number of shards the lock map should have
265 : pub shards: usize,
266 : /// The number of allowed concurrent requests for each endpoitn
267 : #[serde(flatten)]
268 : pub limiter: RateLimiterConfig,
269 : /// Garbage collection epoch
270 : #[serde(deserialize_with = "humantime_serde::deserialize")]
271 : pub epoch: Duration,
272 : /// Lock timeout
273 : #[serde(deserialize_with = "humantime_serde::deserialize")]
274 : pub timeout: Duration,
275 : }
276 :
277 : impl ConcurrencyLockOptions {
278 : /// Default options for [`crate::control_plane::client::ApiLocks`].
279 : pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
280 : /// Default options for [`crate::control_plane::client::ApiLocks`].
281 : pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
282 : "shards=64,permits=100,epoch=10m,timeout=10ms";
283 :
284 : // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
285 :
286 : /// Parse lock options passed via cmdline.
287 : /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
288 4 : fn parse(options: &str) -> anyhow::Result<Self> {
289 4 : let options = options.trim();
290 4 : if options.starts_with('{') && options.ends_with('}') {
291 1 : return Ok(serde_json::from_str(options)?);
292 3 : }
293 :
294 3 : let mut shards = None;
295 3 : let mut permits = None;
296 3 : let mut epoch = None;
297 3 : let mut timeout = None;
298 :
299 9 : for option in options.split(',') {
300 9 : let (key, value) = option
301 9 : .split_once('=')
302 9 : .with_context(|| format!("bad key-value pair: {option}"))?;
303 :
304 9 : match key {
305 9 : "shards" => shards = Some(value.parse()?),
306 7 : "permits" => permits = Some(value.parse()?),
307 4 : "epoch" => epoch = Some(humantime::parse_duration(value)?),
308 2 : "timeout" => timeout = Some(humantime::parse_duration(value)?),
309 0 : unknown => bail!("unknown key: {unknown}"),
310 : }
311 : }
312 :
313 : // these dont matter if lock is disabled
314 3 : if let Some(0) = permits {
315 1 : timeout = Some(Duration::default());
316 1 : epoch = Some(Duration::default());
317 1 : shards = Some(2);
318 2 : }
319 :
320 3 : let permits = permits.context("missing `permits`")?;
321 3 : let out = Self {
322 3 : shards: shards.context("missing `shards`")?,
323 3 : limiter: RateLimiterConfig {
324 3 : algorithm: RateLimitAlgorithm::Fixed,
325 3 : initial_limit: permits,
326 3 : },
327 3 : epoch: epoch.context("missing `epoch`")?,
328 3 : timeout: timeout.context("missing `timeout`")?,
329 : };
330 :
331 3 : ensure!(out.shards > 1, "shard count must be > 1");
332 3 : ensure!(
333 3 : out.shards.is_power_of_two(),
334 0 : "shard count must be a power of two"
335 : );
336 :
337 3 : Ok(out)
338 4 : }
339 : }
340 :
341 : impl FromStr for ConcurrencyLockOptions {
342 : type Err = anyhow::Error;
343 :
344 4 : fn from_str(options: &str) -> Result<Self, Self::Err> {
345 4 : let error = || format!("failed to parse cache lock options '{options}'");
346 4 : Self::parse(options).with_context(error)
347 4 : }
348 : }
349 :
350 : #[derive(Error, Debug)]
351 : pub(crate) enum RefreshConfigError {
352 : #[error(transparent)]
353 : Read(#[from] std::io::Error),
354 : #[error(transparent)]
355 : Parse(#[from] serde_json::Error),
356 : #[error(transparent)]
357 : Validate(anyhow::Error),
358 : #[error(transparent)]
359 : Tls(anyhow::Error),
360 : }
361 :
362 0 : pub(crate) async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
363 0 : let mut init = true;
364 : loop {
365 0 : rx.notified().await;
366 :
367 0 : match refresh_config_inner(config, &path).await {
368 0 : std::result::Result::Ok(()) => {}
369 : // don't log for file not found errors if this is the first time we are checking
370 : // for computes that don't use local_proxy, this is not an error.
371 0 : Err(RefreshConfigError::Read(e))
372 0 : if init && e.kind() == std::io::ErrorKind::NotFound =>
373 : {
374 0 : debug!(error=?e, ?path, "could not read config file");
375 : }
376 0 : Err(RefreshConfigError::Tls(e)) => {
377 0 : error!(error=?e, ?path, "could not read TLS certificates");
378 : }
379 0 : Err(e) => {
380 0 : error!(error=?e, ?path, "could not read config file");
381 : }
382 : }
383 :
384 0 : init = false;
385 : }
386 : }
387 :
388 0 : pub(crate) async fn refresh_config_inner(
389 0 : config: &ProxyConfig,
390 0 : path: &Utf8Path,
391 0 : ) -> Result<(), RefreshConfigError> {
392 0 : let bytes = tokio::fs::read(&path).await?;
393 0 : let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
394 :
395 0 : let mut jwks_set = vec![];
396 :
397 0 : fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
398 0 : let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
399 :
400 0 : ensure!(
401 0 : jwks_url.has_authority()
402 0 : && (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
403 0 : "Invalid JWKS url. Must be HTTP",
404 : );
405 :
406 0 : ensure!(
407 0 : jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
408 0 : "Invalid JWKS url. No domain listed",
409 : );
410 :
411 : // clear username, password and ports
412 0 : jwks_url
413 0 : .set_username("")
414 0 : .expect("url can be a base and has a valid host and is not a file. should not error");
415 0 : jwks_url
416 0 : .set_password(None)
417 0 : .expect("url can be a base and has a valid host and is not a file. should not error");
418 : // local testing is hard if we need to have a specific restricted port
419 0 : if cfg!(not(feature = "testing")) {
420 0 : jwks_url.set_port(None).expect(
421 0 : "url can be a base and has a valid host and is not a file. should not error",
422 0 : );
423 0 : }
424 :
425 : // clear query params
426 0 : jwks_url.set_fragment(None);
427 0 : jwks_url.query_pairs_mut().clear().finish();
428 :
429 0 : if jwks_url.scheme() != "https" {
430 : // local testing is hard if we need to set up https support.
431 0 : if cfg!(not(feature = "testing")) {
432 0 : jwks_url
433 0 : .set_scheme("https")
434 0 : .expect("should not error to set the scheme to https if it was http");
435 0 : } else {
436 0 : warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
437 : }
438 0 : }
439 :
440 0 : Ok(JwksSettings {
441 0 : id: jwks.id,
442 0 : jwks_url,
443 0 : _provider_name: jwks.provider_name,
444 0 : jwt_audience: jwks.jwt_audience,
445 0 : role_names: jwks
446 0 : .role_names
447 0 : .into_iter()
448 0 : .map(RoleName::from)
449 0 : .map(|s| RoleNameInt::from(&s))
450 0 : .collect(),
451 : })
452 0 : }
453 :
454 0 : for jwks in data.jwks.into_iter().flatten() {
455 0 : jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
456 : }
457 :
458 0 : info!("successfully loaded new config");
459 0 : JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
460 :
461 0 : if let Some(tls_config) = data.tls {
462 0 : let tls_config = tokio::task::spawn_blocking(move || {
463 0 : crate::tls::server_config::configure_tls(
464 0 : tls_config.key_path.as_ref(),
465 0 : tls_config.cert_path.as_ref(),
466 0 : None,
467 : false,
468 : )
469 0 : })
470 0 : .await
471 0 : .propagate_task_panic()
472 0 : .map_err(RefreshConfigError::Tls)?;
473 0 : config.tls_config.store(Some(Arc::new(tls_config)));
474 0 : }
475 :
476 0 : std::result::Result::Ok(())
477 0 : }
478 :
479 : #[cfg(test)]
480 : mod tests {
481 : use super::*;
482 : use crate::rate_limiter::Aimd;
483 :
484 : #[test]
485 1 : fn test_parse_cache_options() -> anyhow::Result<()> {
486 1 : let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
487 1 : assert_eq!(size, 4096);
488 1 : assert_eq!(ttl, Duration::from_secs(5 * 60));
489 :
490 1 : let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
491 1 : assert_eq!(size, 2);
492 1 : assert_eq!(ttl, Duration::from_secs(4 * 60));
493 :
494 1 : let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
495 1 : assert_eq!(size, 0);
496 1 : assert_eq!(ttl, Duration::from_secs(1));
497 :
498 1 : let CacheOptions { size, ttl } = "size=0".parse()?;
499 1 : assert_eq!(size, 0);
500 1 : assert_eq!(ttl, Duration::default());
501 :
502 1 : Ok(())
503 1 : }
504 :
505 : #[test]
506 1 : fn test_parse_lock_options() -> anyhow::Result<()> {
507 : let ConcurrencyLockOptions {
508 1 : epoch,
509 1 : limiter,
510 1 : shards,
511 1 : timeout,
512 1 : } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
513 1 : assert_eq!(epoch, Duration::from_secs(10 * 60));
514 1 : assert_eq!(timeout, Duration::from_secs(1));
515 1 : assert_eq!(shards, 32);
516 1 : assert_eq!(limiter.initial_limit, 4);
517 1 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
518 :
519 : let ConcurrencyLockOptions {
520 1 : epoch,
521 1 : limiter,
522 1 : shards,
523 1 : timeout,
524 1 : } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
525 1 : assert_eq!(epoch, Duration::from_secs(60));
526 1 : assert_eq!(timeout, Duration::from_millis(100));
527 1 : assert_eq!(shards, 16);
528 1 : assert_eq!(limiter.initial_limit, 8);
529 1 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
530 :
531 : let ConcurrencyLockOptions {
532 1 : epoch,
533 1 : limiter,
534 1 : shards,
535 1 : timeout,
536 1 : } = "permits=0".parse()?;
537 1 : assert_eq!(epoch, Duration::ZERO);
538 1 : assert_eq!(timeout, Duration::ZERO);
539 1 : assert_eq!(shards, 2);
540 1 : assert_eq!(limiter.initial_limit, 0);
541 1 : assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
542 :
543 1 : Ok(())
544 1 : }
545 :
546 : #[test]
547 1 : fn test_parse_json_lock_options() -> anyhow::Result<()> {
548 : let ConcurrencyLockOptions {
549 1 : epoch,
550 1 : limiter,
551 1 : shards,
552 1 : timeout,
553 1 : } = r#"{"shards":32,"initial_limit":44,"aimd":{"min":5,"max":500,"inc":10,"dec":0.9,"utilisation":0.8},"epoch":"10m","timeout":"1s"}"#
554 1 : .parse()?;
555 1 : assert_eq!(epoch, Duration::from_secs(10 * 60));
556 1 : assert_eq!(timeout, Duration::from_secs(1));
557 1 : assert_eq!(shards, 32);
558 1 : assert_eq!(limiter.initial_limit, 44);
559 1 : assert_eq!(
560 : limiter.algorithm,
561 : RateLimitAlgorithm::Aimd {
562 : conf: Aimd {
563 : min: 5,
564 : max: 500,
565 : dec: 0.9,
566 : inc: 10,
567 : utilisation: 0.8
568 : }
569 : },
570 : );
571 :
572 1 : Ok(())
573 1 : }
574 : }
|