Line data Source code
1 : use crate::{auth, rate_limiter::RateBucketInfo, serverless::GlobalConnPoolOptions};
2 : use anyhow::{bail, ensure, Context, Ok};
3 : use rustls::{sign, Certificate, PrivateKey};
4 : use sha2::{Digest, Sha256};
5 : use std::{
6 : collections::{HashMap, HashSet},
7 : str::FromStr,
8 : sync::Arc,
9 : time::Duration,
10 : };
11 : use tracing::{error, info};
12 : use x509_parser::oid_registry;
13 :
14 : pub struct ProxyConfig {
15 : pub tls_config: Option<TlsConfig>,
16 : pub auth_backend: auth::BackendType<'static, ()>,
17 : pub metric_collection: Option<MetricCollectionConfig>,
18 : pub allow_self_signed_compute: bool,
19 : pub http_config: HttpConfig,
20 : pub authentication_config: AuthenticationConfig,
21 : pub require_client_ip: bool,
22 : pub disable_ip_check_for_http: bool,
23 : pub endpoint_rps_limit: Vec<RateBucketInfo>,
24 : pub region: String,
25 : pub handshake_timeout: Duration,
26 : }
27 :
28 1 : #[derive(Debug)]
29 : pub struct MetricCollectionConfig {
30 : pub endpoint: reqwest::Url,
31 : pub interval: Duration,
32 : }
33 :
34 : pub struct TlsConfig {
35 : pub config: Arc<rustls::ServerConfig>,
36 : pub common_names: HashSet<String>,
37 : pub cert_resolver: Arc<CertResolver>,
38 : }
39 :
40 : pub struct HttpConfig {
41 : pub request_timeout: tokio::time::Duration,
42 : pub pool_options: GlobalConnPoolOptions,
43 : }
44 :
45 : pub struct AuthenticationConfig {
46 : pub scram_protocol_timeout: tokio::time::Duration,
47 : }
48 :
49 : impl TlsConfig {
50 117 : pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
51 117 : self.config.clone()
52 117 : }
53 : }
54 :
55 : /// Configure TLS for the main endpoint.
56 25 : pub fn configure_tls(
57 25 : key_path: &str,
58 25 : cert_path: &str,
59 25 : certs_dir: Option<&String>,
60 25 : ) -> anyhow::Result<TlsConfig> {
61 25 : let mut cert_resolver = CertResolver::new();
62 25 :
63 25 : // add default certificate
64 25 : cert_resolver.add_cert_path(key_path, cert_path, true)?;
65 :
66 : // add extra certificates
67 25 : if let Some(certs_dir) = certs_dir {
68 0 : for entry in std::fs::read_dir(certs_dir)? {
69 0 : let entry = entry?;
70 0 : let path = entry.path();
71 0 : if path.is_dir() {
72 : // file names aligned with default cert-manager names
73 0 : let key_path = path.join("tls.key");
74 0 : let cert_path = path.join("tls.crt");
75 0 : if key_path.exists() && cert_path.exists() {
76 0 : cert_resolver.add_cert_path(
77 0 : &key_path.to_string_lossy(),
78 0 : &cert_path.to_string_lossy(),
79 0 : false,
80 0 : )?;
81 0 : }
82 0 : }
83 : }
84 25 : }
85 :
86 25 : let common_names = cert_resolver.get_common_names();
87 25 :
88 25 : let cert_resolver = Arc::new(cert_resolver);
89 :
90 25 : let config = rustls::ServerConfig::builder()
91 25 : .with_safe_default_cipher_suites()
92 25 : .with_safe_default_kx_groups()
93 25 : // allow TLS 1.2 to be compatible with older client libraries
94 25 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
95 25 : .with_no_client_auth()
96 25 : .with_cert_resolver(cert_resolver.clone())
97 25 : .into();
98 25 :
99 25 : Ok(TlsConfig {
100 25 : config,
101 25 : common_names,
102 25 : cert_resolver,
103 25 : })
104 25 : }
105 :
106 : /// Channel binding parameter
107 : ///
108 : /// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
109 : /// Description: The hash of the TLS server's certificate as it
110 : /// appears, octet for octet, in the server's Certificate message. Note
111 : /// that the Certificate message contains a certificate_list, in which
112 : /// the first element is the server's certificate.
113 : ///
114 : /// The hash function is to be selected as follows:
115 : ///
116 : /// * if the certificate's signatureAlgorithm uses a single hash
117 : /// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
118 : ///
119 : /// * if the certificate's signatureAlgorithm uses a single hash
120 : /// function and that hash function neither MD5 nor SHA-1, then use
121 : /// the hash function associated with the certificate's
122 : /// signatureAlgorithm;
123 : ///
124 : /// * if the certificate's signatureAlgorithm uses no hash functions or
125 : /// uses multiple hash functions, then this channel binding type's
126 : /// channel bindings are undefined at this time (updates to is channel
127 : /// binding type may occur to address this issue if it ever arises).
128 191 : #[derive(Debug, Clone, Copy)]
129 : pub enum TlsServerEndPoint {
130 : Sha256([u8; 32]),
131 : Undefined,
132 : }
133 :
134 : impl TlsServerEndPoint {
135 68 : pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
136 68 : let sha256_oids = [
137 68 : // I'm explicitly not adding MD5 or SHA1 here... They're bad.
138 68 : oid_registry::OID_SIG_ECDSA_WITH_SHA256,
139 68 : oid_registry::OID_PKCS1_SHA256WITHRSA,
140 68 : ];
141 :
142 68 : let pem = x509_parser::parse_x509_certificate(&cert.0)
143 68 : .context("Failed to parse PEM object from cerficiate")?
144 : .1;
145 :
146 68 : info!(subject = %pem.subject, "parsing TLS certificate");
147 :
148 68 : let reg = oid_registry::OidRegistry::default().with_all_crypto();
149 68 : let oid = pem.signature_algorithm.oid();
150 68 : let alg = reg.get(oid);
151 68 : if sha256_oids.contains(oid) {
152 68 : let tls_server_end_point: [u8; 32] =
153 68 : Sha256::new().chain_update(&cert.0).finalize().into();
154 68 : info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
155 68 : Ok(Self::Sha256(tls_server_end_point))
156 : } else {
157 0 : error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
158 0 : Ok(Self::Undefined)
159 : }
160 68 : }
161 :
162 68 : pub fn supported(&self) -> bool {
163 68 : !matches!(self, TlsServerEndPoint::Undefined)
164 68 : }
165 : }
166 :
167 67 : #[derive(Default)]
168 : pub struct CertResolver {
169 : certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
170 : default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
171 : }
172 :
173 : impl CertResolver {
174 67 : pub fn new() -> Self {
175 67 : Self::default()
176 67 : }
177 :
178 25 : fn add_cert_path(
179 25 : &mut self,
180 25 : key_path: &str,
181 25 : cert_path: &str,
182 25 : is_default: bool,
183 25 : ) -> anyhow::Result<()> {
184 25 : let priv_key = {
185 25 : let key_bytes = std::fs::read(key_path)
186 25 : .context(format!("Failed to read TLS keys at '{key_path}'"))?;
187 25 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
188 25 : .context(format!("Failed to parse TLS keys at '{key_path}'"))?;
189 :
190 25 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
191 25 : keys.pop().map(rustls::PrivateKey).unwrap()
192 : };
193 :
194 25 : let cert_chain_bytes = std::fs::read(cert_path)
195 25 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
196 :
197 25 : let cert_chain = {
198 25 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
199 25 : .with_context(|| {
200 0 : format!(
201 0 : "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
202 0 : )
203 25 : })?
204 25 : .into_iter()
205 25 : .map(rustls::Certificate)
206 25 : .collect()
207 25 : };
208 25 :
209 25 : self.add_cert(priv_key, cert_chain, is_default)
210 25 : }
211 :
212 67 : pub fn add_cert(
213 67 : &mut self,
214 67 : priv_key: PrivateKey,
215 67 : cert_chain: Vec<Certificate>,
216 67 : is_default: bool,
217 67 : ) -> anyhow::Result<()> {
218 67 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
219 :
220 67 : let first_cert = &cert_chain[0];
221 67 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
222 67 : let pem = x509_parser::parse_x509_certificate(&first_cert.0)
223 67 : .context("Failed to parse PEM object from cerficiate")?
224 : .1;
225 :
226 67 : let common_name = pem.subject().to_string();
227 :
228 : // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
229 : // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
230 : // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
231 : // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
232 : // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
233 : // of cutting off '*.' parts.
234 67 : let common_name = if common_name.starts_with("CN=*.") {
235 25 : common_name.strip_prefix("CN=*.").map(|s| s.to_string())
236 : } else {
237 42 : common_name.strip_prefix("CN=").map(|s| s.to_string())
238 : }
239 67 : .context("Failed to parse common name from certificate")?;
240 :
241 67 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
242 67 :
243 67 : if is_default {
244 67 : self.default = Some((cert.clone(), tls_server_end_point));
245 67 : }
246 :
247 67 : self.certs.insert(common_name, (cert, tls_server_end_point));
248 67 :
249 67 : Ok(())
250 67 : }
251 :
252 67 : pub fn get_common_names(&self) -> HashSet<String> {
253 67 : self.certs.keys().map(|s| s.to_string()).collect()
254 67 : }
255 : }
256 :
257 : impl rustls::server::ResolvesServerCert for CertResolver {
258 99 : fn resolve(
259 99 : &self,
260 99 : client_hello: rustls::server::ClientHello,
261 99 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
262 99 : self.resolve(client_hello.server_name()).map(|x| x.0)
263 99 : }
264 : }
265 :
266 : impl CertResolver {
267 191 : pub fn resolve(
268 191 : &self,
269 191 : server_name: Option<&str>,
270 191 : ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
271 : // loop here and cut off more and more subdomains until we find
272 : // a match to get a proper wildcard support. OTOH, we now do not
273 : // use nested domains, so keep this simple for now.
274 : //
275 : // With the current coding foo.com will match *.foo.com and that
276 : // repeats behavior of the old code.
277 191 : if let Some(mut sni_name) = server_name {
278 : loop {
279 302 : if let Some(cert) = self.certs.get(sni_name) {
280 151 : return Some(cert.clone());
281 151 : }
282 151 : if let Some((_, rest)) = sni_name.split_once('.') {
283 151 : sni_name = rest;
284 151 : } else {
285 0 : return None;
286 : }
287 : }
288 : } else {
289 : // No SNI, use the default certificate, otherwise we can't get to
290 : // options parameter which can be used to set endpoint name too.
291 : // That means that non-SNI flow will not work for CNAME domains in
292 : // verify-full mode.
293 : //
294 : // If that will be a problem we can:
295 : //
296 : // a) Instead of multi-cert approach use single cert with extra
297 : // domains listed in Subject Alternative Name (SAN).
298 : // b) Deploy separate proxy instances for extra domains.
299 40 : self.default.as_ref().cloned()
300 : }
301 191 : }
302 : }
303 :
304 : /// Helper for cmdline cache options parsing.
305 1 : #[derive(Debug)]
306 : pub struct CacheOptions {
307 : /// Max number of entries.
308 : pub size: usize,
309 : /// Entry's time-to-live.
310 : pub ttl: Duration,
311 : }
312 :
313 : impl CacheOptions {
314 : /// Default options for [`crate::console::provider::NodeInfoCache`].
315 : pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
316 :
317 : /// Parse cache options passed via cmdline.
318 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
319 9 : fn parse(options: &str) -> anyhow::Result<Self> {
320 9 : let mut size = None;
321 9 : let mut ttl = None;
322 :
323 15 : for option in options.split(',') {
324 15 : let (key, value) = option
325 15 : .split_once('=')
326 15 : .with_context(|| format!("bad key-value pair: {option}"))?;
327 :
328 15 : match key {
329 15 : "size" => size = Some(value.parse()?),
330 6 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
331 0 : unknown => bail!("unknown key: {unknown}"),
332 : }
333 : }
334 :
335 : // TTL doesn't matter if cache is always empty.
336 9 : if let Some(0) = size {
337 5 : ttl.get_or_insert(Duration::default());
338 5 : }
339 :
340 : Ok(Self {
341 9 : size: size.context("missing `size`")?,
342 9 : ttl: ttl.context("missing `ttl`")?,
343 : })
344 9 : }
345 : }
346 :
347 : impl FromStr for CacheOptions {
348 : type Err = anyhow::Error;
349 :
350 9 : fn from_str(options: &str) -> Result<Self, Self::Err> {
351 9 : let error = || format!("failed to parse cache options '{options}'");
352 9 : Self::parse(options).with_context(error)
353 9 : }
354 : }
355 :
356 : /// Helper for cmdline cache options parsing.
357 1 : #[derive(Debug)]
358 : pub struct ProjectInfoCacheOptions {
359 : /// Max number of entries.
360 : pub size: usize,
361 : /// Entry's time-to-live.
362 : pub ttl: Duration,
363 : /// Max number of roles per endpoint.
364 : pub max_roles: usize,
365 : /// Gc interval.
366 : pub gc_interval: Duration,
367 : }
368 :
369 : impl ProjectInfoCacheOptions {
370 : /// Default options for [`crate::console::provider::NodeInfoCache`].
371 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
372 : "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
373 :
374 : /// Parse cache options passed via cmdline.
375 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
376 1 : fn parse(options: &str) -> anyhow::Result<Self> {
377 1 : let mut size = None;
378 1 : let mut ttl = None;
379 1 : let mut max_roles = None;
380 1 : let mut gc_interval = None;
381 :
382 4 : for option in options.split(',') {
383 4 : let (key, value) = option
384 4 : .split_once('=')
385 4 : .with_context(|| format!("bad key-value pair: {option}"))?;
386 :
387 4 : match key {
388 4 : "size" => size = Some(value.parse()?),
389 3 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
390 2 : "max_roles" => max_roles = Some(value.parse()?),
391 1 : "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
392 0 : unknown => bail!("unknown key: {unknown}"),
393 : }
394 : }
395 :
396 : // TTL doesn't matter if cache is always empty.
397 1 : if let Some(0) = size {
398 0 : ttl.get_or_insert(Duration::default());
399 1 : }
400 :
401 : Ok(Self {
402 1 : size: size.context("missing `size`")?,
403 1 : ttl: ttl.context("missing `ttl`")?,
404 1 : max_roles: max_roles.context("missing `max_roles`")?,
405 1 : gc_interval: gc_interval.context("missing `gc_interval`")?,
406 : })
407 1 : }
408 : }
409 :
410 : impl FromStr for ProjectInfoCacheOptions {
411 : type Err = anyhow::Error;
412 :
413 1 : fn from_str(options: &str) -> Result<Self, Self::Err> {
414 1 : let error = || format!("failed to parse cache options '{options}'");
415 1 : Self::parse(options).with_context(error)
416 1 : }
417 : }
418 :
419 : /// Helper for cmdline cache options parsing.
420 : pub struct WakeComputeLockOptions {
421 : /// The number of shards the lock map should have
422 : pub shards: usize,
423 : /// The number of allowed concurrent requests for each endpoitn
424 : pub permits: usize,
425 : /// Garbage collection epoch
426 : pub epoch: Duration,
427 : /// Lock timeout
428 : pub timeout: Duration,
429 : }
430 :
431 : impl WakeComputeLockOptions {
432 : /// Default options for [`crate::console::provider::ApiLocks`].
433 : pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
434 :
435 : // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
436 :
437 : /// Parse lock options passed via cmdline.
438 : /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
439 7 : fn parse(options: &str) -> anyhow::Result<Self> {
440 7 : let mut shards = None;
441 7 : let mut permits = None;
442 7 : let mut epoch = None;
443 7 : let mut timeout = None;
444 :
445 19 : for option in options.split(',') {
446 19 : let (key, value) = option
447 19 : .split_once('=')
448 19 : .with_context(|| format!("bad key-value pair: {option}"))?;
449 :
450 19 : match key {
451 19 : "shards" => shards = Some(value.parse()?),
452 15 : "permits" => permits = Some(value.parse()?),
453 8 : "epoch" => epoch = Some(humantime::parse_duration(value)?),
454 4 : "timeout" => timeout = Some(humantime::parse_duration(value)?),
455 0 : unknown => bail!("unknown key: {unknown}"),
456 : }
457 : }
458 :
459 : // these dont matter if lock is disabled
460 7 : if let Some(0) = permits {
461 3 : timeout = Some(Duration::default());
462 3 : epoch = Some(Duration::default());
463 3 : shards = Some(2);
464 4 : }
465 :
466 7 : let out = Self {
467 7 : shards: shards.context("missing `shards`")?,
468 7 : permits: permits.context("missing `permits`")?,
469 7 : epoch: epoch.context("missing `epoch`")?,
470 7 : timeout: timeout.context("missing `timeout`")?,
471 : };
472 :
473 7 : ensure!(out.shards > 1, "shard count must be > 1");
474 7 : ensure!(
475 7 : out.shards.is_power_of_two(),
476 0 : "shard count must be a power of two"
477 : );
478 :
479 7 : Ok(out)
480 7 : }
481 : }
482 :
483 : impl FromStr for WakeComputeLockOptions {
484 : type Err = anyhow::Error;
485 :
486 7 : fn from_str(options: &str) -> Result<Self, Self::Err> {
487 7 : let error = || format!("failed to parse cache lock options '{options}'");
488 7 : Self::parse(options).with_context(error)
489 7 : }
490 : }
491 :
492 : #[cfg(test)]
493 : mod tests {
494 : use super::*;
495 :
496 2 : #[test]
497 2 : fn test_parse_cache_options() -> anyhow::Result<()> {
498 2 : let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
499 2 : assert_eq!(size, 4096);
500 2 : assert_eq!(ttl, Duration::from_secs(5 * 60));
501 :
502 2 : let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
503 2 : assert_eq!(size, 2);
504 2 : assert_eq!(ttl, Duration::from_secs(4 * 60));
505 :
506 2 : let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
507 2 : assert_eq!(size, 0);
508 2 : assert_eq!(ttl, Duration::from_secs(1));
509 :
510 2 : let CacheOptions { size, ttl } = "size=0".parse()?;
511 2 : assert_eq!(size, 0);
512 2 : assert_eq!(ttl, Duration::default());
513 :
514 2 : Ok(())
515 2 : }
516 :
517 2 : #[test]
518 2 : fn test_parse_lock_options() -> anyhow::Result<()> {
519 : let WakeComputeLockOptions {
520 2 : epoch,
521 2 : permits,
522 2 : shards,
523 2 : timeout,
524 2 : } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
525 2 : assert_eq!(epoch, Duration::from_secs(10 * 60));
526 2 : assert_eq!(timeout, Duration::from_secs(1));
527 2 : assert_eq!(shards, 32);
528 2 : assert_eq!(permits, 4);
529 :
530 : let WakeComputeLockOptions {
531 2 : epoch,
532 2 : permits,
533 2 : shards,
534 2 : timeout,
535 2 : } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
536 2 : assert_eq!(epoch, Duration::from_secs(60));
537 2 : assert_eq!(timeout, Duration::from_millis(100));
538 2 : assert_eq!(shards, 16);
539 2 : assert_eq!(permits, 8);
540 :
541 : let WakeComputeLockOptions {
542 2 : epoch,
543 2 : permits,
544 2 : shards,
545 2 : timeout,
546 2 : } = "permits=0".parse()?;
547 2 : assert_eq!(epoch, Duration::ZERO);
548 2 : assert_eq!(timeout, Duration::ZERO);
549 2 : assert_eq!(shards, 2);
550 2 : assert_eq!(permits, 0);
551 :
552 2 : Ok(())
553 2 : }
554 : }
|