Line data Source code
1 : use crate::config::TlsServerEndPoint;
2 : use crate::error::UserFacingError;
3 : use anyhow::bail;
4 : use bytes::BytesMut;
5 :
6 : use pq_proto::framed::{ConnectionError, Framed};
7 : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
8 : use rustls::ServerConfig;
9 : use std::pin::Pin;
10 : use std::sync::Arc;
11 : use std::{io, task};
12 : use thiserror::Error;
13 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14 : use tokio_rustls::server::TlsStream;
15 :
16 : /// Stream wrapper which implements libpq's protocol.
17 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
18 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
19 : /// to pass random malformed bytes through the connection).
20 : pub struct PqStream<S> {
21 : pub(crate) framed: Framed<S>,
22 : }
23 :
24 : impl<S> PqStream<S> {
25 : /// Construct a new libpq protocol wrapper.
26 197 : pub fn new(stream: S) -> Self {
27 197 : Self {
28 197 : framed: Framed::new(stream),
29 197 : }
30 197 : }
31 :
32 : /// Extract the underlying stream and read buffer.
33 130 : pub fn into_inner(self) -> (S, BytesMut) {
34 130 : self.framed.into_inner()
35 130 : }
36 :
37 : /// Get a shared reference to the underlying stream.
38 204 : pub fn get_ref(&self) -> &S {
39 204 : self.framed.get_ref()
40 204 : }
41 : }
42 :
43 1 : fn err_connection() -> io::Error {
44 1 : io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
45 1 : }
46 :
47 : impl<S: AsyncRead + Unpin> PqStream<S> {
48 : /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
49 197 : pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
50 197 : self.framed
51 197 : .read_startup_message()
52 19 : .await
53 197 : .map_err(ConnectionError::into_io_error)?
54 197 : .ok_or_else(err_connection)
55 197 : }
56 :
57 121 : async fn read_message(&mut self) -> io::Result<FeMessage> {
58 121 : self.framed
59 121 : .read_message()
60 121 : .await
61 121 : .map_err(ConnectionError::into_io_error)?
62 119 : .ok_or_else(err_connection)
63 121 : }
64 :
65 121 : pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
66 121 : match self.read_message().await? {
67 119 : FeMessage::PasswordMessage(msg) => Ok(msg),
68 0 : bad => Err(io::Error::new(
69 0 : io::ErrorKind::InvalidData,
70 0 : format!("unexpected message type: {:?}", bad),
71 0 : )),
72 : }
73 121 : }
74 : }
75 :
76 : impl<S: AsyncWrite + Unpin> PqStream<S> {
77 : /// Write the message into an internal buffer, but don't flush the underlying stream.
78 955 : pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
79 955 : self.framed
80 955 : .write_message(message)
81 955 : .map_err(ProtocolError::into_io_error)?;
82 955 : Ok(self)
83 955 : }
84 :
85 : /// Write the message into an internal buffer and flush it.
86 336 : pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
87 336 : self.write_message_noflush(message)?;
88 336 : self.flush().await?;
89 336 : Ok(self)
90 336 : }
91 :
92 : /// Flush the output buffer into the underlying stream.
93 336 : pub async fn flush(&mut self) -> io::Result<&mut Self> {
94 336 : self.framed.flush().await?;
95 336 : Ok(self)
96 336 : }
97 :
98 : /// Write the error message using [`Self::write_message`], then re-throw it.
99 : /// Allowing string literals is safe under the assumption they might not contain any runtime info.
100 : /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
101 13 : pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
102 11 : tracing::info!("forwarding error to user: {error}");
103 13 : self.write_message(&BeMessage::ErrorResponse(error, None))
104 0 : .await?;
105 13 : bail!(error)
106 13 : }
107 :
108 : /// Write the error message using [`Self::write_message`], then re-throw it.
109 : /// Trait [`UserFacingError`] acts as an allowlist for error types.
110 11 : pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
111 11 : where
112 11 : E: UserFacingError + Into<anyhow::Error>,
113 11 : {
114 11 : let msg = error.to_string_client();
115 11 : tracing::info!("forwarding error to user: {msg}");
116 11 : self.write_message(&BeMessage::ErrorResponse(&msg, None))
117 0 : .await?;
118 11 : bail!(error)
119 11 : }
120 : }
121 :
122 : /// Wrapper for upgrading raw streams into secure streams.
123 : pub enum Stream<S> {
124 : /// We always begin with a raw stream,
125 : /// which may then be upgraded into a secure stream.
126 : Raw { raw: S },
127 : Tls {
128 : /// We box [`TlsStream`] since it can be quite large.
129 : tls: Box<TlsStream<S>>,
130 : /// Channel binding parameter
131 : tls_server_end_point: TlsServerEndPoint,
132 : },
133 : }
134 :
135 : impl<S: Unpin> Unpin for Stream<S> {}
136 :
137 : impl<S> Stream<S> {
138 : /// Construct a new instance from a raw stream.
139 107 : pub fn from_raw(raw: S) -> Self {
140 107 : Self::Raw { raw }
141 107 : }
142 :
143 : /// Return SNI hostname when it's available.
144 51 : pub fn sni_hostname(&self) -> Option<&str> {
145 51 : match self {
146 0 : Stream::Raw { .. } => None,
147 51 : Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
148 : }
149 51 : }
150 :
151 64 : pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
152 64 : match self {
153 0 : Stream::Raw { .. } => TlsServerEndPoint::Undefined,
154 : Stream::Tls {
155 64 : tls_server_end_point,
156 64 : ..
157 64 : } => *tls_server_end_point,
158 : }
159 64 : }
160 : }
161 :
162 0 : #[derive(Debug, Error)]
163 : #[error("Can't upgrade TLS stream")]
164 : pub enum StreamUpgradeError {
165 : #[error("Bad state reached: can't upgrade TLS stream")]
166 : AlreadyTls,
167 :
168 : #[error("Can't upgrade stream: IO error: {0}")]
169 : Io(#[from] io::Error),
170 : }
171 :
172 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
173 : /// If possible, upgrade raw stream into a secure TLS-based stream.
174 91 : pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
175 91 : match self {
176 195 : Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
177 0 : Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
178 : }
179 91 : }
180 : }
181 :
182 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
183 806 : fn poll_read(
184 806 : mut self: Pin<&mut Self>,
185 806 : context: &mut task::Context<'_>,
186 806 : buf: &mut ReadBuf<'_>,
187 806 : ) -> task::Poll<io::Result<()>> {
188 806 : match &mut *self {
189 107 : Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
190 699 : Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
191 : }
192 806 : }
193 : }
194 :
195 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
196 408 : fn poll_write(
197 408 : mut self: Pin<&mut Self>,
198 408 : context: &mut task::Context<'_>,
199 408 : buf: &[u8],
200 408 : ) -> task::Poll<io::Result<usize>> {
201 408 : match &mut *self {
202 106 : Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
203 302 : Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
204 : }
205 408 : }
206 :
207 448 : fn poll_flush(
208 448 : mut self: Pin<&mut Self>,
209 448 : context: &mut task::Context<'_>,
210 448 : ) -> task::Poll<io::Result<()>> {
211 448 : match &mut *self {
212 106 : Self::Raw { raw } => Pin::new(raw).poll_flush(context),
213 342 : Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
214 : }
215 448 : }
216 :
217 40 : fn poll_shutdown(
218 40 : mut self: Pin<&mut Self>,
219 40 : context: &mut task::Context<'_>,
220 40 : ) -> task::Poll<io::Result<()>> {
221 40 : match &mut *self {
222 0 : Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
223 40 : Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
224 : }
225 40 : }
226 : }
|