Line data Source code
1 : use std::collections::{HashMap, HashSet};
2 : use std::path::Path;
3 : use std::sync::Arc;
4 :
5 : use anyhow::{Context, bail};
6 : use itertools::Itertools;
7 : use rustls::crypto::ring::{self, sign};
8 : use rustls::pki_types::{CertificateDer, PrivateKeyDer};
9 : use rustls::sign::CertifiedKey;
10 : use x509_cert::der::{Reader, SliceReader};
11 :
12 : use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
13 :
14 : pub struct TlsConfig {
15 : // unfortunate split since we cannot change the ALPN on demand.
16 : // <https://github.com/rustls/rustls/issues/2260>
17 : pub http_config: Arc<rustls::ServerConfig>,
18 : pub pg_config: Arc<rustls::ServerConfig>,
19 : pub common_names: HashSet<String>,
20 : pub cert_resolver: Arc<CertResolver>,
21 : }
22 :
23 : /// Configure TLS for the main endpoint.
24 0 : pub fn configure_tls(
25 0 : key_path: &Path,
26 0 : cert_path: &Path,
27 0 : certs_dir: Option<&Path>,
28 0 : allow_tls_keylogfile: bool,
29 0 : ) -> anyhow::Result<TlsConfig> {
30 : // add default certificate
31 0 : let mut cert_resolver = CertResolver::parse_new(key_path, cert_path)?;
32 :
33 : // add extra certificates
34 0 : if let Some(certs_dir) = certs_dir {
35 0 : for entry in std::fs::read_dir(certs_dir)? {
36 0 : let entry = entry?;
37 0 : let path = entry.path();
38 0 : if path.is_dir() {
39 : // file names aligned with default cert-manager names
40 0 : let key_path = path.join("tls.key");
41 0 : let cert_path = path.join("tls.crt");
42 0 : if key_path.exists() && cert_path.exists() {
43 0 : cert_resolver.add_cert_path(&key_path, &cert_path)?;
44 0 : }
45 0 : }
46 : }
47 0 : }
48 :
49 0 : let common_names = cert_resolver.get_common_names();
50 :
51 0 : let cert_resolver = Arc::new(cert_resolver);
52 :
53 : // allow TLS 1.2 to be compatible with older client libraries
54 0 : let mut config =
55 0 : rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
56 0 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
57 0 : .context("ring should support TLS1.2 and TLS1.3")?
58 0 : .with_no_client_auth()
59 0 : .with_cert_resolver(cert_resolver.clone());
60 :
61 0 : config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
62 :
63 0 : if allow_tls_keylogfile {
64 0 : // KeyLogFile will check for the SSLKEYLOGFILE environment variable.
65 0 : config.key_log = Arc::new(rustls::KeyLogFile::new());
66 0 : }
67 :
68 0 : let mut http_config = config.clone();
69 0 : let mut pg_config = config;
70 :
71 0 : http_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
72 0 : pg_config.alpn_protocols = vec![b"postgresql".to_vec()];
73 :
74 0 : Ok(TlsConfig {
75 0 : http_config: Arc::new(http_config),
76 0 : pg_config: Arc::new(pg_config),
77 0 : common_names,
78 0 : cert_resolver,
79 0 : })
80 0 : }
81 :
82 : #[derive(Debug)]
83 : pub struct CertResolver {
84 : certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
85 : default: (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint),
86 : }
87 :
88 : impl CertResolver {
89 0 : fn parse_new(key_path: &Path, cert_path: &Path) -> anyhow::Result<Self> {
90 0 : let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
91 0 : Self::new(priv_key, cert_chain)
92 0 : }
93 :
94 21 : pub fn new(
95 21 : priv_key: PrivateKeyDer<'static>,
96 21 : cert_chain: Vec<CertificateDer<'static>>,
97 21 : ) -> anyhow::Result<Self> {
98 21 : let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
99 :
100 21 : let mut certs = HashMap::new();
101 21 : let default = (cert.clone(), tls_server_end_point);
102 21 : certs.insert(common_name, (cert, tls_server_end_point));
103 21 : Ok(Self { certs, default })
104 21 : }
105 :
106 0 : fn add_cert_path(&mut self, key_path: &Path, cert_path: &Path) -> anyhow::Result<()> {
107 0 : let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
108 0 : self.add_cert(priv_key, cert_chain)
109 0 : }
110 :
111 0 : fn add_cert(
112 0 : &mut self,
113 0 : priv_key: PrivateKeyDer<'static>,
114 0 : cert_chain: Vec<CertificateDer<'static>>,
115 0 : ) -> anyhow::Result<()> {
116 0 : let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
117 0 : self.certs.insert(common_name, (cert, tls_server_end_point));
118 0 : Ok(())
119 0 : }
120 :
121 21 : pub fn get_common_names(&self) -> HashSet<String> {
122 21 : self.certs.keys().cloned().collect()
123 21 : }
124 : }
125 :
126 0 : fn parse_key_cert(
127 0 : key_path: &Path,
128 0 : cert_path: &Path,
129 0 : ) -> anyhow::Result<(PrivateKeyDer<'static>, Vec<CertificateDer<'static>>)> {
130 0 : let priv_key = {
131 0 : let key_bytes = std::fs::read(key_path)
132 0 : .with_context(|| format!("Failed to read TLS keys at '{}'", key_path.display()))?;
133 0 : rustls_pemfile::private_key(&mut &key_bytes[..])
134 0 : .with_context(|| format!("Failed to parse TLS keys at '{}'", key_path.display()))?
135 0 : .with_context(|| format!("Failed to parse TLS keys at '{}'", key_path.display()))?
136 : };
137 :
138 0 : let cert_chain_bytes = std::fs::read(cert_path).context(format!(
139 0 : "Failed to read TLS cert file at '{}.'",
140 0 : cert_path.display()
141 0 : ))?;
142 :
143 0 : let cert_chain = {
144 0 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
145 0 : .try_collect()
146 0 : .with_context(|| {
147 0 : format!(
148 0 : "Failed to read TLS certificate chain from bytes from file at '{}'.",
149 0 : cert_path.display()
150 : )
151 0 : })?
152 : };
153 :
154 0 : Ok((priv_key, cert_chain))
155 0 : }
156 :
157 21 : fn process_key_cert(
158 21 : priv_key: PrivateKeyDer<'static>,
159 21 : cert_chain: Vec<CertificateDer<'static>>,
160 21 : ) -> anyhow::Result<(String, Arc<CertifiedKey>, TlsServerEndPoint)> {
161 21 : let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
162 :
163 21 : let first_cert = &cert_chain[0];
164 21 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
165 :
166 21 : let certificate = SliceReader::new(first_cert)
167 21 : .context("Failed to parse cerficiate")?
168 21 : .decode::<x509_cert::Certificate>()
169 21 : .context("Failed to parse cerficiate")?;
170 :
171 21 : let common_name = certificate.tbs_certificate.subject.to_string();
172 :
173 : // We need to get the canonical name for this certificate so we can match them against any domain names
174 : // seen within the proxy codebase.
175 : //
176 : // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
177 : // We need to remove the wildcard prefix for the purposes of certificate selection.
178 : //
179 : // auth-broker does not use SNI and instead uses the Neon-Connection-String header.
180 : // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
181 : //
182 : // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
183 : // validation, so let's we can continue with any common-name
184 21 : let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
185 0 : s.to_string()
186 21 : } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
187 0 : s.to_string()
188 21 : } else if let Some(s) = common_name.strip_prefix("CN=") {
189 21 : s.to_string()
190 : } else {
191 0 : bail!("Failed to parse common name from certificate")
192 : };
193 :
194 21 : let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
195 :
196 21 : Ok((common_name, cert, tls_server_end_point))
197 21 : }
198 :
199 : impl rustls::server::ResolvesServerCert for CertResolver {
200 0 : fn resolve(
201 0 : &self,
202 0 : client_hello: rustls::server::ClientHello<'_>,
203 0 : ) -> Option<Arc<rustls::sign::CertifiedKey>> {
204 0 : Some(self.resolve(client_hello.server_name()).0)
205 0 : }
206 : }
207 :
208 : impl CertResolver {
209 20 : pub fn resolve(
210 20 : &self,
211 20 : server_name: Option<&str>,
212 20 : ) -> (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint) {
213 : // loop here and cut off more and more subdomains until we find
214 : // a match to get a proper wildcard support. OTOH, we now do not
215 : // use nested domains, so keep this simple for now.
216 : //
217 : // With the current coding foo.com will match *.foo.com and that
218 : // repeats behavior of the old code.
219 20 : if let Some(mut sni_name) = server_name {
220 : loop {
221 40 : if let Some(cert) = self.certs.get(sni_name) {
222 20 : return cert.clone();
223 20 : }
224 20 : if let Some((_, rest)) = sni_name.split_once('.') {
225 20 : sni_name = rest;
226 20 : } else {
227 : // The customer has some custom DNS mapping - just return
228 : // a default certificate.
229 : //
230 : // This will error if the customer uses anything stronger
231 : // than sslmode=require. That's a choice they can make.
232 0 : return self.default.clone();
233 : }
234 : }
235 : } else {
236 : // No SNI, use the default certificate, otherwise we can't get to
237 : // options parameter which can be used to set endpoint name too.
238 : // That means that non-SNI flow will not work for CNAME domains in
239 : // verify-full mode.
240 : //
241 : // If that will be a problem we can:
242 : //
243 : // a) Instead of multi-cert approach use single cert with extra
244 : // domains listed in Subject Alternative Name (SAN).
245 : // b) Deploy separate proxy instances for extra domains.
246 0 : self.default.clone()
247 : }
248 20 : }
249 : }
|