TLA Line data Source code
1 : use crate::auth;
2 : use anyhow::{bail, ensure, Context, Ok};
3 : use rustls::sign;
4 : use std::{
5 : collections::{HashMap, HashSet},
6 : str::FromStr,
7 : sync::Arc,
8 : time::Duration,
9 : };
10 :
11 : pub struct ProxyConfig {
12 : pub tls_config: Option<TlsConfig>,
13 : pub auth_backend: auth::BackendType<'static, ()>,
14 : pub metric_collection: Option<MetricCollectionConfig>,
15 : pub allow_self_signed_compute: bool,
16 : pub http_config: HttpConfig,
17 : pub require_client_ip: bool,
18 : }
19 :
20 CBC 1 : #[derive(Debug)]
21 : pub struct MetricCollectionConfig {
22 : pub endpoint: reqwest::Url,
23 : pub interval: Duration,
24 : }
25 :
26 : pub struct TlsConfig {
27 : pub config: Arc<rustls::ServerConfig>,
28 : pub common_names: Option<HashSet<String>>,
29 : }
30 :
31 : pub struct HttpConfig {
32 : pub sql_over_http_timeout: tokio::time::Duration,
33 : }
34 :
35 : impl TlsConfig {
36 54 : pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
37 54 : self.config.clone()
38 54 : }
39 : }
40 :
41 : /// Configure TLS for the main endpoint.
42 16 : pub fn configure_tls(
43 16 : key_path: &str,
44 16 : cert_path: &str,
45 16 : certs_dir: Option<&String>,
46 16 : ) -> anyhow::Result<TlsConfig> {
47 16 : let mut cert_resolver = CertResolver::new();
48 16 :
49 16 : // add default certificate
50 16 : cert_resolver.add_cert(key_path, cert_path, true)?;
51 :
52 : // add extra certificates
53 16 : if let Some(certs_dir) = certs_dir {
54 UBC 0 : for entry in std::fs::read_dir(certs_dir)? {
55 0 : let entry = entry?;
56 0 : let path = entry.path();
57 0 : if path.is_dir() {
58 : // file names aligned with default cert-manager names
59 0 : let key_path = path.join("tls.key");
60 0 : let cert_path = path.join("tls.crt");
61 0 : if key_path.exists() && cert_path.exists() {
62 0 : cert_resolver.add_cert(
63 0 : &key_path.to_string_lossy(),
64 0 : &cert_path.to_string_lossy(),
65 0 : false,
66 0 : )?;
67 0 : }
68 0 : }
69 : }
70 CBC 16 : }
71 :
72 16 : let common_names = cert_resolver.get_common_names();
73 :
74 16 : let config = rustls::ServerConfig::builder()
75 16 : .with_safe_default_cipher_suites()
76 16 : .with_safe_default_kx_groups()
77 16 : // allow TLS 1.2 to be compatible with older client libraries
78 16 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
79 16 : .with_no_client_auth()
80 16 : .with_cert_resolver(Arc::new(cert_resolver))
81 16 : .into();
82 16 :
83 16 : Ok(TlsConfig {
84 16 : config,
85 16 : common_names: Some(common_names),
86 16 : })
87 16 : }
88 :
89 : struct CertResolver {
90 : certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
91 : default: Option<Arc<rustls::sign::CertifiedKey>>,
92 : }
93 :
94 : impl CertResolver {
95 16 : fn new() -> Self {
96 16 : Self {
97 16 : certs: HashMap::new(),
98 16 : default: None,
99 16 : }
100 16 : }
101 :
102 16 : fn add_cert(
103 16 : &mut self,
104 16 : key_path: &str,
105 16 : cert_path: &str,
106 16 : is_default: bool,
107 16 : ) -> anyhow::Result<()> {
108 16 : let priv_key = {
109 16 : let key_bytes = std::fs::read(key_path)
110 16 : .context(format!("Failed to read TLS keys at '{key_path}'"))?;
111 16 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
112 16 : .context(format!("Failed to parse TLS keys at '{key_path}'"))?;
113 :
114 16 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
115 16 : keys.pop().map(rustls::PrivateKey).unwrap()
116 : };
117 :
118 16 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
119 :
120 16 : let cert_chain_bytes = std::fs::read(cert_path)
121 16 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
122 :
123 16 : let cert_chain = {
124 16 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
125 16 : .context(format!(
126 16 : "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
127 16 : ))?
128 16 : .into_iter()
129 16 : .map(rustls::Certificate)
130 16 : .collect()
131 : };
132 :
133 16 : let common_name = {
134 16 : let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
135 16 : .context(format!(
136 16 : "Failed to parse PEM object from bytes from file at '{cert_path}'."
137 16 : ))?
138 : .1;
139 16 : let common_name = pem.parse_x509()?.subject().to_string();
140 16 :
141 16 : // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
142 16 : // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
143 16 : // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
144 16 : // and passed None instead, which blows up number of cases downstream code should handle. Proper coding
145 16 : // here should better avoid Option for common_names, and do wildcard-based certificate selection instead
146 16 : // of cutting off '*.' parts.
147 16 : if common_name.starts_with("CN=*.") {
148 16 : common_name.strip_prefix("CN=*.").map(|s| s.to_string())
149 : } else {
150 UBC 0 : common_name.strip_prefix("CN=").map(|s| s.to_string())
151 : }
152 : }
153 CBC 16 : .context(format!(
154 16 : "Failed to parse common name from certificate at '{cert_path}'."
155 16 : ))?;
156 :
157 16 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
158 16 :
159 16 : if is_default {
160 16 : self.default = Some(cert.clone());
161 16 : }
162 :
163 16 : self.certs.insert(common_name, cert);
164 16 :
165 16 : Ok(())
166 16 : }
167 :
168 16 : fn get_common_names(&self) -> HashSet<String> {
169 16 : self.certs.keys().map(|s| s.to_string()).collect()
170 16 : }
171 : }
172 :
173 : impl rustls::server::ResolvesServerCert for CertResolver {
174 : fn resolve(
175 : &self,
176 : _client_hello: rustls::server::ClientHello,
177 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
178 : // loop here and cut off more and more subdomains until we find
179 : // a match to get a proper wildcard support. OTOH, we now do not
180 : // use nested domains, so keep this simple for now.
181 : //
182 : // With the current coding foo.com will match *.foo.com and that
183 : // repeats behavior of the old code.
184 63 : if let Some(mut sni_name) = _client_hello.server_name() {
185 : loop {
186 102 : if let Some(cert) = self.certs.get(sni_name) {
187 51 : return Some(cert.clone());
188 51 : }
189 51 : if let Some((_, rest)) = sni_name.split_once('.') {
190 51 : sni_name = rest;
191 51 : } else {
192 UBC 0 : return None;
193 : }
194 : }
195 : } else {
196 : // No SNI, use the default certificate, otherwise we can't get to
197 : // options parameter which can be used to set endpoint name too.
198 : // That means that non-SNI flow will not work for CNAME domains in
199 : // verify-full mode.
200 : //
201 : // If that will be a problem we can:
202 : //
203 : // a) Instead of multi-cert approach use single cert with extra
204 : // domains listed in Subject Alternative Name (SAN).
205 : // b) Deploy separate proxy instances for extra domains.
206 CBC 12 : self.default.as_ref().cloned()
207 : }
208 63 : }
209 : }
210 :
211 : /// Helper for cmdline cache options parsing.
212 : pub struct CacheOptions {
213 : /// Max number of entries.
214 : pub size: usize,
215 : /// Entry's time-to-live.
216 : pub ttl: Duration,
217 : }
218 :
219 : impl CacheOptions {
220 : /// Default options for [`crate::console::provider::NodeInfoCache`].
221 : pub const DEFAULT_OPTIONS_NODE_INFO: &str = "size=4000,ttl=4m";
222 :
223 : /// Parse cache options passed via cmdline.
224 : /// Example: [`Self::DEFAULT_OPTIONS_NODE_INFO`].
225 4 : fn parse(options: &str) -> anyhow::Result<Self> {
226 4 : let mut size = None;
227 4 : let mut ttl = None;
228 :
229 7 : for option in options.split(',') {
230 7 : let (key, value) = option
231 7 : .split_once('=')
232 7 : .with_context(|| format!("bad key-value pair: {option}"))?;
233 :
234 7 : match key {
235 7 : "size" => size = Some(value.parse()?),
236 3 : "ttl" => ttl = Some(humantime::parse_duration(value)?),
237 UBC 0 : unknown => bail!("unknown key: {unknown}"),
238 : }
239 : }
240 :
241 : // TTL doesn't matter if cache is always empty.
242 CBC 4 : if let Some(0) = size {
243 2 : ttl.get_or_insert(Duration::default());
244 2 : }
245 :
246 : Ok(Self {
247 4 : size: size.context("missing `size`")?,
248 4 : ttl: ttl.context("missing `ttl`")?,
249 : })
250 4 : }
251 : }
252 :
253 : impl FromStr for CacheOptions {
254 : type Err = anyhow::Error;
255 :
256 4 : fn from_str(options: &str) -> Result<Self, Self::Err> {
257 4 : let error = || format!("failed to parse cache options '{options}'");
258 4 : Self::parse(options).with_context(error)
259 4 : }
260 : }
261 :
262 : #[cfg(test)]
263 : mod tests {
264 : use super::*;
265 :
266 1 : #[test]
267 1 : fn test_parse_cache_options() -> anyhow::Result<()> {
268 1 : let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
269 1 : assert_eq!(size, 4096);
270 1 : assert_eq!(ttl, Duration::from_secs(5 * 60));
271 :
272 1 : let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
273 1 : assert_eq!(size, 2);
274 1 : assert_eq!(ttl, Duration::from_secs(4 * 60));
275 :
276 1 : let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
277 1 : assert_eq!(size, 0);
278 1 : assert_eq!(ttl, Duration::from_secs(1));
279 :
280 1 : let CacheOptions { size, ttl } = "size=0".parse()?;
281 1 : assert_eq!(size, 0);
282 1 : assert_eq!(ttl, Duration::default());
283 :
284 1 : Ok(())
285 1 : }
286 : }
|