LCOV - code coverage report
Current view: top level - compute_tools/src - pg_helpers.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 55.2 % 125 69
Test Date: 2023-09-06 10:18:01 Functions: 57.1 % 28 16

            Line data    Source code
       1              : use std::fmt::Write;
       2              : use std::fs;
       3              : use std::fs::File;
       4              : use std::io::{BufRead, BufReader};
       5              : use std::os::unix::fs::PermissionsExt;
       6              : use std::path::Path;
       7              : use std::process::Child;
       8              : use std::time::{Duration, Instant};
       9              : 
      10              : use anyhow::{bail, Result};
      11              : use notify::{RecursiveMode, Watcher};
      12              : use postgres::{Client, Transaction};
      13              : use tracing::{debug, instrument};
      14              : 
      15              : use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
      16              : 
      17              : const POSTGRES_WAIT_TIMEOUT: Duration = Duration::from_millis(60 * 1000); // milliseconds
      18              : 
      19              : /// Escape a string for including it in a SQL literal. Wrapping the result
      20              : /// with `E'{}'` or `'{}'` is not required, as it returns a ready-to-use
      21              : /// SQL string literal, e.g. `'db'''` or `E'db\\'`.
      22              : /// See <https://github.com/postgres/postgres/blob/da98d005cdbcd45af563d0c4ac86d0e9772cd15f/src/backend/utils/adt/quote.c#L47>
      23              : /// for the original implementation.
      24            6 : pub fn escape_literal(s: &str) -> String {
      25            6 :     let res = s.replace('\'', "''").replace('\\', "\\\\");
      26            6 : 
      27            6 :     if res.contains('\\') {
      28            2 :         format!("E'{}'", res)
      29              :     } else {
      30            4 :         format!("'{}'", res)
      31              :     }
      32            6 : }
      33              : 
      34              : /// Escape a string so that it can be used in postgresql.conf. Wrapping the result
      35              : /// with `'{}'` is not required, as it returns a ready-to-use config string.
      36         2571 : pub fn escape_conf_value(s: &str) -> String {
      37         2571 :     let res = s.replace('\'', "''").replace('\\', "\\\\");
      38         2571 :     format!("'{}'", res)
      39         2571 : }
      40              : 
      41              : trait GenericOptionExt {
      42              :     fn to_pg_option(&self) -> String;
      43              :     fn to_pg_setting(&self) -> String;
      44              : }
      45              : 
      46              : impl GenericOptionExt for GenericOption {
      47              :     /// Represent `GenericOption` as SQL statement parameter.
      48              :     fn to_pg_option(&self) -> String {
      49            3 :         if let Some(val) = &self.value {
      50            3 :             match self.vartype.as_ref() {
      51            3 :                 "string" => format!("{} {}", self.name, escape_literal(val)),
      52            1 :                 _ => format!("{} {}", self.name, val),
      53              :             }
      54              :         } else {
      55            0 :             self.name.to_owned()
      56              :         }
      57            3 :     }
      58              : 
      59              :     /// Represent `GenericOption` as configuration option.
      60              :     fn to_pg_setting(&self) -> String {
      61           23 :         if let Some(val) = &self.value {
      62           23 :             match self.vartype.as_ref() {
      63           23 :                 "string" => format!("{} = {}", self.name, escape_conf_value(val)),
      64           15 :                 _ => format!("{} = {}", self.name, val),
      65              :             }
      66              :         } else {
      67            0 :             self.name.to_owned()
      68              :         }
      69           23 :     }
      70              : }
      71              : 
      72              : pub trait PgOptionsSerialize {
      73              :     fn as_pg_options(&self) -> String;
      74              :     fn as_pg_settings(&self) -> String;
      75              : }
      76              : 
      77              : impl PgOptionsSerialize for GenericOptions {
      78              :     /// Serialize an optional collection of `GenericOption`'s to
      79              :     /// Postgres SQL statement arguments.
      80              :     fn as_pg_options(&self) -> String {
      81            2 :         if let Some(ops) = &self {
      82            1 :             ops.iter()
      83            3 :                 .map(|op| op.to_pg_option())
      84            1 :                 .collect::<Vec<String>>()
      85            1 :                 .join(" ")
      86              :         } else {
      87            1 :             "".to_string()
      88              :         }
      89            2 :     }
      90              : 
      91              :     /// Serialize an optional collection of `GenericOption`'s to
      92              :     /// `postgresql.conf` compatible format.
      93              :     fn as_pg_settings(&self) -> String {
      94          664 :         if let Some(ops) = &self {
      95            1 :             ops.iter()
      96           23 :                 .map(|op| op.to_pg_setting())
      97            1 :                 .collect::<Vec<String>>()
      98            1 :                 .join("\n")
      99            1 :                 + "\n" // newline after last setting
     100              :         } else {
     101          663 :             "".to_string()
     102              :         }
     103          664 :     }
     104              : }
     105              : 
     106              : pub trait GenericOptionsSearch {
     107              :     fn find(&self, name: &str) -> Option<String>;
     108              :     fn find_ref(&self, name: &str) -> Option<&GenericOption>;
     109              : }
     110              : 
     111              : impl GenericOptionsSearch for GenericOptions {
     112              :     /// Lookup option by name
     113            9 :     fn find(&self, name: &str) -> Option<String> {
     114            9 :         let ops = self.as_ref()?;
     115            6 :         let op = ops.iter().find(|s| s.name == name)?;
     116            2 :         op.value.clone()
     117            9 :     }
     118              : 
     119              :     /// Lookup option by name, returning ref
     120            0 :     fn find_ref(&self, name: &str) -> Option<&GenericOption> {
     121            0 :         let ops = self.as_ref()?;
     122            0 :         ops.iter().find(|s| s.name == name)
     123            0 :     }
     124              : }
     125              : 
     126              : pub trait RoleExt {
     127              :     fn to_pg_options(&self) -> String;
     128              : }
     129              : 
     130              : impl RoleExt for Role {
     131              :     /// Serialize a list of role parameters into a Postgres-acceptable
     132              :     /// string of arguments.
     133            1 :     fn to_pg_options(&self) -> String {
     134            1 :         // XXX: consider putting LOGIN as a default option somewhere higher, e.g. in control-plane.
     135            1 :         let mut params: String = self.options.as_pg_options();
     136            1 :         params.push_str(" LOGIN");
     137              : 
     138            1 :         if let Some(pass) = &self.encrypted_password {
     139              :             // Some time ago we supported only md5 and treated all encrypted_password as md5.
     140              :             // Now we also support SCRAM-SHA-256 and to preserve compatibility
     141              :             // we treat all encrypted_password as md5 unless they starts with SCRAM-SHA-256.
     142            1 :             if pass.starts_with("SCRAM-SHA-256") {
     143            0 :                 write!(params, " PASSWORD '{pass}'")
     144            0 :                     .expect("String is documented to not to error during write operations");
     145            1 :             } else {
     146            1 :                 write!(params, " PASSWORD 'md5{pass}'")
     147            1 :                     .expect("String is documented to not to error during write operations");
     148            1 :             }
     149            0 :         } else {
     150            0 :             params.push_str(" PASSWORD NULL");
     151            0 :         }
     152              : 
     153            1 :         params
     154            1 :     }
     155              : }
     156              : 
     157              : pub trait DatabaseExt {
     158              :     fn to_pg_options(&self) -> String;
     159              : }
     160              : 
     161              : impl DatabaseExt for Database {
     162              :     /// Serialize a list of database parameters into a Postgres-acceptable
     163              :     /// string of arguments.
     164              :     /// NB: `TEMPLATE` is actually also an identifier, but so far we only need
     165              :     /// to use `template0` and `template1`, so it is not a problem. Yet in the future
     166              :     /// it may require a proper quoting too.
     167            1 :     fn to_pg_options(&self) -> String {
     168            1 :         let mut params: String = self.options.as_pg_options();
     169            1 :         write!(params, " OWNER {}", &self.owner.pg_quote())
     170            1 :             .expect("String is documented to not to error during write operations");
     171            1 : 
     172            1 :         params
     173            1 :     }
     174              : }
     175              : 
     176              : /// Generic trait used to provide quoting / encoding for strings used in the
     177              : /// Postgres SQL queries and DATABASE_URL.
     178              : pub trait Escaping {
     179              :     fn pg_quote(&self) -> String;
     180              : }
     181              : 
     182              : impl Escaping for PgIdent {
     183              :     /// This is intended to mimic Postgres quote_ident(), but for simplicity it
     184              :     /// always quotes provided string with `""` and escapes every `"`.
     185              :     /// **Not idempotent**, i.e. if string is already escaped it will be escaped again.
     186            2 :     fn pg_quote(&self) -> String {
     187            2 :         let result = format!("\"{}\"", self.replace('"', "\"\""));
     188            2 :         result
     189            2 :     }
     190              : }
     191              : 
     192              : /// Build a list of existing Postgres roles
     193            0 : pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
     194            0 :     let postgres_roles = xact
     195            0 :         .query("SELECT rolname, rolpassword FROM pg_catalog.pg_authid", &[])?
     196            0 :         .iter()
     197            0 :         .map(|row| Role {
     198            0 :             name: row.get("rolname"),
     199            0 :             encrypted_password: row.get("rolpassword"),
     200            0 :             options: None,
     201            0 :         })
     202            0 :         .collect();
     203            0 : 
     204            0 :     Ok(postgres_roles)
     205            0 : }
     206              : 
     207              : /// Build a list of existing Postgres databases
     208            0 : pub fn get_existing_dbs(client: &mut Client) -> Result<Vec<Database>> {
     209            0 :     let postgres_dbs = client
     210            0 :         .query(
     211            0 :             "SELECT datname, datdba::regrole::text as owner
     212            0 :                FROM pg_catalog.pg_database;",
     213            0 :             &[],
     214            0 :         )?
     215            0 :         .iter()
     216            0 :         .map(|row| Database {
     217            0 :             name: row.get("datname"),
     218            0 :             owner: row.get("owner"),
     219            0 :             options: None,
     220            0 :         })
     221            0 :         .collect();
     222            0 : 
     223            0 :     Ok(postgres_dbs)
     224            0 : }
     225              : 
     226              : /// Wait for Postgres to become ready to accept connections. It's ready to
     227              : /// accept connections when the state-field in `pgdata/postmaster.pid` says
     228              : /// 'ready'.
     229          654 : #[instrument(skip_all, fields(pgdata = %pgdata.display()))]
     230              : pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
     231              :     let pid_path = pgdata.join("postmaster.pid");
     232              : 
     233              :     // PostgreSQL writes line "ready" to the postmaster.pid file, when it has
     234              :     // completed initialization and is ready to accept connections. We want to
     235              :     // react quickly and perform the rest of our initialization as soon as
     236              :     // PostgreSQL starts accepting connections. Use 'notify' to be notified
     237              :     // whenever the PID file is changed, and whenever it changes, read it to
     238              :     // check if it's now "ready".
     239              :     //
     240              :     // You cannot actually watch a file before it exists, so we first watch the
     241              :     // data directory, and once the postmaster.pid file appears, we switch to
     242              :     // watch the file instead. We also wake up every 100 ms to poll, just in
     243              :     // case we miss some events for some reason. Not strictly necessary, but
     244              :     // better safe than sorry.
     245              :     let (tx, rx) = std::sync::mpsc::channel();
     246         8461 :     let (mut watcher, rx): (Box<dyn Watcher>, _) = match notify::recommended_watcher(move |res| {
     247         8461 :         let _ = tx.send(res);
     248         8461 :     }) {
     249              :         Ok(watcher) => (Box::new(watcher), rx),
     250              :         Err(e) => {
     251              :             match e.kind {
     252              :                 notify::ErrorKind::Io(os) if os.raw_os_error() == Some(38) => {
     253              :                     // docker on m1 macs does not support recommended_watcher
     254              :                     // but return "Function not implemented (os error 38)"
     255              :                     // see https://github.com/notify-rs/notify/issues/423
     256              :                     let (tx, rx) = std::sync::mpsc::channel();
     257              : 
     258              :                     // let's poll it faster than what we check the results for (100ms)
     259              :                     let config =
     260              :                         notify::Config::default().with_poll_interval(Duration::from_millis(50));
     261              : 
     262              :                     let watcher = notify::PollWatcher::new(
     263            0 :                         move |res| {
     264            0 :                             let _ = tx.send(res);
     265            0 :                         },
     266              :                         config,
     267              :                     )?;
     268              : 
     269              :                     (Box::new(watcher), rx)
     270              :                 }
     271              :                 _ => return Err(e.into()),
     272              :             }
     273              :         }
     274              :     };
     275              : 
     276              :     watcher.watch(pgdata, RecursiveMode::NonRecursive)?;
     277              : 
     278              :     let started_at = Instant::now();
     279              :     let mut postmaster_pid_seen = false;
     280              :     loop {
     281              :         if let Ok(Some(status)) = pg.try_wait() {
     282              :             // Postgres exited, that is not what we expected, bail out earlier.
     283              :             let code = status.code().unwrap_or(-1);
     284              :             bail!("Postgres exited unexpectedly with code {}", code);
     285              :         }
     286              : 
     287              :         let res = rx.recv_timeout(Duration::from_millis(100));
     288            0 :         debug!("woken up by notify: {res:?}");
     289              :         // If there are multiple events in the channel already, we only need to be
     290              :         // check once. Swallow the extra events before we go ahead to check the
     291              :         // pid file.
     292              :         while let Ok(res) = rx.try_recv() {
     293            0 :             debug!("swallowing extra event: {res:?}");
     294              :         }
     295              : 
     296              :         // Check that we can open pid file first.
     297              :         if let Ok(file) = File::open(&pid_path) {
     298              :             if !postmaster_pid_seen {
     299            0 :                 debug!("postmaster.pid appeared");
     300              :                 watcher
     301              :                     .unwatch(pgdata)
     302              :                     .expect("Failed to remove pgdata dir watch");
     303              :                 watcher
     304              :                     .watch(&pid_path, RecursiveMode::NonRecursive)
     305              :                     .expect("Failed to add postmaster.pid file watch");
     306              :                 postmaster_pid_seen = true;
     307              :             }
     308              : 
     309              :             let file = BufReader::new(file);
     310              :             let last_line = file.lines().last();
     311              : 
     312              :             // Pid file could be there and we could read it, but it could be empty, for example.
     313              :             if let Some(Ok(line)) = last_line {
     314              :                 let status = line.trim();
     315            0 :                 debug!("last line of postmaster.pid: {status:?}");
     316              : 
     317              :                 // Now Postgres is ready to accept connections
     318              :                 if status == "ready" {
     319              :                     break;
     320              :                 }
     321              :             }
     322              :         }
     323              : 
     324              :         // Give up after POSTGRES_WAIT_TIMEOUT.
     325              :         let duration = started_at.elapsed();
     326              :         if duration >= POSTGRES_WAIT_TIMEOUT {
     327              :             bail!("timed out while waiting for Postgres to start");
     328              :         }
     329              :     }
     330              : 
     331          654 :     tracing::info!("PostgreSQL is now running, continuing to configure it");
     332              : 
     333              :     Ok(())
     334              : }
     335              : 
     336              : /// Remove `pgdata` directory and create it again with right permissions.
     337            0 : pub fn create_pgdata(pgdata: &str) -> Result<()> {
     338            0 :     // Ignore removal error, likely it is a 'No such file or directory (os error 2)'.
     339            0 :     // If it is something different then create_dir() will error out anyway.
     340            0 :     let _ok = fs::remove_dir_all(pgdata);
     341            0 :     fs::create_dir(pgdata)?;
     342            0 :     fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
     343              : 
     344            0 :     Ok(())
     345            0 : }
        

Generated by: LCOV version 2.1-beta