Line data Source code
1 : mod tls;
2 :
3 : use std::fmt::Debug;
4 : use std::io;
5 : use std::net::{IpAddr, SocketAddr};
6 :
7 : use futures::{FutureExt, TryFutureExt};
8 : use itertools::Itertools;
9 : use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
10 : use postgres_client::maybe_tls_stream::MaybeTlsStream;
11 : use postgres_client::tls::MakeTlsConnect;
12 : use postgres_client::{NoTls, RawCancelToken, RawConnection};
13 : use postgres_protocol::message::backend::NoticeResponseBody;
14 : use thiserror::Error;
15 : use tokio::net::{TcpStream, lookup_host};
16 : use tracing::{debug, error, info, warn};
17 :
18 : use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
19 : use crate::auth::parse_endpoint_param;
20 : use crate::cancellation::CancelClosure;
21 : use crate::compute::tls::TlsError;
22 : use crate::config::ComputeConfig;
23 : use crate::context::RequestContext;
24 : use crate::control_plane::client::ApiLockError;
25 : use crate::control_plane::errors::WakeComputeError;
26 : use crate::control_plane::messages::MetricsAuxInfo;
27 : use crate::error::{ReportableError, UserFacingError};
28 : use crate::metrics::{Metrics, NumDbConnectionsGuard};
29 : use crate::pqproto::StartupMessageParams;
30 : use crate::proxy::neon_option;
31 : use crate::types::Host;
32 :
33 : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
34 :
35 : #[derive(Debug, Error)]
36 : pub(crate) enum PostgresError {
37 : /// This error doesn't seem to reveal any secrets; for instance,
38 : /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
39 : #[error("{COULD_NOT_CONNECT}: {0}")]
40 : Postgres(#[from] postgres_client::Error),
41 : }
42 :
43 : impl UserFacingError for PostgresError {
44 0 : fn to_string_client(&self) -> String {
45 0 : match self {
46 : // This helps us drop irrelevant library-specific prefixes.
47 : // TODO: propagate severity level and other parameters.
48 0 : PostgresError::Postgres(err) => match err.as_db_error() {
49 0 : Some(err) => {
50 0 : let msg = err.message();
51 :
52 0 : if msg.starts_with("unsupported startup parameter: ")
53 0 : || msg.starts_with("unsupported startup parameter in options: ")
54 : {
55 0 : format!(
56 0 : "{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter"
57 : )
58 : } else {
59 0 : msg.to_owned()
60 : }
61 : }
62 0 : None => err.to_string(),
63 : },
64 : }
65 0 : }
66 : }
67 :
68 : impl ReportableError for PostgresError {
69 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
70 0 : match self {
71 0 : PostgresError::Postgres(e) if e.as_db_error().is_some() => {
72 0 : crate::error::ErrorKind::Postgres
73 : }
74 0 : PostgresError::Postgres(_) => crate::error::ErrorKind::Compute,
75 : }
76 0 : }
77 : }
78 :
79 : #[derive(Debug, Error)]
80 : pub(crate) enum ConnectionError {
81 : #[error("{COULD_NOT_CONNECT}: {0}")]
82 : TlsError(#[from] TlsError),
83 :
84 : #[error("{COULD_NOT_CONNECT}: {0}")]
85 : WakeComputeError(#[from] WakeComputeError),
86 :
87 : #[error("error acquiring resource permit: {0}")]
88 : TooManyConnectionAttempts(#[from] ApiLockError),
89 : }
90 :
91 : impl UserFacingError for ConnectionError {
92 0 : fn to_string_client(&self) -> String {
93 0 : match self {
94 0 : ConnectionError::WakeComputeError(err) => err.to_string_client(),
95 : ConnectionError::TooManyConnectionAttempts(_) => {
96 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
97 : }
98 0 : ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
99 : }
100 0 : }
101 : }
102 :
103 : impl ReportableError for ConnectionError {
104 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
105 0 : match self {
106 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
107 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
108 0 : ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
109 : }
110 0 : }
111 : }
112 :
113 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
114 : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
115 :
116 : #[derive(Clone)]
117 : pub enum Auth {
118 : /// Only used during console-redirect.
119 : Password(Vec<u8>),
120 : /// Used by sql-over-http, ws, tcp.
121 : Scram(Box<ScramKeys>),
122 : }
123 :
124 : /// A config for authenticating to the compute node.
125 : pub(crate) struct AuthInfo {
126 : /// None for local-proxy, as we use trust-based localhost auth.
127 : /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
128 : /// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
129 : auth: Option<Auth>,
130 : server_params: StartupMessageParams,
131 :
132 : channel_binding: ChannelBinding,
133 :
134 : /// Console redirect sets user and database, we shouldn't re-use those from the params.
135 : skip_db_user: bool,
136 : }
137 :
138 : /// Contains only the data needed to establish a secure connection to compute.
139 : #[derive(Clone)]
140 : pub struct ConnectInfo {
141 : pub host_addr: Option<IpAddr>,
142 : pub host: Host,
143 : pub port: u16,
144 : pub ssl_mode: SslMode,
145 : }
146 :
147 : /// Creation and initialization routines.
148 : impl AuthInfo {
149 0 : pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
150 0 : let mut server_params = StartupMessageParams::default();
151 0 : server_params.insert("database", db);
152 0 : server_params.insert("user", user);
153 : Self {
154 0 : auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
155 0 : server_params,
156 : skip_db_user: true,
157 : // pg-sni-router is a mitm so this would fail.
158 0 : channel_binding: ChannelBinding::Disable,
159 : }
160 0 : }
161 :
162 0 : pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
163 : Self {
164 0 : auth: match keys {
165 0 : ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
166 0 : Some(Auth::Scram(Box::new(auth_keys)))
167 : }
168 0 : ComputeCredentialKeys::JwtPayload(_) => None,
169 : },
170 0 : server_params: StartupMessageParams::default(),
171 : skip_db_user: false,
172 0 : channel_binding: ChannelBinding::Prefer,
173 : }
174 0 : }
175 : }
176 :
177 : impl ConnectInfo {
178 0 : pub fn to_postgres_client_config(&self) -> postgres_client::Config {
179 0 : let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
180 0 : config.ssl_mode(self.ssl_mode);
181 0 : if let Some(host_addr) = self.host_addr {
182 0 : config.set_host_addr(host_addr);
183 0 : }
184 0 : config
185 0 : }
186 : }
187 :
188 : impl AuthInfo {
189 0 : fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
190 0 : match &self.auth {
191 0 : Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
192 0 : Some(Auth::Password(pw)) => config.password(pw),
193 0 : None => &mut config,
194 : };
195 0 : config.channel_binding(self.channel_binding);
196 0 : for (k, v) in self.server_params.iter() {
197 0 : config.set_param(k, v);
198 0 : }
199 0 : config
200 0 : }
201 :
202 : /// Apply startup message params to the connection config.
203 0 : pub(crate) fn set_startup_params(
204 0 : &mut self,
205 0 : params: &StartupMessageParams,
206 0 : arbitrary_params: bool,
207 0 : ) {
208 0 : if !arbitrary_params {
209 0 : self.server_params.insert("client_encoding", "UTF8");
210 0 : }
211 0 : for (k, v) in params.iter() {
212 0 : match k {
213 : // Only set `user` if it's not present in the config.
214 : // Console redirect auth flow takes username from the console's response.
215 0 : "user" | "database" if self.skip_db_user => {}
216 0 : "options" => {
217 0 : if let Some(options) = filtered_options(v) {
218 0 : self.server_params.insert(k, &options);
219 0 : }
220 : }
221 0 : "user" | "database" | "application_name" | "replication" => {
222 0 : self.server_params.insert(k, v);
223 0 : }
224 :
225 : // if we allow arbitrary params, then we forward them through.
226 : // this is a flag for a period of backwards compatibility
227 0 : k if arbitrary_params => {
228 0 : self.server_params.insert(k, v);
229 0 : }
230 0 : _ => {}
231 : }
232 : }
233 0 : }
234 :
235 0 : pub async fn authenticate(
236 0 : &self,
237 0 : ctx: &RequestContext,
238 0 : compute: &mut ComputeConnection,
239 0 : user_info: &ComputeUserInfo,
240 0 : ) -> Result<PostgresSettings, PostgresError> {
241 : // client config with stubbed connect info.
242 : // TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
243 : // utilising pqproto.rs.
244 0 : let mut tmp_config = postgres_client::Config::new(String::new(), 0);
245 : // We have already established SSL if necessary.
246 0 : tmp_config.ssl_mode(SslMode::Disable);
247 0 : let tmp_config = self.enrich(tmp_config);
248 :
249 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
250 0 : let connection = tmp_config
251 0 : .tls_and_authenticate(&mut compute.stream, NoTls)
252 0 : .await?;
253 0 : drop(pause);
254 :
255 : let RawConnection {
256 : stream: _,
257 0 : parameters,
258 0 : delayed_notice,
259 0 : process_id,
260 0 : secret_key,
261 0 : } = connection;
262 :
263 0 : tracing::Span::current().record("pid", tracing::field::display(process_id));
264 :
265 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
266 : // Yet another reason to rework the connection establishing code.
267 0 : let cancel_closure = CancelClosure::new(
268 0 : compute.socket_addr,
269 0 : RawCancelToken {
270 0 : ssl_mode: compute.ssl_mode,
271 0 : process_id,
272 0 : secret_key,
273 0 : },
274 0 : compute.hostname.to_string(),
275 0 : user_info.clone(),
276 : );
277 :
278 0 : Ok(PostgresSettings {
279 0 : params: parameters,
280 0 : cancel_closure,
281 0 : delayed_notice,
282 0 : })
283 0 : }
284 : }
285 :
286 : impl ConnectInfo {
287 : /// Establish a raw TCP+TLS connection to the compute node.
288 0 : async fn connect_raw(
289 0 : &self,
290 0 : config: &ComputeConfig,
291 0 : ) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
292 0 : let timeout = config.timeout;
293 :
294 : // wrap TcpStream::connect with timeout
295 0 : let connect_with_timeout = |addrs| {
296 0 : tokio::time::timeout(timeout, TcpStream::connect(addrs)).map(move |res| match res {
297 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
298 0 : Err(_) => Err(io::Error::new(
299 0 : io::ErrorKind::TimedOut,
300 0 : format!("exceeded connection timeout {timeout:?}"),
301 0 : )),
302 0 : })
303 0 : };
304 :
305 0 : let connect_once = |addrs| {
306 0 : debug!("trying to connect to compute node at {addrs:?}");
307 0 : connect_with_timeout(addrs).and_then(|stream| async {
308 0 : let socket_addr = stream.peer_addr()?;
309 0 : let socket = socket2::SockRef::from(&stream);
310 : // Disable Nagle's algorithm to not introduce latency between
311 : // client and compute.
312 0 : socket.set_nodelay(true)?;
313 : // This prevents load balancer from severing the connection.
314 0 : socket.set_keepalive(true)?;
315 0 : Ok((socket_addr, stream))
316 0 : })
317 0 : };
318 :
319 : // We can't reuse connection establishing logic from `postgres_client` here,
320 : // because it has no means for extracting the underlying socket which we
321 : // require for our business.
322 0 : let port = self.port;
323 0 : let host = &*self.host;
324 :
325 0 : let addrs = match self.host_addr {
326 0 : Some(addr) => vec![SocketAddr::new(addr, port)],
327 0 : None => lookup_host((host, port)).await?.collect(),
328 : };
329 :
330 0 : match connect_once(&*addrs).await {
331 0 : Ok((sockaddr, stream)) => Ok((
332 0 : sockaddr,
333 0 : tls::connect_tls(stream, self.ssl_mode, config, host).await?,
334 : )),
335 0 : Err(err) => {
336 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
337 0 : Err(TlsError::Connection(err))
338 : }
339 : }
340 0 : }
341 : }
342 :
343 : pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
344 : pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
345 :
346 : // TODO(conrad): we don't need to parse these.
347 : // These are just immediately forwarded back to the client.
348 : // We could instead stream them out instead of reading them into memory.
349 : pub struct PostgresSettings {
350 : /// PostgreSQL connection parameters.
351 : pub params: std::collections::HashMap<String, String>,
352 : /// Query cancellation token.
353 : pub cancel_closure: CancelClosure,
354 : /// Notices received from compute after authenticating
355 : pub delayed_notice: Vec<NoticeResponseBody>,
356 : }
357 :
358 : pub struct ComputeConnection {
359 : /// Socket connected to a compute node.
360 : pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
361 : /// Labels for proxy's metrics.
362 : pub aux: MetricsAuxInfo,
363 : pub hostname: Host,
364 : pub ssl_mode: SslMode,
365 : pub socket_addr: SocketAddr,
366 : pub guage: NumDbConnectionsGuard<'static>,
367 : }
368 :
369 : impl ConnectInfo {
370 : /// Connect to a corresponding compute node.
371 0 : pub async fn connect(
372 0 : &self,
373 0 : ctx: &RequestContext,
374 0 : aux: &MetricsAuxInfo,
375 0 : config: &ComputeConfig,
376 0 : ) -> Result<ComputeConnection, ConnectionError> {
377 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
378 0 : let (socket_addr, stream) = self.connect_raw(config).await?;
379 0 : drop(pause);
380 :
381 0 : tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
382 :
383 : // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
384 0 : info!(
385 0 : cold_start_info = ctx.cold_start_info().as_str(),
386 0 : "connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
387 : self.host,
388 : self.ssl_mode,
389 0 : ctx.get_proxy_latency(),
390 0 : ctx.get_testodrome_id().unwrap_or_default(),
391 : );
392 :
393 0 : let connection = ComputeConnection {
394 0 : stream,
395 0 : socket_addr,
396 0 : hostname: self.host.clone(),
397 0 : ssl_mode: self.ssl_mode,
398 0 : aux: aux.clone(),
399 0 : guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
400 0 : };
401 :
402 0 : Ok(connection)
403 0 : }
404 : }
405 :
406 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
407 6 : fn filtered_options(options: &str) -> Option<String> {
408 : #[allow(unstable_name_collisions)]
409 6 : let options: String = StartupMessageParams::parse_options_raw(options)
410 14 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
411 6 : .intersperse(" ") // TODO: use impl from std once it's stabilized
412 6 : .collect();
413 :
414 : // Don't even bother with empty options.
415 6 : if options.is_empty() {
416 3 : return None;
417 3 : }
418 :
419 3 : Some(options)
420 6 : }
421 :
422 : #[cfg(test)]
423 : mod tests {
424 : use super::*;
425 :
426 : #[test]
427 1 : fn test_filtered_options() {
428 : // Empty options is unlikely to be useful anyway.
429 1 : let params = "";
430 1 : assert_eq!(filtered_options(params), None);
431 :
432 : // It's likely that clients will only use options to specify endpoint/project.
433 1 : let params = "project=foo";
434 1 : assert_eq!(filtered_options(params), None);
435 :
436 : // Same, because unescaped whitespaces are no-op.
437 1 : let params = " project=foo ";
438 1 : assert_eq!(filtered_options(params).as_deref(), None);
439 :
440 1 : let params = r"\ project=foo \ ";
441 1 : assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
442 :
443 1 : let params = "project = foo";
444 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
445 :
446 1 : let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true";
447 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
448 1 : }
449 : }
|