Line data Source code
1 : use bytes::BytesMut;
2 : use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
3 :
4 : /// A trait for doing owned-buffer write IO.
5 : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
6 : pub trait OwnedAsyncWriter {
7 : async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
8 : &mut self,
9 : buf: B,
10 : ) -> std::io::Result<(usize, B::Buf)>;
11 : }
12 :
13 : /// A wrapper aorund an [`OwnedAsyncWriter`] that batches smaller writers
14 : /// into `BUFFER_SIZE`-sized writes.
15 : ///
16 : /// # Passthrough Of Large Writers
17 : ///
18 : /// Buffered writes larger than the `BUFFER_SIZE` cause the internal
19 : /// buffer to be flushed, even if it is not full yet. Then, the large
20 : /// buffered write is passed through to the unerlying [`OwnedAsyncWriter`].
21 : ///
22 : /// This pass-through is generally beneficial for throughput, but if
23 : /// the storage backend of the [`OwnedAsyncWriter`] is a shared resource,
24 : /// unlimited large writes may cause latency or fairness issues.
25 : ///
26 : /// In such cases, a different implementation that always buffers in memory
27 : /// may be preferable.
28 : pub struct BufferedWriter<const BUFFER_SIZE: usize, W> {
29 : writer: W,
30 : // invariant: always remains Some(buf)
31 : // with buf.capacity() == BUFFER_SIZE except
32 : // - while IO is ongoing => goes back to Some() once the IO completed successfully
33 : // - after an IO error => stays `None` forever
34 : // In these exceptional cases, it's `None`.
35 : buf: Option<BytesMut>,
36 : }
37 :
38 : impl<const BUFFER_SIZE: usize, W> BufferedWriter<BUFFER_SIZE, W>
39 : where
40 : W: OwnedAsyncWriter,
41 : {
42 9 : pub fn new(writer: W) -> Self {
43 9 : Self {
44 9 : writer,
45 9 : buf: Some(BytesMut::with_capacity(BUFFER_SIZE)),
46 9 : }
47 9 : }
48 :
49 9 : pub async fn flush_and_into_inner(mut self) -> std::io::Result<W> {
50 9 : self.flush().await?;
51 9 : let Self { buf, writer } = self;
52 9 : assert!(buf.is_some());
53 9 : Ok(writer)
54 9 : }
55 :
56 44 : pub async fn write_buffered<B: IoBuf>(&mut self, chunk: Slice<B>) -> std::io::Result<()>
57 44 : where
58 44 : B: IoBuf + Send,
59 44 : {
60 44 : // avoid memcpy for the middle of the chunk
61 44 : if chunk.len() >= BUFFER_SIZE {
62 8 : self.flush().await?;
63 : // do a big write, bypassing `buf`
64 8 : assert_eq!(
65 8 : self.buf
66 8 : .as_ref()
67 8 : .expect("must not use after an error")
68 8 : .len(),
69 8 : 0
70 8 : );
71 8 : let chunk_len = chunk.len();
72 8 : let (nwritten, chunk) = self.writer.write_all(chunk).await?;
73 8 : assert_eq!(nwritten, chunk_len);
74 8 : drop(chunk);
75 8 : return Ok(());
76 36 : }
77 36 : // in-memory copy the < BUFFER_SIZED tail of the chunk
78 36 : assert!(chunk.len() < BUFFER_SIZE);
79 36 : let mut chunk = &chunk[..];
80 70 : while !chunk.is_empty() {
81 34 : let buf = self.buf.as_mut().expect("must not use after an error");
82 34 : let need = BUFFER_SIZE - buf.len();
83 34 : let have = chunk.len();
84 34 : let n = std::cmp::min(need, have);
85 34 : buf.extend_from_slice(&chunk[..n]);
86 34 : chunk = &chunk[n..];
87 34 : if buf.len() >= BUFFER_SIZE {
88 6 : assert_eq!(buf.len(), BUFFER_SIZE);
89 6 : self.flush().await?;
90 28 : }
91 : }
92 36 : assert!(chunk.is_empty(), "by now we should have drained the chunk");
93 36 : Ok(())
94 44 : }
95 :
96 23 : async fn flush(&mut self) -> std::io::Result<()> {
97 23 : let buf = self.buf.take().expect("must not use after an error");
98 23 : if buf.is_empty() {
99 10 : self.buf = Some(buf);
100 10 : return std::io::Result::Ok(());
101 13 : }
102 13 : let buf_len = buf.len();
103 13 : let (nwritten, mut buf) = self.writer.write_all(buf).await?;
104 13 : assert_eq!(nwritten, buf_len);
105 13 : buf.clear();
106 13 : self.buf = Some(buf);
107 13 : Ok(())
108 23 : }
109 : }
110 :
111 : impl OwnedAsyncWriter for Vec<u8> {
112 0 : async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
113 0 : &mut self,
114 0 : buf: B,
115 0 : ) -> std::io::Result<(usize, B::Buf)> {
116 0 : let nbytes = buf.bytes_init();
117 0 : if nbytes == 0 {
118 0 : return Ok((0, Slice::into_inner(buf.slice_full())));
119 0 : }
120 0 : let buf = buf.slice(0..nbytes);
121 0 : self.extend_from_slice(&buf[..]);
122 0 : Ok((buf.len(), Slice::into_inner(buf)))
123 0 : }
124 : }
125 :
126 : #[cfg(test)]
127 : mod tests {
128 : use super::*;
129 :
130 : #[derive(Default)]
131 : struct RecorderWriter {
132 : writes: Vec<Vec<u8>>,
133 : }
134 : impl OwnedAsyncWriter for RecorderWriter {
135 18 : async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
136 18 : &mut self,
137 18 : buf: B,
138 18 : ) -> std::io::Result<(usize, B::Buf)> {
139 18 : let nbytes = buf.bytes_init();
140 18 : if nbytes == 0 {
141 0 : self.writes.push(vec![]);
142 0 : return Ok((0, Slice::into_inner(buf.slice_full())));
143 18 : }
144 18 : let buf = buf.slice(0..nbytes);
145 18 : self.writes.push(Vec::from(&buf[..]));
146 18 : Ok((buf.len(), Slice::into_inner(buf)))
147 18 : }
148 : }
149 :
150 : macro_rules! write {
151 : ($writer:ident, $data:literal) => {{
152 : $writer
153 : .write_buffered(::bytes::Bytes::from_static($data).slice_full())
154 : .await?;
155 : }};
156 : }
157 :
158 : #[tokio::test]
159 2 : async fn test_buffered_writes_only() -> std::io::Result<()> {
160 2 : let recorder = RecorderWriter::default();
161 2 : let mut writer = BufferedWriter::<2, _>::new(recorder);
162 2 : write!(writer, b"a");
163 2 : write!(writer, b"b");
164 2 : write!(writer, b"c");
165 2 : write!(writer, b"d");
166 2 : write!(writer, b"e");
167 2 : let recorder = writer.flush_and_into_inner().await?;
168 2 : assert_eq!(
169 2 : recorder.writes,
170 2 : vec![Vec::from(b"ab"), Vec::from(b"cd"), Vec::from(b"e")]
171 2 : );
172 2 : Ok(())
173 2 : }
174 :
175 : #[tokio::test]
176 2 : async fn test_passthrough_writes_only() -> std::io::Result<()> {
177 2 : let recorder = RecorderWriter::default();
178 2 : let mut writer = BufferedWriter::<2, _>::new(recorder);
179 2 : write!(writer, b"abc");
180 2 : write!(writer, b"de");
181 2 : write!(writer, b"");
182 2 : write!(writer, b"fghijk");
183 2 : let recorder = writer.flush_and_into_inner().await?;
184 2 : assert_eq!(
185 2 : recorder.writes,
186 2 : vec![Vec::from(b"abc"), Vec::from(b"de"), Vec::from(b"fghijk")]
187 2 : );
188 2 : Ok(())
189 2 : }
190 :
191 : #[tokio::test]
192 2 : async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> {
193 2 : let recorder = RecorderWriter::default();
194 2 : let mut writer = BufferedWriter::<2, _>::new(recorder);
195 2 : write!(writer, b"a");
196 2 : write!(writer, b"bc");
197 2 : write!(writer, b"d");
198 2 : write!(writer, b"e");
199 2 : let recorder = writer.flush_and_into_inner().await?;
200 2 : assert_eq!(
201 2 : recorder.writes,
202 2 : vec![Vec::from(b"a"), Vec::from(b"bc"), Vec::from(b"de")]
203 2 : );
204 2 : Ok(())
205 2 : }
206 : }
|