Line data Source code
1 : use crate::{
2 : auth::parse_endpoint_param,
3 : cancellation::CancelClosure,
4 : console::errors::WakeComputeError,
5 : context::RequestMonitoring,
6 : error::{ReportableError, UserFacingError},
7 : metrics::NUM_DB_CONNECTIONS_GAUGE,
8 : proxy::neon_option,
9 : };
10 : use futures::{FutureExt, TryFutureExt};
11 : use itertools::Itertools;
12 : use metrics::IntCounterPairGuard;
13 : use pq_proto::StartupMessageParams;
14 : use std::{io, net::SocketAddr, time::Duration};
15 : use thiserror::Error;
16 : use tokio::net::TcpStream;
17 : use tokio_postgres::tls::MakeTlsConnect;
18 : use tracing::{error, info, warn};
19 :
20 : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
21 :
22 0 : #[derive(Debug, Error)]
23 : pub enum ConnectionError {
24 : /// This error doesn't seem to reveal any secrets; for instance,
25 : /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
26 : #[error("{COULD_NOT_CONNECT}: {0}")]
27 : Postgres(#[from] tokio_postgres::Error),
28 :
29 : #[error("{COULD_NOT_CONNECT}: {0}")]
30 : CouldNotConnect(#[from] io::Error),
31 :
32 : #[error("{COULD_NOT_CONNECT}: {0}")]
33 : TlsError(#[from] native_tls::Error),
34 :
35 : #[error("{COULD_NOT_CONNECT}: {0}")]
36 : WakeComputeError(#[from] WakeComputeError),
37 : }
38 :
39 : impl UserFacingError for ConnectionError {
40 0 : fn to_string_client(&self) -> String {
41 0 : use ConnectionError::*;
42 0 : match self {
43 : // This helps us drop irrelevant library-specific prefixes.
44 : // TODO: propagate severity level and other parameters.
45 0 : Postgres(err) => match err.as_db_error() {
46 0 : Some(err) => {
47 0 : let msg = err.message();
48 0 :
49 0 : if msg.starts_with("unsupported startup parameter: ")
50 0 : || msg.starts_with("unsupported startup parameter in options: ")
51 : {
52 0 : format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
53 : } else {
54 0 : msg.to_owned()
55 : }
56 : }
57 0 : None => err.to_string(),
58 : },
59 0 : WakeComputeError(err) => err.to_string_client(),
60 0 : _ => COULD_NOT_CONNECT.to_owned(),
61 : }
62 0 : }
63 : }
64 :
65 : impl ReportableError for ConnectionError {
66 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
67 0 : match self {
68 0 : ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
69 0 : crate::error::ErrorKind::Postgres
70 : }
71 0 : ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
72 0 : ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
73 0 : ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
74 0 : ConnectionError::WakeComputeError(e) => e.get_error_kind(),
75 : }
76 0 : }
77 : }
78 :
79 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
80 : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
81 :
82 : /// A config for establishing a connection to compute node.
83 : /// Eventually, `tokio_postgres` will be replaced with something better.
84 : /// Newtype allows us to implement methods on top of it.
85 0 : #[derive(Clone)]
86 : #[repr(transparent)]
87 : pub struct ConnCfg(Box<tokio_postgres::Config>);
88 :
89 : /// Creation and initialization routines.
90 : impl ConnCfg {
91 105 : pub fn new() -> Self {
92 105 : Self(Default::default())
93 105 : }
94 :
95 : /// Reuse password or auth keys from the other config.
96 10 : pub fn reuse_password(&mut self, other: &Self) {
97 10 : if let Some(password) = other.get_password() {
98 0 : self.password(password);
99 10 : }
100 :
101 10 : if let Some(keys) = other.get_auth_keys() {
102 0 : self.auth_keys(keys);
103 10 : }
104 10 : }
105 :
106 : /// Apply startup message params to the connection config.
107 41 : pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
108 : // Only set `user` if it's not present in the config.
109 : // Link auth flow takes username from the console's response.
110 41 : if let (None, Some(user)) = (self.get_user(), params.get("user")) {
111 38 : self.user(user);
112 38 : }
113 :
114 : // Only set `dbname` if it's not present in the config.
115 : // Link auth flow takes dbname from the console's response.
116 41 : if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
117 38 : self.dbname(dbname);
118 38 : }
119 :
120 : // Don't add `options` if they were only used for specifying a project.
121 : // Connection pools don't support `options`, because they affect backend startup.
122 41 : if let Some(options) = filtered_options(params) {
123 38 : self.options(&options);
124 38 : }
125 :
126 41 : if let Some(app_name) = params.get("application_name") {
127 3 : self.application_name(app_name);
128 38 : }
129 :
130 : // TODO: This is especially ugly...
131 41 : if let Some(replication) = params.get("replication") {
132 : use tokio_postgres::config::ReplicationMode;
133 0 : match replication {
134 0 : "true" | "on" | "yes" | "1" => {
135 0 : self.replication_mode(ReplicationMode::Physical);
136 0 : }
137 0 : "database" => {
138 0 : self.replication_mode(ReplicationMode::Logical);
139 0 : }
140 0 : _other => {}
141 : }
142 41 : }
143 :
144 : // TODO: extend the list of the forwarded startup parameters.
145 : // Currently, tokio-postgres doesn't allow us to pass
146 : // arbitrary parameters, but the ones above are a good start.
147 : //
148 : // This and the reverse params problem can be better addressed
149 : // in a bespoke connection machinery (a new library for that sake).
150 41 : }
151 : }
152 :
153 : impl std::ops::Deref for ConnCfg {
154 : type Target = tokio_postgres::Config;
155 :
156 142 : fn deref(&self) -> &Self::Target {
157 142 : &self.0
158 142 : }
159 : }
160 :
161 : /// For now, let's make it easier to setup the config.
162 : impl std::ops::DerefMut for ConnCfg {
163 279 : fn deref_mut(&mut self) -> &mut Self::Target {
164 279 : &mut self.0
165 279 : }
166 : }
167 :
168 : impl Default for ConnCfg {
169 0 : fn default() -> Self {
170 0 : Self::new()
171 0 : }
172 : }
173 :
174 : impl ConnCfg {
175 : /// Establish a raw TCP connection to the compute node.
176 41 : async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
177 41 : use tokio_postgres::config::Host;
178 41 :
179 41 : // wrap TcpStream::connect with timeout
180 41 : let connect_with_timeout = |host, port| {
181 41 : tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
182 41 : move |res| match res {
183 41 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
184 0 : Err(_) => Err(io::Error::new(
185 0 : io::ErrorKind::TimedOut,
186 0 : format!("exceeded connection timeout {timeout:?}"),
187 0 : )),
188 41 : },
189 41 : )
190 41 : };
191 :
192 41 : let connect_once = |host, port| {
193 41 : info!("trying to connect to compute node at {host}:{port}");
194 41 : connect_with_timeout(host, port).and_then(|socket| async {
195 41 : let socket_addr = socket.peer_addr()?;
196 : // This prevents load balancer from severing the connection.
197 41 : socket2::SockRef::from(&socket).set_keepalive(true)?;
198 41 : Ok((socket_addr, socket))
199 82 : })
200 41 : };
201 :
202 : // We can't reuse connection establishing logic from `tokio_postgres` here,
203 : // because it has no means for extracting the underlying socket which we
204 : // require for our business.
205 41 : let mut connection_error = None;
206 41 : let ports = self.0.get_ports();
207 41 : let hosts = self.0.get_hosts();
208 41 : // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
209 41 : if ports.len() > 1 && ports.len() != hosts.len() {
210 0 : return Err(io::Error::new(
211 0 : io::ErrorKind::Other,
212 0 : format!(
213 0 : "bad compute config, \
214 0 : ports and hosts entries' count does not match: {:?}",
215 0 : self.0
216 0 : ),
217 0 : ));
218 41 : }
219 :
220 41 : for (i, host) in hosts.iter().enumerate() {
221 41 : let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
222 41 : let host = match host {
223 41 : Host::Tcp(host) => host.as_str(),
224 0 : Host::Unix(_) => continue, // unix sockets are not welcome here
225 : };
226 :
227 77 : match connect_once(host, *port).await {
228 41 : Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
229 0 : Err(err) => {
230 : // We can't throw an error here, as there might be more hosts to try.
231 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
232 0 : connection_error = Some(err);
233 : }
234 : }
235 : }
236 :
237 0 : Err(connection_error.unwrap_or_else(|| {
238 0 : io::Error::new(
239 0 : io::ErrorKind::Other,
240 0 : format!("bad compute config: {:?}", self.0),
241 0 : )
242 0 : }))
243 41 : }
244 : }
245 :
246 : pub struct PostgresConnection {
247 : /// Socket connected to a compute node.
248 : pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
249 : tokio::net::TcpStream,
250 : postgres_native_tls::TlsStream<tokio::net::TcpStream>,
251 : >,
252 : /// PostgreSQL connection parameters.
253 : pub params: std::collections::HashMap<String, String>,
254 : /// Query cancellation token.
255 : pub cancel_closure: CancelClosure,
256 :
257 : _guage: IntCounterPairGuard,
258 : }
259 :
260 : impl ConnCfg {
261 : /// Connect to a corresponding compute node.
262 41 : pub async fn connect(
263 41 : &self,
264 41 : ctx: &mut RequestMonitoring,
265 41 : allow_self_signed_compute: bool,
266 41 : timeout: Duration,
267 41 : ) -> Result<PostgresConnection, ConnectionError> {
268 77 : let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
269 :
270 41 : let tls_connector = native_tls::TlsConnector::builder()
271 41 : .danger_accept_invalid_certs(allow_self_signed_compute)
272 41 : .build()
273 41 : .unwrap();
274 41 : let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
275 41 : let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
276 :
277 : // connect_raw() will not use TLS if sslmode is "disable"
278 43 : let (client, connection) = self.0.connect_raw(stream, tls).await?;
279 41 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
280 41 : let stream = connection.stream.into_inner();
281 :
282 41 : info!(
283 41 : "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
284 41 : self.0.get_ssl_mode()
285 41 : );
286 :
287 : // This is very ugly but as of now there's no better way to
288 : // extract the connection parameters from tokio-postgres' connection.
289 : // TODO: solve this problem in a more elegant manner (e.g. the new library).
290 41 : let params = connection.parameters;
291 41 :
292 41 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
293 41 : // Yet another reason to rework the connection establishing code.
294 41 : let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
295 41 :
296 41 : let connection = PostgresConnection {
297 41 : stream,
298 41 : params,
299 41 : cancel_closure,
300 41 : _guage: NUM_DB_CONNECTIONS_GAUGE
301 41 : .with_label_values(&[ctx.protocol])
302 41 : .guard(),
303 41 : };
304 41 :
305 41 : Ok(connection)
306 41 : }
307 : }
308 :
309 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
310 53 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
311 : #[allow(unstable_name_collisions)]
312 53 : let options: String = params
313 53 : .options_raw()?
314 87 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
315 53 : .intersperse(" ") // TODO: use impl from std once it's stabilized
316 53 : .collect();
317 53 :
318 53 : // Don't even bother with empty options.
319 53 : if options.is_empty() {
320 9 : return None;
321 44 : }
322 44 :
323 44 : Some(options)
324 53 : }
325 :
326 : #[cfg(test)]
327 : mod tests {
328 : use super::*;
329 :
330 2 : #[test]
331 2 : fn test_filtered_options() {
332 2 : // Empty options is unlikely to be useful anyway.
333 2 : let params = StartupMessageParams::new([("options", "")]);
334 2 : assert_eq!(filtered_options(¶ms), None);
335 :
336 : // It's likely that clients will only use options to specify endpoint/project.
337 2 : let params = StartupMessageParams::new([("options", "project=foo")]);
338 2 : assert_eq!(filtered_options(¶ms), None);
339 :
340 : // Same, because unescaped whitespaces are no-op.
341 2 : let params = StartupMessageParams::new([("options", " project=foo ")]);
342 2 : assert_eq!(filtered_options(¶ms).as_deref(), None);
343 :
344 2 : let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]);
345 2 : assert_eq!(filtered_options(¶ms).as_deref(), Some(r"\ \ "));
346 :
347 2 : let params = StartupMessageParams::new([("options", "project = foo")]);
348 2 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
349 :
350 2 : let params = StartupMessageParams::new([(
351 2 : "options",
352 2 : "project = foo neon_endpoint_type:read_write neon_lsn:0/2",
353 2 : )]);
354 2 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
355 2 : }
356 : }
|