Line data Source code
1 : use std::{io::Write, os::unix::fs::OpenOptionsExt, path::Path, time::Duration};
2 :
3 : use anyhow::{Context, Result, bail};
4 : use compute_api::responses::TlsConfig;
5 : use ring::digest;
6 : use spki::ObjectIdentifier;
7 : use spki::der::{Decode, PemReader};
8 : use x509_cert::Certificate;
9 :
10 : #[derive(Clone, Copy)]
11 : pub struct CertDigest(digest::Digest);
12 :
13 0 : pub async fn watch_cert_for_changes(cert_path: String) -> tokio::sync::watch::Receiver<CertDigest> {
14 0 : let mut digest = compute_digest(&cert_path).await;
15 0 : let (tx, rx) = tokio::sync::watch::channel(digest);
16 0 : tokio::spawn(async move {
17 0 : while !tx.is_closed() {
18 0 : let new_digest = compute_digest(&cert_path).await;
19 0 : if digest.0.as_ref() != new_digest.0.as_ref() {
20 0 : digest = new_digest;
21 0 : _ = tx.send(digest);
22 0 : }
23 :
24 0 : tokio::time::sleep(Duration::from_secs(60)).await
25 : }
26 0 : });
27 0 : rx
28 0 : }
29 :
30 0 : async fn compute_digest(cert_path: &str) -> CertDigest {
31 : loop {
32 0 : match try_compute_digest(cert_path).await {
33 0 : Ok(d) => break d,
34 0 : Err(e) => {
35 0 : tracing::error!("could not read cert file {e:?}");
36 0 : tokio::time::sleep(Duration::from_secs(1)).await
37 : }
38 : }
39 : }
40 0 : }
41 :
42 0 : async fn try_compute_digest(cert_path: &str) -> Result<CertDigest> {
43 0 : let data = tokio::fs::read(cert_path).await?;
44 : // sha256 is extremely collision resistent. can safely assume the digest to be unique
45 0 : Ok(CertDigest(digest::digest(&digest::SHA256, &data)))
46 0 : }
47 :
48 : pub const SERVER_CRT: &str = "server.crt";
49 : pub const SERVER_KEY: &str = "server.key";
50 :
51 0 : pub fn update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) {
52 : loop {
53 0 : match try_update_key_path_blocking(pg_data, tls_config) {
54 0 : Ok(()) => break,
55 0 : Err(e) => {
56 0 : tracing::error!("could not create key file {e:?}");
57 0 : std::thread::sleep(Duration::from_secs(1))
58 : }
59 : }
60 : }
61 0 : }
62 :
63 : // Postgres requires the keypath be "secure". This means
64 : // 1. Owned by the postgres user.
65 : // 2. Have permission 600.
66 0 : fn try_update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) -> Result<()> {
67 0 : let key = std::fs::read_to_string(&tls_config.key_path)?;
68 0 : let crt = std::fs::read_to_string(&tls_config.cert_path)?;
69 :
70 : // to mitigate a race condition during renewal.
71 0 : verify_key_cert(&key, &crt)?;
72 :
73 0 : let mut key_file = std::fs::OpenOptions::new()
74 0 : .write(true)
75 0 : .create(true)
76 0 : .truncate(true)
77 0 : .mode(0o600)
78 0 : .open(pg_data.join(SERVER_KEY))?;
79 :
80 0 : let mut crt_file = std::fs::OpenOptions::new()
81 0 : .write(true)
82 0 : .create(true)
83 0 : .truncate(true)
84 0 : .mode(0o600)
85 0 : .open(pg_data.join(SERVER_CRT))?;
86 :
87 0 : key_file.write_all(key.as_bytes())?;
88 0 : crt_file.write_all(crt.as_bytes())?;
89 :
90 0 : Ok(())
91 0 : }
92 :
93 0 : fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
94 : const ECDSA_WITH_SHA256: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.4.3.2");
95 :
96 0 : let cert = Certificate::decode(&mut PemReader::new(cert.as_bytes()).context("pem reader")?)
97 0 : .context("decode cert")?;
98 :
99 0 : match cert.signature_algorithm.oid {
100 : ECDSA_WITH_SHA256 => {
101 0 : let key = p256::SecretKey::from_sec1_pem(key).context("parse key")?;
102 :
103 0 : let a = key.public_key().to_sec1_bytes();
104 0 : let b = cert
105 0 : .tbs_certificate
106 0 : .subject_public_key_info
107 0 : .subject_public_key
108 0 : .raw_bytes();
109 0 :
110 0 : if *a != *b {
111 0 : bail!("private key file does not match certificate")
112 0 : }
113 : }
114 0 : _ => bail!("unknown TLS key type"),
115 : }
116 :
117 0 : Ok(())
118 0 : }
|