LCOV - code coverage report
Current view: top level - libs/pq_proto/src - framed.rs (source / functions) Coverage Total Hit
Test: 09e7485004805bd42b53a0c369170b3228136512.info Lines: 56.8 % 125 71
Test Date: 2024-11-21 18:36:18 Functions: 24.4 % 201 49

            Line data    Source code
       1              : //! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
       2              : //! the async stream based on (and buffered with) BytesMut. All functions are
       3              : //! cancellation safe.
       4              : //!
       5              : //! It is similar to what tokio_util::codec::Framed with appropriate codec
       6              : //! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
       7              : //! separately without using split from futures::stream::StreamExt (which
       8              : //! allocates a [Box] in polling internally). tokio::io::split is used for splitting
       9              : //! instead. Plus we customize error messages more than a single type for all io
      10              : //! calls.
      11              : //!
      12              : //! [Box]: https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
      13              : use bytes::{Buf, BytesMut};
      14              : use std::{
      15              :     future::Future,
      16              :     io::{self, ErrorKind},
      17              : };
      18              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
      19              : 
      20              : use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
      21              : 
      22              : const INITIAL_CAPACITY: usize = 8 * 1024;
      23              : 
      24              : /// Error on postgres connection: either IO (physical transport error) or
      25              : /// protocol violation.
      26            0 : #[derive(thiserror::Error, Debug)]
      27              : pub enum ConnectionError {
      28              :     #[error(transparent)]
      29              :     Io(#[from] io::Error),
      30              :     #[error(transparent)]
      31              :     Protocol(#[from] ProtocolError),
      32              : }
      33              : 
      34              : impl ConnectionError {
      35              :     /// Proxy stream.rs uses only io::Error; provide it.
      36            1 :     pub fn into_io_error(self) -> io::Error {
      37            1 :         match self {
      38            1 :             ConnectionError::Io(io) => io,
      39            0 :             ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
      40              :         }
      41            1 :     }
      42              : }
      43              : 
      44              : /// Wraps async io `stream`, providing messages to write/flush + read Postgres
      45              : /// messages.
      46              : pub struct Framed<S> {
      47              :     pub stream: S,
      48              :     pub read_buf: BytesMut,
      49              :     pub write_buf: BytesMut,
      50              : }
      51              : 
      52              : impl<S> Framed<S> {
      53           27 :     pub fn new(stream: S) -> Self {
      54           27 :         Self {
      55           27 :             stream,
      56           27 :             read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
      57           27 :             write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
      58           27 :         }
      59           27 :     }
      60              : 
      61              :     /// Get a shared reference to the underlying stream.
      62           35 :     pub fn get_ref(&self) -> &S {
      63           35 :         &self.stream
      64           35 :     }
      65              : 
      66              :     /// Deconstruct into the underlying stream and read buffer.
      67            7 :     pub fn into_inner(self) -> (S, BytesMut) {
      68            7 :         (self.stream, self.read_buf)
      69            7 :     }
      70              : 
      71              :     /// Return new Framed with stream type transformed by async f, for TLS
      72              :     /// upgrade.
      73            1 :     pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
      74            1 :     where
      75            1 :         F: FnOnce(S) -> Fut,
      76            1 :         Fut: Future<Output = Result<S2, E>>,
      77            1 :     {
      78            2 :         let stream = f(self.stream).await?;
      79            1 :         Ok(Framed {
      80            1 :             stream,
      81            1 :             read_buf: self.read_buf,
      82            1 :             write_buf: self.write_buf,
      83            1 :         })
      84            1 :     }
      85              : }
      86              : 
      87              : impl<S: AsyncRead + Unpin> Framed<S> {
      88           45 :     pub async fn read_startup_message(
      89           45 :         &mut self,
      90           45 :     ) -> Result<Option<FeStartupPacket>, ConnectionError> {
      91           45 :         read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
      92           45 :     }
      93              : 
      94           30 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
      95           30 :         read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
      96           28 :     }
      97              : }
      98              : 
      99              : impl<S: AsyncWrite + Unpin> Framed<S> {
     100              :     /// Write next message to the output buffer; doesn't flush.
     101           96 :     pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
     102           96 :         BeMessage::write(&mut self.write_buf, msg)
     103           96 :     }
     104              : 
     105              :     /// Flush out the buffer. This function is cancellation safe: it can be
     106              :     /// interrupted and flushing will be continued in the next call.
     107           65 :     pub async fn flush(&mut self) -> Result<(), io::Error> {
     108           65 :         flush(&mut self.stream, &mut self.write_buf).await
     109           65 :     }
     110              : 
     111              :     /// Flush out the buffer and shutdown the stream.
     112            0 :     pub async fn shutdown(&mut self) -> Result<(), io::Error> {
     113            0 :         shutdown(&mut self.stream, &mut self.write_buf).await
     114            0 :     }
     115              : }
     116              : 
     117              : impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
     118              :     /// Split into owned read and write parts. Beware of potential issues with
     119              :     /// using halves in different tasks on TLS stream:
     120              :     /// <https://github.com/tokio-rs/tls/issues/40>
     121            0 :     pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
     122            0 :         let (read_half, write_half) = tokio::io::split(self.stream);
     123            0 :         let reader = FramedReader {
     124            0 :             stream: read_half,
     125            0 :             read_buf: self.read_buf,
     126            0 :         };
     127            0 :         let writer = FramedWriter {
     128            0 :             stream: write_half,
     129            0 :             write_buf: self.write_buf,
     130            0 :         };
     131            0 :         (reader, writer)
     132            0 :     }
     133              : 
     134              :     /// Join read and write parts back.
     135            0 :     pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
     136            0 :         Self {
     137            0 :             stream: reader.stream.unsplit(writer.stream),
     138            0 :             read_buf: reader.read_buf,
     139            0 :             write_buf: writer.write_buf,
     140            0 :         }
     141            0 :     }
     142              : }
     143              : 
     144              : /// Read-only version of `Framed`.
     145              : pub struct FramedReader<S> {
     146              :     stream: ReadHalf<S>,
     147              :     read_buf: BytesMut,
     148              : }
     149              : 
     150              : impl<S: AsyncRead + Unpin> FramedReader<S> {
     151            0 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
     152            0 :         read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
     153            0 :     }
     154              : }
     155              : 
     156              : /// Write-only version of `Framed`.
     157              : pub struct FramedWriter<S> {
     158              :     stream: WriteHalf<S>,
     159              :     write_buf: BytesMut,
     160              : }
     161              : 
     162              : impl<S: AsyncWrite + Unpin> FramedWriter<S> {
     163              :     /// Write next message to the output buffer; doesn't flush.
     164            0 :     pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
     165            0 :         BeMessage::write(&mut self.write_buf, msg)
     166            0 :     }
     167              : 
     168              :     /// Flush out the buffer. This function is cancellation safe: it can be
     169              :     /// interrupted and flushing will be continued in the next call.
     170            0 :     pub async fn flush(&mut self) -> Result<(), io::Error> {
     171            0 :         flush(&mut self.stream, &mut self.write_buf).await
     172            0 :     }
     173              : 
     174              :     /// Flush out the buffer and shutdown the stream.
     175            0 :     pub async fn shutdown(&mut self) -> Result<(), io::Error> {
     176            0 :         shutdown(&mut self.stream, &mut self.write_buf).await
     177            0 :     }
     178              : }
     179              : 
     180              : /// Read next message from the stream. Returns Ok(None), if EOF happened and we
     181              : /// don't have remaining data in the buffer. This function is cancellation safe:
     182              : /// you can drop future which is not yet complete and finalize reading message
     183              : /// with the next call.
     184              : ///
     185              : /// Parametrized to allow reading startup or usual message, having different
     186              : /// format.
     187           75 : async fn read_message<S: AsyncRead + Unpin, M, P>(
     188           75 :     stream: &mut S,
     189           75 :     read_buf: &mut BytesMut,
     190           75 :     parse: P,
     191           75 : ) -> Result<Option<M>, ConnectionError>
     192           75 : where
     193           75 :     P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
     194           75 : {
     195              :     loop {
     196          147 :         if let Some(msg) = parse(read_buf)? {
     197           72 :             return Ok(Some(msg));
     198           75 :         }
     199           75 :         // If we can't build a frame yet, try to read more data and try again.
     200           75 :         // Make sure we've got room for at least one byte to read to ensure
     201           75 :         // that we don't get a spurious 0 that looks like EOF.
     202           75 :         read_buf.reserve(1);
     203           75 :         if stream.read_buf(read_buf).await? == 0 {
     204            0 :             if read_buf.has_remaining() {
     205            0 :                 return Err(io::Error::new(
     206            0 :                     ErrorKind::UnexpectedEof,
     207            0 :                     "EOF with unprocessed data in the buffer",
     208            0 :                 )
     209            0 :                 .into());
     210              :             } else {
     211            0 :                 return Ok(None); // clean EOF
     212              :             }
     213           72 :         }
     214              :     }
     215           73 : }
     216              : 
     217              : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
     218           65 : async fn flush<S: AsyncWrite + Unpin>(
     219           65 :     stream: &mut S,
     220           65 :     write_buf: &mut BytesMut,
     221           65 : ) -> Result<(), io::Error> {
     222          130 :     while write_buf.has_remaining() {
     223           65 :         let bytes_written = stream.write_buf(write_buf).await?;
     224           65 :         if bytes_written == 0 {
     225            0 :             return Err(io::Error::new(
     226            0 :                 ErrorKind::WriteZero,
     227            0 :                 "failed to write message",
     228            0 :             ));
     229           65 :         }
     230              :     }
     231           65 :     stream.flush().await
     232           65 : }
     233              : 
     234              : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
     235            0 : async fn shutdown<S: AsyncWrite + Unpin>(
     236            0 :     stream: &mut S,
     237            0 :     write_buf: &mut BytesMut,
     238            0 : ) -> Result<(), io::Error> {
     239            0 :     flush(stream, write_buf).await?;
     240            0 :     stream.shutdown().await
     241            0 : }
        

Generated by: LCOV version 2.1-beta