Line data Source code
1 : use crate::{auth, rate_limiter::RateBucketInfo, serverless::GlobalConnPoolOptions};
2 : use anyhow::{bail, ensure, Context, Ok};
3 : use rustls::{sign, Certificate, PrivateKey};
4 : use sha2::{Digest, Sha256};
5 : use std::{
6 : collections::{HashMap, HashSet},
7 : str::FromStr,
8 : sync::Arc,
9 : time::Duration,
10 : };
11 : use tracing::{error, info};
12 : use x509_parser::oid_registry;
13 :
14 : pub struct ProxyConfig {
15 : pub tls_config: Option<TlsConfig>,
16 : pub auth_backend: auth::BackendType<'static, ()>,
17 : pub metric_collection: Option<MetricCollectionConfig>,
18 : pub allow_self_signed_compute: bool,
19 : pub http_config: HttpConfig,
20 : pub authentication_config: AuthenticationConfig,
21 : pub require_client_ip: bool,
22 : pub disable_ip_check_for_http: bool,
23 : pub endpoint_rps_limit: Vec<RateBucketInfo>,
24 : pub region: String,
25 : }
26 :
27 1 : #[derive(Debug)]
28 : pub struct MetricCollectionConfig {
29 : pub endpoint: reqwest::Url,
30 : pub interval: Duration,
31 : }
32 :
33 : pub struct TlsConfig {
34 : pub config: Arc<rustls::ServerConfig>,
35 : pub common_names: HashSet<String>,
36 : pub cert_resolver: Arc<CertResolver>,
37 : }
38 :
39 : pub struct HttpConfig {
40 : pub request_timeout: tokio::time::Duration,
41 : pub pool_options: GlobalConnPoolOptions,
42 : }
43 :
44 : pub struct AuthenticationConfig {
45 : pub scram_protocol_timeout: tokio::time::Duration,
46 : }
47 :
48 : impl TlsConfig {
49 113 : pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
50 113 : self.config.clone()
51 113 : }
52 : }
53 :
54 : /// Configure TLS for the main endpoint.
55 23 : pub fn configure_tls(
56 23 : key_path: &str,
57 23 : cert_path: &str,
58 23 : certs_dir: Option<&String>,
59 23 : ) -> anyhow::Result<TlsConfig> {
60 23 : let mut cert_resolver = CertResolver::new();
61 23 :
62 23 : // add default certificate
63 23 : cert_resolver.add_cert_path(key_path, cert_path, true)?;
64 :
65 : // add extra certificates
66 23 : if let Some(certs_dir) = certs_dir {
67 0 : for entry in std::fs::read_dir(certs_dir)? {
68 0 : let entry = entry?;
69 0 : let path = entry.path();
70 0 : if path.is_dir() {
71 : // file names aligned with default cert-manager names
72 0 : let key_path = path.join("tls.key");
73 0 : let cert_path = path.join("tls.crt");
74 0 : if key_path.exists() && cert_path.exists() {
75 0 : cert_resolver.add_cert_path(
76 0 : &key_path.to_string_lossy(),
77 0 : &cert_path.to_string_lossy(),
78 0 : false,
79 0 : )?;
80 0 : }
81 0 : }
82 : }
83 23 : }
84 :
85 23 : let common_names = cert_resolver.get_common_names();
86 23 :
87 23 : let cert_resolver = Arc::new(cert_resolver);
88 :
89 23 : let config = rustls::ServerConfig::builder()
90 23 : .with_safe_default_cipher_suites()
91 23 : .with_safe_default_kx_groups()
92 23 : // allow TLS 1.2 to be compatible with older client libraries
93 23 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
94 23 : .with_no_client_auth()
95 23 : .with_cert_resolver(cert_resolver.clone())
96 23 : .into();
97 23 :
98 23 : Ok(TlsConfig {
99 23 : config,
100 23 : common_names,
101 23 : cert_resolver,
102 23 : })
103 23 : }
104 :
105 : /// Channel binding parameter
106 : ///
107 : /// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
108 : /// Description: The hash of the TLS server's certificate as it
109 : /// appears, octet for octet, in the server's Certificate message. Note
110 : /// that the Certificate message contains a certificate_list, in which
111 : /// the first element is the server's certificate.
112 : ///
113 : /// The hash function is to be selected as follows:
114 : ///
115 : /// * if the certificate's signatureAlgorithm uses a single hash
116 : /// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
117 : ///
118 : /// * if the certificate's signatureAlgorithm uses a single hash
119 : /// function and that hash function neither MD5 nor SHA-1, then use
120 : /// the hash function associated with the certificate's
121 : /// signatureAlgorithm;
122 : ///
123 : /// * if the certificate's signatureAlgorithm uses no hash functions or
124 : /// uses multiple hash functions, then this channel binding type's
125 : /// channel bindings are undefined at this time (updates to is channel
126 : /// binding type may occur to address this issue if it ever arises).
127 186 : #[derive(Debug, Clone, Copy)]
128 : pub enum TlsServerEndPoint {
129 : Sha256([u8; 32]),
130 : Undefined,
131 : }
132 :
133 : impl TlsServerEndPoint {
134 66 : pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
135 66 : let sha256_oids = [
136 66 : // I'm explicitly not adding MD5 or SHA1 here... They're bad.
137 66 : oid_registry::OID_SIG_ECDSA_WITH_SHA256,
138 66 : oid_registry::OID_PKCS1_SHA256WITHRSA,
139 66 : ];
140 :
141 66 : let pem = x509_parser::parse_x509_certificate(&cert.0)
142 66 : .context("Failed to parse PEM object from cerficiate")?
143 : .1;
144 :
145 66 : info!(subject = %pem.subject, "parsing TLS certificate");
146 :
147 66 : let reg = oid_registry::OidRegistry::default().with_all_crypto();
148 66 : let oid = pem.signature_algorithm.oid();
149 66 : let alg = reg.get(oid);
150 66 : if sha256_oids.contains(oid) {
151 66 : let tls_server_end_point: [u8; 32] =
152 66 : Sha256::new().chain_update(&cert.0).finalize().into();
153 66 : info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
154 66 : Ok(Self::Sha256(tls_server_end_point))
155 : } else {
156 0 : error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
157 0 : Ok(Self::Undefined)
158 : }
159 66 : }
160 :
161 66 : pub fn supported(&self) -> bool {
162 66 : !matches!(self, TlsServerEndPoint::Undefined)
163 66 : }
164 : }
165 :
166 65 : #[derive(Default)]
167 : pub struct CertResolver {
168 : certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
169 : default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
170 : }
171 :
172 : impl CertResolver {
173 65 : pub fn new() -> Self {
174 65 : Self::default()
175 65 : }
176 :
177 23 : fn add_cert_path(
178 23 : &mut self,
179 23 : key_path: &str,
180 23 : cert_path: &str,
181 23 : is_default: bool,
182 23 : ) -> anyhow::Result<()> {
183 23 : let priv_key = {
184 23 : let key_bytes = std::fs::read(key_path)
185 23 : .context(format!("Failed to read TLS keys at '{key_path}'"))?;
186 23 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
187 23 : .context(format!("Failed to parse TLS keys at '{key_path}'"))?;
188 :
189 23 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
190 23 : keys.pop().map(rustls::PrivateKey).unwrap()
191 : };
192 :
193 23 : let cert_chain_bytes = std::fs::read(cert_path)
194 23 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
195 :
196 23 : let cert_chain = {
197 23 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
198 23 : .with_context(|| {
199 0 : format!(
200 0 : "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
201 0 : )
202 23 : })?
203 23 : .into_iter()
204 23 : .map(rustls::Certificate)
205 23 : .collect()
206 23 : };
207 23 :
208 23 : self.add_cert(priv_key, cert_chain, is_default)
209 23 : }
210 :
211 65 : pub fn add_cert(
212 65 : &mut self,
213 65 : priv_key: PrivateKey,
214 65 : cert_chain: Vec<Certificate>,
215 65 : is_default: bool,
216 65 : ) -> anyhow::Result<()> {
217 65 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
218 :
219 65 : let first_cert = &cert_chain[0];
220 65 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
221 65 : let pem = x509_parser::parse_x509_certificate(&first_cert.0)
222 65 : .context("Failed to parse PEM object from cerficiate")?
223 : .1;
224 :
225 65 : let common_name = pem.subject().to_string();
226 :
227 : // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
228 : // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
229 : // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
230 : // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
231 : // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
232 : // of cutting off '*.' parts.
233 65 : let common_name = if common_name.starts_with("CN=*.") {
234 23 : common_name.strip_prefix("CN=*.").map(|s| s.to_string())
235 : } else {
236 42 : common_name.strip_prefix("CN=").map(|s| s.to_string())
237 : }
238 65 : .context("Failed to parse common name from certificate")?;
239 :
240 65 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
241 65 :
242 65 : if is_default {
243 65 : self.default = Some((cert.clone(), tls_server_end_point));
244 65 : }
245 :
246 65 : self.certs.insert(common_name, (cert, tls_server_end_point));
247 65 :
248 65 : Ok(())
249 65 : }
250 :
251 65 : pub fn get_common_names(&self) -> HashSet<String> {
252 65 : self.certs.keys().map(|s| s.to_string()).collect()
253 65 : }
254 : }
255 :
256 : impl rustls::server::ResolvesServerCert for CertResolver {
257 96 : fn resolve(
258 96 : &self,
259 96 : client_hello: rustls::server::ClientHello,
260 96 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
261 96 : self.resolve(client_hello.server_name()).map(|x| x.0)
262 96 : }
263 : }
264 :
265 : impl CertResolver {
266 186 : pub fn resolve(
267 186 : &self,
268 186 : server_name: Option<&str>,
269 186 : ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
270 : // loop here and cut off more and more subdomains until we find
271 : // a match to get a proper wildcard support. OTOH, we now do not
272 : // use nested domains, so keep this simple for now.
273 : //
274 : // With the current coding foo.com will match *.foo.com and that
275 : // repeats behavior of the old code.
276 186 : if let Some(mut sni_name) = server_name {
277 : loop {
278 292 : if let Some(cert) = self.certs.get(sni_name) {
279 146 : return Some(cert.clone());
280 146 : }
281 146 : if let Some((_, rest)) = sni_name.split_once('.') {
282 146 : sni_name = rest;
283 146 : } else {
284 0 : return None;
285 : }
286 : }
287 : } else {
288 : // No SNI, use the default certificate, otherwise we can't get to
289 : // options parameter which can be used to set endpoint name too.
290 : // That means that non-SNI flow will not work for CNAME domains in
291 : // verify-full mode.
292 : //
293 : // If that will be a problem we can:
294 : //
295 : // a) Instead of multi-cert approach use single cert with extra
296 : // domains listed in Subject Alternative Name (SAN).
297 : // b) Deploy separate proxy instances for extra domains.
298 40 : self.default.as_ref().cloned()
299 : }
300 186 : }
301 : }
302 :
303 : /// Helper for cmdline cache options parsing.
304 1 : #[derive(Debug)]
305 : pub struct CacheOptions {
306 : /// Max number of entries.
307 : pub size: usize,
308 : /// Entry's time-to-live.
309 : pub ttl: Duration,
310 : }
311 :
312 : impl CacheOptions {
313 : /// Default options for [`crate::console::provider::NodeInfoCache`].
314 : pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
315 :
316 : /// Parse cache options passed via cmdline.
317 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
318 9 : fn parse(options: &str) -> anyhow::Result<Self> {
319 9 : let mut size = None;
320 9 : let mut ttl = None;
321 :
322 15 : for option in options.split(',') {
323 15 : let (key, value) = option
324 15 : .split_once('=')
325 15 : .with_context(|| format!("bad key-value pair: {option}"))?;
326 :
327 15 : match key {
328 15 : "size" => size = Some(value.parse()?),
329 6 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
330 0 : unknown => bail!("unknown key: {unknown}"),
331 : }
332 : }
333 :
334 : // TTL doesn't matter if cache is always empty.
335 9 : if let Some(0) = size {
336 5 : ttl.get_or_insert(Duration::default());
337 5 : }
338 :
339 : Ok(Self {
340 9 : size: size.context("missing `size`")?,
341 9 : ttl: ttl.context("missing `ttl`")?,
342 : })
343 9 : }
344 : }
345 :
346 : impl FromStr for CacheOptions {
347 : type Err = anyhow::Error;
348 :
349 9 : fn from_str(options: &str) -> Result<Self, Self::Err> {
350 9 : let error = || format!("failed to parse cache options '{options}'");
351 9 : Self::parse(options).with_context(error)
352 9 : }
353 : }
354 :
355 : /// Helper for cmdline cache options parsing.
356 1 : #[derive(Debug)]
357 : pub struct ProjectInfoCacheOptions {
358 : /// Max number of entries.
359 : pub size: usize,
360 : /// Entry's time-to-live.
361 : pub ttl: Duration,
362 : /// Max number of roles per endpoint.
363 : pub max_roles: usize,
364 : /// Gc interval.
365 : pub gc_interval: Duration,
366 : }
367 :
368 : impl ProjectInfoCacheOptions {
369 : /// Default options for [`crate::console::provider::NodeInfoCache`].
370 : pub const CACHE_DEFAULT_OPTIONS: &'static str =
371 : "size=10000,ttl=4m,max_roles=10,gc_interval=60m";
372 :
373 : /// Parse cache options passed via cmdline.
374 : /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
375 1 : fn parse(options: &str) -> anyhow::Result<Self> {
376 1 : let mut size = None;
377 1 : let mut ttl = None;
378 1 : let mut max_roles = None;
379 1 : let mut gc_interval = None;
380 :
381 4 : for option in options.split(',') {
382 4 : let (key, value) = option
383 4 : .split_once('=')
384 4 : .with_context(|| format!("bad key-value pair: {option}"))?;
385 :
386 4 : match key {
387 4 : "size" => size = Some(value.parse()?),
388 3 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
389 2 : "max_roles" => max_roles = Some(value.parse()?),
390 1 : "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
391 0 : unknown => bail!("unknown key: {unknown}"),
392 : }
393 : }
394 :
395 : // TTL doesn't matter if cache is always empty.
396 1 : if let Some(0) = size {
397 0 : ttl.get_or_insert(Duration::default());
398 1 : }
399 :
400 : Ok(Self {
401 1 : size: size.context("missing `size`")?,
402 1 : ttl: ttl.context("missing `ttl`")?,
403 1 : max_roles: max_roles.context("missing `max_roles`")?,
404 1 : gc_interval: gc_interval.context("missing `gc_interval`")?,
405 : })
406 1 : }
407 : }
408 :
409 : impl FromStr for ProjectInfoCacheOptions {
410 : type Err = anyhow::Error;
411 :
412 1 : fn from_str(options: &str) -> Result<Self, Self::Err> {
413 1 : let error = || format!("failed to parse cache options '{options}'");
414 1 : Self::parse(options).with_context(error)
415 1 : }
416 : }
417 :
418 : /// Helper for cmdline cache options parsing.
419 : pub struct WakeComputeLockOptions {
420 : /// The number of shards the lock map should have
421 : pub shards: usize,
422 : /// The number of allowed concurrent requests for each endpoitn
423 : pub permits: usize,
424 : /// Garbage collection epoch
425 : pub epoch: Duration,
426 : /// Lock timeout
427 : pub timeout: Duration,
428 : }
429 :
430 : impl WakeComputeLockOptions {
431 : /// Default options for [`crate::console::provider::ApiLocks`].
432 : pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
433 :
434 : // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
435 :
436 : /// Parse lock options passed via cmdline.
437 : /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
438 7 : fn parse(options: &str) -> anyhow::Result<Self> {
439 7 : let mut shards = None;
440 7 : let mut permits = None;
441 7 : let mut epoch = None;
442 7 : let mut timeout = None;
443 :
444 19 : for option in options.split(',') {
445 19 : let (key, value) = option
446 19 : .split_once('=')
447 19 : .with_context(|| format!("bad key-value pair: {option}"))?;
448 :
449 19 : match key {
450 19 : "shards" => shards = Some(value.parse()?),
451 15 : "permits" => permits = Some(value.parse()?),
452 8 : "epoch" => epoch = Some(humantime::parse_duration(value)?),
453 4 : "timeout" => timeout = Some(humantime::parse_duration(value)?),
454 0 : unknown => bail!("unknown key: {unknown}"),
455 : }
456 : }
457 :
458 : // these dont matter if lock is disabled
459 7 : if let Some(0) = permits {
460 3 : timeout = Some(Duration::default());
461 3 : epoch = Some(Duration::default());
462 3 : shards = Some(2);
463 4 : }
464 :
465 7 : let out = Self {
466 7 : shards: shards.context("missing `shards`")?,
467 7 : permits: permits.context("missing `permits`")?,
468 7 : epoch: epoch.context("missing `epoch`")?,
469 7 : timeout: timeout.context("missing `timeout`")?,
470 : };
471 :
472 7 : ensure!(out.shards > 1, "shard count must be > 1");
473 : ensure!(
474 7 : out.shards.is_power_of_two(),
475 0 : "shard count must be a power of two"
476 : );
477 :
478 7 : Ok(out)
479 7 : }
480 : }
481 :
482 : impl FromStr for WakeComputeLockOptions {
483 : type Err = anyhow::Error;
484 :
485 7 : fn from_str(options: &str) -> Result<Self, Self::Err> {
486 7 : let error = || format!("failed to parse cache lock options '{options}'");
487 7 : Self::parse(options).with_context(error)
488 7 : }
489 : }
490 :
491 : #[cfg(test)]
492 : mod tests {
493 : use super::*;
494 :
495 2 : #[test]
496 2 : fn test_parse_cache_options() -> anyhow::Result<()> {
497 2 : let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
498 2 : assert_eq!(size, 4096);
499 2 : assert_eq!(ttl, Duration::from_secs(5 * 60));
500 :
501 2 : let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
502 2 : assert_eq!(size, 2);
503 2 : assert_eq!(ttl, Duration::from_secs(4 * 60));
504 :
505 2 : let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
506 2 : assert_eq!(size, 0);
507 2 : assert_eq!(ttl, Duration::from_secs(1));
508 :
509 2 : let CacheOptions { size, ttl } = "size=0".parse()?;
510 2 : assert_eq!(size, 0);
511 2 : assert_eq!(ttl, Duration::default());
512 :
513 2 : Ok(())
514 2 : }
515 :
516 2 : #[test]
517 2 : fn test_parse_lock_options() -> anyhow::Result<()> {
518 : let WakeComputeLockOptions {
519 2 : epoch,
520 2 : permits,
521 2 : shards,
522 2 : timeout,
523 2 : } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
524 2 : assert_eq!(epoch, Duration::from_secs(10 * 60));
525 2 : assert_eq!(timeout, Duration::from_secs(1));
526 2 : assert_eq!(shards, 32);
527 2 : assert_eq!(permits, 4);
528 :
529 : let WakeComputeLockOptions {
530 2 : epoch,
531 2 : permits,
532 2 : shards,
533 2 : timeout,
534 2 : } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
535 2 : assert_eq!(epoch, Duration::from_secs(60));
536 2 : assert_eq!(timeout, Duration::from_millis(100));
537 2 : assert_eq!(shards, 16);
538 2 : assert_eq!(permits, 8);
539 :
540 : let WakeComputeLockOptions {
541 2 : epoch,
542 2 : permits,
543 2 : shards,
544 2 : timeout,
545 2 : } = "permits=0".parse()?;
546 2 : assert_eq!(epoch, Duration::ZERO);
547 2 : assert_eq!(timeout, Duration::ZERO);
548 2 : assert_eq!(shards, 2);
549 2 : assert_eq!(permits, 0);
550 :
551 2 : Ok(())
552 2 : }
553 : }
|