Line data Source code
1 : use bytes::BytesMut;
2 : use tokio_epoll_uring::IoBuf;
3 :
4 : use crate::context::RequestContext;
5 :
6 : use super::io_buf_ext::{FullSlice, IoBufExt};
7 :
8 : /// A trait for doing owned-buffer write IO.
9 : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
10 : pub trait OwnedAsyncWriter {
11 : async fn write_all<Buf: IoBuf + Send>(
12 : &mut self,
13 : buf: FullSlice<Buf>,
14 : ctx: &RequestContext,
15 : ) -> std::io::Result<(usize, FullSlice<Buf>)>;
16 : }
17 :
18 : /// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
19 : /// small writes into larger writes of size [`Buffer::cap`].
20 : ///
21 : /// # Passthrough Of Large Writers
22 : ///
23 : /// Calls to [`BufferedWriter::write_buffered`] that are larger than [`Buffer::cap`]
24 : /// cause the internal buffer to be flushed prematurely so that the large
25 : /// buffered write is passed through to the underlying [`OwnedAsyncWriter`].
26 : ///
27 : /// This pass-through is generally beneficial for throughput, but if
28 : /// the storage backend of the [`OwnedAsyncWriter`] is a shared resource,
29 : /// unlimited large writes may cause latency or fairness issues.
30 : ///
31 : /// In such cases, a different implementation that always buffers in memory
32 : /// may be preferable.
33 : pub struct BufferedWriter<B, W> {
34 : writer: W,
35 : /// invariant: always remains Some(buf) except
36 : /// - while IO is ongoing => goes back to Some() once the IO completed successfully
37 : /// - after an IO error => stays `None` forever
38 : ///
39 : /// In these exceptional cases, it's `None`.
40 : buf: Option<B>,
41 : }
42 :
43 : impl<B, Buf, W> BufferedWriter<B, W>
44 : where
45 : B: Buffer<IoBuf = Buf> + Send,
46 : Buf: IoBuf + Send,
47 : W: OwnedAsyncWriter,
48 : {
49 1285 : pub fn new(writer: W, buf: B) -> Self {
50 1285 : Self {
51 1285 : writer,
52 1285 : buf: Some(buf),
53 1285 : }
54 1285 : }
55 :
56 695386 : pub fn as_inner(&self) -> &W {
57 695386 : &self.writer
58 695386 : }
59 :
60 : /// Panics if used after any of the write paths returned an error
61 694240 : pub fn inspect_buffer(&self) -> &B {
62 694240 : self.buf()
63 694240 : }
64 :
65 : #[cfg_attr(target_os = "macos", allow(dead_code))]
66 11 : pub async fn flush_and_into_inner(mut self, ctx: &RequestContext) -> std::io::Result<W> {
67 11 : self.flush(ctx).await?;
68 :
69 11 : let Self { buf, writer } = self;
70 11 : assert!(buf.is_some());
71 11 : Ok(writer)
72 11 : }
73 :
74 : #[inline(always)]
75 694320 : fn buf(&self) -> &B {
76 694320 : self.buf
77 694320 : .as_ref()
78 694320 : .expect("must not use after we returned an error")
79 694320 : }
80 :
81 : /// Guarantees that if Ok() is returned, all bytes in `chunk` have been accepted.
82 : #[cfg_attr(target_os = "macos", allow(dead_code))]
83 44 : pub async fn write_buffered<S: IoBuf + Send>(
84 44 : &mut self,
85 44 : chunk: FullSlice<S>,
86 44 : ctx: &RequestContext,
87 44 : ) -> std::io::Result<(usize, FullSlice<S>)> {
88 44 : let chunk = chunk.into_raw_slice();
89 44 :
90 44 : let chunk_len = chunk.len();
91 44 : // avoid memcpy for the middle of the chunk
92 44 : if chunk.len() >= self.buf().cap() {
93 8 : self.flush(ctx).await?;
94 : // do a big write, bypassing `buf`
95 8 : assert_eq!(
96 8 : self.buf
97 8 : .as_ref()
98 8 : .expect("must not use after an error")
99 8 : .pending(),
100 8 : 0
101 8 : );
102 8 : let (nwritten, chunk) = self
103 8 : .writer
104 8 : .write_all(FullSlice::must_new(chunk), ctx)
105 0 : .await?;
106 8 : assert_eq!(nwritten, chunk_len);
107 8 : return Ok((nwritten, chunk));
108 36 : }
109 36 : // in-memory copy the < BUFFER_SIZED tail of the chunk
110 36 : assert!(chunk.len() < self.buf().cap());
111 36 : let mut slice = &chunk[..];
112 70 : while !slice.is_empty() {
113 34 : let buf = self.buf.as_mut().expect("must not use after an error");
114 34 : let need = buf.cap() - buf.pending();
115 34 : let have = slice.len();
116 34 : let n = std::cmp::min(need, have);
117 34 : buf.extend_from_slice(&slice[..n]);
118 34 : slice = &slice[n..];
119 34 : if buf.pending() >= buf.cap() {
120 6 : assert_eq!(buf.pending(), buf.cap());
121 6 : self.flush(ctx).await?;
122 28 : }
123 : }
124 36 : assert!(slice.is_empty(), "by now we should have drained the chunk");
125 36 : Ok((chunk_len, FullSlice::must_new(chunk)))
126 44 : }
127 :
128 : /// Strictly less performant variant of [`Self::write_buffered`] that allows writing borrowed data.
129 : ///
130 : /// It is less performant because we always have to copy the borrowed data into the internal buffer
131 : /// before we can do the IO. The [`Self::write_buffered`] can avoid this, which is more performant
132 : /// for large writes.
133 5000826 : pub async fn write_buffered_borrowed(
134 5000826 : &mut self,
135 5000826 : mut chunk: &[u8],
136 5000826 : ctx: &RequestContext,
137 5000826 : ) -> std::io::Result<usize> {
138 5000826 : let chunk_len = chunk.len();
139 10008250 : while !chunk.is_empty() {
140 5007424 : let buf = self.buf.as_mut().expect("must not use after an error");
141 5007424 : let need = buf.cap() - buf.pending();
142 5007424 : let have = chunk.len();
143 5007424 : let n = std::cmp::min(need, have);
144 5007424 : buf.extend_from_slice(&chunk[..n]);
145 5007424 : chunk = &chunk[n..];
146 5007424 : if buf.pending() >= buf.cap() {
147 6608 : assert_eq!(buf.pending(), buf.cap());
148 6608 : self.flush(ctx).await?;
149 5000816 : }
150 : }
151 5000826 : Ok(chunk_len)
152 5000826 : }
153 :
154 6633 : async fn flush(&mut self, ctx: &RequestContext) -> std::io::Result<()> {
155 6633 : let buf = self.buf.take().expect("must not use after an error");
156 6633 : let buf_len = buf.pending();
157 6633 : if buf_len == 0 {
158 10 : self.buf = Some(buf);
159 10 : return Ok(());
160 6623 : }
161 6623 : let slice = buf.flush();
162 6623 : let (nwritten, slice) = self.writer.write_all(slice, ctx).await?;
163 6623 : assert_eq!(nwritten, buf_len);
164 6623 : self.buf = Some(Buffer::reuse_after_flush(
165 6623 : slice.into_raw_slice().into_inner(),
166 6623 : ));
167 6623 : Ok(())
168 6633 : }
169 : }
170 :
171 : /// A [`Buffer`] is used by [`BufferedWriter`] to batch smaller writes into larger ones.
172 : pub trait Buffer {
173 : type IoBuf: IoBuf;
174 :
175 : /// Capacity of the buffer. Must not change over the lifetime `self`.`
176 : fn cap(&self) -> usize;
177 :
178 : /// Add data to the buffer.
179 : /// Panics if there is not enough room to accomodate `other`'s content, i.e.,
180 : /// panics if `other.len() > self.cap() - self.pending()`.
181 : fn extend_from_slice(&mut self, other: &[u8]);
182 :
183 : /// Number of bytes in the buffer.
184 : fn pending(&self) -> usize;
185 :
186 : /// Turns `self` into a [`FullSlice`] of the pending data
187 : /// so we can use [`tokio_epoll_uring`] to write it to disk.
188 : fn flush(self) -> FullSlice<Self::IoBuf>;
189 :
190 : /// After the write to disk is done and we have gotten back the slice,
191 : /// [`BufferedWriter`] uses this method to re-use the io buffer.
192 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self;
193 : }
194 :
195 : impl Buffer for BytesMut {
196 : type IoBuf = BytesMut;
197 :
198 : #[inline(always)]
199 10021610 : fn cap(&self) -> usize {
200 10021610 : self.capacity()
201 10021610 : }
202 :
203 5007458 : fn extend_from_slice(&mut self, other: &[u8]) {
204 5007458 : BytesMut::extend_from_slice(self, other)
205 5007458 : }
206 :
207 : #[inline(always)]
208 10722401 : fn pending(&self) -> usize {
209 10722401 : self.len()
210 10722401 : }
211 :
212 6623 : fn flush(self) -> FullSlice<BytesMut> {
213 6623 : self.slice_len()
214 6623 : }
215 :
216 6623 : fn reuse_after_flush(mut iobuf: BytesMut) -> Self {
217 6623 : iobuf.clear();
218 6623 : iobuf
219 6623 : }
220 : }
221 :
222 : impl OwnedAsyncWriter for Vec<u8> {
223 0 : async fn write_all<Buf: IoBuf + Send>(
224 0 : &mut self,
225 0 : buf: FullSlice<Buf>,
226 0 : _: &RequestContext,
227 0 : ) -> std::io::Result<(usize, FullSlice<Buf>)> {
228 0 : self.extend_from_slice(&buf[..]);
229 0 : Ok((buf.len(), buf))
230 0 : }
231 : }
232 :
233 : #[cfg(test)]
234 : mod tests {
235 : use bytes::BytesMut;
236 :
237 : use super::*;
238 : use crate::context::{DownloadBehavior, RequestContext};
239 : use crate::task_mgr::TaskKind;
240 :
241 : #[derive(Default)]
242 : struct RecorderWriter {
243 : writes: Vec<Vec<u8>>,
244 : }
245 : impl OwnedAsyncWriter for RecorderWriter {
246 34 : async fn write_all<Buf: IoBuf + Send>(
247 34 : &mut self,
248 34 : buf: FullSlice<Buf>,
249 34 : _: &RequestContext,
250 34 : ) -> std::io::Result<(usize, FullSlice<Buf>)> {
251 34 : self.writes.push(Vec::from(&buf[..]));
252 34 : Ok((buf.len(), buf))
253 34 : }
254 : }
255 :
256 34 : fn test_ctx() -> RequestContext {
257 34 : RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
258 34 : }
259 :
260 : macro_rules! write {
261 : ($writer:ident, $data:literal) => {{
262 : $writer
263 : .write_buffered(::bytes::Bytes::from_static($data).slice_len(), &test_ctx())
264 : .await?;
265 : }};
266 : }
267 :
268 : #[tokio::test]
269 2 : async fn test_buffered_writes_only() -> std::io::Result<()> {
270 2 : let recorder = RecorderWriter::default();
271 2 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
272 2 : write!(writer, b"a");
273 2 : write!(writer, b"b");
274 2 : write!(writer, b"c");
275 2 : write!(writer, b"d");
276 2 : write!(writer, b"e");
277 2 : let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
278 2 : assert_eq!(
279 2 : recorder.writes,
280 2 : vec![Vec::from(b"ab"), Vec::from(b"cd"), Vec::from(b"e")]
281 2 : );
282 2 : Ok(())
283 2 : }
284 :
285 : #[tokio::test]
286 2 : async fn test_passthrough_writes_only() -> std::io::Result<()> {
287 2 : let recorder = RecorderWriter::default();
288 2 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
289 2 : write!(writer, b"abc");
290 2 : write!(writer, b"de");
291 2 : write!(writer, b"");
292 2 : write!(writer, b"fghijk");
293 2 : let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
294 2 : assert_eq!(
295 2 : recorder.writes,
296 2 : vec![Vec::from(b"abc"), Vec::from(b"de"), Vec::from(b"fghijk")]
297 2 : );
298 2 : Ok(())
299 2 : }
300 :
301 : #[tokio::test]
302 2 : async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> {
303 2 : let recorder = RecorderWriter::default();
304 2 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
305 2 : write!(writer, b"a");
306 2 : write!(writer, b"bc");
307 2 : write!(writer, b"d");
308 2 : write!(writer, b"e");
309 2 : let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
310 2 : assert_eq!(
311 2 : recorder.writes,
312 2 : vec![Vec::from(b"a"), Vec::from(b"bc"), Vec::from(b"de")]
313 2 : );
314 2 : Ok(())
315 2 : }
316 :
317 : #[tokio::test]
318 2 : async fn test_write_all_borrowed_always_goes_through_buffer() -> std::io::Result<()> {
319 2 : let ctx = test_ctx();
320 2 : let ctx = &ctx;
321 2 : let recorder = RecorderWriter::default();
322 2 : let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
323 2 :
324 2 : writer.write_buffered_borrowed(b"abc", ctx).await?;
325 2 : writer.write_buffered_borrowed(b"d", ctx).await?;
326 2 : writer.write_buffered_borrowed(b"e", ctx).await?;
327 2 : writer.write_buffered_borrowed(b"fg", ctx).await?;
328 2 : writer.write_buffered_borrowed(b"hi", ctx).await?;
329 2 : writer.write_buffered_borrowed(b"j", ctx).await?;
330 2 : writer.write_buffered_borrowed(b"klmno", ctx).await?;
331 2 :
332 2 : let recorder = writer.flush_and_into_inner(ctx).await?;
333 2 : assert_eq!(
334 2 : recorder.writes,
335 2 : {
336 2 : let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
337 2 : expect
338 2 : }
339 2 : .iter()
340 16 : .map(|v| v[..].to_vec())
341 2 : .collect::<Vec<_>>()
342 2 : );
343 2 : Ok(())
344 2 : }
345 : }
|