TLA Line data Source code
1 : use crate::{
2 : auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
3 : context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
4 : proxy::neon_option,
5 : };
6 : use futures::{FutureExt, TryFutureExt};
7 : use itertools::Itertools;
8 : use metrics::IntCounterPairGuard;
9 : use pq_proto::StartupMessageParams;
10 : use std::{io, net::SocketAddr, time::Duration};
11 : use thiserror::Error;
12 : use tokio::net::TcpStream;
13 : use tokio_postgres::tls::MakeTlsConnect;
14 : use tracing::{error, info, warn};
15 :
16 : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
17 :
18 UBC 0 : #[derive(Debug, Error)]
19 : pub enum ConnectionError {
20 : /// This error doesn't seem to reveal any secrets; for instance,
21 : /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
22 : #[error("{COULD_NOT_CONNECT}: {0}")]
23 : Postgres(#[from] tokio_postgres::Error),
24 :
25 : #[error("{COULD_NOT_CONNECT}: {0}")]
26 : CouldNotConnect(#[from] io::Error),
27 :
28 : #[error("{COULD_NOT_CONNECT}: {0}")]
29 : TlsError(#[from] native_tls::Error),
30 :
31 : #[error("{COULD_NOT_CONNECT}: {0}")]
32 : WakeComputeError(#[from] WakeComputeError),
33 : }
34 :
35 : impl UserFacingError for ConnectionError {
36 0 : fn to_string_client(&self) -> String {
37 0 : use ConnectionError::*;
38 0 : match self {
39 : // This helps us drop irrelevant library-specific prefixes.
40 : // TODO: propagate severity level and other parameters.
41 0 : Postgres(err) => match err.as_db_error() {
42 0 : Some(err) => err.message().to_owned(),
43 0 : None => err.to_string(),
44 : },
45 0 : WakeComputeError(err) => err.to_string_client(),
46 0 : _ => COULD_NOT_CONNECT.to_owned(),
47 : }
48 0 : }
49 : }
50 :
51 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
52 : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
53 :
54 : /// A config for establishing a connection to compute node.
55 : /// Eventually, `tokio_postgres` will be replaced with something better.
56 : /// Newtype allows us to implement methods on top of it.
57 0 : #[derive(Clone)]
58 : #[repr(transparent)]
59 : pub struct ConnCfg(Box<tokio_postgres::Config>);
60 :
61 : /// Creation and initialization routines.
62 : impl ConnCfg {
63 CBC 92 : pub fn new() -> Self {
64 92 : Self(Default::default())
65 92 : }
66 :
67 : /// Reuse password or auth keys from the other config.
68 7 : pub fn reuse_password(&mut self, other: &Self) {
69 7 : if let Some(password) = other.get_password() {
70 UBC 0 : self.password(password);
71 CBC 7 : }
72 :
73 7 : if let Some(keys) = other.get_auth_keys() {
74 UBC 0 : self.auth_keys(keys);
75 CBC 7 : }
76 7 : }
77 :
78 : /// Apply startup message params to the connection config.
79 38 : pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
80 : // Only set `user` if it's not present in the config.
81 : // Link auth flow takes username from the console's response.
82 38 : if let (None, Some(user)) = (self.get_user(), params.get("user")) {
83 35 : self.user(user);
84 35 : }
85 :
86 : // Only set `dbname` if it's not present in the config.
87 : // Link auth flow takes dbname from the console's response.
88 38 : if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
89 35 : self.dbname(dbname);
90 35 : }
91 :
92 : // Don't add `options` if they were only used for specifying a project.
93 : // Connection pools don't support `options`, because they affect backend startup.
94 38 : if let Some(options) = filtered_options(params) {
95 35 : self.options(&options);
96 35 : }
97 :
98 38 : if let Some(app_name) = params.get("application_name") {
99 3 : self.application_name(app_name);
100 35 : }
101 :
102 : // TODO: This is especially ugly...
103 38 : if let Some(replication) = params.get("replication") {
104 : use tokio_postgres::config::ReplicationMode;
105 UBC 0 : match replication {
106 0 : "true" | "on" | "yes" | "1" => {
107 0 : self.replication_mode(ReplicationMode::Physical);
108 0 : }
109 0 : "database" => {
110 0 : self.replication_mode(ReplicationMode::Logical);
111 0 : }
112 0 : _other => {}
113 : }
114 CBC 38 : }
115 :
116 : // TODO: extend the list of the forwarded startup parameters.
117 : // Currently, tokio-postgres doesn't allow us to pass
118 : // arbitrary parameters, but the ones above are a good start.
119 : //
120 : // This and the reverse params problem can be better addressed
121 : // in a bespoke connection machinery (a new library for that sake).
122 38 : }
123 : }
124 :
125 : impl std::ops::Deref for ConnCfg {
126 : type Target = tokio_postgres::Config;
127 :
128 132 : fn deref(&self) -> &Self::Target {
129 132 : &self.0
130 132 : }
131 : }
132 :
133 : /// For now, let's make it easier to setup the config.
134 : impl std::ops::DerefMut for ConnCfg {
135 226 : fn deref_mut(&mut self) -> &mut Self::Target {
136 226 : &mut self.0
137 226 : }
138 : }
139 :
140 : impl Default for ConnCfg {
141 UBC 0 : fn default() -> Self {
142 0 : Self::new()
143 0 : }
144 : }
145 :
146 : impl ConnCfg {
147 : /// Establish a raw TCP connection to the compute node.
148 CBC 38 : async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
149 38 : use tokio_postgres::config::Host;
150 38 :
151 38 : // wrap TcpStream::connect with timeout
152 38 : let connect_with_timeout = |host, port| {
153 38 : tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
154 38 : move |res| match res {
155 38 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
156 UBC 0 : Err(_) => Err(io::Error::new(
157 0 : io::ErrorKind::TimedOut,
158 0 : format!("exceeded connection timeout {timeout:?}"),
159 0 : )),
160 CBC 38 : },
161 38 : )
162 38 : };
163 38 :
164 38 : let connect_once = |host, port| {
165 38 : info!("trying to connect to compute node at {host}:{port}");
166 38 : connect_with_timeout(host, port).and_then(|socket| async {
167 38 : let socket_addr = socket.peer_addr()?;
168 : // This prevents load balancer from severing the connection.
169 38 : socket2::SockRef::from(&socket).set_keepalive(true)?;
170 38 : Ok((socket_addr, socket))
171 38 : })
172 38 : };
173 38 :
174 38 : // We can't reuse connection establishing logic from `tokio_postgres` here,
175 38 : // because it has no means for extracting the underlying socket which we
176 38 : // require for our business.
177 38 : let mut connection_error = None;
178 38 : let ports = self.0.get_ports();
179 38 : let hosts = self.0.get_hosts();
180 38 : // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
181 38 : if ports.len() > 1 && ports.len() != hosts.len() {
182 UBC 0 : return Err(io::Error::new(
183 0 : io::ErrorKind::Other,
184 0 : format!(
185 0 : "bad compute config, \
186 0 : ports and hosts entries' count does not match: {:?}",
187 0 : self.0
188 0 : ),
189 0 : ));
190 CBC 38 : }
191 :
192 38 : for (i, host) in hosts.iter().enumerate() {
193 38 : let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
194 38 : let host = match host {
195 38 : Host::Tcp(host) => host.as_str(),
196 UBC 0 : Host::Unix(_) => continue, // unix sockets are not welcome here
197 : };
198 :
199 CBC 68 : match connect_once(host, *port).await {
200 38 : Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
201 UBC 0 : Err(err) => {
202 : // We can't throw an error here, as there might be more hosts to try.
203 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
204 0 : connection_error = Some(err);
205 : }
206 : }
207 : }
208 :
209 0 : Err(connection_error.unwrap_or_else(|| {
210 0 : io::Error::new(
211 0 : io::ErrorKind::Other,
212 0 : format!("bad compute config: {:?}", self.0),
213 0 : )
214 0 : }))
215 CBC 38 : }
216 : }
217 :
218 : pub struct PostgresConnection {
219 : /// Socket connected to a compute node.
220 : pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
221 : tokio::net::TcpStream,
222 : postgres_native_tls::TlsStream<tokio::net::TcpStream>,
223 : >,
224 : /// PostgreSQL connection parameters.
225 : pub params: std::collections::HashMap<String, String>,
226 : /// Query cancellation token.
227 : pub cancel_closure: CancelClosure,
228 :
229 : _guage: IntCounterPairGuard,
230 : }
231 :
232 : impl ConnCfg {
233 : /// Connect to a corresponding compute node.
234 38 : pub async fn connect(
235 38 : &self,
236 38 : ctx: &mut RequestMonitoring,
237 38 : allow_self_signed_compute: bool,
238 38 : timeout: Duration,
239 38 : ) -> Result<PostgresConnection, ConnectionError> {
240 68 : let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
241 :
242 38 : let tls_connector = native_tls::TlsConnector::builder()
243 38 : .danger_accept_invalid_certs(allow_self_signed_compute)
244 38 : .build()
245 38 : .unwrap();
246 38 : let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
247 38 : let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
248 :
249 : // connect_raw() will not use TLS if sslmode is "disable"
250 42 : let (client, connection) = self.0.connect_raw(stream, tls).await?;
251 38 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
252 38 : let stream = connection.stream.into_inner();
253 :
254 38 : info!(
255 38 : "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
256 38 : self.0.get_ssl_mode()
257 38 : );
258 :
259 : // This is very ugly but as of now there's no better way to
260 : // extract the connection parameters from tokio-postgres' connection.
261 : // TODO: solve this problem in a more elegant manner (e.g. the new library).
262 38 : let params = connection.parameters;
263 38 :
264 38 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
265 38 : // Yet another reason to rework the connection establishing code.
266 38 : let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
267 38 :
268 38 : let connection = PostgresConnection {
269 38 : stream,
270 38 : params,
271 38 : cancel_closure,
272 38 : _guage: NUM_DB_CONNECTIONS_GAUGE
273 38 : .with_label_values(&[ctx.protocol])
274 38 : .guard(),
275 38 : };
276 38 :
277 38 : Ok(connection)
278 38 : }
279 : }
280 :
281 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
282 44 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
283 : #[allow(unstable_name_collisions)]
284 44 : let options: String = params
285 44 : .options_raw()?
286 71 : .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
287 44 : .intersperse(" ") // TODO: use impl from std once it's stabilized
288 44 : .collect();
289 44 :
290 44 : // Don't even bother with empty options.
291 44 : if options.is_empty() {
292 6 : return None;
293 38 : }
294 38 :
295 38 : Some(options)
296 44 : }
297 :
298 : #[cfg(test)]
299 : mod tests {
300 : use super::*;
301 :
302 1 : #[test]
303 1 : fn test_filtered_options() {
304 1 : // Empty options is unlikely to be useful anyway.
305 1 : let params = StartupMessageParams::new([("options", "")]);
306 1 : assert_eq!(filtered_options(¶ms), None);
307 :
308 : // It's likely that clients will only use options to specify endpoint/project.
309 1 : let params = StartupMessageParams::new([("options", "project=foo")]);
310 1 : assert_eq!(filtered_options(¶ms), None);
311 :
312 : // Same, because unescaped whitespaces are no-op.
313 1 : let params = StartupMessageParams::new([("options", " project=foo ")]);
314 1 : assert_eq!(filtered_options(¶ms).as_deref(), None);
315 :
316 1 : let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]);
317 1 : assert_eq!(filtered_options(¶ms).as_deref(), Some(r"\ \ "));
318 :
319 1 : let params = StartupMessageParams::new([("options", "project = foo")]);
320 1 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
321 :
322 1 : let params = StartupMessageParams::new([(
323 1 : "options",
324 1 : "project = foo neon_endpoint_type:read_write neon_lsn:0/2",
325 1 : )]);
326 1 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
327 1 : }
328 : }
|