TLA Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt::Write;
3 : use std::fs;
4 : use std::fs::File;
5 : use std::io::{BufRead, BufReader};
6 : use std::os::unix::fs::PermissionsExt;
7 : use std::path::Path;
8 : use std::process::Child;
9 : use std::time::{Duration, Instant};
10 :
11 : use anyhow::{bail, Result};
12 : use ini::Ini;
13 : use notify::{RecursiveMode, Watcher};
14 : use postgres::{Client, Transaction};
15 : use tokio_postgres::NoTls;
16 : use tracing::{debug, error, info, instrument};
17 :
18 : use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
19 :
20 : const POSTGRES_WAIT_TIMEOUT: Duration = Duration::from_millis(60 * 1000); // milliseconds
21 :
22 : /// Escape a string for including it in a SQL literal. Wrapping the result
23 : /// with `E'{}'` or `'{}'` is not required, as it returns a ready-to-use
24 : /// SQL string literal, e.g. `'db'''` or `E'db\\'`.
25 : /// See <https://github.com/postgres/postgres/blob/da98d005cdbcd45af563d0c4ac86d0e9772cd15f/src/backend/utils/adt/quote.c#L47>
26 : /// for the original implementation.
27 CBC 6 : pub fn escape_literal(s: &str) -> String {
28 6 : let res = s.replace('\'', "''").replace('\\', "\\\\");
29 6 :
30 6 : if res.contains('\\') {
31 2 : format!("E'{}'", res)
32 : } else {
33 4 : format!("'{}'", res)
34 : }
35 6 : }
36 :
37 : /// Escape a string so that it can be used in postgresql.conf. Wrapping the result
38 : /// with `'{}'` is not required, as it returns a ready-to-use config string.
39 3002 : pub fn escape_conf_value(s: &str) -> String {
40 3002 : let res = s.replace('\'', "''").replace('\\', "\\\\");
41 3002 : format!("'{}'", res)
42 3002 : }
43 :
44 : trait GenericOptionExt {
45 : fn to_pg_option(&self) -> String;
46 : fn to_pg_setting(&self) -> String;
47 : }
48 :
49 : impl GenericOptionExt for GenericOption {
50 : /// Represent `GenericOption` as SQL statement parameter.
51 3 : fn to_pg_option(&self) -> String {
52 3 : if let Some(val) = &self.value {
53 3 : match self.vartype.as_ref() {
54 3 : "string" => format!("{} {}", self.name, escape_literal(val)),
55 1 : _ => format!("{} {}", self.name, val),
56 : }
57 : } else {
58 UBC 0 : self.name.to_owned()
59 : }
60 CBC 3 : }
61 :
62 : /// Represent `GenericOption` as configuration option.
63 23 : fn to_pg_setting(&self) -> String {
64 23 : if let Some(val) = &self.value {
65 23 : match self.vartype.as_ref() {
66 23 : "string" => format!("{} = {}", self.name, escape_conf_value(val)),
67 15 : _ => format!("{} = {}", self.name, val),
68 : }
69 : } else {
70 UBC 0 : self.name.to_owned()
71 : }
72 CBC 23 : }
73 : }
74 :
75 : pub trait PgOptionsSerialize {
76 : fn as_pg_options(&self) -> String;
77 : fn as_pg_settings(&self) -> String;
78 : }
79 :
80 : impl PgOptionsSerialize for GenericOptions {
81 : /// Serialize an optional collection of `GenericOption`'s to
82 : /// Postgres SQL statement arguments.
83 2 : fn as_pg_options(&self) -> String {
84 2 : if let Some(ops) = &self {
85 1 : ops.iter()
86 3 : .map(|op| op.to_pg_option())
87 1 : .collect::<Vec<String>>()
88 1 : .join(" ")
89 : } else {
90 1 : "".to_string()
91 : }
92 2 : }
93 :
94 : /// Serialize an optional collection of `GenericOption`'s to
95 : /// `postgresql.conf` compatible format.
96 1 : fn as_pg_settings(&self) -> String {
97 1 : if let Some(ops) = &self {
98 1 : ops.iter()
99 23 : .map(|op| op.to_pg_setting())
100 1 : .collect::<Vec<String>>()
101 1 : .join("\n")
102 1 : + "\n" // newline after last setting
103 : } else {
104 UBC 0 : "".to_string()
105 : }
106 CBC 1 : }
107 : }
108 :
109 : pub trait GenericOptionsSearch {
110 : fn find(&self, name: &str) -> Option<String>;
111 : fn find_ref(&self, name: &str) -> Option<&GenericOption>;
112 : }
113 :
114 : impl GenericOptionsSearch for GenericOptions {
115 : /// Lookup option by name
116 231 : fn find(&self, name: &str) -> Option<String> {
117 231 : let ops = self.as_ref()?;
118 6 : let op = ops.iter().find(|s| s.name == name)?;
119 2 : op.value.clone()
120 231 : }
121 :
122 : /// Lookup option by name, returning ref
123 UBC 0 : fn find_ref(&self, name: &str) -> Option<&GenericOption> {
124 0 : let ops = self.as_ref()?;
125 0 : ops.iter().find(|s| s.name == name)
126 0 : }
127 : }
128 :
129 : pub trait RoleExt {
130 : fn to_pg_options(&self) -> String;
131 : }
132 :
133 : impl RoleExt for Role {
134 : /// Serialize a list of role parameters into a Postgres-acceptable
135 : /// string of arguments.
136 CBC 1 : fn to_pg_options(&self) -> String {
137 1 : // XXX: consider putting LOGIN as a default option somewhere higher, e.g. in control-plane.
138 1 : let mut params: String = self.options.as_pg_options();
139 1 : params.push_str(" LOGIN");
140 :
141 1 : if let Some(pass) = &self.encrypted_password {
142 : // Some time ago we supported only md5 and treated all encrypted_password as md5.
143 : // Now we also support SCRAM-SHA-256 and to preserve compatibility
144 : // we treat all encrypted_password as md5 unless they starts with SCRAM-SHA-256.
145 1 : if pass.starts_with("SCRAM-SHA-256") {
146 UBC 0 : write!(params, " PASSWORD '{pass}'")
147 0 : .expect("String is documented to not to error during write operations");
148 CBC 1 : } else {
149 1 : write!(params, " PASSWORD 'md5{pass}'")
150 1 : .expect("String is documented to not to error during write operations");
151 1 : }
152 UBC 0 : } else {
153 0 : params.push_str(" PASSWORD NULL");
154 0 : }
155 :
156 CBC 1 : params
157 1 : }
158 : }
159 :
160 : pub trait DatabaseExt {
161 : fn to_pg_options(&self) -> String;
162 : }
163 :
164 : impl DatabaseExt for Database {
165 : /// Serialize a list of database parameters into a Postgres-acceptable
166 : /// string of arguments.
167 : /// NB: `TEMPLATE` is actually also an identifier, but so far we only need
168 : /// to use `template0` and `template1`, so it is not a problem. Yet in the future
169 : /// it may require a proper quoting too.
170 1 : fn to_pg_options(&self) -> String {
171 1 : let mut params: String = self.options.as_pg_options();
172 1 : write!(params, " OWNER {}", &self.owner.pg_quote())
173 1 : .expect("String is documented to not to error during write operations");
174 1 :
175 1 : params
176 1 : }
177 : }
178 :
179 : /// Generic trait used to provide quoting / encoding for strings used in the
180 : /// Postgres SQL queries and DATABASE_URL.
181 : pub trait Escaping {
182 : fn pg_quote(&self) -> String;
183 : }
184 :
185 : impl Escaping for PgIdent {
186 : /// This is intended to mimic Postgres quote_ident(), but for simplicity it
187 : /// always quotes provided string with `""` and escapes every `"`.
188 : /// **Not idempotent**, i.e. if string is already escaped it will be escaped again.
189 3 : fn pg_quote(&self) -> String {
190 3 : let result = format!("\"{}\"", self.replace('"', "\"\""));
191 3 : result
192 3 : }
193 : }
194 :
195 : /// Build a list of existing Postgres roles
196 442 : pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
197 442 : let postgres_roles = xact
198 442 : .query("SELECT rolname, rolpassword FROM pg_catalog.pg_authid", &[])?
199 442 : .iter()
200 5312 : .map(|row| Role {
201 5312 : name: row.get("rolname"),
202 5312 : encrypted_password: row.get("rolpassword"),
203 5312 : options: None,
204 5312 : })
205 442 : .collect();
206 442 :
207 442 : Ok(postgres_roles)
208 442 : }
209 :
210 : /// Build a list of existing Postgres databases
211 884 : pub fn get_existing_dbs(client: &mut Client) -> Result<HashMap<String, Database>> {
212 : // `pg_database.datconnlimit = -2` means that the database is in the
213 : // invalid state. See:
214 : // https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
215 884 : let postgres_dbs: Vec<Database> = client
216 884 : .query(
217 884 : "SELECT
218 884 : datname AS name,
219 884 : datdba::regrole::text AS owner,
220 884 : NOT datallowconn AS restrict_conn,
221 884 : datconnlimit = - 2 AS invalid
222 884 : FROM
223 884 : pg_catalog.pg_database;",
224 884 : &[],
225 884 : )?
226 884 : .iter()
227 2653 : .map(|row| Database {
228 2653 : name: row.get("name"),
229 2653 : owner: row.get("owner"),
230 2653 : restrict_conn: row.get("restrict_conn"),
231 2653 : invalid: row.get("invalid"),
232 2653 : options: None,
233 2653 : })
234 884 : .collect();
235 884 :
236 884 : let dbs_map = postgres_dbs
237 884 : .iter()
238 2653 : .map(|db| (db.name.clone(), db.clone()))
239 884 : .collect::<HashMap<_, _>>();
240 884 :
241 884 : Ok(dbs_map)
242 884 : }
243 :
244 : /// Wait for Postgres to become ready to accept connections. It's ready to
245 : /// accept connections when the state-field in `pgdata/postmaster.pid` says
246 : /// 'ready'.
247 537 : #[instrument(skip_all, fields(pgdata = %pgdata.display()))]
248 : pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
249 : let pid_path = pgdata.join("postmaster.pid");
250 :
251 : // PostgreSQL writes line "ready" to the postmaster.pid file, when it has
252 : // completed initialization and is ready to accept connections. We want to
253 : // react quickly and perform the rest of our initialization as soon as
254 : // PostgreSQL starts accepting connections. Use 'notify' to be notified
255 : // whenever the PID file is changed, and whenever it changes, read it to
256 : // check if it's now "ready".
257 : //
258 : // You cannot actually watch a file before it exists, so we first watch the
259 : // data directory, and once the postmaster.pid file appears, we switch to
260 : // watch the file instead. We also wake up every 100 ms to poll, just in
261 : // case we miss some events for some reason. Not strictly necessary, but
262 : // better safe than sorry.
263 : let (tx, rx) = std::sync::mpsc::channel();
264 6948 : let (mut watcher, rx): (Box<dyn Watcher>, _) = match notify::recommended_watcher(move |res| {
265 6948 : let _ = tx.send(res);
266 6948 : }) {
267 : Ok(watcher) => (Box::new(watcher), rx),
268 : Err(e) => {
269 : match e.kind {
270 : notify::ErrorKind::Io(os) if os.raw_os_error() == Some(38) => {
271 : // docker on m1 macs does not support recommended_watcher
272 : // but return "Function not implemented (os error 38)"
273 : // see https://github.com/notify-rs/notify/issues/423
274 : let (tx, rx) = std::sync::mpsc::channel();
275 :
276 : // let's poll it faster than what we check the results for (100ms)
277 : let config =
278 : notify::Config::default().with_poll_interval(Duration::from_millis(50));
279 :
280 : let watcher = notify::PollWatcher::new(
281 UBC 0 : move |res| {
282 0 : let _ = tx.send(res);
283 0 : },
284 : config,
285 : )?;
286 :
287 : (Box::new(watcher), rx)
288 : }
289 : _ => return Err(e.into()),
290 : }
291 : }
292 : };
293 :
294 : watcher.watch(pgdata, RecursiveMode::NonRecursive)?;
295 :
296 : let started_at = Instant::now();
297 : let mut postmaster_pid_seen = false;
298 : loop {
299 : if let Ok(Some(status)) = pg.try_wait() {
300 : // Postgres exited, that is not what we expected, bail out earlier.
301 : let code = status.code().unwrap_or(-1);
302 : bail!("Postgres exited unexpectedly with code {}", code);
303 : }
304 :
305 : let res = rx.recv_timeout(Duration::from_millis(100));
306 0 : debug!("woken up by notify: {res:?}");
307 : // If there are multiple events in the channel already, we only need to be
308 : // check once. Swallow the extra events before we go ahead to check the
309 : // pid file.
310 : while let Ok(res) = rx.try_recv() {
311 0 : debug!("swallowing extra event: {res:?}");
312 : }
313 :
314 : // Check that we can open pid file first.
315 : if let Ok(file) = File::open(&pid_path) {
316 : if !postmaster_pid_seen {
317 0 : debug!("postmaster.pid appeared");
318 : watcher
319 : .unwatch(pgdata)
320 : .expect("Failed to remove pgdata dir watch");
321 : watcher
322 : .watch(&pid_path, RecursiveMode::NonRecursive)
323 : .expect("Failed to add postmaster.pid file watch");
324 : postmaster_pid_seen = true;
325 : }
326 :
327 : let file = BufReader::new(file);
328 : let last_line = file.lines().last();
329 :
330 : // Pid file could be there and we could read it, but it could be empty, for example.
331 : if let Some(Ok(line)) = last_line {
332 : let status = line.trim();
333 0 : debug!("last line of postmaster.pid: {status:?}");
334 :
335 : // Now Postgres is ready to accept connections
336 : if status == "ready" {
337 : break;
338 : }
339 : }
340 : }
341 :
342 : // Give up after POSTGRES_WAIT_TIMEOUT.
343 : let duration = started_at.elapsed();
344 : if duration >= POSTGRES_WAIT_TIMEOUT {
345 : bail!("timed out while waiting for Postgres to start");
346 : }
347 : }
348 :
349 CBC 537 : tracing::info!("PostgreSQL is now running, continuing to configure it");
350 :
351 : Ok(())
352 : }
353 :
354 : /// Remove `pgdata` directory and create it again with right permissions.
355 UBC 0 : pub fn create_pgdata(pgdata: &str) -> Result<()> {
356 0 : // Ignore removal error, likely it is a 'No such file or directory (os error 2)'.
357 0 : // If it is something different then create_dir() will error out anyway.
358 0 : let _ok = fs::remove_dir_all(pgdata);
359 0 : fs::create_dir(pgdata)?;
360 0 : fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
361 :
362 0 : Ok(())
363 0 : }
364 :
365 : /// Update pgbouncer.ini with provided options
366 0 : pub fn update_pgbouncer_ini(
367 0 : pgbouncer_config: HashMap<String, String>,
368 0 : pgbouncer_ini_path: &str,
369 0 : ) -> Result<()> {
370 0 : let mut conf = Ini::load_from_file(pgbouncer_ini_path)?;
371 0 : let section = conf.section_mut(Some("pgbouncer")).unwrap();
372 :
373 0 : for (option_name, value) in pgbouncer_config.iter() {
374 0 : section.insert(option_name, value);
375 0 : }
376 :
377 0 : conf.write_to_file(pgbouncer_ini_path)?;
378 0 : Ok(())
379 0 : }
380 :
381 : /// Tune pgbouncer.
382 : /// 1. Apply new config using pgbouncer admin console
383 : /// 2. Add new values to pgbouncer.ini to preserve them after restart
384 CBC 761 : pub async fn tune_pgbouncer(
385 761 : pgbouncer_settings: Option<HashMap<String, String>>,
386 761 : pgbouncer_connstr: &str,
387 761 : pgbouncer_ini_path: Option<String>,
388 761 : ) -> Result<()> {
389 761 : if let Some(pgbouncer_config) = pgbouncer_settings {
390 : // Apply new config
391 UBC 0 : let connect_result = tokio_postgres::connect(pgbouncer_connstr, NoTls).await;
392 0 : let (client, connection) = connect_result.unwrap();
393 0 : tokio::spawn(async move {
394 0 : if let Err(e) = connection.await {
395 0 : eprintln!("connection error: {}", e);
396 0 : }
397 0 : });
398 :
399 0 : for (option_name, value) in pgbouncer_config.iter() {
400 0 : info!(
401 0 : "Applying pgbouncer setting change: {} = {}",
402 0 : option_name, value
403 0 : );
404 0 : let query = format!("SET {} = {}", option_name, value);
405 :
406 0 : let result = client.simple_query(&query).await;
407 :
408 0 : info!("Applying pgbouncer setting change: {}", query);
409 0 : info!("pgbouncer setting change result: {:?}", result);
410 :
411 0 : if let Err(err) = result {
412 : // Don't fail on error, just print it into log
413 0 : error!(
414 0 : "Failed to apply pgbouncer setting change: {}, {}",
415 0 : query, err
416 0 : );
417 0 : };
418 : }
419 :
420 : // save values to pgbouncer.ini
421 : // so that they are preserved after pgbouncer restart
422 0 : if let Some(pgbouncer_ini_path) = pgbouncer_ini_path {
423 0 : update_pgbouncer_ini(pgbouncer_config, &pgbouncer_ini_path)?;
424 0 : }
425 CBC 761 : }
426 :
427 761 : Ok(())
428 761 : }
|