LCOV - code coverage report
Current view: top level - libs/postgres_backend/tests - simple_select.rs (source / functions) Coverage Total Hit
Test: 249f165943bd2c492f96a3f7d250276e4addca1a.info Lines: 96.7 % 120 116
Test Date: 2024-11-20 18:39:52 Functions: 100.0 % 14 14

            Line data    Source code
       1              : /// Test postgres_backend_async with tokio_postgres
       2              : use once_cell::sync::Lazy;
       3              : use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
       4              : use pq_proto::{BeMessage, RowDescriptor};
       5              : use rustls::crypto::ring;
       6              : use std::io::Cursor;
       7              : use std::sync::Arc;
       8              : use tokio::io::{AsyncRead, AsyncWrite};
       9              : use tokio::net::{TcpListener, TcpStream};
      10              : use tokio_postgres::config::SslMode;
      11              : use tokio_postgres::tls::MakeTlsConnect;
      12              : use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
      13              : use tokio_postgres_rustls::MakeRustlsConnect;
      14              : use tokio_util::sync::CancellationToken;
      15              : 
      16              : // generate client, server test streams
      17            2 : async fn make_tcp_pair() -> (TcpStream, TcpStream) {
      18            2 :     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
      19            2 :     let addr = listener.local_addr().unwrap();
      20            2 :     let client_stream = TcpStream::connect(addr).await.unwrap();
      21            2 :     let (server_stream, _) = listener.accept().await.unwrap();
      22            2 :     (client_stream, server_stream)
      23            2 : }
      24              : 
      25              : struct TestHandler {}
      26              : 
      27              : impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
      28              :     // return single col 'hey' for any query
      29            2 :     async fn process_query(
      30            2 :         &mut self,
      31            2 :         pgb: &mut PostgresBackend<IO>,
      32            2 :         _query_string: &str,
      33            2 :     ) -> Result<(), QueryError> {
      34            2 :         pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
      35            2 :             b"hey",
      36            2 :         )]))?
      37            2 :         .write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
      38            2 :         .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
      39            2 :         Ok(())
      40            2 :     }
      41              : }
      42              : 
      43              : // test that basic select works
      44              : #[tokio::test]
      45            1 : async fn simple_select() {
      46            1 :     let (client_sock, server_sock) = make_tcp_pair().await;
      47            1 : 
      48            1 :     // create and run pgbackend
      49            1 :     let pgbackend =
      50            1 :         PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
      51            1 : 
      52            1 :     tokio::spawn(async move {
      53            1 :         let mut handler = TestHandler {};
      54            2 :         pgbackend.run(&mut handler, &CancellationToken::new()).await
      55            1 :     });
      56            1 : 
      57            1 :     let conf = Config::new();
      58            1 :     let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
      59            1 :     // The connection object performs the actual communication with the database,
      60            1 :     // so spawn it off to run on its own.
      61            1 :     tokio::spawn(async move {
      62            1 :         if let Err(e) = connection.await {
      63            0 :             eprintln!("connection error: {}", e);
      64            0 :         }
      65            1 :     });
      66            1 : 
      67            1 :     let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
      68            1 :     if let SimpleQueryMessage::Row(row) = first_val {
      69            1 :         let first_col = row.get(0).expect("first column");
      70            1 :         assert_eq!(first_col, "hey");
      71            1 :     } else {
      72            1 :         panic!("expected SimpleQueryMessage::Row");
      73            1 :     }
      74            1 : }
      75              : 
      76            1 : static KEY: Lazy<rustls::pki_types::PrivateKeyDer<'static>> = Lazy::new(|| {
      77            1 :     let mut cursor = Cursor::new(include_bytes!("key.pem"));
      78            1 :     let key = rustls_pemfile::rsa_private_keys(&mut cursor)
      79            1 :         .next()
      80            1 :         .unwrap()
      81            1 :         .unwrap();
      82            1 :     rustls::pki_types::PrivateKeyDer::Pkcs1(key)
      83            1 : });
      84              : 
      85            1 : static CERT: Lazy<rustls::pki_types::CertificateDer<'static>> = Lazy::new(|| {
      86            1 :     let mut cursor = Cursor::new(include_bytes!("cert.pem"));
      87            1 :     let cert = rustls_pemfile::certs(&mut cursor).next().unwrap().unwrap();
      88            1 :     cert
      89            1 : });
      90              : 
      91              : // test that basic select with ssl works
      92              : #[tokio::test]
      93            1 : async fn simple_select_ssl() {
      94            1 :     let (client_sock, server_sock) = make_tcp_pair().await;
      95            1 : 
      96            1 :     let server_cfg =
      97            1 :         rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
      98            1 :             .with_safe_default_protocol_versions()
      99            1 :             .expect("aws_lc_rs should support the default protocol versions")
     100            1 :             .with_no_client_auth()
     101            1 :             .with_single_cert(vec![CERT.clone()], KEY.clone_key())
     102            1 :             .unwrap();
     103            1 :     let tls_config = Some(Arc::new(server_cfg));
     104            1 :     let pgbackend =
     105            1 :         PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
     106            1 : 
     107            1 :     tokio::spawn(async move {
     108            1 :         let mut handler = TestHandler {};
     109            5 :         pgbackend.run(&mut handler, &CancellationToken::new()).await
     110            1 :     });
     111            1 : 
     112            1 :     let client_cfg =
     113            1 :         rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
     114            1 :             .with_safe_default_protocol_versions()
     115            1 :             .expect("aws_lc_rs should support the default protocol versions")
     116            1 :             .with_root_certificates({
     117            1 :                 let mut store = rustls::RootCertStore::empty();
     118            1 :                 store.add(CERT.clone()).unwrap();
     119            1 :                 store
     120            1 :             })
     121            1 :             .with_no_client_auth();
     122            1 :     let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
     123            1 :     let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
     124            1 :         &mut make_tls_connect,
     125            1 :         "localhost",
     126            1 :     )
     127            1 :     .expect("make_tls_connect");
     128            1 : 
     129            1 :     let mut conf = Config::new();
     130            1 :     conf.ssl_mode(SslMode::Require);
     131            1 :     let (client, connection) = conf
     132            1 :         .connect_raw(client_sock, tls_connect)
     133            4 :         .await
     134            1 :         .expect("connect");
     135            1 :     // The connection object performs the actual communication with the database,
     136            1 :     // so spawn it off to run on its own.
     137            1 :     tokio::spawn(async move {
     138            1 :         if let Err(e) = connection.await {
     139            0 :             eprintln!("connection error: {}", e);
     140            0 :         }
     141            1 :     });
     142            1 : 
     143            1 :     let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
     144            1 :     if let SimpleQueryMessage::Row(row) = first_val {
     145            1 :         let first_col = row.get(0).expect("first column");
     146            1 :         assert_eq!(first_col, "hey");
     147            1 :     } else {
     148            1 :         panic!("expected SimpleQueryMessage::Row");
     149            1 :     }
     150            1 : }
        

Generated by: LCOV version 2.1-beta