Line data Source code
1 : mod tls;
2 :
3 : use std::fmt::Debug;
4 : use std::io;
5 : use std::net::{IpAddr, SocketAddr};
6 :
7 : use futures::{FutureExt, TryFutureExt};
8 : use itertools::Itertools;
9 : use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
10 : use postgres_client::connect_raw::StartupStream;
11 : use postgres_client::maybe_tls_stream::MaybeTlsStream;
12 : use postgres_client::tls::MakeTlsConnect;
13 : use thiserror::Error;
14 : use tokio::net::{TcpStream, lookup_host};
15 : use tracing::{debug, error, info, warn};
16 :
17 : use crate::auth::backend::ComputeCredentialKeys;
18 : use crate::auth::parse_endpoint_param;
19 : use crate::compute::tls::TlsError;
20 : use crate::config::ComputeConfig;
21 : use crate::context::RequestContext;
22 : use crate::control_plane::client::ApiLockError;
23 : use crate::control_plane::errors::WakeComputeError;
24 : use crate::control_plane::messages::MetricsAuxInfo;
25 : use crate::error::{ReportableError, UserFacingError};
26 : use crate::metrics::{Metrics, NumDbConnectionsGuard};
27 : use crate::pqproto::StartupMessageParams;
28 : use crate::proxy::connect_compute::TlsNegotiation;
29 : use crate::proxy::neon_option;
30 : use crate::types::Host;
31 :
32 : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
33 :
34 : #[derive(Debug, Error)]
35 : pub(crate) enum PostgresError {
36 : /// This error doesn't seem to reveal any secrets; for instance,
37 : /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
38 : #[error("{COULD_NOT_CONNECT}: {0}")]
39 : Postgres(#[from] postgres_client::Error),
40 : }
41 :
42 : impl UserFacingError for PostgresError {
43 0 : fn to_string_client(&self) -> String {
44 0 : match self {
45 : // This helps us drop irrelevant library-specific prefixes.
46 : // TODO: propagate severity level and other parameters.
47 0 : PostgresError::Postgres(err) => match err.as_db_error() {
48 0 : Some(err) => {
49 0 : let msg = err.message();
50 :
51 0 : if msg.starts_with("unsupported startup parameter: ")
52 0 : || msg.starts_with("unsupported startup parameter in options: ")
53 : {
54 0 : format!(
55 0 : "{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter"
56 : )
57 : } else {
58 0 : msg.to_owned()
59 : }
60 : }
61 0 : None => err.to_string(),
62 : },
63 : }
64 0 : }
65 : }
66 :
67 : impl ReportableError for PostgresError {
68 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
69 0 : match self {
70 0 : PostgresError::Postgres(e) if e.as_db_error().is_some() => {
71 0 : crate::error::ErrorKind::Postgres
72 : }
73 0 : PostgresError::Postgres(_) => crate::error::ErrorKind::Compute,
74 : }
75 0 : }
76 : }
77 :
78 : #[derive(Debug, Error)]
79 : pub(crate) enum ConnectionError {
80 : #[error("{COULD_NOT_CONNECT}: {0}")]
81 : TlsError(#[from] TlsError),
82 :
83 : #[error("{COULD_NOT_CONNECT}: {0}")]
84 : WakeComputeError(#[from] WakeComputeError),
85 :
86 : #[error("error acquiring resource permit: {0}")]
87 : TooManyConnectionAttempts(#[from] ApiLockError),
88 :
89 : #[cfg(test)]
90 : #[error("retryable: {retryable}, wakeable: {wakeable}, kind: {kind:?}")]
91 : TestError {
92 : retryable: bool,
93 : wakeable: bool,
94 : kind: crate::error::ErrorKind,
95 : },
96 : }
97 :
98 : impl UserFacingError for ConnectionError {
99 0 : fn to_string_client(&self) -> String {
100 0 : match self {
101 0 : ConnectionError::WakeComputeError(err) => err.to_string_client(),
102 : ConnectionError::TooManyConnectionAttempts(_) => {
103 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
104 : }
105 0 : ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
106 : #[cfg(test)]
107 0 : ConnectionError::TestError { .. } => self.to_string(),
108 : }
109 0 : }
110 : }
111 :
112 : impl ReportableError for ConnectionError {
113 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
114 0 : match self {
115 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
116 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
117 0 : ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
118 : #[cfg(test)]
119 0 : ConnectionError::TestError { kind, .. } => *kind,
120 : }
121 0 : }
122 : }
123 :
124 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
125 : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
126 :
127 : #[derive(Clone)]
128 : pub enum Auth {
129 : /// Only used during console-redirect.
130 : Password(Vec<u8>),
131 : /// Used by sql-over-http, ws, tcp.
132 : Scram(Box<ScramKeys>),
133 : }
134 :
135 : /// A config for authenticating to the compute node.
136 : pub(crate) struct AuthInfo {
137 : /// None for local-proxy, as we use trust-based localhost auth.
138 : /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
139 : /// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
140 : auth: Option<Auth>,
141 : server_params: StartupMessageParams,
142 :
143 : channel_binding: ChannelBinding,
144 :
145 : /// Console redirect sets user and database, we shouldn't re-use those from the params.
146 : skip_db_user: bool,
147 : }
148 :
149 : /// Contains only the data needed to establish a secure connection to compute.
150 : #[derive(Clone)]
151 : pub struct ConnectInfo {
152 : pub host_addr: Option<IpAddr>,
153 : pub host: Host,
154 : pub port: u16,
155 : pub ssl_mode: SslMode,
156 : }
157 :
158 : /// Creation and initialization routines.
159 : impl AuthInfo {
160 0 : pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
161 0 : let mut server_params = StartupMessageParams::default();
162 0 : server_params.insert("database", db);
163 0 : server_params.insert("user", user);
164 : Self {
165 0 : auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
166 0 : server_params,
167 : skip_db_user: true,
168 : // pg-sni-router is a mitm so this would fail.
169 0 : channel_binding: ChannelBinding::Disable,
170 : }
171 0 : }
172 :
173 0 : pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
174 : Self {
175 0 : auth: match keys {
176 0 : ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
177 0 : Some(Auth::Scram(Box::new(auth_keys)))
178 : }
179 0 : ComputeCredentialKeys::JwtPayload(_) => None,
180 : },
181 0 : server_params: StartupMessageParams::default(),
182 : skip_db_user: false,
183 0 : channel_binding: ChannelBinding::Prefer,
184 : }
185 0 : }
186 : }
187 :
188 : impl ConnectInfo {
189 0 : pub fn to_postgres_client_config(&self) -> postgres_client::Config {
190 0 : let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
191 0 : config.ssl_mode(self.ssl_mode);
192 0 : if let Some(host_addr) = self.host_addr {
193 0 : config.set_host_addr(host_addr);
194 0 : }
195 0 : config
196 0 : }
197 : }
198 :
199 : impl AuthInfo {
200 0 : fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
201 0 : match &self.auth {
202 0 : Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
203 0 : Some(Auth::Password(pw)) => config.password(pw),
204 0 : None => &mut config,
205 : };
206 0 : config.channel_binding(self.channel_binding);
207 0 : for (k, v) in self.server_params.iter() {
208 0 : config.set_param(k, v);
209 0 : }
210 0 : config
211 0 : }
212 :
213 : /// Apply startup message params to the connection config.
214 0 : pub(crate) fn set_startup_params(
215 0 : &mut self,
216 0 : params: &StartupMessageParams,
217 0 : arbitrary_params: bool,
218 0 : ) {
219 0 : if !arbitrary_params {
220 0 : self.server_params.insert("client_encoding", "UTF8");
221 0 : }
222 0 : for (k, v) in params.iter() {
223 0 : match k {
224 : // Only set `user` if it's not present in the config.
225 : // Console redirect auth flow takes username from the console's response.
226 0 : "user" | "database" if self.skip_db_user => {}
227 0 : "options" => {
228 0 : if let Some(options) = filtered_options(v) {
229 0 : self.server_params.insert(k, &options);
230 0 : }
231 : }
232 0 : "user" | "database" | "application_name" | "replication" => {
233 0 : self.server_params.insert(k, v);
234 0 : }
235 :
236 : // if we allow arbitrary params, then we forward them through.
237 : // this is a flag for a period of backwards compatibility
238 0 : k if arbitrary_params => {
239 0 : self.server_params.insert(k, v);
240 0 : }
241 0 : _ => {}
242 : }
243 : }
244 0 : }
245 :
246 0 : pub async fn authenticate(
247 0 : &self,
248 0 : ctx: &RequestContext,
249 0 : compute: &mut ComputeConnection,
250 0 : ) -> Result<(), PostgresError> {
251 : // client config with stubbed connect info.
252 : // TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
253 : // utilising pqproto.rs.
254 0 : let mut tmp_config = postgres_client::Config::new(String::new(), 0);
255 : // We have already established SSL if necessary.
256 0 : tmp_config.ssl_mode(SslMode::Disable);
257 0 : let tmp_config = self.enrich(tmp_config);
258 :
259 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
260 0 : tmp_config.authenticate(&mut compute.stream).await?;
261 0 : drop(pause);
262 :
263 0 : Ok(())
264 0 : }
265 : }
266 :
267 : impl ConnectInfo {
268 : /// Establish a raw TCP+TLS connection to the compute node.
269 0 : async fn connect_raw(
270 0 : &self,
271 0 : config: &ComputeConfig,
272 0 : tls: TlsNegotiation,
273 0 : ) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
274 0 : let timeout = config.timeout;
275 :
276 : // wrap TcpStream::connect with timeout
277 0 : let connect_with_timeout = |addrs| {
278 0 : tokio::time::timeout(timeout, TcpStream::connect(addrs)).map(move |res| match res {
279 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
280 0 : Err(_) => Err(io::Error::new(
281 0 : io::ErrorKind::TimedOut,
282 0 : format!("exceeded connection timeout {timeout:?}"),
283 0 : )),
284 0 : })
285 0 : };
286 :
287 0 : let connect_once = |addrs| {
288 0 : debug!("trying to connect to compute node at {addrs:?}");
289 0 : connect_with_timeout(addrs).and_then(|stream| async {
290 0 : let socket_addr = stream.peer_addr()?;
291 0 : let socket = socket2::SockRef::from(&stream);
292 : // Disable Nagle's algorithm to not introduce latency between
293 : // client and compute.
294 0 : socket.set_nodelay(true)?;
295 : // This prevents load balancer from severing the connection.
296 0 : socket.set_keepalive(true)?;
297 0 : Ok((socket_addr, stream))
298 0 : })
299 0 : };
300 :
301 : // We can't reuse connection establishing logic from `postgres_client` here,
302 : // because it has no means for extracting the underlying socket which we
303 : // require for our business.
304 0 : let port = self.port;
305 0 : let host = &*self.host;
306 :
307 0 : let addrs = match self.host_addr {
308 0 : Some(addr) => vec![SocketAddr::new(addr, port)],
309 0 : None => lookup_host((host, port)).await?.collect(),
310 : };
311 :
312 0 : match connect_once(&*addrs).await {
313 0 : Ok((sockaddr, stream)) => Ok((
314 0 : sockaddr,
315 0 : tls::connect_tls(stream, self.ssl_mode, config, host, tls).await?,
316 : )),
317 0 : Err(err) => {
318 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
319 0 : Err(TlsError::Connection(err))
320 : }
321 : }
322 0 : }
323 : }
324 :
325 : pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
326 : pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
327 :
328 : pub struct ComputeConnection {
329 : /// Socket connected to a compute node.
330 : pub stream: StartupStream<tokio::net::TcpStream, RustlsStream>,
331 : /// Labels for proxy's metrics.
332 : pub aux: MetricsAuxInfo,
333 : pub hostname: Host,
334 : pub ssl_mode: SslMode,
335 : pub socket_addr: SocketAddr,
336 : pub guage: NumDbConnectionsGuard<'static>,
337 : }
338 :
339 : impl ConnectInfo {
340 : /// Connect to a corresponding compute node.
341 0 : pub async fn connect(
342 0 : &self,
343 0 : ctx: &RequestContext,
344 0 : aux: &MetricsAuxInfo,
345 0 : config: &ComputeConfig,
346 0 : tls: TlsNegotiation,
347 0 : ) -> Result<ComputeConnection, ConnectionError> {
348 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
349 0 : let (socket_addr, stream) = self.connect_raw(config, tls).await?;
350 0 : drop(pause);
351 :
352 0 : tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
353 :
354 : // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
355 0 : info!(
356 0 : cold_start_info = ctx.cold_start_info().as_str(),
357 0 : "connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
358 : self.host,
359 : self.ssl_mode,
360 0 : ctx.get_proxy_latency(),
361 0 : ctx.get_testodrome_id().unwrap_or_default(),
362 : );
363 :
364 0 : let stream = StartupStream::new(stream);
365 0 : let connection = ComputeConnection {
366 0 : stream,
367 0 : socket_addr,
368 0 : hostname: self.host.clone(),
369 0 : ssl_mode: self.ssl_mode,
370 0 : aux: aux.clone(),
371 0 : guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
372 0 : };
373 :
374 0 : Ok(connection)
375 0 : }
376 : }
377 :
378 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
379 6 : fn filtered_options(options: &str) -> Option<String> {
380 : #[allow(unstable_name_collisions)]
381 6 : let options: String = StartupMessageParams::parse_options_raw(options)
382 14 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
383 6 : .intersperse(" ") // TODO: use impl from std once it's stabilized
384 6 : .collect();
385 :
386 : // Don't even bother with empty options.
387 6 : if options.is_empty() {
388 3 : return None;
389 3 : }
390 :
391 3 : Some(options)
392 6 : }
393 :
394 : #[cfg(test)]
395 : mod tests {
396 : use super::*;
397 :
398 : #[test]
399 1 : fn test_filtered_options() {
400 : // Empty options is unlikely to be useful anyway.
401 1 : let params = "";
402 1 : assert_eq!(filtered_options(params), None);
403 :
404 : // It's likely that clients will only use options to specify endpoint/project.
405 1 : let params = "project=foo";
406 1 : assert_eq!(filtered_options(params), None);
407 :
408 : // Same, because unescaped whitespaces are no-op.
409 1 : let params = " project=foo ";
410 1 : assert_eq!(filtered_options(params).as_deref(), None);
411 :
412 1 : let params = r"\ project=foo \ ";
413 1 : assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
414 :
415 1 : let params = "project = foo";
416 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
417 :
418 1 : let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true";
419 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
420 1 : }
421 : }
|