Line data Source code
1 : use crate::config::TlsServerEndPoint;
2 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
3 : use crate::metrics::Metrics;
4 : use bytes::BytesMut;
5 :
6 : use pq_proto::framed::{ConnectionError, Framed};
7 : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
8 : use rustls::ServerConfig;
9 : use std::pin::Pin;
10 : use std::sync::Arc;
11 : use std::{io, task};
12 : use thiserror::Error;
13 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14 : use tokio_rustls::server::TlsStream;
15 :
16 : /// Stream wrapper which implements libpq's protocol.
17 : ///
18 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
19 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
20 : /// to pass random malformed bytes through the connection).
21 : pub struct PqStream<S> {
22 : pub(crate) framed: Framed<S>,
23 : }
24 :
25 : impl<S> PqStream<S> {
26 : /// Construct a new libpq protocol wrapper.
27 25 : pub fn new(stream: S) -> Self {
28 25 : Self {
29 25 : framed: Framed::new(stream),
30 25 : }
31 25 : }
32 :
33 : /// Extract the underlying stream and read buffer.
34 0 : pub fn into_inner(self) -> (S, BytesMut) {
35 0 : self.framed.into_inner()
36 0 : }
37 :
38 : /// Get a shared reference to the underlying stream.
39 35 : pub(crate) fn get_ref(&self) -> &S {
40 35 : self.framed.get_ref()
41 35 : }
42 : }
43 :
44 0 : fn err_connection() -> io::Error {
45 0 : io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
46 0 : }
47 :
48 : impl<S: AsyncRead + Unpin> PqStream<S> {
49 : /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
50 42 : pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
51 42 : self.framed
52 42 : .read_startup_message()
53 7 : .await
54 42 : .map_err(ConnectionError::into_io_error)?
55 42 : .ok_or_else(err_connection)
56 42 : }
57 :
58 26 : async fn read_message(&mut self) -> io::Result<FeMessage> {
59 26 : self.framed
60 26 : .read_message()
61 26 : .await
62 26 : .map_err(ConnectionError::into_io_error)?
63 25 : .ok_or_else(err_connection)
64 26 : }
65 :
66 26 : pub(crate) async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
67 26 : match self.read_message().await? {
68 25 : FeMessage::PasswordMessage(msg) => Ok(msg),
69 0 : bad => Err(io::Error::new(
70 0 : io::ErrorKind::InvalidData,
71 0 : format!("unexpected message type: {bad:?}"),
72 0 : )),
73 : }
74 26 : }
75 : }
76 :
77 : #[derive(Debug)]
78 : pub struct ReportedError {
79 : source: anyhow::Error,
80 : error_kind: ErrorKind,
81 : }
82 :
83 : impl std::fmt::Display for ReportedError {
84 1 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 1 : self.source.fmt(f)
86 1 : }
87 : }
88 :
89 : impl std::error::Error for ReportedError {
90 0 : fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
91 0 : self.source.source()
92 0 : }
93 : }
94 :
95 : impl ReportableError for ReportedError {
96 0 : fn get_error_kind(&self) -> ErrorKind {
97 0 : self.error_kind
98 0 : }
99 : }
100 :
101 : impl<S: AsyncWrite + Unpin> PqStream<S> {
102 : /// Write the message into an internal buffer, but don't flush the underlying stream.
103 77 : pub(crate) fn write_message_noflush(
104 77 : &mut self,
105 77 : message: &BeMessage<'_>,
106 77 : ) -> io::Result<&mut Self> {
107 77 : self.framed
108 77 : .write_message(message)
109 77 : .map_err(ProtocolError::into_io_error)?;
110 77 : Ok(self)
111 77 : }
112 :
113 : /// Write the message into an internal buffer and flush it.
114 60 : pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
115 60 : self.write_message_noflush(message)?;
116 60 : self.flush().await?;
117 60 : Ok(self)
118 60 : }
119 :
120 : /// Flush the output buffer into the underlying stream.
121 60 : pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> {
122 60 : self.framed.flush().await?;
123 60 : Ok(self)
124 60 : }
125 :
126 : /// Write the error message using [`Self::write_message`], then re-throw it.
127 : /// Allowing string literals is safe under the assumption they might not contain any runtime info.
128 : /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
129 1 : pub async fn throw_error_str<T>(
130 1 : &mut self,
131 1 : msg: &'static str,
132 1 : error_kind: ErrorKind,
133 1 : ) -> Result<T, ReportedError> {
134 1 : tracing::info!(
135 0 : kind = error_kind.to_metric_label(),
136 0 : msg,
137 0 : "forwarding error to user"
138 : );
139 :
140 : // already error case, ignore client IO error
141 1 : let _: Result<_, std::io::Error> = self
142 1 : .write_message(&BeMessage::ErrorResponse(msg, None))
143 0 : .await;
144 :
145 1 : Err(ReportedError {
146 1 : source: anyhow::anyhow!(msg),
147 1 : error_kind,
148 1 : })
149 1 : }
150 :
151 : /// Write the error message using [`Self::write_message`], then re-throw it.
152 : /// Trait [`UserFacingError`] acts as an allowlist for error types.
153 0 : pub(crate) async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
154 0 : where
155 0 : E: UserFacingError + Into<anyhow::Error>,
156 0 : {
157 0 : let error_kind = error.get_error_kind();
158 0 : let msg = error.to_string_client();
159 0 : tracing::info!(
160 0 : kind=error_kind.to_metric_label(),
161 0 : error=%error,
162 0 : msg,
163 0 : "forwarding error to user"
164 : );
165 :
166 : // already error case, ignore client IO error
167 0 : let _: Result<_, std::io::Error> = self
168 0 : .write_message(&BeMessage::ErrorResponse(&msg, None))
169 0 : .await;
170 :
171 0 : Err(ReportedError {
172 0 : source: anyhow::anyhow!(error),
173 0 : error_kind,
174 0 : })
175 0 : }
176 : }
177 :
178 : /// Wrapper for upgrading raw streams into secure streams.
179 : pub enum Stream<S> {
180 : /// We always begin with a raw stream,
181 : /// which may then be upgraded into a secure stream.
182 : Raw { raw: S },
183 : Tls {
184 : /// We box [`TlsStream`] since it can be quite large.
185 : tls: Box<TlsStream<S>>,
186 : /// Channel binding parameter
187 : tls_server_end_point: TlsServerEndPoint,
188 : },
189 : }
190 :
191 : impl<S: Unpin> Unpin for Stream<S> {}
192 :
193 : impl<S> Stream<S> {
194 : /// Construct a new instance from a raw stream.
195 25 : pub fn from_raw(raw: S) -> Self {
196 25 : Self::Raw { raw }
197 25 : }
198 :
199 : /// Return SNI hostname when it's available.
200 0 : pub fn sni_hostname(&self) -> Option<&str> {
201 0 : match self {
202 0 : Stream::Raw { .. } => None,
203 0 : Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
204 : }
205 0 : }
206 :
207 15 : pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
208 15 : match self {
209 3 : Stream::Raw { .. } => TlsServerEndPoint::Undefined,
210 : Stream::Tls {
211 12 : tls_server_end_point,
212 12 : ..
213 12 : } => *tls_server_end_point,
214 : }
215 15 : }
216 : }
217 :
218 0 : #[derive(Debug, Error)]
219 : #[error("Can't upgrade TLS stream")]
220 : pub enum StreamUpgradeError {
221 : #[error("Bad state reached: can't upgrade TLS stream")]
222 : AlreadyTls,
223 :
224 : #[error("Can't upgrade stream: IO error: {0}")]
225 : Io(#[from] io::Error),
226 : }
227 :
228 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
229 : /// If possible, upgrade raw stream into a secure TLS-based stream.
230 0 : pub async fn upgrade(
231 0 : self,
232 0 : cfg: Arc<ServerConfig>,
233 0 : record_handshake_error: bool,
234 0 : ) -> Result<TlsStream<S>, StreamUpgradeError> {
235 0 : match self {
236 0 : Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
237 0 : .accept(raw)
238 0 : .await
239 0 : .inspect_err(|_| {
240 0 : if record_handshake_error {
241 0 : Metrics::get().proxy.tls_handshake_failures.inc();
242 0 : }
243 0 : })?),
244 0 : Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
245 : }
246 0 : }
247 : }
248 :
249 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
250 158 : fn poll_read(
251 158 : mut self: Pin<&mut Self>,
252 158 : context: &mut task::Context<'_>,
253 158 : buf: &mut ReadBuf<'_>,
254 158 : ) -> task::Poll<io::Result<()>> {
255 158 : match &mut *self {
256 30 : Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
257 128 : Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
258 : }
259 158 : }
260 : }
261 :
262 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
263 76 : fn poll_write(
264 76 : mut self: Pin<&mut Self>,
265 76 : context: &mut task::Context<'_>,
266 76 : buf: &[u8],
267 76 : ) -> task::Poll<io::Result<usize>> {
268 76 : match &mut *self {
269 27 : Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
270 49 : Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
271 : }
272 76 : }
273 :
274 76 : fn poll_flush(
275 76 : mut self: Pin<&mut Self>,
276 76 : context: &mut task::Context<'_>,
277 76 : ) -> task::Poll<io::Result<()>> {
278 76 : match &mut *self {
279 27 : Self::Raw { raw } => Pin::new(raw).poll_flush(context),
280 49 : Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
281 : }
282 76 : }
283 :
284 0 : fn poll_shutdown(
285 0 : mut self: Pin<&mut Self>,
286 0 : context: &mut task::Context<'_>,
287 0 : ) -> task::Poll<io::Result<()>> {
288 0 : match &mut *self {
289 0 : Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
290 0 : Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
291 : }
292 0 : }
293 : }
|