Line data Source code
1 : use crate::{
2 : auth::parse_endpoint_param,
3 : cancellation::CancelClosure,
4 : console::{errors::WakeComputeError, messages::MetricsAuxInfo},
5 : context::RequestMonitoring,
6 : error::{ReportableError, UserFacingError},
7 : metrics::{Metrics, NumDbConnectionsGuard},
8 : proxy::neon_option,
9 : };
10 : use futures::{FutureExt, TryFutureExt};
11 : use itertools::Itertools;
12 : use pq_proto::StartupMessageParams;
13 : use std::{io, net::SocketAddr, time::Duration};
14 : use thiserror::Error;
15 : use tokio::net::TcpStream;
16 : use tokio_postgres::tls::MakeTlsConnect;
17 : use tracing::{error, info, warn};
18 :
19 : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
20 :
21 0 : #[derive(Debug, Error)]
22 : pub enum ConnectionError {
23 : /// This error doesn't seem to reveal any secrets; for instance,
24 : /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
25 : #[error("{COULD_NOT_CONNECT}: {0}")]
26 : Postgres(#[from] tokio_postgres::Error),
27 :
28 : #[error("{COULD_NOT_CONNECT}: {0}")]
29 : CouldNotConnect(#[from] io::Error),
30 :
31 : #[error("{COULD_NOT_CONNECT}: {0}")]
32 : TlsError(#[from] native_tls::Error),
33 :
34 : #[error("{COULD_NOT_CONNECT}: {0}")]
35 : WakeComputeError(#[from] WakeComputeError),
36 : }
37 :
38 : impl UserFacingError for ConnectionError {
39 0 : fn to_string_client(&self) -> String {
40 0 : use ConnectionError::*;
41 0 : match self {
42 : // This helps us drop irrelevant library-specific prefixes.
43 : // TODO: propagate severity level and other parameters.
44 0 : Postgres(err) => match err.as_db_error() {
45 0 : Some(err) => {
46 0 : let msg = err.message();
47 0 :
48 0 : if msg.starts_with("unsupported startup parameter: ")
49 0 : || msg.starts_with("unsupported startup parameter in options: ")
50 : {
51 0 : format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
52 : } else {
53 0 : msg.to_owned()
54 : }
55 : }
56 0 : None => err.to_string(),
57 : },
58 0 : WakeComputeError(err) => err.to_string_client(),
59 0 : _ => COULD_NOT_CONNECT.to_owned(),
60 : }
61 0 : }
62 : }
63 :
64 : impl ReportableError for ConnectionError {
65 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
66 0 : match self {
67 0 : ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
68 0 : crate::error::ErrorKind::Postgres
69 : }
70 0 : ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
71 0 : ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
72 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
73 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
74 : }
75 0 : }
76 : }
77 :
78 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
79 : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
80 :
81 : /// A config for establishing a connection to compute node.
82 : /// Eventually, `tokio_postgres` will be replaced with something better.
83 : /// Newtype allows us to implement methods on top of it.
84 : #[derive(Clone, Default)]
85 : pub struct ConnCfg(Box<tokio_postgres::Config>);
86 :
87 : /// Creation and initialization routines.
88 : impl ConnCfg {
89 20 : pub fn new() -> Self {
90 20 : Self::default()
91 20 : }
92 :
93 : /// Reuse password or auth keys from the other config.
94 8 : pub fn reuse_password(&mut self, other: Self) {
95 8 : if let Some(password) = other.get_password() {
96 8 : self.password(password);
97 8 : }
98 :
99 8 : if let Some(keys) = other.get_auth_keys() {
100 0 : self.auth_keys(keys);
101 8 : }
102 8 : }
103 :
104 : /// Apply startup message params to the connection config.
105 0 : pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
106 : // Only set `user` if it's not present in the config.
107 : // Link auth flow takes username from the console's response.
108 0 : if let (None, Some(user)) = (self.get_user(), params.get("user")) {
109 0 : self.user(user);
110 0 : }
111 :
112 : // Only set `dbname` if it's not present in the config.
113 : // Link auth flow takes dbname from the console's response.
114 0 : if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
115 0 : self.dbname(dbname);
116 0 : }
117 :
118 : // Don't add `options` if they were only used for specifying a project.
119 : // Connection pools don't support `options`, because they affect backend startup.
120 0 : if let Some(options) = filtered_options(params) {
121 0 : self.options(&options);
122 0 : }
123 :
124 0 : if let Some(app_name) = params.get("application_name") {
125 0 : self.application_name(app_name);
126 0 : }
127 :
128 : // TODO: This is especially ugly...
129 0 : if let Some(replication) = params.get("replication") {
130 : use tokio_postgres::config::ReplicationMode;
131 0 : match replication {
132 0 : "true" | "on" | "yes" | "1" => {
133 0 : self.replication_mode(ReplicationMode::Physical);
134 0 : }
135 0 : "database" => {
136 0 : self.replication_mode(ReplicationMode::Logical);
137 0 : }
138 0 : _other => {}
139 : }
140 0 : }
141 :
142 : // TODO: extend the list of the forwarded startup parameters.
143 : // Currently, tokio-postgres doesn't allow us to pass
144 : // arbitrary parameters, but the ones above are a good start.
145 : //
146 : // This and the reverse params problem can be better addressed
147 : // in a bespoke connection machinery (a new library for that sake).
148 0 : }
149 : }
150 :
151 : impl std::ops::Deref for ConnCfg {
152 : type Target = tokio_postgres::Config;
153 :
154 16 : fn deref(&self) -> &Self::Target {
155 16 : &self.0
156 16 : }
157 : }
158 :
159 : /// For now, let's make it easier to setup the config.
160 : impl std::ops::DerefMut for ConnCfg {
161 20 : fn deref_mut(&mut self) -> &mut Self::Target {
162 20 : &mut self.0
163 20 : }
164 : }
165 :
166 : impl ConnCfg {
167 : /// Establish a raw TCP connection to the compute node.
168 0 : async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
169 0 : use tokio_postgres::config::Host;
170 0 :
171 0 : // wrap TcpStream::connect with timeout
172 0 : let connect_with_timeout = |host, port| {
173 0 : tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
174 0 : move |res| match res {
175 0 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
176 0 : Err(_) => Err(io::Error::new(
177 0 : io::ErrorKind::TimedOut,
178 0 : format!("exceeded connection timeout {timeout:?}"),
179 0 : )),
180 0 : },
181 0 : )
182 0 : };
183 :
184 0 : let connect_once = |host, port| {
185 0 : info!("trying to connect to compute node at {host}:{port}");
186 0 : connect_with_timeout(host, port).and_then(|socket| async {
187 0 : let socket_addr = socket.peer_addr()?;
188 : // This prevents load balancer from severing the connection.
189 0 : socket2::SockRef::from(&socket).set_keepalive(true)?;
190 0 : Ok((socket_addr, socket))
191 0 : })
192 0 : };
193 :
194 : // We can't reuse connection establishing logic from `tokio_postgres` here,
195 : // because it has no means for extracting the underlying socket which we
196 : // require for our business.
197 0 : let mut connection_error = None;
198 0 : let ports = self.0.get_ports();
199 0 : let hosts = self.0.get_hosts();
200 0 : // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
201 0 : if ports.len() > 1 && ports.len() != hosts.len() {
202 0 : return Err(io::Error::new(
203 0 : io::ErrorKind::Other,
204 0 : format!(
205 0 : "bad compute config, \
206 0 : ports and hosts entries' count does not match: {:?}",
207 0 : self.0
208 0 : ),
209 0 : ));
210 0 : }
211 :
212 0 : for (i, host) in hosts.iter().enumerate() {
213 0 : let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
214 0 : let host = match host {
215 0 : Host::Tcp(host) => host.as_str(),
216 0 : Host::Unix(_) => continue, // unix sockets are not welcome here
217 : };
218 :
219 0 : match connect_once(host, *port).await {
220 0 : Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
221 0 : Err(err) => {
222 0 : // We can't throw an error here, as there might be more hosts to try.
223 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
224 0 : connection_error = Some(err);
225 : }
226 : }
227 : }
228 :
229 0 : Err(connection_error.unwrap_or_else(|| {
230 0 : io::Error::new(
231 0 : io::ErrorKind::Other,
232 0 : format!("bad compute config: {:?}", self.0),
233 0 : )
234 0 : }))
235 0 : }
236 : }
237 :
238 : pub struct PostgresConnection {
239 : /// Socket connected to a compute node.
240 : pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
241 : tokio::net::TcpStream,
242 : postgres_native_tls::TlsStream<tokio::net::TcpStream>,
243 : >,
244 : /// PostgreSQL connection parameters.
245 : pub params: std::collections::HashMap<String, String>,
246 : /// Query cancellation token.
247 : pub cancel_closure: CancelClosure,
248 : /// Labels for proxy's metrics.
249 : pub aux: MetricsAuxInfo,
250 :
251 : _guage: NumDbConnectionsGuard<'static>,
252 : }
253 :
254 : impl ConnCfg {
255 : /// Connect to a corresponding compute node.
256 0 : pub async fn connect(
257 0 : &self,
258 0 : ctx: &mut RequestMonitoring,
259 0 : allow_self_signed_compute: bool,
260 0 : aux: MetricsAuxInfo,
261 0 : timeout: Duration,
262 0 : ) -> Result<PostgresConnection, ConnectionError> {
263 0 : let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
264 :
265 0 : let tls_connector = native_tls::TlsConnector::builder()
266 0 : .danger_accept_invalid_certs(allow_self_signed_compute)
267 0 : .build()
268 0 : .unwrap();
269 0 : let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
270 0 : let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
271 :
272 : // connect_raw() will not use TLS if sslmode is "disable"
273 0 : let (client, connection) = self.0.connect_raw(stream, tls).await?;
274 0 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
275 0 : let stream = connection.stream.into_inner();
276 0 :
277 0 : info!(
278 0 : cold_start_info = ctx.cold_start_info.as_str(),
279 0 : "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
280 0 : self.0.get_ssl_mode()
281 0 : );
282 :
283 : // This is very ugly but as of now there's no better way to
284 : // extract the connection parameters from tokio-postgres' connection.
285 : // TODO: solve this problem in a more elegant manner (e.g. the new library).
286 0 : let params = connection.parameters;
287 0 :
288 0 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
289 0 : // Yet another reason to rework the connection establishing code.
290 0 : let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
291 0 :
292 0 : let connection = PostgresConnection {
293 0 : stream,
294 0 : params,
295 0 : cancel_closure,
296 0 : aux,
297 0 : _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol),
298 0 : };
299 0 :
300 0 : Ok(connection)
301 0 : }
302 : }
303 :
304 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
305 12 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
306 : #[allow(unstable_name_collisions)]
307 12 : let options: String = params
308 12 : .options_raw()?
309 26 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
310 12 : .intersperse(" ") // TODO: use impl from std once it's stabilized
311 12 : .collect();
312 12 :
313 12 : // Don't even bother with empty options.
314 12 : if options.is_empty() {
315 6 : return None;
316 6 : }
317 6 :
318 6 : Some(options)
319 12 : }
320 :
321 : #[cfg(test)]
322 : mod tests {
323 : use super::*;
324 :
325 : #[test]
326 2 : fn test_filtered_options() {
327 2 : // Empty options is unlikely to be useful anyway.
328 2 : let params = StartupMessageParams::new([("options", "")]);
329 2 : assert_eq!(filtered_options(¶ms), None);
330 :
331 : // It's likely that clients will only use options to specify endpoint/project.
332 2 : let params = StartupMessageParams::new([("options", "project=foo")]);
333 2 : assert_eq!(filtered_options(¶ms), None);
334 :
335 : // Same, because unescaped whitespaces are no-op.
336 2 : let params = StartupMessageParams::new([("options", " project=foo ")]);
337 2 : assert_eq!(filtered_options(¶ms).as_deref(), None);
338 :
339 2 : let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]);
340 2 : assert_eq!(filtered_options(¶ms).as_deref(), Some(r"\ \ "));
341 :
342 2 : let params = StartupMessageParams::new([("options", "project = foo")]);
343 2 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
344 :
345 2 : let params = StartupMessageParams::new([(
346 2 : "options",
347 2 : "project = foo neon_endpoint_type:read_write neon_lsn:0/2",
348 2 : )]);
349 2 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
350 2 : }
351 : }
|