Line data Source code
1 : use std::fmt::Debug;
2 : use std::io;
3 : use std::net::SocketAddr;
4 : use std::time::Duration;
5 :
6 : use futures::{FutureExt, TryFutureExt};
7 : use itertools::Itertools;
8 : use postgres_client::tls::MakeTlsConnect;
9 : use postgres_client::{CancelToken, RawConnection};
10 : use postgres_protocol::message::backend::NoticeResponseBody;
11 : use pq_proto::StartupMessageParams;
12 : use rustls::pki_types::InvalidDnsNameError;
13 : use thiserror::Error;
14 : use tokio::net::{TcpStream, lookup_host};
15 : use tracing::{debug, error, info, warn};
16 :
17 : use crate::auth::backend::ComputeUserInfo;
18 : use crate::auth::parse_endpoint_param;
19 : use crate::cancellation::CancelClosure;
20 : use crate::config::ComputeConfig;
21 : use crate::context::RequestContext;
22 : use crate::control_plane::client::ApiLockError;
23 : use crate::control_plane::errors::WakeComputeError;
24 : use crate::control_plane::messages::MetricsAuxInfo;
25 : use crate::error::{ReportableError, UserFacingError};
26 : use crate::metrics::{Metrics, NumDbConnectionsGuard};
27 : use crate::proxy::neon_option;
28 : use crate::tls::postgres_rustls::MakeRustlsConnect;
29 : use crate::types::Host;
30 :
31 : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
32 :
33 : #[derive(Debug, Error)]
34 : pub(crate) enum ConnectionError {
35 : /// This error doesn't seem to reveal any secrets; for instance,
36 : /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
37 : #[error("{COULD_NOT_CONNECT}: {0}")]
38 : Postgres(#[from] postgres_client::Error),
39 :
40 : #[error("{COULD_NOT_CONNECT}: {0}")]
41 : CouldNotConnect(#[from] io::Error),
42 :
43 : #[error("{COULD_NOT_CONNECT}: {0}")]
44 : TlsError(#[from] InvalidDnsNameError),
45 :
46 : #[error("{COULD_NOT_CONNECT}: {0}")]
47 : WakeComputeError(#[from] WakeComputeError),
48 :
49 : #[error("error acquiring resource permit: {0}")]
50 : TooManyConnectionAttempts(#[from] ApiLockError),
51 : }
52 :
53 : impl UserFacingError for ConnectionError {
54 0 : fn to_string_client(&self) -> String {
55 0 : match self {
56 : // This helps us drop irrelevant library-specific prefixes.
57 : // TODO: propagate severity level and other parameters.
58 0 : ConnectionError::Postgres(err) => match err.as_db_error() {
59 0 : Some(err) => {
60 0 : let msg = err.message();
61 0 :
62 0 : if msg.starts_with("unsupported startup parameter: ")
63 0 : || msg.starts_with("unsupported startup parameter in options: ")
64 : {
65 0 : format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
66 : } else {
67 0 : msg.to_owned()
68 : }
69 : }
70 0 : None => err.to_string(),
71 : },
72 0 : ConnectionError::WakeComputeError(err) => err.to_string_client(),
73 : ConnectionError::TooManyConnectionAttempts(_) => {
74 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
75 : }
76 0 : _ => COULD_NOT_CONNECT.to_owned(),
77 : }
78 0 : }
79 : }
80 :
81 : impl ReportableError for ConnectionError {
82 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
83 0 : match self {
84 0 : ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
85 0 : crate::error::ErrorKind::Postgres
86 : }
87 0 : ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
88 0 : ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
89 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
90 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
91 0 : ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
92 : }
93 0 : }
94 : }
95 :
96 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
97 : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
98 :
99 : /// A config for establishing a connection to compute node.
100 : /// Eventually, `postgres_client` will be replaced with something better.
101 : /// Newtype allows us to implement methods on top of it.
102 : #[derive(Clone)]
103 : pub(crate) struct ConnCfg(Box<postgres_client::Config>);
104 :
105 : /// Creation and initialization routines.
106 : impl ConnCfg {
107 10 : pub(crate) fn new(host: String, port: u16) -> Self {
108 10 : Self(Box::new(postgres_client::Config::new(host, port)))
109 10 : }
110 :
111 : /// Reuse password or auth keys from the other config.
112 4 : pub(crate) fn reuse_password(&mut self, other: Self) {
113 4 : if let Some(password) = other.get_password() {
114 4 : self.password(password);
115 4 : }
116 :
117 4 : if let Some(keys) = other.get_auth_keys() {
118 0 : self.auth_keys(keys);
119 4 : }
120 4 : }
121 :
122 0 : pub(crate) fn get_host(&self) -> Host {
123 0 : match self.0.get_host() {
124 0 : postgres_client::config::Host::Tcp(s) => s.into(),
125 0 : }
126 0 : }
127 :
128 : /// Apply startup message params to the connection config.
129 0 : pub(crate) fn set_startup_params(
130 0 : &mut self,
131 0 : params: &StartupMessageParams,
132 0 : arbitrary_params: bool,
133 0 : ) {
134 0 : if !arbitrary_params {
135 0 : self.set_param("client_encoding", "UTF8");
136 0 : }
137 0 : for (k, v) in params.iter() {
138 0 : match k {
139 : // Only set `user` if it's not present in the config.
140 : // Console redirect auth flow takes username from the console's response.
141 0 : "user" if self.user_is_set() => {}
142 0 : "database" if self.db_is_set() => {}
143 0 : "options" => {
144 0 : if let Some(options) = filtered_options(v) {
145 0 : self.set_param(k, &options);
146 0 : }
147 : }
148 0 : "user" | "database" | "application_name" | "replication" => {
149 0 : self.set_param(k, v);
150 0 : }
151 :
152 : // if we allow arbitrary params, then we forward them through.
153 : // this is a flag for a period of backwards compatibility
154 0 : k if arbitrary_params => {
155 0 : self.set_param(k, v);
156 0 : }
157 0 : _ => {}
158 : }
159 : }
160 0 : }
161 : }
162 :
163 : impl std::ops::Deref for ConnCfg {
164 : type Target = postgres_client::Config;
165 :
166 8 : fn deref(&self) -> &Self::Target {
167 8 : &self.0
168 8 : }
169 : }
170 :
171 : /// For now, let's make it easier to setup the config.
172 : impl std::ops::DerefMut for ConnCfg {
173 10 : fn deref_mut(&mut self) -> &mut Self::Target {
174 10 : &mut self.0
175 10 : }
176 : }
177 :
178 : impl ConnCfg {
179 : /// Establish a raw TCP connection to the compute node.
180 0 : async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
181 : use postgres_client::config::Host;
182 :
183 : // wrap TcpStream::connect with timeout
184 0 : let connect_with_timeout = |addrs| {
185 0 : tokio::time::timeout(timeout, TcpStream::connect(addrs)).map(move |res| match res {
186 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
187 0 : Err(_) => Err(io::Error::new(
188 0 : io::ErrorKind::TimedOut,
189 0 : format!("exceeded connection timeout {timeout:?}"),
190 0 : )),
191 0 : })
192 0 : };
193 :
194 0 : let connect_once = |addrs| {
195 0 : debug!("trying to connect to compute node at {addrs:?}");
196 0 : connect_with_timeout(addrs).and_then(|stream| async {
197 0 : let socket_addr = stream.peer_addr()?;
198 0 : let socket = socket2::SockRef::from(&stream);
199 0 : // Disable Nagle's algorithm to not introduce latency between
200 0 : // client and compute.
201 0 : socket.set_nodelay(true)?;
202 : // This prevents load balancer from severing the connection.
203 0 : socket.set_keepalive(true)?;
204 0 : Ok((socket_addr, stream))
205 0 : })
206 0 : };
207 :
208 : // We can't reuse connection establishing logic from `postgres_client` here,
209 : // because it has no means for extracting the underlying socket which we
210 : // require for our business.
211 0 : let port = self.0.get_port();
212 0 : let host = self.0.get_host();
213 0 :
214 0 : let host = match host {
215 0 : Host::Tcp(host) => host.as_str(),
216 : };
217 :
218 0 : let addrs = match self.0.get_host_addr() {
219 0 : Some(addr) => vec![SocketAddr::new(addr, port)],
220 0 : None => lookup_host((host, port)).await?.collect(),
221 : };
222 :
223 0 : match connect_once(&*addrs).await {
224 0 : Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
225 0 : Err(err) => {
226 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
227 0 : Err(err)
228 : }
229 : }
230 0 : }
231 : }
232 :
233 : type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
234 :
235 : pub(crate) struct PostgresConnection {
236 : /// Socket connected to a compute node.
237 : pub(crate) stream:
238 : postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
239 : /// PostgreSQL connection parameters.
240 : pub(crate) params: std::collections::HashMap<String, String>,
241 : /// Query cancellation token.
242 : pub(crate) cancel_closure: CancelClosure,
243 : /// Labels for proxy's metrics.
244 : pub(crate) aux: MetricsAuxInfo,
245 : /// Notices received from compute after authenticating
246 : pub(crate) delayed_notice: Vec<NoticeResponseBody>,
247 :
248 : _guage: NumDbConnectionsGuard<'static>,
249 : }
250 :
251 : impl ConnCfg {
252 : /// Connect to a corresponding compute node.
253 0 : pub(crate) async fn connect(
254 0 : &self,
255 0 : ctx: &RequestContext,
256 0 : aux: MetricsAuxInfo,
257 0 : config: &ComputeConfig,
258 0 : user_info: ComputeUserInfo,
259 0 : ) -> Result<PostgresConnection, ConnectionError> {
260 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
261 0 : let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
262 0 : drop(pause);
263 0 :
264 0 : let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
265 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
266 0 : &mut mk_tls,
267 0 : host,
268 0 : )?;
269 :
270 : // connect_raw() will not use TLS if sslmode is "disable"
271 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
272 0 : let connection = self.0.connect_raw(stream, tls).await?;
273 0 : drop(pause);
274 0 :
275 0 : let RawConnection {
276 0 : stream,
277 0 : parameters,
278 0 : delayed_notice,
279 0 : process_id,
280 0 : secret_key,
281 0 : } = connection;
282 0 :
283 0 : tracing::Span::current().record("pid", tracing::field::display(process_id));
284 0 : tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
285 0 : let stream = stream.into_inner();
286 0 :
287 0 : // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
288 0 : info!(
289 0 : cold_start_info = ctx.cold_start_info().as_str(),
290 0 : "connected to compute node at {host} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
291 0 : self.0.get_ssl_mode(),
292 0 : ctx.get_proxy_latency(),
293 0 : ctx.get_testodrome_id().unwrap_or_default(),
294 : );
295 :
296 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
297 : // Yet another reason to rework the connection establishing code.
298 0 : let cancel_closure = CancelClosure::new(
299 0 : socket_addr,
300 0 : CancelToken {
301 0 : socket_config: None,
302 0 : ssl_mode: self.0.get_ssl_mode(),
303 0 : process_id,
304 0 : secret_key,
305 0 : },
306 0 : host.to_string(),
307 0 : user_info,
308 0 : );
309 0 :
310 0 : let connection = PostgresConnection {
311 0 : stream,
312 0 : params: parameters,
313 0 : delayed_notice,
314 0 : cancel_closure,
315 0 : aux,
316 0 : _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
317 0 : };
318 0 :
319 0 : Ok(connection)
320 0 : }
321 : }
322 :
323 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
324 6 : fn filtered_options(options: &str) -> Option<String> {
325 6 : #[allow(unstable_name_collisions)]
326 6 : let options: String = StartupMessageParams::parse_options_raw(options)
327 14 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
328 6 : .intersperse(" ") // TODO: use impl from std once it's stabilized
329 6 : .collect();
330 6 :
331 6 : // Don't even bother with empty options.
332 6 : if options.is_empty() {
333 3 : return None;
334 3 : }
335 3 :
336 3 : Some(options)
337 6 : }
338 :
339 : #[cfg(test)]
340 : mod tests {
341 : use super::*;
342 :
343 : #[test]
344 1 : fn test_filtered_options() {
345 1 : // Empty options is unlikely to be useful anyway.
346 1 : let params = "";
347 1 : assert_eq!(filtered_options(params), None);
348 :
349 : // It's likely that clients will only use options to specify endpoint/project.
350 1 : let params = "project=foo";
351 1 : assert_eq!(filtered_options(params), None);
352 :
353 : // Same, because unescaped whitespaces are no-op.
354 1 : let params = " project=foo ";
355 1 : assert_eq!(filtered_options(params).as_deref(), None);
356 :
357 1 : let params = r"\ project=foo \ ";
358 1 : assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
359 :
360 1 : let params = "project = foo";
361 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
362 :
363 1 : let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true";
364 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
365 1 : }
366 : }
|