Line data Source code
1 : use crate::config::TlsServerEndPoint;
2 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
3 : use crate::metrics::Metrics;
4 : use bytes::BytesMut;
5 :
6 : use pq_proto::framed::{ConnectionError, Framed};
7 : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
8 : use rustls::ServerConfig;
9 : use std::pin::Pin;
10 : use std::sync::Arc;
11 : use std::{io, task};
12 : use thiserror::Error;
13 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14 : use tokio_rustls::server::TlsStream;
15 :
16 : /// Stream wrapper which implements libpq's protocol.
17 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
18 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
19 : /// to pass random malformed bytes through the connection).
20 : pub struct PqStream<S> {
21 : pub(crate) framed: Framed<S>,
22 : }
23 :
24 : impl<S> PqStream<S> {
25 : /// Construct a new libpq protocol wrapper.
26 90 : pub fn new(stream: S) -> Self {
27 90 : Self {
28 90 : framed: Framed::new(stream),
29 90 : }
30 90 : }
31 :
32 : /// Extract the underlying stream and read buffer.
33 40 : pub fn into_inner(self) -> (S, BytesMut) {
34 40 : self.framed.into_inner()
35 40 : }
36 :
37 : /// Get a shared reference to the underlying stream.
38 70 : pub fn get_ref(&self) -> &S {
39 70 : self.framed.get_ref()
40 70 : }
41 : }
42 :
43 0 : fn err_connection() -> io::Error {
44 0 : io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
45 0 : }
46 :
47 : impl<S: AsyncRead + Unpin> PqStream<S> {
48 : /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
49 84 : pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
50 84 : self.framed
51 84 : .read_startup_message()
52 14 : .await
53 84 : .map_err(ConnectionError::into_io_error)?
54 84 : .ok_or_else(err_connection)
55 84 : }
56 :
57 52 : async fn read_message(&mut self) -> io::Result<FeMessage> {
58 52 : self.framed
59 52 : .read_message()
60 52 : .await
61 52 : .map_err(ConnectionError::into_io_error)?
62 50 : .ok_or_else(err_connection)
63 52 : }
64 :
65 52 : pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
66 52 : match self.read_message().await? {
67 50 : FeMessage::PasswordMessage(msg) => Ok(msg),
68 0 : bad => Err(io::Error::new(
69 0 : io::ErrorKind::InvalidData,
70 0 : format!("unexpected message type: {:?}", bad),
71 0 : )),
72 : }
73 52 : }
74 : }
75 :
76 : #[derive(Debug)]
77 : pub struct ReportedError {
78 : source: anyhow::Error,
79 : error_kind: ErrorKind,
80 : }
81 :
82 : impl std::fmt::Display for ReportedError {
83 2 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 2 : self.source.fmt(f)
85 2 : }
86 : }
87 :
88 : impl std::error::Error for ReportedError {
89 0 : fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
90 0 : self.source.source()
91 0 : }
92 : }
93 :
94 : impl ReportableError for ReportedError {
95 0 : fn get_error_kind(&self) -> ErrorKind {
96 0 : self.error_kind
97 0 : }
98 : }
99 :
100 : impl<S: AsyncWrite + Unpin> PqStream<S> {
101 : /// Write the message into an internal buffer, but don't flush the underlying stream.
102 154 : pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
103 154 : self.framed
104 154 : .write_message(message)
105 154 : .map_err(ProtocolError::into_io_error)?;
106 154 : Ok(self)
107 154 : }
108 :
109 : /// Write the message into an internal buffer and flush it.
110 120 : pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
111 120 : self.write_message_noflush(message)?;
112 120 : self.flush().await?;
113 120 : Ok(self)
114 120 : }
115 :
116 : /// Flush the output buffer into the underlying stream.
117 120 : pub async fn flush(&mut self) -> io::Result<&mut Self> {
118 120 : self.framed.flush().await?;
119 120 : Ok(self)
120 120 : }
121 :
122 : /// Write the error message using [`Self::write_message`], then re-throw it.
123 : /// Allowing string literals is safe under the assumption they might not contain any runtime info.
124 : /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
125 2 : pub async fn throw_error_str<T>(
126 2 : &mut self,
127 2 : msg: &'static str,
128 2 : error_kind: ErrorKind,
129 2 : ) -> Result<T, ReportedError> {
130 2 : tracing::info!(
131 0 : kind = error_kind.to_metric_label(),
132 0 : msg,
133 0 : "forwarding error to user"
134 : );
135 :
136 : // already error case, ignore client IO error
137 2 : let _: Result<_, std::io::Error> = self
138 2 : .write_message(&BeMessage::ErrorResponse(msg, None))
139 0 : .await;
140 :
141 2 : Err(ReportedError {
142 2 : source: anyhow::anyhow!(msg),
143 2 : error_kind,
144 2 : })
145 2 : }
146 :
147 : /// Write the error message using [`Self::write_message`], then re-throw it.
148 : /// Trait [`UserFacingError`] acts as an allowlist for error types.
149 0 : pub async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
150 0 : where
151 0 : E: UserFacingError + Into<anyhow::Error>,
152 0 : {
153 0 : let error_kind = error.get_error_kind();
154 0 : let msg = error.to_string_client();
155 0 : tracing::info!(
156 0 : kind=error_kind.to_metric_label(),
157 0 : error=%error,
158 0 : msg,
159 0 : "forwarding error to user"
160 : );
161 :
162 : // already error case, ignore client IO error
163 0 : let _: Result<_, std::io::Error> = self
164 0 : .write_message(&BeMessage::ErrorResponse(&msg, None))
165 0 : .await;
166 :
167 0 : Err(ReportedError {
168 0 : source: anyhow::anyhow!(error),
169 0 : error_kind,
170 0 : })
171 0 : }
172 : }
173 :
174 : /// Wrapper for upgrading raw streams into secure streams.
175 : pub enum Stream<S> {
176 : /// We always begin with a raw stream,
177 : /// which may then be upgraded into a secure stream.
178 : Raw { raw: S },
179 : Tls {
180 : /// We box [`TlsStream`] since it can be quite large.
181 : tls: Box<TlsStream<S>>,
182 : /// Channel binding parameter
183 : tls_server_end_point: TlsServerEndPoint,
184 : },
185 : }
186 :
187 : impl<S: Unpin> Unpin for Stream<S> {}
188 :
189 : impl<S> Stream<S> {
190 : /// Construct a new instance from a raw stream.
191 50 : pub fn from_raw(raw: S) -> Self {
192 50 : Self::Raw { raw }
193 50 : }
194 :
195 : /// Return SNI hostname when it's available.
196 0 : pub fn sni_hostname(&self) -> Option<&str> {
197 0 : match self {
198 0 : Stream::Raw { .. } => None,
199 0 : Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
200 : }
201 0 : }
202 :
203 30 : pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
204 30 : match self {
205 6 : Stream::Raw { .. } => TlsServerEndPoint::Undefined,
206 : Stream::Tls {
207 24 : tls_server_end_point,
208 24 : ..
209 24 : } => *tls_server_end_point,
210 : }
211 30 : }
212 : }
213 :
214 0 : #[derive(Debug, Error)]
215 : #[error("Can't upgrade TLS stream")]
216 : pub enum StreamUpgradeError {
217 : #[error("Bad state reached: can't upgrade TLS stream")]
218 : AlreadyTls,
219 :
220 : #[error("Can't upgrade stream: IO error: {0}")]
221 : Io(#[from] io::Error),
222 : }
223 :
224 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
225 : /// If possible, upgrade raw stream into a secure TLS-based stream.
226 40 : pub async fn upgrade(
227 40 : self,
228 40 : cfg: Arc<ServerConfig>,
229 40 : record_handshake_error: bool,
230 40 : ) -> Result<TlsStream<S>, StreamUpgradeError> {
231 40 : match self {
232 40 : Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
233 40 : .accept(raw)
234 94 : .await
235 40 : .inspect_err(|_| {
236 0 : if record_handshake_error {
237 0 : Metrics::get().proxy.tls_handshake_failures.inc()
238 0 : }
239 40 : })?),
240 0 : Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
241 : }
242 40 : }
243 : }
244 :
245 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
246 303 : fn poll_read(
247 303 : mut self: Pin<&mut Self>,
248 303 : context: &mut task::Context<'_>,
249 303 : buf: &mut ReadBuf<'_>,
250 303 : ) -> task::Poll<io::Result<()>> {
251 303 : match &mut *self {
252 60 : Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
253 243 : Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
254 : }
255 303 : }
256 : }
257 :
258 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
259 152 : fn poll_write(
260 152 : mut self: Pin<&mut Self>,
261 152 : context: &mut task::Context<'_>,
262 152 : buf: &[u8],
263 152 : ) -> task::Poll<io::Result<usize>> {
264 152 : match &mut *self {
265 54 : Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
266 98 : Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
267 : }
268 152 : }
269 :
270 152 : fn poll_flush(
271 152 : mut self: Pin<&mut Self>,
272 152 : context: &mut task::Context<'_>,
273 152 : ) -> task::Poll<io::Result<()>> {
274 152 : match &mut *self {
275 54 : Self::Raw { raw } => Pin::new(raw).poll_flush(context),
276 98 : Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
277 : }
278 152 : }
279 :
280 0 : fn poll_shutdown(
281 0 : mut self: Pin<&mut Self>,
282 0 : context: &mut task::Context<'_>,
283 0 : ) -> task::Poll<io::Result<()>> {
284 0 : match &mut *self {
285 0 : Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
286 0 : Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
287 : }
288 0 : }
289 : }
|