Line data Source code
1 : mod flush;
2 : use std::sync::Arc;
3 :
4 : use flush::FlushHandle;
5 : use tokio_epoll_uring::IoBuf;
6 :
7 : use crate::{
8 : context::RequestContext,
9 : virtual_file::{IoBuffer, IoBufferMut},
10 : };
11 :
12 : use super::{
13 : io_buf_aligned::IoBufAligned,
14 : io_buf_ext::{FullSlice, IoBufExt},
15 : };
16 :
17 : pub(crate) use flush::FlushControl;
18 :
19 : pub(crate) trait CheapCloneForRead {
20 : /// Returns a cheap clone of the buffer.
21 : fn cheap_clone(&self) -> Self;
22 : }
23 :
24 : impl CheapCloneForRead for IoBuffer {
25 6619 : fn cheap_clone(&self) -> Self {
26 6619 : // Cheap clone over an `Arc`.
27 6619 : self.clone()
28 6619 : }
29 : }
30 :
31 : /// A trait for doing owned-buffer write IO.
32 : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
33 : /// The owned buffers need to be aligned due to Direct IO requirements.
34 : pub trait OwnedAsyncWriter {
35 : fn write_all_at<Buf: IoBufAligned + Send>(
36 : &self,
37 : buf: FullSlice<Buf>,
38 : offset: u64,
39 : ctx: &RequestContext,
40 : ) -> impl std::future::Future<Output = std::io::Result<FullSlice<Buf>>> + Send;
41 : }
42 :
43 : /// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
44 : /// small writes into larger writes of size [`Buffer::cap`].
45 : // TODO(yuchen): For large write, implementing buffer bypass for aligned parts of the write could be beneficial to throughput,
46 : // since we would avoid copying majority of the data into the internal buffer.
47 : pub struct BufferedWriter<B: Buffer, W> {
48 : writer: Arc<W>,
49 : /// invariant: always remains Some(buf) except
50 : /// - while IO is ongoing => goes back to Some() once the IO completed successfully
51 : /// - after an IO error => stays `None` forever
52 : ///
53 : /// In these exceptional cases, it's `None`.
54 : mutable: Option<B>,
55 : /// A handle to the background flush task for writting data to disk.
56 : flush_handle: FlushHandle<B::IoBuf, W>,
57 : /// The number of bytes submitted to the background task.
58 : bytes_submitted: u64,
59 : }
60 :
61 : impl<B, Buf, W> BufferedWriter<B, W>
62 : where
63 : B: Buffer<IoBuf = Buf> + Send + 'static,
64 : Buf: IoBufAligned + Send + Sync + CheapCloneForRead,
65 : W: OwnedAsyncWriter + Send + Sync + 'static + std::fmt::Debug,
66 : {
67 : /// Creates a new buffered writer.
68 : ///
69 : /// The `buf_new` function provides a way to initialize the owned buffers used by this writer.
70 1285 : pub fn new(
71 1285 : writer: Arc<W>,
72 1285 : buf_new: impl Fn() -> B,
73 1285 : gate_guard: utils::sync::gate::GateGuard,
74 1285 : ctx: &RequestContext,
75 1285 : ) -> Self {
76 1285 : Self {
77 1285 : writer: writer.clone(),
78 1285 : mutable: Some(buf_new()),
79 1285 : flush_handle: FlushHandle::spawn_new(
80 1285 : writer,
81 1285 : buf_new(),
82 1285 : gate_guard,
83 1285 : ctx.attached_child(),
84 1285 : ),
85 1285 : bytes_submitted: 0,
86 1285 : }
87 1285 : }
88 :
89 11345 : pub fn as_inner(&self) -> &W {
90 11345 : &self.writer
91 11345 : }
92 :
93 : /// Returns the number of bytes submitted to the background flush task.
94 498416 : pub fn bytes_submitted(&self) -> u64 {
95 498416 : self.bytes_submitted
96 498416 : }
97 :
98 : /// Panics if used after any of the write paths returned an error
99 498426 : pub fn inspect_mutable(&self) -> &B {
100 498426 : self.mutable()
101 498426 : }
102 :
103 : /// Gets a reference to the maybe flushed read-only buffer.
104 : /// Returns `None` if the writer has not submitted any flush request.
105 498420 : pub fn inspect_maybe_flushed(&self) -> Option<&FullSlice<Buf>> {
106 498420 : self.flush_handle.maybe_flushed.as_ref()
107 498420 : }
108 :
109 : #[cfg_attr(target_os = "macos", allow(dead_code))]
110 5 : pub async fn flush_and_into_inner(
111 5 : mut self,
112 5 : ctx: &RequestContext,
113 5 : ) -> std::io::Result<(u64, Arc<W>)> {
114 5 : self.flush(ctx).await?;
115 :
116 : let Self {
117 5 : mutable: buf,
118 5 : writer,
119 5 : mut flush_handle,
120 5 : bytes_submitted: bytes_amount,
121 5 : } = self;
122 5 : flush_handle.shutdown().await?;
123 5 : assert!(buf.is_some());
124 5 : Ok((bytes_amount, writer))
125 5 : }
126 :
127 : /// Gets a reference to the mutable in-memory buffer.
128 : #[inline(always)]
129 498426 : fn mutable(&self) -> &B {
130 498426 : self.mutable
131 498426 : .as_ref()
132 498426 : .expect("must not use after we returned an error")
133 498426 : }
134 :
135 : #[cfg_attr(target_os = "macos", allow(dead_code))]
136 34 : pub async fn write_buffered_borrowed(
137 34 : &mut self,
138 34 : chunk: &[u8],
139 34 : ctx: &RequestContext,
140 34 : ) -> std::io::Result<usize> {
141 34 : let (len, control) = self.write_buffered_borrowed_controlled(chunk, ctx).await?;
142 34 : if let Some(control) = control {
143 12 : control.release().await;
144 22 : }
145 34 : Ok(len)
146 34 : }
147 :
148 : /// In addition to bytes submitted in this write, also returns a handle that can control the flush behavior.
149 4804884 : pub(crate) async fn write_buffered_borrowed_controlled(
150 4804884 : &mut self,
151 4804884 : mut chunk: &[u8],
152 4804884 : ctx: &RequestContext,
153 4804884 : ) -> std::io::Result<(usize, Option<FlushControl>)> {
154 4804884 : let chunk_len = chunk.len();
155 4804884 : let mut control: Option<FlushControl> = None;
156 9616368 : while !chunk.is_empty() {
157 4811484 : let buf = self.mutable.as_mut().expect("must not use after an error");
158 4811484 : let need = buf.cap() - buf.pending();
159 4811484 : let have = chunk.len();
160 4811484 : let n = std::cmp::min(need, have);
161 4811484 : buf.extend_from_slice(&chunk[..n]);
162 4811484 : chunk = &chunk[n..];
163 4811484 : if buf.pending() >= buf.cap() {
164 6614 : assert_eq!(buf.pending(), buf.cap());
165 6614 : if let Some(control) = control.take() {
166 1066 : control.release().await;
167 5548 : }
168 6614 : control = self.flush(ctx).await?;
169 4804870 : }
170 : }
171 4804884 : Ok((chunk_len, control))
172 4804884 : }
173 :
174 : #[must_use = "caller must explcitly check the flush control"]
175 6619 : async fn flush(&mut self, _ctx: &RequestContext) -> std::io::Result<Option<FlushControl>> {
176 6619 : let buf = self.mutable.take().expect("must not use after an error");
177 6619 : let buf_len = buf.pending();
178 6619 : if buf_len == 0 {
179 0 : self.mutable = Some(buf);
180 0 : return Ok(None);
181 6619 : }
182 6619 : let (recycled, flush_control) = self.flush_handle.flush(buf, self.bytes_submitted).await?;
183 6619 : self.bytes_submitted += u64::try_from(buf_len).unwrap();
184 6619 : self.mutable = Some(recycled);
185 6619 : Ok(Some(flush_control))
186 6619 : }
187 : }
188 :
189 : /// A [`Buffer`] is used by [`BufferedWriter`] to batch smaller writes into larger ones.
190 : pub trait Buffer {
191 : type IoBuf: IoBuf;
192 :
193 : /// Capacity of the buffer. Must not change over the lifetime `self`.`
194 : fn cap(&self) -> usize;
195 :
196 : /// Add data to the buffer.
197 : /// Panics if there is not enough room to accomodate `other`'s content, i.e.,
198 : /// panics if `other.len() > self.cap() - self.pending()`.
199 : fn extend_from_slice(&mut self, other: &[u8]);
200 :
201 : /// Number of bytes in the buffer.
202 : fn pending(&self) -> usize;
203 :
204 : /// Turns `self` into a [`FullSlice`] of the pending data
205 : /// so we can use [`tokio_epoll_uring`] to write it to disk.
206 : fn flush(self) -> FullSlice<Self::IoBuf>;
207 :
208 : /// After the write to disk is done and we have gotten back the slice,
209 : /// [`BufferedWriter`] uses this method to re-use the io buffer.
210 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self;
211 : }
212 :
213 : impl Buffer for IoBufferMut {
214 : type IoBuf = IoBuffer;
215 :
216 14441066 : fn cap(&self) -> usize {
217 14441066 : self.capacity()
218 14441066 : }
219 :
220 4811484 : fn extend_from_slice(&mut self, other: &[u8]) {
221 4811484 : if self.len() + other.len() > self.cap() {
222 0 : panic!("Buffer capacity exceeded");
223 4811484 : }
224 4811484 :
225 4811484 : IoBufferMut::extend_from_slice(self, other);
226 4811484 : }
227 :
228 10134617 : fn pending(&self) -> usize {
229 10134617 : self.len()
230 10134617 : }
231 :
232 7896 : fn flush(self) -> FullSlice<Self::IoBuf> {
233 7896 : self.freeze().slice_len()
234 7896 : }
235 :
236 : /// Caller should make sure that `iobuf` only have one strong reference before invoking this method.
237 6619 : fn reuse_after_flush(iobuf: Self::IoBuf) -> Self {
238 6619 : let mut recycled = iobuf
239 6619 : .into_mut()
240 6619 : .expect("buffer should only have one strong reference");
241 6619 : recycled.clear();
242 6619 : recycled
243 6619 : }
244 : }
245 :
246 : #[cfg(test)]
247 : mod tests {
248 : use std::sync::Mutex;
249 :
250 : use super::*;
251 : use crate::context::{DownloadBehavior, RequestContext};
252 : use crate::task_mgr::TaskKind;
253 :
254 : #[derive(Default, Debug)]
255 : struct RecorderWriter {
256 : /// record bytes and write offsets.
257 : writes: Mutex<Vec<(Vec<u8>, u64)>>,
258 : }
259 :
260 : impl RecorderWriter {
261 : /// Gets recorded bytes and write offsets.
262 2 : fn get_writes(&self) -> Vec<Vec<u8>> {
263 2 : self.writes
264 2 : .lock()
265 2 : .unwrap()
266 2 : .iter()
267 16 : .map(|(buf, _)| buf.clone())
268 2 : .collect()
269 2 : }
270 : }
271 :
272 : impl OwnedAsyncWriter for RecorderWriter {
273 16 : async fn write_all_at<Buf: IoBufAligned + Send>(
274 16 : &self,
275 16 : buf: FullSlice<Buf>,
276 16 : offset: u64,
277 16 : _: &RequestContext,
278 16 : ) -> std::io::Result<FullSlice<Buf>> {
279 16 : self.writes
280 16 : .lock()
281 16 : .unwrap()
282 16 : .push((Vec::from(&buf[..]), offset));
283 16 : Ok(buf)
284 16 : }
285 : }
286 :
287 2 : fn test_ctx() -> RequestContext {
288 2 : RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
289 2 : }
290 :
291 : #[tokio::test]
292 2 : async fn test_write_all_borrowed_always_goes_through_buffer() -> anyhow::Result<()> {
293 2 : let ctx = test_ctx();
294 2 : let ctx = &ctx;
295 2 : let recorder = Arc::new(RecorderWriter::default());
296 2 : let gate = utils::sync::gate::Gate::default();
297 2 : let mut writer = BufferedWriter::<_, RecorderWriter>::new(
298 2 : recorder,
299 4 : || IoBufferMut::with_capacity(2),
300 2 : gate.enter()?,
301 2 : ctx,
302 2 : );
303 2 :
304 2 : writer.write_buffered_borrowed(b"abc", ctx).await?;
305 2 : writer.write_buffered_borrowed(b"", ctx).await?;
306 2 : writer.write_buffered_borrowed(b"d", ctx).await?;
307 2 : writer.write_buffered_borrowed(b"e", ctx).await?;
308 2 : writer.write_buffered_borrowed(b"fg", ctx).await?;
309 2 : writer.write_buffered_borrowed(b"hi", ctx).await?;
310 2 : writer.write_buffered_borrowed(b"j", ctx).await?;
311 2 : writer.write_buffered_borrowed(b"klmno", ctx).await?;
312 2 :
313 2 : let (_, recorder) = writer.flush_and_into_inner(ctx).await?;
314 2 : assert_eq!(
315 2 : recorder.get_writes(),
316 2 : {
317 2 : let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
318 2 : expect
319 2 : }
320 2 : .iter()
321 16 : .map(|v| v[..].to_vec())
322 2 : .collect::<Vec<_>>()
323 2 : );
324 2 : Ok(())
325 2 : }
326 : }
|