LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - tls.rs (source / functions) Coverage Total Hit
Test: 1b0a6a0c05cee5a7de360813c8034804e105ce1c.info Lines: 20.0 % 40 8
Test Date: 2025-03-12 00:01:28 Functions: 10.5 % 19 2

            Line data    Source code
       1              : //! TLS support.
       2              : 
       3              : use std::error::Error;
       4              : use std::future::Future;
       5              : use std::pin::Pin;
       6              : use std::task::{Context, Poll};
       7              : use std::{fmt, io};
       8              : 
       9              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      10              : 
      11              : pub(crate) mod private {
      12              :     pub struct ForcePrivateApi;
      13              : }
      14              : 
      15              : /// Channel binding information returned from a TLS handshake.
      16              : pub struct ChannelBinding {
      17              :     pub(crate) tls_server_end_point: Option<Vec<u8>>,
      18              : }
      19              : 
      20              : impl ChannelBinding {
      21              :     /// Creates a `ChannelBinding` containing no information.
      22            0 :     pub fn none() -> ChannelBinding {
      23            0 :         ChannelBinding {
      24            0 :             tls_server_end_point: None,
      25            0 :         }
      26            0 :     }
      27              : 
      28              :     /// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information.
      29           12 :     pub fn tls_server_end_point(tls_server_end_point: Vec<u8>) -> ChannelBinding {
      30           12 :         ChannelBinding {
      31           12 :             tls_server_end_point: Some(tls_server_end_point),
      32           12 :         }
      33           12 :     }
      34              : }
      35              : 
      36              : /// A constructor of `TlsConnect`ors.
      37              : ///
      38              : /// Requires the `runtime` Cargo feature (enabled by default).
      39              : pub trait MakeTlsConnect<S> {
      40              :     /// The stream type created by the `TlsConnect` implementation.
      41              :     type Stream: TlsStream + Unpin;
      42              :     /// The `TlsConnect` implementation created by this type.
      43              :     type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
      44              :     /// The error type returned by the `TlsConnect` implementation.
      45              :     type Error: Into<Box<dyn Error + Sync + Send>>;
      46              : 
      47              :     /// Creates a new `TlsConnect`or.
      48              :     ///
      49              :     /// The domain name is provided for certificate verification and SNI.
      50              :     fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
      51              : }
      52              : 
      53              : /// An asynchronous function wrapping a stream in a TLS session.
      54              : pub trait TlsConnect<S> {
      55              :     /// The stream returned by the future.
      56              :     type Stream: TlsStream + Unpin;
      57              :     /// The error returned by the future.
      58              :     type Error: Into<Box<dyn Error + Sync + Send>>;
      59              :     /// The future returned by the connector.
      60              :     type Future: Future<Output = Result<Self::Stream, Self::Error>>;
      61              : 
      62              :     /// Returns a future performing a TLS handshake over the stream.
      63              :     fn connect(self, stream: S) -> Self::Future;
      64              : 
      65              :     #[doc(hidden)]
      66            0 :     fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
      67            0 :         true
      68            0 :     }
      69              : }
      70              : 
      71              : /// A TLS-wrapped connection to a PostgreSQL database.
      72              : pub trait TlsStream: AsyncRead + AsyncWrite {
      73              :     /// Returns channel binding information for the session.
      74              :     fn channel_binding(&self) -> ChannelBinding;
      75              : }
      76              : 
      77              : /// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
      78              : ///
      79              : /// This can be used when `sslmode` is `none` or `prefer`.
      80              : #[derive(Debug, Copy, Clone)]
      81              : pub struct NoTls;
      82              : 
      83              : impl<S> MakeTlsConnect<S> for NoTls {
      84              :     type Stream = NoTlsStream;
      85              :     type TlsConnect = NoTls;
      86              :     type Error = NoTlsError;
      87              : 
      88            0 :     fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
      89            0 :         Ok(NoTls)
      90            0 :     }
      91              : }
      92              : 
      93              : impl<S> TlsConnect<S> for NoTls {
      94              :     type Stream = NoTlsStream;
      95              :     type Error = NoTlsError;
      96              :     type Future = NoTlsFuture;
      97              : 
      98            0 :     fn connect(self, _: S) -> NoTlsFuture {
      99            0 :         NoTlsFuture(())
     100            0 :     }
     101              : 
     102            1 :     fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
     103            1 :         false
     104            1 :     }
     105              : }
     106              : 
     107              : /// The future returned by `NoTls`.
     108              : pub struct NoTlsFuture(());
     109              : 
     110              : impl Future for NoTlsFuture {
     111              :     type Output = Result<NoTlsStream, NoTlsError>;
     112              : 
     113            0 :     fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
     114            0 :         Poll::Ready(Err(NoTlsError(())))
     115            0 :     }
     116              : }
     117              : 
     118              : /// The TLS "stream" type produced by the `NoTls` connector.
     119              : ///
     120              : /// Since `NoTls` doesn't support TLS, this type is uninhabited.
     121              : pub enum NoTlsStream {}
     122              : 
     123              : impl AsyncRead for NoTlsStream {
     124            0 :     fn poll_read(
     125            0 :         self: Pin<&mut Self>,
     126            0 :         _: &mut Context<'_>,
     127            0 :         _: &mut ReadBuf<'_>,
     128            0 :     ) -> Poll<io::Result<()>> {
     129            0 :         match *self {}
     130              :     }
     131              : }
     132              : 
     133              : impl AsyncWrite for NoTlsStream {
     134            0 :     fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll<io::Result<usize>> {
     135            0 :         match *self {}
     136              :     }
     137              : 
     138            0 :     fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
     139            0 :         match *self {}
     140              :     }
     141              : 
     142            0 :     fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
     143            0 :         match *self {}
     144              :     }
     145              : }
     146              : 
     147              : impl TlsStream for NoTlsStream {
     148              :     fn channel_binding(&self) -> ChannelBinding {
     149              :         match *self {}
     150              :     }
     151              : }
     152              : 
     153              : /// The error returned by `NoTls`.
     154              : #[derive(Debug)]
     155              : pub struct NoTlsError(());
     156              : 
     157              : impl fmt::Display for NoTlsError {
     158            0 :     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
     159            0 :         fmt.write_str("no TLS implementation configured")
     160            0 :     }
     161              : }
     162              : 
     163              : impl Error for NoTlsError {}
        

Generated by: LCOV version 2.1-beta