Line data Source code
1 : mod tls;
2 :
3 : use std::fmt::Debug;
4 : use std::io;
5 : use std::net::{IpAddr, SocketAddr};
6 :
7 : use futures::{FutureExt, TryFutureExt};
8 : use itertools::Itertools;
9 : use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
10 : use postgres_client::connect_raw::StartupStream;
11 : use postgres_client::maybe_tls_stream::MaybeTlsStream;
12 : use postgres_client::tls::MakeTlsConnect;
13 : use thiserror::Error;
14 : use tokio::net::{TcpStream, lookup_host};
15 : use tracing::{debug, error, info, warn};
16 :
17 : use crate::auth::backend::ComputeCredentialKeys;
18 : use crate::auth::parse_endpoint_param;
19 : use crate::compute::tls::TlsError;
20 : use crate::config::ComputeConfig;
21 : use crate::context::RequestContext;
22 : use crate::control_plane::client::ApiLockError;
23 : use crate::control_plane::errors::WakeComputeError;
24 : use crate::control_plane::messages::MetricsAuxInfo;
25 : use crate::error::{ReportableError, UserFacingError};
26 : use crate::metrics::{Metrics, NumDbConnectionsGuard};
27 : use crate::pqproto::StartupMessageParams;
28 : use crate::proxy::neon_option;
29 : use crate::types::Host;
30 :
31 : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
32 :
33 : #[derive(Debug, Error)]
34 : pub(crate) enum PostgresError {
35 : /// This error doesn't seem to reveal any secrets; for instance,
36 : /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
37 : #[error("{COULD_NOT_CONNECT}: {0}")]
38 : Postgres(#[from] postgres_client::Error),
39 : }
40 :
41 : impl UserFacingError for PostgresError {
42 0 : fn to_string_client(&self) -> String {
43 0 : match self {
44 : // This helps us drop irrelevant library-specific prefixes.
45 : // TODO: propagate severity level and other parameters.
46 0 : PostgresError::Postgres(err) => match err.as_db_error() {
47 0 : Some(err) => {
48 0 : let msg = err.message();
49 :
50 0 : if msg.starts_with("unsupported startup parameter: ")
51 0 : || msg.starts_with("unsupported startup parameter in options: ")
52 : {
53 0 : format!(
54 0 : "{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter"
55 : )
56 : } else {
57 0 : msg.to_owned()
58 : }
59 : }
60 0 : None => err.to_string(),
61 : },
62 : }
63 0 : }
64 : }
65 :
66 : impl ReportableError for PostgresError {
67 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
68 0 : match self {
69 0 : PostgresError::Postgres(e) if e.as_db_error().is_some() => {
70 0 : crate::error::ErrorKind::Postgres
71 : }
72 0 : PostgresError::Postgres(_) => crate::error::ErrorKind::Compute,
73 : }
74 0 : }
75 : }
76 :
77 : #[derive(Debug, Error)]
78 : pub(crate) enum ConnectionError {
79 : #[error("{COULD_NOT_CONNECT}: {0}")]
80 : TlsError(#[from] TlsError),
81 :
82 : #[error("{COULD_NOT_CONNECT}: {0}")]
83 : WakeComputeError(#[from] WakeComputeError),
84 :
85 : #[error("error acquiring resource permit: {0}")]
86 : TooManyConnectionAttempts(#[from] ApiLockError),
87 : }
88 :
89 : impl UserFacingError for ConnectionError {
90 0 : fn to_string_client(&self) -> String {
91 0 : match self {
92 0 : ConnectionError::WakeComputeError(err) => err.to_string_client(),
93 : ConnectionError::TooManyConnectionAttempts(_) => {
94 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
95 : }
96 0 : ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
97 : }
98 0 : }
99 : }
100 :
101 : impl ReportableError for ConnectionError {
102 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
103 0 : match self {
104 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
105 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
106 0 : ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
107 : }
108 0 : }
109 : }
110 :
111 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
112 : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
113 :
114 : #[derive(Clone)]
115 : pub enum Auth {
116 : /// Only used during console-redirect.
117 : Password(Vec<u8>),
118 : /// Used by sql-over-http, ws, tcp.
119 : Scram(Box<ScramKeys>),
120 : }
121 :
122 : /// A config for authenticating to the compute node.
123 : pub(crate) struct AuthInfo {
124 : /// None for local-proxy, as we use trust-based localhost auth.
125 : /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
126 : /// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
127 : auth: Option<Auth>,
128 : server_params: StartupMessageParams,
129 :
130 : channel_binding: ChannelBinding,
131 :
132 : /// Console redirect sets user and database, we shouldn't re-use those from the params.
133 : skip_db_user: bool,
134 : }
135 :
136 : /// Contains only the data needed to establish a secure connection to compute.
137 : #[derive(Clone)]
138 : pub struct ConnectInfo {
139 : pub host_addr: Option<IpAddr>,
140 : pub host: Host,
141 : pub port: u16,
142 : pub ssl_mode: SslMode,
143 : }
144 :
145 : /// Creation and initialization routines.
146 : impl AuthInfo {
147 0 : pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
148 0 : let mut server_params = StartupMessageParams::default();
149 0 : server_params.insert("database", db);
150 0 : server_params.insert("user", user);
151 : Self {
152 0 : auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
153 0 : server_params,
154 : skip_db_user: true,
155 : // pg-sni-router is a mitm so this would fail.
156 0 : channel_binding: ChannelBinding::Disable,
157 : }
158 0 : }
159 :
160 0 : pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
161 : Self {
162 0 : auth: match keys {
163 0 : ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
164 0 : Some(Auth::Scram(Box::new(auth_keys)))
165 : }
166 0 : ComputeCredentialKeys::JwtPayload(_) => None,
167 : },
168 0 : server_params: StartupMessageParams::default(),
169 : skip_db_user: false,
170 0 : channel_binding: ChannelBinding::Prefer,
171 : }
172 0 : }
173 : }
174 :
175 : impl ConnectInfo {
176 0 : pub fn to_postgres_client_config(&self) -> postgres_client::Config {
177 0 : let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
178 0 : config.ssl_mode(self.ssl_mode);
179 0 : if let Some(host_addr) = self.host_addr {
180 0 : config.set_host_addr(host_addr);
181 0 : }
182 0 : config
183 0 : }
184 : }
185 :
186 : impl AuthInfo {
187 0 : fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
188 0 : match &self.auth {
189 0 : Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
190 0 : Some(Auth::Password(pw)) => config.password(pw),
191 0 : None => &mut config,
192 : };
193 0 : config.channel_binding(self.channel_binding);
194 0 : for (k, v) in self.server_params.iter() {
195 0 : config.set_param(k, v);
196 0 : }
197 0 : config
198 0 : }
199 :
200 : /// Apply startup message params to the connection config.
201 0 : pub(crate) fn set_startup_params(
202 0 : &mut self,
203 0 : params: &StartupMessageParams,
204 0 : arbitrary_params: bool,
205 0 : ) {
206 0 : if !arbitrary_params {
207 0 : self.server_params.insert("client_encoding", "UTF8");
208 0 : }
209 0 : for (k, v) in params.iter() {
210 0 : match k {
211 : // Only set `user` if it's not present in the config.
212 : // Console redirect auth flow takes username from the console's response.
213 0 : "user" | "database" if self.skip_db_user => {}
214 0 : "options" => {
215 0 : if let Some(options) = filtered_options(v) {
216 0 : self.server_params.insert(k, &options);
217 0 : }
218 : }
219 0 : "user" | "database" | "application_name" | "replication" => {
220 0 : self.server_params.insert(k, v);
221 0 : }
222 :
223 : // if we allow arbitrary params, then we forward them through.
224 : // this is a flag for a period of backwards compatibility
225 0 : k if arbitrary_params => {
226 0 : self.server_params.insert(k, v);
227 0 : }
228 0 : _ => {}
229 : }
230 : }
231 0 : }
232 :
233 0 : pub async fn authenticate(
234 0 : &self,
235 0 : ctx: &RequestContext,
236 0 : compute: &mut ComputeConnection,
237 0 : ) -> Result<(), PostgresError> {
238 : // client config with stubbed connect info.
239 : // TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
240 : // utilising pqproto.rs.
241 0 : let mut tmp_config = postgres_client::Config::new(String::new(), 0);
242 : // We have already established SSL if necessary.
243 0 : tmp_config.ssl_mode(SslMode::Disable);
244 0 : let tmp_config = self.enrich(tmp_config);
245 :
246 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
247 0 : tmp_config.authenticate(&mut compute.stream).await?;
248 0 : drop(pause);
249 :
250 0 : Ok(())
251 0 : }
252 : }
253 :
254 : impl ConnectInfo {
255 : /// Establish a raw TCP+TLS connection to the compute node.
256 0 : async fn connect_raw(
257 0 : &self,
258 0 : config: &ComputeConfig,
259 0 : ) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
260 0 : let timeout = config.timeout;
261 :
262 : // wrap TcpStream::connect with timeout
263 0 : let connect_with_timeout = |addrs| {
264 0 : tokio::time::timeout(timeout, TcpStream::connect(addrs)).map(move |res| match res {
265 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
266 0 : Err(_) => Err(io::Error::new(
267 0 : io::ErrorKind::TimedOut,
268 0 : format!("exceeded connection timeout {timeout:?}"),
269 0 : )),
270 0 : })
271 0 : };
272 :
273 0 : let connect_once = |addrs| {
274 0 : debug!("trying to connect to compute node at {addrs:?}");
275 0 : connect_with_timeout(addrs).and_then(|stream| async {
276 0 : let socket_addr = stream.peer_addr()?;
277 0 : let socket = socket2::SockRef::from(&stream);
278 : // Disable Nagle's algorithm to not introduce latency between
279 : // client and compute.
280 0 : socket.set_nodelay(true)?;
281 : // This prevents load balancer from severing the connection.
282 0 : socket.set_keepalive(true)?;
283 0 : Ok((socket_addr, stream))
284 0 : })
285 0 : };
286 :
287 : // We can't reuse connection establishing logic from `postgres_client` here,
288 : // because it has no means for extracting the underlying socket which we
289 : // require for our business.
290 0 : let port = self.port;
291 0 : let host = &*self.host;
292 :
293 0 : let addrs = match self.host_addr {
294 0 : Some(addr) => vec![SocketAddr::new(addr, port)],
295 0 : None => lookup_host((host, port)).await?.collect(),
296 : };
297 :
298 0 : match connect_once(&*addrs).await {
299 0 : Ok((sockaddr, stream)) => Ok((
300 0 : sockaddr,
301 0 : tls::connect_tls(stream, self.ssl_mode, config, host).await?,
302 : )),
303 0 : Err(err) => {
304 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
305 0 : Err(TlsError::Connection(err))
306 : }
307 : }
308 0 : }
309 : }
310 :
311 : pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
312 : pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
313 :
314 : pub struct ComputeConnection {
315 : /// Socket connected to a compute node.
316 : pub stream: StartupStream<tokio::net::TcpStream, RustlsStream>,
317 : /// Labels for proxy's metrics.
318 : pub aux: MetricsAuxInfo,
319 : pub hostname: Host,
320 : pub ssl_mode: SslMode,
321 : pub socket_addr: SocketAddr,
322 : pub guage: NumDbConnectionsGuard<'static>,
323 : }
324 :
325 : impl ConnectInfo {
326 : /// Connect to a corresponding compute node.
327 0 : pub async fn connect(
328 0 : &self,
329 0 : ctx: &RequestContext,
330 0 : aux: &MetricsAuxInfo,
331 0 : config: &ComputeConfig,
332 0 : ) -> Result<ComputeConnection, ConnectionError> {
333 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
334 0 : let (socket_addr, stream) = self.connect_raw(config).await?;
335 0 : drop(pause);
336 :
337 0 : tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
338 :
339 : // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
340 0 : info!(
341 0 : cold_start_info = ctx.cold_start_info().as_str(),
342 0 : "connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
343 : self.host,
344 : self.ssl_mode,
345 0 : ctx.get_proxy_latency(),
346 0 : ctx.get_testodrome_id().unwrap_or_default(),
347 : );
348 :
349 0 : let stream = StartupStream::new(stream);
350 0 : let connection = ComputeConnection {
351 0 : stream,
352 0 : socket_addr,
353 0 : hostname: self.host.clone(),
354 0 : ssl_mode: self.ssl_mode,
355 0 : aux: aux.clone(),
356 0 : guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
357 0 : };
358 :
359 0 : Ok(connection)
360 0 : }
361 : }
362 :
363 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
364 6 : fn filtered_options(options: &str) -> Option<String> {
365 : #[allow(unstable_name_collisions)]
366 6 : let options: String = StartupMessageParams::parse_options_raw(options)
367 14 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
368 6 : .intersperse(" ") // TODO: use impl from std once it's stabilized
369 6 : .collect();
370 :
371 : // Don't even bother with empty options.
372 6 : if options.is_empty() {
373 3 : return None;
374 3 : }
375 :
376 3 : Some(options)
377 6 : }
378 :
379 : #[cfg(test)]
380 : mod tests {
381 : use super::*;
382 :
383 : #[test]
384 1 : fn test_filtered_options() {
385 : // Empty options is unlikely to be useful anyway.
386 1 : let params = "";
387 1 : assert_eq!(filtered_options(params), None);
388 :
389 : // It's likely that clients will only use options to specify endpoint/project.
390 1 : let params = "project=foo";
391 1 : assert_eq!(filtered_options(params), None);
392 :
393 : // Same, because unescaped whitespaces are no-op.
394 1 : let params = " project=foo ";
395 1 : assert_eq!(filtered_options(params).as_deref(), None);
396 :
397 1 : let params = r"\ project=foo \ ";
398 1 : assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
399 :
400 1 : let params = "project = foo";
401 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
402 :
403 1 : let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true";
404 1 : assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
405 1 : }
406 : }
|