LCOV - code coverage report
Current view: top level - proxy/src/pglb - inprocess.rs (source / functions) Coverage Total Hit
Test: aca806cab4756d7eb6a304846130f4a73a5d5393.info Lines: 85.5 % 62 53
Test Date: 2025-04-24 20:31:15 Functions: 70.0 % 10 7

            Line data    Source code
       1              : #![allow(dead_code, reason = "TODO: work in progress")]
       2              : 
       3              : use std::pin::{Pin, pin};
       4              : use std::sync::Arc;
       5              : use std::sync::atomic::{AtomicUsize, Ordering};
       6              : use std::task::{Context, Poll};
       7              : use std::{fmt, io};
       8              : 
       9              : use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
      10              : use tokio::sync::mpsc;
      11              : 
      12              : const STREAM_CHANNEL_SIZE: usize = 16;
      13              : const MAX_STREAM_BUFFER_SIZE: usize = 4096;
      14              : 
      15              : #[derive(Debug)]
      16              : pub struct Connection {
      17              :     stream_sender: mpsc::Sender<Stream>,
      18              :     stream_receiver: mpsc::Receiver<Stream>,
      19              :     stream_id_counter: Arc<AtomicUsize>,
      20              : }
      21              : 
      22              : impl Connection {
      23            1 :     pub fn new() -> (Connection, Connection) {
      24            1 :         let (sender_a, receiver_a) = mpsc::channel(STREAM_CHANNEL_SIZE);
      25            1 :         let (sender_b, receiver_b) = mpsc::channel(STREAM_CHANNEL_SIZE);
      26            1 : 
      27            1 :         let stream_id_counter = Arc::new(AtomicUsize::new(1));
      28            1 : 
      29            1 :         let conn_a = Connection {
      30            1 :             stream_sender: sender_a,
      31            1 :             stream_receiver: receiver_b,
      32            1 :             stream_id_counter: Arc::clone(&stream_id_counter),
      33            1 :         };
      34            1 :         let conn_b = Connection {
      35            1 :             stream_sender: sender_b,
      36            1 :             stream_receiver: receiver_a,
      37            1 :             stream_id_counter,
      38            1 :         };
      39            1 : 
      40            1 :         (conn_a, conn_b)
      41            1 :     }
      42              : 
      43              :     #[inline]
      44            1 :     fn next_stream_id(&self) -> StreamId {
      45            1 :         StreamId(self.stream_id_counter.fetch_add(1, Ordering::Relaxed))
      46            1 :     }
      47              : 
      48              :     #[tracing::instrument(skip_all, fields(stream_id = tracing::field::Empty, err))]
      49              :     pub async fn open_stream(&self) -> io::Result<Stream> {
      50              :         let (local, remote) = tokio::io::duplex(MAX_STREAM_BUFFER_SIZE);
      51              :         let stream_id = self.next_stream_id();
      52              :         tracing::Span::current().record("stream_id", stream_id.0);
      53              : 
      54              :         let local = Stream {
      55              :             inner: local,
      56              :             id: stream_id,
      57              :         };
      58              :         let remote = Stream {
      59              :             inner: remote,
      60              :             id: stream_id,
      61              :         };
      62              : 
      63              :         self.stream_sender
      64              :             .send(remote)
      65              :             .await
      66              :             .map_err(io::Error::other)?;
      67              : 
      68              :         Ok(local)
      69              :     }
      70              : 
      71              :     #[tracing::instrument(skip_all, fields(stream_id = tracing::field::Empty, err))]
      72              :     pub async fn accept_stream(&mut self) -> io::Result<Option<Stream>> {
      73            1 :         Ok(self.stream_receiver.recv().await.inspect(|stream| {
      74            1 :             tracing::Span::current().record("stream_id", stream.id.0);
      75            1 :         }))
      76              :     }
      77              : }
      78              : 
      79              : #[derive(Copy, Clone, Debug)]
      80              : pub struct StreamId(usize);
      81              : 
      82              : impl fmt::Display for StreamId {
      83              :     #[inline]
      84            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      85            0 :         write!(f, "{}", self.0)
      86            0 :     }
      87              : }
      88              : 
      89              : // TODO: Proper closing. Currently Streams can outlive their Connections.
      90              : // Carry WeakSender and check strong_count?
      91              : #[derive(Debug)]
      92              : pub struct Stream {
      93              :     inner: DuplexStream,
      94              :     id: StreamId,
      95              : }
      96              : 
      97              : impl Stream {
      98              :     #[inline]
      99            0 :     pub fn id(&self) -> StreamId {
     100            0 :         self.id
     101            0 :     }
     102              : }
     103              : 
     104              : impl AsyncRead for Stream {
     105              :     #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
     106              :     #[inline]
     107              :     fn poll_read(
     108              :         mut self: Pin<&mut Self>,
     109              :         cx: &mut Context<'_>,
     110              :         buf: &mut ReadBuf<'_>,
     111              :     ) -> Poll<io::Result<()>> {
     112              :         pin!(&mut self.inner).poll_read(cx, buf)
     113              :     }
     114              : }
     115              : 
     116              : impl AsyncWrite for Stream {
     117              :     #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
     118              :     #[inline]
     119              :     fn poll_write(
     120              :         mut self: Pin<&mut Self>,
     121              :         cx: &mut Context<'_>,
     122              :         buf: &[u8],
     123              :     ) -> Poll<Result<usize, io::Error>> {
     124              :         pin!(&mut self.inner).poll_write(cx, buf)
     125              :     }
     126              : 
     127              :     #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
     128              :     #[inline]
     129              :     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
     130              :         pin!(&mut self.inner).poll_flush(cx)
     131              :     }
     132              : 
     133              :     #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
     134              :     #[inline]
     135              :     fn poll_shutdown(
     136              :         mut self: Pin<&mut Self>,
     137              :         cx: &mut Context<'_>,
     138              :     ) -> Poll<Result<(), io::Error>> {
     139              :         pin!(&mut self.inner).poll_shutdown(cx)
     140              :     }
     141              : 
     142              :     #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
     143              :     #[inline]
     144              :     fn poll_write_vectored(
     145              :         mut self: Pin<&mut Self>,
     146              :         cx: &mut Context<'_>,
     147              :         bufs: &[io::IoSlice<'_>],
     148              :     ) -> Poll<Result<usize, io::Error>> {
     149              :         pin!(&mut self.inner).poll_write_vectored(cx, bufs)
     150              :     }
     151              : 
     152              :     #[inline]
     153            0 :     fn is_write_vectored(&self) -> bool {
     154            0 :         self.inner.is_write_vectored()
     155            0 :     }
     156              : }
     157              : 
     158              : #[cfg(test)]
     159              : mod tests {
     160              :     use tokio::io::{AsyncReadExt, AsyncWriteExt};
     161              : 
     162              :     use super::*;
     163              : 
     164              :     #[tokio::test]
     165            1 :     async fn test_simple_roundtrip() {
     166            1 :         let (client, mut server) = Connection::new();
     167            1 : 
     168            1 :         let server_task = tokio::spawn(async move {
     169            2 :             while let Some(mut stream) = server.accept_stream().await.unwrap() {
     170            1 :                 tokio::spawn(async move {
     171            1 :                     let mut buf = [0; 64];
     172            1 :                     loop {
     173            2 :                         match stream.read(&mut buf).await.unwrap() {
     174            1 :                             0 => break,
     175            1 :                             n => stream.write(&buf[..n]).await.unwrap(),
     176            1 :                         };
     177            1 :                     }
     178            1 :                 });
     179            1 :             }
     180            1 :         });
     181            1 : 
     182            1 :         let mut stream = client.open_stream().await.unwrap();
     183            1 :         stream.write_all(b"hello!").await.unwrap();
     184            1 :         let mut buf = [0; 64];
     185            1 :         let n = stream.read(&mut buf).await.unwrap();
     186            1 :         assert_eq!(n, 6);
     187            1 :         assert_eq!(&buf[..n], b"hello!");
     188            1 : 
     189            1 :         drop(stream);
     190            1 :         drop(client);
     191            1 :         server_task.await.unwrap();
     192            1 :     }
     193              : }
        

Generated by: LCOV version 2.1-beta