LCOV - code coverage report
Current view: top level - proxy/src/proxy/tests - mitm.rs (source / functions) Coverage Total Hit
Test: f081ec316c96fa98335efd15ef501745aa4f015d.info Lines: 97.4 % 156 152
Test Date: 2024-06-25 15:11:17 Functions: 100.0 % 23 23

            Line data    Source code
       1              : //! Man-in-the-middle tests
       2              : //!
       3              : //! Channel binding should prevent a proxy server
       4              : //! - that has access to create valid certificates -
       5              : //! from controlling the TLS connection.
       6              : 
       7              : use std::fmt::Debug;
       8              : 
       9              : use super::*;
      10              : use bytes::{Bytes, BytesMut};
      11              : use futures::{SinkExt, StreamExt};
      12              : use postgres_protocol::message::frontend;
      13              : use tokio::io::{AsyncReadExt, DuplexStream};
      14              : use tokio_postgres::tls::TlsConnect;
      15              : use tokio_util::codec::{Decoder, Encoder};
      16              : 
      17              : enum Intercept {
      18              :     None,
      19              :     Methods,
      20              :     SASLResponse,
      21              : }
      22              : 
      23           14 : async fn proxy_mitm(
      24           14 :     intercept: Intercept,
      25           14 : ) -> (DuplexStream, DuplexStream, ClientConfig<'static>, TlsConfig) {
      26           14 :     let (end_server1, client1) = tokio::io::duplex(1024);
      27           14 :     let (server2, end_client2) = tokio::io::duplex(1024);
      28           14 : 
      29           14 :     let (client_config1, server_config1) =
      30           14 :         generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
      31           14 :     let (client_config2, server_config2) =
      32           14 :         generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
      33           14 : 
      34           14 :     tokio::spawn(async move {
      35              :         // begin handshake with end_server
      36           28 :         let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
      37           14 :         let (end_client, startup) = match handshake(client1, Some(&server_config1), false)
      38           42 :             .await
      39           14 :             .unwrap()
      40              :         {
      41           14 :             HandshakeData::Startup(stream, params) => (stream, params),
      42            0 :             HandshakeData::Cancel(_) => panic!("cancellation not supported"),
      43              :         };
      44              : 
      45           14 :         let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
      46           14 :         let (end_client, buf) = end_client.framed.into_inner();
      47           14 :         assert!(buf.is_empty());
      48           14 :         let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
      49           14 : 
      50           14 :         // give the end_server the startup parameters
      51           14 :         let mut buf = BytesMut::new();
      52           14 :         frontend::startup_message(startup.iter(), &mut buf).unwrap();
      53           14 :         end_server.send(buf.freeze()).await.unwrap();
      54              : 
      55              :         // proxy messages between end_client and end_server
      56           68 :         loop {
      57           68 :             tokio::select! {
      58              :                 message = end_server.next() => {
      59              :                     match message {
      60              :                         Some(Ok(message)) => {
      61              :                             // intercept SASL and return only SCRAM-SHA-256 ;)
      62              :                             if matches!(intercept, Intercept::Methods) && message.starts_with(b"R") && message[5..].starts_with(&[0,0,0,10]) {
      63              :                                 end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
      64              :                                 continue;
      65              :                             }
      66              :                             end_client.send(message).await.unwrap()
      67              :                         }
      68              :                         _ => break,
      69              :                     }
      70              :                 }
      71              :                 message = end_client.next() => {
      72              :                     match message {
      73              :                         Some(Ok(message)) => {
      74              :                             // intercept SASL response and return SCRAM-SHA-256 with no channel binding ;)
      75              :                             if matches!(intercept, Intercept::SASLResponse) && message.starts_with(b"p") && message[5..].starts_with(b"SCRAM-SHA-256-PLUS\0") {
      76              :                                 let sasl_message = &message[1+4+19+4..];
      77              :                                 let mut new_message = b"n,,".to_vec();
      78              :                                 new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
      79              : 
      80              :                                 let mut buf = BytesMut::new();
      81              :                                 frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
      82              : 
      83              :                                 end_server.send(buf.freeze()).await.unwrap();
      84              :                                 continue;
      85              :                             }
      86              :                             end_server.send(message).await.unwrap()
      87              :                         }
      88              :                         _ => break,
      89              :                     }
      90              :                 }
      91              :                 else => { break }
      92           68 :             }
      93           68 :         }
      94           14 :     });
      95           14 : 
      96           14 :     (end_server1, end_client2, client_config1, server_config2)
      97           14 : }
      98              : 
      99              : /// taken from tokio-postgres
     100           14 : pub async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
     101           14 : where
     102           14 :     S: AsyncRead + AsyncWrite + Unpin,
     103           14 :     T: TlsConnect<S>,
     104           14 :     T::Error: Debug,
     105           14 : {
     106           14 :     let mut buf = BytesMut::new();
     107           14 :     frontend::ssl_request(&mut buf);
     108           14 :     stream.write_all(&buf).await.unwrap();
     109           14 : 
     110           14 :     let mut buf = [0];
     111           14 :     stream.read_exact(&mut buf).await.unwrap();
     112           14 : 
     113           14 :     if buf[0] != b'S' {
     114            0 :         panic!("ssl not supported by server");
     115           14 :     }
     116           14 : 
     117           14 :     tls.connect(stream).await.unwrap()
     118           14 : }
     119              : 
     120              : struct PgFrame;
     121              : impl Decoder for PgFrame {
     122              :     type Item = Bytes;
     123              :     type Error = std::io::Error;
     124              : 
     125          102 :     fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
     126          102 :         if src.len() < 5 {
     127           48 :             src.reserve(5 - src.len());
     128           48 :             return Ok(None);
     129           54 :         }
     130           54 :         let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
     131           54 :         if src.len() < len {
     132            0 :             src.reserve(len - src.len());
     133            0 :             return Ok(None);
     134           54 :         }
     135           54 :         Ok(Some(src.split_to(len).freeze()))
     136          102 :     }
     137              : }
     138              : impl Encoder<Bytes> for PgFrame {
     139              :     type Error = std::io::Error;
     140              : 
     141           68 :     fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
     142           68 :         dst.extend_from_slice(&item);
     143           68 :         Ok(())
     144           68 :     }
     145              : }
     146              : 
     147              : /// If the client doesn't support channel bindings, it can be exploited.
     148              : #[tokio::test]
     149            2 : async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
     150            2 :     let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
     151            2 :     let proxy = tokio::spawn(dummy_proxy(
     152            2 :         client,
     153            2 :         Some(server_config),
     154            6 :         Scram::new("password").await?,
     155            2 :     ));
     156            2 : 
     157            2 :     let _client_err = tokio_postgres::Config::new()
     158            2 :         .channel_binding(tokio_postgres::config::ChannelBinding::Disable)
     159            2 :         .user("user")
     160            2 :         .dbname("db")
     161            2 :         .password("password")
     162            2 :         .ssl_mode(SslMode::Require)
     163            2 :         .connect_raw(server, client_config.make_tls_connect()?)
     164           16 :         .await?;
     165            2 : 
     166            2 :     proxy.await?
     167            2 : }
     168              : 
     169              : /// If the client chooses SCRAM-PLUS, it will fail
     170              : #[tokio::test]
     171            2 : async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
     172            2 :     connect_failure(
     173            2 :         Intercept::None,
     174            2 :         tokio_postgres::config::ChannelBinding::Prefer,
     175            2 :     )
     176           22 :     .await
     177            2 : }
     178              : 
     179              : /// If the MITM pretends like SCRAM-PLUS isn't available, but the client supports it, it will fail
     180              : #[tokio::test]
     181            2 : async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
     182            2 :     connect_failure(
     183            2 :         Intercept::Methods,
     184            2 :         tokio_postgres::config::ChannelBinding::Prefer,
     185            2 :     )
     186           14 :     .await
     187            2 : }
     188              : 
     189              : /// If the MITM pretends like the client doesn't support channel bindings, it will fail
     190              : #[tokio::test]
     191            2 : async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
     192            2 :     connect_failure(
     193            2 :         Intercept::SASLResponse,
     194            2 :         tokio_postgres::config::ChannelBinding::Prefer,
     195            2 :     )
     196           22 :     .await
     197            2 : }
     198              : 
     199              : /// If the client chooses SCRAM-PLUS, it will fail
     200              : #[tokio::test]
     201            2 : async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
     202            2 :     connect_failure(
     203            2 :         Intercept::None,
     204            2 :         tokio_postgres::config::ChannelBinding::Require,
     205            2 :     )
     206           22 :     .await
     207            2 : }
     208              : 
     209              : /// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
     210              : #[tokio::test]
     211            2 : async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
     212            2 :     connect_failure(
     213            2 :         Intercept::Methods,
     214            2 :         tokio_postgres::config::ChannelBinding::Require,
     215            2 :     )
     216           14 :     .await
     217            2 : }
     218              : 
     219              : /// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
     220              : #[tokio::test]
     221            2 : async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
     222            2 :     connect_failure(
     223            2 :         Intercept::SASLResponse,
     224            2 :         tokio_postgres::config::ChannelBinding::Require,
     225            2 :     )
     226           22 :     .await
     227            2 : }
     228              : 
     229           12 : async fn connect_failure(
     230           12 :     intercept: Intercept,
     231           12 :     channel_binding: tokio_postgres::config::ChannelBinding,
     232           12 : ) -> anyhow::Result<()> {
     233           12 :     let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
     234           12 :     let proxy = tokio::spawn(dummy_proxy(
     235           12 :         client,
     236           12 :         Some(server_config),
     237           36 :         Scram::new("password").await?,
     238              :     ));
     239              : 
     240           12 :     let _client_err = tokio_postgres::Config::new()
     241           12 :         .channel_binding(channel_binding)
     242           12 :         .user("user")
     243           12 :         .dbname("db")
     244           12 :         .password("password")
     245           12 :         .ssl_mode(SslMode::Require)
     246           12 :         .connect_raw(server, client_config.make_tls_connect()?)
     247           78 :         .await
     248           12 :         .err()
     249           12 :         .context("client shouldn't be able to connect")?;
     250              : 
     251           12 :     let _server_err = proxy
     252            2 :         .await?
     253           12 :         .err()
     254           12 :         .context("server shouldn't accept client")?;
     255              : 
     256           12 :     Ok(())
     257           12 : }
        

Generated by: LCOV version 2.1-beta