LCOV - code coverage report
Current view: top level - compute_tools/src - tls.rs (source / functions) Coverage Total Hit
Test: ae4948feae6a1d420c855050eb8c189119446a71.info Lines: 0.0 % 70 0
Test Date: 2025-03-18 18:33:46 Functions: 0.0 % 10 0

            Line data    Source code
       1              : use std::{io::Write, os::unix::fs::OpenOptionsExt, path::Path, time::Duration};
       2              : 
       3              : use anyhow::{Context, Result, bail};
       4              : use compute_api::responses::TlsConfig;
       5              : use ring::digest;
       6              : use spki::der::{Decode, PemReader};
       7              : use x509_cert::Certificate;
       8              : 
       9              : #[derive(Clone, Copy)]
      10              : pub struct CertDigest(digest::Digest);
      11              : 
      12            0 : pub async fn watch_cert_for_changes(cert_path: String) -> tokio::sync::watch::Receiver<CertDigest> {
      13            0 :     let mut digest = compute_digest(&cert_path).await;
      14            0 :     let (tx, rx) = tokio::sync::watch::channel(digest);
      15            0 :     tokio::spawn(async move {
      16            0 :         while !tx.is_closed() {
      17            0 :             let new_digest = compute_digest(&cert_path).await;
      18            0 :             if digest.0.as_ref() != new_digest.0.as_ref() {
      19            0 :                 digest = new_digest;
      20            0 :                 _ = tx.send(digest);
      21            0 :             }
      22              : 
      23            0 :             tokio::time::sleep(Duration::from_secs(60)).await
      24              :         }
      25            0 :     });
      26            0 :     rx
      27            0 : }
      28              : 
      29            0 : async fn compute_digest(cert_path: &str) -> CertDigest {
      30              :     loop {
      31            0 :         match try_compute_digest(cert_path).await {
      32            0 :             Ok(d) => break d,
      33            0 :             Err(e) => {
      34            0 :                 tracing::error!("could not read cert file {e:?}");
      35            0 :                 tokio::time::sleep(Duration::from_secs(1)).await
      36              :             }
      37              :         }
      38              :     }
      39            0 : }
      40              : 
      41            0 : async fn try_compute_digest(cert_path: &str) -> Result<CertDigest> {
      42            0 :     let data = tokio::fs::read(cert_path).await?;
      43              :     // sha256 is extremely collision resistent. can safely assume the digest to be unique
      44            0 :     Ok(CertDigest(digest::digest(&digest::SHA256, &data)))
      45            0 : }
      46              : 
      47              : pub const SERVER_CRT: &str = "server.crt";
      48              : pub const SERVER_KEY: &str = "server.key";
      49              : 
      50            0 : pub fn update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) {
      51              :     loop {
      52            0 :         match try_update_key_path_blocking(pg_data, tls_config) {
      53            0 :             Ok(()) => break,
      54            0 :             Err(e) => {
      55            0 :                 tracing::error!("could not create key file {e:?}");
      56            0 :                 std::thread::sleep(Duration::from_secs(1))
      57              :             }
      58              :         }
      59              :     }
      60            0 : }
      61              : 
      62              : // Postgres requires the keypath be "secure". This means
      63              : // 1. Owned by the postgres user.
      64              : // 2. Have permission 600.
      65            0 : fn try_update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) -> Result<()> {
      66            0 :     let key = std::fs::read_to_string(&tls_config.key_path)?;
      67            0 :     let crt = std::fs::read_to_string(&tls_config.cert_path)?;
      68              : 
      69              :     // to mitigate a race condition during renewal.
      70            0 :     verify_key_cert(&key, &crt)?;
      71              : 
      72            0 :     let mut key_file = std::fs::OpenOptions::new()
      73            0 :         .write(true)
      74            0 :         .create(true)
      75            0 :         .truncate(true)
      76            0 :         .mode(0o600)
      77            0 :         .open(pg_data.join(SERVER_KEY))?;
      78              : 
      79            0 :     let mut crt_file = std::fs::OpenOptions::new()
      80            0 :         .write(true)
      81            0 :         .create(true)
      82            0 :         .truncate(true)
      83            0 :         .mode(0o600)
      84            0 :         .open(pg_data.join(SERVER_CRT))?;
      85              : 
      86            0 :     key_file.write_all(key.as_bytes())?;
      87            0 :     crt_file.write_all(crt.as_bytes())?;
      88              : 
      89            0 :     Ok(())
      90            0 : }
      91              : 
      92            0 : fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
      93              :     use x509_cert::der::oid::db::rfc5912::ECDSA_WITH_SHA_256;
      94              : 
      95            0 :     let cert = Certificate::decode(&mut PemReader::new(cert.as_bytes()).context("pem reader")?)
      96            0 :         .context("decode cert")?;
      97              : 
      98            0 :     match cert.signature_algorithm.oid {
      99              :         ECDSA_WITH_SHA_256 => {
     100            0 :             let key = p256::SecretKey::from_sec1_pem(key).context("parse key")?;
     101              : 
     102            0 :             let a = key.public_key().to_sec1_bytes();
     103            0 :             let b = cert
     104            0 :                 .tbs_certificate
     105            0 :                 .subject_public_key_info
     106            0 :                 .subject_public_key
     107            0 :                 .raw_bytes();
     108            0 : 
     109            0 :             if *a != *b {
     110            0 :                 bail!("private key file does not match certificate")
     111            0 :             }
     112              :         }
     113            0 :         _ => bail!("unknown TLS key type"),
     114              :     }
     115              : 
     116            0 :     Ok(())
     117            0 : }
        

Generated by: LCOV version 2.1-beta