Line data Source code
1 : use crate::{
2 : auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
3 : context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
4 : proxy::neon_option,
5 : };
6 : use futures::{FutureExt, TryFutureExt};
7 : use itertools::Itertools;
8 : use metrics::IntCounterPairGuard;
9 : use pq_proto::StartupMessageParams;
10 : use std::{io, net::SocketAddr, time::Duration};
11 : use thiserror::Error;
12 : use tokio::net::TcpStream;
13 : use tokio_postgres::tls::MakeTlsConnect;
14 : use tracing::{error, info, warn};
15 :
16 : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
17 :
18 0 : #[derive(Debug, Error)]
19 : pub enum ConnectionError {
20 : /// This error doesn't seem to reveal any secrets; for instance,
21 : /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
22 : #[error("{COULD_NOT_CONNECT}: {0}")]
23 : Postgres(#[from] tokio_postgres::Error),
24 :
25 : #[error("{COULD_NOT_CONNECT}: {0}")]
26 : CouldNotConnect(#[from] io::Error),
27 :
28 : #[error("{COULD_NOT_CONNECT}: {0}")]
29 : TlsError(#[from] native_tls::Error),
30 :
31 : #[error("{COULD_NOT_CONNECT}: {0}")]
32 : WakeComputeError(#[from] WakeComputeError),
33 : }
34 :
35 : impl UserFacingError for ConnectionError {
36 0 : fn to_string_client(&self) -> String {
37 0 : use ConnectionError::*;
38 0 : match self {
39 : // This helps us drop irrelevant library-specific prefixes.
40 : // TODO: propagate severity level and other parameters.
41 0 : Postgres(err) => match err.as_db_error() {
42 0 : Some(err) => {
43 0 : let msg = err.message();
44 0 :
45 0 : if msg.starts_with("unsupported startup parameter: ")
46 0 : || msg.starts_with("unsupported startup parameter in options: ")
47 : {
48 0 : format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
49 : } else {
50 0 : msg.to_owned()
51 : }
52 : }
53 0 : None => err.to_string(),
54 : },
55 0 : WakeComputeError(err) => err.to_string_client(),
56 0 : _ => COULD_NOT_CONNECT.to_owned(),
57 : }
58 0 : }
59 : }
60 :
61 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
62 : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
63 :
64 : /// A config for establishing a connection to compute node.
65 : /// Eventually, `tokio_postgres` will be replaced with something better.
66 : /// Newtype allows us to implement methods on top of it.
67 0 : #[derive(Clone)]
68 : #[repr(transparent)]
69 : pub struct ConnCfg(Box<tokio_postgres::Config>);
70 :
71 : /// Creation and initialization routines.
72 : impl ConnCfg {
73 106 : pub fn new() -> Self {
74 106 : Self(Default::default())
75 106 : }
76 :
77 : /// Reuse password or auth keys from the other config.
78 12 : pub fn reuse_password(&mut self, other: &Self) {
79 12 : if let Some(password) = other.get_password() {
80 0 : self.password(password);
81 12 : }
82 :
83 12 : if let Some(keys) = other.get_auth_keys() {
84 0 : self.auth_keys(keys);
85 12 : }
86 12 : }
87 :
88 : /// Apply startup message params to the connection config.
89 39 : pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
90 : // Only set `user` if it's not present in the config.
91 : // Link auth flow takes username from the console's response.
92 39 : if let (None, Some(user)) = (self.get_user(), params.get("user")) {
93 36 : self.user(user);
94 36 : }
95 :
96 : // Only set `dbname` if it's not present in the config.
97 : // Link auth flow takes dbname from the console's response.
98 39 : if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
99 36 : self.dbname(dbname);
100 36 : }
101 :
102 : // Don't add `options` if they were only used for specifying a project.
103 : // Connection pools don't support `options`, because they affect backend startup.
104 39 : if let Some(options) = filtered_options(params) {
105 36 : self.options(&options);
106 36 : }
107 :
108 39 : if let Some(app_name) = params.get("application_name") {
109 3 : self.application_name(app_name);
110 36 : }
111 :
112 : // TODO: This is especially ugly...
113 39 : if let Some(replication) = params.get("replication") {
114 : use tokio_postgres::config::ReplicationMode;
115 0 : match replication {
116 0 : "true" | "on" | "yes" | "1" => {
117 0 : self.replication_mode(ReplicationMode::Physical);
118 0 : }
119 0 : "database" => {
120 0 : self.replication_mode(ReplicationMode::Logical);
121 0 : }
122 0 : _other => {}
123 : }
124 39 : }
125 :
126 : // TODO: extend the list of the forwarded startup parameters.
127 : // Currently, tokio-postgres doesn't allow us to pass
128 : // arbitrary parameters, but the ones above are a good start.
129 : //
130 : // This and the reverse params problem can be better addressed
131 : // in a bespoke connection machinery (a new library for that sake).
132 39 : }
133 : }
134 :
135 : impl std::ops::Deref for ConnCfg {
136 : type Target = tokio_postgres::Config;
137 :
138 145 : fn deref(&self) -> &Self::Target {
139 145 : &self.0
140 145 : }
141 : }
142 :
143 : /// For now, let's make it easier to setup the config.
144 : impl std::ops::DerefMut for ConnCfg {
145 232 : fn deref_mut(&mut self) -> &mut Self::Target {
146 232 : &mut self.0
147 232 : }
148 : }
149 :
150 : impl Default for ConnCfg {
151 0 : fn default() -> Self {
152 0 : Self::new()
153 0 : }
154 : }
155 :
156 : impl ConnCfg {
157 : /// Establish a raw TCP connection to the compute node.
158 39 : async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
159 39 : use tokio_postgres::config::Host;
160 39 :
161 39 : // wrap TcpStream::connect with timeout
162 39 : let connect_with_timeout = |host, port| {
163 39 : tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
164 39 : move |res| match res {
165 39 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
166 0 : Err(_) => Err(io::Error::new(
167 0 : io::ErrorKind::TimedOut,
168 0 : format!("exceeded connection timeout {timeout:?}"),
169 0 : )),
170 39 : },
171 39 : )
172 39 : };
173 39 :
174 39 : let connect_once = |host, port| {
175 39 : info!("trying to connect to compute node at {host}:{port}");
176 39 : connect_with_timeout(host, port).and_then(|socket| async {
177 39 : let socket_addr = socket.peer_addr()?;
178 : // This prevents load balancer from severing the connection.
179 39 : socket2::SockRef::from(&socket).set_keepalive(true)?;
180 39 : Ok((socket_addr, socket))
181 39 : })
182 39 : };
183 39 :
184 39 : // We can't reuse connection establishing logic from `tokio_postgres` here,
185 39 : // because it has no means for extracting the underlying socket which we
186 39 : // require for our business.
187 39 : let mut connection_error = None;
188 39 : let ports = self.0.get_ports();
189 39 : let hosts = self.0.get_hosts();
190 39 : // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
191 39 : if ports.len() > 1 && ports.len() != hosts.len() {
192 0 : return Err(io::Error::new(
193 0 : io::ErrorKind::Other,
194 0 : format!(
195 0 : "bad compute config, \
196 0 : ports and hosts entries' count does not match: {:?}",
197 0 : self.0
198 0 : ),
199 0 : ));
200 39 : }
201 :
202 39 : for (i, host) in hosts.iter().enumerate() {
203 39 : let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
204 39 : let host = match host {
205 39 : Host::Tcp(host) => host.as_str(),
206 0 : Host::Unix(_) => continue, // unix sockets are not welcome here
207 : };
208 :
209 72 : match connect_once(host, *port).await {
210 39 : Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
211 0 : Err(err) => {
212 : // We can't throw an error here, as there might be more hosts to try.
213 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
214 0 : connection_error = Some(err);
215 : }
216 : }
217 : }
218 :
219 0 : Err(connection_error.unwrap_or_else(|| {
220 0 : io::Error::new(
221 0 : io::ErrorKind::Other,
222 0 : format!("bad compute config: {:?}", self.0),
223 0 : )
224 0 : }))
225 39 : }
226 : }
227 :
228 : pub struct PostgresConnection {
229 : /// Socket connected to a compute node.
230 : pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
231 : tokio::net::TcpStream,
232 : postgres_native_tls::TlsStream<tokio::net::TcpStream>,
233 : >,
234 : /// PostgreSQL connection parameters.
235 : pub params: std::collections::HashMap<String, String>,
236 : /// Query cancellation token.
237 : pub cancel_closure: CancelClosure,
238 :
239 : _guage: IntCounterPairGuard,
240 : }
241 :
242 : impl ConnCfg {
243 : /// Connect to a corresponding compute node.
244 39 : pub async fn connect(
245 39 : &self,
246 39 : ctx: &mut RequestMonitoring,
247 39 : allow_self_signed_compute: bool,
248 39 : timeout: Duration,
249 39 : ) -> Result<PostgresConnection, ConnectionError> {
250 72 : let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
251 :
252 39 : let tls_connector = native_tls::TlsConnector::builder()
253 39 : .danger_accept_invalid_certs(allow_self_signed_compute)
254 39 : .build()
255 39 : .unwrap();
256 39 : let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
257 39 : let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
258 :
259 : // connect_raw() will not use TLS if sslmode is "disable"
260 42 : let (client, connection) = self.0.connect_raw(stream, tls).await?;
261 39 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
262 39 : let stream = connection.stream.into_inner();
263 :
264 39 : info!(
265 39 : "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
266 39 : self.0.get_ssl_mode()
267 39 : );
268 :
269 : // This is very ugly but as of now there's no better way to
270 : // extract the connection parameters from tokio-postgres' connection.
271 : // TODO: solve this problem in a more elegant manner (e.g. the new library).
272 39 : let params = connection.parameters;
273 39 :
274 39 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
275 39 : // Yet another reason to rework the connection establishing code.
276 39 : let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
277 39 :
278 39 : let connection = PostgresConnection {
279 39 : stream,
280 39 : params,
281 39 : cancel_closure,
282 39 : _guage: NUM_DB_CONNECTIONS_GAUGE
283 39 : .with_label_values(&[ctx.protocol])
284 39 : .guard(),
285 39 : };
286 39 :
287 39 : Ok(connection)
288 39 : }
289 : }
290 :
291 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
292 51 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
293 : #[allow(unstable_name_collisions)]
294 51 : let options: String = params
295 51 : .options_raw()?
296 85 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
297 51 : .intersperse(" ") // TODO: use impl from std once it's stabilized
298 51 : .collect();
299 51 :
300 51 : // Don't even bother with empty options.
301 51 : if options.is_empty() {
302 9 : return None;
303 42 : }
304 42 :
305 42 : Some(options)
306 51 : }
307 :
308 : #[cfg(test)]
309 : mod tests {
310 : use super::*;
311 :
312 2 : #[test]
313 2 : fn test_filtered_options() {
314 2 : // Empty options is unlikely to be useful anyway.
315 2 : let params = StartupMessageParams::new([("options", "")]);
316 2 : assert_eq!(filtered_options(¶ms), None);
317 :
318 : // It's likely that clients will only use options to specify endpoint/project.
319 2 : let params = StartupMessageParams::new([("options", "project=foo")]);
320 2 : assert_eq!(filtered_options(¶ms), None);
321 :
322 : // Same, because unescaped whitespaces are no-op.
323 2 : let params = StartupMessageParams::new([("options", " project=foo ")]);
324 2 : assert_eq!(filtered_options(¶ms).as_deref(), None);
325 :
326 2 : let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]);
327 2 : assert_eq!(filtered_options(¶ms).as_deref(), Some(r"\ \ "));
328 :
329 2 : let params = StartupMessageParams::new([("options", "project = foo")]);
330 2 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
331 :
332 2 : let params = StartupMessageParams::new([(
333 2 : "options",
334 2 : "project = foo neon_endpoint_type:read_write neon_lsn:0/2",
335 2 : )]);
336 2 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
337 2 : }
338 : }
|