Line data Source code
1 : use futures::{FutureExt, TryFutureExt};
2 : use thiserror::Error;
3 : use tokio::io::{AsyncRead, AsyncWrite};
4 : use tracing::{debug, info, warn};
5 :
6 : use crate::auth::endpoint_sni;
7 : use crate::config::TlsConfig;
8 : use crate::context::RequestContext;
9 : use crate::error::ReportableError;
10 : use crate::metrics::Metrics;
11 : use crate::pglb::TlsRequired;
12 : use crate::pqproto::{
13 : BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
14 : };
15 : use crate::stream::{PqStream, Stream, StreamUpgradeError};
16 : use crate::tls::PG_ALPN_PROTOCOL;
17 :
18 : #[derive(Error, Debug)]
19 : pub(crate) enum HandshakeError {
20 : #[error("data is sent before server replied with EncryptionResponse")]
21 : EarlyData,
22 :
23 : #[error("protocol violation")]
24 : ProtocolViolation,
25 :
26 : #[error("{0}")]
27 : StreamUpgradeError(#[from] StreamUpgradeError),
28 :
29 : #[error("{0}")]
30 : Io(#[from] std::io::Error),
31 :
32 : #[error("{0}")]
33 : ReportedError(#[from] crate::stream::ReportedError),
34 : }
35 :
36 : impl ReportableError for HandshakeError {
37 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
38 0 : match self {
39 0 : HandshakeError::EarlyData => crate::error::ErrorKind::User,
40 0 : HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
41 0 : HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
42 0 : StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
43 0 : StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
44 : },
45 0 : HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
46 0 : HandshakeError::ReportedError(e) => e.get_error_kind(),
47 : }
48 0 : }
49 : }
50 :
51 : pub(crate) enum HandshakeData<S> {
52 : Startup(PqStream<Stream<S>>, StartupMessageParams),
53 : Cancel(CancelKeyData),
54 : }
55 :
56 : /// Establish a (most probably, secure) connection with the client.
57 : /// For better testing experience, `stream` can be any object satisfying the traits.
58 : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
59 : /// we also take an extra care of propagating only the select handshake errors to client.
60 : #[tracing::instrument(skip_all)]
61 : pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
62 : ctx: &RequestContext,
63 : stream: S,
64 : mut tls: Option<&TlsConfig>,
65 : record_handshake_error: bool,
66 : ) -> Result<HandshakeData<S>, HandshakeError> {
67 : // Client may try upgrading to each protocol only once
68 : let (mut tried_ssl, mut tried_gss) = (false, false);
69 :
70 : const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
71 : const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
72 :
73 : let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?;
74 : loop {
75 : match msg {
76 : FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
77 : Stream::Raw { .. } if !tried_ssl => {
78 : tried_ssl = true;
79 :
80 : if let Some(tls) = tls.take() {
81 : // Upgrade raw stream into a secure TLS-backed stream.
82 : // NOTE: We've consumed `tls`; this fact will be used later.
83 :
84 : let mut read_buf;
85 : let raw = if let Some(direct) = &direct {
86 : read_buf = &direct[..];
87 : stream.accept_direct_tls()
88 : } else {
89 : read_buf = &[];
90 : stream.accept_tls().await?
91 : };
92 :
93 : let Stream::Raw { raw } = raw else {
94 : return Err(HandshakeError::StreamUpgradeError(
95 : StreamUpgradeError::AlreadyTls,
96 : ));
97 : };
98 :
99 : let mut res = Ok(());
100 : let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone())
101 20 : .accept_with(raw, |session| {
102 : // push the early data to the tls session
103 20 : while !read_buf.is_empty() {
104 0 : match session.read_tls(&mut read_buf) {
105 0 : Ok(_) => {}
106 0 : Err(e) => {
107 0 : res = Err(e);
108 0 : break;
109 : }
110 : }
111 : }
112 20 : })
113 : .map_ok(Box::new)
114 : .boxed();
115 :
116 : res?;
117 :
118 : if !read_buf.is_empty() {
119 : return Err(HandshakeError::EarlyData);
120 : }
121 :
122 0 : let tls_stream = accept.await.inspect_err(|_| {
123 0 : if record_handshake_error {
124 0 : Metrics::get().proxy.tls_handshake_failures.inc();
125 0 : }
126 0 : })?;
127 :
128 : let conn_info = tls_stream.get_ref().1;
129 :
130 : // try parse endpoint
131 : let ep = conn_info
132 : .server_name()
133 20 : .and_then(|sni| endpoint_sni(sni, &tls.common_names));
134 : if let Some(ep) = ep {
135 : ctx.set_endpoint_id(ep);
136 : }
137 :
138 : // check the ALPN, if exists, as required.
139 : match conn_info.alpn_protocol() {
140 : None | Some(PG_ALPN_PROTOCOL) => {}
141 : Some(other) => {
142 : let alpn = String::from_utf8_lossy(other);
143 : warn!(%alpn, "unexpected ALPN");
144 : return Err(HandshakeError::ProtocolViolation);
145 : }
146 : }
147 :
148 : let (_, tls_server_end_point) =
149 : tls.cert_resolver.resolve(conn_info.server_name());
150 :
151 : let tls = Stream::Tls {
152 : tls: tls_stream,
153 : tls_server_end_point,
154 : };
155 : (stream, msg) = PqStream::parse_startup(tls).await?;
156 : } else {
157 : if direct.is_some() {
158 : // client sent us a ClientHello already, we can't do anything with it.
159 : return Err(HandshakeError::ProtocolViolation);
160 : }
161 : msg = stream.reject_encryption().await?;
162 : }
163 : }
164 : _ => return Err(HandshakeError::ProtocolViolation),
165 : },
166 : FeStartupPacket::GssEncRequest => match stream.get_ref() {
167 : Stream::Raw { .. } if !tried_gss => {
168 : tried_gss = true;
169 :
170 : // Currently, we don't support GSSAPI
171 : msg = stream.reject_encryption().await?;
172 : }
173 : _ => return Err(HandshakeError::ProtocolViolation),
174 : },
175 : FeStartupPacket::StartupMessage { params, version }
176 : if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST =>
177 : {
178 : // Check that the config has been consumed during upgrade
179 : // OR we didn't provide it at all (for dev purposes).
180 : if tls.is_some() {
181 : Err(stream.throw_error(TlsRequired, None).await)?;
182 : }
183 :
184 : // This log highlights the start of the connection.
185 : // This contains useful information for debugging, not logged elsewhere, like role name and endpoint id.
186 : info!(
187 : ?version,
188 : ?params,
189 : session_type = "normal",
190 : "successful handshake"
191 : );
192 : break Ok(HandshakeData::Startup(stream, params));
193 : }
194 : // downgrade protocol version
195 : FeStartupPacket::StartupMessage { params, version }
196 : if version.major() == 3 && version > PG_PROTOCOL_LATEST =>
197 : {
198 : debug!(?version, "unsupported minor version");
199 :
200 : // no protocol extensions are supported.
201 : // <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
202 : let mut unsupported = vec![];
203 : let mut supported = StartupMessageParams::default();
204 :
205 : for (k, v) in params.iter() {
206 : if k.starts_with("_pq_.") {
207 : unsupported.push(k);
208 : } else {
209 : supported.insert(k, v);
210 : }
211 : }
212 :
213 : stream.write_message(BeMessage::NegotiateProtocolVersion {
214 : version: PG_PROTOCOL_LATEST,
215 : options: &unsupported,
216 : });
217 : stream.flush().await?;
218 :
219 : info!(
220 : ?version,
221 : ?params,
222 : session_type = "normal",
223 : "successful handshake; unsupported minor version requested"
224 : );
225 : break Ok(HandshakeData::Startup(stream, supported));
226 : }
227 : FeStartupPacket::StartupMessage { version, params } => {
228 : warn!(
229 : ?version,
230 : ?params,
231 : session_type = "normal",
232 : "unsuccessful handshake; unsupported version"
233 : );
234 : return Err(HandshakeError::ProtocolViolation);
235 : }
236 : FeStartupPacket::CancelRequest(cancel_key_data) => {
237 : info!(session_type = "cancellation", "successful handshake");
238 : break Ok(HandshakeData::Cancel(cancel_key_data));
239 : }
240 : }
241 : }
242 : }
|