LCOV - code coverage report
Current view: top level - pageserver/src/virtual_file/owned_buffers_io - write.rs (source / functions) Coverage Total Hit
Test: 5fe7fa8d483b39476409aee736d6d5e32728bfac.info Lines: 98.4 % 182 179
Test Date: 2025-03-12 16:10:49 Functions: 100.0 % 39 39

            Line data    Source code
       1              : mod flush;
       2              : use std::sync::Arc;
       3              : 
       4              : pub(crate) use flush::FlushControl;
       5              : use flush::FlushHandle;
       6              : use tokio_epoll_uring::IoBuf;
       7              : 
       8              : use super::io_buf_aligned::IoBufAligned;
       9              : use super::io_buf_ext::{FullSlice, IoBufExt};
      10              : use crate::context::RequestContext;
      11              : use crate::virtual_file::{IoBuffer, IoBufferMut};
      12              : 
      13              : pub(crate) trait CheapCloneForRead {
      14              :     /// Returns a cheap clone of the buffer.
      15              :     fn cheap_clone(&self) -> Self;
      16              : }
      17              : 
      18              : impl CheapCloneForRead for IoBuffer {
      19        13246 :     fn cheap_clone(&self) -> Self {
      20        13246 :         // Cheap clone over an `Arc`.
      21        13246 :         self.clone()
      22        13246 :     }
      23              : }
      24              : 
      25              : /// A trait for doing owned-buffer write IO.
      26              : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
      27              : /// The owned buffers need to be aligned due to Direct IO requirements.
      28              : pub trait OwnedAsyncWriter {
      29              :     fn write_all_at<Buf: IoBufAligned + Send>(
      30              :         &self,
      31              :         buf: FullSlice<Buf>,
      32              :         offset: u64,
      33              :         ctx: &RequestContext,
      34              :     ) -> impl std::future::Future<Output = (FullSlice<Buf>, std::io::Result<()>)> + Send;
      35              : }
      36              : 
      37              : /// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
      38              : /// small writes into larger writes of size [`Buffer::cap`].
      39              : // TODO(yuchen): For large write, implementing buffer bypass for aligned parts of the write could be beneficial to throughput,
      40              : // since we would avoid copying majority of the data into the internal buffer.
      41              : pub struct BufferedWriter<B: Buffer, W> {
      42              :     writer: Arc<W>,
      43              :     /// invariant: always remains Some(buf) except
      44              :     /// - while IO is ongoing => goes back to Some() once the IO completed successfully
      45              :     /// - after an IO error => stays `None` forever
      46              :     ///
      47              :     /// In these exceptional cases, it's `None`.
      48              :     mutable: Option<B>,
      49              :     /// A handle to the background flush task for writting data to disk.
      50              :     flush_handle: FlushHandle<B::IoBuf, W>,
      51              :     /// The number of bytes submitted to the background task.
      52              :     bytes_submitted: u64,
      53              : }
      54              : 
      55              : impl<B, Buf, W> BufferedWriter<B, W>
      56              : where
      57              :     B: Buffer<IoBuf = Buf> + Send + 'static,
      58              :     Buf: IoBufAligned + Send + Sync + CheapCloneForRead,
      59              :     W: OwnedAsyncWriter + Send + Sync + 'static + std::fmt::Debug,
      60              : {
      61              :     /// Creates a new buffered writer.
      62              :     ///
      63              :     /// The `buf_new` function provides a way to initialize the owned buffers used by this writer.
      64         2650 :     pub fn new(
      65         2650 :         writer: Arc<W>,
      66         2650 :         buf_new: impl Fn() -> B,
      67         2650 :         gate_guard: utils::sync::gate::GateGuard,
      68         2650 :         ctx: &RequestContext,
      69         2650 :         flush_task_span: tracing::Span,
      70         2650 :     ) -> Self {
      71         2650 :         Self {
      72         2650 :             writer: writer.clone(),
      73         2650 :             mutable: Some(buf_new()),
      74         2650 :             flush_handle: FlushHandle::spawn_new(
      75         2650 :                 writer,
      76         2650 :                 buf_new(),
      77         2650 :                 gate_guard,
      78         2650 :                 ctx.attached_child(),
      79         2650 :                 flush_task_span,
      80         2650 :             ),
      81         2650 :             bytes_submitted: 0,
      82         2650 :         }
      83         2650 :     }
      84              : 
      85        22723 :     pub fn as_inner(&self) -> &W {
      86        22723 :         &self.writer
      87        22723 :     }
      88              : 
      89              :     /// Returns the number of bytes submitted to the background flush task.
      90       997837 :     pub fn bytes_submitted(&self) -> u64 {
      91       997837 :         self.bytes_submitted
      92       997837 :     }
      93              : 
      94              :     /// Panics if used after any of the write paths returned an error
      95       997857 :     pub fn inspect_mutable(&self) -> &B {
      96       997857 :         self.mutable()
      97       997857 :     }
      98              : 
      99              :     /// Gets a reference to the maybe flushed read-only buffer.
     100              :     /// Returns `None` if the writer has not submitted any flush request.
     101       997845 :     pub fn inspect_maybe_flushed(&self) -> Option<&FullSlice<Buf>> {
     102       997845 :         self.flush_handle.maybe_flushed.as_ref()
     103       997845 :     }
     104              : 
     105              :     #[cfg_attr(target_os = "macos", allow(dead_code))]
     106           18 :     pub async fn flush_and_into_inner(
     107           18 :         mut self,
     108           18 :         ctx: &RequestContext,
     109           18 :     ) -> std::io::Result<(u64, Arc<W>)> {
     110           18 :         self.flush(ctx).await?;
     111              : 
     112              :         let Self {
     113           18 :             mutable: buf,
     114           18 :             writer,
     115           18 :             mut flush_handle,
     116           18 :             bytes_submitted: bytes_amount,
     117           18 :         } = self;
     118           18 :         flush_handle.shutdown().await?;
     119           18 :         assert!(buf.is_some());
     120           18 :         Ok((bytes_amount, writer))
     121           18 :     }
     122              : 
     123              :     /// Gets a reference to the mutable in-memory buffer.
     124              :     #[inline(always)]
     125       997857 :     fn mutable(&self) -> &B {
     126       997857 :         self.mutable
     127       997857 :             .as_ref()
     128       997857 :             .expect("must not use after we returned an error")
     129       997857 :     }
     130              : 
     131              :     #[cfg_attr(target_os = "macos", allow(dead_code))]
     132          116 :     pub async fn write_buffered_borrowed(
     133          116 :         &mut self,
     134          116 :         chunk: &[u8],
     135          116 :         ctx: &RequestContext,
     136          116 :     ) -> std::io::Result<usize> {
     137          116 :         let (len, control) = self.write_buffered_borrowed_controlled(chunk, ctx).await?;
     138          116 :         if let Some(control) = control {
     139           24 :             control.release().await;
     140           92 :         }
     141          116 :         Ok(len)
     142          116 :     }
     143              : 
     144              :     /// In addition to bytes submitted in this write, also returns a handle that can control the flush behavior.
     145      9609884 :     pub(crate) async fn write_buffered_borrowed_controlled(
     146      9609884 :         &mut self,
     147      9609884 :         mut chunk: &[u8],
     148      9609884 :         ctx: &RequestContext,
     149      9609884 :     ) -> std::io::Result<(usize, Option<FlushControl>)> {
     150      9609884 :         let chunk_len = chunk.len();
     151      9609884 :         let mut control: Option<FlushControl> = None;
     152     19232968 :         while !chunk.is_empty() {
     153      9623084 :             let buf = self.mutable.as_mut().expect("must not use after an error");
     154      9623084 :             let need = buf.cap() - buf.pending();
     155      9623084 :             let have = chunk.len();
     156      9623084 :             let n = std::cmp::min(need, have);
     157      9623084 :             buf.extend_from_slice(&chunk[..n]);
     158      9623084 :             chunk = &chunk[n..];
     159      9623084 :             if buf.pending() >= buf.cap() {
     160        13228 :                 assert_eq!(buf.pending(), buf.cap());
     161        13228 :                 if let Some(control) = control.take() {
     162         2132 :                     control.release().await;
     163        11096 :                 }
     164        13228 :                 control = self.flush(ctx).await?;
     165      9609856 :             }
     166              :         }
     167      9609884 :         Ok((chunk_len, control))
     168      9609884 :     }
     169              : 
     170              :     #[must_use = "caller must explcitly check the flush control"]
     171        13246 :     async fn flush(&mut self, _ctx: &RequestContext) -> std::io::Result<Option<FlushControl>> {
     172        13246 :         let buf = self.mutable.take().expect("must not use after an error");
     173        13246 :         let buf_len = buf.pending();
     174        13246 :         if buf_len == 0 {
     175            0 :             self.mutable = Some(buf);
     176            0 :             return Ok(None);
     177        13246 :         }
     178        13246 :         let (recycled, flush_control) = self.flush_handle.flush(buf, self.bytes_submitted).await?;
     179        13246 :         self.bytes_submitted += u64::try_from(buf_len).unwrap();
     180        13246 :         self.mutable = Some(recycled);
     181        13246 :         Ok(Some(flush_control))
     182        13246 :     }
     183              : }
     184              : 
     185              : /// A [`Buffer`] is used by [`BufferedWriter`] to batch smaller writes into larger ones.
     186              : pub trait Buffer {
     187              :     type IoBuf: IoBuf;
     188              : 
     189              :     /// Capacity of the buffer. Must not change over the lifetime `self`.`
     190              :     fn cap(&self) -> usize;
     191              : 
     192              :     /// Add data to the buffer.
     193              :     /// Panics if there is not enough room to accomodate `other`'s content, i.e.,
     194              :     /// panics if `other.len() > self.cap() - self.pending()`.
     195              :     fn extend_from_slice(&mut self, other: &[u8]);
     196              : 
     197              :     /// Number of bytes in the buffer.
     198              :     fn pending(&self) -> usize;
     199              : 
     200              :     /// Turns `self` into a [`FullSlice`] of the pending data
     201              :     /// so we can use [`tokio_epoll_uring`] to write it to disk.
     202              :     fn flush(self) -> FullSlice<Self::IoBuf>;
     203              : 
     204              :     /// After the write to disk is done and we have gotten back the slice,
     205              :     /// [`BufferedWriter`] uses this method to re-use the io buffer.
     206              :     fn reuse_after_flush(iobuf: Self::IoBuf) -> Self;
     207              : }
     208              : 
     209              : impl Buffer for IoBufferMut {
     210              :     type IoBuf = IoBuffer;
     211              : 
     212     28882480 :     fn cap(&self) -> usize {
     213     28882480 :         self.capacity()
     214     28882480 :     }
     215              : 
     216      9623084 :     fn extend_from_slice(&mut self, other: &[u8]) {
     217      9623084 :         if self.len() + other.len() > self.cap() {
     218            0 :             panic!("Buffer capacity exceeded");
     219      9623084 :         }
     220      9623084 : 
     221      9623084 :         IoBufferMut::extend_from_slice(self, other);
     222      9623084 :     }
     223              : 
     224     20270479 :     fn pending(&self) -> usize {
     225     20270479 :         self.len()
     226     20270479 :     }
     227              : 
     228        15873 :     fn flush(self) -> FullSlice<Self::IoBuf> {
     229        15873 :         self.freeze().slice_len()
     230        15873 :     }
     231              : 
     232              :     /// Caller should make sure that `iobuf` only have one strong reference before invoking this method.
     233        13246 :     fn reuse_after_flush(iobuf: Self::IoBuf) -> Self {
     234        13246 :         let mut recycled = iobuf
     235        13246 :             .into_mut()
     236        13246 :             .expect("buffer should only have one strong reference");
     237        13246 :         recycled.clear();
     238        13246 :         recycled
     239        13246 :     }
     240              : }
     241              : 
     242              : #[cfg(test)]
     243              : mod tests {
     244              :     use std::sync::Mutex;
     245              : 
     246              :     use super::*;
     247              :     use crate::context::{DownloadBehavior, RequestContext};
     248              :     use crate::task_mgr::TaskKind;
     249              : 
     250              :     #[derive(Default, Debug)]
     251              :     struct RecorderWriter {
     252              :         /// record bytes and write offsets.
     253              :         writes: Mutex<Vec<(Vec<u8>, u64)>>,
     254              :     }
     255              : 
     256              :     impl RecorderWriter {
     257              :         /// Gets recorded bytes and write offsets.
     258            4 :         fn get_writes(&self) -> Vec<Vec<u8>> {
     259            4 :             self.writes
     260            4 :                 .lock()
     261            4 :                 .unwrap()
     262            4 :                 .iter()
     263           32 :                 .map(|(buf, _)| buf.clone())
     264            4 :                 .collect()
     265            4 :         }
     266              :     }
     267              : 
     268              :     impl OwnedAsyncWriter for RecorderWriter {
     269           32 :         async fn write_all_at<Buf: IoBufAligned + Send>(
     270           32 :             &self,
     271           32 :             buf: FullSlice<Buf>,
     272           32 :             offset: u64,
     273           32 :             _: &RequestContext,
     274           32 :         ) -> (FullSlice<Buf>, std::io::Result<()>) {
     275           32 :             self.writes
     276           32 :                 .lock()
     277           32 :                 .unwrap()
     278           32 :                 .push((Vec::from(&buf[..]), offset));
     279           32 :             (buf, Ok(()))
     280           32 :         }
     281              :     }
     282              : 
     283            4 :     fn test_ctx() -> RequestContext {
     284            4 :         RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
     285            4 :     }
     286              : 
     287              :     #[tokio::test]
     288            4 :     async fn test_write_all_borrowed_always_goes_through_buffer() -> anyhow::Result<()> {
     289            4 :         let ctx = test_ctx();
     290            4 :         let ctx = &ctx;
     291            4 :         let recorder = Arc::new(RecorderWriter::default());
     292            4 :         let gate = utils::sync::gate::Gate::default();
     293            4 :         let mut writer = BufferedWriter::<_, RecorderWriter>::new(
     294            4 :             recorder,
     295            8 :             || IoBufferMut::with_capacity(2),
     296            4 :             gate.enter()?,
     297            4 :             ctx,
     298            4 :             tracing::Span::none(),
     299            4 :         );
     300            4 : 
     301            4 :         writer.write_buffered_borrowed(b"abc", ctx).await?;
     302            4 :         writer.write_buffered_borrowed(b"", ctx).await?;
     303            4 :         writer.write_buffered_borrowed(b"d", ctx).await?;
     304            4 :         writer.write_buffered_borrowed(b"e", ctx).await?;
     305            4 :         writer.write_buffered_borrowed(b"fg", ctx).await?;
     306            4 :         writer.write_buffered_borrowed(b"hi", ctx).await?;
     307            4 :         writer.write_buffered_borrowed(b"j", ctx).await?;
     308            4 :         writer.write_buffered_borrowed(b"klmno", ctx).await?;
     309            4 : 
     310            4 :         let (_, recorder) = writer.flush_and_into_inner(ctx).await?;
     311            4 :         assert_eq!(
     312            4 :             recorder.get_writes(),
     313            4 :             {
     314            4 :                 let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
     315            4 :                 expect
     316            4 :             }
     317            4 :             .iter()
     318           32 :             .map(|v| v[..].to_vec())
     319            4 :             .collect::<Vec<_>>()
     320            4 :         );
     321            4 :         Ok(())
     322            4 :     }
     323              : }
        

Generated by: LCOV version 2.1-beta