Line data Source code
1 : use std::collections::{HashMap, HashSet};
2 : use std::sync::Arc;
3 :
4 : use anyhow::{Context, bail};
5 : use itertools::Itertools;
6 : use rustls::crypto::ring::{self, sign};
7 : use rustls::pki_types::{CertificateDer, PrivateKeyDer};
8 : use x509_cert::der::{Reader, SliceReader};
9 :
10 : use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
11 :
12 : pub struct TlsConfig {
13 : // unfortunate split since we cannot change the ALPN on demand.
14 : // <https://github.com/rustls/rustls/issues/2260>
15 : pub http_config: Arc<rustls::ServerConfig>,
16 : pub pg_config: Arc<rustls::ServerConfig>,
17 : pub common_names: HashSet<String>,
18 : pub cert_resolver: Arc<CertResolver>,
19 : }
20 :
21 : /// Configure TLS for the main endpoint.
22 0 : pub fn configure_tls(
23 0 : key_path: &str,
24 0 : cert_path: &str,
25 0 : certs_dir: Option<&String>,
26 0 : allow_tls_keylogfile: bool,
27 0 : ) -> anyhow::Result<TlsConfig> {
28 0 : let mut cert_resolver = CertResolver::new();
29 0 :
30 0 : // add default certificate
31 0 : cert_resolver.add_cert_path(key_path, cert_path, true)?;
32 :
33 : // add extra certificates
34 0 : if let Some(certs_dir) = certs_dir {
35 0 : for entry in std::fs::read_dir(certs_dir)? {
36 0 : let entry = entry?;
37 0 : let path = entry.path();
38 0 : if path.is_dir() {
39 : // file names aligned with default cert-manager names
40 0 : let key_path = path.join("tls.key");
41 0 : let cert_path = path.join("tls.crt");
42 0 : if key_path.exists() && cert_path.exists() {
43 0 : cert_resolver.add_cert_path(
44 0 : &key_path.to_string_lossy(),
45 0 : &cert_path.to_string_lossy(),
46 0 : false,
47 0 : )?;
48 0 : }
49 0 : }
50 : }
51 0 : }
52 :
53 0 : let common_names = cert_resolver.get_common_names();
54 0 :
55 0 : let cert_resolver = Arc::new(cert_resolver);
56 :
57 : // allow TLS 1.2 to be compatible with older client libraries
58 0 : let mut config =
59 0 : rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
60 0 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
61 0 : .context("ring should support TLS1.2 and TLS1.3")?
62 0 : .with_no_client_auth()
63 0 : .with_cert_resolver(cert_resolver.clone());
64 0 :
65 0 : config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
66 0 :
67 0 : if allow_tls_keylogfile {
68 0 : // KeyLogFile will check for the SSLKEYLOGFILE environment variable.
69 0 : config.key_log = Arc::new(rustls::KeyLogFile::new());
70 0 : }
71 :
72 0 : let mut http_config = config.clone();
73 0 : let mut pg_config = config;
74 0 :
75 0 : http_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
76 0 : pg_config.alpn_protocols = vec![b"postgresql".to_vec()];
77 0 :
78 0 : Ok(TlsConfig {
79 0 : http_config: Arc::new(http_config),
80 0 : pg_config: Arc::new(pg_config),
81 0 : common_names,
82 0 : cert_resolver,
83 0 : })
84 0 : }
85 :
86 : #[derive(Default, Debug)]
87 : pub struct CertResolver {
88 : certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
89 : default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
90 : }
91 :
92 : impl CertResolver {
93 21 : pub fn new() -> Self {
94 21 : Self::default()
95 21 : }
96 :
97 0 : fn add_cert_path(
98 0 : &mut self,
99 0 : key_path: &str,
100 0 : cert_path: &str,
101 0 : is_default: bool,
102 0 : ) -> anyhow::Result<()> {
103 0 : let priv_key = {
104 0 : let key_bytes = std::fs::read(key_path)
105 0 : .with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
106 0 : rustls_pemfile::private_key(&mut &key_bytes[..])
107 0 : .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
108 0 : .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
109 : };
110 :
111 0 : let cert_chain_bytes = std::fs::read(cert_path)
112 0 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
113 :
114 0 : let cert_chain = {
115 0 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
116 0 : .try_collect()
117 0 : .with_context(|| {
118 0 : format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
119 0 : })?
120 : };
121 :
122 0 : self.add_cert(priv_key, cert_chain, is_default)
123 0 : }
124 :
125 21 : pub fn add_cert(
126 21 : &mut self,
127 21 : priv_key: PrivateKeyDer<'static>,
128 21 : cert_chain: Vec<CertificateDer<'static>>,
129 21 : is_default: bool,
130 21 : ) -> anyhow::Result<()> {
131 21 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
132 :
133 21 : let first_cert = &cert_chain[0];
134 21 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
135 :
136 21 : let certificate = SliceReader::new(first_cert)
137 21 : .context("Failed to parse cerficiate")?
138 21 : .decode::<x509_cert::Certificate>()
139 21 : .context("Failed to parse cerficiate")?;
140 :
141 21 : let common_name = certificate.tbs_certificate.subject.to_string();
142 :
143 : // We need to get the canonical name for this certificate so we can match them against any domain names
144 : // seen within the proxy codebase.
145 : //
146 : // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
147 : // We need to remove the wildcard prefix for the purposes of certificate selection.
148 : //
149 : // auth-broker does not use SNI and instead uses the Neon-Connection-String header.
150 : // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
151 : //
152 : // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
153 : // validation, so let's we can continue with any common-name
154 21 : let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
155 0 : s.to_string()
156 21 : } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
157 0 : s.to_string()
158 21 : } else if let Some(s) = common_name.strip_prefix("CN=") {
159 21 : s.to_string()
160 : } else {
161 0 : bail!("Failed to parse common name from certificate")
162 : };
163 :
164 21 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
165 21 :
166 21 : if is_default {
167 21 : self.default = Some((cert.clone(), tls_server_end_point));
168 21 : }
169 :
170 21 : self.certs.insert(common_name, (cert, tls_server_end_point));
171 21 :
172 21 : Ok(())
173 21 : }
174 :
175 21 : pub fn get_common_names(&self) -> HashSet<String> {
176 21 : self.certs.keys().map(|s| s.to_string()).collect()
177 21 : }
178 : }
179 :
180 : impl rustls::server::ResolvesServerCert for CertResolver {
181 0 : fn resolve(
182 0 : &self,
183 0 : client_hello: rustls::server::ClientHello<'_>,
184 0 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
185 0 : self.resolve(client_hello.server_name()).map(|x| x.0)
186 0 : }
187 : }
188 :
189 : impl CertResolver {
190 20 : pub fn resolve(
191 20 : &self,
192 20 : server_name: Option<&str>,
193 20 : ) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
194 : // loop here and cut off more and more subdomains until we find
195 : // a match to get a proper wildcard support. OTOH, we now do not
196 : // use nested domains, so keep this simple for now.
197 : //
198 : // With the current coding foo.com will match *.foo.com and that
199 : // repeats behavior of the old code.
200 20 : if let Some(mut sni_name) = server_name {
201 : loop {
202 40 : if let Some(cert) = self.certs.get(sni_name) {
203 20 : return Some(cert.clone());
204 20 : }
205 20 : if let Some((_, rest)) = sni_name.split_once('.') {
206 20 : sni_name = rest;
207 20 : } else {
208 0 : return None;
209 : }
210 : }
211 : } else {
212 : // No SNI, use the default certificate, otherwise we can't get to
213 : // options parameter which can be used to set endpoint name too.
214 : // That means that non-SNI flow will not work for CNAME domains in
215 : // verify-full mode.
216 : //
217 : // If that will be a problem we can:
218 : //
219 : // a) Instead of multi-cert approach use single cert with extra
220 : // domains listed in Subject Alternative Name (SAN).
221 : // b) Deploy separate proxy instances for extra domains.
222 0 : self.default.clone()
223 : }
224 20 : }
225 : }
|