LCOV - code coverage report
Current view: top level - proxy/src/proxy - copy_bidirectional.rs (source / functions) Coverage Total Hit
Test: 050dd70dd490b28fffe527eae9fb8a1222b5c59c.info Lines: 89.0 % 182 162
Test Date: 2024-06-25 21:28:46 Functions: 23.1 % 52 12

            Line data    Source code
       1              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
       2              : use tracing::info;
       3              : 
       4              : use std::future::poll_fn;
       5              : use std::io;
       6              : use std::pin::Pin;
       7              : use std::task::{ready, Context, Poll};
       8              : 
       9              : #[derive(Debug)]
      10              : enum TransferState {
      11              :     Running(CopyBuffer),
      12              :     ShuttingDown(u64),
      13              :     Done(u64),
      14              : }
      15              : 
      16           10 : fn transfer_one_direction<A, B>(
      17           10 :     cx: &mut Context<'_>,
      18           10 :     state: &mut TransferState,
      19           10 :     r: &mut A,
      20           10 :     w: &mut B,
      21           10 : ) -> Poll<io::Result<u64>>
      22           10 : where
      23           10 :     A: AsyncRead + AsyncWrite + Unpin + ?Sized,
      24           10 :     B: AsyncRead + AsyncWrite + Unpin + ?Sized,
      25           10 : {
      26           10 :     let mut r = Pin::new(r);
      27           10 :     let mut w = Pin::new(w);
      28           24 :     loop {
      29           24 :         match state {
      30            8 :             TransferState::Running(buf) => {
      31            8 :                 let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
      32            6 :                 *state = TransferState::ShuttingDown(count);
      33              :             }
      34            8 :             TransferState::ShuttingDown(count) => {
      35            8 :                 ready!(w.as_mut().poll_shutdown(cx))?;
      36            8 :                 *state = TransferState::Done(*count);
      37              :             }
      38            8 :             TransferState::Done(count) => return Poll::Ready(Ok(*count)),
      39              :         }
      40              :     }
      41           10 : }
      42              : 
      43            8 : #[tracing::instrument(skip_all)]
      44              : pub async fn copy_bidirectional_client_compute<Client, Compute>(
      45              :     client: &mut Client,
      46              :     compute: &mut Compute,
      47              : ) -> Result<(u64, u64), std::io::Error>
      48              : where
      49              :     Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
      50              :     Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
      51              : {
      52              :     let mut client_to_compute = TransferState::Running(CopyBuffer::new());
      53              :     let mut compute_to_client = TransferState::Running(CopyBuffer::new());
      54              : 
      55            4 :     poll_fn(|cx| {
      56            4 :         let mut client_to_compute_result =
      57            4 :             transfer_one_direction(cx, &mut client_to_compute, client, compute)?;
      58            4 :         let mut compute_to_client_result =
      59            4 :             transfer_one_direction(cx, &mut compute_to_client, compute, client)?;
      60              : 
      61              :         // Early termination checks from compute to client.
      62            4 :         if let TransferState::Done(_) = compute_to_client {
      63            4 :             if let TransferState::Running(buf) = &client_to_compute {
      64            2 :                 info!("Compute is done, terminate client");
      65              :                 // Initiate shutdown
      66            2 :                 client_to_compute = TransferState::ShuttingDown(buf.amt);
      67            2 :                 client_to_compute_result =
      68            2 :                     transfer_one_direction(cx, &mut client_to_compute, client, compute)?;
      69            2 :             }
      70            0 :         }
      71              : 
      72              :         // Early termination checks from compute to client.
      73            4 :         if let TransferState::Done(_) = client_to_compute {
      74            4 :             if let TransferState::Running(buf) = &compute_to_client {
      75            0 :                 info!("Client is done, terminate compute");
      76              :                 // Initiate shutdown
      77            0 :                 compute_to_client = TransferState::ShuttingDown(buf.amt);
      78            0 :                 compute_to_client_result =
      79            0 :                     transfer_one_direction(cx, &mut compute_to_client, client, compute)?;
      80            4 :             }
      81            0 :         }
      82              : 
      83              :         // It is not a problem if ready! returns early ... (comment remains the same)
      84            4 :         let client_to_compute = ready!(client_to_compute_result);
      85            4 :         let compute_to_client = ready!(compute_to_client_result);
      86              : 
      87            4 :         Poll::Ready(Ok((client_to_compute, compute_to_client)))
      88            4 :     })
      89              :     .await
      90              : }
      91              : 
      92              : #[derive(Debug)]
      93              : pub(super) struct CopyBuffer {
      94              :     read_done: bool,
      95              :     need_flush: bool,
      96              :     pos: usize,
      97              :     cap: usize,
      98              :     amt: u64,
      99              :     buf: Box<[u8]>,
     100              : }
     101              : const DEFAULT_BUF_SIZE: usize = 1024;
     102              : 
     103              : impl CopyBuffer {
     104            8 :     pub(super) fn new() -> Self {
     105            8 :         Self {
     106            8 :             read_done: false,
     107            8 :             need_flush: false,
     108            8 :             pos: 0,
     109            8 :             cap: 0,
     110            8 :             amt: 0,
     111            8 :             buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
     112            8 :         }
     113            8 :     }
     114              : 
     115           16 :     fn poll_fill_buf<R>(
     116           16 :         &mut self,
     117           16 :         cx: &mut Context<'_>,
     118           16 :         reader: Pin<&mut R>,
     119           16 :     ) -> Poll<io::Result<()>>
     120           16 :     where
     121           16 :         R: AsyncRead + ?Sized,
     122           16 :     {
     123           16 :         let me = &mut *self;
     124           16 :         let mut buf = ReadBuf::new(&mut me.buf);
     125           16 :         buf.set_filled(me.cap);
     126           16 : 
     127           16 :         let res = reader.poll_read(cx, &mut buf);
     128           16 :         if let Poll::Ready(Ok(())) = res {
     129           14 :             let filled_len = buf.filled().len();
     130           14 :             me.read_done = me.cap == filled_len;
     131           14 :             me.cap = filled_len;
     132           14 :         }
     133           16 :         res
     134           16 :     }
     135              : 
     136           10 :     fn poll_write_buf<R, W>(
     137           10 :         &mut self,
     138           10 :         cx: &mut Context<'_>,
     139           10 :         mut reader: Pin<&mut R>,
     140           10 :         mut writer: Pin<&mut W>,
     141           10 :     ) -> Poll<io::Result<usize>>
     142           10 :     where
     143           10 :         R: AsyncRead + ?Sized,
     144           10 :         W: AsyncWrite + ?Sized,
     145           10 :     {
     146           10 :         let me = &mut *self;
     147           10 :         match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
     148              :             Poll::Pending => {
     149              :                 // Top up the buffer towards full if we can read a bit more
     150              :                 // data - this should improve the chances of a large write
     151            2 :                 if !me.read_done && me.cap < me.buf.len() {
     152            2 :                     ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
     153            0 :                 }
     154            0 :                 Poll::Pending
     155              :             }
     156            8 :             res => res,
     157              :         }
     158           10 :     }
     159              : 
     160            8 :     pub(super) fn poll_copy<R, W>(
     161            8 :         &mut self,
     162            8 :         cx: &mut Context<'_>,
     163            8 :         mut reader: Pin<&mut R>,
     164            8 :         mut writer: Pin<&mut W>,
     165            8 :     ) -> Poll<io::Result<u64>>
     166            8 :     where
     167            8 :         R: AsyncRead + ?Sized,
     168            8 :         W: AsyncWrite + ?Sized,
     169            8 :     {
     170           14 :         loop {
     171           14 :             // If our buffer is empty, then we need to read some data to
     172           14 :             // continue.
     173           14 :             if self.pos == self.cap && !self.read_done {
     174           14 :                 self.pos = 0;
     175           14 :                 self.cap = 0;
     176           14 : 
     177           14 :                 match self.poll_fill_buf(cx, reader.as_mut()) {
     178           14 :                     Poll::Ready(Ok(())) => (),
     179            0 :                     Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
     180              :                     Poll::Pending => {
     181              :                         // Try flushing when the reader has no progress to avoid deadlock
     182              :                         // when the reader depends on buffered writer.
     183            0 :                         if self.need_flush {
     184            0 :                             ready!(writer.as_mut().poll_flush(cx))?;
     185            0 :                             self.need_flush = false;
     186            0 :                         }
     187              : 
     188            0 :                         return Poll::Pending;
     189              :                     }
     190              :                 }
     191            0 :             }
     192              : 
     193              :             // If our buffer has some data, let's write it out!
     194           22 :             while self.pos < self.cap {
     195           10 :                 let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
     196            8 :                 if i == 0 {
     197            0 :                     return Poll::Ready(Err(io::Error::new(
     198            0 :                         io::ErrorKind::WriteZero,
     199            0 :                         "write zero byte into writer",
     200            0 :                     )));
     201            8 :                 } else {
     202            8 :                     self.pos += i;
     203            8 :                     self.amt += i as u64;
     204            8 :                     self.need_flush = true;
     205            8 :                 }
     206              :             }
     207              : 
     208              :             // If pos larger than cap, this loop will never stop.
     209              :             // In particular, user's wrong poll_write implementation returning
     210              :             // incorrect written length may lead to thread blocking.
     211           12 :             debug_assert!(
     212           12 :                 self.pos <= self.cap,
     213            0 :                 "writer returned length larger than input slice"
     214              :             );
     215              : 
     216              :             // If we've written all the data and we've seen EOF, flush out the
     217              :             // data and finish the transfer.
     218           12 :             if self.pos == self.cap && self.read_done {
     219            6 :                 ready!(writer.as_mut().poll_flush(cx))?;
     220            6 :                 return Poll::Ready(Ok(self.amt));
     221            6 :             }
     222              :         }
     223            8 :     }
     224              : }
     225              : 
     226              : #[cfg(test)]
     227              : mod tests {
     228              :     use super::*;
     229              :     use tokio::io::AsyncWriteExt;
     230              : 
     231              :     #[tokio::test]
     232            2 :     async fn test_client_to_compute() {
     233            2 :         let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
     234            2 :         let (mut compute_proxy, mut compute_client) = tokio::io::duplex(32); // Create a mock duplex stream
     235            2 : 
     236            2 :         // Simulate 'a' finishing while there's still data for 'b'
     237            2 :         client_client.write_all(b"hello").await.unwrap();
     238            2 :         client_client.shutdown().await.unwrap();
     239            2 :         compute_client.write_all(b"Neon").await.unwrap();
     240            2 :         compute_client.shutdown().await.unwrap();
     241            2 : 
     242            2 :         let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
     243            2 :             .await
     244            2 :             .unwrap();
     245            2 : 
     246            2 :         // Assert correct transferred amounts
     247            2 :         let (client_to_compute_count, compute_to_client_count) = result;
     248            2 :         assert_eq!(client_to_compute_count, 5); // 'hello' was transferred
     249            2 :         assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all
     250            2 :     }
     251              : 
     252              :     #[tokio::test]
     253            2 :     async fn test_compute_to_client() {
     254            2 :         let (mut client_client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
     255            2 :         let (mut compute_proxy, mut compute_client) = tokio::io::duplex(8); // Create a mock duplex stream
     256            2 : 
     257            2 :         // Simulate 'a' finishing while there's still data for 'b'
     258            2 :         compute_client.write_all(b"hello").await.unwrap();
     259            2 :         compute_client.shutdown().await.unwrap();
     260            2 :         client_client
     261            2 :             .write_all(b"Neon Serverless Postgres")
     262            2 :             .await
     263            2 :             .unwrap();
     264            2 : 
     265            2 :         let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
     266            2 :             .await
     267            2 :             .unwrap();
     268            2 : 
     269            2 :         // Assert correct transferred amounts
     270            2 :         let (client_to_compute_count, compute_to_client_count) = result;
     271            2 :         assert_eq!(compute_to_client_count, 5); // 'hello' was transferred
     272            2 :         assert!(client_to_compute_count <= 8); // response only partially transferred or not at all
     273            2 :     }
     274              : }
        

Generated by: LCOV version 2.1-beta