Line data Source code
1 : use crate::{auth, rate_limiter::RateBucketInfo, serverless::GlobalConnPoolOptions};
2 : use anyhow::{bail, ensure, Context, Ok};
3 : use rustls::{sign, Certificate, PrivateKey};
4 : use sha2::{Digest, Sha256};
5 : use std::{
6 : collections::{HashMap, HashSet},
7 : str::FromStr,
8 : sync::Arc,
9 : time::Duration,
10 : };
11 : use tracing::{error, info};
12 : use x509_parser::oid_registry;
13 :
14 : pub struct ProxyConfig {
15 : pub tls_config: Option<TlsConfig>,
16 : pub auth_backend: auth::BackendType<'static, (), ()>,
17 : pub metric_collection: Option<MetricCollectionConfig>,
18 : pub allow_self_signed_compute: bool,
19 : pub http_config: HttpConfig,
20 : pub authentication_config: AuthenticationConfig,
21 : pub require_client_ip: bool,
22 : pub disable_ip_check_for_http: bool,
23 : pub endpoint_rps_limit: Vec<RateBucketInfo>,
24 : pub redis_rps_limit: Vec<RateBucketInfo>,
25 : pub region: String,
26 : pub handshake_timeout: Duration,
27 : }
28 :
29 1 : #[derive(Debug)]
30 : pub struct MetricCollectionConfig {
31 : pub endpoint: reqwest::Url,
32 : pub interval: Duration,
33 : }
34 :
35 : pub struct TlsConfig {
36 : pub config: Arc<rustls::ServerConfig>,
37 : pub common_names: HashSet<String>,
38 : pub cert_resolver: Arc<CertResolver>,
39 : }
40 :
41 : pub struct HttpConfig {
42 : pub request_timeout: tokio::time::Duration,
43 : pub pool_options: GlobalConnPoolOptions,
44 : }
45 :
46 : pub struct AuthenticationConfig {
47 : pub scram_protocol_timeout: tokio::time::Duration,
48 : }
49 :
50 : impl TlsConfig {
51 117 : pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
52 117 : self.config.clone()
53 117 : }
54 : }
55 :
56 : /// Configure TLS for the main endpoint.
57 25 : pub fn configure_tls(
58 25 : key_path: &str,
59 25 : cert_path: &str,
60 25 : certs_dir: Option<&String>,
61 25 : ) -> anyhow::Result<TlsConfig> {
62 25 : let mut cert_resolver = CertResolver::new();
63 25 :
64 25 : // add default certificate
65 25 : cert_resolver.add_cert_path(key_path, cert_path, true)?;
66 :
67 : // add extra certificates
68 25 : if let Some(certs_dir) = certs_dir {
69 0 : for entry in std::fs::read_dir(certs_dir)? {
70 0 : let entry = entry?;
71 0 : let path = entry.path();
72 0 : if path.is_dir() {
73 : // file names aligned with default cert-manager names
74 0 : let key_path = path.join("tls.key");
75 0 : let cert_path = path.join("tls.crt");
76 0 : if key_path.exists() && cert_path.exists() {
77 0 : cert_resolver.add_cert_path(
78 0 : &key_path.to_string_lossy(),
79 0 : &cert_path.to_string_lossy(),
80 0 : false,
81 0 : )?;
82 0 : }
83 0 : }
84 : }
85 25 : }
86 :
87 25 : let common_names = cert_resolver.get_common_names();
88 25 :
89 25 : let cert_resolver = Arc::new(cert_resolver);
90 :
91 25 : let config = rustls::ServerConfig::builder()
92 25 : .with_safe_default_cipher_suites()
93 25 : .with_safe_default_kx_groups()
94 25 : // allow TLS 1.2 to be compatible with older client libraries
95 25 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
96 25 : .with_no_client_auth()
97 25 : .with_cert_resolver(cert_resolver.clone())
98 25 : .into();
99 25 :
100 25 : Ok(TlsConfig {
101 25 : config,
102 25 : common_names,
103 25 : cert_resolver,
104 25 : })
105 25 : }
106 :
107 : /// Channel binding parameter
108 : ///
109 : /// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
110 : /// Description: The hash of the TLS server's certificate as it
111 : /// appears, octet for octet, in the server's Certificate message. Note
112 : /// that the Certificate message contains a certificate_list, in which
113 : /// the first element is the server's certificate.
114 : ///
115 : /// The hash function is to be selected as follows:
116 : ///
117 : /// * if the certificate's signatureAlgorithm uses a single hash
118 : /// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
119 : ///
120 : /// * if the certificate's signatureAlgorithm uses a single hash
121 : /// function and that hash function neither MD5 nor SHA-1, then use
122 : /// the hash function associated with the certificate's
123 : /// signatureAlgorithm;
124 : ///
125 : /// * if the certificate's signatureAlgorithm uses no hash functions or
126 : /// uses multiple hash functions, then this channel binding type's
127 : /// channel bindings are undefined at this time (updates to is channel
128 : /// binding type may occur to address this issue if it ever arises).
129 191 : #[derive(Debug, Clone, Copy)]
130 : pub enum TlsServerEndPoint {
131 : Sha256([u8; 32]),
132 : Undefined,
133 : }
134 :
135 : impl TlsServerEndPoint {
136 68 : pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
137 68 : let sha256_oids = [
138 68 : // I'm explicitly not adding MD5 or SHA1 here... They're bad.
139 68 : oid_registry::OID_SIG_ECDSA_WITH_SHA256,
140 68 : oid_registry::OID_PKCS1_SHA256WITHRSA,
141 68 : ];
142 :
143 68 : let pem = x509_parser::parse_x509_certificate(&cert.0)
144 68 : .context("Failed to parse PEM object from cerficiate")?
145 : .1;
146 :
147 68 : info!(subject = %pem.subject, "parsing TLS certificate");
148 :
149 68 : let reg = oid_registry::OidRegistry::default().with_all_crypto();
150 68 : let oid = pem.signature_algorithm.oid();
151 68 : let alg = reg.get(oid);
152 68 : if sha256_oids.contains(oid) {
153 68 : let tls_server_end_point: [u8; 32] =
154 68 : Sha256::new().chain_update(&cert.0).finalize().into();
155 68 : info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
156 68 : Ok(Self::Sha256(tls_server_end_point))
157 : } else {
158 0 : error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
159 0 : Ok(Self::Undefined)
160 : }
161 68 : }
162 :
163 68 : pub fn supported(&self) -> bool {
164 68 : !matches!(self, TlsServerEndPoint::Undefined)
165 68 : }
166 : }
167 :
168 67 : #[derive(Default)]
169 : pub struct CertResolver {
170 : certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
171 : default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
172 : }
173 :
174 : impl CertResolver {
175 67 : pub fn new() -> Self {
176 67 : Self::default()
177 67 : }
178 :
179 25 : fn add_cert_path(
180 25 : &mut self,
181 25 : key_path: &str,
182 25 : cert_path: &str,
183 25 : is_default: bool,
184 25 : ) -> anyhow::Result<()> {
185 25 : let priv_key = {
186 25 : let key_bytes = std::fs::read(key_path)
187 25 : .context(format!("Failed to read TLS keys at '{key_path}'"))?;
188 25 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
189 25 : .context(format!("Failed to parse TLS keys at '{key_path}'"))?;
190 :
191 25 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
192 25 : keys.pop().map(rustls::PrivateKey).unwrap()
193 : };
194 :
195 25 : let cert_chain_bytes = std::fs::read(cert_path)
196 25 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
197 :
198 25 : let cert_chain = {
199 25 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
200 25 : .with_context(|| {
201 0 : format!(
202 0 : "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
203 0 : )
204 25 : })?
205 25 : .into_iter()
206 25 : .map(rustls::Certificate)
207 25 : .collect()
208 25 : };
209 25 :
210 25 : self.add_cert(priv_key, cert_chain, is_default)
211 25 : }
212 :
213 67 : pub fn add_cert(
214 67 : &mut self,
215 67 : priv_key: PrivateKey,
216 67 : cert_chain: Vec<Certificate>,
217 67 : is_default: bool,
218 67 : ) -> anyhow::Result<()> {
219 67 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
220 :
221 67 : let first_cert = &cert_chain[0];
222 67 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
223 67 : let pem = x509_parser::parse_x509_certificate(&first_cert.0)
224 67 : .context("Failed to parse PEM object from cerficiate")?
225 : .1;
226 :
227 67 : let common_name = pem.subject().to_string();
228 :
229 : // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
230 : // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
231 : // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
232 : // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
233 : // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
234 : // of cutting off '*.' parts.
235 67 : let common_name = if common_name.starts_with("CN=*.") {
236 25 : common_name.strip_prefix("CN=*.").map(|s| s.to_string())
237 : } else {
238 42 : common_name.strip_prefix("CN=").map(|s| s.to_string())
239 : }
240 67 : .context("Failed to parse common name from certificate")?;
241 :
242 67 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
243 67 :
244 67 : if is_default {
245 67 : self.default = Some((cert.clone(), tls_server_end_point));
246 67 : }
247 :
248 67 : self.certs.insert(common_name, (cert, tls_server_end_point));
249 67 :
250 67 : Ok(())
251 67 : }
252 :
253 67 : pub fn get_common_names(&self) -> HashSet<String> {
254 67 : self.certs.keys().map(|s| s.to_string()).collect()
255 67 : }
256 : }
257 :
258 : impl rustls::server::ResolvesServerCert for CertResolver {
259 99 : fn resolve(
260 99 : &self,
261 99 : client_hello: rustls::server::ClientHello,
262 99 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
263 99 : self.resolve(client_hello.server_name()).map(|x| x.0)
264 99 : }
265 : }
266 :
267 : impl CertResolver {
268 191 : pub fn resolve(
269 191 : &self,
270 191 : server_name: Option<&str>,
271 191 : ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
272 : // loop here and cut off more and more subdomains until we find
273 : // a match to get a proper wildcard support. OTOH, we now do not
274 : // use nested domains, so keep this simple for now.
275 : //
276 : // With the current coding foo.com will match *.foo.com and that
277 : // repeats behavior of the old code.
278 191 : if let Some(mut sni_name) = server_name {
279 : loop {
280 302 : if let Some(cert) = self.certs.get(sni_name) {
281 151 : return Some(cert.clone());
282 151 : }
283 151 : if let Some((_, rest)) = sni_name.split_once('.') {
284 151 : sni_name = rest;
285 151 : } else {
286 0 : return None;
287 : }
288 : }
289 : } else {
290 : // No SNI, use the default certificate, otherwise we can't get to
291 : // options parameter which can be used to set endpoint name too.
292 : // That means that non-SNI flow will not work for CNAME domains in
293 : // verify-full mode.
294 : //
295 : // If that will be a problem we can:
296 : //
297 : // a) Instead of multi-cert approach use single cert with extra
298 : // domains listed in Subject Alternative Name (SAN).
299 : // b) Deploy separate proxy instances for extra domains.
300 40 : self.default.as_ref().cloned()
301 : }
302 191 : }
303 : }
304 :
305 : /// Helper for cmdline cache options parsing.
306 1 : #[derive(Debug)]
307 : pub struct CacheOptions {
308 : /// Max number of entries.
309 : pub size: usize,
310 : /// Entry's time-to-live.
311 : pub ttl: Duration,
312 : }
313 :
314 : impl CacheOptions {
315 : /// Default options for [`crate::console::provider::NodeInfoCache`].
316 : pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
317 :
318 : /// Parse cache options passed via cmdline.
319 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
320 9 : fn parse(options: &str) -> anyhow::Result<Self> {
321 9 : let mut size = None;
322 9 : let mut ttl = None;
323 :
324 15 : for option in options.split(',') {
325 15 : let (key, value) = option
326 15 : .split_once('=')
327 15 : .with_context(|| format!("bad key-value pair: {option}"))?;
328 :
329 15 : match key {
330 15 : "size" => size = Some(value.parse()?),
331 6 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
332 0 : unknown => bail!("unknown key: {unknown}"),
333 : }
334 : }
335 :
336 : // TTL doesn't matter if cache is always empty.
337 9 : if let Some(0) = size {
338 5 : ttl.get_or_insert(Duration::default());
339 5 : }
340 :
341 : Ok(Self {
342 9 : size: size.context("missing `size`")?,
343 9 : ttl: ttl.context("missing `ttl`")?,
344 : })
345 9 : }
346 : }
347 :
348 : impl FromStr for CacheOptions {
349 : type Err = anyhow::Error;
350 :
351 9 : fn from_str(options: &str) -> Result<Self, Self::Err> {
352 9 : let error = || format!("failed to parse cache options '{options}'");
353 9 : Self::parse(options).with_context(error)
354 9 : }
355 : }
356 :
357 : /// Helper for cmdline cache options parsing.
358 1 : #[derive(Debug)]
359 : pub struct ProjectInfoCacheOptions {
360 : /// Max number of entries.
361 : pub size: usize,
362 : /// Entry's time-to-live.
363 : pub ttl: Duration,
364 : /// Max number of roles per endpoint.
365 : pub max_roles: usize,
366 : /// Gc interval.
367 : pub gc_interval: Duration,
368 : }
369 :
370 : impl ProjectInfoCacheOptions {
371 : /// Default options for [`crate::console::provider::NodeInfoCache`].
372 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
373 : "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
374 :
375 : /// Parse cache options passed via cmdline.
376 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
377 1 : fn parse(options: &str) -> anyhow::Result<Self> {
378 1 : let mut size = None;
379 1 : let mut ttl = None;
380 1 : let mut max_roles = None;
381 1 : let mut gc_interval = None;
382 :
383 4 : for option in options.split(',') {
384 4 : let (key, value) = option
385 4 : .split_once('=')
386 4 : .with_context(|| format!("bad key-value pair: {option}"))?;
387 :
388 4 : match key {
389 4 : "size" => size = Some(value.parse()?),
390 3 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
391 2 : "max_roles" => max_roles = Some(value.parse()?),
392 1 : "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
393 0 : unknown => bail!("unknown key: {unknown}"),
394 : }
395 : }
396 :
397 : // TTL doesn't matter if cache is always empty.
398 1 : if let Some(0) = size {
399 0 : ttl.get_or_insert(Duration::default());
400 1 : }
401 :
402 : Ok(Self {
403 1 : size: size.context("missing `size`")?,
404 1 : ttl: ttl.context("missing `ttl`")?,
405 1 : max_roles: max_roles.context("missing `max_roles`")?,
406 1 : gc_interval: gc_interval.context("missing `gc_interval`")?,
407 : })
408 1 : }
409 : }
410 :
411 : impl FromStr for ProjectInfoCacheOptions {
412 : type Err = anyhow::Error;
413 :
414 1 : fn from_str(options: &str) -> Result<Self, Self::Err> {
415 1 : let error = || format!("failed to parse cache options '{options}'");
416 1 : Self::parse(options).with_context(error)
417 1 : }
418 : }
419 :
420 : /// Helper for cmdline cache options parsing.
421 : pub struct WakeComputeLockOptions {
422 : /// The number of shards the lock map should have
423 : pub shards: usize,
424 : /// The number of allowed concurrent requests for each endpoitn
425 : pub permits: usize,
426 : /// Garbage collection epoch
427 : pub epoch: Duration,
428 : /// Lock timeout
429 : pub timeout: Duration,
430 : }
431 :
432 : impl WakeComputeLockOptions {
433 : /// Default options for [`crate::console::provider::ApiLocks`].
434 : pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
435 :
436 : // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
437 :
438 : /// Parse lock options passed via cmdline.
439 : /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
440 7 : fn parse(options: &str) -> anyhow::Result<Self> {
441 7 : let mut shards = None;
442 7 : let mut permits = None;
443 7 : let mut epoch = None;
444 7 : let mut timeout = None;
445 :
446 19 : for option in options.split(',') {
447 19 : let (key, value) = option
448 19 : .split_once('=')
449 19 : .with_context(|| format!("bad key-value pair: {option}"))?;
450 :
451 19 : match key {
452 19 : "shards" => shards = Some(value.parse()?),
453 15 : "permits" => permits = Some(value.parse()?),
454 8 : "epoch" => epoch = Some(humantime::parse_duration(value)?),
455 4 : "timeout" => timeout = Some(humantime::parse_duration(value)?),
456 0 : unknown => bail!("unknown key: {unknown}"),
457 : }
458 : }
459 :
460 : // these dont matter if lock is disabled
461 7 : if let Some(0) = permits {
462 3 : timeout = Some(Duration::default());
463 3 : epoch = Some(Duration::default());
464 3 : shards = Some(2);
465 4 : }
466 :
467 7 : let out = Self {
468 7 : shards: shards.context("missing `shards`")?,
469 7 : permits: permits.context("missing `permits`")?,
470 7 : epoch: epoch.context("missing `epoch`")?,
471 7 : timeout: timeout.context("missing `timeout`")?,
472 : };
473 :
474 7 : ensure!(out.shards > 1, "shard count must be > 1");
475 7 : ensure!(
476 7 : out.shards.is_power_of_two(),
477 0 : "shard count must be a power of two"
478 : );
479 :
480 7 : Ok(out)
481 7 : }
482 : }
483 :
484 : impl FromStr for WakeComputeLockOptions {
485 : type Err = anyhow::Error;
486 :
487 7 : fn from_str(options: &str) -> Result<Self, Self::Err> {
488 7 : let error = || format!("failed to parse cache lock options '{options}'");
489 7 : Self::parse(options).with_context(error)
490 7 : }
491 : }
492 :
493 : #[cfg(test)]
494 : mod tests {
495 : use super::*;
496 :
497 2 : #[test]
498 2 : fn test_parse_cache_options() -> anyhow::Result<()> {
499 2 : let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
500 2 : assert_eq!(size, 4096);
501 2 : assert_eq!(ttl, Duration::from_secs(5 * 60));
502 :
503 2 : let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
504 2 : assert_eq!(size, 2);
505 2 : assert_eq!(ttl, Duration::from_secs(4 * 60));
506 :
507 2 : let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
508 2 : assert_eq!(size, 0);
509 2 : assert_eq!(ttl, Duration::from_secs(1));
510 :
511 2 : let CacheOptions { size, ttl } = "size=0".parse()?;
512 2 : assert_eq!(size, 0);
513 2 : assert_eq!(ttl, Duration::default());
514 :
515 2 : Ok(())
516 2 : }
517 :
518 2 : #[test]
519 2 : fn test_parse_lock_options() -> anyhow::Result<()> {
520 : let WakeComputeLockOptions {
521 2 : epoch,
522 2 : permits,
523 2 : shards,
524 2 : timeout,
525 2 : } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
526 2 : assert_eq!(epoch, Duration::from_secs(10 * 60));
527 2 : assert_eq!(timeout, Duration::from_secs(1));
528 2 : assert_eq!(shards, 32);
529 2 : assert_eq!(permits, 4);
530 :
531 : let WakeComputeLockOptions {
532 2 : epoch,
533 2 : permits,
534 2 : shards,
535 2 : timeout,
536 2 : } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
537 2 : assert_eq!(epoch, Duration::from_secs(60));
538 2 : assert_eq!(timeout, Duration::from_millis(100));
539 2 : assert_eq!(shards, 16);
540 2 : assert_eq!(permits, 8);
541 :
542 : let WakeComputeLockOptions {
543 2 : epoch,
544 2 : permits,
545 2 : shards,
546 2 : timeout,
547 2 : } = "permits=0".parse()?;
548 2 : assert_eq!(epoch, Duration::ZERO);
549 2 : assert_eq!(timeout, Duration::ZERO);
550 2 : assert_eq!(shards, 2);
551 2 : assert_eq!(permits, 0);
552 :
553 2 : Ok(())
554 2 : }
555 : }
|