Line data Source code
1 : use std::io::Read;
2 : use std::pin::Pin;
3 : use std::{io, task};
4 :
5 : use pin_project_lite::pin_project;
6 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7 :
8 : pin_project! {
9 : /// This stream tracks all writes and calls user provided
10 : /// callback when the underlying stream is flushed.
11 : pub struct MeasuredStream<S, R, W> {
12 : #[pin]
13 : stream: S,
14 : write_count: usize,
15 : inc_read_count: R,
16 : inc_write_count: W,
17 : }
18 : }
19 :
20 : impl<S, R, W> MeasuredStream<S, R, W> {
21 0 : pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
22 0 : Self {
23 0 : stream,
24 0 : write_count: 0,
25 0 : inc_read_count,
26 0 : inc_write_count,
27 0 : }
28 0 : }
29 : }
30 :
31 : impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
32 0 : fn poll_read(
33 0 : self: Pin<&mut Self>,
34 0 : context: &mut task::Context<'_>,
35 0 : buf: &mut ReadBuf<'_>,
36 0 : ) -> task::Poll<io::Result<()>> {
37 0 : let this = self.project();
38 0 : let filled = buf.filled().len();
39 0 : this.stream.poll_read(context, buf).map_ok(|()| {
40 0 : let cnt = buf.filled().len() - filled;
41 0 : // Increment the read count.
42 0 : (this.inc_read_count)(cnt);
43 0 : })
44 0 : }
45 : }
46 :
47 : impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
48 0 : fn poll_write(
49 0 : self: Pin<&mut Self>,
50 0 : context: &mut task::Context<'_>,
51 0 : buf: &[u8],
52 0 : ) -> task::Poll<io::Result<usize>> {
53 0 : let this = self.project();
54 0 : this.stream.poll_write(context, buf).map_ok(|cnt| {
55 0 : // Increment the write count.
56 0 : *this.write_count += cnt;
57 0 : cnt
58 0 : })
59 0 : }
60 :
61 0 : fn poll_flush(
62 0 : self: Pin<&mut Self>,
63 0 : context: &mut task::Context<'_>,
64 0 : ) -> task::Poll<io::Result<()>> {
65 0 : let this = self.project();
66 0 : this.stream.poll_flush(context).map_ok(|()| {
67 0 : // Call the user provided callback and reset the write count.
68 0 : (this.inc_write_count)(*this.write_count);
69 0 : *this.write_count = 0;
70 0 : })
71 0 : }
72 :
73 0 : fn poll_shutdown(
74 0 : self: Pin<&mut Self>,
75 0 : context: &mut task::Context<'_>,
76 0 : ) -> task::Poll<io::Result<()>> {
77 0 : self.project().stream.poll_shutdown(context)
78 0 : }
79 : }
80 :
81 : /// Wrapper for a reader that counts bytes read.
82 : ///
83 : /// Similar to MeasuredStream but it's one way and it's sync
84 : pub struct MeasuredReader<R: Read> {
85 : inner: R,
86 : byte_count: usize,
87 : }
88 :
89 : impl<R: Read> MeasuredReader<R> {
90 0 : pub fn new(reader: R) -> Self {
91 0 : Self {
92 0 : inner: reader,
93 0 : byte_count: 0,
94 0 : }
95 0 : }
96 :
97 0 : pub fn get_byte_count(&self) -> usize {
98 0 : self.byte_count
99 0 : }
100 : }
101 :
102 : impl<R: Read> Read for MeasuredReader<R> {
103 0 : fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
104 0 : let result = self.inner.read(buf);
105 0 : if let Ok(n_bytes) = result {
106 0 : self.byte_count += n_bytes
107 0 : }
108 0 : result
109 0 : }
110 : }
|