Line data Source code
1 : use std::{sync::Arc, time::Duration};
2 :
3 : use anyhow::Context;
4 : use arc_swap::ArcSwap;
5 : use camino::Utf8Path;
6 : use rustls::{
7 : pki_types::{CertificateDer, PrivateKeyDer},
8 : server::{ClientHello, ResolvesServerCert},
9 : sign::CertifiedKey,
10 : };
11 :
12 0 : pub async fn load_cert_chain(filename: &Utf8Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
13 0 : let cert_data = tokio::fs::read(filename)
14 0 : .await
15 0 : .context(format!("failed reading certificate file {filename:?}"))?;
16 0 : let mut reader = std::io::Cursor::new(&cert_data);
17 :
18 0 : let cert_chain = rustls_pemfile::certs(&mut reader)
19 0 : .collect::<Result<Vec<_>, _>>()
20 0 : .context(format!("failed parsing certificate from file {filename:?}"))?;
21 :
22 0 : Ok(cert_chain)
23 0 : }
24 :
25 0 : pub async fn load_private_key(filename: &Utf8Path) -> anyhow::Result<PrivateKeyDer<'static>> {
26 0 : let key_data = tokio::fs::read(filename)
27 0 : .await
28 0 : .context(format!("failed reading private key file {filename:?}"))?;
29 0 : let mut reader = std::io::Cursor::new(&key_data);
30 :
31 0 : let key = rustls_pemfile::private_key(&mut reader)
32 0 : .context(format!("failed parsing private key from file {filename:?}"))?;
33 :
34 0 : key.ok_or(anyhow::anyhow!(
35 0 : "no private key found in {}",
36 0 : filename.as_str(),
37 0 : ))
38 0 : }
39 :
40 0 : pub async fn load_certified_key(
41 0 : key_filename: &Utf8Path,
42 0 : cert_filename: &Utf8Path,
43 0 : ) -> anyhow::Result<CertifiedKey> {
44 0 : let cert_chain = load_cert_chain(cert_filename).await?;
45 0 : let key = load_private_key(key_filename).await?;
46 :
47 0 : let key = rustls::crypto::ring::default_provider()
48 0 : .key_provider
49 0 : .load_private_key(key)?;
50 :
51 0 : let certified_key = CertifiedKey::new(cert_chain, key);
52 0 : certified_key.keys_match()?;
53 0 : Ok(certified_key)
54 0 : }
55 :
56 : /// Implementation of [`rustls::server::ResolvesServerCert`] which reloads certificates from
57 : /// the disk periodically.
58 : #[derive(Debug)]
59 : pub struct ReloadingCertificateResolver {
60 : certified_key: ArcSwap<CertifiedKey>,
61 : }
62 :
63 : impl ReloadingCertificateResolver {
64 : /// Creates a new Resolver by loading certificate and private key from FS and
65 : /// creating tokio::task to reload them with provided reload_period.
66 0 : pub async fn new(
67 0 : key_filename: &Utf8Path,
68 0 : cert_filename: &Utf8Path,
69 0 : reload_period: Duration,
70 0 : ) -> anyhow::Result<Arc<Self>> {
71 0 : let this = Arc::new(Self {
72 : certified_key: ArcSwap::from_pointee(
73 0 : load_certified_key(key_filename, cert_filename).await?,
74 : ),
75 : });
76 :
77 0 : tokio::spawn({
78 0 : let weak_this = Arc::downgrade(&this);
79 0 : let key_filename = key_filename.to_owned();
80 0 : let cert_filename = cert_filename.to_owned();
81 0 : async move {
82 0 : let start = tokio::time::Instant::now() + reload_period;
83 0 : let mut interval = tokio::time::interval_at(start, reload_period);
84 0 : let mut last_reload_failed = false;
85 : loop {
86 0 : interval.tick().await;
87 0 : let this = match weak_this.upgrade() {
88 0 : Some(this) => this,
89 0 : None => break, // Resolver has been destroyed, exit.
90 : };
91 0 : match load_certified_key(&key_filename, &cert_filename).await {
92 0 : Ok(new_certified_key) => {
93 0 : if new_certified_key.cert == this.certified_key.load().cert {
94 0 : tracing::debug!("Certificate has not changed since last reloading");
95 : } else {
96 0 : tracing::info!("Certificate has been reloaded");
97 0 : this.certified_key.store(Arc::new(new_certified_key));
98 : }
99 0 : last_reload_failed = false;
100 : }
101 0 : Err(err) => {
102 0 : // Note: Reloading certs may fail if it conflicts with the script updating
103 0 : // the files at the same time. Warn only if the error is persistent.
104 0 : if last_reload_failed {
105 0 : tracing::warn!("Error reloading certificate: {err:#}");
106 : } else {
107 0 : tracing::info!("Error reloading certificate: {err:#}");
108 : }
109 0 : last_reload_failed = true;
110 : }
111 : }
112 : }
113 0 : }
114 0 : });
115 0 :
116 0 : Ok(this)
117 0 : }
118 : }
119 :
120 : impl ResolvesServerCert for ReloadingCertificateResolver {
121 0 : fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
122 0 : Some(self.certified_key.load_full())
123 0 : }
124 : }
|