LCOV - code coverage report
Current view: top level - proxy/src/tls - postgres_rustls.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 82.8 % 64 53
Test Date: 2025-07-16 12:29:03 Functions: 50.0 % 22 11

            Line data    Source code
       1              : use std::convert::TryFrom;
       2              : use std::sync::Arc;
       3              : 
       4              : use postgres_client::tls::MakeTlsConnect;
       5              : use rustls::pki_types::{InvalidDnsNameError, ServerName};
       6              : use tokio::io::{AsyncRead, AsyncWrite};
       7              : 
       8              : use crate::config::ComputeConfig;
       9              : 
      10              : mod private {
      11              :     use std::future::Future;
      12              :     use std::io;
      13              :     use std::pin::Pin;
      14              :     use std::task::{Context, Poll};
      15              : 
      16              :     use postgres_client::tls::{ChannelBinding, TlsConnect};
      17              :     use rustls::pki_types::ServerName;
      18              :     use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      19              :     use tokio_rustls::TlsConnector;
      20              :     use tokio_rustls::client::TlsStream;
      21              : 
      22              :     use crate::tls::TlsServerEndPoint;
      23              : 
      24              :     pub struct TlsConnectFuture<S> {
      25              :         inner: tokio_rustls::Connect<S>,
      26              :     }
      27              : 
      28              :     impl<S> Future for TlsConnectFuture<S>
      29              :     where
      30              :         S: AsyncRead + AsyncWrite + Unpin,
      31              :     {
      32              :         type Output = io::Result<RustlsStream<S>>;
      33              : 
      34           40 :         fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
      35           40 :             Pin::new(&mut self.inner)
      36           40 :                 .poll(cx)
      37           40 :                 .map_ok(|s| RustlsStream(Box::new(s)))
      38           40 :         }
      39              :     }
      40              : 
      41              :     pub struct RustlsConnect(pub RustlsConnectData);
      42              : 
      43              :     pub struct RustlsConnectData {
      44              :         pub hostname: ServerName<'static>,
      45              :         pub connector: TlsConnector,
      46              :     }
      47              : 
      48              :     impl<S> TlsConnect<S> for RustlsConnect
      49              :     where
      50              :         S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
      51              :     {
      52              :         type Stream = RustlsStream<S>;
      53              :         type Error = io::Error;
      54              :         type Future = TlsConnectFuture<S>;
      55              : 
      56           20 :         fn connect(self, stream: S) -> Self::Future {
      57           20 :             TlsConnectFuture {
      58           20 :                 inner: self.0.connector.connect(self.0.hostname, stream),
      59           20 :             }
      60           20 :         }
      61              :     }
      62              : 
      63              :     pub struct RustlsStream<S>(Box<TlsStream<S>>);
      64              : 
      65              :     impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
      66              :     where
      67              :         S: AsyncRead + AsyncWrite + Unpin,
      68              :     {
      69           12 :         fn channel_binding(&self) -> ChannelBinding {
      70           12 :             let (_, session) = self.0.get_ref();
      71           12 :             match session.peer_certificates() {
      72           12 :                 Some([cert, ..]) => TlsServerEndPoint::new(cert)
      73           12 :                     .ok()
      74           12 :                     .and_then(|cb| match cb {
      75           12 :                         TlsServerEndPoint::Sha256(hash) => Some(hash),
      76            0 :                         TlsServerEndPoint::Undefined => None,
      77           12 :                     })
      78           12 :                     .map_or_else(ChannelBinding::none, |hash| {
      79           12 :                         ChannelBinding::tls_server_end_point(hash.to_vec())
      80           12 :                     }),
      81            0 :                 _ => ChannelBinding::none(),
      82              :             }
      83           12 :         }
      84              :     }
      85              : 
      86              :     impl<S> AsyncRead for RustlsStream<S>
      87              :     where
      88              :         S: AsyncRead + AsyncWrite + Unpin,
      89              :     {
      90          127 :         fn poll_read(
      91          127 :             mut self: Pin<&mut Self>,
      92          127 :             cx: &mut Context<'_>,
      93          127 :             buf: &mut ReadBuf<'_>,
      94          127 :         ) -> Poll<tokio::io::Result<()>> {
      95          127 :             Pin::new(&mut self.0).poll_read(cx, buf)
      96          127 :         }
      97              :     }
      98              : 
      99              :     impl<S> AsyncWrite for RustlsStream<S>
     100              :     where
     101              :         S: AsyncRead + AsyncWrite + Unpin,
     102              :     {
     103           52 :         fn poll_write(
     104           52 :             mut self: Pin<&mut Self>,
     105           52 :             cx: &mut Context<'_>,
     106           52 :             buf: &[u8],
     107           52 :         ) -> Poll<tokio::io::Result<usize>> {
     108           52 :             Pin::new(&mut self.0).poll_write(cx, buf)
     109           52 :         }
     110              : 
     111           52 :         fn poll_flush(
     112           52 :             mut self: Pin<&mut Self>,
     113           52 :             cx: &mut Context<'_>,
     114           52 :         ) -> Poll<tokio::io::Result<()>> {
     115           52 :             Pin::new(&mut self.0).poll_flush(cx)
     116           52 :         }
     117              : 
     118            0 :         fn poll_shutdown(
     119            0 :             mut self: Pin<&mut Self>,
     120            0 :             cx: &mut Context<'_>,
     121            0 :         ) -> Poll<tokio::io::Result<()>> {
     122            0 :             Pin::new(&mut self.0).poll_shutdown(cx)
     123            0 :         }
     124              :     }
     125              : }
     126              : 
     127              : impl<S> MakeTlsConnect<S> for ComputeConfig
     128              : where
     129              :     S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
     130              : {
     131              :     type Stream = private::RustlsStream<S>;
     132              :     type TlsConnect = private::RustlsConnect;
     133              :     type Error = InvalidDnsNameError;
     134              : 
     135            0 :     fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
     136            0 :         make_tls_connect(&self.tls, hostname)
     137            0 :     }
     138              : }
     139              : 
     140           20 : pub fn make_tls_connect(
     141           20 :     tls: &Arc<rustls::ClientConfig>,
     142           20 :     hostname: &str,
     143           20 : ) -> Result<private::RustlsConnect, InvalidDnsNameError> {
     144           20 :     ServerName::try_from(hostname).map(|dns_name| {
     145           20 :         private::RustlsConnect(private::RustlsConnectData {
     146           20 :             hostname: dns_name.to_owned(),
     147           20 :             connector: tls.clone().into(),
     148           20 :         })
     149           20 :     })
     150           20 : }
        

Generated by: LCOV version 2.1-beta