Line data Source code
1 : use std::io;
2 : use std::net::SocketAddr;
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use futures::{FutureExt, TryFutureExt};
7 : use itertools::Itertools;
8 : use once_cell::sync::OnceCell;
9 : use postgres_client::tls::MakeTlsConnect;
10 : use postgres_client::{CancelToken, RawConnection};
11 : use postgres_protocol::message::backend::NoticeResponseBody;
12 : use pq_proto::StartupMessageParams;
13 : use rustls::client::danger::ServerCertVerifier;
14 : use rustls::crypto::ring;
15 : use rustls::pki_types::InvalidDnsNameError;
16 : use thiserror::Error;
17 : use tokio::net::TcpStream;
18 : use tracing::{debug, error, info, warn};
19 :
20 : use crate::auth::parse_endpoint_param;
21 : use crate::cancellation::CancelClosure;
22 : use crate::context::RequestContext;
23 : use crate::control_plane::client::ApiLockError;
24 : use crate::control_plane::errors::WakeComputeError;
25 : use crate::control_plane::messages::MetricsAuxInfo;
26 : use crate::error::{ReportableError, UserFacingError};
27 : use crate::metrics::{Metrics, NumDbConnectionsGuard};
28 : use crate::postgres_rustls::MakeRustlsConnect;
29 : use crate::proxy::neon_option;
30 : use crate::types::Host;
31 :
32 : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
33 :
34 : #[derive(Debug, Error)]
35 : pub(crate) enum ConnectionError {
36 : /// This error doesn't seem to reveal any secrets; for instance,
37 : /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
38 : #[error("{COULD_NOT_CONNECT}: {0}")]
39 : Postgres(#[from] postgres_client::Error),
40 :
41 : #[error("{COULD_NOT_CONNECT}: {0}")]
42 : CouldNotConnect(#[from] io::Error),
43 :
44 : #[error("Couldn't load native TLS certificates: {0:?}")]
45 : TlsCertificateError(Vec<rustls_native_certs::Error>),
46 :
47 : #[error("{COULD_NOT_CONNECT}: {0}")]
48 : TlsError(#[from] InvalidDnsNameError),
49 :
50 : #[error("{COULD_NOT_CONNECT}: {0}")]
51 : WakeComputeError(#[from] WakeComputeError),
52 :
53 : #[error("error acquiring resource permit: {0}")]
54 : TooManyConnectionAttempts(#[from] ApiLockError),
55 : }
56 :
57 : impl UserFacingError for ConnectionError {
58 0 : fn to_string_client(&self) -> String {
59 0 : match self {
60 : // This helps us drop irrelevant library-specific prefixes.
61 : // TODO: propagate severity level and other parameters.
62 0 : ConnectionError::Postgres(err) => match err.as_db_error() {
63 0 : Some(err) => {
64 0 : let msg = err.message();
65 0 :
66 0 : if msg.starts_with("unsupported startup parameter: ")
67 0 : || msg.starts_with("unsupported startup parameter in options: ")
68 : {
69 0 : format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
70 : } else {
71 0 : msg.to_owned()
72 : }
73 : }
74 0 : None => err.to_string(),
75 : },
76 0 : ConnectionError::WakeComputeError(err) => err.to_string_client(),
77 : ConnectionError::TooManyConnectionAttempts(_) => {
78 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
79 : }
80 0 : _ => COULD_NOT_CONNECT.to_owned(),
81 : }
82 0 : }
83 : }
84 :
85 : impl ReportableError for ConnectionError {
86 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
87 0 : match self {
88 0 : ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
89 0 : crate::error::ErrorKind::Postgres
90 : }
91 0 : ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
92 0 : ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
93 0 : ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
94 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
95 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
96 0 : ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
97 : }
98 0 : }
99 : }
100 :
101 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
102 : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
103 :
104 : /// A config for establishing a connection to compute node.
105 : /// Eventually, `postgres_client` will be replaced with something better.
106 : /// Newtype allows us to implement methods on top of it.
107 : #[derive(Clone)]
108 : pub(crate) struct ConnCfg(Box<postgres_client::Config>);
109 :
110 : /// Creation and initialization routines.
111 : impl ConnCfg {
112 10 : pub(crate) fn new(host: String, port: u16) -> Self {
113 10 : Self(Box::new(postgres_client::Config::new(host, port)))
114 10 : }
115 :
116 : /// Reuse password or auth keys from the other config.
117 4 : pub(crate) fn reuse_password(&mut self, other: Self) {
118 4 : if let Some(password) = other.get_password() {
119 4 : self.password(password);
120 4 : }
121 :
122 4 : if let Some(keys) = other.get_auth_keys() {
123 0 : self.auth_keys(keys);
124 4 : }
125 4 : }
126 :
127 0 : pub(crate) fn get_host(&self) -> Host {
128 0 : match self.0.get_host() {
129 0 : postgres_client::config::Host::Tcp(s) => s.into(),
130 0 : }
131 0 : }
132 :
133 : /// Apply startup message params to the connection config.
134 0 : pub(crate) fn set_startup_params(
135 0 : &mut self,
136 0 : params: &StartupMessageParams,
137 0 : arbitrary_params: bool,
138 0 : ) {
139 0 : if !arbitrary_params {
140 0 : self.set_param("client_encoding", "UTF8");
141 0 : }
142 0 : for (k, v) in params.iter() {
143 0 : match k {
144 : // Only set `user` if it's not present in the config.
145 : // Console redirect auth flow takes username from the console's response.
146 0 : "user" if self.user_is_set() => continue,
147 0 : "database" if self.db_is_set() => continue,
148 0 : "options" => {
149 0 : if let Some(options) = filtered_options(v) {
150 0 : self.set_param(k, &options);
151 0 : }
152 : }
153 0 : "user" | "database" | "application_name" | "replication" => {
154 0 : self.set_param(k, v);
155 0 : }
156 :
157 : // if we allow arbitrary params, then we forward them through.
158 : // this is a flag for a period of backwards compatibility
159 0 : k if arbitrary_params => {
160 0 : self.set_param(k, v);
161 0 : }
162 0 : _ => {}
163 : }
164 : }
165 0 : }
166 : }
167 :
168 : impl std::ops::Deref for ConnCfg {
169 : type Target = postgres_client::Config;
170 :
171 8 : fn deref(&self) -> &Self::Target {
172 8 : &self.0
173 8 : }
174 : }
175 :
176 : /// For now, let's make it easier to setup the config.
177 : impl std::ops::DerefMut for ConnCfg {
178 10 : fn deref_mut(&mut self) -> &mut Self::Target {
179 10 : &mut self.0
180 10 : }
181 : }
182 :
183 : impl ConnCfg {
184 : /// Establish a raw TCP connection to the compute node.
185 0 : async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
186 : use postgres_client::config::Host;
187 :
188 : // wrap TcpStream::connect with timeout
189 0 : let connect_with_timeout = |host, port| {
190 0 : tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
191 0 : move |res| match res {
192 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
193 0 : Err(_) => Err(io::Error::new(
194 0 : io::ErrorKind::TimedOut,
195 0 : format!("exceeded connection timeout {timeout:?}"),
196 0 : )),
197 0 : },
198 0 : )
199 0 : };
200 :
201 0 : let connect_once = |host, port| {
202 0 : debug!("trying to connect to compute node at {host}:{port}");
203 0 : connect_with_timeout(host, port).and_then(|socket| async {
204 0 : let socket_addr = socket.peer_addr()?;
205 : // This prevents load balancer from severing the connection.
206 0 : socket2::SockRef::from(&socket).set_keepalive(true)?;
207 0 : Ok((socket_addr, socket))
208 0 : })
209 0 : };
210 :
211 : // We can't reuse connection establishing logic from `postgres_client` here,
212 : // because it has no means for extracting the underlying socket which we
213 : // require for our business.
214 0 : let port = self.0.get_port();
215 0 : let host = self.0.get_host();
216 0 :
217 0 : let host = match host {
218 0 : Host::Tcp(host) => host.as_str(),
219 0 : };
220 0 :
221 0 : match connect_once(host, port).await {
222 0 : Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
223 0 : Err(err) => {
224 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
225 0 : Err(err)
226 : }
227 : }
228 0 : }
229 : }
230 :
231 : type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
232 :
233 : pub(crate) struct PostgresConnection {
234 : /// Socket connected to a compute node.
235 : pub(crate) stream:
236 : postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
237 : /// PostgreSQL connection parameters.
238 : pub(crate) params: std::collections::HashMap<String, String>,
239 : /// Query cancellation token.
240 : pub(crate) cancel_closure: CancelClosure,
241 : /// Labels for proxy's metrics.
242 : pub(crate) aux: MetricsAuxInfo,
243 : /// Notices received from compute after authenticating
244 : pub(crate) delayed_notice: Vec<NoticeResponseBody>,
245 :
246 : _guage: NumDbConnectionsGuard<'static>,
247 : }
248 :
249 : impl ConnCfg {
250 : /// Connect to a corresponding compute node.
251 0 : pub(crate) async fn connect(
252 0 : &self,
253 0 : ctx: &RequestContext,
254 0 : allow_self_signed_compute: bool,
255 0 : aux: MetricsAuxInfo,
256 0 : timeout: Duration,
257 0 : ) -> Result<PostgresConnection, ConnectionError> {
258 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
259 0 : let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
260 0 : drop(pause);
261 :
262 0 : let client_config = if allow_self_signed_compute {
263 : // Allow all certificates for creating the connection
264 0 : let verifier = Arc::new(AcceptEverythingVerifier);
265 0 : rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
266 0 : .with_safe_default_protocol_versions()
267 0 : .expect("ring should support the default protocol versions")
268 0 : .dangerous()
269 0 : .with_custom_certificate_verifier(verifier)
270 : } else {
271 0 : let root_store = TLS_ROOTS
272 0 : .get_or_try_init(load_certs)
273 0 : .map_err(ConnectionError::TlsCertificateError)?
274 0 : .clone();
275 0 : rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
276 0 : .with_safe_default_protocol_versions()
277 0 : .expect("ring should support the default protocol versions")
278 0 : .with_root_certificates(root_store)
279 : };
280 0 : let client_config = client_config.with_no_client_auth();
281 0 :
282 0 : let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
283 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
284 0 : &mut mk_tls,
285 0 : host,
286 0 : )?;
287 :
288 : // connect_raw() will not use TLS if sslmode is "disable"
289 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
290 0 : let connection = self.0.connect_raw(stream, tls).await?;
291 0 : drop(pause);
292 0 :
293 0 : let RawConnection {
294 0 : stream,
295 0 : parameters,
296 0 : delayed_notice,
297 0 : process_id,
298 0 : secret_key,
299 0 : } = connection;
300 0 :
301 0 : tracing::Span::current().record("pid", tracing::field::display(process_id));
302 0 : let stream = stream.into_inner();
303 0 :
304 0 : // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
305 0 : info!(
306 0 : cold_start_info = ctx.cold_start_info().as_str(),
307 0 : "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
308 0 : self.0.get_ssl_mode()
309 : );
310 :
311 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
312 : // Yet another reason to rework the connection establishing code.
313 0 : let cancel_closure = CancelClosure::new(
314 0 : socket_addr,
315 0 : CancelToken {
316 0 : socket_config: None,
317 0 : ssl_mode: self.0.get_ssl_mode(),
318 0 : process_id,
319 0 : secret_key,
320 0 : },
321 0 : vec![],
322 0 : );
323 0 :
324 0 : let connection = PostgresConnection {
325 0 : stream,
326 0 : params: parameters,
327 0 : delayed_notice,
328 0 : cancel_closure,
329 0 : aux,
330 0 : _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
331 0 : };
332 0 :
333 0 : Ok(connection)
334 0 : }
335 : }
336 :
337 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
338 6 : fn filtered_options(options: &str) -> Option<String> {
339 6 : #[allow(unstable_name_collisions)]
340 6 : let options: String = StartupMessageParams::parse_options_raw(options)
341 14 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
342 6 : .intersperse(" ") // TODO: use impl from std once it's stabilized
343 6 : .collect();
344 6 :
345 6 : // Don't even bother with empty options.
346 6 : if options.is_empty() {
347 3 : return None;
348 3 : }
349 3 :
350 3 : Some(options)
351 6 : }
352 :
353 0 : fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
354 0 : let der_certs = rustls_native_certs::load_native_certs();
355 0 :
356 0 : if !der_certs.errors.is_empty() {
357 0 : return Err(der_certs.errors);
358 0 : }
359 0 :
360 0 : let mut store = rustls::RootCertStore::empty();
361 0 : store.add_parsable_certificates(der_certs.certs);
362 0 : Ok(Arc::new(store))
363 0 : }
364 : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
365 :
366 : #[derive(Debug)]
367 : struct AcceptEverythingVerifier;
368 : impl ServerCertVerifier for AcceptEverythingVerifier {
369 0 : fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
370 : use rustls::SignatureScheme;
371 : // The schemes for which `SignatureScheme::supported_in_tls13` returns true.
372 0 : vec![
373 0 : SignatureScheme::ECDSA_NISTP521_SHA512,
374 0 : SignatureScheme::ECDSA_NISTP384_SHA384,
375 0 : SignatureScheme::ECDSA_NISTP256_SHA256,
376 0 : SignatureScheme::RSA_PSS_SHA512,
377 0 : SignatureScheme::RSA_PSS_SHA384,
378 0 : SignatureScheme::RSA_PSS_SHA256,
379 0 : SignatureScheme::ED25519,
380 0 : ]
381 0 : }
382 0 : fn verify_server_cert(
383 0 : &self,
384 0 : _end_entity: &rustls::pki_types::CertificateDer<'_>,
385 0 : _intermediates: &[rustls::pki_types::CertificateDer<'_>],
386 0 : _server_name: &rustls::pki_types::ServerName<'_>,
387 0 : _ocsp_response: &[u8],
388 0 : _now: rustls::pki_types::UnixTime,
389 0 : ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
390 0 : Ok(rustls::client::danger::ServerCertVerified::assertion())
391 0 : }
392 0 : fn verify_tls12_signature(
393 0 : &self,
394 0 : _message: &[u8],
395 0 : _cert: &rustls::pki_types::CertificateDer<'_>,
396 0 : _dss: &rustls::DigitallySignedStruct,
397 0 : ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
398 0 : Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
399 0 : }
400 0 : fn verify_tls13_signature(
401 0 : &self,
402 0 : _message: &[u8],
403 0 : _cert: &rustls::pki_types::CertificateDer<'_>,
404 0 : _dss: &rustls::DigitallySignedStruct,
405 0 : ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
406 0 : Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
407 0 : }
408 : }
409 :
410 : #[cfg(test)]
411 : mod tests {
412 : use super::*;
413 :
414 : #[test]
415 1 : fn test_filtered_options() {
416 1 : // Empty options is unlikely to be useful anyway.
417 1 : let params = "";
418 1 : assert_eq!(filtered_options(params), None);
419 :
420 : // It's likely that clients will only use options to specify endpoint/project.
421 1 : let params = "project=foo";
422 1 : assert_eq!(filtered_options(params), None);
423 :
424 : // Same, because unescaped whitespaces are no-op.
425 1 : let params = " project=foo ";
426 1 : assert_eq!(filtered_options(params).as_deref(), None);
427 :
428 1 : let params = r"\ project=foo \ ";
429 1 : assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
430 :
431 1 : let params = "project = foo";
432 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
433 :
434 1 : let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true";
435 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
436 1 : }
437 : }
|