TLA Line data Source code
1 : use crate::{
2 : auth::parse_endpoint_param,
3 : cancellation::CancelClosure,
4 : console::errors::WakeComputeError,
5 : error::{io_error, UserFacingError},
6 : };
7 : use futures::{FutureExt, TryFutureExt};
8 : use itertools::Itertools;
9 : use pq_proto::StartupMessageParams;
10 : use std::{io, net::SocketAddr, time::Duration};
11 : use thiserror::Error;
12 : use tokio::net::TcpStream;
13 : use tokio_postgres::tls::MakeTlsConnect;
14 : use tracing::{error, info, warn};
15 :
16 : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
17 :
18 UBC 0 : #[derive(Debug, Error)]
19 : pub enum ConnectionError {
20 : /// This error doesn't seem to reveal any secrets; for instance,
21 : /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
22 : #[error("{COULD_NOT_CONNECT}: {0}")]
23 : Postgres(#[from] tokio_postgres::Error),
24 :
25 : #[error("{COULD_NOT_CONNECT}: {0}")]
26 : CouldNotConnect(#[from] io::Error),
27 :
28 : #[error("{COULD_NOT_CONNECT}: {0}")]
29 : TlsError(#[from] native_tls::Error),
30 : }
31 :
32 : impl From<WakeComputeError> for ConnectionError {
33 0 : fn from(value: WakeComputeError) -> Self {
34 0 : io_error(value).into()
35 0 : }
36 : }
37 :
38 : impl UserFacingError for ConnectionError {
39 0 : fn to_string_client(&self) -> String {
40 0 : use ConnectionError::*;
41 0 : match self {
42 : // This helps us drop irrelevant library-specific prefixes.
43 : // TODO: propagate severity level and other parameters.
44 0 : Postgres(err) => match err.as_db_error() {
45 0 : Some(err) => err.message().to_owned(),
46 0 : None => err.to_string(),
47 : },
48 0 : _ => COULD_NOT_CONNECT.to_owned(),
49 : }
50 0 : }
51 : }
52 :
53 : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
54 : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
55 :
56 : /// A config for establishing a connection to compute node.
57 : /// Eventually, `tokio_postgres` will be replaced with something better.
58 : /// Newtype allows us to implement methods on top of it.
59 0 : #[derive(Clone)]
60 : #[repr(transparent)]
61 : pub struct ConnCfg(Box<tokio_postgres::Config>);
62 :
63 : /// Creation and initialization routines.
64 : impl ConnCfg {
65 CBC 70 : pub fn new() -> Self {
66 70 : Self(Default::default())
67 70 : }
68 :
69 : /// Reuse password or auth keys from the other config.
70 : pub fn reuse_password(&mut self, other: &Self) {
71 7 : if let Some(password) = other.get_password() {
72 UBC 0 : self.password(password);
73 CBC 7 : }
74 :
75 7 : if let Some(keys) = other.get_auth_keys() {
76 UBC 0 : self.auth_keys(keys);
77 CBC 7 : }
78 7 : }
79 :
80 : /// Apply startup message params to the connection config.
81 : pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
82 : // Only set `user` if it's not present in the config.
83 : // Link auth flow takes username from the console's response.
84 29 : if let (None, Some(user)) = (self.get_user(), params.get("user")) {
85 26 : self.user(user);
86 26 : }
87 :
88 : // Only set `dbname` if it's not present in the config.
89 : // Link auth flow takes dbname from the console's response.
90 29 : if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
91 26 : self.dbname(dbname);
92 26 : }
93 :
94 : // Don't add `options` if they were only used for specifying a project.
95 : // Connection pools don't support `options`, because they affect backend startup.
96 29 : if let Some(options) = filtered_options(params) {
97 26 : self.options(&options);
98 26 : }
99 :
100 29 : if let Some(app_name) = params.get("application_name") {
101 3 : self.application_name(app_name);
102 26 : }
103 :
104 : // TODO: This is especially ugly...
105 29 : if let Some(replication) = params.get("replication") {
106 : use tokio_postgres::config::ReplicationMode;
107 UBC 0 : match replication {
108 0 : "true" | "on" | "yes" | "1" => {
109 0 : self.replication_mode(ReplicationMode::Physical);
110 0 : }
111 0 : "database" => {
112 0 : self.replication_mode(ReplicationMode::Logical);
113 0 : }
114 0 : _other => {}
115 : }
116 CBC 29 : }
117 :
118 : // TODO: extend the list of the forwarded startup parameters.
119 : // Currently, tokio-postgres doesn't allow us to pass
120 : // arbitrary parameters, but the ones above are a good start.
121 : //
122 : // This and the reverse params problem can be better addressed
123 : // in a bespoke connection machinery (a new library for that sake).
124 29 : }
125 : }
126 :
127 : impl std::ops::Deref for ConnCfg {
128 : type Target = tokio_postgres::Config;
129 :
130 101 : fn deref(&self) -> &Self::Target {
131 101 : &self.0
132 101 : }
133 : }
134 :
135 : /// For now, let's make it easier to setup the config.
136 : impl std::ops::DerefMut for ConnCfg {
137 168 : fn deref_mut(&mut self) -> &mut Self::Target {
138 168 : &mut self.0
139 168 : }
140 : }
141 :
142 : impl Default for ConnCfg {
143 UBC 0 : fn default() -> Self {
144 0 : Self::new()
145 0 : }
146 : }
147 :
148 : impl ConnCfg {
149 : /// Establish a raw TCP connection to the compute node.
150 CBC 29 : async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
151 29 : use tokio_postgres::config::Host;
152 29 :
153 29 : // wrap TcpStream::connect with timeout
154 29 : let connect_with_timeout = |host, port| {
155 29 : tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
156 29 : move |res| match res {
157 29 : Ok(tcpstream_connect_res) => tcpstream_connect_res,
158 UBC 0 : Err(_) => Err(io::Error::new(
159 0 : io::ErrorKind::TimedOut,
160 0 : format!("exceeded connection timeout {timeout:?}"),
161 0 : )),
162 CBC 29 : },
163 29 : )
164 29 : };
165 29 :
166 29 : let connect_once = |host, port| {
167 29 : info!("trying to connect to compute node at {host}:{port}");
168 29 : connect_with_timeout(host, port).and_then(|socket| async {
169 29 : let socket_addr = socket.peer_addr()?;
170 : // This prevents load balancer from severing the connection.
171 29 : socket2::SockRef::from(&socket).set_keepalive(true)?;
172 29 : Ok((socket_addr, socket))
173 29 : })
174 29 : };
175 29 :
176 29 : // We can't reuse connection establishing logic from `tokio_postgres` here,
177 29 : // because it has no means for extracting the underlying socket which we
178 29 : // require for our business.
179 29 : let mut connection_error = None;
180 29 : let ports = self.0.get_ports();
181 29 : let hosts = self.0.get_hosts();
182 29 : // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
183 29 : if ports.len() > 1 && ports.len() != hosts.len() {
184 UBC 0 : return Err(io::Error::new(
185 0 : io::ErrorKind::Other,
186 0 : format!(
187 0 : "bad compute config, \
188 0 : ports and hosts entries' count does not match: {:?}",
189 0 : self.0
190 0 : ),
191 0 : ));
192 CBC 29 : }
193 :
194 29 : for (i, host) in hosts.iter().enumerate() {
195 29 : let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
196 29 : let host = match host {
197 29 : Host::Tcp(host) => host.as_str(),
198 UBC 0 : Host::Unix(_) => continue, // unix sockets are not welcome here
199 : };
200 :
201 CBC 58 : match connect_once(host, *port).await {
202 29 : Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
203 UBC 0 : Err(err) => {
204 : // We can't throw an error here, as there might be more hosts to try.
205 0 : warn!("couldn't connect to compute node at {host}:{port}: {err}");
206 0 : connection_error = Some(err);
207 : }
208 : }
209 : }
210 :
211 0 : Err(connection_error.unwrap_or_else(|| {
212 0 : io::Error::new(
213 0 : io::ErrorKind::Other,
214 0 : format!("bad compute config: {:?}", self.0),
215 0 : )
216 0 : }))
217 CBC 29 : }
218 : }
219 :
220 : pub struct PostgresConnection {
221 : /// Socket connected to a compute node.
222 : pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
223 : tokio::net::TcpStream,
224 : postgres_native_tls::TlsStream<tokio::net::TcpStream>,
225 : >,
226 : /// PostgreSQL connection parameters.
227 : pub params: std::collections::HashMap<String, String>,
228 : /// Query cancellation token.
229 : pub cancel_closure: CancelClosure,
230 : }
231 :
232 : impl ConnCfg {
233 : /// Connect to a corresponding compute node.
234 29 : pub async fn connect(
235 29 : &self,
236 29 : allow_self_signed_compute: bool,
237 29 : timeout: Duration,
238 29 : ) -> Result<PostgresConnection, ConnectionError> {
239 58 : let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
240 :
241 29 : let tls_connector = native_tls::TlsConnector::builder()
242 29 : .danger_accept_invalid_certs(allow_self_signed_compute)
243 29 : .build()
244 29 : .unwrap();
245 29 : let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
246 29 : let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
247 :
248 : // connect_raw() will not use TLS if sslmode is "disable"
249 29 : let (client, connection) = self.0.connect_raw(stream, tls).await?;
250 29 : let stream = connection.stream.into_inner();
251 :
252 29 : info!(
253 29 : "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
254 29 : self.0.get_ssl_mode()
255 29 : );
256 :
257 : // This is very ugly but as of now there's no better way to
258 : // extract the connection parameters from tokio-postgres' connection.
259 : // TODO: solve this problem in a more elegant manner (e.g. the new library).
260 29 : let params = connection.parameters;
261 29 :
262 29 : // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
263 29 : // Yet another reason to rework the connection establishing code.
264 29 : let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
265 29 :
266 29 : let connection = PostgresConnection {
267 29 : stream,
268 29 : params,
269 29 : cancel_closure,
270 29 : };
271 29 :
272 29 : Ok(connection)
273 29 : }
274 : }
275 :
276 : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
277 34 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
278 : #[allow(unstable_name_collisions)]
279 34 : let options: String = params
280 34 : .options_raw()?
281 55 : .filter(|opt| parse_endpoint_param(opt).is_none())
282 34 : .intersperse(" ") // TODO: use impl from std once it's stabilized
283 34 : .collect();
284 34 :
285 34 : // Don't even bother with empty options.
286 34 : if options.is_empty() {
287 6 : return None;
288 28 : }
289 28 :
290 28 : Some(options)
291 34 : }
292 :
293 : #[cfg(test)]
294 : mod tests {
295 : use super::*;
296 :
297 1 : #[test]
298 1 : fn test_filtered_options() {
299 1 : // Empty options is unlikely to be useful anyway.
300 1 : let params = StartupMessageParams::new([("options", "")]);
301 1 : assert_eq!(filtered_options(¶ms), None);
302 :
303 : // It's likely that clients will only use options to specify endpoint/project.
304 1 : let params = StartupMessageParams::new([("options", "project=foo")]);
305 1 : assert_eq!(filtered_options(¶ms), None);
306 :
307 : // Same, because unescaped whitespaces are no-op.
308 1 : let params = StartupMessageParams::new([("options", " project=foo ")]);
309 1 : assert_eq!(filtered_options(¶ms).as_deref(), None);
310 :
311 1 : let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]);
312 1 : assert_eq!(filtered_options(¶ms).as_deref(), Some(r"\ \ "));
313 :
314 1 : let params = StartupMessageParams::new([("options", "project = foo")]);
315 1 : assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
316 1 : }
317 : }
|