Line data Source code
1 : use std::io;
2 : use std::net::SocketAddr;
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use futures::{FutureExt, TryFutureExt};
7 : use itertools::Itertools;
8 : use once_cell::sync::OnceCell;
9 : use pq_proto::StartupMessageParams;
10 : use rustls::client::danger::ServerCertVerifier;
11 : use rustls::crypto::ring;
12 : use rustls::pki_types::InvalidDnsNameError;
13 : use thiserror::Error;
14 : use tokio::net::TcpStream;
15 : use tokio_postgres::tls::MakeTlsConnect;
16 : use tokio_postgres_rustls::MakeRustlsConnect;
17 : use tracing::{error, info, warn};
18 :
19 : use crate::auth::parse_endpoint_param;
20 : use crate::cancellation::CancelClosure;
21 : use crate::context::RequestMonitoring;
22 : use crate::control_plane::client::ApiLockError;
23 : use crate::control_plane::errors::WakeComputeError;
24 : use crate::control_plane::messages::MetricsAuxInfo;
25 : use crate::error::{ReportableError, UserFacingError};
26 : use crate::metrics::{Metrics, NumDbConnectionsGuard};
27 : use crate::proxy::neon_option;
28 : use crate::types::Host;
29 :
30 : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
31 :
32 0 : #[derive(Debug, Error)]
33 : pub(crate) enum ConnectionError {
34 : /// This error doesn't seem to reveal any secrets; for instance,
35 : /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
36 : #[error("{COULD_NOT_CONNECT}: {0}")]
37 : Postgres(#[from] tokio_postgres::Error),
38 :
39 : #[error("{COULD_NOT_CONNECT}: {0}")]
40 : CouldNotConnect(#[from] io::Error),
41 :
42 : #[error("Couldn't load native TLS certificates: {0:?}")]
43 : TlsCertificateError(Vec<rustls_native_certs::Error>),
44 :
45 : #[error("{COULD_NOT_CONNECT}: {0}")]
46 : TlsError(#[from] InvalidDnsNameError),
47 :
48 : #[error("{COULD_NOT_CONNECT}: {0}")]
49 : WakeComputeError(#[from] WakeComputeError),
50 :
51 : #[error("error acquiring resource permit: {0}")]
52 : TooManyConnectionAttempts(#[from] ApiLockError),
53 : }
54 :
55 : impl UserFacingError for ConnectionError {
56 0 : fn to_string_client(&self) -> String {
57 0 : match self {
58 : // This helps us drop irrelevant library-specific prefixes.
59 : // TODO: propagate severity level and other parameters.
60 0 : ConnectionError::Postgres(err) => match err.as_db_error() {
61 0 : Some(err) => {
62 0 : let msg = err.message();
63 0 :
64 0 : if msg.starts_with("unsupported startup parameter: ")
65 0 : || msg.starts_with("unsupported startup parameter in options: ")
66 : {
67 0 : format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
68 : } else {
69 0 : msg.to_owned()
70 : }
71 : }
72 0 : None => err.to_string(),
73 : },
74 0 : ConnectionError::WakeComputeError(err) => err.to_string_client(),
75 : ConnectionError::TooManyConnectionAttempts(_) => {
76 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
77 : }
78 0 : _ => COULD_NOT_CONNECT.to_owned(),
79 : }
80 0 : }
81 : }
82 :
83 : impl ReportableError for ConnectionError {
84 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
85 0 : match self {
86 0 : ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
87 0 : crate::error::ErrorKind::Postgres
88 : }
89 0 : ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
90 0 : ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
91 0 : ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
92 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
93 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
94 0 : ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
95 : }
96 0 : }
97 : }
98 :
99 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
100 : pub(crate) type ScramKeys = tokio_postgres::config::ScramKeys<32>;
101 :
102 : /// A config for establishing a connection to compute node.
103 : /// Eventually, `tokio_postgres` will be replaced with something better.
104 : /// Newtype allows us to implement methods on top of it.
105 : #[derive(Clone, Default)]
106 : pub(crate) struct ConnCfg(Box<tokio_postgres::Config>);
107 :
108 : /// Creation and initialization routines.
109 : impl ConnCfg {
110 10 : pub(crate) fn new() -> Self {
111 10 : Self::default()
112 10 : }
113 :
114 : /// Reuse password or auth keys from the other config.
115 4 : pub(crate) fn reuse_password(&mut self, other: Self) {
116 4 : if let Some(password) = other.get_password() {
117 4 : self.password(password);
118 4 : }
119 :
120 4 : if let Some(keys) = other.get_auth_keys() {
121 0 : self.auth_keys(keys);
122 4 : }
123 4 : }
124 :
125 0 : pub(crate) fn get_host(&self) -> Result<Host, WakeComputeError> {
126 0 : match self.0.get_hosts() {
127 0 : [tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
128 : // we should not have multiple address or unix addresses.
129 0 : _ => Err(WakeComputeError::BadComputeAddress(
130 0 : "invalid compute address".into(),
131 0 : )),
132 : }
133 0 : }
134 :
135 : /// Apply startup message params to the connection config.
136 0 : pub(crate) fn set_startup_params(&mut self, params: &StartupMessageParams) {
137 : // Only set `user` if it's not present in the config.
138 : // Console redirect auth flow takes username from the console's response.
139 0 : if let (None, Some(user)) = (self.get_user(), params.get("user")) {
140 0 : self.user(user);
141 0 : }
142 :
143 : // Only set `dbname` if it's not present in the config.
144 : // Console redirect auth flow takes dbname from the console's response.
145 0 : if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
146 0 : self.dbname(dbname);
147 0 : }
148 :
149 : // Don't add `options` if they were only used for specifying a project.
150 : // Connection pools don't support `options`, because they affect backend startup.
151 0 : if let Some(options) = filtered_options(params) {
152 0 : self.options(&options);
153 0 : }
154 :
155 0 : if let Some(app_name) = params.get("application_name") {
156 0 : self.application_name(app_name);
157 0 : }
158 :
159 : // TODO: This is especially ugly...
160 0 : if let Some(replication) = params.get("replication") {
161 : use tokio_postgres::config::ReplicationMode;
162 0 : match replication {
163 0 : "true" | "on" | "yes" | "1" => {
164 0 : self.replication_mode(ReplicationMode::Physical);
165 0 : }
166 0 : "database" => {
167 0 : self.replication_mode(ReplicationMode::Logical);
168 0 : }
169 0 : _other => {}
170 : }
171 0 : }
172 :
173 : // TODO: extend the list of the forwarded startup parameters.
174 : // Currently, tokio-postgres doesn't allow us to pass
175 : // arbitrary parameters, but the ones above are a good start.
176 : //
177 : // This and the reverse params problem can be better addressed
178 : // in a bespoke connection machinery (a new library for that sake).
179 0 : }
180 : }
181 :
182 : impl std::ops::Deref for ConnCfg {
183 : type Target = tokio_postgres::Config;
184 :
185 8 : fn deref(&self) -> &Self::Target {
186 8 : &self.0
187 8 : }
188 : }
189 :
190 : /// For now, let's make it easier to setup the config.
191 : impl std::ops::DerefMut for ConnCfg {
192 10 : fn deref_mut(&mut self) -> &mut Self::Target {
193 10 : &mut self.0
194 10 : }
195 : }
196 :
197 : impl ConnCfg {
198 : /// Establish a raw TCP connection to the compute node.
199 0 : async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
200 : use tokio_postgres::config::Host;
201 :
202 : // wrap TcpStream::connect with timeout
203 0 : let connect_with_timeout = |host, port| {
204 0 : tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
205 0 : move |res| match res {
206 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
207 0 : Err(_) => Err(io::Error::new(
208 0 : io::ErrorKind::TimedOut,
209 0 : format!("exceeded connection timeout {timeout:?}"),
210 0 : )),
211 0 : },
212 0 : )
213 0 : };
214 :
215 0 : let connect_once = |host, port| {
216 0 : info!("trying to connect to compute node at {host}:{port}");
217 0 : connect_with_timeout(host, port).and_then(|socket| async {
218 0 : let socket_addr = socket.peer_addr()?;
219 : // This prevents load balancer from severing the connection.
220 0 : socket2::SockRef::from(&socket).set_keepalive(true)?;
221 0 : Ok((socket_addr, socket))
222 0 : })
223 0 : };
224 :
225 : // We can't reuse connection establishing logic from `tokio_postgres` here,
226 : // because it has no means for extracting the underlying socket which we
227 : // require for our business.
228 0 : let mut connection_error = None;
229 0 : let ports = self.0.get_ports();
230 0 : let hosts = self.0.get_hosts();
231 0 : // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
232 0 : if ports.len() > 1 && ports.len() != hosts.len() {
233 0 : return Err(io::Error::new(
234 0 : io::ErrorKind::Other,
235 0 : format!(
236 0 : "bad compute config, \
237 0 : ports and hosts entries' count does not match: {:?}",
238 0 : self.0
239 0 : ),
240 0 : ));
241 0 : }
242 :
243 0 : for (i, host) in hosts.iter().enumerate() {
244 0 : let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
245 0 : let host = match host {
246 0 : Host::Tcp(host) => host.as_str(),
247 0 : Host::Unix(_) => continue, // unix sockets are not welcome here
248 : };
249 :
250 0 : match connect_once(host, *port).await {
251 0 : Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
252 0 : Err(err) => {
253 0 : // We can't throw an error here, as there might be more hosts to try.
254 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
255 0 : connection_error = Some(err);
256 : }
257 : }
258 : }
259 :
260 0 : Err(connection_error.unwrap_or_else(|| {
261 0 : io::Error::new(
262 0 : io::ErrorKind::Other,
263 0 : format!("bad compute config: {:?}", self.0),
264 0 : )
265 0 : }))
266 0 : }
267 : }
268 :
269 : type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
270 :
271 : pub(crate) struct PostgresConnection {
272 : /// Socket connected to a compute node.
273 : pub(crate) stream:
274 : tokio_postgres::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
275 : /// PostgreSQL connection parameters.
276 : pub(crate) params: std::collections::HashMap<String, String>,
277 : /// Query cancellation token.
278 : pub(crate) cancel_closure: CancelClosure,
279 : /// Labels for proxy's metrics.
280 : pub(crate) aux: MetricsAuxInfo,
281 :
282 : _guage: NumDbConnectionsGuard<'static>,
283 : }
284 :
285 : impl ConnCfg {
286 : /// Connect to a corresponding compute node.
287 0 : pub(crate) async fn connect(
288 0 : &self,
289 0 : ctx: &RequestMonitoring,
290 0 : allow_self_signed_compute: bool,
291 0 : aux: MetricsAuxInfo,
292 0 : timeout: Duration,
293 0 : ) -> Result<PostgresConnection, ConnectionError> {
294 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
295 0 : let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
296 0 : drop(pause);
297 :
298 0 : let client_config = if allow_self_signed_compute {
299 : // Allow all certificates for creating the connection
300 0 : let verifier = Arc::new(AcceptEverythingVerifier);
301 0 : rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
302 0 : .with_safe_default_protocol_versions()
303 0 : .expect("ring should support the default protocol versions")
304 0 : .dangerous()
305 0 : .with_custom_certificate_verifier(verifier)
306 : } else {
307 0 : let root_store = TLS_ROOTS
308 0 : .get_or_try_init(load_certs)
309 0 : .map_err(ConnectionError::TlsCertificateError)?
310 0 : .clone();
311 0 : rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
312 0 : .with_safe_default_protocol_versions()
313 0 : .expect("ring should support the default protocol versions")
314 0 : .with_root_certificates(root_store)
315 : };
316 0 : let client_config = client_config.with_no_client_auth();
317 0 :
318 0 : let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
319 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
320 0 : &mut mk_tls,
321 0 : host,
322 0 : )?;
323 :
324 : // connect_raw() will not use TLS if sslmode is "disable"
325 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
326 0 : let (client, connection) = self.0.connect_raw(stream, tls).await?;
327 0 : drop(pause);
328 0 : tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
329 0 : let stream = connection.stream.into_inner();
330 0 :
331 0 : info!(
332 0 : cold_start_info = ctx.cold_start_info().as_str(),
333 0 : "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
334 0 : self.0.get_ssl_mode()
335 : );
336 :
337 : // This is very ugly but as of now there's no better way to
338 : // extract the connection parameters from tokio-postgres' connection.
339 : // TODO: solve this problem in a more elegant manner (e.g. the new library).
340 0 : let params = connection.parameters;
341 0 :
342 0 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
343 0 : // Yet another reason to rework the connection establishing code.
344 0 : let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
345 0 :
346 0 : let connection = PostgresConnection {
347 0 : stream,
348 0 : params,
349 0 : cancel_closure,
350 0 : aux,
351 0 : _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
352 0 : };
353 0 :
354 0 : Ok(connection)
355 0 : }
356 : }
357 :
358 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
359 6 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
360 : #[allow(unstable_name_collisions)]
361 6 : let options: String = params
362 6 : .options_raw()?
363 13 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
364 6 : .intersperse(" ") // TODO: use impl from std once it's stabilized
365 6 : .collect();
366 6 :
367 6 : // Don't even bother with empty options.
368 6 : if options.is_empty() {
369 3 : return None;
370 3 : }
371 3 :
372 3 : Some(options)
373 6 : }
374 :
375 0 : fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
376 0 : let der_certs = rustls_native_certs::load_native_certs();
377 0 :
378 0 : if !der_certs.errors.is_empty() {
379 0 : return Err(der_certs.errors);
380 0 : }
381 0 :
382 0 : let mut store = rustls::RootCertStore::empty();
383 0 : store.add_parsable_certificates(der_certs.certs);
384 0 : Ok(Arc::new(store))
385 0 : }
386 : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
387 :
388 : #[derive(Debug)]
389 : struct AcceptEverythingVerifier;
390 : impl ServerCertVerifier for AcceptEverythingVerifier {
391 0 : fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
392 : use rustls::SignatureScheme;
393 : // The schemes for which `SignatureScheme::supported_in_tls13` returns true.
394 0 : vec![
395 0 : SignatureScheme::ECDSA_NISTP521_SHA512,
396 0 : SignatureScheme::ECDSA_NISTP384_SHA384,
397 0 : SignatureScheme::ECDSA_NISTP256_SHA256,
398 0 : SignatureScheme::RSA_PSS_SHA512,
399 0 : SignatureScheme::RSA_PSS_SHA384,
400 0 : SignatureScheme::RSA_PSS_SHA256,
401 0 : SignatureScheme::ED25519,
402 0 : ]
403 0 : }
404 0 : fn verify_server_cert(
405 0 : &self,
406 0 : _end_entity: &rustls::pki_types::CertificateDer<'_>,
407 0 : _intermediates: &[rustls::pki_types::CertificateDer<'_>],
408 0 : _server_name: &rustls::pki_types::ServerName<'_>,
409 0 : _ocsp_response: &[u8],
410 0 : _now: rustls::pki_types::UnixTime,
411 0 : ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
412 0 : Ok(rustls::client::danger::ServerCertVerified::assertion())
413 0 : }
414 0 : fn verify_tls12_signature(
415 0 : &self,
416 0 : _message: &[u8],
417 0 : _cert: &rustls::pki_types::CertificateDer<'_>,
418 0 : _dss: &rustls::DigitallySignedStruct,
419 0 : ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
420 0 : Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
421 0 : }
422 0 : fn verify_tls13_signature(
423 0 : &self,
424 0 : _message: &[u8],
425 0 : _cert: &rustls::pki_types::CertificateDer<'_>,
426 0 : _dss: &rustls::DigitallySignedStruct,
427 0 : ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
428 0 : Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
429 0 : }
430 : }
431 :
432 : #[cfg(test)]
433 : mod tests {
434 : use super::*;
435 :
436 : #[test]
437 1 : fn test_filtered_options() {
438 1 : // Empty options is unlikely to be useful anyway.
439 1 : let params = StartupMessageParams::new([("options", "")]);
440 1 : assert_eq!(filtered_options(¶ms), None);
441 :
442 : // It's likely that clients will only use options to specify endpoint/project.
443 1 : let params = StartupMessageParams::new([("options", "project=foo")]);
444 1 : assert_eq!(filtered_options(¶ms), None);
445 :
446 : // Same, because unescaped whitespaces are no-op.
447 1 : let params = StartupMessageParams::new([("options", " project=foo ")]);
448 1 : assert_eq!(filtered_options(¶ms).as_deref(), None);
449 :
450 1 : let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]);
451 1 : assert_eq!(filtered_options(¶ms).as_deref(), Some(r"\ \ "));
452 :
453 1 : let params = StartupMessageParams::new([("options", "project = foo")]);
454 1 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
455 :
456 1 : let params = StartupMessageParams::new([(
457 1 : "options",
458 1 : "project = foo neon_endpoint_type:read_write neon_lsn:0/2",
459 1 : )]);
460 1 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
461 1 : }
462 : }
|