Line data Source code
1 : use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams};
2 : use thiserror::Error;
3 : use tokio::io::{AsyncRead, AsyncWrite};
4 : use tracing::info;
5 :
6 : use crate::{
7 : config::TlsConfig,
8 : error::ReportableError,
9 : proxy::ERR_INSECURE_CONNECTION,
10 : stream::{PqStream, Stream, StreamUpgradeError},
11 : };
12 :
13 2 : #[derive(Error, Debug)]
14 : pub enum HandshakeError {
15 : #[error("data is sent before server replied with EncryptionResponse")]
16 : EarlyData,
17 :
18 : #[error("protocol violation")]
19 : ProtocolViolation,
20 :
21 : #[error("missing certificate")]
22 : MissingCertificate,
23 :
24 : #[error("{0}")]
25 : StreamUpgradeError(#[from] StreamUpgradeError),
26 :
27 : #[error("{0}")]
28 : Io(#[from] std::io::Error),
29 :
30 : #[error("{0}")]
31 : ReportedError(#[from] crate::stream::ReportedError),
32 : }
33 :
34 : impl ReportableError for HandshakeError {
35 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
36 0 : match self {
37 0 : HandshakeError::EarlyData => crate::error::ErrorKind::User,
38 0 : HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
39 : // This error should not happen, but will if we have no default certificate and
40 : // the client sends no SNI extension.
41 : // If they provide SNI then we can be sure there is a certificate that matches.
42 0 : HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
43 0 : HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
44 0 : StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
45 0 : StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
46 : },
47 0 : HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
48 0 : HandshakeError::ReportedError(e) => e.get_error_kind(),
49 : }
50 0 : }
51 : }
52 :
53 : pub enum HandshakeData<S> {
54 : Startup(PqStream<Stream<S>>, StartupMessageParams),
55 : Cancel(CancelKeyData),
56 : }
57 :
58 : /// Establish a (most probably, secure) connection with the client.
59 : /// For better testing experience, `stream` can be any object satisfying the traits.
60 : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
61 : /// we also take an extra care of propagating only the select handshake errors to client.
62 88 : #[tracing::instrument(skip_all)]
63 : pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
64 : stream: S,
65 : mut tls: Option<&TlsConfig>,
66 : record_handshake_error: bool,
67 : ) -> Result<HandshakeData<S>, HandshakeError> {
68 : // Client may try upgrading to each protocol only once
69 : let (mut tried_ssl, mut tried_gss) = (false, false);
70 :
71 : let mut stream = PqStream::new(Stream::from_raw(stream));
72 : loop {
73 : let msg = stream.read_startup_packet().await?;
74 : info!("received {msg:?}");
75 :
76 : use FeStartupPacket::*;
77 : match msg {
78 : SslRequest => match stream.get_ref() {
79 : Stream::Raw { .. } if !tried_ssl => {
80 : tried_ssl = true;
81 :
82 : // We can't perform TLS handshake without a config
83 : let enc = tls.is_some();
84 : stream.write_message(&Be::EncryptionResponse(enc)).await?;
85 : if let Some(tls) = tls.take() {
86 : // Upgrade raw stream into a secure TLS-backed stream.
87 : // NOTE: We've consumed `tls`; this fact will be used later.
88 :
89 : let (raw, read_buf) = stream.into_inner();
90 : // TODO: Normally, client doesn't send any data before
91 : // server says TLS handshake is ok and read_buf is empy.
92 : // However, you could imagine pipelining of postgres
93 : // SSLRequest + TLS ClientHello in one hunk similar to
94 : // pipelining in our node js driver. We should probably
95 : // support that by chaining read_buf with the stream.
96 : if !read_buf.is_empty() {
97 : return Err(HandshakeError::EarlyData);
98 : }
99 : let tls_stream = raw
100 : .upgrade(tls.to_server_config(), record_handshake_error)
101 : .await?;
102 :
103 : let (_, tls_server_end_point) = tls
104 : .cert_resolver
105 : .resolve(tls_stream.get_ref().1.server_name())
106 : .ok_or(HandshakeError::MissingCertificate)?;
107 :
108 : stream = PqStream::new(Stream::Tls {
109 : tls: Box::new(tls_stream),
110 : tls_server_end_point,
111 : });
112 : }
113 : }
114 : _ => return Err(HandshakeError::ProtocolViolation),
115 : },
116 : GssEncRequest => match stream.get_ref() {
117 : Stream::Raw { .. } if !tried_gss => {
118 : tried_gss = true;
119 :
120 : // Currently, we don't support GSSAPI
121 : stream.write_message(&Be::EncryptionResponse(false)).await?;
122 : }
123 : _ => return Err(HandshakeError::ProtocolViolation),
124 : },
125 : StartupMessage { params, .. } => {
126 : // Check that the config has been consumed during upgrade
127 : // OR we didn't provide it at all (for dev purposes).
128 : if tls.is_some() {
129 : return stream
130 : .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
131 : .await?;
132 : }
133 :
134 : info!(session_type = "normal", "successful handshake");
135 : break Ok(HandshakeData::Startup(stream, params));
136 : }
137 : CancelRequest(cancel_key_data) => {
138 : info!(session_type = "cancellation", "successful handshake");
139 : break Ok(HandshakeData::Cancel(cancel_key_data));
140 : }
141 : }
142 : }
143 : }
|