LCOV - differential code coverage report
Current view: top level - libs/pq_proto/src - framed.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 86.9 % 130 113 17 113
Current Date: 2023-10-19 02:04:12 Functions: 59.9 % 187 112 75 112
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : //! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
       2                 : //! the async stream based on (and buffered with) BytesMut. All functions are
       3                 : //! cancellation safe.
       4                 : //!
       5                 : //! It is similar to what tokio_util::codec::Framed with appropriate codec
       6                 : //! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
       7                 : //! separately without using split from futures::stream::StreamExt (which
       8                 : //! allocates a [Box] in polling internally). tokio::io::split is used for splitting
       9                 : //! instead. Plus we customize error messages more than a single type for all io
      10                 : //! calls.
      11                 : //!
      12                 : //! [Box]: https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
      13                 : use bytes::{Buf, BytesMut};
      14                 : use std::{
      15                 :     future::Future,
      16                 :     io::{self, ErrorKind},
      17                 : };
      18                 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
      19                 : 
      20                 : use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
      21                 : 
      22                 : const INITIAL_CAPACITY: usize = 8 * 1024;
      23                 : 
      24                 : /// Error on postgres connection: either IO (physical transport error) or
      25                 : /// protocol violation.
      26 CBC         760 : #[derive(thiserror::Error, Debug)]
      27                 : pub enum ConnectionError {
      28                 :     #[error(transparent)]
      29                 :     Io(#[from] io::Error),
      30                 :     #[error(transparent)]
      31                 :     Protocol(#[from] ProtocolError),
      32                 : }
      33                 : 
      34                 : impl ConnectionError {
      35                 :     /// Proxy stream.rs uses only io::Error; provide it.
      36 UBC           0 :     pub fn into_io_error(self) -> io::Error {
      37               0 :         match self {
      38               0 :             ConnectionError::Io(io) => io,
      39               0 :             ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
      40                 :         }
      41               0 :     }
      42                 : }
      43                 : 
      44                 : /// Wraps async io `stream`, providing messages to write/flush + read Postgres
      45                 : /// messages.
      46                 : pub struct Framed<S> {
      47                 :     stream: S,
      48                 :     read_buf: BytesMut,
      49                 :     write_buf: BytesMut,
      50                 : }
      51                 : 
      52                 : impl<S> Framed<S> {
      53 CBC        8769 :     pub fn new(stream: S) -> Self {
      54            8769 :         Self {
      55            8769 :             stream,
      56            8769 :             read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
      57            8769 :             write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
      58            8769 :         }
      59            8769 :     }
      60                 : 
      61                 :     /// Get a shared reference to the underlying stream.
      62              71 :     pub fn get_ref(&self) -> &S {
      63              71 :         &self.stream
      64              71 :     }
      65                 : 
      66                 :     /// Deconstruct into the underlying stream and read buffer.
      67              68 :     pub fn into_inner(self) -> (S, BytesMut) {
      68              68 :         (self.stream, self.read_buf)
      69              68 :     }
      70                 : 
      71                 :     /// Return new Framed with stream type transformed by async f, for TLS
      72                 :     /// upgrade.
      73               1 :     pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
      74               1 :     where
      75               1 :         F: FnOnce(S) -> Fut,
      76               1 :         Fut: Future<Output = Result<S2, E>>,
      77               1 :     {
      78               2 :         let stream = f(self.stream).await?;
      79               1 :         Ok(Framed {
      80               1 :             stream,
      81               1 :             read_buf: self.read_buf,
      82               1 :             write_buf: self.write_buf,
      83               1 :         })
      84               1 :     }
      85                 : }
      86                 : 
      87                 : impl<S: AsyncRead + Unpin> Framed<S> {
      88           15383 :     pub async fn read_startup_message(
      89           15383 :         &mut self,
      90           15383 :     ) -> Result<Option<FeStartupPacket>, ConnectionError> {
      91           15383 :         read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
      92           15383 :     }
      93                 : 
      94         3831595 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
      95         3831595 :         read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
      96         3831446 :     }
      97                 : }
      98                 : 
      99                 : impl<S: AsyncWrite + Unpin> Framed<S> {
     100                 :     /// Write next message to the output buffer; doesn't flush.
     101         4784406 :     pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
     102         4784406 :         BeMessage::write(&mut self.write_buf, msg)
     103         4784406 :     }
     104                 : 
     105                 :     /// Flush out the buffer. This function is cancellation safe: it can be
     106                 :     /// interrupted and flushing will be continued in the next call.
     107         4790278 :     pub async fn flush(&mut self) -> Result<(), io::Error> {
     108         4790278 :         flush(&mut self.stream, &mut self.write_buf).await
     109         4782625 :     }
     110                 : 
     111                 :     /// Flush out the buffer and shutdown the stream.
     112            8210 :     pub async fn shutdown(&mut self) -> Result<(), io::Error> {
     113            8210 :         shutdown(&mut self.stream, &mut self.write_buf).await
     114            8210 :     }
     115                 : }
     116                 : 
     117                 : impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
     118                 :     /// Split into owned read and write parts. Beware of potential issues with
     119                 :     /// using halves in different tasks on TLS stream:
     120                 :     /// <https://github.com/tokio-rs/tls/issues/40>
     121            2796 :     pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
     122            2796 :         let (read_half, write_half) = tokio::io::split(self.stream);
     123            2796 :         let reader = FramedReader {
     124            2796 :             stream: read_half,
     125            2796 :             read_buf: self.read_buf,
     126            2796 :         };
     127            2796 :         let writer = FramedWriter {
     128            2796 :             stream: write_half,
     129            2796 :             write_buf: self.write_buf,
     130            2796 :         };
     131            2796 :         (reader, writer)
     132            2796 :     }
     133                 : 
     134                 :     /// Join read and write parts back.
     135            2399 :     pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
     136            2399 :         Self {
     137            2399 :             stream: reader.stream.unsplit(writer.stream),
     138            2399 :             read_buf: reader.read_buf,
     139            2399 :             write_buf: writer.write_buf,
     140            2399 :         }
     141            2399 :     }
     142                 : }
     143                 : 
     144                 : /// Read-only version of `Framed`.
     145                 : pub struct FramedReader<S> {
     146                 :     stream: ReadHalf<S>,
     147                 :     read_buf: BytesMut,
     148                 : }
     149                 : 
     150                 : impl<S: AsyncRead + Unpin> FramedReader<S> {
     151         3700009 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
     152         8011620 :         read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
     153         3699542 :     }
     154                 : }
     155                 : 
     156                 : /// Write-only version of `Framed`.
     157                 : pub struct FramedWriter<S> {
     158                 :     stream: WriteHalf<S>,
     159                 :     write_buf: BytesMut,
     160                 : }
     161                 : 
     162                 : impl<S: AsyncWrite + Unpin> FramedWriter<S> {
     163                 :     /// Write next message to the output buffer; doesn't flush.
     164         3249357 :     pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
     165         3249357 :         BeMessage::write(&mut self.write_buf, msg)
     166         3249357 :     }
     167                 : 
     168                 :     /// Flush out the buffer. This function is cancellation safe: it can be
     169                 :     /// interrupted and flushing will be continued in the next call.
     170         3249357 :     pub async fn flush(&mut self) -> Result<(), io::Error> {
     171         3249357 :         flush(&mut self.stream, &mut self.write_buf).await
     172         3249346 :     }
     173                 : 
     174                 :     /// Flush out the buffer and shutdown the stream.
     175 UBC           0 :     pub async fn shutdown(&mut self) -> Result<(), io::Error> {
     176               0 :         shutdown(&mut self.stream, &mut self.write_buf).await
     177               0 :     }
     178                 : }
     179                 : 
     180                 : /// Read next message from the stream. Returns Ok(None), if EOF happened and we
     181                 : /// don't have remaining data in the buffer. This function is cancellation safe:
     182                 : /// you can drop future which is not yet complete and finalize reading message
     183                 : /// with the next call.
     184                 : ///
     185                 : /// Parametrized to allow reading startup or usual message, having different
     186                 : /// format.
     187 CBC     7546987 : async fn read_message<S: AsyncRead + Unpin, M, P>(
     188         7546987 :     stream: &mut S,
     189         7546987 :     read_buf: &mut BytesMut,
     190         7546987 :     parse: P,
     191         7546987 : ) -> Result<Option<M>, ConnectionError>
     192         7546987 : where
     193         7546987 :     P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
     194         7546987 : {
     195                 :     loop {
     196        14473703 :         if let Some(msg) = parse(read_buf)? {
     197         7533778 :             return Ok(Some(msg));
     198         6939925 :         }
     199         6939925 :         // If we can't build a frame yet, try to read more data and try again.
     200         6939925 :         // Make sure we've got room for at least one byte to read to ensure
     201         6939925 :         // that we don't get a spurious 0 that looks like EOF.
     202         6939925 :         read_buf.reserve(1);
     203        11449638 :         if stream.read_buf(read_buf).await? == 0 {
     204           12213 :             if read_buf.has_remaining() {
     205 UBC           0 :                 return Err(io::Error::new(
     206               0 :                     ErrorKind::UnexpectedEof,
     207               0 :                     "EOF with unprocessed data in the buffer",
     208               0 :                 )
     209               0 :                 .into());
     210                 :             } else {
     211 CBC       12213 :                 return Ok(None); // clean EOF
     212                 :             }
     213         6926716 :         }
     214                 :     }
     215         7546371 : }
     216                 : 
     217         8047845 : async fn flush<S: AsyncWrite + Unpin>(
     218         8047845 :     stream: &mut S,
     219         8047845 :     write_buf: &mut BytesMut,
     220         8047845 : ) -> Result<(), io::Error> {
     221        16044811 :     while write_buf.has_remaining() {
     222         8005363 :         let bytes_written = stream.write(write_buf.chunk()).await?;
     223         7996966 :         if bytes_written == 0 {
     224 UBC           0 :             return Err(io::Error::new(
     225               0 :                 ErrorKind::WriteZero,
     226               0 :                 "failed to write message",
     227               0 :             ));
     228 CBC     7996966 :         }
     229         7996966 :         // The advanced part will be garbage collected, likely during shifting
     230         7996966 :         // data left on next attempt to write to buffer when free space is not
     231         7996966 :         // enough.
     232         7996966 :         write_buf.advance(bytes_written);
     233                 :     }
     234         8039448 :     write_buf.clear();
     235         8039448 :     stream.flush().await
     236         8040181 : }
     237                 : 
     238            8210 : async fn shutdown<S: AsyncWrite + Unpin>(
     239            8210 :     stream: &mut S,
     240            8210 :     write_buf: &mut BytesMut,
     241            8210 : ) -> Result<(), io::Error> {
     242            8210 :     flush(stream, write_buf).await?;
     243            7850 :     stream.shutdown().await
     244            8210 : }
        

Generated by: LCOV version 2.1-beta