LCOV - code coverage report
Current view: top level - libs/pq_proto/src - framed.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 89.6 % 125 112
Test Date: 2024-02-12 20:26:03 Functions: 62.8 % 199 125

            Line data    Source code
       1              : //! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
       2              : //! the async stream based on (and buffered with) BytesMut. All functions are
       3              : //! cancellation safe.
       4              : //!
       5              : //! It is similar to what tokio_util::codec::Framed with appropriate codec
       6              : //! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
       7              : //! separately without using split from futures::stream::StreamExt (which
       8              : //! allocates a [Box] in polling internally). tokio::io::split is used for splitting
       9              : //! instead. Plus we customize error messages more than a single type for all io
      10              : //! calls.
      11              : //!
      12              : //! [Box]: https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
      13              : use bytes::{Buf, BytesMut};
      14              : use std::{
      15              :     future::Future,
      16              :     io::{self, ErrorKind},
      17              : };
      18              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
      19              : 
      20              : use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
      21              : 
      22              : const INITIAL_CAPACITY: usize = 8 * 1024;
      23              : 
      24              : /// Error on postgres connection: either IO (physical transport error) or
      25              : /// protocol violation.
      26          748 : #[derive(thiserror::Error, Debug)]
      27              : pub enum ConnectionError {
      28              :     #[error(transparent)]
      29              :     Io(#[from] io::Error),
      30              :     #[error(transparent)]
      31              :     Protocol(#[from] ProtocolError),
      32              : }
      33              : 
      34              : impl ConnectionError {
      35              :     /// Proxy stream.rs uses only io::Error; provide it.
      36            2 :     pub fn into_io_error(self) -> io::Error {
      37            2 :         match self {
      38            2 :             ConnectionError::Io(io) => io,
      39            0 :             ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
      40              :         }
      41            2 :     }
      42              : }
      43              : 
      44              : /// Wraps async io `stream`, providing messages to write/flush + read Postgres
      45              : /// messages.
      46              : pub struct Framed<S> {
      47              :     stream: S,
      48              :     read_buf: BytesMut,
      49              :     write_buf: BytesMut,
      50              : }
      51              : 
      52              : impl<S> Framed<S> {
      53        14062 :     pub fn new(stream: S) -> Self {
      54        14062 :         Self {
      55        14062 :             stream,
      56        14062 :             read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
      57        14062 :             write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
      58        14062 :         }
      59        14062 :     }
      60              : 
      61              :     /// Get a shared reference to the underlying stream.
      62          210 :     pub fn get_ref(&self) -> &S {
      63          210 :         &self.stream
      64          210 :     }
      65              : 
      66              :     /// Deconstruct into the underlying stream and read buffer.
      67          148 :     pub fn into_inner(self) -> (S, BytesMut) {
      68          148 :         (self.stream, self.read_buf)
      69          148 :     }
      70              : 
      71              :     /// Return new Framed with stream type transformed by async f, for TLS
      72              :     /// upgrade.
      73            2 :     pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
      74            2 :     where
      75            2 :         F: FnOnce(S) -> Fut,
      76            2 :         Fut: Future<Output = Result<S2, E>>,
      77            2 :     {
      78            4 :         let stream = f(self.stream).await?;
      79            2 :         Ok(Framed {
      80            2 :             stream,
      81            2 :             read_buf: self.read_buf,
      82            2 :             write_buf: self.write_buf,
      83            2 :         })
      84            2 :     }
      85              : }
      86              : 
      87              : impl<S: AsyncRead + Unpin> Framed<S> {
      88        25897 :     pub async fn read_startup_message(
      89        25897 :         &mut self,
      90        25897 :     ) -> Result<Option<FeStartupPacket>, ConnectionError> {
      91        25897 :         read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
      92        25897 :     }
      93              : 
      94      4645001 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
      95      4645001 :         read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
      96      4644796 :     }
      97              : }
      98              : 
      99              : impl<S: AsyncWrite + Unpin> Framed<S> {
     100              :     /// Write next message to the output buffer; doesn't flush.
     101      5596960 :     pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
     102      5596960 :         BeMessage::write(&mut self.write_buf, msg)
     103      5596960 :     }
     104              : 
     105              :     /// Flush out the buffer. This function is cancellation safe: it can be
     106              :     /// interrupted and flushing will be continued in the next call.
     107      5586816 :     pub async fn flush(&mut self) -> Result<(), io::Error> {
     108      5586816 :         flush(&mut self.stream, &mut self.write_buf).await
     109      5578568 :     }
     110              : 
     111              :     /// Flush out the buffer and shutdown the stream.
     112        13406 :     pub async fn shutdown(&mut self) -> Result<(), io::Error> {
     113        13406 :         shutdown(&mut self.stream, &mut self.write_buf).await
     114        13406 :     }
     115              : }
     116              : 
     117              : impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
     118              :     /// Split into owned read and write parts. Beware of potential issues with
     119              :     /// using halves in different tasks on TLS stream:
     120              :     /// <https://github.com/tokio-rs/tls/issues/40>
     121         2604 :     pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
     122         2604 :         let (read_half, write_half) = tokio::io::split(self.stream);
     123         2604 :         let reader = FramedReader {
     124         2604 :             stream: read_half,
     125         2604 :             read_buf: self.read_buf,
     126         2604 :         };
     127         2604 :         let writer = FramedWriter {
     128         2604 :             stream: write_half,
     129         2604 :             write_buf: self.write_buf,
     130         2604 :         };
     131         2604 :         (reader, writer)
     132         2604 :     }
     133              : 
     134              :     /// Join read and write parts back.
     135         2168 :     pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
     136         2168 :         Self {
     137         2168 :             stream: reader.stream.unsplit(writer.stream),
     138         2168 :             read_buf: reader.read_buf,
     139         2168 :             write_buf: writer.write_buf,
     140         2168 :         }
     141         2168 :     }
     142              : }
     143              : 
     144              : /// Read-only version of `Framed`.
     145              : pub struct FramedReader<S> {
     146              :     stream: ReadHalf<S>,
     147              :     read_buf: BytesMut,
     148              : }
     149              : 
     150              : impl<S: AsyncRead + Unpin> FramedReader<S> {
     151      3477763 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
     152      6545416 :         read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
     153      3477293 :     }
     154              : }
     155              : 
     156              : /// Write-only version of `Framed`.
     157              : pub struct FramedWriter<S> {
     158              :     stream: WriteHalf<S>,
     159              :     write_buf: BytesMut,
     160              : }
     161              : 
     162              : impl<S: AsyncWrite + Unpin> FramedWriter<S> {
     163              :     /// Write next message to the output buffer; doesn't flush.
     164      2466754 :     pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
     165      2466754 :         BeMessage::write(&mut self.write_buf, msg)
     166      2466754 :     }
     167              : 
     168              :     /// Flush out the buffer. This function is cancellation safe: it can be
     169              :     /// interrupted and flushing will be continued in the next call.
     170      2466754 :     pub async fn flush(&mut self) -> Result<(), io::Error> {
     171      2466754 :         flush(&mut self.stream, &mut self.write_buf).await
     172      2466738 :     }
     173              : 
     174              :     /// Flush out the buffer and shutdown the stream.
     175            0 :     pub async fn shutdown(&mut self) -> Result<(), io::Error> {
     176            0 :         shutdown(&mut self.stream, &mut self.write_buf).await
     177            0 :     }
     178              : }
     179              : 
     180              : /// Read next message from the stream. Returns Ok(None), if EOF happened and we
     181              : /// don't have remaining data in the buffer. This function is cancellation safe:
     182              : /// you can drop future which is not yet complete and finalize reading message
     183              : /// with the next call.
     184              : ///
     185              : /// Parametrized to allow reading startup or usual message, having different
     186              : /// format.
     187      8148661 : async fn read_message<S: AsyncRead + Unpin, M, P>(
     188      8148661 :     stream: &mut S,
     189      8148661 :     read_buf: &mut BytesMut,
     190      8148661 :     parse: P,
     191      8148661 : ) -> Result<Option<M>, ConnectionError>
     192      8148661 : where
     193      8148661 :     P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
     194      8148661 : {
     195              :     loop {
     196     15233966 :         if let Some(msg) = parse(read_buf)? {
     197      8125715 :             return Ok(Some(msg));
     198      7108251 :         }
     199      7108251 :         // If we can't build a frame yet, try to read more data and try again.
     200      7108251 :         // Make sure we've got room for at least one byte to read to ensure
     201      7108251 :         // that we don't get a spurious 0 that looks like EOF.
     202      7108251 :         read_buf.reserve(1);
     203     10810332 :         if stream.read_buf(read_buf).await? == 0 {
     204        21977 :             if read_buf.has_remaining() {
     205            0 :                 return Err(io::Error::new(
     206            0 :                     ErrorKind::UnexpectedEof,
     207            0 :                     "EOF with unprocessed data in the buffer",
     208            0 :                 )
     209            0 :                 .into());
     210              :             } else {
     211        21977 :                 return Ok(None); // clean EOF
     212              :             }
     213      7085305 :         }
     214              :     }
     215      8147986 : }
     216              : 
     217              : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
     218      8066976 : async fn flush<S: AsyncWrite + Unpin>(
     219      8066976 :     stream: &mut S,
     220      8066976 :     write_buf: &mut BytesMut,
     221      8066976 : ) -> Result<(), io::Error> {
     222     16073917 :     while write_buf.has_remaining() {
     223      8016014 :         let bytes_written = stream.write_buf(write_buf).await?;
     224      8006941 :         if bytes_written == 0 {
     225            0 :             return Err(io::Error::new(
     226            0 :                 ErrorKind::WriteZero,
     227            0 :                 "failed to write message",
     228            0 :             ));
     229      8006941 :         }
     230              :     }
     231      8057903 :     stream.flush().await
     232      8058712 : }
     233              : 
     234              : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
     235        13406 : async fn shutdown<S: AsyncWrite + Unpin>(
     236        13406 :     stream: &mut S,
     237        13406 :     write_buf: &mut BytesMut,
     238        13406 : ) -> Result<(), io::Error> {
     239        13406 :     flush(stream, write_buf).await?;
     240        13047 :     stream.shutdown().await
     241        13406 : }
        

Generated by: LCOV version 2.1-beta