TLA Line data Source code
1 : use pin_project_lite::pin_project;
2 : use std::io::Read;
3 : use std::pin::Pin;
4 : use std::{io, task};
5 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
6 :
7 : pin_project! {
8 : /// This stream tracks all writes and calls user provided
9 : /// callback when the underlying stream is flushed.
10 : pub struct MeasuredStream<S, R, W> {
11 : #[pin]
12 : stream: S,
13 : write_count: usize,
14 : inc_read_count: R,
15 : inc_write_count: W,
16 : }
17 : }
18 :
19 : impl<S, R, W> MeasuredStream<S, R, W> {
20 CBC 3621 : pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
21 3621 : Self {
22 3621 : stream,
23 3621 : write_count: 0,
24 3621 : inc_read_count,
25 3621 : inc_write_count,
26 3621 : }
27 3621 : }
28 : }
29 :
30 : impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
31 11563552 : fn poll_read(
32 11563552 : self: Pin<&mut Self>,
33 11563552 : context: &mut task::Context<'_>,
34 11563552 : buf: &mut ReadBuf<'_>,
35 11563552 : ) -> task::Poll<io::Result<()>> {
36 11563552 : let this = self.project();
37 11563552 : let filled = buf.filled().len();
38 11563552 : this.stream.poll_read(context, buf).map_ok(|()| {
39 3540027 : let cnt = buf.filled().len() - filled;
40 3540027 : // Increment the read count.
41 3540027 : (this.inc_read_count)(cnt);
42 11563552 : })
43 11563552 : }
44 : }
45 :
46 : impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
47 3284718 : fn poll_write(
48 3284718 : self: Pin<&mut Self>,
49 3284718 : context: &mut task::Context<'_>,
50 3284718 : buf: &[u8],
51 3284718 : ) -> task::Poll<io::Result<usize>> {
52 3284718 : let this = self.project();
53 3284718 : this.stream.poll_write(context, buf).map_ok(|cnt| {
54 3263785 : // Increment the write count.
55 3263785 : *this.write_count += cnt;
56 3263785 : cnt
57 3284718 : })
58 3284718 : }
59 :
60 3265241 : fn poll_flush(
61 3265241 : self: Pin<&mut Self>,
62 3265241 : context: &mut task::Context<'_>,
63 3265241 : ) -> task::Poll<io::Result<()>> {
64 3265241 : let this = self.project();
65 3265241 : this.stream.poll_flush(context).map_ok(|()| {
66 3265241 : // Call the user provided callback and reset the write count.
67 3265241 : (this.inc_write_count)(*this.write_count);
68 3265241 : *this.write_count = 0;
69 3265241 : })
70 3265241 : }
71 :
72 2871 : fn poll_shutdown(
73 2871 : self: Pin<&mut Self>,
74 2871 : context: &mut task::Context<'_>,
75 2871 : ) -> task::Poll<io::Result<()>> {
76 2871 : self.project().stream.poll_shutdown(context)
77 2871 : }
78 : }
79 :
80 : /// Wrapper for a reader that counts bytes read.
81 : ///
82 : /// Similar to MeasuredStream but it's one way and it's sync
83 : pub struct MeasuredReader<R: Read> {
84 : inner: R,
85 : byte_count: usize,
86 : }
87 :
88 : impl<R: Read> MeasuredReader<R> {
89 634 : pub fn new(reader: R) -> Self {
90 634 : Self {
91 634 : inner: reader,
92 634 : byte_count: 0,
93 634 : }
94 634 : }
95 :
96 632 : pub fn get_byte_count(&self) -> usize {
97 632 : self.byte_count
98 632 : }
99 : }
100 :
101 : impl<R: Read> Read for MeasuredReader<R> {
102 862766 : fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
103 862766 : let result = self.inner.read(buf);
104 862766 : if let Ok(n_bytes) = result {
105 862764 : self.byte_count += n_bytes
106 2 : }
107 862766 : result
108 862766 : }
109 : }
|