LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connect.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 0.0 % 89 0
Test Date: 2025-07-26 17:20:05 Functions: 0.0 % 22 0

            Line data    Source code
       1              : use std::net::IpAddr;
       2              : 
       3              : use futures_util::TryStreamExt;
       4              : use postgres_protocol2::message::backend::Message;
       5              : use tokio::io::{AsyncRead, AsyncWrite};
       6              : use tokio::net::TcpStream;
       7              : use tokio::sync::mpsc;
       8              : 
       9              : use crate::client::SocketConfig;
      10              : use crate::config::{Host, SslMode};
      11              : use crate::connect_raw::StartupStream;
      12              : use crate::connect_socket::connect_socket;
      13              : use crate::tls::{MakeTlsConnect, TlsConnect};
      14              : use crate::{Client, Config, Connection, Error};
      15              : 
      16            0 : pub async fn connect<T>(
      17            0 :     tls: &T,
      18            0 :     config: &Config,
      19            0 : ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
      20            0 : where
      21            0 :     T: MakeTlsConnect<TcpStream>,
      22            0 : {
      23            0 :     let hostname = match &config.host {
      24            0 :         Host::Tcp(host) => host.as_str(),
      25              :     };
      26              : 
      27            0 :     let tls = tls
      28            0 :         .make_tls_connect(hostname)
      29            0 :         .map_err(|e| Error::tls(e.into()))?;
      30              : 
      31            0 :     match connect_once(config.host_addr, &config.host, config.port, tls, config).await {
      32            0 :         Ok((client, connection)) => Ok((client, connection)),
      33            0 :         Err(e) => Err(e),
      34              :     }
      35            0 : }
      36              : 
      37            0 : async fn connect_once<T>(
      38            0 :     host_addr: Option<IpAddr>,
      39            0 :     host: &Host,
      40            0 :     port: u16,
      41            0 :     tls: T,
      42            0 :     config: &Config,
      43            0 : ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
      44            0 : where
      45            0 :     T: TlsConnect<TcpStream>,
      46            0 : {
      47            0 :     let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
      48            0 :     let stream = config.tls_and_authenticate(socket, tls).await?;
      49            0 :     managed(
      50            0 :         stream,
      51            0 :         host_addr,
      52            0 :         host.clone(),
      53            0 :         port,
      54            0 :         config.ssl_mode,
      55            0 :         config.connect_timeout,
      56            0 :     )
      57            0 :     .await
      58            0 : }
      59              : 
      60            0 : pub async fn managed<TlsStream>(
      61            0 :     mut stream: StartupStream<TcpStream, TlsStream>,
      62            0 :     host_addr: Option<IpAddr>,
      63            0 :     host: Host,
      64            0 :     port: u16,
      65            0 :     ssl_mode: SslMode,
      66            0 :     connect_timeout: Option<std::time::Duration>,
      67            0 : ) -> Result<(Client, Connection<TcpStream, TlsStream>), Error>
      68            0 : where
      69            0 :     TlsStream: AsyncRead + AsyncWrite + Unpin,
      70            0 : {
      71            0 :     let (process_id, secret_key) = wait_until_ready(&mut stream).await?;
      72              : 
      73            0 :     let socket_config = SocketConfig {
      74            0 :         host_addr,
      75            0 :         host,
      76            0 :         port,
      77            0 :         connect_timeout,
      78            0 :     };
      79              : 
      80            0 :     let mut stream = stream.into_framed();
      81            0 :     let write_buf = std::mem::take(stream.write_buffer_mut());
      82              : 
      83            0 :     let (client_tx, conn_rx) = mpsc::unbounded_channel();
      84            0 :     let (conn_tx, client_rx) = mpsc::channel(4);
      85            0 :     let client = Client::new(
      86            0 :         client_tx,
      87            0 :         client_rx,
      88            0 :         socket_config,
      89            0 :         ssl_mode,
      90            0 :         process_id,
      91            0 :         secret_key,
      92            0 :         write_buf,
      93              :     );
      94              : 
      95            0 :     let connection = Connection::new(stream, conn_tx, conn_rx);
      96              : 
      97            0 :     Ok((client, connection))
      98            0 : }
      99              : 
     100            0 : async fn wait_until_ready<S, T>(stream: &mut StartupStream<S, T>) -> Result<(i32, i32), Error>
     101            0 : where
     102            0 :     S: AsyncRead + AsyncWrite + Unpin,
     103            0 :     T: AsyncRead + AsyncWrite + Unpin,
     104            0 : {
     105            0 :     let mut process_id = 0;
     106            0 :     let mut secret_key = 0;
     107              : 
     108              :     loop {
     109            0 :         match stream.try_next().await.map_err(Error::io)? {
     110            0 :             Some(Message::BackendKeyData(body)) => {
     111            0 :                 process_id = body.process_id();
     112            0 :                 secret_key = body.secret_key();
     113            0 :             }
     114              :             // These values are currently not used by `Client`/`Connection`. Ignore them.
     115            0 :             Some(Message::ParameterStatus(_)) | Some(Message::NoticeResponse(_)) => {}
     116            0 :             Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key)),
     117            0 :             Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     118            0 :             Some(_) => return Err(Error::unexpected_message()),
     119            0 :             None => return Err(Error::closed()),
     120              :         }
     121              :     }
     122            0 : }
        

Generated by: LCOV version 2.1-beta