Line data Source code
1 : use std::pin::Pin;
2 : use std::sync::Arc;
3 : use std::{io, task};
4 :
5 : use bytes::BytesMut;
6 : use pq_proto::framed::{ConnectionError, Framed};
7 : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
8 : use rustls::ServerConfig;
9 : use thiserror::Error;
10 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11 : use tokio_rustls::server::TlsStream;
12 : use tracing::debug;
13 :
14 : use crate::config::TlsServerEndPoint;
15 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
16 : use crate::metrics::Metrics;
17 :
18 : /// Stream wrapper which implements libpq's protocol.
19 : ///
20 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
21 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
22 : /// to pass random malformed bytes through the connection).
23 : pub struct PqStream<S> {
24 : pub(crate) framed: Framed<S>,
25 : }
26 :
27 : impl<S> PqStream<S> {
28 : /// Construct a new libpq protocol wrapper.
29 25 : pub fn new(stream: S) -> Self {
30 25 : Self {
31 25 : framed: Framed::new(stream),
32 25 : }
33 25 : }
34 :
35 : /// Extract the underlying stream and read buffer.
36 0 : pub fn into_inner(self) -> (S, BytesMut) {
37 0 : self.framed.into_inner()
38 0 : }
39 :
40 : /// Get a shared reference to the underlying stream.
41 35 : pub(crate) fn get_ref(&self) -> &S {
42 35 : self.framed.get_ref()
43 35 : }
44 : }
45 :
46 0 : fn err_connection() -> io::Error {
47 0 : io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
48 0 : }
49 :
50 : impl<S: AsyncRead + Unpin> PqStream<S> {
51 : /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
52 42 : pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
53 42 : self.framed
54 42 : .read_startup_message()
55 7 : .await
56 42 : .map_err(ConnectionError::into_io_error)?
57 42 : .ok_or_else(err_connection)
58 42 : }
59 :
60 26 : async fn read_message(&mut self) -> io::Result<FeMessage> {
61 26 : self.framed
62 26 : .read_message()
63 26 : .await
64 26 : .map_err(ConnectionError::into_io_error)?
65 25 : .ok_or_else(err_connection)
66 26 : }
67 :
68 26 : pub(crate) async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
69 26 : match self.read_message().await? {
70 25 : FeMessage::PasswordMessage(msg) => Ok(msg),
71 0 : bad => Err(io::Error::new(
72 0 : io::ErrorKind::InvalidData,
73 0 : format!("unexpected message type: {bad:?}"),
74 0 : )),
75 : }
76 26 : }
77 : }
78 :
79 : #[derive(Debug)]
80 : pub struct ReportedError {
81 : source: anyhow::Error,
82 : error_kind: ErrorKind,
83 : }
84 :
85 : impl std::fmt::Display for ReportedError {
86 1 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 1 : self.source.fmt(f)
88 1 : }
89 : }
90 :
91 : impl std::error::Error for ReportedError {
92 0 : fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
93 0 : self.source.source()
94 0 : }
95 : }
96 :
97 : impl ReportableError for ReportedError {
98 0 : fn get_error_kind(&self) -> ErrorKind {
99 0 : self.error_kind
100 0 : }
101 : }
102 :
103 : impl<S: AsyncWrite + Unpin> PqStream<S> {
104 : /// Write the message into an internal buffer, but don't flush the underlying stream.
105 77 : pub(crate) fn write_message_noflush(
106 77 : &mut self,
107 77 : message: &BeMessage<'_>,
108 77 : ) -> io::Result<&mut Self> {
109 77 : self.framed
110 77 : .write_message(message)
111 77 : .map_err(ProtocolError::into_io_error)?;
112 77 : Ok(self)
113 77 : }
114 :
115 : /// Write the message into an internal buffer and flush it.
116 60 : pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
117 60 : self.write_message_noflush(message)?;
118 60 : self.flush().await?;
119 60 : Ok(self)
120 60 : }
121 :
122 : /// Flush the output buffer into the underlying stream.
123 60 : pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> {
124 60 : self.framed.flush().await?;
125 60 : Ok(self)
126 60 : }
127 :
128 : /// Write the error message using [`Self::write_message`], then re-throw it.
129 : /// Allowing string literals is safe under the assumption they might not contain any runtime info.
130 : /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
131 1 : pub async fn throw_error_str<T>(
132 1 : &mut self,
133 1 : msg: &'static str,
134 1 : error_kind: ErrorKind,
135 1 : ) -> Result<T, ReportedError> {
136 1 : // TODO: only log this for actually interesting errors
137 1 : tracing::info!(
138 0 : kind = error_kind.to_metric_label(),
139 0 : msg,
140 0 : "forwarding error to user"
141 : );
142 :
143 : // already error case, ignore client IO error
144 1 : self.write_message(&BeMessage::ErrorResponse(msg, None))
145 0 : .await
146 1 : .inspect_err(|e| debug!("write_message failed: {e}"))
147 1 : .ok();
148 1 :
149 1 : Err(ReportedError {
150 1 : source: anyhow::anyhow!(msg),
151 1 : error_kind,
152 1 : })
153 1 : }
154 :
155 : /// Write the error message using [`Self::write_message`], then re-throw it.
156 : /// Trait [`UserFacingError`] acts as an allowlist for error types.
157 0 : pub(crate) async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
158 0 : where
159 0 : E: UserFacingError + Into<anyhow::Error>,
160 0 : {
161 0 : let error_kind = error.get_error_kind();
162 0 : let msg = error.to_string_client();
163 0 : tracing::info!(
164 0 : kind=error_kind.to_metric_label(),
165 0 : error=%error,
166 0 : msg,
167 0 : "forwarding error to user"
168 : );
169 :
170 : // already error case, ignore client IO error
171 0 : self.write_message(&BeMessage::ErrorResponse(&msg, None))
172 0 : .await
173 0 : .inspect_err(|e| debug!("write_message failed: {e}"))
174 0 : .ok();
175 0 :
176 0 : Err(ReportedError {
177 0 : source: anyhow::anyhow!(error),
178 0 : error_kind,
179 0 : })
180 0 : }
181 : }
182 :
183 : /// Wrapper for upgrading raw streams into secure streams.
184 : pub enum Stream<S> {
185 : /// We always begin with a raw stream,
186 : /// which may then be upgraded into a secure stream.
187 : Raw { raw: S },
188 : Tls {
189 : /// We box [`TlsStream`] since it can be quite large.
190 : tls: Box<TlsStream<S>>,
191 : /// Channel binding parameter
192 : tls_server_end_point: TlsServerEndPoint,
193 : },
194 : }
195 :
196 : impl<S: Unpin> Unpin for Stream<S> {}
197 :
198 : impl<S> Stream<S> {
199 : /// Construct a new instance from a raw stream.
200 25 : pub fn from_raw(raw: S) -> Self {
201 25 : Self::Raw { raw }
202 25 : }
203 :
204 : /// Return SNI hostname when it's available.
205 0 : pub fn sni_hostname(&self) -> Option<&str> {
206 0 : match self {
207 0 : Stream::Raw { .. } => None,
208 0 : Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
209 : }
210 0 : }
211 :
212 15 : pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
213 15 : match self {
214 3 : Stream::Raw { .. } => TlsServerEndPoint::Undefined,
215 : Stream::Tls {
216 12 : tls_server_end_point,
217 12 : ..
218 12 : } => *tls_server_end_point,
219 : }
220 15 : }
221 : }
222 :
223 0 : #[derive(Debug, Error)]
224 : #[error("Can't upgrade TLS stream")]
225 : pub enum StreamUpgradeError {
226 : #[error("Bad state reached: can't upgrade TLS stream")]
227 : AlreadyTls,
228 :
229 : #[error("Can't upgrade stream: IO error: {0}")]
230 : Io(#[from] io::Error),
231 : }
232 :
233 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
234 : /// If possible, upgrade raw stream into a secure TLS-based stream.
235 0 : pub async fn upgrade(
236 0 : self,
237 0 : cfg: Arc<ServerConfig>,
238 0 : record_handshake_error: bool,
239 0 : ) -> Result<TlsStream<S>, StreamUpgradeError> {
240 0 : match self {
241 0 : Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
242 0 : .accept(raw)
243 0 : .await
244 0 : .inspect_err(|_| {
245 0 : if record_handshake_error {
246 0 : Metrics::get().proxy.tls_handshake_failures.inc();
247 0 : }
248 0 : })?),
249 0 : Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
250 : }
251 0 : }
252 : }
253 :
254 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
255 155 : fn poll_read(
256 155 : mut self: Pin<&mut Self>,
257 155 : context: &mut task::Context<'_>,
258 155 : buf: &mut ReadBuf<'_>,
259 155 : ) -> task::Poll<io::Result<()>> {
260 155 : match &mut *self {
261 30 : Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
262 125 : Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
263 : }
264 155 : }
265 : }
266 :
267 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
268 76 : fn poll_write(
269 76 : mut self: Pin<&mut Self>,
270 76 : context: &mut task::Context<'_>,
271 76 : buf: &[u8],
272 76 : ) -> task::Poll<io::Result<usize>> {
273 76 : match &mut *self {
274 27 : Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
275 49 : Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
276 : }
277 76 : }
278 :
279 76 : fn poll_flush(
280 76 : mut self: Pin<&mut Self>,
281 76 : context: &mut task::Context<'_>,
282 76 : ) -> task::Poll<io::Result<()>> {
283 76 : match &mut *self {
284 27 : Self::Raw { raw } => Pin::new(raw).poll_flush(context),
285 49 : Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
286 : }
287 76 : }
288 :
289 0 : fn poll_shutdown(
290 0 : mut self: Pin<&mut Self>,
291 0 : context: &mut task::Context<'_>,
292 0 : ) -> task::Poll<io::Result<()>> {
293 0 : match &mut *self {
294 0 : Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
295 0 : Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
296 : }
297 0 : }
298 : }
|