LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connect_tls.rs (source / functions) Coverage Total Hit
Test: 5fe7fa8d483b39476409aee736d6d5e32728bfac.info Lines: 84.4 % 32 27
Test Date: 2025-03-12 16:10:49 Functions: 20.0 % 15 3

            Line data    Source code
       1              : use bytes::BytesMut;
       2              : use postgres_protocol2::message::frontend;
       3              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
       4              : 
       5              : use crate::Error;
       6              : use crate::config::SslMode;
       7              : use crate::maybe_tls_stream::MaybeTlsStream;
       8              : use crate::tls::TlsConnect;
       9              : use crate::tls::private::ForcePrivateApi;
      10              : 
      11           15 : pub async fn connect_tls<S, T>(
      12           15 :     mut stream: S,
      13           15 :     mode: SslMode,
      14           15 :     tls: T,
      15           15 : ) -> Result<MaybeTlsStream<S, T::Stream>, Error>
      16           15 : where
      17           15 :     S: AsyncRead + AsyncWrite + Unpin,
      18           15 :     T: TlsConnect<S>,
      19           15 : {
      20            1 :     match mode {
      21            1 :         SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
      22            1 :         SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
      23            1 :             return Ok(MaybeTlsStream::Raw(stream));
      24              :         }
      25           13 :         SslMode::Prefer | SslMode::Require => {}
      26           13 :     }
      27           13 : 
      28           13 :     let mut buf = BytesMut::new();
      29           13 :     frontend::ssl_request(&mut buf);
      30           13 :     stream.write_all(&buf).await.map_err(Error::io)?;
      31              : 
      32           13 :     let mut buf = [0];
      33           13 :     stream.read_exact(&mut buf).await.map_err(Error::io)?;
      34              : 
      35           13 :     if buf[0] != b'S' {
      36            0 :         if SslMode::Require == mode {
      37            0 :             return Err(Error::tls("server does not support TLS".into()));
      38              :         } else {
      39            0 :             return Ok(MaybeTlsStream::Raw(stream));
      40              :         }
      41            0 :     }
      42              : 
      43           13 :     let stream = tls
      44           13 :         .connect(stream)
      45           13 :         .await
      46           13 :         .map_err(|e| Error::tls(e.into()))?;
      47              : 
      48           13 :     Ok(MaybeTlsStream::Tls(stream))
      49            0 : }
        

Generated by: LCOV version 2.1-beta