Line data Source code
1 : //! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
2 : //! the async stream based on (and buffered with) BytesMut. All functions are
3 : //! cancellation safe.
4 : //!
5 : //! It is similar to what tokio_util::codec::Framed with appropriate codec
6 : //! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
7 : //! separately without using split from futures::stream::StreamExt (which
8 : //! allocates a [Box] in polling internally). tokio::io::split is used for splitting
9 : //! instead. Plus we customize error messages more than a single type for all io
10 : //! calls.
11 : //!
12 : //! [Box]: https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
13 : use std::future::Future;
14 : use std::io::{self, ErrorKind};
15 :
16 : use bytes::{Buf, BytesMut};
17 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
18 :
19 : use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
20 :
21 : const INITIAL_CAPACITY: usize = 8 * 1024;
22 :
23 : /// Error on postgres connection: either IO (physical transport error) or
24 : /// protocol violation.
25 : #[derive(thiserror::Error, Debug)]
26 : pub enum ConnectionError {
27 : #[error(transparent)]
28 : Io(#[from] io::Error),
29 : #[error(transparent)]
30 : Protocol(#[from] ProtocolError),
31 : }
32 :
33 : impl ConnectionError {
34 : /// Proxy stream.rs uses only io::Error; provide it.
35 1 : pub fn into_io_error(self) -> io::Error {
36 1 : match self {
37 1 : ConnectionError::Io(io) => io,
38 0 : ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
39 : }
40 1 : }
41 : }
42 :
43 : /// Wraps async io `stream`, providing messages to write/flush + read Postgres
44 : /// messages.
45 : pub struct Framed<S> {
46 : pub stream: S,
47 : pub read_buf: BytesMut,
48 : pub write_buf: BytesMut,
49 : }
50 :
51 : impl<S> Framed<S> {
52 27 : pub fn new(stream: S) -> Self {
53 27 : Self {
54 27 : stream,
55 27 : read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
56 27 : write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
57 27 : }
58 27 : }
59 :
60 : /// Get a shared reference to the underlying stream.
61 35 : pub fn get_ref(&self) -> &S {
62 35 : &self.stream
63 35 : }
64 :
65 : /// Deconstruct into the underlying stream and read buffer.
66 7 : pub fn into_inner(self) -> (S, BytesMut) {
67 7 : (self.stream, self.read_buf)
68 7 : }
69 :
70 : /// Return new Framed with stream type transformed by async f, for TLS
71 : /// upgrade.
72 1 : pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
73 1 : where
74 1 : F: FnOnce(S) -> Fut,
75 1 : Fut: Future<Output = Result<S2, E>>,
76 1 : {
77 1 : let stream = f(self.stream).await?;
78 1 : Ok(Framed {
79 1 : stream,
80 1 : read_buf: self.read_buf,
81 1 : write_buf: self.write_buf,
82 1 : })
83 0 : }
84 : }
85 :
86 : impl<S: AsyncRead + Unpin> Framed<S> {
87 45 : pub async fn read_startup_message(
88 45 : &mut self,
89 45 : ) -> Result<Option<FeStartupPacket>, ConnectionError> {
90 45 : read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
91 0 : }
92 :
93 30 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
94 30 : read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
95 0 : }
96 : }
97 :
98 : impl<S: AsyncWrite + Unpin> Framed<S> {
99 : /// Write next message to the output buffer; doesn't flush.
100 96 : pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
101 96 : BeMessage::write(&mut self.write_buf, msg)
102 96 : }
103 :
104 : /// Flush out the buffer. This function is cancellation safe: it can be
105 : /// interrupted and flushing will be continued in the next call.
106 60 : pub async fn flush(&mut self) -> Result<(), io::Error> {
107 60 : flush(&mut self.stream, &mut self.write_buf).await
108 0 : }
109 :
110 : /// Flush out the buffer and shutdown the stream.
111 0 : pub async fn shutdown(&mut self) -> Result<(), io::Error> {
112 0 : shutdown(&mut self.stream, &mut self.write_buf).await
113 0 : }
114 : }
115 :
116 : impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
117 : /// Split into owned read and write parts. Beware of potential issues with
118 : /// using halves in different tasks on TLS stream:
119 : /// <https://github.com/tokio-rs/tls/issues/40>
120 0 : pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
121 0 : let (read_half, write_half) = tokio::io::split(self.stream);
122 0 : let reader = FramedReader {
123 0 : stream: read_half,
124 0 : read_buf: self.read_buf,
125 0 : };
126 0 : let writer = FramedWriter {
127 0 : stream: write_half,
128 0 : write_buf: self.write_buf,
129 0 : };
130 0 : (reader, writer)
131 0 : }
132 :
133 : /// Join read and write parts back.
134 0 : pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
135 0 : Self {
136 0 : stream: reader.stream.unsplit(writer.stream),
137 0 : read_buf: reader.read_buf,
138 0 : write_buf: writer.write_buf,
139 0 : }
140 0 : }
141 : }
142 :
143 : /// Read-only version of `Framed`.
144 : pub struct FramedReader<S> {
145 : stream: ReadHalf<S>,
146 : read_buf: BytesMut,
147 : }
148 :
149 : impl<S: AsyncRead + Unpin> FramedReader<S> {
150 0 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
151 0 : read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
152 0 : }
153 : }
154 :
155 : /// Write-only version of `Framed`.
156 : pub struct FramedWriter<S> {
157 : stream: WriteHalf<S>,
158 : write_buf: BytesMut,
159 : }
160 :
161 : impl<S: AsyncWrite + Unpin> FramedWriter<S> {
162 : /// Write next message to the output buffer; doesn't flush.
163 0 : pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
164 0 : BeMessage::write(&mut self.write_buf, msg)
165 0 : }
166 :
167 : /// Flush out the buffer. This function is cancellation safe: it can be
168 : /// interrupted and flushing will be continued in the next call.
169 0 : pub async fn flush(&mut self) -> Result<(), io::Error> {
170 0 : flush(&mut self.stream, &mut self.write_buf).await
171 0 : }
172 :
173 : /// Flush out the buffer and shutdown the stream.
174 0 : pub async fn shutdown(&mut self) -> Result<(), io::Error> {
175 0 : shutdown(&mut self.stream, &mut self.write_buf).await
176 0 : }
177 : }
178 :
179 : /// Read next message from the stream. Returns Ok(None), if EOF happened and we
180 : /// don't have remaining data in the buffer. This function is cancellation safe:
181 : /// you can drop future which is not yet complete and finalize reading message
182 : /// with the next call.
183 : ///
184 : /// Parametrized to allow reading startup or usual message, having different
185 : /// format.
186 75 : async fn read_message<S: AsyncRead + Unpin, M, P>(
187 75 : stream: &mut S,
188 75 : read_buf: &mut BytesMut,
189 75 : parse: P,
190 75 : ) -> Result<Option<M>, ConnectionError>
191 75 : where
192 75 : P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
193 75 : {
194 : loop {
195 147 : if let Some(msg) = parse(read_buf)? {
196 72 : return Ok(Some(msg));
197 75 : }
198 75 : // If we can't build a frame yet, try to read more data and try again.
199 75 : // Make sure we've got room for at least one byte to read to ensure
200 75 : // that we don't get a spurious 0 that looks like EOF.
201 75 : read_buf.reserve(1);
202 75 : if stream.read_buf(read_buf).await? == 0 {
203 0 : if read_buf.has_remaining() {
204 0 : return Err(io::Error::new(
205 0 : ErrorKind::UnexpectedEof,
206 0 : "EOF with unprocessed data in the buffer",
207 0 : )
208 0 : .into());
209 : } else {
210 0 : return Ok(None); // clean EOF
211 : }
212 0 : }
213 : }
214 0 : }
215 :
216 : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
217 60 : async fn flush<S: AsyncWrite + Unpin>(
218 60 : stream: &mut S,
219 60 : write_buf: &mut BytesMut,
220 60 : ) -> Result<(), io::Error> {
221 120 : while write_buf.has_remaining() {
222 60 : let bytes_written = stream.write_buf(write_buf).await?;
223 60 : if bytes_written == 0 {
224 0 : return Err(io::Error::new(
225 0 : ErrorKind::WriteZero,
226 0 : "failed to write message",
227 0 : ));
228 0 : }
229 : }
230 60 : stream.flush().await
231 0 : }
232 :
233 : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
234 0 : async fn shutdown<S: AsyncWrite + Unpin>(
235 0 : stream: &mut S,
236 0 : write_buf: &mut BytesMut,
237 0 : ) -> Result<(), io::Error> {
238 0 : flush(stream, write_buf).await?;
239 0 : stream.shutdown().await
240 0 : }
|