LCOV - code coverage report
Current view: top level - proxy/src/pglb - copy_bidirectional.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 82.1 % 173 142
Test Date: 2025-07-16 12:29:03 Functions: 21.3 % 47 10

            Line data    Source code
       1              : use std::future::poll_fn;
       2              : use std::io;
       3              : use std::pin::Pin;
       4              : use std::task::{Context, Poll, ready};
       5              : 
       6              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
       7              : use tracing::info;
       8              : 
       9              : #[derive(Debug)]
      10              : enum TransferState {
      11              :     Running(CopyBuffer),
      12              :     ShuttingDown(u64),
      13              :     Done(u64),
      14              : }
      15              : 
      16              : #[derive(Debug)]
      17              : pub(crate) enum ErrorDirection {
      18              :     Read(io::Error),
      19              :     Write(io::Error),
      20              : }
      21              : 
      22              : impl ErrorSource {
      23            0 :     fn from_client(err: ErrorDirection) -> ErrorSource {
      24            0 :         match err {
      25            0 :             ErrorDirection::Read(client) => Self::Client(client),
      26            0 :             ErrorDirection::Write(compute) => Self::Compute(compute),
      27              :         }
      28            0 :     }
      29            0 :     fn from_compute(err: ErrorDirection) -> ErrorSource {
      30            0 :         match err {
      31            0 :             ErrorDirection::Write(client) => Self::Client(client),
      32            0 :             ErrorDirection::Read(compute) => Self::Compute(compute),
      33              :         }
      34            0 :     }
      35              : }
      36              : 
      37              : #[derive(Debug)]
      38              : pub enum ErrorSource {
      39              :     Client(io::Error),
      40              :     Compute(io::Error),
      41              : }
      42              : 
      43            5 : fn transfer_one_direction<A, B>(
      44            5 :     cx: &mut Context<'_>,
      45            5 :     state: &mut TransferState,
      46            5 :     r: &mut A,
      47            5 :     w: &mut B,
      48            5 : ) -> Poll<Result<u64, ErrorDirection>>
      49            5 : where
      50            5 :     A: AsyncRead + AsyncWrite + Unpin + ?Sized,
      51            5 :     B: AsyncRead + AsyncWrite + Unpin + ?Sized,
      52              : {
      53            5 :     let mut r = Pin::new(r);
      54            5 :     let mut w = Pin::new(w);
      55              :     loop {
      56           12 :         match state {
      57            4 :             TransferState::Running(buf) => {
      58            4 :                 let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
      59            3 :                 *state = TransferState::ShuttingDown(count);
      60              :             }
      61            4 :             TransferState::ShuttingDown(count) => {
      62            4 :                 ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?;
      63            4 :                 *state = TransferState::Done(*count);
      64              :             }
      65            4 :             TransferState::Done(count) => return Poll::Ready(Ok(*count)),
      66              :         }
      67              :     }
      68            5 : }
      69              : 
      70              : #[tracing::instrument(skip_all)]
      71              : pub async fn copy_bidirectional_client_compute<Client, Compute>(
      72              :     client: &mut Client,
      73              :     compute: &mut Compute,
      74              : ) -> Result<(u64, u64), ErrorSource>
      75              : where
      76              :     Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
      77              :     Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
      78              : {
      79              :     let mut client_to_compute = TransferState::Running(CopyBuffer::new());
      80              :     let mut compute_to_client = TransferState::Running(CopyBuffer::new());
      81              : 
      82            2 :     poll_fn(|cx| {
      83            2 :         let mut client_to_compute_result =
      84            2 :             transfer_one_direction(cx, &mut client_to_compute, client, compute)
      85            2 :                 .map_err(ErrorSource::from_client)?;
      86            2 :         let mut compute_to_client_result =
      87            2 :             transfer_one_direction(cx, &mut compute_to_client, compute, client)
      88            2 :                 .map_err(ErrorSource::from_compute)?;
      89              : 
      90              :         // TODO: 1 info log, with a enum label for close direction.
      91              : 
      92              :         // Early termination checks from compute to client.
      93            2 :         if let TransferState::Done(_) = compute_to_client
      94            2 :             && let TransferState::Running(buf) = &client_to_compute
      95              :         {
      96            1 :             info!("Compute is done, terminate client");
      97              :             // Initiate shutdown
      98            1 :             client_to_compute = TransferState::ShuttingDown(buf.amt);
      99              :             client_to_compute_result =
     100            1 :                 transfer_one_direction(cx, &mut client_to_compute, client, compute)
     101            1 :                     .map_err(ErrorSource::from_client)?;
     102            1 :         }
     103              : 
     104              :         // Early termination checks from client to compute.
     105            2 :         if let TransferState::Done(_) = client_to_compute
     106            2 :             && let TransferState::Running(buf) = &compute_to_client
     107              :         {
     108            0 :             info!("Client is done, terminate compute");
     109              :             // Initiate shutdown
     110            0 :             compute_to_client = TransferState::ShuttingDown(buf.amt);
     111              :             compute_to_client_result =
     112            0 :                 transfer_one_direction(cx, &mut compute_to_client, compute, client)
     113            0 :                     .map_err(ErrorSource::from_compute)?;
     114            2 :         }
     115              : 
     116              :         // It is not a problem if ready! returns early ... (comment remains the same)
     117            2 :         let client_to_compute = ready!(client_to_compute_result);
     118            2 :         let compute_to_client = ready!(compute_to_client_result);
     119              : 
     120            2 :         Poll::Ready(Ok((client_to_compute, compute_to_client)))
     121            2 :     })
     122              :     .await
     123              : }
     124              : 
     125              : #[derive(Debug)]
     126              : pub(super) struct CopyBuffer {
     127              :     read_done: bool,
     128              :     need_flush: bool,
     129              :     pos: usize,
     130              :     cap: usize,
     131              :     amt: u64,
     132              :     buf: Box<[u8]>,
     133              : }
     134              : const DEFAULT_BUF_SIZE: usize = 1024;
     135              : 
     136              : impl CopyBuffer {
     137            4 :     pub(super) fn new() -> Self {
     138            4 :         Self {
     139            4 :             read_done: false,
     140            4 :             need_flush: false,
     141            4 :             pos: 0,
     142            4 :             cap: 0,
     143            4 :             amt: 0,
     144            4 :             buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
     145            4 :         }
     146            4 :     }
     147              : 
     148            8 :     fn poll_fill_buf<R>(
     149            8 :         &mut self,
     150            8 :         cx: &mut Context<'_>,
     151            8 :         reader: Pin<&mut R>,
     152            8 :     ) -> Poll<io::Result<()>>
     153            8 :     where
     154            8 :         R: AsyncRead + ?Sized,
     155              :     {
     156            8 :         let me = &mut *self;
     157            8 :         let mut buf = ReadBuf::new(&mut me.buf);
     158            8 :         buf.set_filled(me.cap);
     159              : 
     160            8 :         let res = reader.poll_read(cx, &mut buf);
     161            7 :         if let Poll::Ready(Ok(())) = res {
     162            7 :             let filled_len = buf.filled().len();
     163            7 :             me.read_done = me.cap == filled_len;
     164            7 :             me.cap = filled_len;
     165            7 :         }
     166            8 :         res
     167            8 :     }
     168              : 
     169            5 :     fn poll_write_buf<R, W>(
     170            5 :         &mut self,
     171            5 :         cx: &mut Context<'_>,
     172            5 :         mut reader: Pin<&mut R>,
     173            5 :         mut writer: Pin<&mut W>,
     174            5 :     ) -> Poll<Result<usize, ErrorDirection>>
     175            5 :     where
     176            5 :         R: AsyncRead + ?Sized,
     177            5 :         W: AsyncWrite + ?Sized,
     178              :     {
     179            5 :         let me = &mut *self;
     180            5 :         match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
     181              :             Poll::Pending => {
     182              :                 // Top up the buffer towards full if we can read a bit more
     183              :                 // data - this should improve the chances of a large write
     184            1 :                 if !me.read_done && me.cap < me.buf.len() {
     185            1 :                     ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?;
     186            0 :                 }
     187            0 :                 Poll::Pending
     188              :             }
     189            4 :             res @ Poll::Ready(_) => res.map_err(ErrorDirection::Write),
     190              :         }
     191            5 :     }
     192              : 
     193            4 :     pub(super) fn poll_copy<R, W>(
     194            4 :         &mut self,
     195            4 :         cx: &mut Context<'_>,
     196            4 :         mut reader: Pin<&mut R>,
     197            4 :         mut writer: Pin<&mut W>,
     198            4 :     ) -> Poll<Result<u64, ErrorDirection>>
     199            4 :     where
     200            4 :         R: AsyncRead + ?Sized,
     201            4 :         W: AsyncWrite + ?Sized,
     202              :     {
     203              :         loop {
     204              :             // If there is some space left in our buffer, then we try to read some
     205              :             // data to continue, thus maximizing the chances of a large write.
     206            7 :             if self.cap < self.buf.len() && !self.read_done {
     207            7 :                 match self.poll_fill_buf(cx, reader.as_mut()) {
     208            7 :                     Poll::Ready(Ok(())) => (),
     209            0 :                     Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
     210              :                     Poll::Pending => {
     211              :                         // Ignore pending reads when our buffer is not empty, because
     212              :                         // we can try to write data immediately.
     213            0 :                         if self.pos == self.cap {
     214              :                             // Try flushing when the reader has no progress to avoid deadlock
     215              :                             // when the reader depends on buffered writer.
     216            0 :                             if self.need_flush {
     217            0 :                                 ready!(writer.as_mut().poll_flush(cx))
     218            0 :                                     .map_err(ErrorDirection::Write)?;
     219            0 :                                 self.need_flush = false;
     220            0 :                             }
     221              : 
     222            0 :                             return Poll::Pending;
     223            0 :                         }
     224              :                     }
     225              :                 }
     226            0 :             }
     227              : 
     228              :             // If our buffer has some data, let's write it out!
     229           11 :             while self.pos < self.cap {
     230            5 :                 let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
     231            4 :                 if i == 0 {
     232            0 :                     return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
     233            0 :                         io::ErrorKind::WriteZero,
     234            0 :                         "write zero byte into writer",
     235            0 :                     ))));
     236            4 :                 }
     237            4 :                 self.pos += i;
     238            4 :                 self.amt += i as u64;
     239            4 :                 self.need_flush = true;
     240              :             }
     241              : 
     242              :             // If pos larger than cap, this loop will never stop.
     243              :             // In particular, user's wrong poll_write implementation returning
     244              :             // incorrect written length may lead to thread blocking.
     245            6 :             debug_assert!(
     246            6 :                 self.pos <= self.cap,
     247            0 :                 "writer returned length larger than input slice"
     248              :             );
     249              : 
     250              :             // All data has been written, the buffer can be considered empty again
     251            6 :             self.pos = 0;
     252            6 :             self.cap = 0;
     253              : 
     254              :             // If we've written all the data and we've seen EOF, flush out the
     255              :             // data and finish the transfer.
     256            6 :             if self.read_done {
     257            3 :                 ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?;
     258            3 :                 return Poll::Ready(Ok(self.amt));
     259            3 :             }
     260              :         }
     261            4 :     }
     262              : }
     263              : 
     264              : #[cfg(test)]
     265              : mod tests {
     266              :     use tokio::io::AsyncWriteExt;
     267              : 
     268              :     use super::*;
     269              : 
     270              :     #[tokio::test]
     271            1 :     async fn test_client_to_compute() {
     272            1 :         let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
     273            1 :         let (mut compute_proxy, mut compute_client) = tokio::io::duplex(32); // Create a mock duplex stream
     274              : 
     275              :         // Simulate 'a' finishing while there's still data for 'b'
     276            1 :         client_client.write_all(b"hello").await.unwrap();
     277            1 :         client_client.shutdown().await.unwrap();
     278            1 :         compute_client.write_all(b"Neon").await.unwrap();
     279            1 :         compute_client.shutdown().await.unwrap();
     280              : 
     281            1 :         let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
     282            1 :             .await
     283            1 :             .unwrap();
     284              : 
     285              :         // Assert correct transferred amounts
     286            1 :         let (client_to_compute_count, compute_to_client_count) = result;
     287            1 :         assert_eq!(client_to_compute_count, 5); // 'hello' was transferred
     288            1 :         assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all
     289            1 :     }
     290              : 
     291              :     #[tokio::test]
     292            1 :     async fn test_compute_to_client() {
     293            1 :         let (mut client_client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
     294            1 :         let (mut compute_proxy, mut compute_client) = tokio::io::duplex(8); // Create a mock duplex stream
     295              : 
     296              :         // Simulate 'a' finishing while there's still data for 'b'
     297            1 :         compute_client.write_all(b"hello").await.unwrap();
     298            1 :         compute_client.shutdown().await.unwrap();
     299            1 :         client_client
     300            1 :             .write_all(b"Neon Serverless Postgres")
     301            1 :             .await
     302            1 :             .unwrap();
     303              : 
     304            1 :         let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
     305            1 :             .await
     306            1 :             .unwrap();
     307              : 
     308              :         // Assert correct transferred amounts
     309            1 :         let (client_to_compute_count, compute_to_client_count) = result;
     310            1 :         assert_eq!(compute_to_client_count, 5); // 'hello' was transferred
     311            1 :         assert!(client_to_compute_count <= 8); // response only partially transferred or not at all
     312            1 :     }
     313              : }
        

Generated by: LCOV version 2.1-beta