Line data Source code
1 : //! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
2 : //! the async stream based on (and buffered with) BytesMut. All functions are
3 : //! cancellation safe.
4 : //!
5 : //! It is similar to what tokio_util::codec::Framed with appropriate codec
6 : //! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
7 : //! separately without using split from futures::stream::StreamExt (which
8 : //! allocates a [Box] in polling internally). tokio::io::split is used for splitting
9 : //! instead. Plus we customize error messages more than a single type for all io
10 : //! calls.
11 : //!
12 : //! [Box]: https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
13 : use bytes::{Buf, BytesMut};
14 : use std::{
15 : future::Future,
16 : io::{self, ErrorKind},
17 : };
18 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
19 :
20 : use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
21 :
22 : const INITIAL_CAPACITY: usize = 8 * 1024;
23 :
24 : /// Error on postgres connection: either IO (physical transport error) or
25 : /// protocol violation.
26 748 : #[derive(thiserror::Error, Debug)]
27 : pub enum ConnectionError {
28 : #[error(transparent)]
29 : Io(#[from] io::Error),
30 : #[error(transparent)]
31 : Protocol(#[from] ProtocolError),
32 : }
33 :
34 : impl ConnectionError {
35 : /// Proxy stream.rs uses only io::Error; provide it.
36 2 : pub fn into_io_error(self) -> io::Error {
37 2 : match self {
38 2 : ConnectionError::Io(io) => io,
39 0 : ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
40 : }
41 2 : }
42 : }
43 :
44 : /// Wraps async io `stream`, providing messages to write/flush + read Postgres
45 : /// messages.
46 : pub struct Framed<S> {
47 : stream: S,
48 : read_buf: BytesMut,
49 : write_buf: BytesMut,
50 : }
51 :
52 : impl<S> Framed<S> {
53 14062 : pub fn new(stream: S) -> Self {
54 14062 : Self {
55 14062 : stream,
56 14062 : read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
57 14062 : write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
58 14062 : }
59 14062 : }
60 :
61 : /// Get a shared reference to the underlying stream.
62 210 : pub fn get_ref(&self) -> &S {
63 210 : &self.stream
64 210 : }
65 :
66 : /// Deconstruct into the underlying stream and read buffer.
67 148 : pub fn into_inner(self) -> (S, BytesMut) {
68 148 : (self.stream, self.read_buf)
69 148 : }
70 :
71 : /// Return new Framed with stream type transformed by async f, for TLS
72 : /// upgrade.
73 2 : pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
74 2 : where
75 2 : F: FnOnce(S) -> Fut,
76 2 : Fut: Future<Output = Result<S2, E>>,
77 2 : {
78 4 : let stream = f(self.stream).await?;
79 2 : Ok(Framed {
80 2 : stream,
81 2 : read_buf: self.read_buf,
82 2 : write_buf: self.write_buf,
83 2 : })
84 2 : }
85 : }
86 :
87 : impl<S: AsyncRead + Unpin> Framed<S> {
88 25897 : pub async fn read_startup_message(
89 25897 : &mut self,
90 25897 : ) -> Result<Option<FeStartupPacket>, ConnectionError> {
91 25897 : read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
92 25897 : }
93 :
94 4645001 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
95 4645001 : read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
96 4644796 : }
97 : }
98 :
99 : impl<S: AsyncWrite + Unpin> Framed<S> {
100 : /// Write next message to the output buffer; doesn't flush.
101 5596960 : pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
102 5596960 : BeMessage::write(&mut self.write_buf, msg)
103 5596960 : }
104 :
105 : /// Flush out the buffer. This function is cancellation safe: it can be
106 : /// interrupted and flushing will be continued in the next call.
107 5586816 : pub async fn flush(&mut self) -> Result<(), io::Error> {
108 5586816 : flush(&mut self.stream, &mut self.write_buf).await
109 5578568 : }
110 :
111 : /// Flush out the buffer and shutdown the stream.
112 13406 : pub async fn shutdown(&mut self) -> Result<(), io::Error> {
113 13406 : shutdown(&mut self.stream, &mut self.write_buf).await
114 13406 : }
115 : }
116 :
117 : impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
118 : /// Split into owned read and write parts. Beware of potential issues with
119 : /// using halves in different tasks on TLS stream:
120 : /// <https://github.com/tokio-rs/tls/issues/40>
121 2604 : pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
122 2604 : let (read_half, write_half) = tokio::io::split(self.stream);
123 2604 : let reader = FramedReader {
124 2604 : stream: read_half,
125 2604 : read_buf: self.read_buf,
126 2604 : };
127 2604 : let writer = FramedWriter {
128 2604 : stream: write_half,
129 2604 : write_buf: self.write_buf,
130 2604 : };
131 2604 : (reader, writer)
132 2604 : }
133 :
134 : /// Join read and write parts back.
135 2168 : pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
136 2168 : Self {
137 2168 : stream: reader.stream.unsplit(writer.stream),
138 2168 : read_buf: reader.read_buf,
139 2168 : write_buf: writer.write_buf,
140 2168 : }
141 2168 : }
142 : }
143 :
144 : /// Read-only version of `Framed`.
145 : pub struct FramedReader<S> {
146 : stream: ReadHalf<S>,
147 : read_buf: BytesMut,
148 : }
149 :
150 : impl<S: AsyncRead + Unpin> FramedReader<S> {
151 3477763 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
152 6545416 : read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
153 3477293 : }
154 : }
155 :
156 : /// Write-only version of `Framed`.
157 : pub struct FramedWriter<S> {
158 : stream: WriteHalf<S>,
159 : write_buf: BytesMut,
160 : }
161 :
162 : impl<S: AsyncWrite + Unpin> FramedWriter<S> {
163 : /// Write next message to the output buffer; doesn't flush.
164 2466754 : pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
165 2466754 : BeMessage::write(&mut self.write_buf, msg)
166 2466754 : }
167 :
168 : /// Flush out the buffer. This function is cancellation safe: it can be
169 : /// interrupted and flushing will be continued in the next call.
170 2466754 : pub async fn flush(&mut self) -> Result<(), io::Error> {
171 2466754 : flush(&mut self.stream, &mut self.write_buf).await
172 2466738 : }
173 :
174 : /// Flush out the buffer and shutdown the stream.
175 0 : pub async fn shutdown(&mut self) -> Result<(), io::Error> {
176 0 : shutdown(&mut self.stream, &mut self.write_buf).await
177 0 : }
178 : }
179 :
180 : /// Read next message from the stream. Returns Ok(None), if EOF happened and we
181 : /// don't have remaining data in the buffer. This function is cancellation safe:
182 : /// you can drop future which is not yet complete and finalize reading message
183 : /// with the next call.
184 : ///
185 : /// Parametrized to allow reading startup or usual message, having different
186 : /// format.
187 8148661 : async fn read_message<S: AsyncRead + Unpin, M, P>(
188 8148661 : stream: &mut S,
189 8148661 : read_buf: &mut BytesMut,
190 8148661 : parse: P,
191 8148661 : ) -> Result<Option<M>, ConnectionError>
192 8148661 : where
193 8148661 : P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
194 8148661 : {
195 : loop {
196 15233966 : if let Some(msg) = parse(read_buf)? {
197 8125715 : return Ok(Some(msg));
198 7108251 : }
199 7108251 : // If we can't build a frame yet, try to read more data and try again.
200 7108251 : // Make sure we've got room for at least one byte to read to ensure
201 7108251 : // that we don't get a spurious 0 that looks like EOF.
202 7108251 : read_buf.reserve(1);
203 10810332 : if stream.read_buf(read_buf).await? == 0 {
204 21977 : if read_buf.has_remaining() {
205 0 : return Err(io::Error::new(
206 0 : ErrorKind::UnexpectedEof,
207 0 : "EOF with unprocessed data in the buffer",
208 0 : )
209 0 : .into());
210 : } else {
211 21977 : return Ok(None); // clean EOF
212 : }
213 7085305 : }
214 : }
215 8147986 : }
216 :
217 : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
218 8066976 : async fn flush<S: AsyncWrite + Unpin>(
219 8066976 : stream: &mut S,
220 8066976 : write_buf: &mut BytesMut,
221 8066976 : ) -> Result<(), io::Error> {
222 16073917 : while write_buf.has_remaining() {
223 8016014 : let bytes_written = stream.write_buf(write_buf).await?;
224 8006941 : if bytes_written == 0 {
225 0 : return Err(io::Error::new(
226 0 : ErrorKind::WriteZero,
227 0 : "failed to write message",
228 0 : ));
229 8006941 : }
230 : }
231 8057903 : stream.flush().await
232 8058712 : }
233 :
234 : /// Cancellation safe as long as the AsyncWrite is cancellation safe.
235 13406 : async fn shutdown<S: AsyncWrite + Unpin>(
236 13406 : stream: &mut S,
237 13406 : write_buf: &mut BytesMut,
238 13406 : ) -> Result<(), io::Error> {
239 13406 : flush(stream, write_buf).await?;
240 13047 : stream.shutdown().await
241 13406 : }
|