Line data Source code
1 : mod tls;
2 :
3 : use std::fmt::Debug;
4 : use std::io;
5 : use std::net::{IpAddr, SocketAddr};
6 :
7 : use futures::{FutureExt, TryFutureExt};
8 : use itertools::Itertools;
9 : use postgres_client::config::{AuthKeys, SslMode};
10 : use postgres_client::maybe_tls_stream::MaybeTlsStream;
11 : use postgres_client::tls::MakeTlsConnect;
12 : use postgres_client::{CancelToken, NoTls, RawConnection};
13 : use postgres_protocol::message::backend::NoticeResponseBody;
14 : use thiserror::Error;
15 : use tokio::net::{TcpStream, lookup_host};
16 : use tracing::{debug, error, info, warn};
17 :
18 : use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
19 : use crate::auth::parse_endpoint_param;
20 : use crate::cancellation::CancelClosure;
21 : use crate::compute::tls::TlsError;
22 : use crate::config::ComputeConfig;
23 : use crate::context::RequestContext;
24 : use crate::control_plane::client::ApiLockError;
25 : use crate::control_plane::errors::WakeComputeError;
26 : use crate::control_plane::messages::MetricsAuxInfo;
27 : use crate::error::{ReportableError, UserFacingError};
28 : use crate::metrics::{Metrics, NumDbConnectionsGuard};
29 : use crate::pqproto::StartupMessageParams;
30 : use crate::proxy::neon_option;
31 : use crate::types::Host;
32 :
33 : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
34 :
35 : #[derive(Debug, Error)]
36 : pub(crate) enum ConnectionError {
37 : /// This error doesn't seem to reveal any secrets; for instance,
38 : /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
39 : #[error("{COULD_NOT_CONNECT}: {0}")]
40 : Postgres(#[from] postgres_client::Error),
41 :
42 : #[error("{COULD_NOT_CONNECT}: {0}")]
43 : TlsError(#[from] TlsError),
44 :
45 : #[error("{COULD_NOT_CONNECT}: {0}")]
46 : WakeComputeError(#[from] WakeComputeError),
47 :
48 : #[error("error acquiring resource permit: {0}")]
49 : TooManyConnectionAttempts(#[from] ApiLockError),
50 : }
51 :
52 : impl UserFacingError for ConnectionError {
53 0 : fn to_string_client(&self) -> String {
54 0 : match self {
55 : // This helps us drop irrelevant library-specific prefixes.
56 : // TODO: propagate severity level and other parameters.
57 0 : ConnectionError::Postgres(err) => match err.as_db_error() {
58 0 : Some(err) => {
59 0 : let msg = err.message();
60 0 :
61 0 : if msg.starts_with("unsupported startup parameter: ")
62 0 : || msg.starts_with("unsupported startup parameter in options: ")
63 : {
64 0 : format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
65 : } else {
66 0 : msg.to_owned()
67 : }
68 : }
69 0 : None => err.to_string(),
70 : },
71 0 : ConnectionError::WakeComputeError(err) => err.to_string_client(),
72 : ConnectionError::TooManyConnectionAttempts(_) => {
73 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
74 : }
75 0 : ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
76 : }
77 0 : }
78 : }
79 :
80 : impl ReportableError for ConnectionError {
81 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
82 0 : match self {
83 0 : ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
84 0 : crate::error::ErrorKind::Postgres
85 : }
86 0 : ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
87 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
88 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
89 0 : ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
90 : }
91 0 : }
92 : }
93 :
94 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
95 : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
96 :
97 : #[derive(Clone)]
98 : pub enum Auth {
99 : /// Only used during console-redirect.
100 : Password(Vec<u8>),
101 : /// Used by sql-over-http, ws, tcp.
102 : Scram(Box<ScramKeys>),
103 : }
104 :
105 : /// A config for authenticating to the compute node.
106 : pub(crate) struct AuthInfo {
107 : /// None for local-proxy, as we use trust-based localhost auth.
108 : /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
109 : /// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
110 : auth: Option<Auth>,
111 : server_params: StartupMessageParams,
112 :
113 : /// Console redirect sets user and database, we shouldn't re-use those from the params.
114 : skip_db_user: bool,
115 : }
116 :
117 : /// Contains only the data needed to establish a secure connection to compute.
118 : #[derive(Clone)]
119 : pub struct ConnectInfo {
120 : pub host_addr: Option<IpAddr>,
121 : pub host: Host,
122 : pub port: u16,
123 : pub ssl_mode: SslMode,
124 : }
125 :
126 : /// Creation and initialization routines.
127 : impl AuthInfo {
128 0 : pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
129 0 : let mut server_params = StartupMessageParams::default();
130 0 : server_params.insert("database", db);
131 0 : server_params.insert("user", user);
132 0 : Self {
133 0 : auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
134 0 : server_params,
135 0 : skip_db_user: true,
136 0 : }
137 0 : }
138 :
139 0 : pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
140 0 : Self {
141 0 : auth: match keys {
142 0 : ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
143 0 : Some(Auth::Scram(Box::new(auth_keys)))
144 : }
145 0 : ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
146 : },
147 0 : server_params: StartupMessageParams::default(),
148 0 : skip_db_user: false,
149 0 : }
150 0 : }
151 : }
152 :
153 : impl ConnectInfo {
154 0 : pub fn to_postgres_client_config(&self) -> postgres_client::Config {
155 0 : let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
156 0 : config.ssl_mode(self.ssl_mode);
157 0 : if let Some(host_addr) = self.host_addr {
158 0 : config.set_host_addr(host_addr);
159 0 : }
160 0 : config
161 0 : }
162 : }
163 :
164 : impl AuthInfo {
165 0 : fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
166 0 : match &self.auth {
167 0 : Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
168 0 : Some(Auth::Password(pw)) => config.password(pw),
169 0 : None => &mut config,
170 : };
171 0 : for (k, v) in self.server_params.iter() {
172 0 : config.set_param(k, v);
173 0 : }
174 0 : config
175 0 : }
176 :
177 : /// Apply startup message params to the connection config.
178 0 : pub(crate) fn set_startup_params(
179 0 : &mut self,
180 0 : params: &StartupMessageParams,
181 0 : arbitrary_params: bool,
182 0 : ) {
183 0 : if !arbitrary_params {
184 0 : self.server_params.insert("client_encoding", "UTF8");
185 0 : }
186 0 : for (k, v) in params.iter() {
187 0 : match k {
188 : // Only set `user` if it's not present in the config.
189 : // Console redirect auth flow takes username from the console's response.
190 0 : "user" | "database" if self.skip_db_user => {}
191 0 : "options" => {
192 0 : if let Some(options) = filtered_options(v) {
193 0 : self.server_params.insert(k, &options);
194 0 : }
195 : }
196 0 : "user" | "database" | "application_name" | "replication" => {
197 0 : self.server_params.insert(k, v);
198 0 : }
199 :
200 : // if we allow arbitrary params, then we forward them through.
201 : // this is a flag for a period of backwards compatibility
202 0 : k if arbitrary_params => {
203 0 : self.server_params.insert(k, v);
204 0 : }
205 0 : _ => {}
206 : }
207 : }
208 0 : }
209 : }
210 :
211 : impl ConnectInfo {
212 : /// Establish a raw TCP+TLS connection to the compute node.
213 0 : async fn connect_raw(
214 0 : &self,
215 0 : config: &ComputeConfig,
216 0 : ) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
217 0 : let timeout = config.timeout;
218 0 :
219 0 : // wrap TcpStream::connect with timeout
220 0 : let connect_with_timeout = |addrs| {
221 0 : tokio::time::timeout(timeout, TcpStream::connect(addrs)).map(move |res| match res {
222 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
223 0 : Err(_) => Err(io::Error::new(
224 0 : io::ErrorKind::TimedOut,
225 0 : format!("exceeded connection timeout {timeout:?}"),
226 0 : )),
227 0 : })
228 0 : };
229 :
230 0 : let connect_once = |addrs| {
231 0 : debug!("trying to connect to compute node at {addrs:?}");
232 0 : connect_with_timeout(addrs).and_then(|stream| async {
233 0 : let socket_addr = stream.peer_addr()?;
234 0 : let socket = socket2::SockRef::from(&stream);
235 0 : // Disable Nagle's algorithm to not introduce latency between
236 0 : // client and compute.
237 0 : socket.set_nodelay(true)?;
238 : // This prevents load balancer from severing the connection.
239 0 : socket.set_keepalive(true)?;
240 0 : Ok((socket_addr, stream))
241 0 : })
242 0 : };
243 :
244 : // We can't reuse connection establishing logic from `postgres_client` here,
245 : // because it has no means for extracting the underlying socket which we
246 : // require for our business.
247 0 : let port = self.port;
248 0 : let host = &*self.host;
249 :
250 0 : let addrs = match self.host_addr {
251 0 : Some(addr) => vec![SocketAddr::new(addr, port)],
252 0 : None => lookup_host((host, port)).await?.collect(),
253 : };
254 :
255 0 : match connect_once(&*addrs).await {
256 0 : Ok((sockaddr, stream)) => Ok((
257 0 : sockaddr,
258 0 : tls::connect_tls(stream, self.ssl_mode, config, host).await?,
259 : )),
260 0 : Err(err) => {
261 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
262 0 : Err(TlsError::Connection(err))
263 : }
264 : }
265 0 : }
266 : }
267 :
268 : type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
269 :
270 : pub(crate) struct PostgresConnection {
271 : /// Socket connected to a compute node.
272 : pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
273 : /// PostgreSQL connection parameters.
274 : pub(crate) params: std::collections::HashMap<String, String>,
275 : /// Query cancellation token.
276 : pub(crate) cancel_closure: CancelClosure,
277 : /// Labels for proxy's metrics.
278 : pub(crate) aux: MetricsAuxInfo,
279 : /// Notices received from compute after authenticating
280 : pub(crate) delayed_notice: Vec<NoticeResponseBody>,
281 :
282 : _guage: NumDbConnectionsGuard<'static>,
283 : }
284 :
285 : impl ConnectInfo {
286 : /// Connect to a corresponding compute node.
287 0 : pub(crate) async fn connect(
288 0 : &self,
289 0 : ctx: &RequestContext,
290 0 : aux: MetricsAuxInfo,
291 0 : auth: &AuthInfo,
292 0 : config: &ComputeConfig,
293 0 : user_info: ComputeUserInfo,
294 0 : ) -> Result<PostgresConnection, ConnectionError> {
295 0 : let mut tmp_config = auth.enrich(self.to_postgres_client_config());
296 0 : // we setup SSL early in `ConnectInfo::connect_raw`.
297 0 : tmp_config.ssl_mode(SslMode::Disable);
298 0 :
299 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
300 0 : let (socket_addr, stream) = self.connect_raw(config).await?;
301 0 : let connection = tmp_config.connect_raw(stream, NoTls).await?;
302 0 : drop(pause);
303 0 :
304 0 : let RawConnection {
305 0 : stream,
306 0 : parameters,
307 0 : delayed_notice,
308 0 : process_id,
309 0 : secret_key,
310 0 : } = connection;
311 0 :
312 0 : tracing::Span::current().record("pid", tracing::field::display(process_id));
313 0 : tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
314 0 : let MaybeTlsStream::Raw(stream) = stream.into_inner();
315 0 :
316 0 : // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
317 0 : info!(
318 0 : cold_start_info = ctx.cold_start_info().as_str(),
319 0 : "connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
320 0 : self.host,
321 0 : self.ssl_mode,
322 0 : ctx.get_proxy_latency(),
323 0 : ctx.get_testodrome_id().unwrap_or_default(),
324 : );
325 :
326 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
327 : // Yet another reason to rework the connection establishing code.
328 0 : let cancel_closure = CancelClosure::new(
329 0 : socket_addr,
330 0 : CancelToken {
331 0 : socket_config: None,
332 0 : ssl_mode: self.ssl_mode,
333 0 : process_id,
334 0 : secret_key,
335 0 : },
336 0 : self.host.to_string(),
337 0 : user_info,
338 0 : );
339 0 :
340 0 : let connection = PostgresConnection {
341 0 : stream,
342 0 : params: parameters,
343 0 : delayed_notice,
344 0 : cancel_closure,
345 0 : aux,
346 0 : _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
347 0 : };
348 0 :
349 0 : Ok(connection)
350 0 : }
351 : }
352 :
353 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
354 6 : fn filtered_options(options: &str) -> Option<String> {
355 6 : #[allow(unstable_name_collisions)]
356 6 : let options: String = StartupMessageParams::parse_options_raw(options)
357 14 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
358 6 : .intersperse(" ") // TODO: use impl from std once it's stabilized
359 6 : .collect();
360 6 :
361 6 : // Don't even bother with empty options.
362 6 : if options.is_empty() {
363 3 : return None;
364 3 : }
365 3 :
366 3 : Some(options)
367 6 : }
368 :
369 : #[cfg(test)]
370 : mod tests {
371 : use super::*;
372 :
373 : #[test]
374 1 : fn test_filtered_options() {
375 1 : // Empty options is unlikely to be useful anyway.
376 1 : let params = "";
377 1 : assert_eq!(filtered_options(params), None);
378 :
379 : // It's likely that clients will only use options to specify endpoint/project.
380 1 : let params = "project=foo";
381 1 : assert_eq!(filtered_options(params), None);
382 :
383 : // Same, because unescaped whitespaces are no-op.
384 1 : let params = " project=foo ";
385 1 : assert_eq!(filtered_options(params).as_deref(), None);
386 :
387 1 : let params = r"\ project=foo \ ";
388 1 : assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
389 :
390 1 : let params = "project = foo";
391 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
392 :
393 1 : let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true";
394 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
395 1 : }
396 : }
|