TLA Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt::Write;
3 : use std::fs;
4 : use std::fs::File;
5 : use std::io::{BufRead, BufReader};
6 : use std::os::unix::fs::PermissionsExt;
7 : use std::path::Path;
8 : use std::process::Child;
9 : use std::time::{Duration, Instant};
10 :
11 : use anyhow::{bail, Result};
12 : use notify::{RecursiveMode, Watcher};
13 : use postgres::{Client, Transaction};
14 : use tracing::{debug, instrument};
15 :
16 : use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
17 :
18 : const POSTGRES_WAIT_TIMEOUT: Duration = Duration::from_millis(60 * 1000); // milliseconds
19 :
20 : /// Escape a string for including it in a SQL literal. Wrapping the result
21 : /// with `E'{}'` or `'{}'` is not required, as it returns a ready-to-use
22 : /// SQL string literal, e.g. `'db'''` or `E'db\\'`.
23 : /// See <https://github.com/postgres/postgres/blob/da98d005cdbcd45af563d0c4ac86d0e9772cd15f/src/backend/utils/adt/quote.c#L47>
24 : /// for the original implementation.
25 CBC 6 : pub fn escape_literal(s: &str) -> String {
26 6 : let res = s.replace('\'', "''").replace('\\', "\\\\");
27 6 :
28 6 : if res.contains('\\') {
29 2 : format!("E'{}'", res)
30 : } else {
31 4 : format!("'{}'", res)
32 : }
33 6 : }
34 :
35 : /// Escape a string so that it can be used in postgresql.conf. Wrapping the result
36 : /// with `'{}'` is not required, as it returns a ready-to-use config string.
37 2483 : pub fn escape_conf_value(s: &str) -> String {
38 2483 : let res = s.replace('\'', "''").replace('\\', "\\\\");
39 2483 : format!("'{}'", res)
40 2483 : }
41 :
42 : trait GenericOptionExt {
43 : fn to_pg_option(&self) -> String;
44 : fn to_pg_setting(&self) -> String;
45 : }
46 :
47 : impl GenericOptionExt for GenericOption {
48 : /// Represent `GenericOption` as SQL statement parameter.
49 : fn to_pg_option(&self) -> String {
50 3 : if let Some(val) = &self.value {
51 3 : match self.vartype.as_ref() {
52 3 : "string" => format!("{} {}", self.name, escape_literal(val)),
53 1 : _ => format!("{} {}", self.name, val),
54 : }
55 : } else {
56 UBC 0 : self.name.to_owned()
57 : }
58 CBC 3 : }
59 :
60 : /// Represent `GenericOption` as configuration option.
61 : fn to_pg_setting(&self) -> String {
62 23 : if let Some(val) = &self.value {
63 23 : match self.vartype.as_ref() {
64 23 : "string" => format!("{} = {}", self.name, escape_conf_value(val)),
65 15 : _ => format!("{} = {}", self.name, val),
66 : }
67 : } else {
68 UBC 0 : self.name.to_owned()
69 : }
70 CBC 23 : }
71 : }
72 :
73 : pub trait PgOptionsSerialize {
74 : fn as_pg_options(&self) -> String;
75 : fn as_pg_settings(&self) -> String;
76 : }
77 :
78 : impl PgOptionsSerialize for GenericOptions {
79 : /// Serialize an optional collection of `GenericOption`'s to
80 : /// Postgres SQL statement arguments.
81 : fn as_pg_options(&self) -> String {
82 2 : if let Some(ops) = &self {
83 1 : ops.iter()
84 3 : .map(|op| op.to_pg_option())
85 1 : .collect::<Vec<String>>()
86 1 : .join(" ")
87 : } else {
88 1 : "".to_string()
89 : }
90 2 : }
91 :
92 : /// Serialize an optional collection of `GenericOption`'s to
93 : /// `postgresql.conf` compatible format.
94 : fn as_pg_settings(&self) -> String {
95 1 : if let Some(ops) = &self {
96 1 : ops.iter()
97 23 : .map(|op| op.to_pg_setting())
98 1 : .collect::<Vec<String>>()
99 1 : .join("\n")
100 1 : + "\n" // newline after last setting
101 : } else {
102 UBC 0 : "".to_string()
103 : }
104 CBC 1 : }
105 : }
106 :
107 : pub trait GenericOptionsSearch {
108 : fn find(&self, name: &str) -> Option<String>;
109 : fn find_ref(&self, name: &str) -> Option<&GenericOption>;
110 : }
111 :
112 : impl GenericOptionsSearch for GenericOptions {
113 : /// Lookup option by name
114 10 : fn find(&self, name: &str) -> Option<String> {
115 10 : let ops = self.as_ref()?;
116 6 : let op = ops.iter().find(|s| s.name == name)?;
117 2 : op.value.clone()
118 10 : }
119 :
120 : /// Lookup option by name, returning ref
121 UBC 0 : fn find_ref(&self, name: &str) -> Option<&GenericOption> {
122 0 : let ops = self.as_ref()?;
123 0 : ops.iter().find(|s| s.name == name)
124 0 : }
125 : }
126 :
127 : pub trait RoleExt {
128 : fn to_pg_options(&self) -> String;
129 : }
130 :
131 : impl RoleExt for Role {
132 : /// Serialize a list of role parameters into a Postgres-acceptable
133 : /// string of arguments.
134 CBC 1 : fn to_pg_options(&self) -> String {
135 1 : // XXX: consider putting LOGIN as a default option somewhere higher, e.g. in control-plane.
136 1 : let mut params: String = self.options.as_pg_options();
137 1 : params.push_str(" LOGIN");
138 :
139 1 : if let Some(pass) = &self.encrypted_password {
140 : // Some time ago we supported only md5 and treated all encrypted_password as md5.
141 : // Now we also support SCRAM-SHA-256 and to preserve compatibility
142 : // we treat all encrypted_password as md5 unless they starts with SCRAM-SHA-256.
143 1 : if pass.starts_with("SCRAM-SHA-256") {
144 UBC 0 : write!(params, " PASSWORD '{pass}'")
145 0 : .expect("String is documented to not to error during write operations");
146 CBC 1 : } else {
147 1 : write!(params, " PASSWORD 'md5{pass}'")
148 1 : .expect("String is documented to not to error during write operations");
149 1 : }
150 UBC 0 : } else {
151 0 : params.push_str(" PASSWORD NULL");
152 0 : }
153 :
154 CBC 1 : params
155 1 : }
156 : }
157 :
158 : pub trait DatabaseExt {
159 : fn to_pg_options(&self) -> String;
160 : }
161 :
162 : impl DatabaseExt for Database {
163 : /// Serialize a list of database parameters into a Postgres-acceptable
164 : /// string of arguments.
165 : /// NB: `TEMPLATE` is actually also an identifier, but so far we only need
166 : /// to use `template0` and `template1`, so it is not a problem. Yet in the future
167 : /// it may require a proper quoting too.
168 1 : fn to_pg_options(&self) -> String {
169 1 : let mut params: String = self.options.as_pg_options();
170 1 : write!(params, " OWNER {}", &self.owner.pg_quote())
171 1 : .expect("String is documented to not to error during write operations");
172 1 :
173 1 : params
174 1 : }
175 : }
176 :
177 : /// Generic trait used to provide quoting / encoding for strings used in the
178 : /// Postgres SQL queries and DATABASE_URL.
179 : pub trait Escaping {
180 : fn pg_quote(&self) -> String;
181 : }
182 :
183 : impl Escaping for PgIdent {
184 : /// This is intended to mimic Postgres quote_ident(), but for simplicity it
185 : /// always quotes provided string with `""` and escapes every `"`.
186 : /// **Not idempotent**, i.e. if string is already escaped it will be escaped again.
187 3 : fn pg_quote(&self) -> String {
188 3 : let result = format!("\"{}\"", self.replace('"', "\"\""));
189 3 : result
190 3 : }
191 : }
192 :
193 : /// Build a list of existing Postgres roles
194 2 : pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
195 2 : let postgres_roles = xact
196 2 : .query("SELECT rolname, rolpassword FROM pg_catalog.pg_authid", &[])?
197 2 : .iter()
198 26 : .map(|row| Role {
199 26 : name: row.get("rolname"),
200 26 : encrypted_password: row.get("rolpassword"),
201 26 : options: None,
202 26 : })
203 2 : .collect();
204 2 :
205 2 : Ok(postgres_roles)
206 2 : }
207 :
208 : /// Build a list of existing Postgres databases
209 4 : pub fn get_existing_dbs(client: &mut Client) -> Result<HashMap<String, Database>> {
210 : // `pg_database.datconnlimit = -2` means that the database is in the
211 : // invalid state. See:
212 : // https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
213 4 : let postgres_dbs: Vec<Database> = client
214 4 : .query(
215 4 : "SELECT
216 4 : datname AS name,
217 4 : datdba::regrole::text AS owner,
218 4 : NOT datallowconn AS restrict_conn,
219 4 : datconnlimit = - 2 AS invalid
220 4 : FROM
221 4 : pg_catalog.pg_database;",
222 4 : &[],
223 4 : )?
224 4 : .iter()
225 13 : .map(|row| Database {
226 13 : name: row.get("name"),
227 13 : owner: row.get("owner"),
228 13 : restrict_conn: row.get("restrict_conn"),
229 13 : invalid: row.get("invalid"),
230 13 : options: None,
231 13 : })
232 4 : .collect();
233 4 :
234 4 : let dbs_map = postgres_dbs
235 4 : .iter()
236 13 : .map(|db| (db.name.clone(), db.clone()))
237 4 : .collect::<HashMap<_, _>>();
238 4 :
239 4 : Ok(dbs_map)
240 4 : }
241 :
242 : /// Wait for Postgres to become ready to accept connections. It's ready to
243 : /// accept connections when the state-field in `pgdata/postmaster.pid` says
244 : /// 'ready'.
245 632 : #[instrument(skip_all, fields(pgdata = %pgdata.display()))]
246 : pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
247 : let pid_path = pgdata.join("postmaster.pid");
248 :
249 : // PostgreSQL writes line "ready" to the postmaster.pid file, when it has
250 : // completed initialization and is ready to accept connections. We want to
251 : // react quickly and perform the rest of our initialization as soon as
252 : // PostgreSQL starts accepting connections. Use 'notify' to be notified
253 : // whenever the PID file is changed, and whenever it changes, read it to
254 : // check if it's now "ready".
255 : //
256 : // You cannot actually watch a file before it exists, so we first watch the
257 : // data directory, and once the postmaster.pid file appears, we switch to
258 : // watch the file instead. We also wake up every 100 ms to poll, just in
259 : // case we miss some events for some reason. Not strictly necessary, but
260 : // better safe than sorry.
261 : let (tx, rx) = std::sync::mpsc::channel();
262 8199 : let (mut watcher, rx): (Box<dyn Watcher>, _) = match notify::recommended_watcher(move |res| {
263 8199 : let _ = tx.send(res);
264 8199 : }) {
265 : Ok(watcher) => (Box::new(watcher), rx),
266 : Err(e) => {
267 : match e.kind {
268 : notify::ErrorKind::Io(os) if os.raw_os_error() == Some(38) => {
269 : // docker on m1 macs does not support recommended_watcher
270 : // but return "Function not implemented (os error 38)"
271 : // see https://github.com/notify-rs/notify/issues/423
272 : let (tx, rx) = std::sync::mpsc::channel();
273 :
274 : // let's poll it faster than what we check the results for (100ms)
275 : let config =
276 : notify::Config::default().with_poll_interval(Duration::from_millis(50));
277 :
278 : let watcher = notify::PollWatcher::new(
279 UBC 0 : move |res| {
280 0 : let _ = tx.send(res);
281 0 : },
282 : config,
283 : )?;
284 :
285 : (Box::new(watcher), rx)
286 : }
287 : _ => return Err(e.into()),
288 : }
289 : }
290 : };
291 :
292 : watcher.watch(pgdata, RecursiveMode::NonRecursive)?;
293 :
294 : let started_at = Instant::now();
295 : let mut postmaster_pid_seen = false;
296 : loop {
297 : if let Ok(Some(status)) = pg.try_wait() {
298 : // Postgres exited, that is not what we expected, bail out earlier.
299 : let code = status.code().unwrap_or(-1);
300 : bail!("Postgres exited unexpectedly with code {}", code);
301 : }
302 :
303 : let res = rx.recv_timeout(Duration::from_millis(100));
304 0 : debug!("woken up by notify: {res:?}");
305 : // If there are multiple events in the channel already, we only need to be
306 : // check once. Swallow the extra events before we go ahead to check the
307 : // pid file.
308 : while let Ok(res) = rx.try_recv() {
309 0 : debug!("swallowing extra event: {res:?}");
310 : }
311 :
312 : // Check that we can open pid file first.
313 : if let Ok(file) = File::open(&pid_path) {
314 : if !postmaster_pid_seen {
315 0 : debug!("postmaster.pid appeared");
316 : watcher
317 : .unwatch(pgdata)
318 : .expect("Failed to remove pgdata dir watch");
319 : watcher
320 : .watch(&pid_path, RecursiveMode::NonRecursive)
321 : .expect("Failed to add postmaster.pid file watch");
322 : postmaster_pid_seen = true;
323 : }
324 :
325 : let file = BufReader::new(file);
326 : let last_line = file.lines().last();
327 :
328 : // Pid file could be there and we could read it, but it could be empty, for example.
329 : if let Some(Ok(line)) = last_line {
330 : let status = line.trim();
331 0 : debug!("last line of postmaster.pid: {status:?}");
332 :
333 : // Now Postgres is ready to accept connections
334 : if status == "ready" {
335 : break;
336 : }
337 : }
338 : }
339 :
340 : // Give up after POSTGRES_WAIT_TIMEOUT.
341 : let duration = started_at.elapsed();
342 : if duration >= POSTGRES_WAIT_TIMEOUT {
343 : bail!("timed out while waiting for Postgres to start");
344 : }
345 : }
346 :
347 CBC 632 : tracing::info!("PostgreSQL is now running, continuing to configure it");
348 :
349 : Ok(())
350 : }
351 :
352 : /// Remove `pgdata` directory and create it again with right permissions.
353 UBC 0 : pub fn create_pgdata(pgdata: &str) -> Result<()> {
354 0 : // Ignore removal error, likely it is a 'No such file or directory (os error 2)'.
355 0 : // If it is something different then create_dir() will error out anyway.
356 0 : let _ok = fs::remove_dir_all(pgdata);
357 0 : fs::create_dir(pgdata)?;
358 0 : fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
359 :
360 0 : Ok(())
361 0 : }
|