LCOV - code coverage report
Current view: top level - proxy/src/compute - tls.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 0.0 % 31 0
Test Date: 2025-07-16 12:29:03 Functions: 0.0 % 3 0

            Line data    Source code
       1              : use futures::FutureExt;
       2              : use postgres_client::config::SslMode;
       3              : use postgres_client::maybe_tls_stream::MaybeTlsStream;
       4              : use postgres_client::tls::{MakeTlsConnect, TlsConnect};
       5              : use rustls::pki_types::InvalidDnsNameError;
       6              : use thiserror::Error;
       7              : use tokio::io::{AsyncRead, AsyncWrite};
       8              : 
       9              : use crate::pqproto::request_tls;
      10              : use crate::proxy::retry::CouldRetry;
      11              : 
      12              : #[derive(Debug, Error)]
      13              : pub enum TlsError {
      14              :     #[error(transparent)]
      15              :     Dns(#[from] InvalidDnsNameError),
      16              :     #[error(transparent)]
      17              :     Connection(#[from] std::io::Error),
      18              :     #[error("TLS required but not provided")]
      19              :     Required,
      20              : }
      21              : 
      22              : impl CouldRetry for TlsError {
      23            0 :     fn could_retry(&self) -> bool {
      24            0 :         match self {
      25            0 :             TlsError::Dns(_) => false,
      26            0 :             TlsError::Connection(err) => err.could_retry(),
      27              :             // perhaps compute didn't realise it supports TLS?
      28            0 :             TlsError::Required => true,
      29              :         }
      30            0 :     }
      31              : }
      32              : 
      33            0 : pub async fn connect_tls<S, T>(
      34            0 :     mut stream: S,
      35            0 :     mode: SslMode,
      36            0 :     tls: &T,
      37            0 :     host: &str,
      38            0 : ) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
      39            0 : where
      40            0 :     S: AsyncRead + AsyncWrite + Unpin + Send,
      41            0 :     T: MakeTlsConnect<
      42            0 :             S,
      43            0 :             Error = InvalidDnsNameError,
      44            0 :             TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
      45            0 :         >,
      46            0 : {
      47            0 :     match mode {
      48            0 :         SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
      49            0 :         SslMode::Prefer | SslMode::Require => {}
      50              :     }
      51              : 
      52            0 :     if !request_tls(&mut stream).await? {
      53            0 :         if SslMode::Require == mode {
      54            0 :             return Err(TlsError::Required);
      55            0 :         }
      56              : 
      57            0 :         return Ok(MaybeTlsStream::Raw(stream));
      58            0 :     }
      59              : 
      60              :     Ok(MaybeTlsStream::Tls(
      61            0 :         tls.make_tls_connect(host)?.connect(stream).boxed().await?,
      62              :     ))
      63            0 : }
        

Generated by: LCOV version 2.1-beta