Line data Source code
1 : use crate::auth;
2 : use anyhow::{bail, ensure, Context, Ok};
3 : use rustls::sign;
4 : use std::{
5 : collections::{HashMap, HashSet},
6 : str::FromStr,
7 : sync::Arc,
8 : time::Duration,
9 : };
10 :
11 : pub struct ProxyConfig {
12 : pub tls_config: Option<TlsConfig>,
13 : pub auth_backend: auth::BackendType<'static, ()>,
14 : pub metric_collection: Option<MetricCollectionConfig>,
15 : pub allow_self_signed_compute: bool,
16 : }
17 :
18 1 : #[derive(Debug)]
19 : pub struct MetricCollectionConfig {
20 : pub endpoint: reqwest::Url,
21 : pub interval: Duration,
22 : }
23 :
24 : pub struct TlsConfig {
25 : pub config: Arc<rustls::ServerConfig>,
26 : pub common_names: Option<HashSet<String>>,
27 : }
28 :
29 : impl TlsConfig {
30 50 : pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
31 50 : self.config.clone()
32 50 : }
33 : }
34 :
35 : /// Configure TLS for the main endpoint.
36 14 : pub fn configure_tls(
37 14 : key_path: &str,
38 14 : cert_path: &str,
39 14 : certs_dir: Option<&String>,
40 14 : ) -> anyhow::Result<TlsConfig> {
41 14 : let mut cert_resolver = CertResolver::new();
42 14 :
43 14 : // add default certificate
44 14 : cert_resolver.add_cert(key_path, cert_path, true)?;
45 :
46 : // add extra certificates
47 14 : if let Some(certs_dir) = certs_dir {
48 0 : for entry in std::fs::read_dir(certs_dir)? {
49 0 : let entry = entry?;
50 0 : let path = entry.path();
51 0 : if path.is_dir() {
52 : // file names aligned with default cert-manager names
53 0 : let key_path = path.join("tls.key");
54 0 : let cert_path = path.join("tls.crt");
55 0 : if key_path.exists() && cert_path.exists() {
56 0 : cert_resolver.add_cert(
57 0 : &key_path.to_string_lossy(),
58 0 : &cert_path.to_string_lossy(),
59 0 : false,
60 0 : )?;
61 0 : }
62 0 : }
63 : }
64 14 : }
65 :
66 14 : let common_names = cert_resolver.get_common_names();
67 :
68 14 : let config = rustls::ServerConfig::builder()
69 14 : .with_safe_default_cipher_suites()
70 14 : .with_safe_default_kx_groups()
71 14 : // allow TLS 1.2 to be compatible with older client libraries
72 14 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
73 14 : .with_no_client_auth()
74 14 : .with_cert_resolver(Arc::new(cert_resolver))
75 14 : .into();
76 14 :
77 14 : Ok(TlsConfig {
78 14 : config,
79 14 : common_names: Some(common_names),
80 14 : })
81 14 : }
82 :
83 : struct CertResolver {
84 : certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
85 : default: Option<Arc<rustls::sign::CertifiedKey>>,
86 : }
87 :
88 : impl CertResolver {
89 14 : fn new() -> Self {
90 14 : Self {
91 14 : certs: HashMap::new(),
92 14 : default: None,
93 14 : }
94 14 : }
95 :
96 14 : fn add_cert(
97 14 : &mut self,
98 14 : key_path: &str,
99 14 : cert_path: &str,
100 14 : is_default: bool,
101 14 : ) -> anyhow::Result<()> {
102 14 : let priv_key = {
103 14 : let key_bytes = std::fs::read(key_path)
104 14 : .context(format!("Failed to read TLS keys at '{key_path}'"))?;
105 14 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
106 14 : .context(format!("Failed to parse TLS keys at '{key_path}'"))?;
107 :
108 14 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
109 14 : keys.pop().map(rustls::PrivateKey).unwrap()
110 : };
111 :
112 14 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
113 :
114 14 : let cert_chain_bytes = std::fs::read(cert_path)
115 14 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
116 :
117 14 : let cert_chain = {
118 14 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
119 14 : .context(format!(
120 14 : "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
121 14 : ))?
122 14 : .into_iter()
123 14 : .map(rustls::Certificate)
124 14 : .collect()
125 : };
126 :
127 14 : let common_name = {
128 14 : let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
129 14 : .context(format!(
130 14 : "Failed to parse PEM object from bytes from file at '{cert_path}'."
131 14 : ))?
132 : .1;
133 14 : let common_name = pem.parse_x509()?.subject().to_string();
134 14 :
135 14 : // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
136 14 : // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
137 14 : // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
138 14 : // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
139 14 : // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
140 14 : // of cutting off '*.' parts.
141 14 : if common_name.starts_with("CN=*.") {
142 14 : common_name.strip_prefix("CN=*.").map(|s| s.to_string())
143 : } else {
144 0 : common_name.strip_prefix("CN=").map(|s| s.to_string())
145 : }
146 : }
147 14 : .context(format!(
148 14 : "Failed to parse common name from certificate at '{cert_path}'."
149 14 : ))?;
150 :
151 14 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
152 14 :
153 14 : if is_default {
154 14 : self.default = Some(cert.clone());
155 14 : }
156 :
157 14 : self.certs.insert(common_name, cert);
158 14 :
159 14 : Ok(())
160 14 : }
161 :
162 14 : fn get_common_names(&self) -> HashSet<String> {
163 14 : self.certs.keys().map(|s| s.to_string()).collect()
164 14 : }
165 : }
166 :
167 : impl rustls::server::ResolvesServerCert for CertResolver {
168 : fn resolve(
169 : &self,
170 : _client_hello: rustls::server::ClientHello,
171 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
172 : // loop here and cut off more and more subdomains until we find
173 : // a match to get a proper wildcard support. OTOH, we now do not
174 : // use nested domains, so keep this simple for now.
175 : //
176 : // With the current coding foo.com will match *.foo.com and that
177 : // repeats behavior of the old code.
178 53 : if let Some(mut sni_name) = _client_hello.server_name() {
179 : loop {
180 82 : if let Some(cert) = self.certs.get(sni_name) {
181 41 : return Some(cert.clone());
182 41 : }
183 41 : if let Some((_, rest)) = sni_name.split_once('.') {
184 41 : sni_name = rest;
185 41 : } else {
186 0 : return None;
187 : }
188 : }
189 : } else {
190 : // No SNI, use the default certificate, otherwise we can't get to
191 : // options parameter which can be used to set endpoint name too.
192 : // That means that non-SNI flow will not work for CNAME domains in
193 : // verify-full mode.
194 : //
195 : // If that will be a problem we can:
196 : //
197 : // a) Instead of multi-cert approach use single cert with extra
198 : // domains listed in Subject Alternative Name (SAN).
199 : // b) Deploy separate proxy instances for extra domains.
200 12 : self.default.as_ref().cloned()
201 : }
202 53 : }
203 : }
204 :
205 : /// Helper for cmdline cache options parsing.
206 : pub struct CacheOptions {
207 : /// Max number of entries.
208 : pub size: usize,
209 : /// Entry's time-to-live.
210 : pub ttl: Duration,
211 : }
212 :
213 : impl CacheOptions {
214 : /// Default options for [`crate::console::provider::NodeInfoCache`].
215 : pub const DEFAULT_OPTIONS_NODE_INFO: &str = "size=4000,ttl=4m";
216 :
217 : /// Parse cache options passed via cmdline.
218 : /// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`].
219 4 : fn parse(options: &str) -> anyhow::Result<Self> {
220 4 : let mut size = None;
221 4 : let mut ttl = None;
222 :
223 7 : for option in options.split(',') {
224 7 : let (key, value) = option
225 7 : .split_once('=')
226 7 : .with_context(|| format!("bad key-value pair: {option}"))?;
227 :
228 7 : match key {
229 7 : "size" => size = Some(value.parse()?),
230 3 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
231 0 : unknown => bail!("unknown key: {unknown}"),
232 : }
233 : }
234 :
235 : // TTL doesn't matter if cache is always empty.
236 4 : if let Some(0) = size {
237 2 : ttl.get_or_insert(Duration::default());
238 2 : }
239 :
240 : Ok(Self {
241 4 : size: size.context("missing `size`")?,
242 4 : ttl: ttl.context("missing `ttl`")?,
243 : })
244 4 : }
245 : }
246 :
247 : impl FromStr for CacheOptions {
248 : type Err = anyhow::Error;
249 :
250 4 : fn from_str(options: &str) -> Result<Self, Self::Err> {
251 4 : let error = || format!("failed to parse cache options '{options}'");
252 4 : Self::parse(options).with_context(error)
253 4 : }
254 : }
255 :
256 : #[cfg(test)]
257 : mod tests {
258 : use super::*;
259 :
260 1 : #[test]
261 1 : fn test_parse_cache_options() -> anyhow::Result<()> {
262 1 : let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
263 1 : assert_eq!(size, 4096);
264 1 : assert_eq!(ttl, Duration::from_secs(5 * 60));
265 :
266 1 : let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
267 1 : assert_eq!(size, 2);
268 1 : assert_eq!(ttl, Duration::from_secs(4 * 60));
269 :
270 1 : let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
271 1 : assert_eq!(size, 0);
272 1 : assert_eq!(ttl, Duration::from_secs(1));
273 :
274 1 : let CacheOptions { size, ttl } = "size=0".parse()?;
275 1 : assert_eq!(size, 0);
276 1 : assert_eq!(ttl, Duration::default());
277 :
278 1 : Ok(())
279 1 : }
280 : }
|