Line data Source code
1 : /// Test postgres_backend_async with tokio_postgres
2 : use once_cell::sync::Lazy;
3 : use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
4 : use pq_proto::{BeMessage, RowDescriptor};
5 : use std::io::Cursor;
6 : use std::sync::Arc;
7 : use tokio::io::{AsyncRead, AsyncWrite};
8 : use tokio::net::{TcpListener, TcpStream};
9 : use tokio_postgres::config::SslMode;
10 : use tokio_postgres::tls::MakeTlsConnect;
11 : use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
12 : use tokio_postgres_rustls::MakeRustlsConnect;
13 : use tokio_util::sync::CancellationToken;
14 :
15 : // generate client, server test streams
16 12 : async fn make_tcp_pair() -> (TcpStream, TcpStream) {
17 12 : let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
18 12 : let addr = listener.local_addr().unwrap();
19 12 : let client_stream = TcpStream::connect(addr).await.unwrap();
20 12 : let (server_stream, _) = listener.accept().await.unwrap();
21 12 : (client_stream, server_stream)
22 12 : }
23 :
24 : struct TestHandler {}
25 :
26 : #[async_trait::async_trait]
27 : impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
28 : // return single col 'hey' for any query
29 12 : async fn process_query(
30 12 : &mut self,
31 12 : pgb: &mut PostgresBackend<IO>,
32 12 : _query_string: &str,
33 12 : ) -> Result<(), QueryError> {
34 12 : pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
35 12 : b"hey",
36 12 : )]))?
37 12 : .write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
38 12 : .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
39 12 : Ok(())
40 12 : }
41 : }
42 :
43 : // test that basic select works
44 : #[tokio::test]
45 6 : async fn simple_select() {
46 6 : let (client_sock, server_sock) = make_tcp_pair().await;
47 6 :
48 6 : // create and run pgbackend
49 6 : let pgbackend =
50 6 : PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
51 6 :
52 6 : tokio::spawn(async move {
53 6 : let mut handler = TestHandler {};
54 12 : pgbackend.run(&mut handler, &CancellationToken::new()).await
55 6 : });
56 6 :
57 6 : let conf = Config::new();
58 6 : let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
59 6 : // The connection object performs the actual communication with the database,
60 6 : // so spawn it off to run on its own.
61 6 : tokio::spawn(async move {
62 6 : if let Err(e) = connection.await {
63 0 : eprintln!("connection error: {}", e);
64 0 : }
65 6 : });
66 6 :
67 6 : let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
68 6 : if let SimpleQueryMessage::Row(row) = first_val {
69 6 : let first_col = row.get(0).expect("first column");
70 6 : assert_eq!(first_col, "hey");
71 6 : } else {
72 6 : panic!("expected SimpleQueryMessage::Row");
73 6 : }
74 6 : }
75 :
76 6 : static KEY: Lazy<rustls::pki_types::PrivateKeyDer<'static>> = Lazy::new(|| {
77 6 : let mut cursor = Cursor::new(include_bytes!("key.pem"));
78 6 : let key = rustls_pemfile::rsa_private_keys(&mut cursor)
79 6 : .next()
80 6 : .unwrap()
81 6 : .unwrap();
82 6 : rustls::pki_types::PrivateKeyDer::Pkcs1(key)
83 6 : });
84 :
85 6 : static CERT: Lazy<rustls::pki_types::CertificateDer<'static>> = Lazy::new(|| {
86 6 : let mut cursor = Cursor::new(include_bytes!("cert.pem"));
87 6 : let cert = rustls_pemfile::certs(&mut cursor).next().unwrap().unwrap();
88 6 : cert
89 6 : });
90 :
91 : // test that basic select with ssl works
92 : #[tokio::test]
93 6 : async fn simple_select_ssl() {
94 6 : let (client_sock, server_sock) = make_tcp_pair().await;
95 6 :
96 6 : let server_cfg = rustls::ServerConfig::builder()
97 6 : .with_no_client_auth()
98 6 : .with_single_cert(vec![CERT.clone()], KEY.clone_key())
99 6 : .unwrap();
100 6 : let tls_config = Some(Arc::new(server_cfg));
101 6 : let pgbackend =
102 6 : PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
103 6 :
104 6 : tokio::spawn(async move {
105 6 : let mut handler = TestHandler {};
106 30 : pgbackend.run(&mut handler, &CancellationToken::new()).await
107 6 : });
108 6 :
109 6 : let client_cfg = rustls::ClientConfig::builder()
110 6 : .with_root_certificates({
111 6 : let mut store = rustls::RootCertStore::empty();
112 6 : store.add(CERT.clone()).unwrap();
113 6 : store
114 6 : })
115 6 : .with_no_client_auth();
116 6 : let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
117 6 : let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
118 6 : &mut make_tls_connect,
119 6 : "localhost",
120 6 : )
121 6 : .expect("make_tls_connect");
122 6 :
123 6 : let mut conf = Config::new();
124 6 : conf.ssl_mode(SslMode::Require);
125 6 : let (client, connection) = conf
126 6 : .connect_raw(client_sock, tls_connect)
127 24 : .await
128 6 : .expect("connect");
129 6 : // The connection object performs the actual communication with the database,
130 6 : // so spawn it off to run on its own.
131 6 : tokio::spawn(async move {
132 6 : if let Err(e) = connection.await {
133 0 : eprintln!("connection error: {}", e);
134 0 : }
135 6 : });
136 6 :
137 6 : let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
138 6 : if let SimpleQueryMessage::Row(row) = first_val {
139 6 : let first_col = row.get(0).expect("first column");
140 6 : assert_eq!(first_col, "hey");
141 6 : } else {
142 6 : panic!("expected SimpleQueryMessage::Row");
143 6 : }
144 6 : }
|