LCOV - code coverage report
Current view: top level - libs/postgres_backend/tests - simple_select.rs (source / functions) Coverage Total Hit
Test: e402c46de0a007db6b48dddbde450ddbb92e6ceb.info Lines: 96.5 % 114 110
Test Date: 2024-06-25 10:31:23 Functions: 100.0 % 14 14

            Line data    Source code
       1              : /// Test postgres_backend_async with tokio_postgres
       2              : use once_cell::sync::Lazy;
       3              : use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
       4              : use pq_proto::{BeMessage, RowDescriptor};
       5              : use std::io::Cursor;
       6              : use std::{future, sync::Arc};
       7              : use tokio::io::{AsyncRead, AsyncWrite};
       8              : use tokio::net::{TcpListener, TcpStream};
       9              : use tokio_postgres::config::SslMode;
      10              : use tokio_postgres::tls::MakeTlsConnect;
      11              : use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
      12              : use tokio_postgres_rustls::MakeRustlsConnect;
      13              : 
      14              : // generate client, server test streams
      15            4 : async fn make_tcp_pair() -> (TcpStream, TcpStream) {
      16            4 :     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
      17            4 :     let addr = listener.local_addr().unwrap();
      18            4 :     let client_stream = TcpStream::connect(addr).await.unwrap();
      19            4 :     let (server_stream, _) = listener.accept().await.unwrap();
      20            4 :     (client_stream, server_stream)
      21            4 : }
      22              : 
      23              : struct TestHandler {}
      24              : 
      25              : #[async_trait::async_trait]
      26              : impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
      27              :     // return single col 'hey' for any query
      28            4 :     async fn process_query(
      29            4 :         &mut self,
      30            4 :         pgb: &mut PostgresBackend<IO>,
      31            4 :         _query_string: &str,
      32            4 :     ) -> Result<(), QueryError> {
      33            4 :         pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
      34            4 :             b"hey",
      35            4 :         )]))?
      36            4 :         .write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
      37            4 :         .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
      38            4 :         Ok(())
      39            4 :     }
      40              : }
      41              : 
      42              : // test that basic select works
      43              : #[tokio::test]
      44            2 : async fn simple_select() {
      45            2 :     let (client_sock, server_sock) = make_tcp_pair().await;
      46            2 : 
      47            2 :     // create and run pgbackend
      48            2 :     let pgbackend =
      49            2 :         PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
      50            2 : 
      51            2 :     tokio::spawn(async move {
      52            2 :         let mut handler = TestHandler {};
      53            4 :         pgbackend.run(&mut handler, future::pending::<()>).await
      54            2 :     });
      55            2 : 
      56            2 :     let conf = Config::new();
      57            2 :     let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
      58            2 :     // The connection object performs the actual communication with the database,
      59            2 :     // so spawn it off to run on its own.
      60            2 :     tokio::spawn(async move {
      61            2 :         if let Err(e) = connection.await {
      62            0 :             eprintln!("connection error: {}", e);
      63            0 :         }
      64            2 :     });
      65            2 : 
      66            2 :     let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
      67            2 :     if let SimpleQueryMessage::Row(row) = first_val {
      68            2 :         let first_col = row.get(0).expect("first column");
      69            2 :         assert_eq!(first_col, "hey");
      70            2 :     } else {
      71            2 :         panic!("expected SimpleQueryMessage::Row");
      72            2 :     }
      73            2 : }
      74              : 
      75            2 : static KEY: Lazy<rustls::pki_types::PrivateKeyDer<'static>> = Lazy::new(|| {
      76            2 :     let mut cursor = Cursor::new(include_bytes!("key.pem"));
      77            2 :     let key = rustls_pemfile::rsa_private_keys(&mut cursor)
      78            2 :         .next()
      79            2 :         .unwrap()
      80            2 :         .unwrap();
      81            2 :     rustls::pki_types::PrivateKeyDer::Pkcs1(key)
      82            2 : });
      83              : 
      84            2 : static CERT: Lazy<rustls::pki_types::CertificateDer<'static>> = Lazy::new(|| {
      85            2 :     let mut cursor = Cursor::new(include_bytes!("cert.pem"));
      86            2 :     let cert = rustls_pemfile::certs(&mut cursor).next().unwrap().unwrap();
      87            2 :     cert
      88            2 : });
      89              : 
      90              : // test that basic select with ssl works
      91              : #[tokio::test]
      92            2 : async fn simple_select_ssl() {
      93            2 :     let (client_sock, server_sock) = make_tcp_pair().await;
      94            2 : 
      95            2 :     let server_cfg = rustls::ServerConfig::builder()
      96            2 :         .with_no_client_auth()
      97            2 :         .with_single_cert(vec![CERT.clone()], KEY.clone_key())
      98            2 :         .unwrap();
      99            2 :     let tls_config = Some(Arc::new(server_cfg));
     100            2 :     let pgbackend =
     101            2 :         PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
     102            2 : 
     103            2 :     tokio::spawn(async move {
     104            2 :         let mut handler = TestHandler {};
     105           10 :         pgbackend.run(&mut handler, future::pending::<()>).await
     106            2 :     });
     107            2 : 
     108            2 :     let client_cfg = rustls::ClientConfig::builder()
     109            2 :         .with_root_certificates({
     110            2 :             let mut store = rustls::RootCertStore::empty();
     111            2 :             store.add(CERT.clone()).unwrap();
     112            2 :             store
     113            2 :         })
     114            2 :         .with_no_client_auth();
     115            2 :     let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
     116            2 :     let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
     117            2 :         &mut make_tls_connect,
     118            2 :         "localhost",
     119            2 :     )
     120            2 :     .expect("make_tls_connect");
     121            2 : 
     122            2 :     let mut conf = Config::new();
     123            2 :     conf.ssl_mode(SslMode::Require);
     124            2 :     let (client, connection) = conf
     125            2 :         .connect_raw(client_sock, tls_connect)
     126            8 :         .await
     127            2 :         .expect("connect");
     128            2 :     // The connection object performs the actual communication with the database,
     129            2 :     // so spawn it off to run on its own.
     130            2 :     tokio::spawn(async move {
     131            2 :         if let Err(e) = connection.await {
     132            0 :             eprintln!("connection error: {}", e);
     133            0 :         }
     134            2 :     });
     135            2 : 
     136            2 :     let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
     137            2 :     if let SimpleQueryMessage::Row(row) = first_val {
     138            2 :         let first_col = row.get(0).expect("first column");
     139            2 :         assert_eq!(first_col, "hey");
     140            2 :     } else {
     141            2 :         panic!("expected SimpleQueryMessage::Row");
     142            2 :     }
     143            2 : }
        

Generated by: LCOV version 2.1-beta