LCOV - code coverage report
Current view: top level - compute_tools/src - tls.rs (source / functions) Coverage Total Hit
Test: 62212f4d57a7ad0f69dc82a04629a0bbd5f7c824.info Lines: 0.0 % 70 0
Test Date: 2025-03-17 10:41:39 Functions: 0.0 % 10 0

            Line data    Source code
       1              : use std::{io::Write, os::unix::fs::OpenOptionsExt, path::Path, time::Duration};
       2              : 
       3              : use anyhow::{Context, Result, bail};
       4              : use compute_api::responses::TlsConfig;
       5              : use ring::digest;
       6              : use spki::ObjectIdentifier;
       7              : use spki::der::{Decode, PemReader};
       8              : use x509_cert::Certificate;
       9              : 
      10              : #[derive(Clone, Copy)]
      11              : pub struct CertDigest(digest::Digest);
      12              : 
      13            0 : pub async fn watch_cert_for_changes(cert_path: String) -> tokio::sync::watch::Receiver<CertDigest> {
      14            0 :     let mut digest = compute_digest(&cert_path).await;
      15            0 :     let (tx, rx) = tokio::sync::watch::channel(digest);
      16            0 :     tokio::spawn(async move {
      17            0 :         while !tx.is_closed() {
      18            0 :             let new_digest = compute_digest(&cert_path).await;
      19            0 :             if digest.0.as_ref() != new_digest.0.as_ref() {
      20            0 :                 digest = new_digest;
      21            0 :                 _ = tx.send(digest);
      22            0 :             }
      23              : 
      24            0 :             tokio::time::sleep(Duration::from_secs(60)).await
      25              :         }
      26            0 :     });
      27            0 :     rx
      28            0 : }
      29              : 
      30            0 : async fn compute_digest(cert_path: &str) -> CertDigest {
      31              :     loop {
      32            0 :         match try_compute_digest(cert_path).await {
      33            0 :             Ok(d) => break d,
      34            0 :             Err(e) => {
      35            0 :                 tracing::error!("could not read cert file {e:?}");
      36            0 :                 tokio::time::sleep(Duration::from_secs(1)).await
      37              :             }
      38              :         }
      39              :     }
      40            0 : }
      41              : 
      42            0 : async fn try_compute_digest(cert_path: &str) -> Result<CertDigest> {
      43            0 :     let data = tokio::fs::read(cert_path).await?;
      44              :     // sha256 is extremely collision resistent. can safely assume the digest to be unique
      45            0 :     Ok(CertDigest(digest::digest(&digest::SHA256, &data)))
      46            0 : }
      47              : 
      48              : pub const SERVER_CRT: &str = "server.crt";
      49              : pub const SERVER_KEY: &str = "server.key";
      50              : 
      51            0 : pub fn update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) {
      52              :     loop {
      53            0 :         match try_update_key_path_blocking(pg_data, tls_config) {
      54            0 :             Ok(()) => break,
      55            0 :             Err(e) => {
      56            0 :                 tracing::error!("could not create key file {e:?}");
      57            0 :                 std::thread::sleep(Duration::from_secs(1))
      58              :             }
      59              :         }
      60              :     }
      61            0 : }
      62              : 
      63              : // Postgres requires the keypath be "secure". This means
      64              : // 1. Owned by the postgres user.
      65              : // 2. Have permission 600.
      66            0 : fn try_update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) -> Result<()> {
      67            0 :     let key = std::fs::read_to_string(&tls_config.key_path)?;
      68            0 :     let crt = std::fs::read_to_string(&tls_config.cert_path)?;
      69              : 
      70              :     // to mitigate a race condition during renewal.
      71            0 :     verify_key_cert(&key, &crt)?;
      72              : 
      73            0 :     let mut key_file = std::fs::OpenOptions::new()
      74            0 :         .write(true)
      75            0 :         .create(true)
      76            0 :         .truncate(true)
      77            0 :         .mode(0o600)
      78            0 :         .open(pg_data.join(SERVER_KEY))?;
      79              : 
      80            0 :     let mut crt_file = std::fs::OpenOptions::new()
      81            0 :         .write(true)
      82            0 :         .create(true)
      83            0 :         .truncate(true)
      84            0 :         .mode(0o600)
      85            0 :         .open(pg_data.join(SERVER_CRT))?;
      86              : 
      87            0 :     key_file.write_all(key.as_bytes())?;
      88            0 :     crt_file.write_all(crt.as_bytes())?;
      89              : 
      90            0 :     Ok(())
      91            0 : }
      92              : 
      93            0 : fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
      94              :     const ECDSA_WITH_SHA256: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.4.3.2");
      95              : 
      96            0 :     let cert = Certificate::decode(&mut PemReader::new(cert.as_bytes()).context("pem reader")?)
      97            0 :         .context("decode cert")?;
      98              : 
      99            0 :     match cert.signature_algorithm.oid {
     100              :         ECDSA_WITH_SHA256 => {
     101            0 :             let key = p256::SecretKey::from_sec1_pem(key).context("parse key")?;
     102              : 
     103            0 :             let a = key.public_key().to_sec1_bytes();
     104            0 :             let b = cert
     105            0 :                 .tbs_certificate
     106            0 :                 .subject_public_key_info
     107            0 :                 .subject_public_key
     108            0 :                 .raw_bytes();
     109            0 : 
     110            0 :             if *a != *b {
     111            0 :                 bail!("private key file does not match certificate")
     112            0 :             }
     113              :         }
     114            0 :         _ => bail!("unknown TLS key type"),
     115              :     }
     116              : 
     117            0 :     Ok(())
     118            0 : }
        

Generated by: LCOV version 2.1-beta