LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connect_tls.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 90.6 % 32 29
Test Date: 2025-02-20 13:11:02 Functions: 26.7 % 15 4

            Line data    Source code
       1              : use crate::config::SslMode;
       2              : use crate::maybe_tls_stream::MaybeTlsStream;
       3              : use crate::tls::private::ForcePrivateApi;
       4              : use crate::tls::TlsConnect;
       5              : use crate::Error;
       6              : use bytes::BytesMut;
       7              : use postgres_protocol2::message::frontend;
       8              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
       9              : 
      10           15 : pub async fn connect_tls<S, T>(
      11           15 :     mut stream: S,
      12           15 :     mode: SslMode,
      13           15 :     tls: T,
      14           15 : ) -> Result<MaybeTlsStream<S, T::Stream>, Error>
      15           15 : where
      16           15 :     S: AsyncRead + AsyncWrite + Unpin,
      17           15 :     T: TlsConnect<S>,
      18           15 : {
      19            1 :     match mode {
      20            1 :         SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
      21            1 :         SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
      22            1 :             return Ok(MaybeTlsStream::Raw(stream))
      23              :         }
      24           13 :         SslMode::Prefer | SslMode::Require => {}
      25           13 :     }
      26           13 : 
      27           13 :     let mut buf = BytesMut::new();
      28           13 :     frontend::ssl_request(&mut buf);
      29           13 :     stream.write_all(&buf).await.map_err(Error::io)?;
      30              : 
      31           13 :     let mut buf = [0];
      32           13 :     stream.read_exact(&mut buf).await.map_err(Error::io)?;
      33              : 
      34           13 :     if buf[0] != b'S' {
      35            0 :         if SslMode::Require == mode {
      36            0 :             return Err(Error::tls("server does not support TLS".into()));
      37              :         } else {
      38            0 :             return Ok(MaybeTlsStream::Raw(stream));
      39              :         }
      40           13 :     }
      41              : 
      42           13 :     let stream = tls
      43           13 :         .connect(stream)
      44           13 :         .await
      45           13 :         .map_err(|e| Error::tls(e.into()))?;
      46              : 
      47           13 :     Ok(MaybeTlsStream::Tls(stream))
      48           15 : }
        

Generated by: LCOV version 2.1-beta