LCOV - code coverage report
Current view: top level - libs/utils/src - measured_stream.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 98.5 % 66 65
Test Date: 2024-02-12 20:26:03 Functions: 61.4 % 70 43

            Line data    Source code
       1              : use pin_project_lite::pin_project;
       2              : use std::io::Read;
       3              : use std::pin::Pin;
       4              : use std::{io, task};
       5              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
       6              : 
       7              : pin_project! {
       8              :     /// This stream tracks all writes and calls user provided
       9              :     /// callback when the underlying stream is flushed.
      10              :     pub struct MeasuredStream<S, R, W> {
      11              :         #[pin]
      12              :         stream: S,
      13              :         write_count: usize,
      14              :         inc_read_count: R,
      15              :         inc_write_count: W,
      16              :     }
      17              : }
      18              : 
      19              : impl<S, R, W> MeasuredStream<S, R, W> {
      20         3372 :     pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
      21         3372 :         Self {
      22         3372 :             stream,
      23         3372 :             write_count: 0,
      24         3372 :             inc_read_count,
      25         3372 :             inc_write_count,
      26         3372 :         }
      27         3372 :     }
      28              : }
      29              : 
      30              : impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
      31      9496370 :     fn poll_read(
      32      9496370 :         self: Pin<&mut Self>,
      33      9496370 :         context: &mut task::Context<'_>,
      34      9496370 :         buf: &mut ReadBuf<'_>,
      35      9496370 :     ) -> task::Poll<io::Result<()>> {
      36      9496370 :         let this = self.project();
      37      9496370 :         let filled = buf.filled().len();
      38      9496370 :         this.stream.poll_read(context, buf).map_ok(|()| {
      39      2939974 :             let cnt = buf.filled().len() - filled;
      40      2939974 :             // Increment the read count.
      41      2939974 :             (this.inc_read_count)(cnt);
      42      9496370 :         })
      43      9496370 :     }
      44              : }
      45              : 
      46              : impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
      47      2503958 :     fn poll_write(
      48      2503958 :         self: Pin<&mut Self>,
      49      2503958 :         context: &mut task::Context<'_>,
      50      2503958 :         buf: &[u8],
      51      2503958 :     ) -> task::Poll<io::Result<usize>> {
      52      2503958 :         let this = self.project();
      53      2503958 :         this.stream.poll_write(context, buf).map_ok(|cnt| {
      54      2480619 :             // Increment the write count.
      55      2480619 :             *this.write_count += cnt;
      56      2480619 :             cnt
      57      2503958 :         })
      58      2503958 :     }
      59              : 
      60      2481465 :     fn poll_flush(
      61      2481465 :         self: Pin<&mut Self>,
      62      2481465 :         context: &mut task::Context<'_>,
      63      2481465 :     ) -> task::Poll<io::Result<()>> {
      64      2481465 :         let this = self.project();
      65      2481465 :         this.stream.poll_flush(context).map_ok(|()| {
      66      2481465 :             // Call the user provided callback and reset the write count.
      67      2481465 :             (this.inc_write_count)(*this.write_count);
      68      2481465 :             *this.write_count = 0;
      69      2481465 :         })
      70      2481465 :     }
      71              : 
      72         2687 :     fn poll_shutdown(
      73         2687 :         self: Pin<&mut Self>,
      74         2687 :         context: &mut task::Context<'_>,
      75         2687 :     ) -> task::Poll<io::Result<()>> {
      76         2687 :         self.project().stream.poll_shutdown(context)
      77         2687 :     }
      78              : }
      79              : 
      80              : /// Wrapper for a reader that counts bytes read.
      81              : ///
      82              : /// Similar to MeasuredStream but it's one way and it's sync
      83              : pub struct MeasuredReader<R: Read> {
      84              :     inner: R,
      85              :     byte_count: usize,
      86              : }
      87              : 
      88              : impl<R: Read> MeasuredReader<R> {
      89          572 :     pub fn new(reader: R) -> Self {
      90          572 :         Self {
      91          572 :             inner: reader,
      92          572 :             byte_count: 0,
      93          572 :         }
      94          572 :     }
      95              : 
      96          572 :     pub fn get_byte_count(&self) -> usize {
      97          572 :         self.byte_count
      98          572 :     }
      99              : }
     100              : 
     101              : impl<R: Read> Read for MeasuredReader<R> {
     102       857369 :     fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
     103       857369 :         let result = self.inner.read(buf);
     104       857369 :         if let Ok(n_bytes) = result {
     105       857369 :             self.byte_count += n_bytes
     106            0 :         }
     107       857369 :         result
     108       857369 :     }
     109              : }
        

Generated by: LCOV version 2.1-beta