LCOV - code coverage report
Current view: top level - pageserver/src/virtual_file/owned_buffers_io - write.rs (source / functions) Coverage Total Hit
Test: 42f947419473a288706e86ecdf7c2863d760d5d7.info Lines: 92.4 % 211 195
Test Date: 2024-08-02 21:34:27 Functions: 95.7 % 47 45

            Line data    Source code
       1              : use bytes::BytesMut;
       2              : use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
       3              : 
       4              : use crate::context::RequestContext;
       5              : 
       6              : /// A trait for doing owned-buffer write IO.
       7              : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
       8              : pub trait OwnedAsyncWriter {
       9              :     async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
      10              :         &mut self,
      11              :         buf: B,
      12              :         ctx: &RequestContext,
      13              :     ) -> std::io::Result<(usize, B::Buf)>;
      14              : }
      15              : 
      16              : /// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
      17              : /// small writes into larger writes of size [`Buffer::cap`].
      18              : ///
      19              : /// # Passthrough Of Large Writers
      20              : ///
      21              : /// Calls to [`BufferedWriter::write_buffered`] that are larger than [`Buffer::cap`]
      22              : /// cause the internal buffer to be flushed prematurely so that the large
      23              : /// buffered write is passed through to the underlying [`OwnedAsyncWriter`].
      24              : ///
      25              : /// This pass-through is generally beneficial for throughput, but if
      26              : /// the storage backend of the [`OwnedAsyncWriter`] is a shared resource,
      27              : /// unlimited large writes may cause latency or fairness issues.
      28              : ///
      29              : /// In such cases, a different implementation that always buffers in memory
      30              : /// may be preferable.
      31              : pub struct BufferedWriter<B, W> {
      32              :     writer: W,
      33              :     /// invariant: always remains Some(buf) except
      34              :     /// - while IO is ongoing => goes back to Some() once the IO completed successfully
      35              :     /// - after an IO error => stays `None` forever
      36              :     ///
      37              :     /// In these exceptional cases, it's `None`.
      38              :     buf: Option<B>,
      39              : }
      40              : 
      41              : impl<B, Buf, W> BufferedWriter<B, W>
      42              : where
      43              :     B: Buffer<IoBuf = Buf> + Send,
      44              :     Buf: IoBuf + Send,
      45              :     W: OwnedAsyncWriter,
      46              : {
      47         1269 :     pub fn new(writer: W, buf: B) -> Self {
      48         1269 :         Self {
      49         1269 :             writer,
      50         1269 :             buf: Some(buf),
      51         1269 :         }
      52         1269 :     }
      53              : 
      54     10931128 :     pub fn as_inner(&self) -> &W {
      55     10931128 :         &self.writer
      56     10931128 :     }
      57              : 
      58              :     /// Panics if used after any of the write paths returned an error
      59     10730280 :     pub fn inspect_buffer(&self) -> &B {
      60     10730280 :         self.buf()
      61     10730280 :     }
      62              : 
      63              :     #[cfg_attr(target_os = "macos", allow(dead_code))]
      64           11 :     pub async fn flush_and_into_inner(mut self, ctx: &RequestContext) -> std::io::Result<W> {
      65           11 :         self.flush(ctx).await?;
      66              : 
      67           11 :         let Self { buf, writer } = self;
      68           11 :         assert!(buf.is_some());
      69           11 :         Ok(writer)
      70           11 :     }
      71              : 
      72              :     #[inline(always)]
      73     10730360 :     fn buf(&self) -> &B {
      74     10730360 :         self.buf
      75     10730360 :             .as_ref()
      76     10730360 :             .expect("must not use after we returned an error")
      77     10730360 :     }
      78              : 
      79              :     #[cfg_attr(target_os = "macos", allow(dead_code))]
      80           44 :     pub async fn write_buffered<S: IoBuf + Send>(
      81           44 :         &mut self,
      82           44 :         chunk: Slice<S>,
      83           44 :         ctx: &RequestContext,
      84           44 :     ) -> std::io::Result<(usize, S)> {
      85           44 :         let chunk_len = chunk.len();
      86           44 :         // avoid memcpy for the middle of the chunk
      87           44 :         if chunk.len() >= self.buf().cap() {
      88            8 :             self.flush(ctx).await?;
      89              :             // do a big write, bypassing `buf`
      90            8 :             assert_eq!(
      91            8 :                 self.buf
      92            8 :                     .as_ref()
      93            8 :                     .expect("must not use after an error")
      94            8 :                     .pending(),
      95            8 :                 0
      96            8 :             );
      97            8 :             let (nwritten, chunk) = self.writer.write_all(chunk, ctx).await?;
      98            8 :             assert_eq!(nwritten, chunk_len);
      99            8 :             return Ok((nwritten, chunk));
     100           36 :         }
     101           36 :         // in-memory copy the < BUFFER_SIZED tail of the chunk
     102           36 :         assert!(chunk.len() < self.buf().cap());
     103           36 :         let mut slice = &chunk[..];
     104           70 :         while !slice.is_empty() {
     105           34 :             let buf = self.buf.as_mut().expect("must not use after an error");
     106           34 :             let need = buf.cap() - buf.pending();
     107           34 :             let have = slice.len();
     108           34 :             let n = std::cmp::min(need, have);
     109           34 :             buf.extend_from_slice(&slice[..n]);
     110           34 :             slice = &slice[n..];
     111           34 :             if buf.pending() >= buf.cap() {
     112            6 :                 assert_eq!(buf.pending(), buf.cap());
     113            6 :                 self.flush(ctx).await?;
     114           28 :             }
     115              :         }
     116           36 :         assert!(slice.is_empty(), "by now we should have drained the chunk");
     117           36 :         Ok((chunk_len, chunk.into_inner()))
     118           44 :     }
     119              : 
     120              :     /// Strictly less performant variant of [`Self::write_buffered`] that allows writing borrowed data.
     121              :     ///
     122              :     /// It is less performant because we always have to copy the borrowed data into the internal buffer
     123              :     /// before we can do the IO. The [`Self::write_buffered`] can avoid this, which is more performant
     124              :     /// for large writes.
     125     10221518 :     pub async fn write_buffered_borrowed(
     126     10221518 :         &mut self,
     127     10221518 :         mut chunk: &[u8],
     128     10221518 :         ctx: &RequestContext,
     129     10221518 :     ) -> std::io::Result<usize> {
     130     10221518 :         let chunk_len = chunk.len();
     131     20449642 :         while !chunk.is_empty() {
     132     10228124 :             let buf = self.buf.as_mut().expect("must not use after an error");
     133     10228124 :             let need = buf.cap() - buf.pending();
     134     10228124 :             let have = chunk.len();
     135     10228124 :             let n = std::cmp::min(need, have);
     136     10228124 :             buf.extend_from_slice(&chunk[..n]);
     137     10228124 :             chunk = &chunk[n..];
     138     10228124 :             if buf.pending() >= buf.cap() {
     139         6624 :                 assert_eq!(buf.pending(), buf.cap());
     140         6624 :                 self.flush(ctx).await?;
     141     10221500 :             }
     142              :         }
     143     10221518 :         Ok(chunk_len)
     144     10221518 :     }
     145              : 
     146         6649 :     async fn flush(&mut self, ctx: &RequestContext) -> std::io::Result<()> {
     147         6649 :         let buf = self.buf.take().expect("must not use after an error");
     148         6649 :         let buf_len = buf.pending();
     149         6649 :         if buf_len == 0 {
     150           10 :             self.buf = Some(buf);
     151           10 :             return Ok(());
     152         6639 :         }
     153         6639 :         let (nwritten, io_buf) = self.writer.write_all(buf.flush(), ctx).await?;
     154         6639 :         assert_eq!(nwritten, buf_len);
     155         6639 :         self.buf = Some(Buffer::reuse_after_flush(io_buf));
     156         6639 :         Ok(())
     157         6649 :     }
     158              : }
     159              : 
     160              : /// A [`Buffer`] is used by [`BufferedWriter`] to batch smaller writes into larger ones.
     161              : pub trait Buffer {
     162              :     type IoBuf: IoBuf;
     163              : 
     164              :     /// Capacity of the buffer. Must not change over the lifetime `self`.`
     165              :     fn cap(&self) -> usize;
     166              : 
     167              :     /// Add data to the buffer.
     168              :     /// Panics if there is not enough room to accomodate `other`'s content, i.e.,
     169              :     /// panics if `other.len() > self.cap() - self.pending()`.
     170              :     fn extend_from_slice(&mut self, other: &[u8]);
     171              : 
     172              :     /// Number of bytes in the buffer.
     173              :     fn pending(&self) -> usize;
     174              : 
     175              :     /// Turns `self` into a [`tokio_epoll_uring::Slice`] of the pending data
     176              :     /// so we can use [`tokio_epoll_uring`] to write it to disk.
     177              :     fn flush(self) -> Slice<Self::IoBuf>;
     178              : 
     179              :     /// After the write to disk is done and we have gotten back the slice,
     180              :     /// [`BufferedWriter`] uses this method to re-use the io buffer.
     181              :     fn reuse_after_flush(iobuf: Self::IoBuf) -> Self;
     182              : }
     183              : 
     184              : impl Buffer for BytesMut {
     185              :     type IoBuf = BytesMut;
     186              : 
     187              :     #[inline(always)]
     188          216 :     fn cap(&self) -> usize {
     189          216 :         self.capacity()
     190          216 :     }
     191              : 
     192           58 :     fn extend_from_slice(&mut self, other: &[u8]) {
     193           58 :         BytesMut::extend_from_slice(self, other)
     194           58 :     }
     195              : 
     196              :     #[inline(always)]
     197          183 :     fn pending(&self) -> usize {
     198          183 :         self.len()
     199          183 :     }
     200              : 
     201           29 :     fn flush(self) -> Slice<BytesMut> {
     202           29 :         if self.is_empty() {
     203            0 :             return self.slice_full();
     204           29 :         }
     205           29 :         let len = self.len();
     206           29 :         self.slice(0..len)
     207           29 :     }
     208              : 
     209           29 :     fn reuse_after_flush(mut iobuf: BytesMut) -> Self {
     210           29 :         iobuf.clear();
     211           29 :         iobuf
     212           29 :     }
     213              : }
     214              : 
     215              : impl OwnedAsyncWriter for Vec<u8> {
     216            0 :     async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
     217            0 :         &mut self,
     218            0 :         buf: B,
     219            0 :         _: &RequestContext,
     220            0 :     ) -> std::io::Result<(usize, B::Buf)> {
     221            0 :         let nbytes = buf.bytes_init();
     222            0 :         if nbytes == 0 {
     223            0 :             return Ok((0, Slice::into_inner(buf.slice_full())));
     224            0 :         }
     225            0 :         let buf = buf.slice(0..nbytes);
     226            0 :         self.extend_from_slice(&buf[..]);
     227            0 :         Ok((buf.len(), Slice::into_inner(buf)))
     228            0 :     }
     229              : }
     230              : 
     231              : #[cfg(test)]
     232              : mod tests {
     233              :     use bytes::BytesMut;
     234              : 
     235              :     use super::*;
     236              :     use crate::context::{DownloadBehavior, RequestContext};
     237              :     use crate::task_mgr::TaskKind;
     238              : 
     239              :     #[derive(Default)]
     240              :     struct RecorderWriter {
     241              :         writes: Vec<Vec<u8>>,
     242              :     }
     243              :     impl OwnedAsyncWriter for RecorderWriter {
     244           34 :         async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
     245           34 :             &mut self,
     246           34 :             buf: B,
     247           34 :             _: &RequestContext,
     248           34 :         ) -> std::io::Result<(usize, B::Buf)> {
     249           34 :             let nbytes = buf.bytes_init();
     250           34 :             if nbytes == 0 {
     251            0 :                 self.writes.push(vec![]);
     252            0 :                 return Ok((0, Slice::into_inner(buf.slice_full())));
     253           34 :             }
     254           34 :             let buf = buf.slice(0..nbytes);
     255           34 :             self.writes.push(Vec::from(&buf[..]));
     256           34 :             Ok((buf.len(), Slice::into_inner(buf)))
     257           34 :         }
     258              :     }
     259              : 
     260           34 :     fn test_ctx() -> RequestContext {
     261           34 :         RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
     262           34 :     }
     263              : 
     264              :     macro_rules! write {
     265              :         ($writer:ident, $data:literal) => {{
     266              :             $writer
     267              :                 .write_buffered(::bytes::Bytes::from_static($data).slice_full(), &test_ctx())
     268              :                 .await?;
     269              :         }};
     270              :     }
     271              : 
     272              :     #[tokio::test]
     273            2 :     async fn test_buffered_writes_only() -> std::io::Result<()> {
     274            2 :         let recorder = RecorderWriter::default();
     275            2 :         let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
     276            2 :         write!(writer, b"a");
     277            2 :         write!(writer, b"b");
     278            2 :         write!(writer, b"c");
     279            2 :         write!(writer, b"d");
     280            2 :         write!(writer, b"e");
     281            2 :         let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
     282            2 :         assert_eq!(
     283            2 :             recorder.writes,
     284            2 :             vec![Vec::from(b"ab"), Vec::from(b"cd"), Vec::from(b"e")]
     285            2 :         );
     286            2 :         Ok(())
     287            2 :     }
     288              : 
     289              :     #[tokio::test]
     290            2 :     async fn test_passthrough_writes_only() -> std::io::Result<()> {
     291            2 :         let recorder = RecorderWriter::default();
     292            2 :         let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
     293            2 :         write!(writer, b"abc");
     294            2 :         write!(writer, b"de");
     295            2 :         write!(writer, b"");
     296            2 :         write!(writer, b"fghijk");
     297            2 :         let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
     298            2 :         assert_eq!(
     299            2 :             recorder.writes,
     300            2 :             vec![Vec::from(b"abc"), Vec::from(b"de"), Vec::from(b"fghijk")]
     301            2 :         );
     302            2 :         Ok(())
     303            2 :     }
     304              : 
     305              :     #[tokio::test]
     306            2 :     async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> {
     307            2 :         let recorder = RecorderWriter::default();
     308            2 :         let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
     309            2 :         write!(writer, b"a");
     310            2 :         write!(writer, b"bc");
     311            2 :         write!(writer, b"d");
     312            2 :         write!(writer, b"e");
     313            2 :         let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
     314            2 :         assert_eq!(
     315            2 :             recorder.writes,
     316            2 :             vec![Vec::from(b"a"), Vec::from(b"bc"), Vec::from(b"de")]
     317            2 :         );
     318            2 :         Ok(())
     319            2 :     }
     320              : 
     321              :     #[tokio::test]
     322            2 :     async fn test_write_all_borrowed_always_goes_through_buffer() -> std::io::Result<()> {
     323            2 :         let ctx = test_ctx();
     324            2 :         let ctx = &ctx;
     325            2 :         let recorder = RecorderWriter::default();
     326            2 :         let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
     327            2 : 
     328            2 :         writer.write_buffered_borrowed(b"abc", ctx).await?;
     329            2 :         writer.write_buffered_borrowed(b"d", ctx).await?;
     330            2 :         writer.write_buffered_borrowed(b"e", ctx).await?;
     331            2 :         writer.write_buffered_borrowed(b"fg", ctx).await?;
     332            2 :         writer.write_buffered_borrowed(b"hi", ctx).await?;
     333            2 :         writer.write_buffered_borrowed(b"j", ctx).await?;
     334            2 :         writer.write_buffered_borrowed(b"klmno", ctx).await?;
     335            2 : 
     336            2 :         let recorder = writer.flush_and_into_inner(ctx).await?;
     337            2 :         assert_eq!(
     338            2 :             recorder.writes,
     339            2 :             {
     340            2 :                 let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
     341            2 :                 expect
     342            2 :             }
     343            2 :             .iter()
     344           16 :             .map(|v| v[..].to_vec())
     345            2 :             .collect::<Vec<_>>()
     346            2 :         );
     347            2 :         Ok(())
     348            2 :     }
     349              : }
        

Generated by: LCOV version 2.1-beta