|             Line data    Source code 
       1              : //! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
       2              : //! the async stream based on (and buffered with) BytesMut. All functions are
       3              : //! cancellation safe.
       4              : //!
       5              : //! It is similar to what tokio_util::codec::Framed with appropriate codec
       6              : //! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
       7              : //! separately without using split from futures::stream::StreamExt (which
       8              : //! allocates a [Box] in polling internally). tokio::io::split is used for splitting
       9              : //! instead. Plus we customize error messages more than a single type for all io
      10              : //! calls.
      11              : //!
      12              : //! [Box]: https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
      13              : use std::future::Future;
      14              : use std::io::{self, ErrorKind};
      15              : 
      16              : use bytes::{Buf, BytesMut};
      17              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
      18              : 
      19              : use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
      20              : 
      21              : const INITIAL_CAPACITY: usize = 8 * 1024;
      22              : 
      23              : /// Error on postgres connection: either IO (physical transport error) or
      24              : /// protocol violation.
      25              : #[derive(thiserror::Error, Debug)]
      26              : pub enum ConnectionError {
      27              :     #[error(transparent)]
      28              :     Io(#[from] io::Error),
      29              :     #[error(transparent)]
      30              :     Protocol(#[from] ProtocolError),
      31              : }
      32              : 
      33              : impl ConnectionError {
      34              :     /// Proxy stream.rs uses only io::Error; provide it.
      35            0 :     pub fn into_io_error(self) -> io::Error {
      36            0 :         match self {
      37            0 :             ConnectionError::Io(io) => io,
      38            0 :             ConnectionError::Protocol(pe) => io::Error::other(pe.to_string()),
      39              :         }
      40            0 :     }
      41              : }
      42              : 
      43              : /// Wraps async io `stream`, providing messages to write/flush + read Postgres
      44              : /// messages.
      45              : pub struct Framed<S> {
      46              :     pub stream: S,
      47              :     pub read_buf: BytesMut,
      48              :     pub write_buf: BytesMut,
      49              : }
      50              : 
      51              : impl<S> Framed<S> {
      52            2 :     pub fn new(stream: S) -> Self {
      53            2 :         Self {
      54            2 :             stream,
      55            2 :             read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
      56            2 :             write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
      57            2 :         }
      58            0 :     }
      59              : 
      60              :     /// Get a shared reference to the underlying stream.
      61            0 :     pub fn get_ref(&self) -> &S {
      62            0 :         &self.stream
      63            0 :     }
      64              : 
      65              :     /// Deconstruct into the underlying stream and read buffer.
      66            0 :     pub fn into_inner(self) -> (S, BytesMut) {
      67            0 :         (self.stream, self.read_buf)
      68            0 :     }
      69              : 
      70              :     /// Return new Framed with stream type transformed by async f, for TLS
      71              :     /// upgrade.
      72            1 :     pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
      73            1 :     where
      74            1 :         F: FnOnce(S) -> Fut,
      75            1 :         Fut: Future<Output = Result<S2, E>>,
      76            0 :     {
      77            1 :         let stream = f(self.stream).await?;
      78            1 :         Ok(Framed {
      79            1 :             stream,
      80            1 :             read_buf: self.read_buf,
      81            1 :             write_buf: self.write_buf,
      82            1 :         })
      83            0 :     }
      84              : }
      85              : 
      86              : impl<S: AsyncRead + Unpin> Framed<S> {
      87            3 :     pub async fn read_startup_message(
      88            3 :         &mut self,
      89            3 :     ) -> Result<Option<FeStartupPacket>, ConnectionError> {
      90            3 :         read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
      91            0 :     }
      92              : 
      93            4 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
      94            4 :         read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
      95            0 :     }
      96              : }
      97              : 
      98              : impl<S: AsyncWrite + Unpin> Framed<S> {
      99              :     /// Write next message to the output buffer; doesn't flush.
     100           19 :     pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
     101           19 :         BeMessage::write(&mut self.write_buf, msg)
     102            0 :     }
     103              : 
     104              :     /// Flush out the buffer. This function is cancellation safe: it can be
     105              :     /// interrupted and flushing will be continued in the next call.
     106            5 :     pub async fn flush(&mut self) -> Result<(), io::Error> {
     107            5 :         flush(&mut self.stream, &mut self.write_buf).await
     108            0 :     }
     109              : 
     110              :     /// Flush out the buffer and shutdown the stream.
     111            0 :     pub async fn shutdown(&mut self) -> Result<(), io::Error> {
     112            0 :         shutdown(&mut self.stream, &mut self.write_buf).await
     113            0 :     }
     114              : }
     115              : 
     116              : impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
     117              :     /// Split into owned read and write parts. Beware of potential issues with
     118              :     /// using halves in different tasks on TLS stream:
     119              :     /// <https://github.com/tokio-rs/tls/issues/40>
     120            0 :     pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
     121            0 :         let (read_half, write_half) = tokio::io::split(self.stream);
     122            0 :         let reader = FramedReader {
     123            0 :             stream: read_half,
     124            0 :             read_buf: self.read_buf,
     125            0 :         };
     126            0 :         let writer = FramedWriter {
     127            0 :             stream: write_half,
     128            0 :             write_buf: self.write_buf,
     129            0 :         };
     130            0 :         (reader, writer)
     131            0 :     }
     132              : 
     133              :     /// Join read and write parts back.
     134            0 :     pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
     135            0 :         Self {
     136            0 :             stream: reader.stream.unsplit(writer.stream),
     137            0 :             read_buf: reader.read_buf,
     138            0 :             write_buf: writer.write_buf,
     139            0 :         }
     140            0 :     }
     141              : }
     142              : 
     143              : /// Read-only version of `Framed`.
     144              : pub struct FramedReader<S> {
     145              :     stream: ReadHalf<S>,
     146              :     read_buf: BytesMut,
     147              : }
     148              : 
     149              : impl<S: AsyncRead + Unpin> FramedReader<S> {
     150            0 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
     151            0 :         read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
     152            0 :     }
     153              : }
     154              : 
     155              : /// Write-only version of `Framed`.
     156              : pub struct FramedWriter<S> {
     157              :     stream: WriteHalf<S>,
     158              :     write_buf: BytesMut,
     159              : }
     160              : 
     161              : impl<S: AsyncWrite + Unpin> FramedWriter<S> {
     162              :     /// Write next message to the output buffer; doesn't flush.
     163            0 :     pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
     164            0 :         BeMessage::write(&mut self.write_buf, msg)
     165            0 :     }
     166              : 
     167              :     /// Flush out the buffer. This function is cancellation safe: it can be
     168              :     /// interrupted and flushing will be continued in the next call.
     169            0 :     pub async fn flush(&mut self) -> Result<(), io::Error> {
     170            0 :         flush(&mut self.stream, &mut self.write_buf).await
     171            0 :     }
     172              : 
     173              :     /// Flush out the buffer and shutdown the stream.
     174            0 :     pub async fn shutdown(&mut self) -> Result<(), io::Error> {
     175            0 :         shutdown(&mut self.stream, &mut self.write_buf).await
     176            0 :     }
     177              : }
     178              : 
     179              : /// Read next message from the stream. Returns Ok(None), if EOF happened and we
     180              : /// don't have remaining data in the buffer. This function is cancellation safe:
     181              : /// you can drop future which is not yet complete and finalize reading message
     182              : /// with the next call.
     183              : ///
     184              : /// Parametrized to allow reading startup or usual message, having different
     185              : /// format.
     186            7 : async fn read_message<S: AsyncRead + Unpin, M, P>(
     187            7 :     stream: &mut S,
     188            7 :     read_buf: &mut BytesMut,
     189            7 :     parse: P,
     190            7 : ) -> Result<Option<M>, ConnectionError>
     191            7 : where
     192            7 :     P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
     193            0 : {
     194              :     loop {
     195           12 :         if let Some(msg) = parse(read_buf)? {
     196            5 :             return Ok(Some(msg));
     197            0 :         }
     198              :         // If we can't build a frame yet, try to read more data and try again.
     199              :         // Make sure we've got room for at least one byte to read to ensure
     200              :         // that we don't get a spurious 0 that looks like EOF.
     201            7 :         read_buf.reserve(1);
     202            7 :         if stream.read_buf(read_buf).await? == 0 {
     203            0 :             if read_buf.has_remaining() {
     204            0 :                 return Err(io::Error::new(
     205            0 :                     ErrorKind::UnexpectedEof,
     206            0 :                     "EOF with unprocessed data in the buffer",
     207            0 :                 )
     208            0 :                 .into());
     209              :             } else {
     210            0 :                 return Ok(None); // clean EOF
     211              :             }
     212            0 :         }
     213              :     }
     214            0 : }
     215              : 
     216              : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
     217            5 : async fn flush<S: AsyncWrite + Unpin>(
     218            5 :     stream: &mut S,
     219            5 :     write_buf: &mut BytesMut,
     220            5 : ) -> Result<(), io::Error> {
     221           10 :     while write_buf.has_remaining() {
     222            5 :         let bytes_written = stream.write_buf(write_buf).await?;
     223            5 :         if bytes_written == 0 {
     224            0 :             return Err(io::Error::new(
     225            0 :                 ErrorKind::WriteZero,
     226            0 :                 "failed to write message",
     227            0 :             ));
     228            0 :         }
     229              :     }
     230            5 :     stream.flush().await
     231            0 : }
     232              : 
     233              : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
     234            0 : async fn shutdown<S: AsyncWrite + Unpin>(
     235            0 :     stream: &mut S,
     236            0 :     write_buf: &mut BytesMut,
     237            0 : ) -> Result<(), io::Error> {
     238            0 :     flush(stream, write_buf).await?;
     239            0 :     stream.shutdown().await
     240            0 : }
         |