Line data Source code
1 : mod flush;
2 : use std::sync::Arc;
3 :
4 : pub(crate) use flush::FlushControl;
5 : use flush::FlushHandle;
6 : use tokio_epoll_uring::IoBuf;
7 :
8 : use super::io_buf_aligned::IoBufAligned;
9 : use super::io_buf_ext::{FullSlice, IoBufExt};
10 : use crate::context::RequestContext;
11 : use crate::virtual_file::{IoBuffer, IoBufferMut};
12 :
13 : pub(crate) trait CheapCloneForRead {
14 : /// Returns a cheap clone of the buffer.
15 : fn cheap_clone(&self) -> Self;
16 : }
17 :
18 : impl CheapCloneForRead for IoBuffer {
19 13246 : fn cheap_clone(&self) -> Self {
20 13246 : // Cheap clone over an `Arc`.
21 13246 : self.clone()
22 13246 : }
23 : }
24 :
25 : /// A trait for doing owned-buffer write IO.
26 : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
27 : /// The owned buffers need to be aligned due to Direct IO requirements.
28 : pub trait OwnedAsyncWriter {
29 : fn write_all_at<Buf: IoBufAligned + Send>(
30 : &self,
31 : buf: FullSlice<Buf>,
32 : offset: u64,
33 : ctx: &RequestContext,
34 : ) -> impl std::future::Future<Output = std::io::Result<FullSlice<Buf>>> + Send;
35 : }
36 :
37 : /// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
38 : /// small writes into larger writes of size [`Buffer::cap`].
39 : // TODO(yuchen): For large write, implementing buffer bypass for aligned parts of the write could be beneficial to throughput,
40 : // since we would avoid copying majority of the data into the internal buffer.
41 : pub struct BufferedWriter<B: Buffer, W> {
42 : writer: Arc<W>,
43 : /// invariant: always remains Some(buf) except
44 : /// - while IO is ongoing => goes back to Some() once the IO completed successfully
45 : /// - after an IO error => stays `None` forever
46 : ///
47 : /// In these exceptional cases, it's `None`.
48 : mutable: Option<B>,
49 : /// A handle to the background flush task for writting data to disk.
50 : flush_handle: FlushHandle<B::IoBuf, W>,
51 : /// The number of bytes submitted to the background task.
52 : bytes_submitted: u64,
53 : }
54 :
55 : impl<B, Buf, W> BufferedWriter<B, W>
56 : where
57 : B: Buffer<IoBuf = Buf> + Send + 'static,
58 : Buf: IoBufAligned + Send + Sync + CheapCloneForRead,
59 : W: OwnedAsyncWriter + Send + Sync + 'static + std::fmt::Debug,
60 : {
61 : /// Creates a new buffered writer.
62 : ///
63 : /// The `buf_new` function provides a way to initialize the owned buffers used by this writer.
64 2650 : pub fn new(
65 2650 : writer: Arc<W>,
66 2650 : buf_new: impl Fn() -> B,
67 2650 : gate_guard: utils::sync::gate::GateGuard,
68 2650 : ctx: &RequestContext,
69 2650 : ) -> Self {
70 2650 : Self {
71 2650 : writer: writer.clone(),
72 2650 : mutable: Some(buf_new()),
73 2650 : flush_handle: FlushHandle::spawn_new(
74 2650 : writer,
75 2650 : buf_new(),
76 2650 : gate_guard,
77 2650 : ctx.attached_child(),
78 2650 : ),
79 2650 : bytes_submitted: 0,
80 2650 : }
81 2650 : }
82 :
83 22736 : pub fn as_inner(&self) -> &W {
84 22736 : &self.writer
85 22736 : }
86 :
87 : /// Returns the number of bytes submitted to the background flush task.
88 997798 : pub fn bytes_submitted(&self) -> u64 {
89 997798 : self.bytes_submitted
90 997798 : }
91 :
92 : /// Panics if used after any of the write paths returned an error
93 997818 : pub fn inspect_mutable(&self) -> &B {
94 997818 : self.mutable()
95 997818 : }
96 :
97 : /// Gets a reference to the maybe flushed read-only buffer.
98 : /// Returns `None` if the writer has not submitted any flush request.
99 997806 : pub fn inspect_maybe_flushed(&self) -> Option<&FullSlice<Buf>> {
100 997806 : self.flush_handle.maybe_flushed.as_ref()
101 997806 : }
102 :
103 : #[cfg_attr(target_os = "macos", allow(dead_code))]
104 18 : pub async fn flush_and_into_inner(
105 18 : mut self,
106 18 : ctx: &RequestContext,
107 18 : ) -> std::io::Result<(u64, Arc<W>)> {
108 18 : self.flush(ctx).await?;
109 :
110 : let Self {
111 18 : mutable: buf,
112 18 : writer,
113 18 : mut flush_handle,
114 18 : bytes_submitted: bytes_amount,
115 18 : } = self;
116 18 : flush_handle.shutdown().await?;
117 18 : assert!(buf.is_some());
118 18 : Ok((bytes_amount, writer))
119 18 : }
120 :
121 : /// Gets a reference to the mutable in-memory buffer.
122 : #[inline(always)]
123 997818 : fn mutable(&self) -> &B {
124 997818 : self.mutable
125 997818 : .as_ref()
126 997818 : .expect("must not use after we returned an error")
127 997818 : }
128 :
129 : #[cfg_attr(target_os = "macos", allow(dead_code))]
130 116 : pub async fn write_buffered_borrowed(
131 116 : &mut self,
132 116 : chunk: &[u8],
133 116 : ctx: &RequestContext,
134 116 : ) -> std::io::Result<usize> {
135 116 : let (len, control) = self.write_buffered_borrowed_controlled(chunk, ctx).await?;
136 116 : if let Some(control) = control {
137 24 : control.release().await;
138 92 : }
139 116 : Ok(len)
140 116 : }
141 :
142 : /// In addition to bytes submitted in this write, also returns a handle that can control the flush behavior.
143 9609884 : pub(crate) async fn write_buffered_borrowed_controlled(
144 9609884 : &mut self,
145 9609884 : mut chunk: &[u8],
146 9609884 : ctx: &RequestContext,
147 9609884 : ) -> std::io::Result<(usize, Option<FlushControl>)> {
148 9609884 : let chunk_len = chunk.len();
149 9609884 : let mut control: Option<FlushControl> = None;
150 19232968 : while !chunk.is_empty() {
151 9623084 : let buf = self.mutable.as_mut().expect("must not use after an error");
152 9623084 : let need = buf.cap() - buf.pending();
153 9623084 : let have = chunk.len();
154 9623084 : let n = std::cmp::min(need, have);
155 9623084 : buf.extend_from_slice(&chunk[..n]);
156 9623084 : chunk = &chunk[n..];
157 9623084 : if buf.pending() >= buf.cap() {
158 13228 : assert_eq!(buf.pending(), buf.cap());
159 13228 : if let Some(control) = control.take() {
160 2132 : control.release().await;
161 11096 : }
162 13228 : control = self.flush(ctx).await?;
163 9609856 : }
164 : }
165 9609884 : Ok((chunk_len, control))
166 9609884 : }
167 :
168 : #[must_use = "caller must explcitly check the flush control"]
169 13246 : async fn flush(&mut self, _ctx: &RequestContext) -> std::io::Result<Option<FlushControl>> {
170 13246 : let buf = self.mutable.take().expect("must not use after an error");
171 13246 : let buf_len = buf.pending();
172 13246 : if buf_len == 0 {
173 0 : self.mutable = Some(buf);
174 0 : return Ok(None);
175 13246 : }
176 13246 : let (recycled, flush_control) = self.flush_handle.flush(buf, self.bytes_submitted).await?;
177 13246 : self.bytes_submitted += u64::try_from(buf_len).unwrap();
178 13246 : self.mutable = Some(recycled);
179 13246 : Ok(Some(flush_control))
180 13246 : }
181 : }
182 :
183 : /// A [`Buffer`] is used by [`BufferedWriter`] to batch smaller writes into larger ones.
184 : pub trait Buffer {
185 : type IoBuf: IoBuf;
186 :
187 : /// Capacity of the buffer. Must not change over the lifetime `self`.`
188 : fn cap(&self) -> usize;
189 :
190 : /// Add data to the buffer.
191 : /// Panics if there is not enough room to accomodate `other`'s content, i.e.,
192 : /// panics if `other.len() > self.cap() - self.pending()`.
193 : fn extend_from_slice(&mut self, other: &[u8]);
194 :
195 : /// Number of bytes in the buffer.
196 : fn pending(&self) -> usize;
197 :
198 : /// Turns `self` into a [`FullSlice`] of the pending data
199 : /// so we can use [`tokio_epoll_uring`] to write it to disk.
200 : fn flush(self) -> FullSlice<Self::IoBuf>;
201 :
202 : /// After the write to disk is done and we have gotten back the slice,
203 : /// [`BufferedWriter`] uses this method to re-use the io buffer.
204 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self;
205 : }
206 :
207 : impl Buffer for IoBufferMut {
208 : type IoBuf = IoBuffer;
209 :
210 28882480 : fn cap(&self) -> usize {
211 28882480 : self.capacity()
212 28882480 : }
213 :
214 9623084 : fn extend_from_slice(&mut self, other: &[u8]) {
215 9623084 : if self.len() + other.len() > self.cap() {
216 0 : panic!("Buffer capacity exceeded");
217 9623084 : }
218 9623084 :
219 9623084 : IoBufferMut::extend_from_slice(self, other);
220 9623084 : }
221 :
222 20270440 : fn pending(&self) -> usize {
223 20270440 : self.len()
224 20270440 : }
225 :
226 15873 : fn flush(self) -> FullSlice<Self::IoBuf> {
227 15873 : self.freeze().slice_len()
228 15873 : }
229 :
230 : /// Caller should make sure that `iobuf` only have one strong reference before invoking this method.
231 13246 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self {
232 13246 : let mut recycled = iobuf
233 13246 : .into_mut()
234 13246 : .expect("buffer should only have one strong reference");
235 13246 : recycled.clear();
236 13246 : recycled
237 13246 : }
238 : }
239 :
240 : #[cfg(test)]
241 : mod tests {
242 : use std::sync::Mutex;
243 :
244 : use super::*;
245 : use crate::context::{DownloadBehavior, RequestContext};
246 : use crate::task_mgr::TaskKind;
247 :
248 : #[derive(Default, Debug)]
249 : struct RecorderWriter {
250 : /// record bytes and write offsets.
251 : writes: Mutex<Vec<(Vec<u8>, u64)>>,
252 : }
253 :
254 : impl RecorderWriter {
255 : /// Gets recorded bytes and write offsets.
256 4 : fn get_writes(&self) -> Vec<Vec<u8>> {
257 4 : self.writes
258 4 : .lock()
259 4 : .unwrap()
260 4 : .iter()
261 32 : .map(|(buf, _)| buf.clone())
262 4 : .collect()
263 4 : }
264 : }
265 :
266 : impl OwnedAsyncWriter for RecorderWriter {
267 32 : async fn write_all_at<Buf: IoBufAligned + Send>(
268 32 : &self,
269 32 : buf: FullSlice<Buf>,
270 32 : offset: u64,
271 32 : _: &RequestContext,
272 32 : ) -> std::io::Result<FullSlice<Buf>> {
273 32 : self.writes
274 32 : .lock()
275 32 : .unwrap()
276 32 : .push((Vec::from(&buf[..]), offset));
277 32 : Ok(buf)
278 32 : }
279 : }
280 :
281 4 : fn test_ctx() -> RequestContext {
282 4 : RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
283 4 : }
284 :
285 : #[tokio::test]
286 4 : async fn test_write_all_borrowed_always_goes_through_buffer() -> anyhow::Result<()> {
287 4 : let ctx = test_ctx();
288 4 : let ctx = &ctx;
289 4 : let recorder = Arc::new(RecorderWriter::default());
290 4 : let gate = utils::sync::gate::Gate::default();
291 4 : let mut writer = BufferedWriter::<_, RecorderWriter>::new(
292 4 : recorder,
293 8 : || IoBufferMut::with_capacity(2),
294 4 : gate.enter()?,
295 4 : ctx,
296 4 : );
297 4 :
298 4 : writer.write_buffered_borrowed(b"abc", ctx).await?;
299 4 : writer.write_buffered_borrowed(b"", ctx).await?;
300 4 : writer.write_buffered_borrowed(b"d", ctx).await?;
301 4 : writer.write_buffered_borrowed(b"e", ctx).await?;
302 4 : writer.write_buffered_borrowed(b"fg", ctx).await?;
303 4 : writer.write_buffered_borrowed(b"hi", ctx).await?;
304 4 : writer.write_buffered_borrowed(b"j", ctx).await?;
305 4 : writer.write_buffered_borrowed(b"klmno", ctx).await?;
306 4 :
307 4 : let (_, recorder) = writer.flush_and_into_inner(ctx).await?;
308 4 : assert_eq!(
309 4 : recorder.get_writes(),
310 4 : {
311 4 : let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
312 4 : expect
313 4 : }
314 4 : .iter()
315 32 : .map(|v| v[..].to_vec())
316 4 : .collect::<Vec<_>>()
317 4 : );
318 4 : Ok(())
319 4 : }
320 : }
|