Line data Source code
1 : use std::pin::Pin;
2 : use std::sync::Arc;
3 : use std::{io, task};
4 :
5 : use bytes::BytesMut;
6 : use pq_proto::framed::{ConnectionError, Framed};
7 : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
8 : use rustls::ServerConfig;
9 : use serde::{Deserialize, Serialize};
10 : use thiserror::Error;
11 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12 : use tokio_rustls::server::TlsStream;
13 : use tracing::debug;
14 :
15 : use crate::control_plane::messages::ColdStartInfo;
16 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
17 : use crate::metrics::Metrics;
18 : use crate::tls::TlsServerEndPoint;
19 :
20 : /// Stream wrapper which implements libpq's protocol.
21 : ///
22 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
23 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
24 : /// to pass random malformed bytes through the connection).
25 : pub struct PqStream<S> {
26 : pub(crate) framed: Framed<S>,
27 : }
28 :
29 : impl<S> PqStream<S> {
30 : /// Construct a new libpq protocol wrapper.
31 25 : pub fn new(stream: S) -> Self {
32 25 : Self {
33 25 : framed: Framed::new(stream),
34 25 : }
35 25 : }
36 :
37 : /// Extract the underlying stream and read buffer.
38 0 : pub fn into_inner(self) -> (S, BytesMut) {
39 0 : self.framed.into_inner()
40 0 : }
41 :
42 : /// Get a shared reference to the underlying stream.
43 35 : pub(crate) fn get_ref(&self) -> &S {
44 35 : self.framed.get_ref()
45 35 : }
46 : }
47 :
48 0 : fn err_connection() -> io::Error {
49 0 : io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
50 0 : }
51 :
52 : impl<S: AsyncRead + Unpin> PqStream<S> {
53 : /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
54 42 : pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
55 42 : self.framed
56 42 : .read_startup_message()
57 42 : .await
58 42 : .map_err(ConnectionError::into_io_error)?
59 42 : .ok_or_else(err_connection)
60 42 : }
61 :
62 26 : async fn read_message(&mut self) -> io::Result<FeMessage> {
63 26 : self.framed
64 26 : .read_message()
65 26 : .await
66 26 : .map_err(ConnectionError::into_io_error)?
67 25 : .ok_or_else(err_connection)
68 26 : }
69 :
70 26 : pub(crate) async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
71 26 : match self.read_message().await? {
72 25 : FeMessage::PasswordMessage(msg) => Ok(msg),
73 0 : bad => Err(io::Error::new(
74 0 : io::ErrorKind::InvalidData,
75 0 : format!("unexpected message type: {bad:?}"),
76 0 : )),
77 : }
78 26 : }
79 : }
80 :
81 : #[derive(Debug)]
82 : pub struct ReportedError {
83 : source: anyhow::Error,
84 : error_kind: ErrorKind,
85 : }
86 :
87 : impl std::fmt::Display for ReportedError {
88 1 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 1 : self.source.fmt(f)
90 1 : }
91 : }
92 :
93 : impl std::error::Error for ReportedError {
94 0 : fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
95 0 : self.source.source()
96 0 : }
97 : }
98 :
99 : impl ReportableError for ReportedError {
100 0 : fn get_error_kind(&self) -> ErrorKind {
101 0 : self.error_kind
102 0 : }
103 : }
104 :
105 0 : #[derive(Serialize, Deserialize, Debug)]
106 : enum ErrorTag {
107 : #[serde(rename = "proxy")]
108 : Proxy,
109 : #[serde(rename = "compute")]
110 : Compute,
111 : #[serde(rename = "client")]
112 : Client,
113 : #[serde(rename = "controlplane")]
114 : ControlPlane,
115 : #[serde(rename = "other")]
116 : Other,
117 : }
118 :
119 : impl From<ErrorKind> for ErrorTag {
120 0 : fn from(error_kind: ErrorKind) -> Self {
121 0 : match error_kind {
122 0 : ErrorKind::User => Self::Client,
123 0 : ErrorKind::ClientDisconnect => Self::Client,
124 0 : ErrorKind::RateLimit => Self::Proxy,
125 0 : ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI
126 0 : ErrorKind::Quota => Self::Proxy,
127 0 : ErrorKind::Service => Self::Proxy,
128 0 : ErrorKind::ControlPlane => Self::ControlPlane,
129 0 : ErrorKind::Postgres => Self::Other,
130 0 : ErrorKind::Compute => Self::Compute,
131 : }
132 0 : }
133 : }
134 :
135 0 : #[derive(Serialize, Deserialize, Debug)]
136 : #[serde(rename_all = "snake_case")]
137 : struct ProbeErrorData {
138 : tag: ErrorTag,
139 : msg: String,
140 : cold_start_info: Option<ColdStartInfo>,
141 : }
142 :
143 : impl<S: AsyncWrite + Unpin> PqStream<S> {
144 : /// Write the message into an internal buffer, but don't flush the underlying stream.
145 77 : pub(crate) fn write_message_noflush(
146 77 : &mut self,
147 77 : message: &BeMessage<'_>,
148 77 : ) -> io::Result<&mut Self> {
149 77 : self.framed
150 77 : .write_message(message)
151 77 : .map_err(ProtocolError::into_io_error)?;
152 77 : Ok(self)
153 77 : }
154 :
155 : /// Write the message into an internal buffer and flush it.
156 54 : pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
157 54 : self.write_message_noflush(message)?;
158 54 : self.flush().await?;
159 54 : Ok(self)
160 54 : }
161 :
162 : /// Flush the output buffer into the underlying stream.
163 55 : pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> {
164 55 : self.framed.flush().await?;
165 55 : Ok(self)
166 55 : }
167 :
168 : /// Writes message with the given error kind to the stream.
169 : /// Used only for probe queries
170 1 : async fn write_format_message(
171 1 : &mut self,
172 1 : msg: &str,
173 1 : error_kind: ErrorKind,
174 1 : ctx: Option<&crate::context::RequestContext>,
175 1 : ) -> String {
176 1 : let formatted_msg = match ctx {
177 0 : Some(ctx) if ctx.get_testodrome_id().is_some() => {
178 0 : serde_json::to_string(&ProbeErrorData {
179 0 : tag: ErrorTag::from(error_kind),
180 0 : msg: msg.to_string(),
181 0 : cold_start_info: Some(ctx.cold_start_info()),
182 0 : })
183 0 : .unwrap_or_default()
184 : }
185 1 : _ => msg.to_string(),
186 : };
187 :
188 : // already error case, ignore client IO error
189 1 : self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None))
190 1 : .await
191 1 : .inspect_err(|e| debug!("write_message failed: {e}"))
192 1 : .ok();
193 1 :
194 1 : formatted_msg
195 1 : }
196 :
197 : /// Write the error message using [`Self::write_format_message`], then re-throw it.
198 : /// Allowing string literals is safe under the assumption they might not contain any runtime info.
199 : /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
200 : /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
201 1 : pub async fn throw_error_str<T>(
202 1 : &mut self,
203 1 : msg: &'static str,
204 1 : error_kind: ErrorKind,
205 1 : ctx: Option<&crate::context::RequestContext>,
206 1 : ) -> Result<T, ReportedError> {
207 1 : self.write_format_message(msg, error_kind, ctx).await;
208 :
209 1 : if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
210 0 : tracing::info!(
211 0 : kind = error_kind.to_metric_label(),
212 0 : msg,
213 0 : "forwarding error to user"
214 : );
215 1 : }
216 :
217 1 : Err(ReportedError {
218 1 : source: anyhow::anyhow!(msg),
219 1 : error_kind,
220 1 : })
221 1 : }
222 :
223 : /// Write the error message using [`Self::write_format_message`], then re-throw it.
224 : /// Trait [`UserFacingError`] acts as an allowlist for error types.
225 : /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
226 0 : pub(crate) async fn throw_error<T, E>(
227 0 : &mut self,
228 0 : error: E,
229 0 : ctx: Option<&crate::context::RequestContext>,
230 0 : ) -> Result<T, ReportedError>
231 0 : where
232 0 : E: UserFacingError + Into<anyhow::Error>,
233 0 : {
234 0 : let error_kind = error.get_error_kind();
235 0 : let msg = error.to_string_client();
236 0 : self.write_format_message(&msg, error_kind, ctx).await;
237 0 : if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
238 0 : tracing::info!(
239 0 : kind=error_kind.to_metric_label(),
240 0 : error=%error,
241 0 : msg,
242 0 : "forwarding error to user",
243 : );
244 0 : }
245 :
246 0 : Err(ReportedError {
247 0 : source: anyhow::anyhow!(error),
248 0 : error_kind,
249 0 : })
250 0 : }
251 : }
252 :
253 : /// Wrapper for upgrading raw streams into secure streams.
254 : pub enum Stream<S> {
255 : /// We always begin with a raw stream,
256 : /// which may then be upgraded into a secure stream.
257 : Raw { raw: S },
258 : Tls {
259 : /// We box [`TlsStream`] since it can be quite large.
260 : tls: Box<TlsStream<S>>,
261 : /// Channel binding parameter
262 : tls_server_end_point: TlsServerEndPoint,
263 : },
264 : }
265 :
266 : impl<S: Unpin> Unpin for Stream<S> {}
267 :
268 : impl<S> Stream<S> {
269 : /// Construct a new instance from a raw stream.
270 25 : pub fn from_raw(raw: S) -> Self {
271 25 : Self::Raw { raw }
272 25 : }
273 :
274 : /// Return SNI hostname when it's available.
275 0 : pub fn sni_hostname(&self) -> Option<&str> {
276 0 : match self {
277 0 : Stream::Raw { .. } => None,
278 0 : Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
279 : }
280 0 : }
281 :
282 15 : pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
283 15 : match self {
284 3 : Stream::Raw { .. } => TlsServerEndPoint::Undefined,
285 : Stream::Tls {
286 12 : tls_server_end_point,
287 12 : ..
288 12 : } => *tls_server_end_point,
289 : }
290 15 : }
291 : }
292 :
293 : #[derive(Debug, Error)]
294 : #[error("Can't upgrade TLS stream")]
295 : pub enum StreamUpgradeError {
296 : #[error("Bad state reached: can't upgrade TLS stream")]
297 : AlreadyTls,
298 :
299 : #[error("Can't upgrade stream: IO error: {0}")]
300 : Io(#[from] io::Error),
301 : }
302 :
303 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
304 : /// If possible, upgrade raw stream into a secure TLS-based stream.
305 0 : pub async fn upgrade(
306 0 : self,
307 0 : cfg: Arc<ServerConfig>,
308 0 : record_handshake_error: bool,
309 0 : ) -> Result<TlsStream<S>, StreamUpgradeError> {
310 0 : match self {
311 0 : Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
312 0 : .accept(raw)
313 0 : .await
314 0 : .inspect_err(|_| {
315 0 : if record_handshake_error {
316 0 : Metrics::get().proxy.tls_handshake_failures.inc();
317 0 : }
318 0 : })?),
319 0 : Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
320 : }
321 0 : }
322 : }
323 :
324 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
325 154 : fn poll_read(
326 154 : mut self: Pin<&mut Self>,
327 154 : context: &mut task::Context<'_>,
328 154 : buf: &mut ReadBuf<'_>,
329 154 : ) -> task::Poll<io::Result<()>> {
330 154 : match &mut *self {
331 30 : Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
332 124 : Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
333 : }
334 154 : }
335 : }
336 :
337 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
338 71 : fn poll_write(
339 71 : mut self: Pin<&mut Self>,
340 71 : context: &mut task::Context<'_>,
341 71 : buf: &[u8],
342 71 : ) -> task::Poll<io::Result<usize>> {
343 71 : match &mut *self {
344 27 : Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
345 44 : Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
346 : }
347 71 : }
348 :
349 71 : fn poll_flush(
350 71 : mut self: Pin<&mut Self>,
351 71 : context: &mut task::Context<'_>,
352 71 : ) -> task::Poll<io::Result<()>> {
353 71 : match &mut *self {
354 27 : Self::Raw { raw } => Pin::new(raw).poll_flush(context),
355 44 : Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
356 : }
357 71 : }
358 :
359 0 : fn poll_shutdown(
360 0 : mut self: Pin<&mut Self>,
361 0 : context: &mut task::Context<'_>,
362 0 : ) -> task::Poll<io::Result<()>> {
363 0 : match &mut *self {
364 0 : Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
365 0 : Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
366 : }
367 0 : }
368 : }
|