Line data Source code
1 : mod flush;
2 : use std::sync::Arc;
3 :
4 : pub(crate) use flush::FlushControl;
5 : use flush::FlushHandle;
6 : use tokio_epoll_uring::IoBuf;
7 :
8 : use super::io_buf_aligned::IoBufAligned;
9 : use super::io_buf_ext::{FullSlice, IoBufExt};
10 : use crate::context::RequestContext;
11 : use crate::virtual_file::{IoBuffer, IoBufferMut};
12 :
13 : pub(crate) trait CheapCloneForRead {
14 : /// Returns a cheap clone of the buffer.
15 : fn cheap_clone(&self) -> Self;
16 : }
17 :
18 : impl CheapCloneForRead for IoBuffer {
19 13246 : fn cheap_clone(&self) -> Self {
20 13246 : // Cheap clone over an `Arc`.
21 13246 : self.clone()
22 13246 : }
23 : }
24 :
25 : /// A trait for doing owned-buffer write IO.
26 : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
27 : /// The owned buffers need to be aligned due to Direct IO requirements.
28 : pub trait OwnedAsyncWriter {
29 : fn write_all_at<Buf: IoBufAligned + Send>(
30 : &self,
31 : buf: FullSlice<Buf>,
32 : offset: u64,
33 : ctx: &RequestContext,
34 : ) -> impl std::future::Future<Output = (FullSlice<Buf>, std::io::Result<()>)> + Send;
35 : }
36 :
37 : /// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
38 : /// small writes into larger writes of size [`Buffer::cap`].
39 : // TODO(yuchen): For large write, implementing buffer bypass for aligned parts of the write could be beneficial to throughput,
40 : // since we would avoid copying majority of the data into the internal buffer.
41 : pub struct BufferedWriter<B: Buffer, W> {
42 : writer: Arc<W>,
43 : /// invariant: always remains Some(buf) except
44 : /// - while IO is ongoing => goes back to Some() once the IO completed successfully
45 : /// - after an IO error => stays `None` forever
46 : ///
47 : /// In these exceptional cases, it's `None`.
48 : mutable: Option<B>,
49 : /// A handle to the background flush task for writting data to disk.
50 : flush_handle: FlushHandle<B::IoBuf, W>,
51 : /// The number of bytes submitted to the background task.
52 : bytes_submitted: u64,
53 : }
54 :
55 : impl<B, Buf, W> BufferedWriter<B, W>
56 : where
57 : B: Buffer<IoBuf = Buf> + Send + 'static,
58 : Buf: IoBufAligned + Send + Sync + CheapCloneForRead,
59 : W: OwnedAsyncWriter + Send + Sync + 'static + std::fmt::Debug,
60 : {
61 : /// Creates a new buffered writer.
62 : ///
63 : /// The `buf_new` function provides a way to initialize the owned buffers used by this writer.
64 2650 : pub fn new(
65 2650 : writer: Arc<W>,
66 2650 : buf_new: impl Fn() -> B,
67 2650 : gate_guard: utils::sync::gate::GateGuard,
68 2650 : ctx: &RequestContext,
69 2650 : flush_task_span: tracing::Span,
70 2650 : ) -> Self {
71 2650 : Self {
72 2650 : writer: writer.clone(),
73 2650 : mutable: Some(buf_new()),
74 2650 : flush_handle: FlushHandle::spawn_new(
75 2650 : writer,
76 2650 : buf_new(),
77 2650 : gate_guard,
78 2650 : ctx.attached_child(),
79 2650 : flush_task_span,
80 2650 : ),
81 2650 : bytes_submitted: 0,
82 2650 : }
83 2650 : }
84 :
85 22705 : pub fn as_inner(&self) -> &W {
86 22705 : &self.writer
87 22705 : }
88 :
89 : /// Returns the number of bytes submitted to the background flush task.
90 997826 : pub fn bytes_submitted(&self) -> u64 {
91 997826 : self.bytes_submitted
92 997826 : }
93 :
94 : /// Panics if used after any of the write paths returned an error
95 997846 : pub fn inspect_mutable(&self) -> &B {
96 997846 : self.mutable()
97 997846 : }
98 :
99 : /// Gets a reference to the maybe flushed read-only buffer.
100 : /// Returns `None` if the writer has not submitted any flush request.
101 997834 : pub fn inspect_maybe_flushed(&self) -> Option<&FullSlice<Buf>> {
102 997834 : self.flush_handle.maybe_flushed.as_ref()
103 997834 : }
104 :
105 : #[cfg_attr(target_os = "macos", allow(dead_code))]
106 18 : pub async fn flush_and_into_inner(
107 18 : mut self,
108 18 : ctx: &RequestContext,
109 18 : ) -> std::io::Result<(u64, Arc<W>)> {
110 18 : self.flush(ctx).await?;
111 :
112 : let Self {
113 18 : mutable: buf,
114 18 : writer,
115 18 : mut flush_handle,
116 18 : bytes_submitted: bytes_amount,
117 18 : } = self;
118 18 : flush_handle.shutdown().await?;
119 18 : assert!(buf.is_some());
120 18 : Ok((bytes_amount, writer))
121 18 : }
122 :
123 : /// Gets a reference to the mutable in-memory buffer.
124 : #[inline(always)]
125 997846 : fn mutable(&self) -> &B {
126 997846 : self.mutable
127 997846 : .as_ref()
128 997846 : .expect("must not use after we returned an error")
129 997846 : }
130 :
131 : #[cfg_attr(target_os = "macos", allow(dead_code))]
132 116 : pub async fn write_buffered_borrowed(
133 116 : &mut self,
134 116 : chunk: &[u8],
135 116 : ctx: &RequestContext,
136 116 : ) -> std::io::Result<usize> {
137 116 : let (len, control) = self.write_buffered_borrowed_controlled(chunk, ctx).await?;
138 116 : if let Some(control) = control {
139 24 : control.release().await;
140 92 : }
141 116 : Ok(len)
142 116 : }
143 :
144 : /// In addition to bytes submitted in this write, also returns a handle that can control the flush behavior.
145 9609884 : pub(crate) async fn write_buffered_borrowed_controlled(
146 9609884 : &mut self,
147 9609884 : mut chunk: &[u8],
148 9609884 : ctx: &RequestContext,
149 9609884 : ) -> std::io::Result<(usize, Option<FlushControl>)> {
150 9609884 : let chunk_len = chunk.len();
151 9609884 : let mut control: Option<FlushControl> = None;
152 19232968 : while !chunk.is_empty() {
153 9623084 : let buf = self.mutable.as_mut().expect("must not use after an error");
154 9623084 : let need = buf.cap() - buf.pending();
155 9623084 : let have = chunk.len();
156 9623084 : let n = std::cmp::min(need, have);
157 9623084 : buf.extend_from_slice(&chunk[..n]);
158 9623084 : chunk = &chunk[n..];
159 9623084 : if buf.pending() >= buf.cap() {
160 13228 : assert_eq!(buf.pending(), buf.cap());
161 13228 : if let Some(control) = control.take() {
162 2132 : control.release().await;
163 11096 : }
164 13228 : control = self.flush(ctx).await?;
165 9609856 : }
166 : }
167 9609884 : Ok((chunk_len, control))
168 9609884 : }
169 :
170 : #[must_use = "caller must explcitly check the flush control"]
171 13246 : async fn flush(&mut self, _ctx: &RequestContext) -> std::io::Result<Option<FlushControl>> {
172 13246 : let buf = self.mutable.take().expect("must not use after an error");
173 13246 : let buf_len = buf.pending();
174 13246 : if buf_len == 0 {
175 0 : self.mutable = Some(buf);
176 0 : return Ok(None);
177 13246 : }
178 13246 : let (recycled, flush_control) = self.flush_handle.flush(buf, self.bytes_submitted).await?;
179 13246 : self.bytes_submitted += u64::try_from(buf_len).unwrap();
180 13246 : self.mutable = Some(recycled);
181 13246 : Ok(Some(flush_control))
182 13246 : }
183 : }
184 :
185 : /// A [`Buffer`] is used by [`BufferedWriter`] to batch smaller writes into larger ones.
186 : pub trait Buffer {
187 : type IoBuf: IoBuf;
188 :
189 : /// Capacity of the buffer. Must not change over the lifetime `self`.`
190 : fn cap(&self) -> usize;
191 :
192 : /// Add data to the buffer.
193 : /// Panics if there is not enough room to accomodate `other`'s content, i.e.,
194 : /// panics if `other.len() > self.cap() - self.pending()`.
195 : fn extend_from_slice(&mut self, other: &[u8]);
196 :
197 : /// Number of bytes in the buffer.
198 : fn pending(&self) -> usize;
199 :
200 : /// Turns `self` into a [`FullSlice`] of the pending data
201 : /// so we can use [`tokio_epoll_uring`] to write it to disk.
202 : fn flush(self) -> FullSlice<Self::IoBuf>;
203 :
204 : /// After the write to disk is done and we have gotten back the slice,
205 : /// [`BufferedWriter`] uses this method to re-use the io buffer.
206 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self;
207 : }
208 :
209 : impl Buffer for IoBufferMut {
210 : type IoBuf = IoBuffer;
211 :
212 28882480 : fn cap(&self) -> usize {
213 28882480 : self.capacity()
214 28882480 : }
215 :
216 9623084 : fn extend_from_slice(&mut self, other: &[u8]) {
217 9623084 : if self.len() + other.len() > self.cap() {
218 0 : panic!("Buffer capacity exceeded");
219 9623084 : }
220 9623084 :
221 9623084 : IoBufferMut::extend_from_slice(self, other);
222 9623084 : }
223 :
224 20270468 : fn pending(&self) -> usize {
225 20270468 : self.len()
226 20270468 : }
227 :
228 15872 : fn flush(self) -> FullSlice<Self::IoBuf> {
229 15872 : self.freeze().slice_len()
230 15872 : }
231 :
232 : /// Caller should make sure that `iobuf` only have one strong reference before invoking this method.
233 13246 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self {
234 13246 : let mut recycled = iobuf
235 13246 : .into_mut()
236 13246 : .expect("buffer should only have one strong reference");
237 13246 : recycled.clear();
238 13246 : recycled
239 13246 : }
240 : }
241 :
242 : #[cfg(test)]
243 : mod tests {
244 : use std::sync::Mutex;
245 :
246 : use super::*;
247 : use crate::context::{DownloadBehavior, RequestContext};
248 : use crate::task_mgr::TaskKind;
249 :
250 : #[derive(Default, Debug)]
251 : struct RecorderWriter {
252 : /// record bytes and write offsets.
253 : writes: Mutex<Vec<(Vec<u8>, u64)>>,
254 : }
255 :
256 : impl RecorderWriter {
257 : /// Gets recorded bytes and write offsets.
258 4 : fn get_writes(&self) -> Vec<Vec<u8>> {
259 4 : self.writes
260 4 : .lock()
261 4 : .unwrap()
262 4 : .iter()
263 32 : .map(|(buf, _)| buf.clone())
264 4 : .collect()
265 4 : }
266 : }
267 :
268 : impl OwnedAsyncWriter for RecorderWriter {
269 32 : async fn write_all_at<Buf: IoBufAligned + Send>(
270 32 : &self,
271 32 : buf: FullSlice<Buf>,
272 32 : offset: u64,
273 32 : _: &RequestContext,
274 32 : ) -> (FullSlice<Buf>, std::io::Result<()>) {
275 32 : self.writes
276 32 : .lock()
277 32 : .unwrap()
278 32 : .push((Vec::from(&buf[..]), offset));
279 32 : (buf, Ok(()))
280 32 : }
281 : }
282 :
283 4 : fn test_ctx() -> RequestContext {
284 4 : RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
285 4 : }
286 :
287 : #[tokio::test]
288 4 : async fn test_write_all_borrowed_always_goes_through_buffer() -> anyhow::Result<()> {
289 4 : let ctx = test_ctx();
290 4 : let ctx = &ctx;
291 4 : let recorder = Arc::new(RecorderWriter::default());
292 4 : let gate = utils::sync::gate::Gate::default();
293 4 : let mut writer = BufferedWriter::<_, RecorderWriter>::new(
294 4 : recorder,
295 8 : || IoBufferMut::with_capacity(2),
296 4 : gate.enter()?,
297 4 : ctx,
298 4 : tracing::Span::none(),
299 4 : );
300 4 :
301 4 : writer.write_buffered_borrowed(b"abc", ctx).await?;
302 4 : writer.write_buffered_borrowed(b"", ctx).await?;
303 4 : writer.write_buffered_borrowed(b"d", ctx).await?;
304 4 : writer.write_buffered_borrowed(b"e", ctx).await?;
305 4 : writer.write_buffered_borrowed(b"fg", ctx).await?;
306 4 : writer.write_buffered_borrowed(b"hi", ctx).await?;
307 4 : writer.write_buffered_borrowed(b"j", ctx).await?;
308 4 : writer.write_buffered_borrowed(b"klmno", ctx).await?;
309 4 :
310 4 : let (_, recorder) = writer.flush_and_into_inner(ctx).await?;
311 4 : assert_eq!(
312 4 : recorder.get_writes(),
313 4 : {
314 4 : let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
315 4 : expect
316 4 : }
317 4 : .iter()
318 32 : .map(|v| v[..].to_vec())
319 4 : .collect::<Vec<_>>()
320 4 : );
321 4 : Ok(())
322 4 : }
323 : }
|