Line data Source code
1 : use crate::{
2 : auth::parse_endpoint_param,
3 : cancellation::CancelClosure,
4 : console::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError},
5 : context::RequestMonitoring,
6 : error::{ReportableError, UserFacingError},
7 : metrics::{Metrics, NumDbConnectionsGuard},
8 : proxy::neon_option,
9 : Host,
10 : };
11 : use futures::{FutureExt, TryFutureExt};
12 : use itertools::Itertools;
13 : use once_cell::sync::OnceCell;
14 : use pq_proto::StartupMessageParams;
15 : use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError};
16 : use std::{io, net::SocketAddr, sync::Arc, time::Duration};
17 : use thiserror::Error;
18 : use tokio::net::TcpStream;
19 : use tokio_postgres::tls::MakeTlsConnect;
20 : use tokio_postgres_rustls::MakeRustlsConnect;
21 : use tracing::{error, info, warn};
22 :
23 : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
24 :
25 0 : #[derive(Debug, Error)]
26 : pub enum ConnectionError {
27 : /// This error doesn't seem to reveal any secrets; for instance,
28 : /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
29 : #[error("{COULD_NOT_CONNECT}: {0}")]
30 : Postgres(#[from] tokio_postgres::Error),
31 :
32 : #[error("{COULD_NOT_CONNECT}: {0}")]
33 : CouldNotConnect(#[from] io::Error),
34 :
35 : #[error("{COULD_NOT_CONNECT}: {0}")]
36 : TlsError(#[from] InvalidDnsNameError),
37 :
38 : #[error("{COULD_NOT_CONNECT}: {0}")]
39 : WakeComputeError(#[from] WakeComputeError),
40 :
41 : #[error("error acquiring resource permit: {0}")]
42 : TooManyConnectionAttempts(#[from] ApiLockError),
43 : }
44 :
45 : impl UserFacingError for ConnectionError {
46 0 : fn to_string_client(&self) -> String {
47 0 : use ConnectionError::*;
48 0 : match self {
49 : // This helps us drop irrelevant library-specific prefixes.
50 : // TODO: propagate severity level and other parameters.
51 0 : Postgres(err) => match err.as_db_error() {
52 0 : Some(err) => {
53 0 : let msg = err.message();
54 0 :
55 0 : if msg.starts_with("unsupported startup parameter: ")
56 0 : || msg.starts_with("unsupported startup parameter in options: ")
57 : {
58 0 : format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
59 : } else {
60 0 : msg.to_owned()
61 : }
62 : }
63 0 : None => err.to_string(),
64 : },
65 0 : WakeComputeError(err) => err.to_string_client(),
66 : TooManyConnectionAttempts(_) => {
67 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
68 : }
69 0 : _ => COULD_NOT_CONNECT.to_owned(),
70 : }
71 0 : }
72 : }
73 :
74 : impl ReportableError for ConnectionError {
75 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
76 0 : match self {
77 0 : ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
78 0 : crate::error::ErrorKind::Postgres
79 : }
80 0 : ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
81 0 : ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
82 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
83 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
84 0 : ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
85 : }
86 0 : }
87 : }
88 :
89 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
90 : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
91 :
92 : /// A config for establishing a connection to compute node.
93 : /// Eventually, `tokio_postgres` will be replaced with something better.
94 : /// Newtype allows us to implement methods on top of it.
95 : #[derive(Clone, Default)]
96 : pub struct ConnCfg(Box<tokio_postgres::Config>);
97 :
98 : /// Creation and initialization routines.
99 : impl ConnCfg {
100 20 : pub fn new() -> Self {
101 20 : Self::default()
102 20 : }
103 :
104 : /// Reuse password or auth keys from the other config.
105 8 : pub fn reuse_password(&mut self, other: Self) {
106 8 : if let Some(password) = other.get_auth() {
107 8 : self.auth(password);
108 8 : }
109 8 : }
110 :
111 0 : pub fn get_host(&self) -> Result<Host, WakeComputeError> {
112 0 : match self.0.get_hosts() {
113 0 : [tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
114 : // we should not have multiple address or unix addresses.
115 0 : _ => Err(WakeComputeError::BadComputeAddress(
116 0 : "invalid compute address".into(),
117 0 : )),
118 : }
119 0 : }
120 :
121 : /// Apply startup message params to the connection config.
122 0 : pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
123 0 : let mut client_encoding = false;
124 0 : for (k, v) in params.iter() {
125 0 : match k {
126 0 : "user" => {
127 : // Only set `user` if it's not present in the config.
128 : // Link auth flow takes username from the console's response.
129 0 : if self.get_user().is_none() {
130 0 : self.user(v);
131 0 : }
132 : }
133 0 : "database" => {
134 : // Only set `dbname` if it's not present in the config.
135 : // Link auth flow takes dbname from the console's response.
136 0 : if self.get_dbname().is_none() {
137 0 : self.dbname(v);
138 0 : }
139 : }
140 0 : "options" => {
141 : // Don't add `options` if they were only used for specifying a project.
142 : // Connection pools don't support `options`, because they affect backend startup.
143 0 : if let Some(options) = filtered_options(v) {
144 0 : self.options(&options);
145 0 : }
146 : }
147 :
148 : // the special ones in tokio-postgres that we don't want being set by the user
149 0 : "dbname" => {}
150 0 : "password" => {}
151 0 : "sslmode" => {}
152 0 : "host" => {}
153 0 : "port" => {}
154 0 : "connect_timeout" => {}
155 0 : "keepalives" => {}
156 0 : "keepalives_idle" => {}
157 0 : "keepalives_interval" => {}
158 0 : "keepalives_retries" => {}
159 0 : "target_session_attrs" => {}
160 0 : "channel_binding" => {}
161 0 : "max_backend_message_size" => {}
162 :
163 0 : "client_encoding" => {
164 0 : client_encoding = true;
165 0 : // only error should be from bad null bytes,
166 0 : // but we've already checked for those.
167 0 : _ = self.param("client_encoding", v);
168 0 : }
169 :
170 0 : _ => {
171 0 : // only error should be from bad null bytes,
172 0 : // but we've already checked for those.
173 0 : _ = self.param(k, v);
174 0 : }
175 : }
176 : }
177 0 : if !client_encoding {
178 0 : // for compatibility since we removed it from tokio-postgres
179 0 : self.param("client_encoding", "UTF8").unwrap();
180 0 : }
181 0 : }
182 : }
183 :
184 : impl std::ops::Deref for ConnCfg {
185 : type Target = tokio_postgres::Config;
186 :
187 8 : fn deref(&self) -> &Self::Target {
188 8 : &self.0
189 8 : }
190 : }
191 :
192 : /// For now, let's make it easier to setup the config.
193 : impl std::ops::DerefMut for ConnCfg {
194 20 : fn deref_mut(&mut self) -> &mut Self::Target {
195 20 : &mut self.0
196 20 : }
197 : }
198 :
199 : impl ConnCfg {
200 : /// Establish a raw TCP connection to the compute node.
201 0 : async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
202 0 : use tokio_postgres::config::Host;
203 0 :
204 0 : // wrap TcpStream::connect with timeout
205 0 : let connect_with_timeout = |host, port| {
206 0 : tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
207 0 : move |res| match res {
208 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
209 0 : Err(_) => Err(io::Error::new(
210 0 : io::ErrorKind::TimedOut,
211 0 : format!("exceeded connection timeout {timeout:?}"),
212 0 : )),
213 0 : },
214 0 : )
215 0 : };
216 :
217 0 : let connect_once = |host, port| {
218 0 : info!("trying to connect to compute node at {host}:{port}");
219 0 : connect_with_timeout(host, port).and_then(|socket| async {
220 0 : let socket_addr = socket.peer_addr()?;
221 0 : // This prevents load balancer from severing the connection.
222 0 : socket2::SockRef::from(&socket).set_keepalive(true)?;
223 0 : Ok((socket_addr, socket))
224 0 : })
225 0 : };
226 :
227 : // We can't reuse connection establishing logic from `tokio_postgres` here,
228 : // because it has no means for extracting the underlying socket which we
229 : // require for our business.
230 0 : let mut connection_error = None;
231 0 : let ports = self.0.get_ports();
232 0 : let hosts = self.0.get_hosts();
233 0 : // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
234 0 : if ports.len() > 1 && ports.len() != hosts.len() {
235 0 : return Err(io::Error::new(
236 0 : io::ErrorKind::Other,
237 0 : format!(
238 0 : "bad compute config, \
239 0 : ports and hosts entries' count does not match: {:?}",
240 0 : self.0
241 0 : ),
242 0 : ));
243 0 : }
244 :
245 0 : for (i, host) in hosts.iter().enumerate() {
246 0 : let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
247 0 : let host = match host {
248 0 : Host::Tcp(host) => host.as_str(),
249 0 : Host::Unix(_) => continue, // unix sockets are not welcome here
250 : };
251 :
252 0 : match connect_once(host, *port).await {
253 0 : Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
254 0 : Err(err) => {
255 0 : // We can't throw an error here, as there might be more hosts to try.
256 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
257 0 : connection_error = Some(err);
258 : }
259 : }
260 : }
261 :
262 0 : Err(connection_error.unwrap_or_else(|| {
263 0 : io::Error::new(
264 0 : io::ErrorKind::Other,
265 0 : format!("bad compute config: {:?}", self.0),
266 0 : )
267 0 : }))
268 0 : }
269 : }
270 :
271 : pub struct PostgresConnection {
272 : /// Socket connected to a compute node.
273 : pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
274 : tokio::net::TcpStream,
275 : tokio_postgres_rustls::RustlsStream<tokio::net::TcpStream>,
276 : >,
277 : /// PostgreSQL connection parameters.
278 : pub params: std::collections::HashMap<String, String>,
279 : /// Query cancellation token.
280 : pub cancel_closure: CancelClosure,
281 : /// Labels for proxy's metrics.
282 : pub aux: MetricsAuxInfo,
283 :
284 : _guage: NumDbConnectionsGuard<'static>,
285 : }
286 :
287 : impl ConnCfg {
288 : /// Connect to a corresponding compute node.
289 0 : pub async fn connect(
290 0 : &self,
291 0 : ctx: &mut RequestMonitoring,
292 0 : allow_self_signed_compute: bool,
293 0 : aux: MetricsAuxInfo,
294 0 : timeout: Duration,
295 0 : ) -> Result<PostgresConnection, ConnectionError> {
296 0 : let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
297 0 : let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
298 0 : drop(pause);
299 :
300 0 : let client_config = if allow_self_signed_compute {
301 : // Allow all certificates for creating the connection
302 0 : let verifier = Arc::new(AcceptEverythingVerifier) as Arc<dyn ServerCertVerifier>;
303 0 : rustls::ClientConfig::builder()
304 0 : .dangerous()
305 0 : .with_custom_certificate_verifier(verifier)
306 : } else {
307 0 : let root_store = TLS_ROOTS.get_or_try_init(load_certs)?.clone();
308 0 : rustls::ClientConfig::builder().with_root_certificates(root_store)
309 : };
310 0 : let client_config = client_config.with_no_client_auth();
311 0 :
312 0 : let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
313 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
314 0 : &mut mk_tls,
315 0 : host,
316 0 : )?;
317 :
318 : // connect_raw() will not use TLS if sslmode is "disable"
319 0 : let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
320 0 : let (client, connection) = self.0.connect_raw(stream, tls).await?;
321 0 : drop(pause);
322 0 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
323 0 : let stream = connection.stream.into_inner();
324 0 :
325 0 : info!(
326 0 : cold_start_info = ctx.cold_start_info.as_str(),
327 0 : "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
328 0 : self.0.get_ssl_mode()
329 : );
330 :
331 : // This is very ugly but as of now there's no better way to
332 : // extract the connection parameters from tokio-postgres' connection.
333 : // TODO: solve this problem in a more elegant manner (e.g. the new library).
334 0 : let params = connection.parameters;
335 0 :
336 0 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
337 0 : // Yet another reason to rework the connection establishing code.
338 0 : let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
339 0 :
340 0 : let connection = PostgresConnection {
341 0 : stream,
342 0 : params,
343 0 : cancel_closure,
344 0 : aux,
345 0 : _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol),
346 0 : };
347 0 :
348 0 : Ok(connection)
349 0 : }
350 : }
351 :
352 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
353 12 : fn filtered_options(options: &str) -> Option<String> {
354 12 : #[allow(unstable_name_collisions)]
355 12 : let options: String = StartupMessageParams::parse_options_raw(options)
356 26 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
357 12 : .intersperse(" ") // TODO: use impl from std once it's stabilized
358 12 : .collect();
359 12 :
360 12 : // Don't even bother with empty options.
361 12 : if options.is_empty() {
362 6 : return None;
363 6 : }
364 6 :
365 6 : Some(options)
366 12 : }
367 :
368 0 : fn load_certs() -> Result<Arc<rustls::RootCertStore>, io::Error> {
369 0 : let der_certs = rustls_native_certs::load_native_certs()?;
370 0 : let mut store = rustls::RootCertStore::empty();
371 0 : store.add_parsable_certificates(der_certs);
372 0 : Ok(Arc::new(store))
373 0 : }
374 : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
375 :
376 : #[derive(Debug)]
377 : struct AcceptEverythingVerifier;
378 : impl ServerCertVerifier for AcceptEverythingVerifier {
379 0 : fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
380 0 : use rustls::SignatureScheme::*;
381 0 : // The schemes for which `SignatureScheme::supported_in_tls13` returns true.
382 0 : vec![
383 0 : ECDSA_NISTP521_SHA512,
384 0 : ECDSA_NISTP384_SHA384,
385 0 : ECDSA_NISTP256_SHA256,
386 0 : RSA_PSS_SHA512,
387 0 : RSA_PSS_SHA384,
388 0 : RSA_PSS_SHA256,
389 0 : ED25519,
390 0 : ]
391 0 : }
392 0 : fn verify_server_cert(
393 0 : &self,
394 0 : _end_entity: &rustls::pki_types::CertificateDer<'_>,
395 0 : _intermediates: &[rustls::pki_types::CertificateDer<'_>],
396 0 : _server_name: &rustls::pki_types::ServerName<'_>,
397 0 : _ocsp_response: &[u8],
398 0 : _now: rustls::pki_types::UnixTime,
399 0 : ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
400 0 : Ok(rustls::client::danger::ServerCertVerified::assertion())
401 0 : }
402 0 : fn verify_tls12_signature(
403 0 : &self,
404 0 : _message: &[u8],
405 0 : _cert: &rustls::pki_types::CertificateDer<'_>,
406 0 : _dss: &rustls::DigitallySignedStruct,
407 0 : ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
408 0 : Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
409 0 : }
410 0 : fn verify_tls13_signature(
411 0 : &self,
412 0 : _message: &[u8],
413 0 : _cert: &rustls::pki_types::CertificateDer<'_>,
414 0 : _dss: &rustls::DigitallySignedStruct,
415 0 : ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
416 0 : Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
417 0 : }
418 : }
419 :
420 : #[cfg(test)]
421 : mod tests {
422 : use super::*;
423 :
424 : #[test]
425 2 : fn test_filtered_options() {
426 2 : // Empty options is unlikely to be useful anyway.
427 2 : assert_eq!(filtered_options(""), None);
428 :
429 : // It's likely that clients will only use options to specify endpoint/project.
430 2 : let params = "project=foo";
431 2 : assert_eq!(filtered_options(params), None);
432 :
433 : // Same, because unescaped whitespaces are no-op.
434 2 : let params = " project=foo ";
435 2 : assert_eq!(filtered_options(params), None);
436 :
437 2 : let params = r"\ project=foo \ ";
438 2 : assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
439 :
440 2 : let params = "project = foo";
441 2 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
442 :
443 2 : let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2";
444 2 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
445 2 : }
446 : }
|