LCOV - code coverage report
Current view: top level - libs/utils/src - measured_stream.rs (source / functions) Coverage Total Hit
Test: 1b0a6a0c05cee5a7de360813c8034804e105ce1c.info Lines: 0.0 % 66 0
Test Date: 2025-03-12 00:01:28 Functions: 0.0 % 54 0

            Line data    Source code
       1              : use std::io::Read;
       2              : use std::pin::Pin;
       3              : use std::{io, task};
       4              : 
       5              : use pin_project_lite::pin_project;
       6              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
       7              : 
       8              : pin_project! {
       9              :     /// This stream tracks all writes and calls user provided
      10              :     /// callback when the underlying stream is flushed.
      11              :     pub struct MeasuredStream<S, R, W> {
      12              :         #[pin]
      13              :         stream: S,
      14              :         write_count: usize,
      15              :         inc_read_count: R,
      16              :         inc_write_count: W,
      17              :     }
      18              : }
      19              : 
      20              : impl<S, R, W> MeasuredStream<S, R, W> {
      21            0 :     pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
      22            0 :         Self {
      23            0 :             stream,
      24            0 :             write_count: 0,
      25            0 :             inc_read_count,
      26            0 :             inc_write_count,
      27            0 :         }
      28            0 :     }
      29              : }
      30              : 
      31              : impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
      32            0 :     fn poll_read(
      33            0 :         self: Pin<&mut Self>,
      34            0 :         context: &mut task::Context<'_>,
      35            0 :         buf: &mut ReadBuf<'_>,
      36            0 :     ) -> task::Poll<io::Result<()>> {
      37            0 :         let this = self.project();
      38            0 :         let filled = buf.filled().len();
      39            0 :         this.stream.poll_read(context, buf).map_ok(|()| {
      40            0 :             let cnt = buf.filled().len() - filled;
      41            0 :             // Increment the read count.
      42            0 :             (this.inc_read_count)(cnt);
      43            0 :         })
      44            0 :     }
      45              : }
      46              : 
      47              : impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
      48            0 :     fn poll_write(
      49            0 :         self: Pin<&mut Self>,
      50            0 :         context: &mut task::Context<'_>,
      51            0 :         buf: &[u8],
      52            0 :     ) -> task::Poll<io::Result<usize>> {
      53            0 :         let this = self.project();
      54            0 :         this.stream.poll_write(context, buf).map_ok(|cnt| {
      55            0 :             // Increment the write count.
      56            0 :             *this.write_count += cnt;
      57            0 :             cnt
      58            0 :         })
      59            0 :     }
      60              : 
      61            0 :     fn poll_flush(
      62            0 :         self: Pin<&mut Self>,
      63            0 :         context: &mut task::Context<'_>,
      64            0 :     ) -> task::Poll<io::Result<()>> {
      65            0 :         let this = self.project();
      66            0 :         this.stream.poll_flush(context).map_ok(|()| {
      67            0 :             // Call the user provided callback and reset the write count.
      68            0 :             (this.inc_write_count)(*this.write_count);
      69            0 :             *this.write_count = 0;
      70            0 :         })
      71            0 :     }
      72              : 
      73            0 :     fn poll_shutdown(
      74            0 :         self: Pin<&mut Self>,
      75            0 :         context: &mut task::Context<'_>,
      76            0 :     ) -> task::Poll<io::Result<()>> {
      77            0 :         self.project().stream.poll_shutdown(context)
      78            0 :     }
      79              : }
      80              : 
      81              : /// Wrapper for a reader that counts bytes read.
      82              : ///
      83              : /// Similar to MeasuredStream but it's one way and it's sync
      84              : pub struct MeasuredReader<R: Read> {
      85              :     inner: R,
      86              :     byte_count: usize,
      87              : }
      88              : 
      89              : impl<R: Read> MeasuredReader<R> {
      90            0 :     pub fn new(reader: R) -> Self {
      91            0 :         Self {
      92            0 :             inner: reader,
      93            0 :             byte_count: 0,
      94            0 :         }
      95            0 :     }
      96              : 
      97            0 :     pub fn get_byte_count(&self) -> usize {
      98            0 :         self.byte_count
      99            0 :     }
     100              : }
     101              : 
     102              : impl<R: Read> Read for MeasuredReader<R> {
     103            0 :     fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
     104            0 :         let result = self.inner.read(buf);
     105            0 :         if let Ok(n_bytes) = result {
     106            0 :             self.byte_count += n_bytes
     107            0 :         }
     108            0 :         result
     109            0 :     }
     110              : }
        

Generated by: LCOV version 2.1-beta