Line data Source code
1 : /// Test postgres_backend_async with tokio_postgres
2 : use once_cell::sync::Lazy;
3 : use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
4 : use pq_proto::{BeMessage, RowDescriptor};
5 : use std::io::Cursor;
6 : use std::{future, sync::Arc};
7 : use tokio::io::{AsyncRead, AsyncWrite};
8 : use tokio::net::{TcpListener, TcpStream};
9 : use tokio_postgres::config::SslMode;
10 : use tokio_postgres::tls::MakeTlsConnect;
11 : use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
12 : use tokio_postgres_rustls::MakeRustlsConnect;
13 :
14 : // generate client, server test streams
15 4 : async fn make_tcp_pair() -> (TcpStream, TcpStream) {
16 4 : let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
17 4 : let addr = listener.local_addr().unwrap();
18 4 : let client_stream = TcpStream::connect(addr).await.unwrap();
19 4 : let (server_stream, _) = listener.accept().await.unwrap();
20 4 : (client_stream, server_stream)
21 4 : }
22 :
23 : struct TestHandler {}
24 :
25 : #[async_trait::async_trait]
26 : impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
27 : // return single col 'hey' for any query
28 4 : async fn process_query(
29 4 : &mut self,
30 4 : pgb: &mut PostgresBackend<IO>,
31 4 : _query_string: &str,
32 4 : ) -> Result<(), QueryError> {
33 4 : pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
34 4 : b"hey",
35 4 : )]))?
36 4 : .write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
37 4 : .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
38 4 : Ok(())
39 8 : }
40 : }
41 :
42 : // test that basic select works
43 2 : #[tokio::test]
44 2 : async fn simple_select() {
45 2 : let (client_sock, server_sock) = make_tcp_pair().await;
46 :
47 : // create and run pgbackend
48 2 : let pgbackend =
49 2 : PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
50 2 :
51 2 : tokio::spawn(async move {
52 2 : let mut handler = TestHandler {};
53 4 : pgbackend.run(&mut handler, future::pending::<()>).await
54 2 : });
55 2 :
56 2 : let conf = Config::new();
57 2 : let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
58 2 : // The connection object performs the actual communication with the database,
59 2 : // so spawn it off to run on its own.
60 2 : tokio::spawn(async move {
61 2 : if let Err(e) = connection.await {
62 0 : eprintln!("connection error: {}", e);
63 0 : }
64 2 : });
65 :
66 2 : let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
67 2 : if let SimpleQueryMessage::Row(row) = first_val {
68 2 : let first_col = row.get(0).expect("first column");
69 2 : assert_eq!(first_col, "hey");
70 : } else {
71 0 : panic!("expected SimpleQueryMessage::Row");
72 : }
73 : }
74 :
75 2 : static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
76 2 : let mut cursor = Cursor::new(include_bytes!("key.pem"));
77 2 : rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
78 2 : });
79 :
80 2 : static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
81 2 : let mut cursor = Cursor::new(include_bytes!("cert.pem"));
82 2 : rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
83 2 : });
84 :
85 : // test that basic select with ssl works
86 2 : #[tokio::test]
87 2 : async fn simple_select_ssl() {
88 2 : let (client_sock, server_sock) = make_tcp_pair().await;
89 :
90 2 : let server_cfg = rustls::ServerConfig::builder()
91 2 : .with_safe_defaults()
92 2 : .with_no_client_auth()
93 2 : .with_single_cert(vec![CERT.clone()], KEY.clone())
94 2 : .unwrap();
95 2 : let tls_config = Some(Arc::new(server_cfg));
96 2 : let pgbackend =
97 2 : PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
98 2 :
99 2 : tokio::spawn(async move {
100 2 : let mut handler = TestHandler {};
101 10 : pgbackend.run(&mut handler, future::pending::<()>).await
102 2 : });
103 2 :
104 2 : let client_cfg = rustls::ClientConfig::builder()
105 2 : .with_safe_defaults()
106 2 : .with_root_certificates({
107 2 : let mut store = rustls::RootCertStore::empty();
108 2 : store.add(&CERT).unwrap();
109 2 : store
110 2 : })
111 2 : .with_no_client_auth();
112 2 : let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
113 2 : let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
114 2 : &mut make_tls_connect,
115 2 : "localhost",
116 2 : )
117 2 : .expect("make_tls_connect");
118 2 :
119 2 : let mut conf = Config::new();
120 2 : conf.ssl_mode(SslMode::Require);
121 2 : let (client, connection) = conf
122 2 : .connect_raw(client_sock, tls_connect)
123 8 : .await
124 2 : .expect("connect");
125 2 : // The connection object performs the actual communication with the database,
126 2 : // so spawn it off to run on its own.
127 2 : tokio::spawn(async move {
128 2 : if let Err(e) = connection.await {
129 0 : eprintln!("connection error: {}", e);
130 0 : }
131 2 : });
132 :
133 2 : let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
134 2 : if let SimpleQueryMessage::Row(row) = first_val {
135 2 : let first_col = row.get(0).expect("first column");
136 2 : assert_eq!(first_col, "hey");
137 : } else {
138 0 : panic!("expected SimpleQueryMessage::Row");
139 : }
140 : }
|