LCOV - code coverage report
Current view: top level - proxy/src/compute - tls.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 0.0 % 31 0
Test Date: 2025-07-26 17:20:05 Functions: 0.0 % 3 0

            Line data    Source code
       1              : use futures::FutureExt;
       2              : use postgres_client::config::SslMode;
       3              : use postgres_client::maybe_tls_stream::MaybeTlsStream;
       4              : use postgres_client::tls::{MakeTlsConnect, TlsConnect};
       5              : use rustls::pki_types::InvalidDnsNameError;
       6              : use thiserror::Error;
       7              : use tokio::io::{AsyncRead, AsyncWrite};
       8              : 
       9              : use crate::pqproto::request_tls;
      10              : use crate::proxy::connect_compute::TlsNegotiation;
      11              : use crate::proxy::retry::CouldRetry;
      12              : 
      13              : #[derive(Debug, Error)]
      14              : pub enum TlsError {
      15              :     #[error(transparent)]
      16              :     Dns(#[from] InvalidDnsNameError),
      17              :     #[error(transparent)]
      18              :     Connection(#[from] std::io::Error),
      19              :     #[error("TLS required but not provided")]
      20              :     Required,
      21              : }
      22              : 
      23              : impl CouldRetry for TlsError {
      24            0 :     fn could_retry(&self) -> bool {
      25            0 :         match self {
      26            0 :             TlsError::Dns(_) => false,
      27            0 :             TlsError::Connection(err) => err.could_retry(),
      28              :             // perhaps compute didn't realise it supports TLS?
      29            0 :             TlsError::Required => true,
      30              :         }
      31            0 :     }
      32              : }
      33              : 
      34            0 : pub async fn connect_tls<S, T>(
      35            0 :     mut stream: S,
      36            0 :     mode: SslMode,
      37            0 :     tls: &T,
      38            0 :     host: &str,
      39            0 :     negotiation: TlsNegotiation,
      40            0 : ) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
      41            0 : where
      42            0 :     S: AsyncRead + AsyncWrite + Unpin + Send,
      43            0 :     T: MakeTlsConnect<
      44            0 :             S,
      45            0 :             Error = InvalidDnsNameError,
      46            0 :             TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
      47            0 :         >,
      48            0 : {
      49            0 :     match mode {
      50            0 :         SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
      51            0 :         SslMode::Prefer | SslMode::Require => {}
      52              :     }
      53              : 
      54            0 :     match negotiation {
      55              :         // No TLS request needed
      56            0 :         TlsNegotiation::Direct => {}
      57              :         // TLS request successful
      58            0 :         TlsNegotiation::Postgres if request_tls(&mut stream).await? => {}
      59              :         // TLS request failed but is required
      60            0 :         TlsNegotiation::Postgres if SslMode::Require == mode => return Err(TlsError::Required),
      61              :         // TLS request failed but is not required
      62            0 :         TlsNegotiation::Postgres => return Ok(MaybeTlsStream::Raw(stream)),
      63              :     }
      64              : 
      65              :     Ok(MaybeTlsStream::Tls(
      66            0 :         tls.make_tls_connect(host)?.connect(stream).boxed().await?,
      67              :     ))
      68            0 : }
        

Generated by: LCOV version 2.1-beta