LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - tls.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 20.0 % 40 8
Test Date: 2025-02-20 13:11:02 Functions: 10.5 % 19 2

            Line data    Source code
       1              : //! TLS support.
       2              : 
       3              : use std::error::Error;
       4              : use std::future::Future;
       5              : use std::pin::Pin;
       6              : use std::task::{Context, Poll};
       7              : use std::{fmt, io};
       8              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
       9              : 
      10              : pub(crate) mod private {
      11              :     pub struct ForcePrivateApi;
      12              : }
      13              : 
      14              : /// Channel binding information returned from a TLS handshake.
      15              : pub struct ChannelBinding {
      16              :     pub(crate) tls_server_end_point: Option<Vec<u8>>,
      17              : }
      18              : 
      19              : impl ChannelBinding {
      20              :     /// Creates a `ChannelBinding` containing no information.
      21            0 :     pub fn none() -> ChannelBinding {
      22            0 :         ChannelBinding {
      23            0 :             tls_server_end_point: None,
      24            0 :         }
      25            0 :     }
      26              : 
      27              :     /// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information.
      28           12 :     pub fn tls_server_end_point(tls_server_end_point: Vec<u8>) -> ChannelBinding {
      29           12 :         ChannelBinding {
      30           12 :             tls_server_end_point: Some(tls_server_end_point),
      31           12 :         }
      32           12 :     }
      33              : }
      34              : 
      35              : /// A constructor of `TlsConnect`ors.
      36              : ///
      37              : /// Requires the `runtime` Cargo feature (enabled by default).
      38              : pub trait MakeTlsConnect<S> {
      39              :     /// The stream type created by the `TlsConnect` implementation.
      40              :     type Stream: TlsStream + Unpin;
      41              :     /// The `TlsConnect` implementation created by this type.
      42              :     type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
      43              :     /// The error type returned by the `TlsConnect` implementation.
      44              :     type Error: Into<Box<dyn Error + Sync + Send>>;
      45              : 
      46              :     /// Creates a new `TlsConnect`or.
      47              :     ///
      48              :     /// The domain name is provided for certificate verification and SNI.
      49              :     fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
      50              : }
      51              : 
      52              : /// An asynchronous function wrapping a stream in a TLS session.
      53              : pub trait TlsConnect<S> {
      54              :     /// The stream returned by the future.
      55              :     type Stream: TlsStream + Unpin;
      56              :     /// The error returned by the future.
      57              :     type Error: Into<Box<dyn Error + Sync + Send>>;
      58              :     /// The future returned by the connector.
      59              :     type Future: Future<Output = Result<Self::Stream, Self::Error>>;
      60              : 
      61              :     /// Returns a future performing a TLS handshake over the stream.
      62              :     fn connect(self, stream: S) -> Self::Future;
      63              : 
      64              :     #[doc(hidden)]
      65            0 :     fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
      66            0 :         true
      67            0 :     }
      68              : }
      69              : 
      70              : /// A TLS-wrapped connection to a PostgreSQL database.
      71              : pub trait TlsStream: AsyncRead + AsyncWrite {
      72              :     /// Returns channel binding information for the session.
      73              :     fn channel_binding(&self) -> ChannelBinding;
      74              : }
      75              : 
      76              : /// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
      77              : ///
      78              : /// This can be used when `sslmode` is `none` or `prefer`.
      79              : #[derive(Debug, Copy, Clone)]
      80              : pub struct NoTls;
      81              : 
      82              : impl<S> MakeTlsConnect<S> for NoTls {
      83              :     type Stream = NoTlsStream;
      84              :     type TlsConnect = NoTls;
      85              :     type Error = NoTlsError;
      86              : 
      87            0 :     fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
      88            0 :         Ok(NoTls)
      89            0 :     }
      90              : }
      91              : 
      92              : impl<S> TlsConnect<S> for NoTls {
      93              :     type Stream = NoTlsStream;
      94              :     type Error = NoTlsError;
      95              :     type Future = NoTlsFuture;
      96              : 
      97            0 :     fn connect(self, _: S) -> NoTlsFuture {
      98            0 :         NoTlsFuture(())
      99            0 :     }
     100              : 
     101            1 :     fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
     102            1 :         false
     103            1 :     }
     104              : }
     105              : 
     106              : /// The future returned by `NoTls`.
     107              : pub struct NoTlsFuture(());
     108              : 
     109              : impl Future for NoTlsFuture {
     110              :     type Output = Result<NoTlsStream, NoTlsError>;
     111              : 
     112            0 :     fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
     113            0 :         Poll::Ready(Err(NoTlsError(())))
     114            0 :     }
     115              : }
     116              : 
     117              : /// The TLS "stream" type produced by the `NoTls` connector.
     118              : ///
     119              : /// Since `NoTls` doesn't support TLS, this type is uninhabited.
     120              : pub enum NoTlsStream {}
     121              : 
     122              : impl AsyncRead for NoTlsStream {
     123            0 :     fn poll_read(
     124            0 :         self: Pin<&mut Self>,
     125            0 :         _: &mut Context<'_>,
     126            0 :         _: &mut ReadBuf<'_>,
     127            0 :     ) -> Poll<io::Result<()>> {
     128            0 :         match *self {}
     129              :     }
     130              : }
     131              : 
     132              : impl AsyncWrite for NoTlsStream {
     133            0 :     fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll<io::Result<usize>> {
     134            0 :         match *self {}
     135              :     }
     136              : 
     137            0 :     fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
     138            0 :         match *self {}
     139              :     }
     140              : 
     141            0 :     fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
     142            0 :         match *self {}
     143              :     }
     144              : }
     145              : 
     146              : impl TlsStream for NoTlsStream {
     147              :     fn channel_binding(&self) -> ChannelBinding {
     148              :         match *self {}
     149              :     }
     150              : }
     151              : 
     152              : /// The error returned by `NoTls`.
     153              : #[derive(Debug)]
     154              : pub struct NoTlsError(());
     155              : 
     156              : impl fmt::Display for NoTlsError {
     157            0 :     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
     158            0 :         fmt.write_str("no TLS implementation configured")
     159            0 :     }
     160              : }
     161              : 
     162              : impl Error for NoTlsError {}
        

Generated by: LCOV version 2.1-beta