LCOV - code coverage report
Current view: top level - libs/http-utils/src - tls_certs.rs (source / functions) Coverage Total Hit
Test: 046155f5c3321e806c1c5acca9ccd26414587b38.info Lines: 0.0 % 76 0
Test Date: 2025-03-27 12:42:09 Functions: 0.0 % 10 0

            Line data    Source code
       1              : use std::{sync::Arc, time::Duration};
       2              : 
       3              : use anyhow::Context;
       4              : use arc_swap::ArcSwap;
       5              : use camino::Utf8Path;
       6              : use rustls::{
       7              :     pki_types::{CertificateDer, PrivateKeyDer},
       8              :     server::{ClientHello, ResolvesServerCert},
       9              :     sign::CertifiedKey,
      10              : };
      11              : 
      12            0 : pub async fn load_cert_chain(filename: &Utf8Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
      13            0 :     let cert_data = tokio::fs::read(filename)
      14            0 :         .await
      15            0 :         .context(format!("failed reading certificate file {filename:?}"))?;
      16            0 :     let mut reader = std::io::Cursor::new(&cert_data);
      17              : 
      18            0 :     let cert_chain = rustls_pemfile::certs(&mut reader)
      19            0 :         .collect::<Result<Vec<_>, _>>()
      20            0 :         .context(format!("failed parsing certificate from file {filename:?}"))?;
      21              : 
      22            0 :     Ok(cert_chain)
      23            0 : }
      24              : 
      25            0 : pub async fn load_private_key(filename: &Utf8Path) -> anyhow::Result<PrivateKeyDer<'static>> {
      26            0 :     let key_data = tokio::fs::read(filename)
      27            0 :         .await
      28            0 :         .context(format!("failed reading private key file {filename:?}"))?;
      29            0 :     let mut reader = std::io::Cursor::new(&key_data);
      30              : 
      31            0 :     let key = rustls_pemfile::private_key(&mut reader)
      32            0 :         .context(format!("failed parsing private key from file {filename:?}"))?;
      33              : 
      34            0 :     key.ok_or(anyhow::anyhow!(
      35            0 :         "no private key found in {}",
      36            0 :         filename.as_str(),
      37            0 :     ))
      38            0 : }
      39              : 
      40            0 : pub async fn load_certified_key(
      41            0 :     key_filename: &Utf8Path,
      42            0 :     cert_filename: &Utf8Path,
      43            0 : ) -> anyhow::Result<CertifiedKey> {
      44            0 :     let cert_chain = load_cert_chain(cert_filename).await?;
      45            0 :     let key = load_private_key(key_filename).await?;
      46              : 
      47            0 :     let key = rustls::crypto::ring::default_provider()
      48            0 :         .key_provider
      49            0 :         .load_private_key(key)?;
      50              : 
      51            0 :     let certified_key = CertifiedKey::new(cert_chain, key);
      52            0 :     certified_key.keys_match()?;
      53            0 :     Ok(certified_key)
      54            0 : }
      55              : 
      56              : /// Implementation of [`rustls::server::ResolvesServerCert`] which reloads certificates from
      57              : /// the disk periodically.
      58              : #[derive(Debug)]
      59              : pub struct ReloadingCertificateResolver {
      60              :     certified_key: ArcSwap<CertifiedKey>,
      61              : }
      62              : 
      63              : impl ReloadingCertificateResolver {
      64              :     /// Creates a new Resolver by loading certificate and private key from FS and
      65              :     /// creating tokio::task to reload them with provided reload_period.
      66            0 :     pub async fn new(
      67            0 :         key_filename: &Utf8Path,
      68            0 :         cert_filename: &Utf8Path,
      69            0 :         reload_period: Duration,
      70            0 :     ) -> anyhow::Result<Arc<Self>> {
      71            0 :         let this = Arc::new(Self {
      72              :             certified_key: ArcSwap::from_pointee(
      73            0 :                 load_certified_key(key_filename, cert_filename).await?,
      74              :             ),
      75              :         });
      76              : 
      77            0 :         tokio::spawn({
      78            0 :             let weak_this = Arc::downgrade(&this);
      79            0 :             let key_filename = key_filename.to_owned();
      80            0 :             let cert_filename = cert_filename.to_owned();
      81            0 :             async move {
      82            0 :                 let start = tokio::time::Instant::now() + reload_period;
      83            0 :                 let mut interval = tokio::time::interval_at(start, reload_period);
      84            0 :                 let mut last_reload_failed = false;
      85              :                 loop {
      86            0 :                     interval.tick().await;
      87            0 :                     let this = match weak_this.upgrade() {
      88            0 :                         Some(this) => this,
      89            0 :                         None => break, // Resolver has been destroyed, exit.
      90              :                     };
      91            0 :                     match load_certified_key(&key_filename, &cert_filename).await {
      92            0 :                         Ok(new_certified_key) => {
      93            0 :                             if new_certified_key.cert == this.certified_key.load().cert {
      94            0 :                                 tracing::debug!("Certificate has not changed since last reloading");
      95              :                             } else {
      96            0 :                                 tracing::info!("Certificate has been reloaded");
      97            0 :                                 this.certified_key.store(Arc::new(new_certified_key));
      98              :                             }
      99            0 :                             last_reload_failed = false;
     100              :                         }
     101            0 :                         Err(err) => {
     102            0 :                             // Note: Reloading certs may fail if it conflicts with the script updating
     103            0 :                             // the files at the same time. Warn only if the error is persistent.
     104            0 :                             if last_reload_failed {
     105            0 :                                 tracing::warn!("Error reloading certificate: {err:#}");
     106              :                             } else {
     107            0 :                                 tracing::info!("Error reloading certificate: {err:#}");
     108              :                             }
     109            0 :                             last_reload_failed = true;
     110              :                         }
     111              :                     }
     112              :                 }
     113            0 :             }
     114            0 :         });
     115            0 : 
     116            0 :         Ok(this)
     117            0 :     }
     118              : }
     119              : 
     120              : impl ResolvesServerCert for ReloadingCertificateResolver {
     121            0 :     fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
     122            0 :         Some(self.certified_key.load_full())
     123            0 :     }
     124              : }
        

Generated by: LCOV version 2.1-beta