LCOV - code coverage report
Current view: top level - pageserver/src/virtual_file/owned_buffers_io - write.rs (source / functions) Coverage Total Hit
Test: b837401fb09d2d9818b70e630fdb67e9799b7b0d.info Lines: 89.2 % 130 116
Test Date: 2024-04-18 15:32:49 Functions: 92.3 % 26 24

            Line data    Source code
       1              : use bytes::BytesMut;
       2              : use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
       3              : 
       4              : /// A trait for doing owned-buffer write IO.
       5              : /// Think [`tokio::io::AsyncWrite`] but with owned buffers.
       6              : pub trait OwnedAsyncWriter {
       7              :     async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
       8              :         &mut self,
       9              :         buf: B,
      10              :     ) -> std::io::Result<(usize, B::Buf)>;
      11              : }
      12              : 
      13              : /// A wrapper aorund an [`OwnedAsyncWriter`] that batches smaller writers
      14              : /// into `BUFFER_SIZE`-sized writes.
      15              : ///
      16              : /// # Passthrough Of Large Writers
      17              : ///
      18              : /// Buffered writes larger than the `BUFFER_SIZE` cause the internal
      19              : /// buffer to be flushed, even if it is not full yet. Then, the large
      20              : /// buffered write is passed through to the unerlying [`OwnedAsyncWriter`].
      21              : ///
      22              : /// This pass-through is generally beneficial for throughput, but if
      23              : /// the storage backend of the [`OwnedAsyncWriter`] is a shared resource,
      24              : /// unlimited large writes may cause latency or fairness issues.
      25              : ///
      26              : /// In such cases, a different implementation that always buffers in memory
      27              : /// may be preferable.
      28              : pub struct BufferedWriter<const BUFFER_SIZE: usize, W> {
      29              :     writer: W,
      30              :     // invariant: always remains Some(buf)
      31              :     // with buf.capacity() == BUFFER_SIZE except
      32              :     // - while IO is ongoing => goes back to Some() once the IO completed successfully
      33              :     // - after an IO error => stays `None` forever
      34              :     // In these exceptional cases, it's `None`.
      35              :     buf: Option<BytesMut>,
      36              : }
      37              : 
      38              : impl<const BUFFER_SIZE: usize, W> BufferedWriter<BUFFER_SIZE, W>
      39              : where
      40              :     W: OwnedAsyncWriter,
      41              : {
      42            9 :     pub fn new(writer: W) -> Self {
      43            9 :         Self {
      44            9 :             writer,
      45            9 :             buf: Some(BytesMut::with_capacity(BUFFER_SIZE)),
      46            9 :         }
      47            9 :     }
      48              : 
      49            9 :     pub async fn flush_and_into_inner(mut self) -> std::io::Result<W> {
      50            9 :         self.flush().await?;
      51            9 :         let Self { buf, writer } = self;
      52            9 :         assert!(buf.is_some());
      53            9 :         Ok(writer)
      54            9 :     }
      55              : 
      56           44 :     pub async fn write_buffered<B: IoBuf>(&mut self, chunk: Slice<B>) -> std::io::Result<()>
      57           44 :     where
      58           44 :         B: IoBuf + Send,
      59           44 :     {
      60           44 :         // avoid memcpy for the middle of the chunk
      61           44 :         if chunk.len() >= BUFFER_SIZE {
      62            8 :             self.flush().await?;
      63              :             // do a big write, bypassing `buf`
      64            8 :             assert_eq!(
      65            8 :                 self.buf
      66            8 :                     .as_ref()
      67            8 :                     .expect("must not use after an error")
      68            8 :                     .len(),
      69            8 :                 0
      70            8 :             );
      71            8 :             let chunk_len = chunk.len();
      72            8 :             let (nwritten, chunk) = self.writer.write_all(chunk).await?;
      73            8 :             assert_eq!(nwritten, chunk_len);
      74            8 :             drop(chunk);
      75            8 :             return Ok(());
      76           36 :         }
      77           36 :         // in-memory copy the < BUFFER_SIZED tail of the chunk
      78           36 :         assert!(chunk.len() < BUFFER_SIZE);
      79           36 :         let mut chunk = &chunk[..];
      80           70 :         while !chunk.is_empty() {
      81           34 :             let buf = self.buf.as_mut().expect("must not use after an error");
      82           34 :             let need = BUFFER_SIZE - buf.len();
      83           34 :             let have = chunk.len();
      84           34 :             let n = std::cmp::min(need, have);
      85           34 :             buf.extend_from_slice(&chunk[..n]);
      86           34 :             chunk = &chunk[n..];
      87           34 :             if buf.len() >= BUFFER_SIZE {
      88            6 :                 assert_eq!(buf.len(), BUFFER_SIZE);
      89            6 :                 self.flush().await?;
      90           28 :             }
      91              :         }
      92           36 :         assert!(chunk.is_empty(), "by now we should have drained the chunk");
      93           36 :         Ok(())
      94           44 :     }
      95              : 
      96           23 :     async fn flush(&mut self) -> std::io::Result<()> {
      97           23 :         let buf = self.buf.take().expect("must not use after an error");
      98           23 :         if buf.is_empty() {
      99           10 :             self.buf = Some(buf);
     100           10 :             return std::io::Result::Ok(());
     101           13 :         }
     102           13 :         let buf_len = buf.len();
     103           13 :         let (nwritten, mut buf) = self.writer.write_all(buf).await?;
     104           13 :         assert_eq!(nwritten, buf_len);
     105           13 :         buf.clear();
     106           13 :         self.buf = Some(buf);
     107           13 :         Ok(())
     108           23 :     }
     109              : }
     110              : 
     111              : impl OwnedAsyncWriter for Vec<u8> {
     112            0 :     async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
     113            0 :         &mut self,
     114            0 :         buf: B,
     115            0 :     ) -> std::io::Result<(usize, B::Buf)> {
     116            0 :         let nbytes = buf.bytes_init();
     117            0 :         if nbytes == 0 {
     118            0 :             return Ok((0, Slice::into_inner(buf.slice_full())));
     119            0 :         }
     120            0 :         let buf = buf.slice(0..nbytes);
     121            0 :         self.extend_from_slice(&buf[..]);
     122            0 :         Ok((buf.len(), Slice::into_inner(buf)))
     123            0 :     }
     124              : }
     125              : 
     126              : #[cfg(test)]
     127              : mod tests {
     128              :     use super::*;
     129              : 
     130              :     #[derive(Default)]
     131              :     struct RecorderWriter {
     132              :         writes: Vec<Vec<u8>>,
     133              :     }
     134              :     impl OwnedAsyncWriter for RecorderWriter {
     135           18 :         async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
     136           18 :             &mut self,
     137           18 :             buf: B,
     138           18 :         ) -> std::io::Result<(usize, B::Buf)> {
     139           18 :             let nbytes = buf.bytes_init();
     140           18 :             if nbytes == 0 {
     141            0 :                 self.writes.push(vec![]);
     142            0 :                 return Ok((0, Slice::into_inner(buf.slice_full())));
     143           18 :             }
     144           18 :             let buf = buf.slice(0..nbytes);
     145           18 :             self.writes.push(Vec::from(&buf[..]));
     146           18 :             Ok((buf.len(), Slice::into_inner(buf)))
     147           18 :         }
     148              :     }
     149              : 
     150              :     macro_rules! write {
     151              :         ($writer:ident, $data:literal) => {{
     152              :             $writer
     153              :                 .write_buffered(::bytes::Bytes::from_static($data).slice_full())
     154              :                 .await?;
     155              :         }};
     156              :     }
     157              : 
     158              :     #[tokio::test]
     159            2 :     async fn test_buffered_writes_only() -> std::io::Result<()> {
     160            2 :         let recorder = RecorderWriter::default();
     161            2 :         let mut writer = BufferedWriter::<2, _>::new(recorder);
     162            2 :         write!(writer, b"a");
     163            2 :         write!(writer, b"b");
     164            2 :         write!(writer, b"c");
     165            2 :         write!(writer, b"d");
     166            2 :         write!(writer, b"e");
     167            2 :         let recorder = writer.flush_and_into_inner().await?;
     168            2 :         assert_eq!(
     169            2 :             recorder.writes,
     170            2 :             vec![Vec::from(b"ab"), Vec::from(b"cd"), Vec::from(b"e")]
     171            2 :         );
     172            2 :         Ok(())
     173            2 :     }
     174              : 
     175              :     #[tokio::test]
     176            2 :     async fn test_passthrough_writes_only() -> std::io::Result<()> {
     177            2 :         let recorder = RecorderWriter::default();
     178            2 :         let mut writer = BufferedWriter::<2, _>::new(recorder);
     179            2 :         write!(writer, b"abc");
     180            2 :         write!(writer, b"de");
     181            2 :         write!(writer, b"");
     182            2 :         write!(writer, b"fghijk");
     183            2 :         let recorder = writer.flush_and_into_inner().await?;
     184            2 :         assert_eq!(
     185            2 :             recorder.writes,
     186            2 :             vec![Vec::from(b"abc"), Vec::from(b"de"), Vec::from(b"fghijk")]
     187            2 :         );
     188            2 :         Ok(())
     189            2 :     }
     190              : 
     191              :     #[tokio::test]
     192            2 :     async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> {
     193            2 :         let recorder = RecorderWriter::default();
     194            2 :         let mut writer = BufferedWriter::<2, _>::new(recorder);
     195            2 :         write!(writer, b"a");
     196            2 :         write!(writer, b"bc");
     197            2 :         write!(writer, b"d");
     198            2 :         write!(writer, b"e");
     199            2 :         let recorder = writer.flush_and_into_inner().await?;
     200            2 :         assert_eq!(
     201            2 :             recorder.writes,
     202            2 :             vec![Vec::from(b"a"), Vec::from(b"bc"), Vec::from(b"de")]
     203            2 :         );
     204            2 :         Ok(())
     205            2 :     }
     206              : }
        

Generated by: LCOV version 2.1-beta