LCOV - code coverage report
Current view: top level - proxy/src/tls - postgres_rustls.rs (source / functions) Coverage Total Hit
Test: 5445d246133daeceb0507e6cc0797ab7c1c70cb8.info Lines: 86.4 % 59 51
Test Date: 2025-03-12 18:05:02 Functions: 47.8 % 23 11

            Line data    Source code
       1              : use std::convert::TryFrom;
       2              : use std::sync::Arc;
       3              : 
       4              : use postgres_client::tls::MakeTlsConnect;
       5              : use rustls::ClientConfig;
       6              : use rustls::pki_types::ServerName;
       7              : use tokio::io::{AsyncRead, AsyncWrite};
       8              : 
       9              : mod private {
      10              :     use std::future::Future;
      11              :     use std::io;
      12              :     use std::pin::Pin;
      13              :     use std::task::{Context, Poll};
      14              : 
      15              :     use postgres_client::tls::{ChannelBinding, TlsConnect};
      16              :     use rustls::pki_types::ServerName;
      17              :     use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      18              :     use tokio_rustls::TlsConnector;
      19              :     use tokio_rustls::client::TlsStream;
      20              : 
      21              :     use crate::tls::TlsServerEndPoint;
      22              : 
      23              :     pub struct TlsConnectFuture<S> {
      24              :         inner: tokio_rustls::Connect<S>,
      25              :     }
      26              : 
      27              :     impl<S> Future for TlsConnectFuture<S>
      28              :     where
      29              :         S: AsyncRead + AsyncWrite + Unpin,
      30              :     {
      31              :         type Output = io::Result<RustlsStream<S>>;
      32              : 
      33           40 :         fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
      34           40 :             Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
      35           40 :         }
      36              :     }
      37              : 
      38              :     pub struct RustlsConnect(pub RustlsConnectData);
      39              : 
      40              :     pub struct RustlsConnectData {
      41              :         pub hostname: ServerName<'static>,
      42              :         pub connector: TlsConnector,
      43              :     }
      44              : 
      45              :     impl<S> TlsConnect<S> for RustlsConnect
      46              :     where
      47              :         S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
      48              :     {
      49              :         type Stream = RustlsStream<S>;
      50              :         type Error = io::Error;
      51              :         type Future = TlsConnectFuture<S>;
      52              : 
      53           20 :         fn connect(self, stream: S) -> Self::Future {
      54           20 :             TlsConnectFuture {
      55           20 :                 inner: self.0.connector.connect(self.0.hostname, stream),
      56           20 :             }
      57           20 :         }
      58              :     }
      59              : 
      60              :     pub struct RustlsStream<S>(TlsStream<S>);
      61              : 
      62              :     impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
      63              :     where
      64              :         S: AsyncRead + AsyncWrite + Unpin,
      65              :     {
      66           12 :         fn channel_binding(&self) -> ChannelBinding {
      67           12 :             let (_, session) = self.0.get_ref();
      68           12 :             match session.peer_certificates() {
      69           12 :                 Some([cert, ..]) => TlsServerEndPoint::new(cert)
      70           12 :                     .ok()
      71           12 :                     .and_then(|cb| match cb {
      72           12 :                         TlsServerEndPoint::Sha256(hash) => Some(hash),
      73            0 :                         TlsServerEndPoint::Undefined => None,
      74           12 :                     })
      75           12 :                     .map_or_else(ChannelBinding::none, |hash| {
      76           12 :                         ChannelBinding::tls_server_end_point(hash.to_vec())
      77           12 :                     }),
      78            0 :                 _ => ChannelBinding::none(),
      79              :             }
      80           12 :         }
      81              :     }
      82              : 
      83              :     impl<S> AsyncRead for RustlsStream<S>
      84              :     where
      85              :         S: AsyncRead + AsyncWrite + Unpin,
      86              :     {
      87          123 :         fn poll_read(
      88          123 :             mut self: Pin<&mut Self>,
      89          123 :             cx: &mut Context<'_>,
      90          123 :             buf: &mut ReadBuf<'_>,
      91          123 :         ) -> Poll<tokio::io::Result<()>> {
      92          123 :             Pin::new(&mut self.0).poll_read(cx, buf)
      93          123 :         }
      94              :     }
      95              : 
      96              :     impl<S> AsyncWrite for RustlsStream<S>
      97              :     where
      98              :         S: AsyncRead + AsyncWrite + Unpin,
      99              :     {
     100           52 :         fn poll_write(
     101           52 :             mut self: Pin<&mut Self>,
     102           52 :             cx: &mut Context<'_>,
     103           52 :             buf: &[u8],
     104           52 :         ) -> Poll<tokio::io::Result<usize>> {
     105           52 :             Pin::new(&mut self.0).poll_write(cx, buf)
     106           52 :         }
     107              : 
     108           52 :         fn poll_flush(
     109           52 :             mut self: Pin<&mut Self>,
     110           52 :             cx: &mut Context<'_>,
     111           52 :         ) -> Poll<tokio::io::Result<()>> {
     112           52 :             Pin::new(&mut self.0).poll_flush(cx)
     113           52 :         }
     114              : 
     115            0 :         fn poll_shutdown(
     116            0 :             mut self: Pin<&mut Self>,
     117            0 :             cx: &mut Context<'_>,
     118            0 :         ) -> Poll<tokio::io::Result<()>> {
     119            0 :             Pin::new(&mut self.0).poll_shutdown(cx)
     120            0 :         }
     121              :     }
     122              : }
     123              : 
     124              : /// A `MakeTlsConnect` implementation using `rustls`.
     125              : ///
     126              : /// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
     127              : #[derive(Clone)]
     128              : pub struct MakeRustlsConnect {
     129              :     pub config: Arc<ClientConfig>,
     130              : }
     131              : 
     132              : impl MakeRustlsConnect {
     133              :     /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
     134              :     #[must_use]
     135           20 :     pub fn new(config: Arc<ClientConfig>) -> Self {
     136           20 :         Self { config }
     137           20 :     }
     138              : }
     139              : 
     140              : impl<S> MakeTlsConnect<S> for MakeRustlsConnect
     141              : where
     142              :     S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
     143              : {
     144              :     type Stream = private::RustlsStream<S>;
     145              :     type TlsConnect = private::RustlsConnect;
     146              :     type Error = rustls::pki_types::InvalidDnsNameError;
     147              : 
     148           20 :     fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
     149           20 :         ServerName::try_from(hostname).map(|dns_name| {
     150           20 :             private::RustlsConnect(private::RustlsConnectData {
     151           20 :                 hostname: dns_name.to_owned(),
     152           20 :                 connector: Arc::clone(&self.config).into(),
     153           20 :             })
     154           20 :         })
     155           20 :     }
     156              : }
        

Generated by: LCOV version 2.1-beta