Line data Source code
1 : use std::{io::Write, os::unix::fs::OpenOptionsExt, path::Path, time::Duration};
2 :
3 : use anyhow::{Context, Result, bail};
4 : use compute_api::responses::TlsConfig;
5 : use ring::digest;
6 : use spki::der::{Decode, PemReader};
7 : use x509_cert::Certificate;
8 :
9 : #[derive(Clone, Copy)]
10 : pub struct CertDigest(digest::Digest);
11 :
12 0 : pub async fn watch_cert_for_changes(cert_path: String) -> tokio::sync::watch::Receiver<CertDigest> {
13 0 : let mut digest = compute_digest(&cert_path).await;
14 0 : let (tx, rx) = tokio::sync::watch::channel(digest);
15 0 : tokio::spawn(async move {
16 0 : while !tx.is_closed() {
17 0 : let new_digest = compute_digest(&cert_path).await;
18 0 : if digest.0.as_ref() != new_digest.0.as_ref() {
19 0 : digest = new_digest;
20 0 : _ = tx.send(digest);
21 0 : }
22 :
23 0 : tokio::time::sleep(Duration::from_secs(60)).await
24 : }
25 0 : });
26 0 : rx
27 0 : }
28 :
29 0 : async fn compute_digest(cert_path: &str) -> CertDigest {
30 : loop {
31 0 : match try_compute_digest(cert_path).await {
32 0 : Ok(d) => break d,
33 0 : Err(e) => {
34 0 : tracing::error!("could not read cert file {e:?}");
35 0 : tokio::time::sleep(Duration::from_secs(1)).await
36 : }
37 : }
38 : }
39 0 : }
40 :
41 0 : async fn try_compute_digest(cert_path: &str) -> Result<CertDigest> {
42 0 : let data = tokio::fs::read(cert_path).await?;
43 : // sha256 is extremely collision resistent. can safely assume the digest to be unique
44 0 : Ok(CertDigest(digest::digest(&digest::SHA256, &data)))
45 0 : }
46 :
47 : pub const SERVER_CRT: &str = "server.crt";
48 : pub const SERVER_KEY: &str = "server.key";
49 :
50 0 : pub fn update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) {
51 : loop {
52 0 : match try_update_key_path_blocking(pg_data, tls_config) {
53 0 : Ok(()) => break,
54 0 : Err(e) => {
55 0 : tracing::error!("could not create key file {e:?}");
56 0 : std::thread::sleep(Duration::from_secs(1))
57 : }
58 : }
59 : }
60 0 : }
61 :
62 : // Postgres requires the keypath be "secure". This means
63 : // 1. Owned by the postgres user.
64 : // 2. Have permission 600.
65 0 : fn try_update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) -> Result<()> {
66 0 : let key = std::fs::read_to_string(&tls_config.key_path)?;
67 0 : let crt = std::fs::read_to_string(&tls_config.cert_path)?;
68 :
69 : // to mitigate a race condition during renewal.
70 0 : verify_key_cert(&key, &crt)?;
71 :
72 0 : let mut key_file = std::fs::OpenOptions::new()
73 0 : .write(true)
74 0 : .create(true)
75 0 : .truncate(true)
76 0 : .mode(0o600)
77 0 : .open(pg_data.join(SERVER_KEY))?;
78 :
79 0 : let mut crt_file = std::fs::OpenOptions::new()
80 0 : .write(true)
81 0 : .create(true)
82 0 : .truncate(true)
83 0 : .mode(0o600)
84 0 : .open(pg_data.join(SERVER_CRT))?;
85 :
86 0 : key_file.write_all(key.as_bytes())?;
87 0 : crt_file.write_all(crt.as_bytes())?;
88 :
89 0 : Ok(())
90 0 : }
91 :
92 0 : fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
93 : use x509_cert::der::oid::db::rfc5912::ECDSA_WITH_SHA_256;
94 :
95 0 : let cert = Certificate::decode(&mut PemReader::new(cert.as_bytes()).context("pem reader")?)
96 0 : .context("decode cert")?;
97 :
98 0 : match cert.signature_algorithm.oid {
99 : ECDSA_WITH_SHA_256 => {
100 0 : let key = p256::SecretKey::from_sec1_pem(key).context("parse key")?;
101 :
102 0 : let a = key.public_key().to_sec1_bytes();
103 0 : let b = cert
104 0 : .tbs_certificate
105 0 : .subject_public_key_info
106 0 : .subject_public_key
107 0 : .raw_bytes();
108 0 :
109 0 : if *a != *b {
110 0 : bail!("private key file does not match certificate")
111 0 : }
112 : }
113 0 : _ => bail!("unknown TLS key type"),
114 : }
115 :
116 0 : Ok(())
117 0 : }
|