LCOV - differential code coverage report
Current view: top level - compute_tools/src - pg_helpers.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 65.8 % 187 123 64 123
Current Date: 2024-01-09 02:06:09 Functions: 62.2 % 37 23 14 23
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use std::collections::HashMap;
       2                 : use std::fmt::Write;
       3                 : use std::fs;
       4                 : use std::fs::File;
       5                 : use std::io::{BufRead, BufReader};
       6                 : use std::os::unix::fs::PermissionsExt;
       7                 : use std::path::Path;
       8                 : use std::process::Child;
       9                 : use std::time::{Duration, Instant};
      10                 : 
      11                 : use anyhow::{bail, Result};
      12                 : use ini::Ini;
      13                 : use notify::{RecursiveMode, Watcher};
      14                 : use postgres::{Client, Transaction};
      15                 : use tokio_postgres::NoTls;
      16                 : use tracing::{debug, error, info, instrument};
      17                 : 
      18                 : use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
      19                 : 
      20                 : const POSTGRES_WAIT_TIMEOUT: Duration = Duration::from_millis(60 * 1000); // milliseconds
      21                 : 
      22                 : /// Escape a string for including it in a SQL literal. Wrapping the result
      23                 : /// with `E'{}'` or `'{}'` is not required, as it returns a ready-to-use
      24                 : /// SQL string literal, e.g. `'db'''` or `E'db\\'`.
      25                 : /// See <https://github.com/postgres/postgres/blob/da98d005cdbcd45af563d0c4ac86d0e9772cd15f/src/backend/utils/adt/quote.c#L47>
      26                 : /// for the original implementation.
      27 CBC           6 : pub fn escape_literal(s: &str) -> String {
      28               6 :     let res = s.replace('\'', "''").replace('\\', "\\\\");
      29               6 : 
      30               6 :     if res.contains('\\') {
      31               2 :         format!("E'{}'", res)
      32                 :     } else {
      33               4 :         format!("'{}'", res)
      34                 :     }
      35               6 : }
      36                 : 
      37                 : /// Escape a string so that it can be used in postgresql.conf. Wrapping the result
      38                 : /// with `'{}'` is not required, as it returns a ready-to-use config string.
      39            3002 : pub fn escape_conf_value(s: &str) -> String {
      40            3002 :     let res = s.replace('\'', "''").replace('\\', "\\\\");
      41            3002 :     format!("'{}'", res)
      42            3002 : }
      43                 : 
      44                 : trait GenericOptionExt {
      45                 :     fn to_pg_option(&self) -> String;
      46                 :     fn to_pg_setting(&self) -> String;
      47                 : }
      48                 : 
      49                 : impl GenericOptionExt for GenericOption {
      50                 :     /// Represent `GenericOption` as SQL statement parameter.
      51               3 :     fn to_pg_option(&self) -> String {
      52               3 :         if let Some(val) = &self.value {
      53               3 :             match self.vartype.as_ref() {
      54               3 :                 "string" => format!("{} {}", self.name, escape_literal(val)),
      55               1 :                 _ => format!("{} {}", self.name, val),
      56                 :             }
      57                 :         } else {
      58 UBC           0 :             self.name.to_owned()
      59                 :         }
      60 CBC           3 :     }
      61                 : 
      62                 :     /// Represent `GenericOption` as configuration option.
      63              23 :     fn to_pg_setting(&self) -> String {
      64              23 :         if let Some(val) = &self.value {
      65              23 :             match self.vartype.as_ref() {
      66              23 :                 "string" => format!("{} = {}", self.name, escape_conf_value(val)),
      67              15 :                 _ => format!("{} = {}", self.name, val),
      68                 :             }
      69                 :         } else {
      70 UBC           0 :             self.name.to_owned()
      71                 :         }
      72 CBC          23 :     }
      73                 : }
      74                 : 
      75                 : pub trait PgOptionsSerialize {
      76                 :     fn as_pg_options(&self) -> String;
      77                 :     fn as_pg_settings(&self) -> String;
      78                 : }
      79                 : 
      80                 : impl PgOptionsSerialize for GenericOptions {
      81                 :     /// Serialize an optional collection of `GenericOption`'s to
      82                 :     /// Postgres SQL statement arguments.
      83               2 :     fn as_pg_options(&self) -> String {
      84               2 :         if let Some(ops) = &self {
      85               1 :             ops.iter()
      86               3 :                 .map(|op| op.to_pg_option())
      87               1 :                 .collect::<Vec<String>>()
      88               1 :                 .join(" ")
      89                 :         } else {
      90               1 :             "".to_string()
      91                 :         }
      92               2 :     }
      93                 : 
      94                 :     /// Serialize an optional collection of `GenericOption`'s to
      95                 :     /// `postgresql.conf` compatible format.
      96               1 :     fn as_pg_settings(&self) -> String {
      97               1 :         if let Some(ops) = &self {
      98               1 :             ops.iter()
      99              23 :                 .map(|op| op.to_pg_setting())
     100               1 :                 .collect::<Vec<String>>()
     101               1 :                 .join("\n")
     102               1 :                 + "\n" // newline after last setting
     103                 :         } else {
     104 UBC           0 :             "".to_string()
     105                 :         }
     106 CBC           1 :     }
     107                 : }
     108                 : 
     109                 : pub trait GenericOptionsSearch {
     110                 :     fn find(&self, name: &str) -> Option<String>;
     111                 :     fn find_ref(&self, name: &str) -> Option<&GenericOption>;
     112                 : }
     113                 : 
     114                 : impl GenericOptionsSearch for GenericOptions {
     115                 :     /// Lookup option by name
     116             231 :     fn find(&self, name: &str) -> Option<String> {
     117             231 :         let ops = self.as_ref()?;
     118               6 :         let op = ops.iter().find(|s| s.name == name)?;
     119               2 :         op.value.clone()
     120             231 :     }
     121                 : 
     122                 :     /// Lookup option by name, returning ref
     123 UBC           0 :     fn find_ref(&self, name: &str) -> Option<&GenericOption> {
     124               0 :         let ops = self.as_ref()?;
     125               0 :         ops.iter().find(|s| s.name == name)
     126               0 :     }
     127                 : }
     128                 : 
     129                 : pub trait RoleExt {
     130                 :     fn to_pg_options(&self) -> String;
     131                 : }
     132                 : 
     133                 : impl RoleExt for Role {
     134                 :     /// Serialize a list of role parameters into a Postgres-acceptable
     135                 :     /// string of arguments.
     136 CBC           1 :     fn to_pg_options(&self) -> String {
     137               1 :         // XXX: consider putting LOGIN as a default option somewhere higher, e.g. in control-plane.
     138               1 :         let mut params: String = self.options.as_pg_options();
     139               1 :         params.push_str(" LOGIN");
     140                 : 
     141               1 :         if let Some(pass) = &self.encrypted_password {
     142                 :             // Some time ago we supported only md5 and treated all encrypted_password as md5.
     143                 :             // Now we also support SCRAM-SHA-256 and to preserve compatibility
     144                 :             // we treat all encrypted_password as md5 unless they starts with SCRAM-SHA-256.
     145               1 :             if pass.starts_with("SCRAM-SHA-256") {
     146 UBC           0 :                 write!(params, " PASSWORD '{pass}'")
     147               0 :                     .expect("String is documented to not to error during write operations");
     148 CBC           1 :             } else {
     149               1 :                 write!(params, " PASSWORD 'md5{pass}'")
     150               1 :                     .expect("String is documented to not to error during write operations");
     151               1 :             }
     152 UBC           0 :         } else {
     153               0 :             params.push_str(" PASSWORD NULL");
     154               0 :         }
     155                 : 
     156 CBC           1 :         params
     157               1 :     }
     158                 : }
     159                 : 
     160                 : pub trait DatabaseExt {
     161                 :     fn to_pg_options(&self) -> String;
     162                 : }
     163                 : 
     164                 : impl DatabaseExt for Database {
     165                 :     /// Serialize a list of database parameters into a Postgres-acceptable
     166                 :     /// string of arguments.
     167                 :     /// NB: `TEMPLATE` is actually also an identifier, but so far we only need
     168                 :     /// to use `template0` and `template1`, so it is not a problem. Yet in the future
     169                 :     /// it may require a proper quoting too.
     170               1 :     fn to_pg_options(&self) -> String {
     171               1 :         let mut params: String = self.options.as_pg_options();
     172               1 :         write!(params, " OWNER {}", &self.owner.pg_quote())
     173               1 :             .expect("String is documented to not to error during write operations");
     174               1 : 
     175               1 :         params
     176               1 :     }
     177                 : }
     178                 : 
     179                 : /// Generic trait used to provide quoting / encoding for strings used in the
     180                 : /// Postgres SQL queries and DATABASE_URL.
     181                 : pub trait Escaping {
     182                 :     fn pg_quote(&self) -> String;
     183                 : }
     184                 : 
     185                 : impl Escaping for PgIdent {
     186                 :     /// This is intended to mimic Postgres quote_ident(), but for simplicity it
     187                 :     /// always quotes provided string with `""` and escapes every `"`.
     188                 :     /// **Not idempotent**, i.e. if string is already escaped it will be escaped again.
     189               3 :     fn pg_quote(&self) -> String {
     190               3 :         let result = format!("\"{}\"", self.replace('"', "\"\""));
     191               3 :         result
     192               3 :     }
     193                 : }
     194                 : 
     195                 : /// Build a list of existing Postgres roles
     196             442 : pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
     197             442 :     let postgres_roles = xact
     198             442 :         .query("SELECT rolname, rolpassword FROM pg_catalog.pg_authid", &[])?
     199             442 :         .iter()
     200            5312 :         .map(|row| Role {
     201            5312 :             name: row.get("rolname"),
     202            5312 :             encrypted_password: row.get("rolpassword"),
     203            5312 :             options: None,
     204            5312 :         })
     205             442 :         .collect();
     206             442 : 
     207             442 :     Ok(postgres_roles)
     208             442 : }
     209                 : 
     210                 : /// Build a list of existing Postgres databases
     211             884 : pub fn get_existing_dbs(client: &mut Client) -> Result<HashMap<String, Database>> {
     212                 :     // `pg_database.datconnlimit = -2` means that the database is in the
     213                 :     // invalid state. See:
     214                 :     //   https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
     215             884 :     let postgres_dbs: Vec<Database> = client
     216             884 :         .query(
     217             884 :             "SELECT
     218             884 :                 datname AS name,
     219             884 :                 datdba::regrole::text AS owner,
     220             884 :                 NOT datallowconn AS restrict_conn,
     221             884 :                 datconnlimit = - 2 AS invalid
     222             884 :             FROM
     223             884 :                 pg_catalog.pg_database;",
     224             884 :             &[],
     225             884 :         )?
     226             884 :         .iter()
     227            2653 :         .map(|row| Database {
     228            2653 :             name: row.get("name"),
     229            2653 :             owner: row.get("owner"),
     230            2653 :             restrict_conn: row.get("restrict_conn"),
     231            2653 :             invalid: row.get("invalid"),
     232            2653 :             options: None,
     233            2653 :         })
     234             884 :         .collect();
     235             884 : 
     236             884 :     let dbs_map = postgres_dbs
     237             884 :         .iter()
     238            2653 :         .map(|db| (db.name.clone(), db.clone()))
     239             884 :         .collect::<HashMap<_, _>>();
     240             884 : 
     241             884 :     Ok(dbs_map)
     242             884 : }
     243                 : 
     244                 : /// Wait for Postgres to become ready to accept connections. It's ready to
     245                 : /// accept connections when the state-field in `pgdata/postmaster.pid` says
     246                 : /// 'ready'.
     247             537 : #[instrument(skip_all, fields(pgdata = %pgdata.display()))]
     248                 : pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
     249                 :     let pid_path = pgdata.join("postmaster.pid");
     250                 : 
     251                 :     // PostgreSQL writes line "ready" to the postmaster.pid file, when it has
     252                 :     // completed initialization and is ready to accept connections. We want to
     253                 :     // react quickly and perform the rest of our initialization as soon as
     254                 :     // PostgreSQL starts accepting connections. Use 'notify' to be notified
     255                 :     // whenever the PID file is changed, and whenever it changes, read it to
     256                 :     // check if it's now "ready".
     257                 :     //
     258                 :     // You cannot actually watch a file before it exists, so we first watch the
     259                 :     // data directory, and once the postmaster.pid file appears, we switch to
     260                 :     // watch the file instead. We also wake up every 100 ms to poll, just in
     261                 :     // case we miss some events for some reason. Not strictly necessary, but
     262                 :     // better safe than sorry.
     263                 :     let (tx, rx) = std::sync::mpsc::channel();
     264            6948 :     let (mut watcher, rx): (Box<dyn Watcher>, _) = match notify::recommended_watcher(move |res| {
     265            6948 :         let _ = tx.send(res);
     266            6948 :     }) {
     267                 :         Ok(watcher) => (Box::new(watcher), rx),
     268                 :         Err(e) => {
     269                 :             match e.kind {
     270                 :                 notify::ErrorKind::Io(os) if os.raw_os_error() == Some(38) => {
     271                 :                     // docker on m1 macs does not support recommended_watcher
     272                 :                     // but return "Function not implemented (os error 38)"
     273                 :                     // see https://github.com/notify-rs/notify/issues/423
     274                 :                     let (tx, rx) = std::sync::mpsc::channel();
     275                 : 
     276                 :                     // let's poll it faster than what we check the results for (100ms)
     277                 :                     let config =
     278                 :                         notify::Config::default().with_poll_interval(Duration::from_millis(50));
     279                 : 
     280                 :                     let watcher = notify::PollWatcher::new(
     281 UBC           0 :                         move |res| {
     282               0 :                             let _ = tx.send(res);
     283               0 :                         },
     284                 :                         config,
     285                 :                     )?;
     286                 : 
     287                 :                     (Box::new(watcher), rx)
     288                 :                 }
     289                 :                 _ => return Err(e.into()),
     290                 :             }
     291                 :         }
     292                 :     };
     293                 : 
     294                 :     watcher.watch(pgdata, RecursiveMode::NonRecursive)?;
     295                 : 
     296                 :     let started_at = Instant::now();
     297                 :     let mut postmaster_pid_seen = false;
     298                 :     loop {
     299                 :         if let Ok(Some(status)) = pg.try_wait() {
     300                 :             // Postgres exited, that is not what we expected, bail out earlier.
     301                 :             let code = status.code().unwrap_or(-1);
     302                 :             bail!("Postgres exited unexpectedly with code {}", code);
     303                 :         }
     304                 : 
     305                 :         let res = rx.recv_timeout(Duration::from_millis(100));
     306               0 :         debug!("woken up by notify: {res:?}");
     307                 :         // If there are multiple events in the channel already, we only need to be
     308                 :         // check once. Swallow the extra events before we go ahead to check the
     309                 :         // pid file.
     310                 :         while let Ok(res) = rx.try_recv() {
     311               0 :             debug!("swallowing extra event: {res:?}");
     312                 :         }
     313                 : 
     314                 :         // Check that we can open pid file first.
     315                 :         if let Ok(file) = File::open(&pid_path) {
     316                 :             if !postmaster_pid_seen {
     317               0 :                 debug!("postmaster.pid appeared");
     318                 :                 watcher
     319                 :                     .unwatch(pgdata)
     320                 :                     .expect("Failed to remove pgdata dir watch");
     321                 :                 watcher
     322                 :                     .watch(&pid_path, RecursiveMode::NonRecursive)
     323                 :                     .expect("Failed to add postmaster.pid file watch");
     324                 :                 postmaster_pid_seen = true;
     325                 :             }
     326                 : 
     327                 :             let file = BufReader::new(file);
     328                 :             let last_line = file.lines().last();
     329                 : 
     330                 :             // Pid file could be there and we could read it, but it could be empty, for example.
     331                 :             if let Some(Ok(line)) = last_line {
     332                 :                 let status = line.trim();
     333               0 :                 debug!("last line of postmaster.pid: {status:?}");
     334                 : 
     335                 :                 // Now Postgres is ready to accept connections
     336                 :                 if status == "ready" {
     337                 :                     break;
     338                 :                 }
     339                 :             }
     340                 :         }
     341                 : 
     342                 :         // Give up after POSTGRES_WAIT_TIMEOUT.
     343                 :         let duration = started_at.elapsed();
     344                 :         if duration >= POSTGRES_WAIT_TIMEOUT {
     345                 :             bail!("timed out while waiting for Postgres to start");
     346                 :         }
     347                 :     }
     348                 : 
     349 CBC         537 :     tracing::info!("PostgreSQL is now running, continuing to configure it");
     350                 : 
     351                 :     Ok(())
     352                 : }
     353                 : 
     354                 : /// Remove `pgdata` directory and create it again with right permissions.
     355 UBC           0 : pub fn create_pgdata(pgdata: &str) -> Result<()> {
     356               0 :     // Ignore removal error, likely it is a 'No such file or directory (os error 2)'.
     357               0 :     // If it is something different then create_dir() will error out anyway.
     358               0 :     let _ok = fs::remove_dir_all(pgdata);
     359               0 :     fs::create_dir(pgdata)?;
     360               0 :     fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
     361                 : 
     362               0 :     Ok(())
     363               0 : }
     364                 : 
     365                 : /// Update pgbouncer.ini with provided options
     366               0 : pub fn update_pgbouncer_ini(
     367               0 :     pgbouncer_config: HashMap<String, String>,
     368               0 :     pgbouncer_ini_path: &str,
     369               0 : ) -> Result<()> {
     370               0 :     let mut conf = Ini::load_from_file(pgbouncer_ini_path)?;
     371               0 :     let section = conf.section_mut(Some("pgbouncer")).unwrap();
     372                 : 
     373               0 :     for (option_name, value) in pgbouncer_config.iter() {
     374               0 :         section.insert(option_name, value);
     375               0 :     }
     376                 : 
     377               0 :     conf.write_to_file(pgbouncer_ini_path)?;
     378               0 :     Ok(())
     379               0 : }
     380                 : 
     381                 : /// Tune pgbouncer.
     382                 : /// 1. Apply new config using pgbouncer admin console
     383                 : /// 2. Add new values to pgbouncer.ini to preserve them after restart
     384 CBC         761 : pub async fn tune_pgbouncer(
     385             761 :     pgbouncer_settings: Option<HashMap<String, String>>,
     386             761 :     pgbouncer_connstr: &str,
     387             761 :     pgbouncer_ini_path: Option<String>,
     388             761 : ) -> Result<()> {
     389             761 :     if let Some(pgbouncer_config) = pgbouncer_settings {
     390                 :         // Apply new config
     391 UBC           0 :         let connect_result = tokio_postgres::connect(pgbouncer_connstr, NoTls).await;
     392               0 :         let (client, connection) = connect_result.unwrap();
     393               0 :         tokio::spawn(async move {
     394               0 :             if let Err(e) = connection.await {
     395               0 :                 eprintln!("connection error: {}", e);
     396               0 :             }
     397               0 :         });
     398                 : 
     399               0 :         for (option_name, value) in pgbouncer_config.iter() {
     400               0 :             info!(
     401               0 :                 "Applying pgbouncer setting change: {} = {}",
     402               0 :                 option_name, value
     403               0 :             );
     404               0 :             let query = format!("SET {} = {}", option_name, value);
     405                 : 
     406               0 :             let result = client.simple_query(&query).await;
     407                 : 
     408               0 :             info!("Applying pgbouncer setting change: {}", query);
     409               0 :             info!("pgbouncer setting change result: {:?}", result);
     410                 : 
     411               0 :             if let Err(err) = result {
     412                 :                 // Don't fail on error, just print it into log
     413               0 :                 error!(
     414               0 :                     "Failed to apply pgbouncer setting change: {},  {}",
     415               0 :                     query, err
     416               0 :                 );
     417               0 :             };
     418                 :         }
     419                 : 
     420                 :         // save values to pgbouncer.ini
     421                 :         // so that they are preserved after pgbouncer restart
     422               0 :         if let Some(pgbouncer_ini_path) = pgbouncer_ini_path {
     423               0 :             update_pgbouncer_ini(pgbouncer_config, &pgbouncer_ini_path)?;
     424               0 :         }
     425 CBC         761 :     }
     426                 : 
     427             761 :     Ok(())
     428             761 : }
        

Generated by: LCOV version 2.1-beta