Line data Source code
1 : use bytes::BytesMut;
2 : use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
3 :
4 : use crate::context::RequestContext;
5 :
6 : /// A trait for doing owned-buffer write IO.
7 : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
8 : pub trait OwnedAsyncWriter {
9 : async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
10 : &mut self,
11 : buf: B,
12 : ctx: &RequestContext,
13 : ) -> std::io::Result<(usize, B::Buf)>;
14 : }
15 :
16 : /// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
17 : /// small writes into larger writes of size [`Buffer::cap`].
18 : ///
19 : /// # Passthrough Of Large Writers
20 : ///
21 : /// Calls to [`BufferedWriter::write_buffered`] that are larger than [`Buffer::cap`]
22 : /// cause the internal buffer to be flushed prematurely so that the large
23 : /// buffered write is passed through to the underlying [`OwnedAsyncWriter`].
24 : ///
25 : /// This pass-through is generally beneficial for throughput, but if
26 : /// the storage backend of the [`OwnedAsyncWriter`] is a shared resource,
27 : /// unlimited large writes may cause latency or fairness issues.
28 : ///
29 : /// In such cases, a different implementation that always buffers in memory
30 : /// may be preferable.
31 : pub struct BufferedWriter<B, W> {
32 : writer: W,
33 : /// invariant: always remains Some(buf) except
34 : /// - while IO is ongoing => goes back to Some() once the IO completed successfully
35 : /// - after an IO error => stays `None` forever
36 : ///
37 : /// In these exceptional cases, it's `None`.
38 : buf: Option<B>,
39 : }
40 :
41 : impl<B, Buf, W> BufferedWriter<B, W>
42 : where
43 : B: Buffer<IoBuf = Buf> + Send,
44 : Buf: IoBuf + Send,
45 : W: OwnedAsyncWriter,
46 : {
47 1269 : pub fn new(writer: W, buf: B) -> Self {
48 1269 : Self {
49 1269 : writer,
50 1269 : buf: Some(buf),
51 1269 : }
52 1269 : }
53 :
54 10931128 : pub fn as_inner(&self) -> &W {
55 10931128 : &self.writer
56 10931128 : }
57 :
58 : /// Panics if used after any of the write paths returned an error
59 10730280 : pub fn inspect_buffer(&self) -> &B {
60 10730280 : self.buf()
61 10730280 : }
62 :
63 : #[cfg_attr(target_os = "macos", allow(dead_code))]
64 11 : pub async fn flush_and_into_inner(mut self, ctx: &RequestContext) -> std::io::Result<W> {
65 11 : self.flush(ctx).await?;
66 :
67 11 : let Self { buf, writer } = self;
68 11 : assert!(buf.is_some());
69 11 : Ok(writer)
70 11 : }
71 :
72 : #[inline(always)]
73 10730360 : fn buf(&self) -> &B {
74 10730360 : self.buf
75 10730360 : .as_ref()
76 10730360 : .expect("must not use after we returned an error")
77 10730360 : }
78 :
79 : #[cfg_attr(target_os = "macos", allow(dead_code))]
80 44 : pub async fn write_buffered<S: IoBuf + Send>(
81 44 : &mut self,
82 44 : chunk: Slice<S>,
83 44 : ctx: &RequestContext,
84 44 : ) -> std::io::Result<(usize, S)> {
85 44 : let chunk_len = chunk.len();
86 44 : // avoid memcpy for the middle of the chunk
87 44 : if chunk.len() >= self.buf().cap() {
88 8 : self.flush(ctx).await?;
89 : // do a big write, bypassing `buf`
90 8 : assert_eq!(
91 8 : self.buf
92 8 : .as_ref()
93 8 : .expect("must not use after an error")
94 8 : .pending(),
95 8 : 0
96 8 : );
97 8 : let (nwritten, chunk) = self.writer.write_all(chunk, ctx).await?;
98 8 : assert_eq!(nwritten, chunk_len);
99 8 : return Ok((nwritten, chunk));
100 36 : }
101 36 : // in-memory copy the < BUFFER_SIZED tail of the chunk
102 36 : assert!(chunk.len() < self.buf().cap());
103 36 : let mut slice = &chunk[..];
104 70 : while !slice.is_empty() {
105 34 : let buf = self.buf.as_mut().expect("must not use after an error");
106 34 : let need = buf.cap() - buf.pending();
107 34 : let have = slice.len();
108 34 : let n = std::cmp::min(need, have);
109 34 : buf.extend_from_slice(&slice[..n]);
110 34 : slice = &slice[n..];
111 34 : if buf.pending() >= buf.cap() {
112 6 : assert_eq!(buf.pending(), buf.cap());
113 6 : self.flush(ctx).await?;
114 28 : }
115 : }
116 36 : assert!(slice.is_empty(), "by now we should have drained the chunk");
117 36 : Ok((chunk_len, chunk.into_inner()))
118 44 : }
119 :
120 : /// Strictly less performant variant of [`Self::write_buffered`] that allows writing borrowed data.
121 : ///
122 : /// It is less performant because we always have to copy the borrowed data into the internal buffer
123 : /// before we can do the IO. The [`Self::write_buffered`] can avoid this, which is more performant
124 : /// for large writes.
125 10221518 : pub async fn write_buffered_borrowed(
126 10221518 : &mut self,
127 10221518 : mut chunk: &[u8],
128 10221518 : ctx: &RequestContext,
129 10221518 : ) -> std::io::Result<usize> {
130 10221518 : let chunk_len = chunk.len();
131 20449642 : while !chunk.is_empty() {
132 10228124 : let buf = self.buf.as_mut().expect("must not use after an error");
133 10228124 : let need = buf.cap() - buf.pending();
134 10228124 : let have = chunk.len();
135 10228124 : let n = std::cmp::min(need, have);
136 10228124 : buf.extend_from_slice(&chunk[..n]);
137 10228124 : chunk = &chunk[n..];
138 10228124 : if buf.pending() >= buf.cap() {
139 6624 : assert_eq!(buf.pending(), buf.cap());
140 6624 : self.flush(ctx).await?;
141 10221500 : }
142 : }
143 10221518 : Ok(chunk_len)
144 10221518 : }
145 :
146 6649 : async fn flush(&mut self, ctx: &RequestContext) -> std::io::Result<()> {
147 6649 : let buf = self.buf.take().expect("must not use after an error");
148 6649 : let buf_len = buf.pending();
149 6649 : if buf_len == 0 {
150 10 : self.buf = Some(buf);
151 10 : return Ok(());
152 6639 : }
153 6639 : let (nwritten, io_buf) = self.writer.write_all(buf.flush(), ctx).await?;
154 6639 : assert_eq!(nwritten, buf_len);
155 6639 : self.buf = Some(Buffer::reuse_after_flush(io_buf));
156 6639 : Ok(())
157 6649 : }
158 : }
159 :
160 : /// A [`Buffer`] is used by [`BufferedWriter`] to batch smaller writes into larger ones.
161 : pub trait Buffer {
162 : type IoBuf: IoBuf;
163 :
164 : /// Capacity of the buffer. Must not change over the lifetime `self`.`
165 : fn cap(&self) -> usize;
166 :
167 : /// Add data to the buffer.
168 : /// Panics if there is not enough room to accomodate `other`'s content, i.e.,
169 : /// panics if `other.len() > self.cap() - self.pending()`.
170 : fn extend_from_slice(&mut self, other: &[u8]);
171 :
172 : /// Number of bytes in the buffer.
173 : fn pending(&self) -> usize;
174 :
175 : /// Turns `self` into a [`tokio_epoll_uring::Slice`] of the pending data
176 : /// so we can use [`tokio_epoll_uring`] to write it to disk.
177 : fn flush(self) -> Slice<Self::IoBuf>;
178 :
179 : /// After the write to disk is done and we have gotten back the slice,
180 : /// [`BufferedWriter`] uses this method to re-use the io buffer.
181 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self;
182 : }
183 :
184 : impl Buffer for BytesMut {
185 : type IoBuf = BytesMut;
186 :
187 : #[inline(always)]
188 216 : fn cap(&self) -> usize {
189 216 : self.capacity()
190 216 : }
191 :
192 58 : fn extend_from_slice(&mut self, other: &[u8]) {
193 58 : BytesMut::extend_from_slice(self, other)
194 58 : }
195 :
196 : #[inline(always)]
197 183 : fn pending(&self) -> usize {
198 183 : self.len()
199 183 : }
200 :
201 29 : fn flush(self) -> Slice<BytesMut> {
202 29 : if self.is_empty() {
203 0 : return self.slice_full();
204 29 : }
205 29 : let len = self.len();
206 29 : self.slice(0..len)
207 29 : }
208 :
209 29 : fn reuse_after_flush(mut iobuf: BytesMut) -> Self {
210 29 : iobuf.clear();
211 29 : iobuf
212 29 : }
213 : }
214 :
215 : impl OwnedAsyncWriter for Vec<u8> {
216 0 : async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
217 0 : &mut self,
218 0 : buf: B,
219 0 : _: &RequestContext,
220 0 : ) -> std::io::Result<(usize, B::Buf)> {
221 0 : let nbytes = buf.bytes_init();
222 0 : if nbytes == 0 {
223 0 : return Ok((0, Slice::into_inner(buf.slice_full())));
224 0 : }
225 0 : let buf = buf.slice(0..nbytes);
226 0 : self.extend_from_slice(&buf[..]);
227 0 : Ok((buf.len(), Slice::into_inner(buf)))
228 0 : }
229 : }
230 :
231 : #[cfg(test)]
232 : mod tests {
233 : use bytes::BytesMut;
234 :
235 : use super::*;
236 : use crate::context::{DownloadBehavior, RequestContext};
237 : use crate::task_mgr::TaskKind;
238 :
239 : #[derive(Default)]
240 : struct RecorderWriter {
241 : writes: Vec<Vec<u8>>,
242 : }
243 : impl OwnedAsyncWriter for RecorderWriter {
244 34 : async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
245 34 : &mut self,
246 34 : buf: B,
247 34 : _: &RequestContext,
248 34 : ) -> std::io::Result<(usize, B::Buf)> {
249 34 : let nbytes = buf.bytes_init();
250 34 : if nbytes == 0 {
251 0 : self.writes.push(vec![]);
252 0 : return Ok((0, Slice::into_inner(buf.slice_full())));
253 34 : }
254 34 : let buf = buf.slice(0..nbytes);
255 34 : self.writes.push(Vec::from(&buf[..]));
256 34 : Ok((buf.len(), Slice::into_inner(buf)))
257 34 : }
258 : }
259 :
260 34 : fn test_ctx() -> RequestContext {
261 34 : RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
262 34 : }
263 :
264 : macro_rules! write {
265 : ($writer:ident, $data:literal) => {{
266 : $writer
267 : .write_buffered(::bytes::Bytes::from_static($data).slice_full(), &test_ctx())
268 : .await?;
269 : }};
270 : }
271 :
272 : #[tokio::test]
273 2 : async fn test_buffered_writes_only() -> std::io::Result<()> {
274 2 : let recorder = RecorderWriter::default();
275 2 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
276 2 : write!(writer, b"a");
277 2 : write!(writer, b"b");
278 2 : write!(writer, b"c");
279 2 : write!(writer, b"d");
280 2 : write!(writer, b"e");
281 2 : let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
282 2 : assert_eq!(
283 2 : recorder.writes,
284 2 : vec![Vec::from(b"ab"), Vec::from(b"cd"), Vec::from(b"e")]
285 2 : );
286 2 : Ok(())
287 2 : }
288 :
289 : #[tokio::test]
290 2 : async fn test_passthrough_writes_only() -> std::io::Result<()> {
291 2 : let recorder = RecorderWriter::default();
292 2 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
293 2 : write!(writer, b"abc");
294 2 : write!(writer, b"de");
295 2 : write!(writer, b"");
296 2 : write!(writer, b"fghijk");
297 2 : let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
298 2 : assert_eq!(
299 2 : recorder.writes,
300 2 : vec![Vec::from(b"abc"), Vec::from(b"de"), Vec::from(b"fghijk")]
301 2 : );
302 2 : Ok(())
303 2 : }
304 :
305 : #[tokio::test]
306 2 : async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> {
307 2 : let recorder = RecorderWriter::default();
308 2 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
309 2 : write!(writer, b"a");
310 2 : write!(writer, b"bc");
311 2 : write!(writer, b"d");
312 2 : write!(writer, b"e");
313 2 : let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
314 2 : assert_eq!(
315 2 : recorder.writes,
316 2 : vec![Vec::from(b"a"), Vec::from(b"bc"), Vec::from(b"de")]
317 2 : );
318 2 : Ok(())
319 2 : }
320 :
321 : #[tokio::test]
322 2 : async fn test_write_all_borrowed_always_goes_through_buffer() -> std::io::Result<()> {
323 2 : let ctx = test_ctx();
324 2 : let ctx = &ctx;
325 2 : let recorder = RecorderWriter::default();
326 2 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
327 2 :
328 2 : writer.write_buffered_borrowed(b"abc", ctx).await?;
329 2 : writer.write_buffered_borrowed(b"d", ctx).await?;
330 2 : writer.write_buffered_borrowed(b"e", ctx).await?;
331 2 : writer.write_buffered_borrowed(b"fg", ctx).await?;
332 2 : writer.write_buffered_borrowed(b"hi", ctx).await?;
333 2 : writer.write_buffered_borrowed(b"j", ctx).await?;
334 2 : writer.write_buffered_borrowed(b"klmno", ctx).await?;
335 2 :
336 2 : let recorder = writer.flush_and_into_inner(ctx).await?;
337 2 : assert_eq!(
338 2 : recorder.writes,
339 2 : {
340 2 : let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
341 2 : expect
342 2 : }
343 2 : .iter()
344 16 : .map(|v| v[..].to_vec())
345 2 : .collect::<Vec<_>>()
346 2 : );
347 2 : Ok(())
348 2 : }
349 : }
|