Line data Source code
1 : use crate::error::UserFacingError;
2 : use anyhow::bail;
3 : use bytes::BytesMut;
4 : use pin_project_lite::pin_project;
5 : use pq_proto::framed::{ConnectionError, Framed};
6 : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
7 : use rustls::ServerConfig;
8 : use std::pin::Pin;
9 : use std::sync::Arc;
10 : use std::{io, task};
11 : use thiserror::Error;
12 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13 : use tokio_rustls::server::TlsStream;
14 :
15 : /// Stream wrapper which implements libpq's protocol.
16 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
17 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
18 : /// to pass random malformed bytes through the connection).
19 : pub struct PqStream<S> {
20 : framed: Framed<S>,
21 : }
22 :
23 : impl<S> PqStream<S> {
24 : /// Construct a new libpq protocol wrapper.
25 80 : pub fn new(stream: S) -> Self {
26 80 : Self {
27 80 : framed: Framed::new(stream),
28 80 : }
29 80 : }
30 :
31 : /// Extract the underlying stream and read buffer.
32 64 : pub fn into_inner(self) -> (S, BytesMut) {
33 64 : self.framed.into_inner()
34 64 : }
35 :
36 : /// Get a shared reference to the underlying stream.
37 67 : pub fn get_ref(&self) -> &S {
38 67 : self.framed.get_ref()
39 67 : }
40 : }
41 :
42 1 : fn err_connection() -> io::Error {
43 1 : io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
44 1 : }
45 :
46 : impl<S: AsyncRead + Unpin> PqStream<S> {
47 : /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
48 80 : pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
49 80 : self.framed
50 80 : .read_startup_message()
51 21 : .await
52 80 : .map_err(ConnectionError::into_io_error)?
53 80 : .ok_or_else(err_connection)
54 80 : }
55 :
56 61 : async fn read_message(&mut self) -> io::Result<FeMessage> {
57 61 : self.framed
58 61 : .read_message()
59 61 : .await
60 61 : .map_err(ConnectionError::into_io_error)?
61 61 : .ok_or_else(err_connection)
62 61 : }
63 :
64 61 : pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
65 61 : match self.read_message().await? {
66 61 : FeMessage::PasswordMessage(msg) => Ok(msg),
67 0 : bad => Err(io::Error::new(
68 0 : io::ErrorKind::InvalidData,
69 0 : format!("unexpected message type: {:?}", bad),
70 0 : )),
71 : }
72 61 : }
73 : }
74 :
75 : impl<S: AsyncWrite + Unpin> PqStream<S> {
76 : /// Write the message into an internal buffer, but don't flush the underlying stream.
77 : pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
78 588 : self.framed
79 588 : .write_message(message)
80 588 : .map_err(ProtocolError::into_io_error)?;
81 588 : Ok(self)
82 588 : }
83 :
84 : /// Write the message into an internal buffer and flush it.
85 167 : pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
86 167 : self.write_message_noflush(message)?;
87 167 : self.flush().await?;
88 167 : Ok(self)
89 167 : }
90 :
91 : /// Flush the output buffer into the underlying stream.
92 167 : pub async fn flush(&mut self) -> io::Result<&mut Self> {
93 167 : self.framed.flush().await?;
94 167 : Ok(self)
95 167 : }
96 :
97 : /// Write the error message using [`Self::write_message`], then re-throw it.
98 : /// Allowing string literals is safe under the assumption they might not contain any runtime info.
99 : /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
100 5 : pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
101 4 : tracing::info!("forwarding error to user: {error}");
102 5 : self.write_message(&BeMessage::ErrorResponse(error, None))
103 0 : .await?;
104 5 : bail!(error)
105 5 : }
106 :
107 : /// Write the error message using [`Self::write_message`], then re-throw it.
108 : /// Trait [`UserFacingError`] acts as an allowlist for error types.
109 4 : pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
110 4 : where
111 4 : E: UserFacingError + Into<anyhow::Error>,
112 4 : {
113 4 : let msg = error.to_string_client();
114 4 : tracing::info!("forwarding error to user: {msg}");
115 4 : self.write_message(&BeMessage::ErrorResponse(&msg, None))
116 0 : .await?;
117 4 : bail!(error)
118 4 : }
119 : }
120 :
121 : pin_project! {
122 : /// Wrapper for upgrading raw streams into secure streams.
123 : /// NOTE: it should be possible to decompose this object as necessary.
124 : #[project = StreamProj]
125 : pub enum Stream<S> {
126 : /// We always begin with a raw stream,
127 : /// which may then be upgraded into a secure stream.
128 : Raw { #[pin] raw: S },
129 : /// We box [`TlsStream`] since it can be quite large.
130 : Tls { #[pin] tls: Box<TlsStream<S>> },
131 : }
132 : }
133 :
134 : impl<S> Stream<S> {
135 : /// Construct a new instance from a raw stream.
136 44 : pub fn from_raw(raw: S) -> Self {
137 44 : Self::Raw { raw }
138 44 : }
139 :
140 : /// Return SNI hostname when it's available.
141 32 : pub fn sni_hostname(&self) -> Option<&str> {
142 32 : match self {
143 0 : Stream::Raw { .. } => None,
144 32 : Stream::Tls { tls } => tls.get_ref().1.server_name(),
145 : }
146 32 : }
147 : }
148 :
149 0 : #[derive(Debug, Error)]
150 : #[error("Can't upgrade TLS stream")]
151 : pub enum StreamUpgradeError {
152 : #[error("Bad state reached: can't upgrade TLS stream")]
153 : AlreadyTls,
154 :
155 : #[error("Can't upgrade stream: IO error: {0}")]
156 : Io(#[from] io::Error),
157 : }
158 :
159 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
160 : /// If possible, upgrade raw stream into a secure TLS-based stream.
161 37 : pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
162 37 : match self {
163 37 : Stream::Raw { raw } => {
164 74 : let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
165 37 : Ok(Stream::Tls { tls })
166 : }
167 0 : Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
168 : }
169 37 : }
170 : }
171 :
172 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
173 390 : fn poll_read(
174 390 : self: Pin<&mut Self>,
175 390 : context: &mut task::Context<'_>,
176 390 : buf: &mut ReadBuf<'_>,
177 390 : ) -> task::Poll<io::Result<()>> {
178 390 : use StreamProj::*;
179 390 : match self.project() {
180 46 : Raw { raw } => raw.poll_read(context, buf),
181 344 : Tls { tls } => tls.poll_read(context, buf),
182 : }
183 390 : }
184 : }
185 :
186 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
187 195 : fn poll_write(
188 195 : self: Pin<&mut Self>,
189 195 : context: &mut task::Context<'_>,
190 195 : buf: &[u8],
191 195 : ) -> task::Poll<io::Result<usize>> {
192 195 : use StreamProj::*;
193 195 : match self.project() {
194 43 : Raw { raw } => raw.poll_write(context, buf),
195 152 : Tls { tls } => tls.poll_write(context, buf),
196 : }
197 195 : }
198 :
199 223 : fn poll_flush(
200 223 : self: Pin<&mut Self>,
201 223 : context: &mut task::Context<'_>,
202 223 : ) -> task::Poll<io::Result<()>> {
203 223 : use StreamProj::*;
204 223 : match self.project() {
205 43 : Raw { raw } => raw.poll_flush(context),
206 180 : Tls { tls } => tls.poll_flush(context),
207 : }
208 223 : }
209 :
210 28 : fn poll_shutdown(
211 28 : self: Pin<&mut Self>,
212 28 : context: &mut task::Context<'_>,
213 28 : ) -> task::Poll<io::Result<()>> {
214 28 : use StreamProj::*;
215 28 : match self.project() {
216 0 : Raw { raw } => raw.poll_shutdown(context),
217 28 : Tls { tls } => tls.poll_shutdown(context),
218 : }
219 28 : }
220 : }
|