Line data Source code
1 : use std::io;
2 : use std::net::SocketAddr;
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use futures::{FutureExt, TryFutureExt};
7 : use itertools::Itertools;
8 : use once_cell::sync::OnceCell;
9 : use postgres_client::tls::MakeTlsConnect;
10 : use postgres_client::{CancelToken, RawConnection};
11 : use postgres_protocol::message::backend::NoticeResponseBody;
12 : use pq_proto::StartupMessageParams;
13 : use rustls::crypto::ring;
14 : use rustls::pki_types::InvalidDnsNameError;
15 : use thiserror::Error;
16 : use tokio::net::TcpStream;
17 : use tracing::{debug, error, info, warn};
18 :
19 : use crate::auth::parse_endpoint_param;
20 : use crate::cancellation::CancelClosure;
21 : use crate::context::RequestContext;
22 : use crate::control_plane::client::ApiLockError;
23 : use crate::control_plane::errors::WakeComputeError;
24 : use crate::control_plane::messages::MetricsAuxInfo;
25 : use crate::error::{ReportableError, UserFacingError};
26 : use crate::metrics::{Metrics, NumDbConnectionsGuard};
27 : use crate::postgres_rustls::MakeRustlsConnect;
28 : use crate::proxy::neon_option;
29 : use crate::types::Host;
30 :
31 : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
32 :
33 : #[derive(Debug, Error)]
34 : pub(crate) enum ConnectionError {
35 : /// This error doesn't seem to reveal any secrets; for instance,
36 : /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
37 : #[error("{COULD_NOT_CONNECT}: {0}")]
38 : Postgres(#[from] postgres_client::Error),
39 :
40 : #[error("{COULD_NOT_CONNECT}: {0}")]
41 : CouldNotConnect(#[from] io::Error),
42 :
43 : #[error("Couldn't load native TLS certificates: {0:?}")]
44 : TlsCertificateError(Vec<rustls_native_certs::Error>),
45 :
46 : #[error("{COULD_NOT_CONNECT}: {0}")]
47 : TlsError(#[from] InvalidDnsNameError),
48 :
49 : #[error("{COULD_NOT_CONNECT}: {0}")]
50 : WakeComputeError(#[from] WakeComputeError),
51 :
52 : #[error("error acquiring resource permit: {0}")]
53 : TooManyConnectionAttempts(#[from] ApiLockError),
54 : }
55 :
56 : impl UserFacingError for ConnectionError {
57 0 : fn to_string_client(&self) -> String {
58 0 : match self {
59 : // This helps us drop irrelevant library-specific prefixes.
60 : // TODO: propagate severity level and other parameters.
61 0 : ConnectionError::Postgres(err) => match err.as_db_error() {
62 0 : Some(err) => {
63 0 : let msg = err.message();
64 0 :
65 0 : if msg.starts_with("unsupported startup parameter: ")
66 0 : || msg.starts_with("unsupported startup parameter in options: ")
67 : {
68 0 : format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
69 : } else {
70 0 : msg.to_owned()
71 : }
72 : }
73 0 : None => err.to_string(),
74 : },
75 0 : ConnectionError::WakeComputeError(err) => err.to_string_client(),
76 : ConnectionError::TooManyConnectionAttempts(_) => {
77 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
78 : }
79 0 : _ => COULD_NOT_CONNECT.to_owned(),
80 : }
81 0 : }
82 : }
83 :
84 : impl ReportableError for ConnectionError {
85 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
86 0 : match self {
87 0 : ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
88 0 : crate::error::ErrorKind::Postgres
89 : }
90 0 : ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
91 0 : ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
92 0 : ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
93 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
94 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
95 0 : ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
96 : }
97 0 : }
98 : }
99 :
100 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
101 : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
102 :
103 : /// A config for establishing a connection to compute node.
104 : /// Eventually, `postgres_client` will be replaced with something better.
105 : /// Newtype allows us to implement methods on top of it.
106 : #[derive(Clone)]
107 : pub(crate) struct ConnCfg(Box<postgres_client::Config>);
108 :
109 : /// Creation and initialization routines.
110 : impl ConnCfg {
111 10 : pub(crate) fn new(host: String, port: u16) -> Self {
112 10 : Self(Box::new(postgres_client::Config::new(host, port)))
113 10 : }
114 :
115 : /// Reuse password or auth keys from the other config.
116 4 : pub(crate) fn reuse_password(&mut self, other: Self) {
117 4 : if let Some(password) = other.get_password() {
118 4 : self.password(password);
119 4 : }
120 :
121 4 : if let Some(keys) = other.get_auth_keys() {
122 0 : self.auth_keys(keys);
123 4 : }
124 4 : }
125 :
126 0 : pub(crate) fn get_host(&self) -> Host {
127 0 : match self.0.get_host() {
128 0 : postgres_client::config::Host::Tcp(s) => s.into(),
129 0 : }
130 0 : }
131 :
132 : /// Apply startup message params to the connection config.
133 0 : pub(crate) fn set_startup_params(
134 0 : &mut self,
135 0 : params: &StartupMessageParams,
136 0 : arbitrary_params: bool,
137 0 : ) {
138 0 : if !arbitrary_params {
139 0 : self.set_param("client_encoding", "UTF8");
140 0 : }
141 0 : for (k, v) in params.iter() {
142 0 : match k {
143 : // Only set `user` if it's not present in the config.
144 : // Console redirect auth flow takes username from the console's response.
145 0 : "user" if self.user_is_set() => continue,
146 0 : "database" if self.db_is_set() => continue,
147 0 : "options" => {
148 0 : if let Some(options) = filtered_options(v) {
149 0 : self.set_param(k, &options);
150 0 : }
151 : }
152 0 : "user" | "database" | "application_name" | "replication" => {
153 0 : self.set_param(k, v);
154 0 : }
155 :
156 : // if we allow arbitrary params, then we forward them through.
157 : // this is a flag for a period of backwards compatibility
158 0 : k if arbitrary_params => {
159 0 : self.set_param(k, v);
160 0 : }
161 0 : _ => {}
162 : }
163 : }
164 0 : }
165 : }
166 :
167 : impl std::ops::Deref for ConnCfg {
168 : type Target = postgres_client::Config;
169 :
170 8 : fn deref(&self) -> &Self::Target {
171 8 : &self.0
172 8 : }
173 : }
174 :
175 : /// For now, let's make it easier to setup the config.
176 : impl std::ops::DerefMut for ConnCfg {
177 10 : fn deref_mut(&mut self) -> &mut Self::Target {
178 10 : &mut self.0
179 10 : }
180 : }
181 :
182 : impl ConnCfg {
183 : /// Establish a raw TCP connection to the compute node.
184 0 : async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
185 : use postgres_client::config::Host;
186 :
187 : // wrap TcpStream::connect with timeout
188 0 : let connect_with_timeout = |host, port| {
189 0 : tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
190 0 : move |res| match res {
191 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
192 0 : Err(_) => Err(io::Error::new(
193 0 : io::ErrorKind::TimedOut,
194 0 : format!("exceeded connection timeout {timeout:?}"),
195 0 : )),
196 0 : },
197 0 : )
198 0 : };
199 :
200 0 : let connect_once = |host, port| {
201 0 : debug!("trying to connect to compute node at {host}:{port}");
202 0 : connect_with_timeout(host, port).and_then(|socket| async {
203 0 : let socket_addr = socket.peer_addr()?;
204 : // This prevents load balancer from severing the connection.
205 0 : socket2::SockRef::from(&socket).set_keepalive(true)?;
206 0 : Ok((socket_addr, socket))
207 0 : })
208 0 : };
209 :
210 : // We can't reuse connection establishing logic from `postgres_client` here,
211 : // because it has no means for extracting the underlying socket which we
212 : // require for our business.
213 0 : let port = self.0.get_port();
214 0 : let host = self.0.get_host();
215 0 :
216 0 : let host = match host {
217 0 : Host::Tcp(host) => host.as_str(),
218 0 : };
219 0 :
220 0 : match connect_once(host, port).await {
221 0 : Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
222 0 : Err(err) => {
223 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
224 0 : Err(err)
225 : }
226 : }
227 0 : }
228 : }
229 :
230 : type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
231 :
232 : pub(crate) struct PostgresConnection {
233 : /// Socket connected to a compute node.
234 : pub(crate) stream:
235 : postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
236 : /// PostgreSQL connection parameters.
237 : pub(crate) params: std::collections::HashMap<String, String>,
238 : /// Query cancellation token.
239 : pub(crate) cancel_closure: CancelClosure,
240 : /// Labels for proxy's metrics.
241 : pub(crate) aux: MetricsAuxInfo,
242 : /// Notices received from compute after authenticating
243 : pub(crate) delayed_notice: Vec<NoticeResponseBody>,
244 :
245 : _guage: NumDbConnectionsGuard<'static>,
246 : }
247 :
248 : impl ConnCfg {
249 : /// Connect to a corresponding compute node.
250 0 : pub(crate) async fn connect(
251 0 : &self,
252 0 : ctx: &RequestContext,
253 0 : aux: MetricsAuxInfo,
254 0 : timeout: Duration,
255 0 : ) -> Result<PostgresConnection, ConnectionError> {
256 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
257 0 : let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
258 0 : drop(pause);
259 :
260 0 : let root_store = TLS_ROOTS
261 0 : .get_or_try_init(load_certs)
262 0 : .map_err(ConnectionError::TlsCertificateError)?
263 0 : .clone();
264 0 :
265 0 : let client_config =
266 0 : rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
267 0 : .with_safe_default_protocol_versions()
268 0 : .expect("ring should support the default protocol versions")
269 0 : .with_root_certificates(root_store)
270 0 : .with_no_client_auth();
271 0 :
272 0 : let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
273 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
274 0 : &mut mk_tls,
275 0 : host,
276 0 : )?;
277 :
278 : // connect_raw() will not use TLS if sslmode is "disable"
279 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
280 0 : let connection = self.0.connect_raw(stream, tls).await?;
281 0 : drop(pause);
282 0 :
283 0 : let RawConnection {
284 0 : stream,
285 0 : parameters,
286 0 : delayed_notice,
287 0 : process_id,
288 0 : secret_key,
289 0 : } = connection;
290 0 :
291 0 : tracing::Span::current().record("pid", tracing::field::display(process_id));
292 0 : let stream = stream.into_inner();
293 0 :
294 0 : // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
295 0 : info!(
296 0 : cold_start_info = ctx.cold_start_info().as_str(),
297 0 : "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
298 0 : self.0.get_ssl_mode()
299 : );
300 :
301 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
302 : // Yet another reason to rework the connection establishing code.
303 0 : let cancel_closure = CancelClosure::new(
304 0 : socket_addr,
305 0 : CancelToken {
306 0 : socket_config: None,
307 0 : ssl_mode: self.0.get_ssl_mode(),
308 0 : process_id,
309 0 : secret_key,
310 0 : },
311 0 : vec![],
312 0 : host.to_string(),
313 0 : );
314 0 :
315 0 : let connection = PostgresConnection {
316 0 : stream,
317 0 : params: parameters,
318 0 : delayed_notice,
319 0 : cancel_closure,
320 0 : aux,
321 0 : _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
322 0 : };
323 0 :
324 0 : Ok(connection)
325 0 : }
326 : }
327 :
328 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
329 6 : fn filtered_options(options: &str) -> Option<String> {
330 6 : #[allow(unstable_name_collisions)]
331 6 : let options: String = StartupMessageParams::parse_options_raw(options)
332 14 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
333 6 : .intersperse(" ") // TODO: use impl from std once it's stabilized
334 6 : .collect();
335 6 :
336 6 : // Don't even bother with empty options.
337 6 : if options.is_empty() {
338 3 : return None;
339 3 : }
340 3 :
341 3 : Some(options)
342 6 : }
343 :
344 0 : pub(crate) fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
345 0 : let der_certs = rustls_native_certs::load_native_certs();
346 0 :
347 0 : if !der_certs.errors.is_empty() {
348 0 : return Err(der_certs.errors);
349 0 : }
350 0 :
351 0 : let mut store = rustls::RootCertStore::empty();
352 0 : store.add_parsable_certificates(der_certs.certs);
353 0 : Ok(Arc::new(store))
354 0 : }
355 : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
356 :
357 : #[cfg(test)]
358 : mod tests {
359 : use super::*;
360 :
361 : #[test]
362 1 : fn test_filtered_options() {
363 1 : // Empty options is unlikely to be useful anyway.
364 1 : let params = "";
365 1 : assert_eq!(filtered_options(params), None);
366 :
367 : // It's likely that clients will only use options to specify endpoint/project.
368 1 : let params = "project=foo";
369 1 : assert_eq!(filtered_options(params), None);
370 :
371 : // Same, because unescaped whitespaces are no-op.
372 1 : let params = " project=foo ";
373 1 : assert_eq!(filtered_options(params).as_deref(), None);
374 :
375 1 : let params = r"\ project=foo \ ";
376 1 : assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
377 :
378 1 : let params = "project = foo";
379 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
380 :
381 1 : let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true";
382 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
383 1 : }
384 : }
|