Line data Source code
1 : mod flush;
2 : use std::sync::Arc;
3 :
4 : use flush::FlushHandle;
5 : use tokio_epoll_uring::IoBuf;
6 :
7 : use crate::{
8 : context::RequestContext,
9 : virtual_file::{IoBuffer, IoBufferMut},
10 : };
11 :
12 : use super::{
13 : io_buf_aligned::IoBufAligned,
14 : io_buf_ext::{FullSlice, IoBufExt},
15 : };
16 :
17 : pub(crate) use flush::FlushControl;
18 :
19 : pub(crate) trait CheapCloneForRead {
20 : /// Returns a cheap clone of the buffer.
21 : fn cheap_clone(&self) -> Self;
22 : }
23 :
24 : impl CheapCloneForRead for IoBuffer {
25 13246 : fn cheap_clone(&self) -> Self {
26 13246 : // Cheap clone over an `Arc`.
27 13246 : self.clone()
28 13246 : }
29 : }
30 :
31 : /// A trait for doing owned-buffer write IO.
32 : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
33 : /// The owned buffers need to be aligned due to Direct IO requirements.
34 : pub trait OwnedAsyncWriter {
35 : fn write_all_at<Buf: IoBufAligned + Send>(
36 : &self,
37 : buf: FullSlice<Buf>,
38 : offset: u64,
39 : ctx: &RequestContext,
40 : ) -> impl std::future::Future<Output = std::io::Result<FullSlice<Buf>>> + Send;
41 : }
42 :
43 : /// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
44 : /// small writes into larger writes of size [`Buffer::cap`].
45 : // TODO(yuchen): For large write, implementing buffer bypass for aligned parts of the write could be beneficial to throughput,
46 : // since we would avoid copying majority of the data into the internal buffer.
47 : pub struct BufferedWriter<B: Buffer, W> {
48 : writer: Arc<W>,
49 : /// invariant: always remains Some(buf) except
50 : /// - while IO is ongoing => goes back to Some() once the IO completed successfully
51 : /// - after an IO error => stays `None` forever
52 : ///
53 : /// In these exceptional cases, it's `None`.
54 : mutable: Option<B>,
55 : /// A handle to the background flush task for writting data to disk.
56 : flush_handle: FlushHandle<B::IoBuf, W>,
57 : /// The number of bytes submitted to the background task.
58 : bytes_submitted: u64,
59 : }
60 :
61 : impl<B, Buf, W> BufferedWriter<B, W>
62 : where
63 : B: Buffer<IoBuf = Buf> + Send + 'static,
64 : Buf: IoBufAligned + Send + Sync + CheapCloneForRead,
65 : W: OwnedAsyncWriter + Send + Sync + 'static + std::fmt::Debug,
66 : {
67 : /// Creates a new buffered writer.
68 : ///
69 : /// The `buf_new` function provides a way to initialize the owned buffers used by this writer.
70 2630 : pub fn new(
71 2630 : writer: Arc<W>,
72 2630 : buf_new: impl Fn() -> B,
73 2630 : gate_guard: utils::sync::gate::GateGuard,
74 2630 : ctx: &RequestContext,
75 2630 : ) -> Self {
76 2630 : Self {
77 2630 : writer: writer.clone(),
78 2630 : mutable: Some(buf_new()),
79 2630 : flush_handle: FlushHandle::spawn_new(
80 2630 : writer,
81 2630 : buf_new(),
82 2630 : gate_guard,
83 2630 : ctx.attached_child(),
84 2630 : ),
85 2630 : bytes_submitted: 0,
86 2630 : }
87 2630 : }
88 :
89 22681 : pub fn as_inner(&self) -> &W {
90 22681 : &self.writer
91 22681 : }
92 :
93 : /// Returns the number of bytes submitted to the background flush task.
94 997247 : pub fn bytes_submitted(&self) -> u64 {
95 997247 : self.bytes_submitted
96 997247 : }
97 :
98 : /// Panics if used after any of the write paths returned an error
99 997267 : pub fn inspect_mutable(&self) -> &B {
100 997267 : self.mutable()
101 997267 : }
102 :
103 : /// Gets a reference to the maybe flushed read-only buffer.
104 : /// Returns `None` if the writer has not submitted any flush request.
105 997255 : pub fn inspect_maybe_flushed(&self) -> Option<&FullSlice<Buf>> {
106 997255 : self.flush_handle.maybe_flushed.as_ref()
107 997255 : }
108 :
109 : #[cfg_attr(target_os = "macos", allow(dead_code))]
110 18 : pub async fn flush_and_into_inner(
111 18 : mut self,
112 18 : ctx: &RequestContext,
113 18 : ) -> std::io::Result<(u64, Arc<W>)> {
114 18 : self.flush(ctx).await?;
115 :
116 : let Self {
117 18 : mutable: buf,
118 18 : writer,
119 18 : mut flush_handle,
120 18 : bytes_submitted: bytes_amount,
121 18 : } = self;
122 18 : flush_handle.shutdown().await?;
123 18 : assert!(buf.is_some());
124 18 : Ok((bytes_amount, writer))
125 18 : }
126 :
127 : /// Gets a reference to the mutable in-memory buffer.
128 : #[inline(always)]
129 997267 : fn mutable(&self) -> &B {
130 997267 : self.mutable
131 997267 : .as_ref()
132 997267 : .expect("must not use after we returned an error")
133 997267 : }
134 :
135 : #[cfg_attr(target_os = "macos", allow(dead_code))]
136 116 : pub async fn write_buffered_borrowed(
137 116 : &mut self,
138 116 : chunk: &[u8],
139 116 : ctx: &RequestContext,
140 116 : ) -> std::io::Result<usize> {
141 116 : let (len, control) = self.write_buffered_borrowed_controlled(chunk, ctx).await?;
142 116 : if let Some(control) = control {
143 24 : control.release().await;
144 92 : }
145 116 : Ok(len)
146 116 : }
147 :
148 : /// In addition to bytes submitted in this write, also returns a handle that can control the flush behavior.
149 9609868 : pub(crate) async fn write_buffered_borrowed_controlled(
150 9609868 : &mut self,
151 9609868 : mut chunk: &[u8],
152 9609868 : ctx: &RequestContext,
153 9609868 : ) -> std::io::Result<(usize, Option<FlushControl>)> {
154 9609868 : let chunk_len = chunk.len();
155 9609868 : let mut control: Option<FlushControl> = None;
156 19232936 : while !chunk.is_empty() {
157 9623068 : let buf = self.mutable.as_mut().expect("must not use after an error");
158 9623068 : let need = buf.cap() - buf.pending();
159 9623068 : let have = chunk.len();
160 9623068 : let n = std::cmp::min(need, have);
161 9623068 : buf.extend_from_slice(&chunk[..n]);
162 9623068 : chunk = &chunk[n..];
163 9623068 : if buf.pending() >= buf.cap() {
164 13228 : assert_eq!(buf.pending(), buf.cap());
165 13228 : if let Some(control) = control.take() {
166 2132 : control.release().await;
167 11096 : }
168 13228 : control = self.flush(ctx).await?;
169 9609840 : }
170 : }
171 9609868 : Ok((chunk_len, control))
172 9609868 : }
173 :
174 : #[must_use = "caller must explcitly check the flush control"]
175 13246 : async fn flush(&mut self, _ctx: &RequestContext) -> std::io::Result<Option<FlushControl>> {
176 13246 : let buf = self.mutable.take().expect("must not use after an error");
177 13246 : let buf_len = buf.pending();
178 13246 : if buf_len == 0 {
179 0 : self.mutable = Some(buf);
180 0 : return Ok(None);
181 13246 : }
182 13246 : let (recycled, flush_control) = self.flush_handle.flush(buf, self.bytes_submitted).await?;
183 13246 : self.bytes_submitted += u64::try_from(buf_len).unwrap();
184 13246 : self.mutable = Some(recycled);
185 13246 : Ok(Some(flush_control))
186 13246 : }
187 : }
188 :
189 : /// A [`Buffer`] is used by [`BufferedWriter`] to batch smaller writes into larger ones.
190 : pub trait Buffer {
191 : type IoBuf: IoBuf;
192 :
193 : /// Capacity of the buffer. Must not change over the lifetime `self`.`
194 : fn cap(&self) -> usize;
195 :
196 : /// Add data to the buffer.
197 : /// Panics if there is not enough room to accomodate `other`'s content, i.e.,
198 : /// panics if `other.len() > self.cap() - self.pending()`.
199 : fn extend_from_slice(&mut self, other: &[u8]);
200 :
201 : /// Number of bytes in the buffer.
202 : fn pending(&self) -> usize;
203 :
204 : /// Turns `self` into a [`FullSlice`] of the pending data
205 : /// so we can use [`tokio_epoll_uring`] to write it to disk.
206 : fn flush(self) -> FullSlice<Self::IoBuf>;
207 :
208 : /// After the write to disk is done and we have gotten back the slice,
209 : /// [`BufferedWriter`] uses this method to re-use the io buffer.
210 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self;
211 : }
212 :
213 : impl Buffer for IoBufferMut {
214 : type IoBuf = IoBuffer;
215 :
216 28882432 : fn cap(&self) -> usize {
217 28882432 : self.capacity()
218 28882432 : }
219 :
220 9623068 : fn extend_from_slice(&mut self, other: &[u8]) {
221 9623068 : if self.len() + other.len() > self.cap() {
222 0 : panic!("Buffer capacity exceeded");
223 9623068 : }
224 9623068 :
225 9623068 : IoBufferMut::extend_from_slice(self, other);
226 9623068 : }
227 :
228 20269857 : fn pending(&self) -> usize {
229 20269857 : self.len()
230 20269857 : }
231 :
232 15861 : fn flush(self) -> FullSlice<Self::IoBuf> {
233 15861 : self.freeze().slice_len()
234 15861 : }
235 :
236 : /// Caller should make sure that `iobuf` only have one strong reference before invoking this method.
237 13246 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self {
238 13246 : let mut recycled = iobuf
239 13246 : .into_mut()
240 13246 : .expect("buffer should only have one strong reference");
241 13246 : recycled.clear();
242 13246 : recycled
243 13246 : }
244 : }
245 :
246 : #[cfg(test)]
247 : mod tests {
248 : use std::sync::Mutex;
249 :
250 : use super::*;
251 : use crate::context::{DownloadBehavior, RequestContext};
252 : use crate::task_mgr::TaskKind;
253 :
254 : #[derive(Default, Debug)]
255 : struct RecorderWriter {
256 : /// record bytes and write offsets.
257 : writes: Mutex<Vec<(Vec<u8>, u64)>>,
258 : }
259 :
260 : impl RecorderWriter {
261 : /// Gets recorded bytes and write offsets.
262 4 : fn get_writes(&self) -> Vec<Vec<u8>> {
263 4 : self.writes
264 4 : .lock()
265 4 : .unwrap()
266 4 : .iter()
267 32 : .map(|(buf, _)| buf.clone())
268 4 : .collect()
269 4 : }
270 : }
271 :
272 : impl OwnedAsyncWriter for RecorderWriter {
273 32 : async fn write_all_at<Buf: IoBufAligned + Send>(
274 32 : &self,
275 32 : buf: FullSlice<Buf>,
276 32 : offset: u64,
277 32 : _: &RequestContext,
278 32 : ) -> std::io::Result<FullSlice<Buf>> {
279 32 : self.writes
280 32 : .lock()
281 32 : .unwrap()
282 32 : .push((Vec::from(&buf[..]), offset));
283 32 : Ok(buf)
284 32 : }
285 : }
286 :
287 4 : fn test_ctx() -> RequestContext {
288 4 : RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
289 4 : }
290 :
291 : #[tokio::test]
292 4 : async fn test_write_all_borrowed_always_goes_through_buffer() -> anyhow::Result<()> {
293 4 : let ctx = test_ctx();
294 4 : let ctx = &ctx;
295 4 : let recorder = Arc::new(RecorderWriter::default());
296 4 : let gate = utils::sync::gate::Gate::default();
297 4 : let mut writer = BufferedWriter::<_, RecorderWriter>::new(
298 4 : recorder,
299 8 : || IoBufferMut::with_capacity(2),
300 4 : gate.enter()?,
301 4 : ctx,
302 4 : );
303 4 :
304 4 : writer.write_buffered_borrowed(b"abc", ctx).await?;
305 4 : writer.write_buffered_borrowed(b"", ctx).await?;
306 4 : writer.write_buffered_borrowed(b"d", ctx).await?;
307 4 : writer.write_buffered_borrowed(b"e", ctx).await?;
308 4 : writer.write_buffered_borrowed(b"fg", ctx).await?;
309 4 : writer.write_buffered_borrowed(b"hi", ctx).await?;
310 4 : writer.write_buffered_borrowed(b"j", ctx).await?;
311 4 : writer.write_buffered_borrowed(b"klmno", ctx).await?;
312 4 :
313 4 : let (_, recorder) = writer.flush_and_into_inner(ctx).await?;
314 4 : assert_eq!(
315 4 : recorder.get_writes(),
316 4 : {
317 4 : let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
318 4 : expect
319 4 : }
320 4 : .iter()
321 32 : .map(|v| v[..].to_vec())
322 4 : .collect::<Vec<_>>()
323 4 : );
324 4 : Ok(())
325 4 : }
326 : }
|