Line data Source code
1 : use bytes::Buf;
2 : use pq_proto::framed::Framed;
3 : use pq_proto::{
4 : BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
5 : };
6 : use thiserror::Error;
7 : use tokio::io::{AsyncRead, AsyncWrite};
8 : use tracing::{debug, info, warn};
9 :
10 : use crate::auth::endpoint_sni;
11 : use crate::config::TlsConfig;
12 : use crate::context::RequestContext;
13 : use crate::error::ReportableError;
14 : use crate::metrics::Metrics;
15 : use crate::proxy::ERR_INSECURE_CONNECTION;
16 : use crate::stream::{PqStream, Stream, StreamUpgradeError};
17 : use crate::tls::PG_ALPN_PROTOCOL;
18 :
19 : #[derive(Error, Debug)]
20 : pub(crate) enum HandshakeError {
21 : #[error("data is sent before server replied with EncryptionResponse")]
22 : EarlyData,
23 :
24 : #[error("protocol violation")]
25 : ProtocolViolation,
26 :
27 : #[error("missing certificate")]
28 : MissingCertificate,
29 :
30 : #[error("{0}")]
31 : StreamUpgradeError(#[from] StreamUpgradeError),
32 :
33 : #[error("{0}")]
34 : Io(#[from] std::io::Error),
35 :
36 : #[error("{0}")]
37 : ReportedError(#[from] crate::stream::ReportedError),
38 : }
39 :
40 : impl ReportableError for HandshakeError {
41 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
42 0 : match self {
43 0 : HandshakeError::EarlyData => crate::error::ErrorKind::User,
44 0 : HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
45 : // This error should not happen, but will if we have no default certificate and
46 : // the client sends no SNI extension.
47 : // If they provide SNI then we can be sure there is a certificate that matches.
48 0 : HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
49 0 : HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
50 0 : StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
51 0 : StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
52 : },
53 0 : HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
54 0 : HandshakeError::ReportedError(e) => e.get_error_kind(),
55 : }
56 0 : }
57 : }
58 :
59 : pub(crate) enum HandshakeData<S> {
60 : Startup(PqStream<Stream<S>>, StartupMessageParams),
61 : Cancel(CancelKeyData),
62 : }
63 :
64 : /// Establish a (most probably, secure) connection with the client.
65 : /// For better testing experience, `stream` can be any object satisfying the traits.
66 : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
67 : /// we also take an extra care of propagating only the select handshake errors to client.
68 : #[tracing::instrument(skip_all)]
69 : pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
70 : ctx: &RequestContext,
71 : stream: S,
72 : mut tls: Option<&TlsConfig>,
73 : record_handshake_error: bool,
74 : ) -> Result<HandshakeData<S>, HandshakeError> {
75 : // Client may try upgrading to each protocol only once
76 : let (mut tried_ssl, mut tried_gss) = (false, false);
77 :
78 : const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
79 : const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
80 :
81 : let mut stream = PqStream::new(Stream::from_raw(stream));
82 : loop {
83 : let msg = stream.read_startup_packet().await?;
84 : match msg {
85 : FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
86 : Stream::Raw { .. } if !tried_ssl => {
87 : tried_ssl = true;
88 :
89 : // We can't perform TLS handshake without a config
90 : let have_tls = tls.is_some();
91 : if !direct {
92 : stream
93 : .write_message(&Be::EncryptionResponse(have_tls))
94 : .await?;
95 : } else if !have_tls {
96 : return Err(HandshakeError::ProtocolViolation);
97 : }
98 :
99 : if let Some(tls) = tls.take() {
100 : // Upgrade raw stream into a secure TLS-backed stream.
101 : // NOTE: We've consumed `tls`; this fact will be used later.
102 :
103 : let Framed {
104 : stream: raw,
105 : read_buf,
106 : write_buf,
107 : } = stream.framed;
108 :
109 : let Stream::Raw { raw } = raw else {
110 : return Err(HandshakeError::StreamUpgradeError(
111 : StreamUpgradeError::AlreadyTls,
112 : ));
113 : };
114 :
115 : let mut read_buf = read_buf.reader();
116 : let mut res = Ok(());
117 : let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone())
118 20 : .accept_with(raw, |session| {
119 : // push the early data to the tls session
120 20 : while !read_buf.get_ref().is_empty() {
121 0 : match session.read_tls(&mut read_buf) {
122 0 : Ok(_) => {}
123 0 : Err(e) => {
124 0 : res = Err(e);
125 0 : break;
126 : }
127 : }
128 : }
129 20 : });
130 :
131 : res?;
132 :
133 : let read_buf = read_buf.into_inner();
134 : if !read_buf.is_empty() {
135 : return Err(HandshakeError::EarlyData);
136 : }
137 :
138 0 : let tls_stream = accept.await.inspect_err(|_| {
139 0 : if record_handshake_error {
140 0 : Metrics::get().proxy.tls_handshake_failures.inc();
141 0 : }
142 0 : })?;
143 :
144 : let conn_info = tls_stream.get_ref().1;
145 :
146 : // try parse endpoint
147 : let ep = conn_info
148 : .server_name()
149 20 : .and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
150 : if let Some(ep) = ep {
151 : ctx.set_endpoint_id(ep);
152 : }
153 :
154 : // check the ALPN, if exists, as required.
155 : match conn_info.alpn_protocol() {
156 : None | Some(PG_ALPN_PROTOCOL) => {}
157 : Some(other) => {
158 : let alpn = String::from_utf8_lossy(other);
159 : warn!(%alpn, "unexpected ALPN");
160 : return Err(HandshakeError::ProtocolViolation);
161 : }
162 : }
163 :
164 : let (_, tls_server_end_point) = tls
165 : .cert_resolver
166 : .resolve(conn_info.server_name())
167 : .ok_or(HandshakeError::MissingCertificate)?;
168 :
169 : stream = PqStream {
170 : framed: Framed {
171 : stream: Stream::Tls {
172 : tls: Box::new(tls_stream),
173 : tls_server_end_point,
174 : },
175 : read_buf,
176 : write_buf,
177 : },
178 : };
179 : }
180 : }
181 : _ => return Err(HandshakeError::ProtocolViolation),
182 : },
183 : FeStartupPacket::GssEncRequest => match stream.get_ref() {
184 : Stream::Raw { .. } if !tried_gss => {
185 : tried_gss = true;
186 :
187 : // Currently, we don't support GSSAPI
188 : stream.write_message(&Be::EncryptionResponse(false)).await?;
189 : }
190 : _ => return Err(HandshakeError::ProtocolViolation),
191 : },
192 : FeStartupPacket::StartupMessage { params, version }
193 : if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST =>
194 : {
195 : // Check that the config has been consumed during upgrade
196 : // OR we didn't provide it at all (for dev purposes).
197 : if tls.is_some() {
198 : return stream
199 : .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
200 : .await?;
201 : }
202 :
203 : // This log highlights the start of the connection.
204 : // This contains useful information for debugging, not logged elsewhere, like role name and endpoint id.
205 : info!(
206 : ?version,
207 : ?params,
208 : session_type = "normal",
209 : "successful handshake"
210 : );
211 : break Ok(HandshakeData::Startup(stream, params));
212 : }
213 : // downgrade protocol version
214 : FeStartupPacket::StartupMessage { params, version }
215 : if version.major() == 3 && version > PG_PROTOCOL_LATEST =>
216 : {
217 : debug!(?version, "unsupported minor version");
218 :
219 : // no protocol extensions are supported.
220 : // <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
221 : let mut unsupported = vec![];
222 : for (k, _) in params.iter() {
223 : if k.starts_with("_pq_.") {
224 : unsupported.push(k);
225 : }
226 : }
227 :
228 : // TODO: remove unsupported options so we don't send them to compute.
229 :
230 : stream
231 : .write_message(&Be::NegotiateProtocolVersion {
232 : version: PG_PROTOCOL_LATEST,
233 : options: &unsupported,
234 : })
235 : .await?;
236 :
237 : info!(
238 : ?version,
239 : ?params,
240 : session_type = "normal",
241 : "successful handshake; unsupported minor version requested"
242 : );
243 : break Ok(HandshakeData::Startup(stream, params));
244 : }
245 : FeStartupPacket::StartupMessage { version, params } => {
246 : warn!(
247 : ?version,
248 : ?params,
249 : session_type = "normal",
250 : "unsuccessful handshake; unsupported version"
251 : );
252 : return Err(HandshakeError::ProtocolViolation);
253 : }
254 : FeStartupPacket::CancelRequest(cancel_key_data) => {
255 : info!(session_type = "cancellation", "successful handshake");
256 : break Ok(HandshakeData::Cancel(cancel_key_data));
257 : }
258 : }
259 : }
260 : }
|