Line data Source code
1 : //! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
2 : //! the async stream based on (and buffered with) BytesMut. All functions are
3 : //! cancellation safe.
4 : //!
5 : //! It is similar to what tokio_util::codec::Framed with appropriate codec
6 : //! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
7 : //! separately without using split from futures::stream::StreamExt (which
8 : //! allocates a [Box] in polling internally). tokio::io::split is used for splitting
9 : //! instead. Plus we customize error messages more than a single type for all io
10 : //! calls.
11 : //!
12 : //! [Box]: https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
13 : use bytes::{Buf, BytesMut};
14 : use std::{
15 : future::Future,
16 : io::{self, ErrorKind},
17 : };
18 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
19 :
20 : use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
21 :
22 : const INITIAL_CAPACITY: usize = 8 * 1024;
23 :
24 : /// Error on postgres connection: either IO (physical transport error) or
25 : /// protocol violation.
26 0 : #[derive(thiserror::Error, Debug)]
27 : pub enum ConnectionError {
28 : #[error(transparent)]
29 : Io(#[from] io::Error),
30 : #[error(transparent)]
31 : Protocol(#[from] ProtocolError),
32 : }
33 :
34 : impl ConnectionError {
35 : /// Proxy stream.rs uses only io::Error; provide it.
36 1 : pub fn into_io_error(self) -> io::Error {
37 1 : match self {
38 1 : ConnectionError::Io(io) => io,
39 0 : ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
40 : }
41 1 : }
42 : }
43 :
44 : /// Wraps async io `stream`, providing messages to write/flush + read Postgres
45 : /// messages.
46 : pub struct Framed<S> {
47 : pub stream: S,
48 : pub read_buf: BytesMut,
49 : pub write_buf: BytesMut,
50 : }
51 :
52 : impl<S> Framed<S> {
53 27 : pub fn new(stream: S) -> Self {
54 27 : Self {
55 27 : stream,
56 27 : read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
57 27 : write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
58 27 : }
59 27 : }
60 :
61 : /// Get a shared reference to the underlying stream.
62 35 : pub fn get_ref(&self) -> &S {
63 35 : &self.stream
64 35 : }
65 :
66 : /// Deconstruct into the underlying stream and read buffer.
67 7 : pub fn into_inner(self) -> (S, BytesMut) {
68 7 : (self.stream, self.read_buf)
69 7 : }
70 :
71 : /// Return new Framed with stream type transformed by async f, for TLS
72 : /// upgrade.
73 1 : pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
74 1 : where
75 1 : F: FnOnce(S) -> Fut,
76 1 : Fut: Future<Output = Result<S2, E>>,
77 1 : {
78 2 : let stream = f(self.stream).await?;
79 1 : Ok(Framed {
80 1 : stream,
81 1 : read_buf: self.read_buf,
82 1 : write_buf: self.write_buf,
83 1 : })
84 1 : }
85 : }
86 :
87 : impl<S: AsyncRead + Unpin> Framed<S> {
88 45 : pub async fn read_startup_message(
89 45 : &mut self,
90 45 : ) -> Result<Option<FeStartupPacket>, ConnectionError> {
91 45 : read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
92 45 : }
93 :
94 30 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
95 30 : read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
96 28 : }
97 : }
98 :
99 : impl<S: AsyncWrite + Unpin> Framed<S> {
100 : /// Write next message to the output buffer; doesn't flush.
101 96 : pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
102 96 : BeMessage::write(&mut self.write_buf, msg)
103 96 : }
104 :
105 : /// Flush out the buffer. This function is cancellation safe: it can be
106 : /// interrupted and flushing will be continued in the next call.
107 65 : pub async fn flush(&mut self) -> Result<(), io::Error> {
108 65 : flush(&mut self.stream, &mut self.write_buf).await
109 65 : }
110 :
111 : /// Flush out the buffer and shutdown the stream.
112 0 : pub async fn shutdown(&mut self) -> Result<(), io::Error> {
113 0 : shutdown(&mut self.stream, &mut self.write_buf).await
114 0 : }
115 : }
116 :
117 : impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
118 : /// Split into owned read and write parts. Beware of potential issues with
119 : /// using halves in different tasks on TLS stream:
120 : /// <https://github.com/tokio-rs/tls/issues/40>
121 0 : pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
122 0 : let (read_half, write_half) = tokio::io::split(self.stream);
123 0 : let reader = FramedReader {
124 0 : stream: read_half,
125 0 : read_buf: self.read_buf,
126 0 : };
127 0 : let writer = FramedWriter {
128 0 : stream: write_half,
129 0 : write_buf: self.write_buf,
130 0 : };
131 0 : (reader, writer)
132 0 : }
133 :
134 : /// Join read and write parts back.
135 0 : pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
136 0 : Self {
137 0 : stream: reader.stream.unsplit(writer.stream),
138 0 : read_buf: reader.read_buf,
139 0 : write_buf: writer.write_buf,
140 0 : }
141 0 : }
142 : }
143 :
144 : /// Read-only version of `Framed`.
145 : pub struct FramedReader<S> {
146 : stream: ReadHalf<S>,
147 : read_buf: BytesMut,
148 : }
149 :
150 : impl<S: AsyncRead + Unpin> FramedReader<S> {
151 0 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
152 0 : read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
153 0 : }
154 : }
155 :
156 : /// Write-only version of `Framed`.
157 : pub struct FramedWriter<S> {
158 : stream: WriteHalf<S>,
159 : write_buf: BytesMut,
160 : }
161 :
162 : impl<S: AsyncWrite + Unpin> FramedWriter<S> {
163 : /// Write next message to the output buffer; doesn't flush.
164 0 : pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
165 0 : BeMessage::write(&mut self.write_buf, msg)
166 0 : }
167 :
168 : /// Flush out the buffer. This function is cancellation safe: it can be
169 : /// interrupted and flushing will be continued in the next call.
170 0 : pub async fn flush(&mut self) -> Result<(), io::Error> {
171 0 : flush(&mut self.stream, &mut self.write_buf).await
172 0 : }
173 :
174 : /// Flush out the buffer and shutdown the stream.
175 0 : pub async fn shutdown(&mut self) -> Result<(), io::Error> {
176 0 : shutdown(&mut self.stream, &mut self.write_buf).await
177 0 : }
178 : }
179 :
180 : /// Read next message from the stream. Returns Ok(None), if EOF happened and we
181 : /// don't have remaining data in the buffer. This function is cancellation safe:
182 : /// you can drop future which is not yet complete and finalize reading message
183 : /// with the next call.
184 : ///
185 : /// Parametrized to allow reading startup or usual message, having different
186 : /// format.
187 75 : async fn read_message<S: AsyncRead + Unpin, M, P>(
188 75 : stream: &mut S,
189 75 : read_buf: &mut BytesMut,
190 75 : parse: P,
191 75 : ) -> Result<Option<M>, ConnectionError>
192 75 : where
193 75 : P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
194 75 : {
195 : loop {
196 147 : if let Some(msg) = parse(read_buf)? {
197 72 : return Ok(Some(msg));
198 75 : }
199 75 : // If we can't build a frame yet, try to read more data and try again.
200 75 : // Make sure we've got room for at least one byte to read to ensure
201 75 : // that we don't get a spurious 0 that looks like EOF.
202 75 : read_buf.reserve(1);
203 75 : if stream.read_buf(read_buf).await? == 0 {
204 0 : if read_buf.has_remaining() {
205 0 : return Err(io::Error::new(
206 0 : ErrorKind::UnexpectedEof,
207 0 : "EOF with unprocessed data in the buffer",
208 0 : )
209 0 : .into());
210 : } else {
211 0 : return Ok(None); // clean EOF
212 : }
213 72 : }
214 : }
215 73 : }
216 :
217 : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
218 65 : async fn flush<S: AsyncWrite + Unpin>(
219 65 : stream: &mut S,
220 65 : write_buf: &mut BytesMut,
221 65 : ) -> Result<(), io::Error> {
222 130 : while write_buf.has_remaining() {
223 65 : let bytes_written = stream.write_buf(write_buf).await?;
224 65 : if bytes_written == 0 {
225 0 : return Err(io::Error::new(
226 0 : ErrorKind::WriteZero,
227 0 : "failed to write message",
228 0 : ));
229 65 : }
230 : }
231 65 : stream.flush().await
232 65 : }
233 :
234 : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
235 0 : async fn shutdown<S: AsyncWrite + Unpin>(
236 0 : stream: &mut S,
237 0 : write_buf: &mut BytesMut,
238 0 : ) -> Result<(), io::Error> {
239 0 : flush(stream, write_buf).await?;
240 0 : stream.shutdown().await
241 0 : }
|