LCOV - code coverage report
Current view: top level - libs/http-utils/src - tls_certs.rs (source / functions) Coverage Total Hit
Test: 5e392a02abbad1ab595f4dba672e219a49f7f539.info Lines: 0.0 % 150 0
Test Date: 2025-04-11 22:43:24 Functions: 0.0 % 17 0

            Line data    Source code
       1              : use std::{sync::Arc, time::Duration};
       2              : 
       3              : use anyhow::Context;
       4              : use arc_swap::ArcSwap;
       5              : use camino::Utf8Path;
       6              : use metrics::{IntCounterVec, UIntGaugeVec, register_int_counter_vec, register_uint_gauge_vec};
       7              : use once_cell::sync::Lazy;
       8              : use rustls::{
       9              :     pki_types::{CertificateDer, PrivateKeyDer, UnixTime},
      10              :     server::{ClientHello, ResolvesServerCert},
      11              :     sign::CertifiedKey,
      12              : };
      13              : use x509_cert::der::Reader;
      14              : 
      15            0 : pub async fn load_cert_chain(filename: &Utf8Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
      16            0 :     let cert_data = tokio::fs::read(filename)
      17            0 :         .await
      18            0 :         .context(format!("failed reading certificate file {filename:?}"))?;
      19            0 :     let mut reader = std::io::Cursor::new(&cert_data);
      20              : 
      21            0 :     let cert_chain = rustls_pemfile::certs(&mut reader)
      22            0 :         .collect::<Result<Vec<_>, _>>()
      23            0 :         .context(format!("failed parsing certificate from file {filename:?}"))?;
      24              : 
      25            0 :     Ok(cert_chain)
      26            0 : }
      27              : 
      28            0 : pub async fn load_private_key(filename: &Utf8Path) -> anyhow::Result<PrivateKeyDer<'static>> {
      29            0 :     let key_data = tokio::fs::read(filename)
      30            0 :         .await
      31            0 :         .context(format!("failed reading private key file {filename:?}"))?;
      32            0 :     let mut reader = std::io::Cursor::new(&key_data);
      33              : 
      34            0 :     let key = rustls_pemfile::private_key(&mut reader)
      35            0 :         .context(format!("failed parsing private key from file {filename:?}"))?;
      36              : 
      37            0 :     key.ok_or(anyhow::anyhow!(
      38            0 :         "no private key found in {}",
      39            0 :         filename.as_str(),
      40            0 :     ))
      41            0 : }
      42              : 
      43            0 : pub async fn load_certified_key(
      44            0 :     key_filename: &Utf8Path,
      45            0 :     cert_filename: &Utf8Path,
      46            0 : ) -> anyhow::Result<CertifiedKey> {
      47            0 :     let cert_chain = load_cert_chain(cert_filename).await?;
      48            0 :     let key = load_private_key(key_filename).await?;
      49              : 
      50            0 :     let key = rustls::crypto::ring::default_provider()
      51            0 :         .key_provider
      52            0 :         .load_private_key(key)?;
      53              : 
      54            0 :     let certified_key = CertifiedKey::new(cert_chain, key);
      55            0 :     certified_key.keys_match()?;
      56            0 :     Ok(certified_key)
      57            0 : }
      58              : 
      59              : /// rustls's CertifiedKey with extra parsed fields used for metrics.
      60              : struct ParsedCertifiedKey {
      61              :     certified_key: CertifiedKey,
      62              :     expiration_time: UnixTime,
      63              : }
      64              : 
      65              : /// Parse expiration time from an X509 certificate.
      66            0 : fn parse_expiration_time(cert: &CertificateDer<'_>) -> anyhow::Result<UnixTime> {
      67            0 :     let parsed_cert = x509_cert::der::SliceReader::new(cert)
      68            0 :         .context("Failed to parse cerficiate")?
      69            0 :         .decode::<x509_cert::Certificate>()
      70            0 :         .context("Failed to parse cerficiate")?;
      71              : 
      72            0 :     Ok(UnixTime::since_unix_epoch(
      73            0 :         parsed_cert
      74            0 :             .tbs_certificate
      75            0 :             .validity
      76            0 :             .not_after
      77            0 :             .to_unix_duration(),
      78            0 :     ))
      79            0 : }
      80              : 
      81            0 : async fn load_and_parse_certified_key(
      82            0 :     key_filename: &Utf8Path,
      83            0 :     cert_filename: &Utf8Path,
      84            0 : ) -> anyhow::Result<ParsedCertifiedKey> {
      85            0 :     let certified_key = load_certified_key(key_filename, cert_filename).await?;
      86            0 :     let expiration_time = parse_expiration_time(certified_key.end_entity_cert()?)?;
      87            0 :     Ok(ParsedCertifiedKey {
      88            0 :         certified_key,
      89            0 :         expiration_time,
      90            0 :     })
      91            0 : }
      92              : 
      93            0 : static CERT_EXPIRATION_TIME: Lazy<UIntGaugeVec> = Lazy::new(|| {
      94            0 :     register_uint_gauge_vec!(
      95            0 :         "tls_certs_expiration_time_seconds",
      96            0 :         "Expiration time of the loaded certificate since unix epoch in seconds",
      97            0 :         &["resolver_name"]
      98            0 :     )
      99            0 :     .expect("failed to define a metric")
     100            0 : });
     101              : 
     102            0 : static CERT_RELOAD_STARTED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
     103            0 :     register_int_counter_vec!(
     104            0 :         "tls_certs_reload_started_total",
     105            0 :         "Number of certificate reload loop iterations started",
     106            0 :         &["resolver_name"]
     107            0 :     )
     108            0 :     .expect("failed to define a metric")
     109            0 : });
     110              : 
     111            0 : static CERT_RELOAD_UPDATED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
     112            0 :     register_int_counter_vec!(
     113            0 :         "tls_certs_reload_updated_total",
     114            0 :         "Number of times the certificate was updated to the new one",
     115            0 :         &["resolver_name"]
     116            0 :     )
     117            0 :     .expect("failed to define a metric")
     118            0 : });
     119              : 
     120            0 : static CERT_RELOAD_FAILED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
     121            0 :     register_int_counter_vec!(
     122            0 :         "tls_certs_reload_failed_total",
     123            0 :         "Number of times the certificate reload failed",
     124            0 :         &["resolver_name"]
     125            0 :     )
     126            0 :     .expect("failed to define a metric")
     127            0 : });
     128              : 
     129              : /// Implementation of [`rustls::server::ResolvesServerCert`] which reloads certificates from
     130              : /// the disk periodically.
     131              : #[derive(Debug)]
     132              : pub struct ReloadingCertificateResolver {
     133              :     certified_key: ArcSwap<CertifiedKey>,
     134              : }
     135              : 
     136              : impl ReloadingCertificateResolver {
     137              :     /// Creates a new Resolver by loading certificate and private key from FS and
     138              :     /// creating tokio::task to reload them with provided reload_period.
     139              :     /// resolver_name is used as metric's label.
     140            0 :     pub async fn new(
     141            0 :         resolver_name: &str,
     142            0 :         key_filename: &Utf8Path,
     143            0 :         cert_filename: &Utf8Path,
     144            0 :         reload_period: Duration,
     145            0 :     ) -> anyhow::Result<Arc<Self>> {
     146            0 :         // Create metrics for current resolver.
     147            0 :         let cert_expiration_time = CERT_EXPIRATION_TIME.with_label_values(&[resolver_name]);
     148            0 :         let cert_reload_started_counter =
     149            0 :             CERT_RELOAD_STARTED_COUNTER.with_label_values(&[resolver_name]);
     150            0 :         let cert_reload_updated_counter =
     151            0 :             CERT_RELOAD_UPDATED_COUNTER.with_label_values(&[resolver_name]);
     152            0 :         let cert_reload_failed_counter =
     153            0 :             CERT_RELOAD_FAILED_COUNTER.with_label_values(&[resolver_name]);
     154              : 
     155            0 :         let parsed_key = load_and_parse_certified_key(key_filename, cert_filename).await?;
     156              : 
     157            0 :         let this = Arc::new(Self {
     158            0 :             certified_key: ArcSwap::from_pointee(parsed_key.certified_key),
     159            0 :         });
     160            0 :         cert_expiration_time.set(parsed_key.expiration_time.as_secs());
     161            0 : 
     162            0 :         tokio::spawn({
     163            0 :             let weak_this = Arc::downgrade(&this);
     164            0 :             let key_filename = key_filename.to_owned();
     165            0 :             let cert_filename = cert_filename.to_owned();
     166            0 :             async move {
     167            0 :                 let start = tokio::time::Instant::now() + reload_period;
     168            0 :                 let mut interval = tokio::time::interval_at(start, reload_period);
     169            0 :                 let mut last_reload_failed = false;
     170              :                 loop {
     171            0 :                     interval.tick().await;
     172            0 :                     let this = match weak_this.upgrade() {
     173            0 :                         Some(this) => this,
     174            0 :                         None => break, // Resolver has been destroyed, exit.
     175              :                     };
     176            0 :                     cert_reload_started_counter.inc();
     177            0 : 
     178            0 :                     match load_and_parse_certified_key(&key_filename, &cert_filename).await {
     179            0 :                         Ok(parsed_key) => {
     180            0 :                             if parsed_key.certified_key.cert == this.certified_key.load().cert {
     181            0 :                                 tracing::debug!("Certificate has not changed since last reloading");
     182              :                             } else {
     183            0 :                                 tracing::info!("Certificate has been reloaded");
     184            0 :                                 this.certified_key.store(Arc::new(parsed_key.certified_key));
     185            0 :                                 cert_expiration_time.set(parsed_key.expiration_time.as_secs());
     186            0 :                                 cert_reload_updated_counter.inc();
     187              :                             }
     188            0 :                             last_reload_failed = false;
     189              :                         }
     190            0 :                         Err(err) => {
     191            0 :                             cert_reload_failed_counter.inc();
     192            0 :                             // Note: Reloading certs may fail if it conflicts with the script updating
     193            0 :                             // the files at the same time. Warn only if the error is persistent.
     194            0 :                             if last_reload_failed {
     195            0 :                                 tracing::warn!("Error reloading certificate: {err:#}");
     196              :                             } else {
     197            0 :                                 tracing::info!("Error reloading certificate: {err:#}");
     198              :                             }
     199            0 :                             last_reload_failed = true;
     200              :                         }
     201              :                     }
     202              :                 }
     203            0 :             }
     204            0 :         });
     205            0 : 
     206            0 :         Ok(this)
     207            0 :     }
     208              : }
     209              : 
     210              : impl ResolvesServerCert for ReloadingCertificateResolver {
     211            0 :     fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
     212            0 :         Some(self.certified_key.load_full())
     213            0 :     }
     214              : }
        

Generated by: LCOV version 2.1-beta