Line data Source code
1 : use crate::config::TlsServerEndPoint;
2 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
3 : use bytes::BytesMut;
4 :
5 : use pq_proto::framed::{ConnectionError, Framed};
6 : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
7 : use rustls::ServerConfig;
8 : use std::pin::Pin;
9 : use std::sync::Arc;
10 : use std::{io, task};
11 : use thiserror::Error;
12 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13 : use tokio_rustls::server::TlsStream;
14 :
15 : /// Stream wrapper which implements libpq's protocol.
16 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
17 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
18 : /// to pass random malformed bytes through the connection).
19 : pub struct PqStream<S> {
20 : pub(crate) framed: Framed<S>,
21 : }
22 :
23 : impl<S> PqStream<S> {
24 : /// Construct a new libpq protocol wrapper.
25 84 : pub fn new(stream: S) -> Self {
26 84 : Self {
27 84 : framed: Framed::new(stream),
28 84 : }
29 84 : }
30 :
31 : /// Extract the underlying stream and read buffer.
32 40 : pub fn into_inner(self) -> (S, BytesMut) {
33 40 : self.framed.into_inner()
34 40 : }
35 :
36 : /// Get a shared reference to the underlying stream.
37 64 : pub fn get_ref(&self) -> &S {
38 64 : self.framed.get_ref()
39 64 : }
40 : }
41 :
42 0 : fn err_connection() -> io::Error {
43 0 : io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
44 0 : }
45 :
46 : impl<S: AsyncRead + Unpin> PqStream<S> {
47 : /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
48 84 : pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
49 84 : self.framed
50 84 : .read_startup_message()
51 14 : .await
52 84 : .map_err(ConnectionError::into_io_error)?
53 84 : .ok_or_else(err_connection)
54 84 : }
55 :
56 44 : async fn read_message(&mut self) -> io::Result<FeMessage> {
57 44 : self.framed
58 44 : .read_message()
59 44 : .await
60 44 : .map_err(ConnectionError::into_io_error)?
61 42 : .ok_or_else(err_connection)
62 44 : }
63 :
64 44 : pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
65 44 : match self.read_message().await? {
66 42 : FeMessage::PasswordMessage(msg) => Ok(msg),
67 0 : bad => Err(io::Error::new(
68 0 : io::ErrorKind::InvalidData,
69 0 : format!("unexpected message type: {:?}", bad),
70 0 : )),
71 : }
72 44 : }
73 : }
74 :
75 0 : #[derive(Debug)]
76 : pub struct ReportedError {
77 : source: anyhow::Error,
78 : error_kind: ErrorKind,
79 : }
80 :
81 : impl std::fmt::Display for ReportedError {
82 2 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 2 : self.source.fmt(f)
84 2 : }
85 : }
86 :
87 : impl std::error::Error for ReportedError {
88 0 : fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
89 0 : self.source.source()
90 0 : }
91 : }
92 :
93 : impl ReportableError for ReportedError {
94 0 : fn get_error_kind(&self) -> ErrorKind {
95 0 : self.error_kind
96 0 : }
97 : }
98 :
99 : impl<S: AsyncWrite + Unpin> PqStream<S> {
100 : /// Write the message into an internal buffer, but don't flush the underlying stream.
101 138 : pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
102 138 : self.framed
103 138 : .write_message(message)
104 138 : .map_err(ProtocolError::into_io_error)?;
105 138 : Ok(self)
106 138 : }
107 :
108 : /// Write the message into an internal buffer and flush it.
109 110 : pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
110 110 : self.write_message_noflush(message)?;
111 110 : self.flush().await?;
112 110 : Ok(self)
113 110 : }
114 :
115 : /// Flush the output buffer into the underlying stream.
116 110 : pub async fn flush(&mut self) -> io::Result<&mut Self> {
117 110 : self.framed.flush().await?;
118 110 : Ok(self)
119 110 : }
120 :
121 : /// Write the error message using [`Self::write_message`], then re-throw it.
122 : /// Allowing string literals is safe under the assumption they might not contain any runtime info.
123 : /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
124 2 : pub async fn throw_error_str<T>(
125 2 : &mut self,
126 2 : msg: &'static str,
127 2 : error_kind: ErrorKind,
128 2 : ) -> Result<T, ReportedError> {
129 0 : tracing::info!(
130 0 : kind = error_kind.to_metric_label(),
131 0 : msg,
132 0 : "forwarding error to user"
133 0 : );
134 :
135 : // already error case, ignore client IO error
136 2 : let _: Result<_, std::io::Error> = self
137 2 : .write_message(&BeMessage::ErrorResponse(msg, None))
138 0 : .await;
139 :
140 2 : Err(ReportedError {
141 2 : source: anyhow::anyhow!(msg),
142 2 : error_kind,
143 2 : })
144 2 : }
145 :
146 : /// Write the error message using [`Self::write_message`], then re-throw it.
147 : /// Trait [`UserFacingError`] acts as an allowlist for error types.
148 0 : pub async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
149 0 : where
150 0 : E: UserFacingError + Into<anyhow::Error>,
151 0 : {
152 0 : let error_kind = error.get_error_kind();
153 0 : let msg = error.to_string_client();
154 0 : tracing::info!(
155 0 : kind=error_kind.to_metric_label(),
156 0 : error=%error,
157 0 : msg,
158 0 : "forwarding error to user"
159 0 : );
160 :
161 : // already error case, ignore client IO error
162 0 : let _: Result<_, std::io::Error> = self
163 0 : .write_message(&BeMessage::ErrorResponse(&msg, None))
164 0 : .await;
165 :
166 0 : Err(ReportedError {
167 0 : source: anyhow::anyhow!(error),
168 0 : error_kind,
169 0 : })
170 0 : }
171 : }
172 :
173 : /// Wrapper for upgrading raw streams into secure streams.
174 : pub enum Stream<S> {
175 : /// We always begin with a raw stream,
176 : /// which may then be upgraded into a secure stream.
177 : Raw { raw: S },
178 : Tls {
179 : /// We box [`TlsStream`] since it can be quite large.
180 : tls: Box<TlsStream<S>>,
181 : /// Channel binding parameter
182 : tls_server_end_point: TlsServerEndPoint,
183 : },
184 : }
185 :
186 : impl<S: Unpin> Unpin for Stream<S> {}
187 :
188 : impl<S> Stream<S> {
189 : /// Construct a new instance from a raw stream.
190 44 : pub fn from_raw(raw: S) -> Self {
191 44 : Self::Raw { raw }
192 44 : }
193 :
194 : /// Return SNI hostname when it's available.
195 0 : pub fn sni_hostname(&self) -> Option<&str> {
196 0 : match self {
197 0 : Stream::Raw { .. } => None,
198 0 : Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
199 : }
200 0 : }
201 :
202 24 : pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
203 24 : match self {
204 0 : Stream::Raw { .. } => TlsServerEndPoint::Undefined,
205 : Stream::Tls {
206 24 : tls_server_end_point,
207 24 : ..
208 24 : } => *tls_server_end_point,
209 : }
210 24 : }
211 : }
212 :
213 0 : #[derive(Debug, Error)]
214 : #[error("Can't upgrade TLS stream")]
215 : pub enum StreamUpgradeError {
216 : #[error("Bad state reached: can't upgrade TLS stream")]
217 : AlreadyTls,
218 :
219 : #[error("Can't upgrade stream: IO error: {0}")]
220 : Io(#[from] io::Error),
221 : }
222 :
223 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
224 : /// If possible, upgrade raw stream into a secure TLS-based stream.
225 40 : pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
226 40 : match self {
227 94 : Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
228 0 : Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
229 : }
230 40 : }
231 : }
232 :
233 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
234 296 : fn poll_read(
235 296 : mut self: Pin<&mut Self>,
236 296 : context: &mut task::Context<'_>,
237 296 : buf: &mut ReadBuf<'_>,
238 296 : ) -> task::Poll<io::Result<()>> {
239 296 : match &mut *self {
240 44 : Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
241 252 : Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
242 : }
243 296 : }
244 : }
245 :
246 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
247 142 : fn poll_write(
248 142 : mut self: Pin<&mut Self>,
249 142 : context: &mut task::Context<'_>,
250 142 : buf: &[u8],
251 142 : ) -> task::Poll<io::Result<usize>> {
252 142 : match &mut *self {
253 44 : Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
254 98 : Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
255 : }
256 142 : }
257 :
258 142 : fn poll_flush(
259 142 : mut self: Pin<&mut Self>,
260 142 : context: &mut task::Context<'_>,
261 142 : ) -> task::Poll<io::Result<()>> {
262 142 : match &mut *self {
263 44 : Self::Raw { raw } => Pin::new(raw).poll_flush(context),
264 98 : Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
265 : }
266 142 : }
267 :
268 0 : fn poll_shutdown(
269 0 : mut self: Pin<&mut Self>,
270 0 : context: &mut task::Context<'_>,
271 0 : ) -> task::Poll<io::Result<()>> {
272 0 : match &mut *self {
273 0 : Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
274 0 : Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
275 : }
276 0 : }
277 : }
|