TLA Line data Source code
1 : use crate::error::UserFacingError;
2 : use anyhow::bail;
3 : use bytes::BytesMut;
4 : use pin_project_lite::pin_project;
5 : use pq_proto::framed::{ConnectionError, Framed};
6 : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
7 : use rustls::ServerConfig;
8 : use std::pin::Pin;
9 : use std::sync::Arc;
10 : use std::{io, task};
11 : use thiserror::Error;
12 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13 : use tokio_rustls::server::TlsStream;
14 :
15 : /// Stream wrapper which implements libpq's protocol.
16 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
17 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
18 : /// to pass random malformed bytes through the connection).
19 : pub struct PqStream<S> {
20 : framed: Framed<S>,
21 : }
22 :
23 : impl<S> PqStream<S> {
24 : /// Construct a new libpq protocol wrapper.
25 CBC 84 : pub fn new(stream: S) -> Self {
26 84 : Self {
27 84 : framed: Framed::new(stream),
28 84 : }
29 84 : }
30 :
31 : /// Extract the underlying stream and read buffer.
32 68 : pub fn into_inner(self) -> (S, BytesMut) {
33 68 : self.framed.into_inner()
34 68 : }
35 :
36 : /// Get a shared reference to the underlying stream.
37 71 : pub fn get_ref(&self) -> &S {
38 71 : self.framed.get_ref()
39 71 : }
40 : }
41 :
42 1 : fn err_connection() -> io::Error {
43 1 : io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
44 1 : }
45 :
46 : impl<S: AsyncRead + Unpin> PqStream<S> {
47 : /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
48 84 : pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
49 84 : self.framed
50 84 : .read_startup_message()
51 24 : .await
52 84 : .map_err(ConnectionError::into_io_error)?
53 84 : .ok_or_else(err_connection)
54 84 : }
55 :
56 65 : async fn read_message(&mut self) -> io::Result<FeMessage> {
57 65 : self.framed
58 65 : .read_message()
59 65 : .await
60 65 : .map_err(ConnectionError::into_io_error)?
61 65 : .ok_or_else(err_connection)
62 65 : }
63 :
64 65 : pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
65 65 : match self.read_message().await? {
66 65 : FeMessage::PasswordMessage(msg) => Ok(msg),
67 UBC 0 : bad => Err(io::Error::new(
68 0 : io::ErrorKind::InvalidData,
69 0 : format!("unexpected message type: {:?}", bad),
70 0 : )),
71 : }
72 CBC 65 : }
73 : }
74 :
75 : impl<S: AsyncWrite + Unpin> PqStream<S> {
76 : /// Write the message into an internal buffer, but don't flush the underlying stream.
77 : pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
78 628 : self.framed
79 628 : .write_message(message)
80 628 : .map_err(ProtocolError::into_io_error)?;
81 628 : Ok(self)
82 628 : }
83 :
84 : /// Write the message into an internal buffer and flush it.
85 177 : pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
86 177 : self.write_message_noflush(message)?;
87 177 : self.flush().await?;
88 177 : Ok(self)
89 177 : }
90 :
91 : /// Flush the output buffer into the underlying stream.
92 177 : pub async fn flush(&mut self) -> io::Result<&mut Self> {
93 177 : self.framed.flush().await?;
94 177 : Ok(self)
95 177 : }
96 :
97 : /// Write the error message using [`Self::write_message`], then re-throw it.
98 : /// Allowing string literals is safe under the assumption they might not contain any runtime info.
99 : /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
100 5 : pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
101 4 : tracing::info!("forwarding error to user: {error}");
102 5 : self.write_message(&BeMessage::ErrorResponse(error, None))
103 UBC 0 : .await?;
104 CBC 5 : bail!(error)
105 5 : }
106 :
107 : /// Write the error message using [`Self::write_message`], then re-throw it.
108 : /// Trait [`UserFacingError`] acts as an allowlist for error types.
109 4 : pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
110 4 : where
111 4 : E: UserFacingError + Into<anyhow::Error>,
112 4 : {
113 4 : let msg = error.to_string_client();
114 4 : tracing::info!("forwarding error to user: {msg}");
115 4 : self.write_message(&BeMessage::ErrorResponse(&msg, None))
116 UBC 0 : .await?;
117 CBC 4 : bail!(error)
118 4 : }
119 : }
120 :
121 : pin_project! {
122 : /// Wrapper for upgrading raw streams into secure streams.
123 : /// NOTE: it should be possible to decompose this object as necessary.
124 : #[project = StreamProj]
125 : pub enum Stream<S> {
126 : /// We always begin with a raw stream,
127 : /// which may then be upgraded into a secure stream.
128 : Raw { #[pin] raw: S },
129 : /// We box [`TlsStream`] since it can be quite large.
130 : Tls { #[pin] tls: Box<TlsStream<S>> },
131 : }
132 : }
133 :
134 : impl<S> Stream<S> {
135 : /// Construct a new instance from a raw stream.
136 46 : pub fn from_raw(raw: S) -> Self {
137 46 : Self::Raw { raw }
138 46 : }
139 :
140 : /// Return SNI hostname when it's available.
141 34 : pub fn sni_hostname(&self) -> Option<&str> {
142 34 : match self {
143 UBC 0 : Stream::Raw { .. } => None,
144 CBC 34 : Stream::Tls { tls } => tls.get_ref().1.server_name(),
145 : }
146 34 : }
147 : }
148 :
149 UBC 0 : #[derive(Debug, Error)]
150 : #[error("Can't upgrade TLS stream")]
151 : pub enum StreamUpgradeError {
152 : #[error("Bad state reached: can't upgrade TLS stream")]
153 : AlreadyTls,
154 :
155 : #[error("Can't upgrade stream: IO error: {0}")]
156 : Io(#[from] io::Error),
157 : }
158 :
159 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
160 : /// If possible, upgrade raw stream into a secure TLS-based stream.
161 CBC 39 : pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
162 39 : match self {
163 39 : Stream::Raw { raw } => {
164 78 : let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
165 39 : Ok(Stream::Tls { tls })
166 : }
167 UBC 0 : Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
168 : }
169 CBC 39 : }
170 : }
171 :
172 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
173 418 : fn poll_read(
174 418 : self: Pin<&mut Self>,
175 418 : context: &mut task::Context<'_>,
176 418 : buf: &mut ReadBuf<'_>,
177 418 : ) -> task::Poll<io::Result<()>> {
178 418 : use StreamProj::*;
179 418 : match self.project() {
180 48 : Raw { raw } => raw.poll_read(context, buf),
181 370 : Tls { tls } => tls.poll_read(context, buf),
182 : }
183 418 : }
184 : }
185 :
186 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
187 207 : fn poll_write(
188 207 : self: Pin<&mut Self>,
189 207 : context: &mut task::Context<'_>,
190 207 : buf: &[u8],
191 207 : ) -> task::Poll<io::Result<usize>> {
192 207 : use StreamProj::*;
193 207 : match self.project() {
194 45 : Raw { raw } => raw.poll_write(context, buf),
195 162 : Tls { tls } => tls.poll_write(context, buf),
196 : }
197 207 : }
198 :
199 237 : fn poll_flush(
200 237 : self: Pin<&mut Self>,
201 237 : context: &mut task::Context<'_>,
202 237 : ) -> task::Poll<io::Result<()>> {
203 237 : use StreamProj::*;
204 237 : match self.project() {
205 45 : Raw { raw } => raw.poll_flush(context),
206 192 : Tls { tls } => tls.poll_flush(context),
207 : }
208 237 : }
209 :
210 30 : fn poll_shutdown(
211 30 : self: Pin<&mut Self>,
212 30 : context: &mut task::Context<'_>,
213 30 : ) -> task::Poll<io::Result<()>> {
214 30 : use StreamProj::*;
215 30 : match self.project() {
216 UBC 0 : Raw { raw } => raw.poll_shutdown(context),
217 CBC 30 : Tls { tls } => tls.poll_shutdown(context),
218 : }
219 30 : }
220 : }
|