Line data Source code
1 : use std::io::Cursor;
2 : use std::sync::Arc;
3 :
4 : /// Test postgres_backend_async with tokio_postgres
5 : use once_cell::sync::Lazy;
6 : use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
7 : use pq_proto::{BeMessage, RowDescriptor};
8 : use rustls::crypto::ring;
9 : use tokio::io::{AsyncRead, AsyncWrite};
10 : use tokio::net::{TcpListener, TcpStream};
11 : use tokio_postgres::config::SslMode;
12 : use tokio_postgres::tls::MakeTlsConnect;
13 : use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
14 : use tokio_postgres_rustls::MakeRustlsConnect;
15 : use tokio_util::sync::CancellationToken;
16 :
17 : // generate client, server test streams
18 2 : async fn make_tcp_pair() -> (TcpStream, TcpStream) {
19 2 : let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
20 2 : let addr = listener.local_addr().unwrap();
21 2 : let client_stream = TcpStream::connect(addr).await.unwrap();
22 2 : let (server_stream, _) = listener.accept().await.unwrap();
23 2 : (client_stream, server_stream)
24 2 : }
25 :
26 : struct TestHandler {}
27 :
28 : impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
29 : // return single col 'hey' for any query
30 2 : async fn process_query(
31 2 : &mut self,
32 2 : pgb: &mut PostgresBackend<IO>,
33 2 : _query_string: &str,
34 2 : ) -> Result<(), QueryError> {
35 2 : pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
36 2 : b"hey",
37 2 : )]))?
38 2 : .write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
39 2 : .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
40 2 : Ok(())
41 2 : }
42 : }
43 :
44 : // test that basic select works
45 : #[tokio::test]
46 1 : async fn simple_select() {
47 1 : let (client_sock, server_sock) = make_tcp_pair().await;
48 1 :
49 1 : // create and run pgbackend
50 1 : let pgbackend =
51 1 : PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
52 1 :
53 1 : tokio::spawn(async move {
54 1 : let mut handler = TestHandler {};
55 1 : pgbackend.run(&mut handler, &CancellationToken::new()).await
56 1 : });
57 1 :
58 1 : let conf = Config::new();
59 1 : let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
60 1 : // The connection object performs the actual communication with the database,
61 1 : // so spawn it off to run on its own.
62 1 : tokio::spawn(async move {
63 1 : if let Err(e) = connection.await {
64 0 : eprintln!("connection error: {}", e);
65 0 : }
66 1 : });
67 1 :
68 1 : let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
69 1 : if let SimpleQueryMessage::Row(row) = first_val {
70 1 : let first_col = row.get(0).expect("first column");
71 1 : assert_eq!(first_col, "hey");
72 1 : } else {
73 1 : panic!("expected SimpleQueryMessage::Row");
74 1 : }
75 1 : }
76 :
77 1 : static KEY: Lazy<rustls::pki_types::PrivateKeyDer<'static>> = Lazy::new(|| {
78 1 : let mut cursor = Cursor::new(include_bytes!("key.pem"));
79 1 : let key = rustls_pemfile::rsa_private_keys(&mut cursor)
80 1 : .next()
81 1 : .unwrap()
82 1 : .unwrap();
83 1 : rustls::pki_types::PrivateKeyDer::Pkcs1(key)
84 1 : });
85 :
86 1 : static CERT: Lazy<rustls::pki_types::CertificateDer<'static>> = Lazy::new(|| {
87 1 : let mut cursor = Cursor::new(include_bytes!("cert.pem"));
88 1 : let cert = rustls_pemfile::certs(&mut cursor).next().unwrap().unwrap();
89 1 : cert
90 1 : });
91 :
92 : // test that basic select with ssl works
93 : #[tokio::test]
94 1 : async fn simple_select_ssl() {
95 1 : let (client_sock, server_sock) = make_tcp_pair().await;
96 1 :
97 1 : let server_cfg =
98 1 : rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
99 1 : .with_safe_default_protocol_versions()
100 1 : .expect("aws_lc_rs should support the default protocol versions")
101 1 : .with_no_client_auth()
102 1 : .with_single_cert(vec![CERT.clone()], KEY.clone_key())
103 1 : .unwrap();
104 1 : let tls_config = Some(Arc::new(server_cfg));
105 1 : let pgbackend =
106 1 : PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
107 1 :
108 1 : tokio::spawn(async move {
109 1 : let mut handler = TestHandler {};
110 1 : pgbackend.run(&mut handler, &CancellationToken::new()).await
111 1 : });
112 1 :
113 1 : let client_cfg =
114 1 : rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
115 1 : .with_safe_default_protocol_versions()
116 1 : .expect("aws_lc_rs should support the default protocol versions")
117 1 : .with_root_certificates({
118 1 : let mut store = rustls::RootCertStore::empty();
119 1 : store.add(CERT.clone()).unwrap();
120 1 : store
121 1 : })
122 1 : .with_no_client_auth();
123 1 : let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
124 1 : let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
125 1 : &mut make_tls_connect,
126 1 : "localhost",
127 1 : )
128 1 : .expect("make_tls_connect");
129 1 :
130 1 : let mut conf = Config::new();
131 1 : conf.ssl_mode(SslMode::Require);
132 1 : let (client, connection) = conf
133 1 : .connect_raw(client_sock, tls_connect)
134 1 : .await
135 1 : .expect("connect");
136 1 : // The connection object performs the actual communication with the database,
137 1 : // so spawn it off to run on its own.
138 1 : tokio::spawn(async move {
139 1 : if let Err(e) = connection.await {
140 0 : eprintln!("connection error: {}", e);
141 0 : }
142 1 : });
143 1 :
144 1 : let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
145 1 : if let SimpleQueryMessage::Row(row) = first_val {
146 1 : let first_col = row.get(0).expect("first column");
147 1 : assert_eq!(first_col, "hey");
148 1 : } else {
149 1 : panic!("expected SimpleQueryMessage::Row");
150 1 : }
151 1 : }
|