Line data Source code
1 : use bytes::BytesMut;
2 : use tokio_epoll_uring::IoBuf;
3 :
4 : use crate::context::RequestContext;
5 :
6 : use super::io_buf_ext::{FullSlice, IoBufExt};
7 :
8 : /// A trait for doing owned-buffer write IO.
9 : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
10 : pub trait OwnedAsyncWriter {
11 : async fn write_all<Buf: IoBuf + Send>(
12 : &mut self,
13 : buf: FullSlice<Buf>,
14 : ctx: &RequestContext,
15 : ) -> std::io::Result<(usize, FullSlice<Buf>)>;
16 : }
17 :
18 : /// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
19 : /// small writes into larger writes of size [`Buffer::cap`].
20 : ///
21 : /// # Passthrough Of Large Writers
22 : ///
23 : /// Calls to [`BufferedWriter::write_buffered`] that are larger than [`Buffer::cap`]
24 : /// cause the internal buffer to be flushed prematurely so that the large
25 : /// buffered write is passed through to the underlying [`OwnedAsyncWriter`].
26 : ///
27 : /// This pass-through is generally beneficial for throughput, but if
28 : /// the storage backend of the [`OwnedAsyncWriter`] is a shared resource,
29 : /// unlimited large writes may cause latency or fairness issues.
30 : ///
31 : /// In such cases, a different implementation that always buffers in memory
32 : /// may be preferable.
33 : pub struct BufferedWriter<B, W> {
34 : writer: W,
35 : /// invariant: always remains Some(buf) except
36 : /// - while IO is ongoing => goes back to Some() once the IO completed successfully
37 : /// - after an IO error => stays `None` forever
38 : ///
39 : /// In these exceptional cases, it's `None`.
40 : buf: Option<B>,
41 : }
42 :
43 : impl<B, Buf, W> BufferedWriter<B, W>
44 : where
45 : B: Buffer<IoBuf = Buf> + Send,
46 : Buf: IoBuf + Send,
47 : W: OwnedAsyncWriter,
48 : {
49 3873 : pub fn new(writer: W, buf: B) -> Self {
50 3873 : Self {
51 3873 : writer,
52 3873 : buf: Some(buf),
53 3873 : }
54 3873 : }
55 :
56 2086417 : pub fn as_inner(&self) -> &W {
57 2086417 : &self.writer
58 2086417 : }
59 :
60 : /// Panics if used after any of the write paths returned an error
61 2082979 : pub fn inspect_buffer(&self) -> &B {
62 2082979 : self.buf()
63 2082979 : }
64 :
65 : #[cfg_attr(target_os = "macos", allow(dead_code))]
66 33 : pub async fn flush_and_into_inner(mut self, ctx: &RequestContext) -> std::io::Result<W> {
67 33 : self.flush(ctx).await?;
68 :
69 33 : let Self { buf, writer } = self;
70 33 : assert!(buf.is_some());
71 33 : Ok(writer)
72 33 : }
73 :
74 : #[inline(always)]
75 2083219 : fn buf(&self) -> &B {
76 2083219 : self.buf
77 2083219 : .as_ref()
78 2083219 : .expect("must not use after we returned an error")
79 2083219 : }
80 :
81 : /// Guarantees that if Ok() is returned, all bytes in `chunk` have been accepted.
82 : #[cfg_attr(target_os = "macos", allow(dead_code))]
83 132 : pub async fn write_buffered<S: IoBuf + Send>(
84 132 : &mut self,
85 132 : chunk: FullSlice<S>,
86 132 : ctx: &RequestContext,
87 132 : ) -> std::io::Result<(usize, FullSlice<S>)> {
88 132 : let chunk = chunk.into_raw_slice();
89 132 :
90 132 : let chunk_len = chunk.len();
91 132 : // avoid memcpy for the middle of the chunk
92 132 : if chunk.len() >= self.buf().cap() {
93 24 : self.flush(ctx).await?;
94 : // do a big write, bypassing `buf`
95 24 : assert_eq!(
96 24 : self.buf
97 24 : .as_ref()
98 24 : .expect("must not use after an error")
99 24 : .pending(),
100 24 : 0
101 24 : );
102 24 : let (nwritten, chunk) = self
103 24 : .writer
104 24 : .write_all(FullSlice::must_new(chunk), ctx)
105 0 : .await?;
106 24 : assert_eq!(nwritten, chunk_len);
107 24 : return Ok((nwritten, chunk));
108 108 : }
109 108 : // in-memory copy the < BUFFER_SIZED tail of the chunk
110 108 : assert!(chunk.len() < self.buf().cap());
111 108 : let mut slice = &chunk[..];
112 210 : while !slice.is_empty() {
113 102 : let buf = self.buf.as_mut().expect("must not use after an error");
114 102 : let need = buf.cap() - buf.pending();
115 102 : let have = slice.len();
116 102 : let n = std::cmp::min(need, have);
117 102 : buf.extend_from_slice(&slice[..n]);
118 102 : slice = &slice[n..];
119 102 : if buf.pending() >= buf.cap() {
120 18 : assert_eq!(buf.pending(), buf.cap());
121 18 : self.flush(ctx).await?;
122 84 : }
123 : }
124 108 : assert!(slice.is_empty(), "by now we should have drained the chunk");
125 108 : Ok((chunk_len, FullSlice::must_new(chunk)))
126 132 : }
127 :
128 : /// Strictly less performant variant of [`Self::write_buffered`] that allows writing borrowed data.
129 : ///
130 : /// It is less performant because we always have to copy the borrowed data into the internal buffer
131 : /// before we can do the IO. The [`Self::write_buffered`] can avoid this, which is more performant
132 : /// for large writes.
133 15002526 : pub async fn write_buffered_borrowed(
134 15002526 : &mut self,
135 15002526 : mut chunk: &[u8],
136 15002526 : ctx: &RequestContext,
137 15002526 : ) -> std::io::Result<usize> {
138 15002526 : let chunk_len = chunk.len();
139 30024846 : while !chunk.is_empty() {
140 15022320 : let buf = self.buf.as_mut().expect("must not use after an error");
141 15022320 : let need = buf.cap() - buf.pending();
142 15022320 : let have = chunk.len();
143 15022320 : let n = std::cmp::min(need, have);
144 15022320 : buf.extend_from_slice(&chunk[..n]);
145 15022320 : chunk = &chunk[n..];
146 15022320 : if buf.pending() >= buf.cap() {
147 19824 : assert_eq!(buf.pending(), buf.cap());
148 19824 : self.flush(ctx).await?;
149 15002496 : }
150 : }
151 15002526 : Ok(chunk_len)
152 15002526 : }
153 :
154 19899 : async fn flush(&mut self, ctx: &RequestContext) -> std::io::Result<()> {
155 19899 : let buf = self.buf.take().expect("must not use after an error");
156 19899 : let buf_len = buf.pending();
157 19899 : if buf_len == 0 {
158 30 : self.buf = Some(buf);
159 30 : return Ok(());
160 19869 : }
161 19869 : let slice = buf.flush();
162 19869 : let (nwritten, slice) = self.writer.write_all(slice, ctx).await?;
163 19869 : assert_eq!(nwritten, buf_len);
164 19869 : self.buf = Some(Buffer::reuse_after_flush(
165 19869 : slice.into_raw_slice().into_inner(),
166 19869 : ));
167 19869 : Ok(())
168 19899 : }
169 : }
170 :
171 : /// A [`Buffer`] is used by [`BufferedWriter`] to batch smaller writes into larger ones.
172 : pub trait Buffer {
173 : type IoBuf: IoBuf;
174 :
175 : /// Capacity of the buffer. Must not change over the lifetime `self`.`
176 : fn cap(&self) -> usize;
177 :
178 : /// Add data to the buffer.
179 : /// Panics if there is not enough room to accomodate `other`'s content, i.e.,
180 : /// panics if `other.len() > self.cap() - self.pending()`.
181 : fn extend_from_slice(&mut self, other: &[u8]);
182 :
183 : /// Number of bytes in the buffer.
184 : fn pending(&self) -> usize;
185 :
186 : /// Turns `self` into a [`FullSlice`] of the pending data
187 : /// so we can use [`tokio_epoll_uring`] to write it to disk.
188 : fn flush(self) -> FullSlice<Self::IoBuf>;
189 :
190 : /// After the write to disk is done and we have gotten back the slice,
191 : /// [`BufferedWriter`] uses this method to re-use the io buffer.
192 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self;
193 : }
194 :
195 : impl Buffer for BytesMut {
196 : type IoBuf = BytesMut;
197 :
198 : #[inline(always)]
199 30064926 : fn cap(&self) -> usize {
200 30064926 : self.capacity()
201 30064926 : }
202 :
203 15022422 : fn extend_from_slice(&mut self, other: &[u8]) {
204 15022422 : BytesMut::extend_from_slice(self, other)
205 15022422 : }
206 :
207 : #[inline(always)]
208 32167558 : fn pending(&self) -> usize {
209 32167558 : self.len()
210 32167558 : }
211 :
212 19869 : fn flush(self) -> FullSlice<BytesMut> {
213 19869 : self.slice_len()
214 19869 : }
215 :
216 19869 : fn reuse_after_flush(mut iobuf: BytesMut) -> Self {
217 19869 : iobuf.clear();
218 19869 : iobuf
219 19869 : }
220 : }
221 :
222 : impl OwnedAsyncWriter for Vec<u8> {
223 0 : async fn write_all<Buf: IoBuf + Send>(
224 0 : &mut self,
225 0 : buf: FullSlice<Buf>,
226 0 : _: &RequestContext,
227 0 : ) -> std::io::Result<(usize, FullSlice<Buf>)> {
228 0 : self.extend_from_slice(&buf[..]);
229 0 : Ok((buf.len(), buf))
230 0 : }
231 : }
232 :
233 : #[cfg(test)]
234 : mod tests {
235 : use bytes::BytesMut;
236 :
237 : use super::*;
238 : use crate::context::{DownloadBehavior, RequestContext};
239 : use crate::task_mgr::TaskKind;
240 :
241 : #[derive(Default)]
242 : struct RecorderWriter {
243 : writes: Vec<Vec<u8>>,
244 : }
245 : impl OwnedAsyncWriter for RecorderWriter {
246 102 : async fn write_all<Buf: IoBuf + Send>(
247 102 : &mut self,
248 102 : buf: FullSlice<Buf>,
249 102 : _: &RequestContext,
250 102 : ) -> std::io::Result<(usize, FullSlice<Buf>)> {
251 102 : self.writes.push(Vec::from(&buf[..]));
252 102 : Ok((buf.len(), buf))
253 102 : }
254 : }
255 :
256 102 : fn test_ctx() -> RequestContext {
257 102 : RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
258 102 : }
259 :
260 : macro_rules! write {
261 : ($writer:ident, $data:literal) => {{
262 : $writer
263 : .write_buffered(::bytes::Bytes::from_static($data).slice_len(), &test_ctx())
264 : .await?;
265 : }};
266 : }
267 :
268 : #[tokio::test]
269 6 : async fn test_buffered_writes_only() -> std::io::Result<()> {
270 6 : let recorder = RecorderWriter::default();
271 6 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
272 6 : write!(writer, b"a");
273 6 : write!(writer, b"b");
274 6 : write!(writer, b"c");
275 6 : write!(writer, b"d");
276 6 : write!(writer, b"e");
277 6 : let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
278 6 : assert_eq!(
279 6 : recorder.writes,
280 6 : vec![Vec::from(b"ab"), Vec::from(b"cd"), Vec::from(b"e")]
281 6 : );
282 6 : Ok(())
283 6 : }
284 :
285 : #[tokio::test]
286 6 : async fn test_passthrough_writes_only() -> std::io::Result<()> {
287 6 : let recorder = RecorderWriter::default();
288 6 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
289 6 : write!(writer, b"abc");
290 6 : write!(writer, b"de");
291 6 : write!(writer, b"");
292 6 : write!(writer, b"fghijk");
293 6 : let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
294 6 : assert_eq!(
295 6 : recorder.writes,
296 6 : vec![Vec::from(b"abc"), Vec::from(b"de"), Vec::from(b"fghijk")]
297 6 : );
298 6 : Ok(())
299 6 : }
300 :
301 : #[tokio::test]
302 6 : async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> {
303 6 : let recorder = RecorderWriter::default();
304 6 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
305 6 : write!(writer, b"a");
306 6 : write!(writer, b"bc");
307 6 : write!(writer, b"d");
308 6 : write!(writer, b"e");
309 6 : let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
310 6 : assert_eq!(
311 6 : recorder.writes,
312 6 : vec![Vec::from(b"a"), Vec::from(b"bc"), Vec::from(b"de")]
313 6 : );
314 6 : Ok(())
315 6 : }
316 :
317 : #[tokio::test]
318 6 : async fn test_write_all_borrowed_always_goes_through_buffer() -> std::io::Result<()> {
319 6 : let ctx = test_ctx();
320 6 : let ctx = &ctx;
321 6 : let recorder = RecorderWriter::default();
322 6 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
323 6 :
324 6 : writer.write_buffered_borrowed(b"abc", ctx).await?;
325 6 : writer.write_buffered_borrowed(b"d", ctx).await?;
326 6 : writer.write_buffered_borrowed(b"e", ctx).await?;
327 6 : writer.write_buffered_borrowed(b"fg", ctx).await?;
328 6 : writer.write_buffered_borrowed(b"hi", ctx).await?;
329 6 : writer.write_buffered_borrowed(b"j", ctx).await?;
330 6 : writer.write_buffered_borrowed(b"klmno", ctx).await?;
331 6 :
332 6 : let recorder = writer.flush_and_into_inner(ctx).await?;
333 6 : assert_eq!(
334 6 : recorder.writes,
335 6 : {
336 6 : let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
337 6 : expect
338 6 : }
339 6 : .iter()
340 48 : .map(|v| v[..].to_vec())
341 6 : .collect::<Vec<_>>()
342 6 : );
343 6 : Ok(())
344 6 : }
345 : }
|