Line data Source code
1 : use std::net::IpAddr;
2 :
3 : use futures_util::TryStreamExt;
4 : use postgres_protocol2::message::backend::Message;
5 : use tokio::io::{AsyncRead, AsyncWrite};
6 : use tokio::net::TcpStream;
7 : use tokio::sync::mpsc;
8 :
9 : use crate::client::SocketConfig;
10 : use crate::config::{Host, SslMode};
11 : use crate::connect_raw::StartupStream;
12 : use crate::connect_socket::connect_socket;
13 : use crate::tls::{MakeTlsConnect, TlsConnect};
14 : use crate::{Client, Config, Connection, Error};
15 :
16 0 : pub async fn connect<T>(
17 0 : tls: &T,
18 0 : config: &Config,
19 0 : ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
20 0 : where
21 0 : T: MakeTlsConnect<TcpStream>,
22 0 : {
23 0 : let hostname = match &config.host {
24 0 : Host::Tcp(host) => host.as_str(),
25 : };
26 :
27 0 : let tls = tls
28 0 : .make_tls_connect(hostname)
29 0 : .map_err(|e| Error::tls(e.into()))?;
30 :
31 0 : match connect_once(config.host_addr, &config.host, config.port, tls, config).await {
32 0 : Ok((client, connection)) => Ok((client, connection)),
33 0 : Err(e) => Err(e),
34 : }
35 0 : }
36 :
37 0 : async fn connect_once<T>(
38 0 : host_addr: Option<IpAddr>,
39 0 : host: &Host,
40 0 : port: u16,
41 0 : tls: T,
42 0 : config: &Config,
43 0 : ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
44 0 : where
45 0 : T: TlsConnect<TcpStream>,
46 0 : {
47 0 : let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
48 0 : let stream = config.tls_and_authenticate(socket, tls).await?;
49 0 : managed(
50 0 : stream,
51 0 : host_addr,
52 0 : host.clone(),
53 0 : port,
54 0 : config.ssl_mode,
55 0 : config.connect_timeout,
56 0 : )
57 0 : .await
58 0 : }
59 :
60 0 : pub async fn managed<TlsStream>(
61 0 : mut stream: StartupStream<TcpStream, TlsStream>,
62 0 : host_addr: Option<IpAddr>,
63 0 : host: Host,
64 0 : port: u16,
65 0 : ssl_mode: SslMode,
66 0 : connect_timeout: Option<std::time::Duration>,
67 0 : ) -> Result<(Client, Connection<TcpStream, TlsStream>), Error>
68 0 : where
69 0 : TlsStream: AsyncRead + AsyncWrite + Unpin,
70 0 : {
71 0 : let (process_id, secret_key) = wait_until_ready(&mut stream).await?;
72 :
73 0 : let socket_config = SocketConfig {
74 0 : host_addr,
75 0 : host,
76 0 : port,
77 0 : connect_timeout,
78 0 : };
79 :
80 0 : let mut stream = stream.into_framed();
81 0 : let write_buf = std::mem::take(stream.write_buffer_mut());
82 :
83 0 : let (client_tx, conn_rx) = mpsc::unbounded_channel();
84 0 : let (conn_tx, client_rx) = mpsc::channel(4);
85 0 : let client = Client::new(
86 0 : client_tx,
87 0 : client_rx,
88 0 : socket_config,
89 0 : ssl_mode,
90 0 : process_id,
91 0 : secret_key,
92 0 : write_buf,
93 : );
94 :
95 0 : let connection = Connection::new(stream, conn_tx, conn_rx);
96 :
97 0 : Ok((client, connection))
98 0 : }
99 :
100 0 : async fn wait_until_ready<S, T>(stream: &mut StartupStream<S, T>) -> Result<(i32, i32), Error>
101 0 : where
102 0 : S: AsyncRead + AsyncWrite + Unpin,
103 0 : T: AsyncRead + AsyncWrite + Unpin,
104 0 : {
105 0 : let mut process_id = 0;
106 0 : let mut secret_key = 0;
107 :
108 : loop {
109 0 : match stream.try_next().await.map_err(Error::io)? {
110 0 : Some(Message::BackendKeyData(body)) => {
111 0 : process_id = body.process_id();
112 0 : secret_key = body.secret_key();
113 0 : }
114 : // These values are currently not used by `Client`/`Connection`. Ignore them.
115 0 : Some(Message::ParameterStatus(_)) | Some(Message::NoticeResponse(_)) => {}
116 0 : Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key)),
117 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
118 0 : Some(_) => return Err(Error::unexpected_message()),
119 0 : None => return Err(Error::closed()),
120 : }
121 : }
122 0 : }
|