Line data Source code
1 : use std::{sync::Arc, time::Duration};
2 :
3 : use anyhow::Context;
4 : use arc_swap::ArcSwap;
5 : use camino::Utf8Path;
6 : use metrics::{IntCounterVec, UIntGaugeVec, register_int_counter_vec, register_uint_gauge_vec};
7 : use once_cell::sync::Lazy;
8 : use rustls::{
9 : pki_types::{CertificateDer, PrivateKeyDer, UnixTime},
10 : server::{ClientHello, ResolvesServerCert},
11 : sign::CertifiedKey,
12 : };
13 : use x509_cert::der::Reader;
14 :
15 0 : pub async fn load_cert_chain(filename: &Utf8Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
16 0 : let cert_data = tokio::fs::read(filename)
17 0 : .await
18 0 : .context(format!("failed reading certificate file {filename:?}"))?;
19 0 : let mut reader = std::io::Cursor::new(&cert_data);
20 :
21 0 : let cert_chain = rustls_pemfile::certs(&mut reader)
22 0 : .collect::<Result<Vec<_>, _>>()
23 0 : .context(format!("failed parsing certificate from file {filename:?}"))?;
24 :
25 0 : Ok(cert_chain)
26 0 : }
27 :
28 0 : pub async fn load_private_key(filename: &Utf8Path) -> anyhow::Result<PrivateKeyDer<'static>> {
29 0 : let key_data = tokio::fs::read(filename)
30 0 : .await
31 0 : .context(format!("failed reading private key file {filename:?}"))?;
32 0 : let mut reader = std::io::Cursor::new(&key_data);
33 :
34 0 : let key = rustls_pemfile::private_key(&mut reader)
35 0 : .context(format!("failed parsing private key from file {filename:?}"))?;
36 :
37 0 : key.ok_or(anyhow::anyhow!(
38 0 : "no private key found in {}",
39 0 : filename.as_str(),
40 0 : ))
41 0 : }
42 :
43 0 : pub async fn load_certified_key(
44 0 : key_filename: &Utf8Path,
45 0 : cert_filename: &Utf8Path,
46 0 : ) -> anyhow::Result<CertifiedKey> {
47 0 : let cert_chain = load_cert_chain(cert_filename).await?;
48 0 : let key = load_private_key(key_filename).await?;
49 :
50 0 : let key = rustls::crypto::ring::default_provider()
51 0 : .key_provider
52 0 : .load_private_key(key)?;
53 :
54 0 : let certified_key = CertifiedKey::new(cert_chain, key);
55 0 : certified_key.keys_match()?;
56 0 : Ok(certified_key)
57 0 : }
58 :
59 : /// rustls's CertifiedKey with extra parsed fields used for metrics.
60 : struct ParsedCertifiedKey {
61 : certified_key: CertifiedKey,
62 : expiration_time: UnixTime,
63 : }
64 :
65 : /// Parse expiration time from an X509 certificate.
66 0 : fn parse_expiration_time(cert: &CertificateDer<'_>) -> anyhow::Result<UnixTime> {
67 0 : let parsed_cert = x509_cert::der::SliceReader::new(cert)
68 0 : .context("Failed to parse cerficiate")?
69 0 : .decode::<x509_cert::Certificate>()
70 0 : .context("Failed to parse cerficiate")?;
71 :
72 0 : Ok(UnixTime::since_unix_epoch(
73 0 : parsed_cert
74 0 : .tbs_certificate
75 0 : .validity
76 0 : .not_after
77 0 : .to_unix_duration(),
78 0 : ))
79 0 : }
80 :
81 0 : async fn load_and_parse_certified_key(
82 0 : key_filename: &Utf8Path,
83 0 : cert_filename: &Utf8Path,
84 0 : ) -> anyhow::Result<ParsedCertifiedKey> {
85 0 : let certified_key = load_certified_key(key_filename, cert_filename).await?;
86 0 : let expiration_time = parse_expiration_time(certified_key.end_entity_cert()?)?;
87 0 : Ok(ParsedCertifiedKey {
88 0 : certified_key,
89 0 : expiration_time,
90 0 : })
91 0 : }
92 :
93 0 : static CERT_EXPIRATION_TIME: Lazy<UIntGaugeVec> = Lazy::new(|| {
94 0 : register_uint_gauge_vec!(
95 0 : "tls_certs_expiration_time_seconds",
96 0 : "Expiration time of the loaded certificate since unix epoch in seconds",
97 0 : &["resolver_name"]
98 0 : )
99 0 : .expect("failed to define a metric")
100 0 : });
101 :
102 0 : static CERT_RELOAD_STARTED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
103 0 : register_int_counter_vec!(
104 0 : "tls_certs_reload_started_total",
105 0 : "Number of certificate reload loop iterations started",
106 0 : &["resolver_name"]
107 0 : )
108 0 : .expect("failed to define a metric")
109 0 : });
110 :
111 0 : static CERT_RELOAD_UPDATED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
112 0 : register_int_counter_vec!(
113 0 : "tls_certs_reload_updated_total",
114 0 : "Number of times the certificate was updated to the new one",
115 0 : &["resolver_name"]
116 0 : )
117 0 : .expect("failed to define a metric")
118 0 : });
119 :
120 0 : static CERT_RELOAD_FAILED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
121 0 : register_int_counter_vec!(
122 0 : "tls_certs_reload_failed_total",
123 0 : "Number of times the certificate reload failed",
124 0 : &["resolver_name"]
125 0 : )
126 0 : .expect("failed to define a metric")
127 0 : });
128 :
129 : /// Implementation of [`rustls::server::ResolvesServerCert`] which reloads certificates from
130 : /// the disk periodically.
131 : #[derive(Debug)]
132 : pub struct ReloadingCertificateResolver {
133 : certified_key: ArcSwap<CertifiedKey>,
134 : }
135 :
136 : impl ReloadingCertificateResolver {
137 : /// Creates a new Resolver by loading certificate and private key from FS and
138 : /// creating tokio::task to reload them with provided reload_period.
139 : /// resolver_name is used as metric's label.
140 0 : pub async fn new(
141 0 : resolver_name: &str,
142 0 : key_filename: &Utf8Path,
143 0 : cert_filename: &Utf8Path,
144 0 : reload_period: Duration,
145 0 : ) -> anyhow::Result<Arc<Self>> {
146 0 : // Create metrics for current resolver.
147 0 : let cert_expiration_time = CERT_EXPIRATION_TIME.with_label_values(&[resolver_name]);
148 0 : let cert_reload_started_counter =
149 0 : CERT_RELOAD_STARTED_COUNTER.with_label_values(&[resolver_name]);
150 0 : let cert_reload_updated_counter =
151 0 : CERT_RELOAD_UPDATED_COUNTER.with_label_values(&[resolver_name]);
152 0 : let cert_reload_failed_counter =
153 0 : CERT_RELOAD_FAILED_COUNTER.with_label_values(&[resolver_name]);
154 :
155 0 : let parsed_key = load_and_parse_certified_key(key_filename, cert_filename).await?;
156 :
157 0 : let this = Arc::new(Self {
158 0 : certified_key: ArcSwap::from_pointee(parsed_key.certified_key),
159 0 : });
160 0 : cert_expiration_time.set(parsed_key.expiration_time.as_secs());
161 0 :
162 0 : tokio::spawn({
163 0 : let weak_this = Arc::downgrade(&this);
164 0 : let key_filename = key_filename.to_owned();
165 0 : let cert_filename = cert_filename.to_owned();
166 0 : async move {
167 0 : let start = tokio::time::Instant::now() + reload_period;
168 0 : let mut interval = tokio::time::interval_at(start, reload_period);
169 0 : let mut last_reload_failed = false;
170 : loop {
171 0 : interval.tick().await;
172 0 : let this = match weak_this.upgrade() {
173 0 : Some(this) => this,
174 0 : None => break, // Resolver has been destroyed, exit.
175 : };
176 0 : cert_reload_started_counter.inc();
177 0 :
178 0 : match load_and_parse_certified_key(&key_filename, &cert_filename).await {
179 0 : Ok(parsed_key) => {
180 0 : if parsed_key.certified_key.cert == this.certified_key.load().cert {
181 0 : tracing::debug!("Certificate has not changed since last reloading");
182 : } else {
183 0 : tracing::info!("Certificate has been reloaded");
184 0 : this.certified_key.store(Arc::new(parsed_key.certified_key));
185 0 : cert_expiration_time.set(parsed_key.expiration_time.as_secs());
186 0 : cert_reload_updated_counter.inc();
187 : }
188 0 : last_reload_failed = false;
189 : }
190 0 : Err(err) => {
191 0 : cert_reload_failed_counter.inc();
192 0 : // Note: Reloading certs may fail if it conflicts with the script updating
193 0 : // the files at the same time. Warn only if the error is persistent.
194 0 : if last_reload_failed {
195 0 : tracing::warn!("Error reloading certificate: {err:#}");
196 : } else {
197 0 : tracing::info!("Error reloading certificate: {err:#}");
198 : }
199 0 : last_reload_failed = true;
200 : }
201 : }
202 : }
203 0 : }
204 0 : });
205 0 :
206 0 : Ok(this)
207 0 : }
208 : }
209 :
210 : impl ResolvesServerCert for ReloadingCertificateResolver {
211 0 : fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
212 0 : Some(self.certified_key.load_full())
213 0 : }
214 : }
|