Line data Source code
1 : mod tls;
2 :
3 : use std::fmt::Debug;
4 : use std::io;
5 : use std::net::{IpAddr, SocketAddr};
6 :
7 : use futures::{FutureExt, TryFutureExt};
8 : use itertools::Itertools;
9 : use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
10 : use postgres_client::connect_raw::StartupStream;
11 : use postgres_client::error::SqlState;
12 : use postgres_client::maybe_tls_stream::MaybeTlsStream;
13 : use postgres_client::tls::MakeTlsConnect;
14 : use thiserror::Error;
15 : use tokio::net::{TcpStream, lookup_host};
16 : use tracing::{debug, error, info, warn};
17 :
18 : use crate::auth::backend::ComputeCredentialKeys;
19 : use crate::auth::parse_endpoint_param;
20 : use crate::compute::tls::TlsError;
21 : use crate::config::ComputeConfig;
22 : use crate::context::RequestContext;
23 : use crate::control_plane::client::ApiLockError;
24 : use crate::control_plane::errors::WakeComputeError;
25 : use crate::control_plane::messages::MetricsAuxInfo;
26 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
27 : use crate::metrics::{Metrics, NumDbConnectionsGuard};
28 : use crate::pqproto::StartupMessageParams;
29 : use crate::proxy::connect_compute::TlsNegotiation;
30 : use crate::proxy::neon_option;
31 : use crate::types::Host;
32 :
33 : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
34 :
35 : #[derive(Debug, Error)]
36 : pub(crate) enum PostgresError {
37 : /// This error doesn't seem to reveal any secrets; for instance,
38 : /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
39 : #[error("{COULD_NOT_CONNECT}: {0}")]
40 : Postgres(#[from] postgres_client::Error),
41 : }
42 :
43 : impl UserFacingError for PostgresError {
44 0 : fn to_string_client(&self) -> String {
45 0 : match self {
46 : // This helps us drop irrelevant library-specific prefixes.
47 : // TODO: propagate severity level and other parameters.
48 0 : PostgresError::Postgres(err) => match err.as_db_error() {
49 0 : Some(err) => {
50 0 : let msg = err.message();
51 :
52 0 : if msg.starts_with("unsupported startup parameter: ")
53 0 : || msg.starts_with("unsupported startup parameter in options: ")
54 : {
55 0 : format!(
56 0 : "{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter"
57 : )
58 : } else {
59 0 : msg.to_owned()
60 : }
61 : }
62 0 : None => err.to_string(),
63 : },
64 : }
65 0 : }
66 : }
67 :
68 : impl ReportableError for PostgresError {
69 0 : fn get_error_kind(&self) -> ErrorKind {
70 0 : match self {
71 0 : PostgresError::Postgres(err) => match err.as_db_error() {
72 0 : Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User,
73 0 : Some(_) => ErrorKind::Postgres,
74 0 : None => ErrorKind::Compute,
75 : },
76 : }
77 0 : }
78 : }
79 :
80 : #[derive(Debug, Error)]
81 : pub(crate) enum ConnectionError {
82 : #[error("{COULD_NOT_CONNECT}: {0}")]
83 : TlsError(#[from] TlsError),
84 :
85 : #[error("{COULD_NOT_CONNECT}: {0}")]
86 : WakeComputeError(#[from] WakeComputeError),
87 :
88 : #[error("error acquiring resource permit: {0}")]
89 : TooManyConnectionAttempts(#[from] ApiLockError),
90 :
91 : #[cfg(test)]
92 : #[error("retryable: {retryable}, wakeable: {wakeable}, kind: {kind:?}")]
93 : TestError {
94 : retryable: bool,
95 : wakeable: bool,
96 : kind: crate::error::ErrorKind,
97 : },
98 : }
99 :
100 : impl UserFacingError for ConnectionError {
101 0 : fn to_string_client(&self) -> String {
102 0 : match self {
103 0 : ConnectionError::WakeComputeError(err) => err.to_string_client(),
104 : ConnectionError::TooManyConnectionAttempts(_) => {
105 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
106 : }
107 0 : ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
108 : #[cfg(test)]
109 0 : ConnectionError::TestError { .. } => self.to_string(),
110 : }
111 0 : }
112 : }
113 :
114 : impl ReportableError for ConnectionError {
115 0 : fn get_error_kind(&self) -> ErrorKind {
116 0 : match self {
117 0 : ConnectionError::TlsError(_) => ErrorKind::Compute,
118 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
119 0 : ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
120 : #[cfg(test)]
121 0 : ConnectionError::TestError { kind, .. } => *kind,
122 : }
123 0 : }
124 : }
125 :
126 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
127 : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
128 :
129 : #[derive(Clone)]
130 : pub enum Auth {
131 : /// Only used during console-redirect.
132 : Password(Vec<u8>),
133 : /// Used by sql-over-http, ws, tcp.
134 : Scram(Box<ScramKeys>),
135 : }
136 :
137 : /// A config for authenticating to the compute node.
138 : pub(crate) struct AuthInfo {
139 : /// None for local-proxy, as we use trust-based localhost auth.
140 : /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
141 : /// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
142 : auth: Option<Auth>,
143 : server_params: StartupMessageParams,
144 :
145 : channel_binding: ChannelBinding,
146 :
147 : /// Console redirect sets user and database, we shouldn't re-use those from the params.
148 : skip_db_user: bool,
149 : }
150 :
151 : /// Contains only the data needed to establish a secure connection to compute.
152 : #[derive(Clone)]
153 : pub struct ConnectInfo {
154 : pub host_addr: Option<IpAddr>,
155 : pub host: Host,
156 : pub port: u16,
157 : pub ssl_mode: SslMode,
158 : }
159 :
160 : /// Creation and initialization routines.
161 : impl AuthInfo {
162 0 : pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
163 0 : let mut server_params = StartupMessageParams::default();
164 0 : server_params.insert("database", db);
165 0 : server_params.insert("user", user);
166 : Self {
167 0 : auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
168 0 : server_params,
169 : skip_db_user: true,
170 : // pg-sni-router is a mitm so this would fail.
171 0 : channel_binding: ChannelBinding::Disable,
172 : }
173 0 : }
174 :
175 0 : pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
176 : Self {
177 0 : auth: match keys {
178 0 : ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
179 0 : Some(Auth::Scram(Box::new(auth_keys)))
180 : }
181 0 : ComputeCredentialKeys::JwtPayload(_) => None,
182 : },
183 0 : server_params: StartupMessageParams::default(),
184 : skip_db_user: false,
185 0 : channel_binding: ChannelBinding::Prefer,
186 : }
187 0 : }
188 : }
189 :
190 : impl ConnectInfo {
191 0 : pub fn to_postgres_client_config(&self) -> postgres_client::Config {
192 0 : let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
193 0 : config.ssl_mode(self.ssl_mode);
194 0 : if let Some(host_addr) = self.host_addr {
195 0 : config.set_host_addr(host_addr);
196 0 : }
197 0 : config
198 0 : }
199 : }
200 :
201 : impl AuthInfo {
202 0 : fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
203 0 : match &self.auth {
204 0 : Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
205 0 : Some(Auth::Password(pw)) => config.password(pw),
206 0 : None => &mut config,
207 : };
208 0 : config.channel_binding(self.channel_binding);
209 0 : for (k, v) in self.server_params.iter() {
210 0 : config.set_param(k, v);
211 0 : }
212 0 : config
213 0 : }
214 :
215 : /// Apply startup message params to the connection config.
216 0 : pub(crate) fn set_startup_params(
217 0 : &mut self,
218 0 : params: &StartupMessageParams,
219 0 : arbitrary_params: bool,
220 0 : ) {
221 0 : if !arbitrary_params {
222 0 : self.server_params.insert("client_encoding", "UTF8");
223 0 : }
224 0 : for (k, v) in params.iter() {
225 0 : match k {
226 : // Only set `user` if it's not present in the config.
227 : // Console redirect auth flow takes username from the console's response.
228 0 : "user" | "database" if self.skip_db_user => {}
229 0 : "options" => {
230 0 : if let Some(options) = filtered_options(v) {
231 0 : self.server_params.insert(k, &options);
232 0 : }
233 : }
234 0 : "user" | "database" | "application_name" | "replication" => {
235 0 : self.server_params.insert(k, v);
236 0 : }
237 :
238 : // if we allow arbitrary params, then we forward them through.
239 : // this is a flag for a period of backwards compatibility
240 0 : k if arbitrary_params => {
241 0 : self.server_params.insert(k, v);
242 0 : }
243 0 : _ => {}
244 : }
245 : }
246 0 : }
247 :
248 0 : pub async fn authenticate(
249 0 : &self,
250 0 : ctx: &RequestContext,
251 0 : compute: &mut ComputeConnection,
252 0 : ) -> Result<(), PostgresError> {
253 : // client config with stubbed connect info.
254 : // TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
255 : // utilising pqproto.rs.
256 0 : let mut tmp_config = postgres_client::Config::new(String::new(), 0);
257 : // We have already established SSL if necessary.
258 0 : tmp_config.ssl_mode(SslMode::Disable);
259 0 : let tmp_config = self.enrich(tmp_config);
260 :
261 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
262 0 : tmp_config.authenticate(&mut compute.stream).await?;
263 0 : drop(pause);
264 :
265 0 : Ok(())
266 0 : }
267 : }
268 :
269 : impl ConnectInfo {
270 : /// Establish a raw TCP+TLS connection to the compute node.
271 0 : async fn connect_raw(
272 0 : &self,
273 0 : config: &ComputeConfig,
274 0 : tls: TlsNegotiation,
275 0 : ) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
276 0 : let timeout = config.timeout;
277 :
278 : // wrap TcpStream::connect with timeout
279 0 : let connect_with_timeout = |addrs| {
280 0 : tokio::time::timeout(timeout, TcpStream::connect(addrs)).map(move |res| match res {
281 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
282 0 : Err(_) => Err(io::Error::new(
283 0 : io::ErrorKind::TimedOut,
284 0 : format!("exceeded connection timeout {timeout:?}"),
285 0 : )),
286 0 : })
287 0 : };
288 :
289 0 : let connect_once = |addrs| {
290 0 : debug!("trying to connect to compute node at {addrs:?}");
291 0 : connect_with_timeout(addrs).and_then(|stream| async {
292 0 : let socket_addr = stream.peer_addr()?;
293 0 : let socket = socket2::SockRef::from(&stream);
294 : // Disable Nagle's algorithm to not introduce latency between
295 : // client and compute.
296 0 : socket.set_nodelay(true)?;
297 : // This prevents load balancer from severing the connection.
298 0 : socket.set_keepalive(true)?;
299 0 : Ok((socket_addr, stream))
300 0 : })
301 0 : };
302 :
303 : // We can't reuse connection establishing logic from `postgres_client` here,
304 : // because it has no means for extracting the underlying socket which we
305 : // require for our business.
306 0 : let port = self.port;
307 0 : let host = &*self.host;
308 :
309 0 : let addrs = match self.host_addr {
310 0 : Some(addr) => vec![SocketAddr::new(addr, port)],
311 0 : None => lookup_host((host, port)).await?.collect(),
312 : };
313 :
314 0 : match connect_once(&*addrs).await {
315 0 : Ok((sockaddr, stream)) => Ok((
316 0 : sockaddr,
317 0 : tls::connect_tls(stream, self.ssl_mode, config, host, tls).await?,
318 : )),
319 0 : Err(err) => {
320 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
321 0 : Err(TlsError::Connection(err))
322 : }
323 : }
324 0 : }
325 : }
326 :
327 : pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
328 : pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
329 :
330 : pub struct ComputeConnection {
331 : /// Socket connected to a compute node.
332 : pub stream: StartupStream<tokio::net::TcpStream, RustlsStream>,
333 : /// Labels for proxy's metrics.
334 : pub aux: MetricsAuxInfo,
335 : pub hostname: Host,
336 : pub ssl_mode: SslMode,
337 : pub socket_addr: SocketAddr,
338 : pub guage: NumDbConnectionsGuard<'static>,
339 : }
340 :
341 : impl ConnectInfo {
342 : /// Connect to a corresponding compute node.
343 0 : pub async fn connect(
344 0 : &self,
345 0 : ctx: &RequestContext,
346 0 : aux: &MetricsAuxInfo,
347 0 : config: &ComputeConfig,
348 0 : tls: TlsNegotiation,
349 0 : ) -> Result<ComputeConnection, ConnectionError> {
350 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
351 0 : let (socket_addr, stream) = self.connect_raw(config, tls).await?;
352 0 : drop(pause);
353 :
354 0 : tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
355 :
356 : // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
357 0 : info!(
358 0 : cold_start_info = ctx.cold_start_info().as_str(),
359 0 : "connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
360 : self.host,
361 : self.ssl_mode,
362 0 : ctx.get_proxy_latency(),
363 0 : ctx.get_testodrome_id().unwrap_or_default(),
364 : );
365 :
366 0 : let stream = StartupStream::new(stream);
367 0 : let connection = ComputeConnection {
368 0 : stream,
369 0 : socket_addr,
370 0 : hostname: self.host.clone(),
371 0 : ssl_mode: self.ssl_mode,
372 0 : aux: aux.clone(),
373 0 : guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
374 0 : };
375 :
376 0 : Ok(connection)
377 0 : }
378 : }
379 :
380 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
381 6 : fn filtered_options(options: &str) -> Option<String> {
382 : #[allow(unstable_name_collisions)]
383 6 : let options: String = StartupMessageParams::parse_options_raw(options)
384 14 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
385 6 : .intersperse(" ") // TODO: use impl from std once it's stabilized
386 6 : .collect();
387 :
388 : // Don't even bother with empty options.
389 6 : if options.is_empty() {
390 3 : return None;
391 3 : }
392 :
393 3 : Some(options)
394 6 : }
395 :
396 : #[cfg(test)]
397 : mod tests {
398 : use super::*;
399 :
400 : #[test]
401 1 : fn test_filtered_options() {
402 : // Empty options is unlikely to be useful anyway.
403 1 : let params = "";
404 1 : assert_eq!(filtered_options(params), None);
405 :
406 : // It's likely that clients will only use options to specify endpoint/project.
407 1 : let params = "project=foo";
408 1 : assert_eq!(filtered_options(params), None);
409 :
410 : // Same, because unescaped whitespaces are no-op.
411 1 : let params = " project=foo ";
412 1 : assert_eq!(filtered_options(params).as_deref(), None);
413 :
414 1 : let params = r"\ project=foo \ ";
415 1 : assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
416 :
417 1 : let params = "project = foo";
418 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
419 :
420 1 : let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true";
421 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
422 1 : }
423 : }
|