LCOV - code coverage report
Current view: top level - proxy/src/proxy/tests - mitm.rs (source / functions) Coverage Total Hit
Test: 6a14b070dc6eeeeb359cfa8817925ac37a02fab4.info Lines: 97.9 % 187 183
Test Date: 2025-03-31 22:46:13 Functions: 100.0 % 23 23

            Line data    Source code
       1              : //! Man-in-the-middle tests
       2              : //!
       3              : //! Channel binding should prevent a proxy server
       4              : //! *that has access to create valid certificates*
       5              : //! from controlling the TLS connection.
       6              : 
       7              : use std::fmt::Debug;
       8              : 
       9              : use bytes::{Bytes, BytesMut};
      10              : use futures::{SinkExt, StreamExt};
      11              : use postgres_client::tls::TlsConnect;
      12              : use postgres_protocol::message::frontend;
      13              : use tokio::io::{AsyncReadExt, DuplexStream};
      14              : use tokio_util::codec::{Decoder, Encoder};
      15              : 
      16              : use super::*;
      17              : 
      18              : enum Intercept {
      19              :     None,
      20              :     Methods,
      21              :     SASLResponse,
      22              : }
      23              : 
      24            7 : async fn proxy_mitm(
      25            7 :     intercept: Intercept,
      26            7 : ) -> (DuplexStream, DuplexStream, ClientConfig<'static>, TlsConfig) {
      27            7 :     let (end_server1, client1) = tokio::io::duplex(1024);
      28            7 :     let (server2, end_client2) = tokio::io::duplex(1024);
      29            7 : 
      30            7 :     let (client_config1, server_config1) =
      31            7 :         generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
      32            7 :     let (client_config2, server_config2) =
      33            7 :         generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
      34            7 : 
      35            7 :     tokio::spawn(async move {
      36              :         // begin handshake with end_server
      37            7 :         let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
      38            7 :         let (end_client, startup) = match handshake(
      39            7 :             &RequestContext::test(),
      40            7 :             client1,
      41            7 :             Some(&server_config1),
      42            7 :             false,
      43            7 :         )
      44            7 :         .await
      45            7 :         .unwrap()
      46              :         {
      47            7 :             HandshakeData::Startup(stream, params) => (stream, params),
      48            0 :             HandshakeData::Cancel(_) => panic!("cancellation not supported"),
      49              :         };
      50              : 
      51            7 :         let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
      52            7 :         let (end_client, buf) = end_client.framed.into_inner();
      53            7 :         assert!(buf.is_empty());
      54            7 :         let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
      55            7 : 
      56            7 :         // give the end_server the startup parameters
      57            7 :         let mut buf = BytesMut::new();
      58            7 :         frontend::startup_message(
      59            7 :             &postgres_protocol::message::frontend::StartupMessageParams {
      60            7 :                 params: startup.params.into(),
      61            7 :             },
      62            7 :             &mut buf,
      63            7 :         )
      64            7 :         .unwrap();
      65            7 :         end_server.send(buf.freeze()).await.unwrap();
      66              : 
      67              :         // proxy messages between end_client and end_server
      68              :         loop {
      69           34 :             tokio::select! {
      70           34 :                 message = end_server.next() => {
      71           22 :                     match message {
      72           16 :                         Some(Ok(message)) => {
      73              :                             // intercept SASL and return only SCRAM-SHA-256 ;)
      74           16 :                             if matches!(intercept, Intercept::Methods) && message.starts_with(b"R") && message[5..].starts_with(&[0,0,0,10]) {
      75            2 :                                 end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
      76            2 :                                 continue;
      77           14 :                             }
      78           14 :                             end_client.send(message).await.unwrap();
      79              :                         }
      80            6 :                         _ => break,
      81              :                     }
      82              :                 }
      83           34 :                 message = end_client.next() => {
      84           12 :                     match message {
      85           11 :                         Some(Ok(message)) => {
      86              :                             // intercept SASL response and return SCRAM-SHA-256 with no channel binding ;)
      87           11 :                             if matches!(intercept, Intercept::SASLResponse) && message.starts_with(b"p") && message[5..].starts_with(b"SCRAM-SHA-256-PLUS\0") {
      88            2 :                                 let sasl_message = &message[1+4+19+4..];
      89            2 :                                 let mut new_message = b"n,,".to_vec();
      90            2 :                                 new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
      91            2 : 
      92            2 :                                 let mut buf = BytesMut::new();
      93            2 :                                 frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
      94            2 : 
      95            2 :                                 end_server.send(buf.freeze()).await.unwrap();
      96            2 :                                 continue;
      97            9 :                             }
      98            9 :                             end_server.send(message).await.unwrap();
      99              :                         }
     100            1 :                         _ => break,
     101              :                     }
     102              :                 }
     103            0 :                 else => { break }
     104              :             }
     105              :         }
     106            7 :     });
     107            7 : 
     108            7 :     (end_server1, end_client2, client_config1, server_config2)
     109            7 : }
     110              : 
     111              : /// taken from tokio-postgres
     112            7 : pub(crate) async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
     113            7 : where
     114            7 :     S: AsyncRead + AsyncWrite + Unpin,
     115            7 :     T: TlsConnect<S>,
     116            7 :     T::Error: Debug,
     117            7 : {
     118            7 :     let mut buf = BytesMut::new();
     119            7 :     frontend::ssl_request(&mut buf);
     120            7 :     stream.write_all(&buf).await.unwrap();
     121            7 : 
     122            7 :     let mut buf = [0];
     123            7 :     stream.read_exact(&mut buf).await.unwrap();
     124            7 : 
     125            7 :     assert!(buf[0] == b'S', "ssl not supported by server");
     126              : 
     127            7 :     tls.connect(stream).await.unwrap()
     128            7 : }
     129              : 
     130              : struct PgFrame;
     131              : impl Decoder for PgFrame {
     132              :     type Item = Bytes;
     133              :     type Error = std::io::Error;
     134              : 
     135           51 :     fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
     136           51 :         if src.len() < 5 {
     137           24 :             src.reserve(5 - src.len());
     138           24 :             return Ok(None);
     139           27 :         }
     140           27 :         let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
     141           27 :         if src.len() < len {
     142            0 :             src.reserve(len - src.len());
     143            0 :             return Ok(None);
     144           27 :         }
     145           27 :         Ok(Some(src.split_to(len).freeze()))
     146           51 :     }
     147              : }
     148              : impl Encoder<Bytes> for PgFrame {
     149              :     type Error = std::io::Error;
     150              : 
     151           34 :     fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
     152           34 :         dst.extend_from_slice(&item);
     153           34 :         Ok(())
     154           34 :     }
     155              : }
     156              : 
     157              : /// If the client doesn't support channel bindings, it can be exploited.
     158              : #[tokio::test]
     159            1 : async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
     160            1 :     let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
     161            1 :     let proxy = tokio::spawn(dummy_proxy(
     162            1 :         client,
     163            1 :         Some(server_config),
     164            1 :         Scram::new("password").await?,
     165            1 :     ));
     166            1 : 
     167            1 :     let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
     168            1 :         .channel_binding(postgres_client::config::ChannelBinding::Disable)
     169            1 :         .user("user")
     170            1 :         .dbname("db")
     171            1 :         .password("password")
     172            1 :         .ssl_mode(SslMode::Require)
     173            1 :         .connect_raw(server, client_config.make_tls_connect()?)
     174            1 :         .await?;
     175            1 : 
     176            1 :     proxy.await?
     177            1 : }
     178              : 
     179              : /// If the client chooses SCRAM-PLUS, it will fail
     180              : #[tokio::test]
     181            1 : async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
     182            1 :     connect_failure(
     183            1 :         Intercept::None,
     184            1 :         postgres_client::config::ChannelBinding::Prefer,
     185            1 :     )
     186            1 :     .await
     187            1 : }
     188              : 
     189              : /// If the MITM pretends like SCRAM-PLUS isn't available, but the client supports it, it will fail
     190              : #[tokio::test]
     191            1 : async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
     192            1 :     connect_failure(
     193            1 :         Intercept::Methods,
     194            1 :         postgres_client::config::ChannelBinding::Prefer,
     195            1 :     )
     196            1 :     .await
     197            1 : }
     198              : 
     199              : /// If the MITM pretends like the client doesn't support channel bindings, it will fail
     200              : #[tokio::test]
     201            1 : async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
     202            1 :     connect_failure(
     203            1 :         Intercept::SASLResponse,
     204            1 :         postgres_client::config::ChannelBinding::Prefer,
     205            1 :     )
     206            1 :     .await
     207            1 : }
     208              : 
     209              : /// If the client chooses SCRAM-PLUS, it will fail
     210              : #[tokio::test]
     211            1 : async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
     212            1 :     connect_failure(
     213            1 :         Intercept::None,
     214            1 :         postgres_client::config::ChannelBinding::Require,
     215            1 :     )
     216            1 :     .await
     217            1 : }
     218              : 
     219              : /// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
     220              : #[tokio::test]
     221            1 : async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
     222            1 :     connect_failure(
     223            1 :         Intercept::Methods,
     224            1 :         postgres_client::config::ChannelBinding::Require,
     225            1 :     )
     226            1 :     .await
     227            1 : }
     228              : 
     229              : /// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
     230              : #[tokio::test]
     231            1 : async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
     232            1 :     connect_failure(
     233            1 :         Intercept::SASLResponse,
     234            1 :         postgres_client::config::ChannelBinding::Require,
     235            1 :     )
     236            1 :     .await
     237            1 : }
     238              : 
     239            6 : async fn connect_failure(
     240            6 :     intercept: Intercept,
     241            6 :     channel_binding: postgres_client::config::ChannelBinding,
     242            6 : ) -> anyhow::Result<()> {
     243            6 :     let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
     244            6 :     let proxy = tokio::spawn(dummy_proxy(
     245            6 :         client,
     246            6 :         Some(server_config),
     247            6 :         Scram::new("password").await?,
     248              :     ));
     249              : 
     250            6 :     let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
     251            6 :         .channel_binding(channel_binding)
     252            6 :         .user("user")
     253            6 :         .dbname("db")
     254            6 :         .password("password")
     255            6 :         .ssl_mode(SslMode::Require)
     256            6 :         .connect_raw(server, client_config.make_tls_connect()?)
     257            6 :         .await
     258            6 :         .err()
     259            6 :         .context("client shouldn't be able to connect")?;
     260              : 
     261            6 :     let _server_err = proxy
     262            6 :         .await?
     263            6 :         .err()
     264            6 :         .context("server shouldn't accept client")?;
     265              : 
     266            6 :     Ok(())
     267            6 : }
        

Generated by: LCOV version 2.1-beta