Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt::Write;
3 : use std::fs;
4 : use std::fs::File;
5 : use std::io::{BufRead, BufReader};
6 : use std::os::unix::fs::PermissionsExt;
7 : use std::path::Path;
8 : use std::process::Child;
9 : use std::str::FromStr;
10 : use std::time::{Duration, Instant};
11 :
12 : use anyhow::{Result, bail};
13 : use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
14 : use futures::StreamExt;
15 : use ini::Ini;
16 : use notify::{RecursiveMode, Watcher};
17 : use postgres::config::Config;
18 : use tokio::io::AsyncBufReadExt;
19 : use tokio::task::JoinHandle;
20 : use tokio::time::timeout;
21 : use tokio_postgres;
22 : use tokio_postgres::NoTls;
23 : use tracing::{debug, error, info, instrument};
24 :
25 : const POSTGRES_WAIT_TIMEOUT: Duration = Duration::from_millis(60 * 1000); // milliseconds
26 :
27 : /// Escape a string for including it in a SQL literal.
28 : ///
29 : /// Wrapping the result with `E'{}'` or `'{}'` is not required,
30 : /// as it returns a ready-to-use SQL string literal, e.g. `'db'''` or `E'db\\'`.
31 : /// See <https://github.com/postgres/postgres/blob/da98d005cdbcd45af563d0c4ac86d0e9772cd15f/src/backend/utils/adt/quote.c#L47>
32 : /// for the original implementation.
33 6 : pub fn escape_literal(s: &str) -> String {
34 6 : let res = s.replace('\'', "''").replace('\\', "\\\\");
35 6 :
36 6 : if res.contains('\\') {
37 2 : format!("E'{}'", res)
38 : } else {
39 4 : format!("'{}'", res)
40 : }
41 6 : }
42 :
43 : /// Escape a string so that it can be used in postgresql.conf. Wrapping the result
44 : /// with `'{}'` is not required, as it returns a ready-to-use config string.
45 8 : pub fn escape_conf_value(s: &str) -> String {
46 8 : let res = s.replace('\'', "''").replace('\\', "\\\\");
47 8 : format!("'{}'", res)
48 8 : }
49 :
50 : pub trait GenericOptionExt {
51 : fn to_pg_option(&self) -> String;
52 : fn to_pg_setting(&self) -> String;
53 : }
54 :
55 : impl GenericOptionExt for GenericOption {
56 : /// Represent `GenericOption` as SQL statement parameter.
57 3 : fn to_pg_option(&self) -> String {
58 3 : if let Some(val) = &self.value {
59 3 : match self.vartype.as_ref() {
60 3 : "string" => format!("{} {}", self.name, escape_literal(val)),
61 1 : _ => format!("{} {}", self.name, val),
62 : }
63 : } else {
64 0 : self.name.to_owned()
65 : }
66 3 : }
67 :
68 : /// Represent `GenericOption` as configuration option.
69 23 : fn to_pg_setting(&self) -> String {
70 23 : if let Some(val) = &self.value {
71 23 : match self.vartype.as_ref() {
72 23 : "string" => format!("{} = {}", self.name, escape_conf_value(val)),
73 15 : _ => format!("{} = {}", self.name, val),
74 : }
75 : } else {
76 0 : self.name.to_owned()
77 : }
78 23 : }
79 : }
80 :
81 : pub trait PgOptionsSerialize {
82 : fn as_pg_options(&self) -> String;
83 : fn as_pg_settings(&self) -> String;
84 : }
85 :
86 : impl PgOptionsSerialize for GenericOptions {
87 : /// Serialize an optional collection of `GenericOption`'s to
88 : /// Postgres SQL statement arguments.
89 2 : fn as_pg_options(&self) -> String {
90 2 : if let Some(ops) = &self {
91 1 : ops.iter()
92 3 : .map(|op| op.to_pg_option())
93 1 : .collect::<Vec<String>>()
94 1 : .join(" ")
95 : } else {
96 1 : "".to_string()
97 : }
98 2 : }
99 :
100 : /// Serialize an optional collection of `GenericOption`'s to
101 : /// `postgresql.conf` compatible format.
102 1 : fn as_pg_settings(&self) -> String {
103 1 : if let Some(ops) = &self {
104 1 : ops.iter()
105 23 : .map(|op| op.to_pg_setting())
106 1 : .collect::<Vec<String>>()
107 1 : .join("\n")
108 1 : + "\n" // newline after last setting
109 : } else {
110 0 : "".to_string()
111 : }
112 1 : }
113 : }
114 :
115 : pub trait GenericOptionsSearch {
116 : fn find(&self, name: &str) -> Option<String>;
117 : fn find_ref(&self, name: &str) -> Option<&GenericOption>;
118 : }
119 :
120 : impl GenericOptionsSearch for GenericOptions {
121 : /// Lookup option by name
122 9 : fn find(&self, name: &str) -> Option<String> {
123 9 : let ops = self.as_ref()?;
124 6 : let op = ops.iter().find(|s| s.name == name)?;
125 2 : op.value.clone()
126 9 : }
127 :
128 : /// Lookup option by name, returning ref
129 0 : fn find_ref(&self, name: &str) -> Option<&GenericOption> {
130 0 : let ops = self.as_ref()?;
131 0 : ops.iter().find(|s| s.name == name)
132 0 : }
133 : }
134 :
135 : pub trait RoleExt {
136 : fn to_pg_options(&self) -> String;
137 : }
138 :
139 : impl RoleExt for Role {
140 : /// Serialize a list of role parameters into a Postgres-acceptable
141 : /// string of arguments.
142 1 : fn to_pg_options(&self) -> String {
143 1 : // XXX: consider putting LOGIN as a default option somewhere higher, e.g. in control-plane.
144 1 : let mut params: String = self.options.as_pg_options();
145 1 : params.push_str(" LOGIN");
146 :
147 1 : if let Some(pass) = &self.encrypted_password {
148 : // Some time ago we supported only md5 and treated all encrypted_password as md5.
149 : // Now we also support SCRAM-SHA-256 and to preserve compatibility
150 : // we treat all encrypted_password as md5 unless they starts with SCRAM-SHA-256.
151 1 : if pass.starts_with("SCRAM-SHA-256") {
152 0 : write!(params, " PASSWORD '{pass}'")
153 0 : .expect("String is documented to not to error during write operations");
154 1 : } else {
155 1 : write!(params, " PASSWORD 'md5{pass}'")
156 1 : .expect("String is documented to not to error during write operations");
157 1 : }
158 0 : } else {
159 0 : params.push_str(" PASSWORD NULL");
160 0 : }
161 :
162 1 : params
163 1 : }
164 : }
165 :
166 : pub trait DatabaseExt {
167 : fn to_pg_options(&self) -> String;
168 : }
169 :
170 : impl DatabaseExt for Database {
171 : /// Serialize a list of database parameters into a Postgres-acceptable
172 : /// string of arguments.
173 : /// NB: `TEMPLATE` is actually also an identifier, but so far we only need
174 : /// to use `template0` and `template1`, so it is not a problem. Yet in the future
175 : /// it may require a proper quoting too.
176 1 : fn to_pg_options(&self) -> String {
177 1 : let mut params: String = self.options.as_pg_options();
178 1 : write!(params, " OWNER {}", &self.owner.pg_quote())
179 1 : .expect("String is documented to not to error during write operations");
180 1 :
181 1 : params
182 1 : }
183 : }
184 :
185 : /// Generic trait used to provide quoting / encoding for strings used in the
186 : /// Postgres SQL queries and DATABASE_URL.
187 : pub trait Escaping {
188 : fn pg_quote(&self) -> String;
189 : fn pg_quote_dollar(&self) -> (String, String);
190 : }
191 :
192 : impl Escaping for PgIdent {
193 : /// This is intended to mimic Postgres quote_ident(), but for simplicity it
194 : /// always quotes provided string with `""` and escapes every `"`.
195 : /// **Not idempotent**, i.e. if string is already escaped it will be escaped again.
196 : /// N.B. it's not useful for escaping identifiers that are used inside WHERE
197 : /// clause, use `escape_literal()` instead.
198 2 : fn pg_quote(&self) -> String {
199 2 : format!("\"{}\"", self.replace('"', "\"\""))
200 2 : }
201 :
202 : /// This helper is intended to be used for dollar-escaping strings for usage
203 : /// inside PL/pgSQL procedures. In addition to dollar-escaping the string,
204 : /// it also returns a tag that is intended to be used inside the outer
205 : /// PL/pgSQL procedure. If you do not need an outer tag, just discard it.
206 : /// Here we somewhat mimic the logic of Postgres' `pg_get_functiondef()`,
207 : /// <https://github.com/postgres/postgres/blob/8b49392b270b4ac0b9f5c210e2a503546841e832/src/backend/utils/adt/ruleutils.c#L2924>
208 5 : fn pg_quote_dollar(&self) -> (String, String) {
209 5 : let mut tag: String = "".to_string();
210 5 : let mut outer_tag = "x".to_string();
211 :
212 : // Find the first suitable tag that is not present in the string.
213 : // Postgres' max role/DB name length is 63 bytes, so even in the
214 : // worst case it won't take long.
215 10 : while self.contains(&format!("${tag}$")) || self.contains(&format!("${outer_tag}$")) {
216 5 : tag += "x";
217 5 : outer_tag = tag.clone() + "x";
218 5 : }
219 :
220 5 : let escaped = format!("${tag}${self}${tag}$");
221 5 :
222 5 : (escaped, outer_tag)
223 5 : }
224 : }
225 :
226 : /// Build a list of existing Postgres roles
227 0 : pub async fn get_existing_roles_async(client: &tokio_postgres::Client) -> Result<Vec<Role>> {
228 0 : let postgres_roles = client
229 0 : .query_raw::<str, &String, &[String; 0]>(
230 0 : "SELECT rolname, rolpassword FROM pg_catalog.pg_authid",
231 0 : &[],
232 0 : )
233 0 : .await?
234 0 : .filter_map(|row| async { row.ok() })
235 0 : .map(|row| Role {
236 0 : name: row.get("rolname"),
237 0 : encrypted_password: row.get("rolpassword"),
238 0 : options: None,
239 0 : })
240 0 : .collect()
241 0 : .await;
242 :
243 0 : Ok(postgres_roles)
244 0 : }
245 :
246 : /// Build a list of existing Postgres databases
247 0 : pub async fn get_existing_dbs_async(
248 0 : client: &tokio_postgres::Client,
249 0 : ) -> Result<HashMap<String, Database>> {
250 : // `pg_database.datconnlimit = -2` means that the database is in the
251 : // invalid state. See:
252 : // https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
253 0 : let rowstream = client
254 0 : // We use a subquery instead of a fancy `datdba::regrole::text AS owner`,
255 0 : // because the latter automatically wraps the result in double quotes,
256 0 : // if the role name contains special characters.
257 0 : .query_raw::<str, &String, &[String; 0]>(
258 0 : "SELECT
259 0 : datname AS name,
260 0 : (SELECT rolname FROM pg_roles WHERE oid = datdba) AS owner,
261 0 : NOT datallowconn AS restrict_conn,
262 0 : datconnlimit = - 2 AS invalid
263 0 : FROM
264 0 : pg_catalog.pg_database;",
265 0 : &[],
266 0 : )
267 0 : .await?;
268 :
269 0 : let dbs_map = rowstream
270 0 : .filter_map(|r| async { r.ok() })
271 0 : .map(|row| Database {
272 0 : name: row.get("name"),
273 0 : owner: row.get("owner"),
274 0 : restrict_conn: row.get("restrict_conn"),
275 0 : invalid: row.get("invalid"),
276 0 : options: None,
277 0 : })
278 0 : .map(|db| (db.name.clone(), db.clone()))
279 0 : .collect::<HashMap<_, _>>()
280 0 : .await;
281 :
282 0 : Ok(dbs_map)
283 0 : }
284 :
285 : /// Wait for Postgres to become ready to accept connections. It's ready to
286 : /// accept connections when the state-field in `pgdata/postmaster.pid` says
287 : /// 'ready'.
288 : #[instrument(skip_all, fields(pgdata = %pgdata.display()))]
289 : pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
290 : let pid_path = pgdata.join("postmaster.pid");
291 :
292 : // PostgreSQL writes line "ready" to the postmaster.pid file, when it has
293 : // completed initialization and is ready to accept connections. We want to
294 : // react quickly and perform the rest of our initialization as soon as
295 : // PostgreSQL starts accepting connections. Use 'notify' to be notified
296 : // whenever the PID file is changed, and whenever it changes, read it to
297 : // check if it's now "ready".
298 : //
299 : // You cannot actually watch a file before it exists, so we first watch the
300 : // data directory, and once the postmaster.pid file appears, we switch to
301 : // watch the file instead. We also wake up every 100 ms to poll, just in
302 : // case we miss some events for some reason. Not strictly necessary, but
303 : // better safe than sorry.
304 : let (tx, rx) = std::sync::mpsc::channel();
305 0 : let watcher_res = notify::recommended_watcher(move |res| {
306 0 : let _ = tx.send(res);
307 0 : });
308 : let (mut watcher, rx): (Box<dyn Watcher>, _) = match watcher_res {
309 : Ok(watcher) => (Box::new(watcher), rx),
310 : Err(e) => {
311 : match e.kind {
312 : notify::ErrorKind::Io(os) if os.raw_os_error() == Some(38) => {
313 : // docker on m1 macs does not support recommended_watcher
314 : // but return "Function not implemented (os error 38)"
315 : // see https://github.com/notify-rs/notify/issues/423
316 : let (tx, rx) = std::sync::mpsc::channel();
317 :
318 : // let's poll it faster than what we check the results for (100ms)
319 : let config =
320 : notify::Config::default().with_poll_interval(Duration::from_millis(50));
321 :
322 : let watcher = notify::PollWatcher::new(
323 0 : move |res| {
324 0 : let _ = tx.send(res);
325 0 : },
326 : config,
327 : )?;
328 :
329 : (Box::new(watcher), rx)
330 : }
331 : _ => return Err(e.into()),
332 : }
333 : }
334 : };
335 :
336 : watcher.watch(pgdata, RecursiveMode::NonRecursive)?;
337 :
338 : let started_at = Instant::now();
339 : let mut postmaster_pid_seen = false;
340 : loop {
341 : if let Ok(Some(status)) = pg.try_wait() {
342 : // Postgres exited, that is not what we expected, bail out earlier.
343 : let code = status.code().unwrap_or(-1);
344 : bail!("Postgres exited unexpectedly with code {}", code);
345 : }
346 :
347 : let res = rx.recv_timeout(Duration::from_millis(100));
348 : debug!("woken up by notify: {res:?}");
349 : // If there are multiple events in the channel already, we only need to be
350 : // check once. Swallow the extra events before we go ahead to check the
351 : // pid file.
352 : while let Ok(res) = rx.try_recv() {
353 : debug!("swallowing extra event: {res:?}");
354 : }
355 :
356 : // Check that we can open pid file first.
357 : if let Ok(file) = File::open(&pid_path) {
358 : if !postmaster_pid_seen {
359 : debug!("postmaster.pid appeared");
360 : watcher
361 : .unwatch(pgdata)
362 : .expect("Failed to remove pgdata dir watch");
363 : watcher
364 : .watch(&pid_path, RecursiveMode::NonRecursive)
365 : .expect("Failed to add postmaster.pid file watch");
366 : postmaster_pid_seen = true;
367 : }
368 :
369 : let file = BufReader::new(file);
370 : let last_line = file.lines().last();
371 :
372 : // Pid file could be there and we could read it, but it could be empty, for example.
373 : if let Some(Ok(line)) = last_line {
374 : let status = line.trim();
375 : debug!("last line of postmaster.pid: {status:?}");
376 :
377 : // Now Postgres is ready to accept connections
378 : if status == "ready" {
379 : break;
380 : }
381 : }
382 : }
383 :
384 : // Give up after POSTGRES_WAIT_TIMEOUT.
385 : let duration = started_at.elapsed();
386 : if duration >= POSTGRES_WAIT_TIMEOUT {
387 : bail!("timed out while waiting for Postgres to start");
388 : }
389 : }
390 :
391 : tracing::info!("PostgreSQL is now running, continuing to configure it");
392 :
393 : Ok(())
394 : }
395 :
396 : /// Remove `pgdata` directory and create it again with right permissions.
397 0 : pub fn create_pgdata(pgdata: &str) -> Result<()> {
398 0 : // Ignore removal error, likely it is a 'No such file or directory (os error 2)'.
399 0 : // If it is something different then create_dir() will error out anyway.
400 0 : let _ok = fs::remove_dir_all(pgdata);
401 0 : fs::create_dir(pgdata)?;
402 0 : fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
403 :
404 0 : Ok(())
405 0 : }
406 :
407 : /// Update pgbouncer.ini with provided options
408 0 : fn update_pgbouncer_ini(
409 0 : pgbouncer_config: HashMap<String, String>,
410 0 : pgbouncer_ini_path: &str,
411 0 : ) -> Result<()> {
412 0 : let mut conf = Ini::load_from_file(pgbouncer_ini_path)?;
413 0 : let section = conf.section_mut(Some("pgbouncer")).unwrap();
414 :
415 0 : for (option_name, value) in pgbouncer_config.iter() {
416 0 : section.insert(option_name, value);
417 0 : debug!(
418 0 : "Updating pgbouncer.ini with new values {}={}",
419 : option_name, value
420 : );
421 : }
422 :
423 0 : conf.write_to_file(pgbouncer_ini_path)?;
424 0 : Ok(())
425 0 : }
426 :
427 : /// Tune pgbouncer.
428 : /// 1. Apply new config using pgbouncer admin console
429 : /// 2. Add new values to pgbouncer.ini to preserve them after restart
430 0 : pub async fn tune_pgbouncer(pgbouncer_config: HashMap<String, String>) -> Result<()> {
431 0 : let pgbouncer_connstr = if std::env::var_os("AUTOSCALING").is_some() {
432 : // for VMs use pgbouncer specific way to connect to
433 : // pgbouncer admin console without password
434 : // when pgbouncer is running under the same user.
435 0 : "host=/tmp port=6432 dbname=pgbouncer user=pgbouncer".to_string()
436 : } else {
437 : // for k8s use normal connection string with password
438 : // to connect to pgbouncer admin console
439 0 : let mut pgbouncer_connstr =
440 0 : "host=localhost port=6432 dbname=pgbouncer user=postgres sslmode=disable".to_string();
441 0 : if let Ok(pass) = std::env::var("PGBOUNCER_PASSWORD") {
442 0 : pgbouncer_connstr.push_str(format!(" password={}", pass).as_str());
443 0 : }
444 0 : pgbouncer_connstr
445 : };
446 :
447 0 : info!(
448 0 : "Connecting to pgbouncer with connection string: {}",
449 : pgbouncer_connstr
450 : );
451 :
452 : // connect to pgbouncer, retrying several times
453 : // because pgbouncer may not be ready yet
454 0 : let mut retries = 3;
455 0 : let client = loop {
456 0 : match tokio_postgres::connect(&pgbouncer_connstr, NoTls).await {
457 0 : Ok((client, connection)) => {
458 0 : tokio::spawn(async move {
459 0 : if let Err(e) = connection.await {
460 0 : eprintln!("connection error: {}", e);
461 0 : }
462 0 : });
463 0 : break client;
464 : }
465 0 : Err(e) => {
466 0 : if retries == 0 {
467 0 : return Err(e.into());
468 0 : }
469 0 : error!("Failed to connect to pgbouncer: pgbouncer_connstr {}", e);
470 0 : retries -= 1;
471 0 : tokio::time::sleep(Duration::from_secs(1)).await;
472 : }
473 : }
474 : };
475 :
476 : // Apply new config
477 0 : for (option_name, value) in pgbouncer_config.iter() {
478 0 : let query = format!("SET {}={}", option_name, value);
479 0 : // keep this log line for debugging purposes
480 0 : info!("Applying pgbouncer setting change: {}", query);
481 :
482 0 : if let Err(err) = client.simple_query(&query).await {
483 : // Don't fail on error, just print it into log
484 0 : error!(
485 0 : "Failed to apply pgbouncer setting change: {}, {}",
486 : query, err
487 : );
488 0 : };
489 : }
490 :
491 : // save values to pgbouncer.ini
492 : // so that they are preserved after pgbouncer restart
493 0 : let pgbouncer_ini_path = if std::env::var_os("AUTOSCALING").is_some() {
494 : // in VMs we use /etc/pgbouncer.ini
495 0 : "/etc/pgbouncer.ini".to_string()
496 : } else {
497 : // in pods we use /var/db/postgres/pgbouncer/pgbouncer.ini
498 : // this is a shared volume between pgbouncer and postgres containers
499 : // FIXME: fix permissions for this file
500 0 : "/var/db/postgres/pgbouncer/pgbouncer.ini".to_string()
501 : };
502 0 : update_pgbouncer_ini(pgbouncer_config, &pgbouncer_ini_path)?;
503 :
504 0 : Ok(())
505 0 : }
506 :
507 : /// Spawn a task that will read Postgres logs from `stderr`, join multiline logs
508 : /// and send them to the logger. In the future we may also want to add context to
509 : /// these logs.
510 0 : pub fn handle_postgres_logs(stderr: std::process::ChildStderr) -> JoinHandle<Result<()>> {
511 0 : tokio::spawn(async move {
512 0 : let stderr = tokio::process::ChildStderr::from_std(stderr)?;
513 0 : handle_postgres_logs_async(stderr).await
514 0 : })
515 0 : }
516 :
517 : /// Read Postgres logs from `stderr` until EOF. Buffer is flushed on one of the following conditions:
518 : /// - next line starts with timestamp
519 : /// - EOF
520 : /// - no new lines were written for the last 100 milliseconds
521 0 : async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Result<()> {
522 0 : let mut lines = tokio::io::BufReader::new(stderr).lines();
523 0 : let timeout_duration = Duration::from_millis(100);
524 0 : let ts_regex =
525 0 : regex::Regex::new(r"^\d+-\d{2}-\d{2} \d{2}:\d{2}:\d{2}").expect("regex is valid");
526 0 :
527 0 : let mut buf = vec![];
528 : loop {
529 0 : let next_line = timeout(timeout_duration, lines.next_line()).await;
530 :
531 : // we should flush lines from the buffer if we cannot continue reading multiline message
532 0 : let should_flush_buf = match next_line {
533 : // Flushing if new line starts with timestamp
534 0 : Ok(Ok(Some(ref line))) => ts_regex.is_match(line),
535 : // Flushing on EOF, timeout or error
536 0 : _ => true,
537 : };
538 :
539 0 : if !buf.is_empty() && should_flush_buf {
540 : // join multiline message into a single line, separated by unicode Zero Width Space.
541 : // "PG:" suffix is used to distinguish postgres logs from other logs.
542 0 : let combined = format!("PG:{}\n", buf.join("\u{200B}"));
543 0 : buf.clear();
544 :
545 : // sync write to stderr to avoid interleaving with other logs
546 : use std::io::Write;
547 0 : let res = std::io::stderr().lock().write_all(combined.as_bytes());
548 0 : if let Err(e) = res {
549 0 : tracing::error!("error while writing to stderr: {}", e);
550 0 : }
551 0 : }
552 :
553 : // if not timeout, append line to the buffer
554 0 : if next_line.is_ok() {
555 0 : match next_line?? {
556 0 : Some(line) => buf.push(line),
557 : // EOF
558 0 : None => break,
559 : };
560 0 : }
561 : }
562 :
563 0 : Ok(())
564 0 : }
565 :
566 : /// `Postgres::config::Config` handles database names with whitespaces
567 : /// and special characters properly.
568 0 : pub fn postgres_conf_for_db(connstr: &url::Url, dbname: &str) -> Result<Config> {
569 0 : let mut conf = Config::from_str(connstr.as_str())?;
570 0 : conf.dbname(dbname);
571 0 : Ok(conf)
572 0 : }
|