Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt::Write;
3 : use std::fs;
4 : use std::fs::File;
5 : use std::io::{BufRead, BufReader};
6 : use std::os::unix::fs::PermissionsExt;
7 : use std::path::Path;
8 : use std::process::Child;
9 : use std::thread::JoinHandle;
10 : use std::time::{Duration, Instant};
11 :
12 : use anyhow::{bail, Result};
13 : use ini::Ini;
14 : use notify::{RecursiveMode, Watcher};
15 : use postgres::{Client, Transaction};
16 : use tokio::io::AsyncBufReadExt;
17 : use tokio::time::timeout;
18 : use tokio_postgres::NoTls;
19 : use tracing::{debug, error, info, instrument};
20 :
21 : use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
22 :
23 : const POSTGRES_WAIT_TIMEOUT: Duration = Duration::from_millis(60 * 1000); // milliseconds
24 :
25 : /// Escape a string for including it in a SQL literal. Wrapping the result
26 : /// with `E'{}'` or `'{}'` is not required, as it returns a ready-to-use
27 : /// SQL string literal, e.g. `'db'''` or `E'db\\'`.
28 : /// See <https://github.com/postgres/postgres/blob/da98d005cdbcd45af563d0c4ac86d0e9772cd15f/src/backend/utils/adt/quote.c#L47>
29 : /// for the original implementation.
30 12 : pub fn escape_literal(s: &str) -> String {
31 12 : let res = s.replace('\'', "''").replace('\\', "\\\\");
32 12 :
33 12 : if res.contains('\\') {
34 4 : format!("E'{}'", res)
35 : } else {
36 8 : format!("'{}'", res)
37 : }
38 12 : }
39 :
40 : /// Escape a string so that it can be used in postgresql.conf. Wrapping the result
41 : /// with `'{}'` is not required, as it returns a ready-to-use config string.
42 3178 : pub fn escape_conf_value(s: &str) -> String {
43 3178 : let res = s.replace('\'', "''").replace('\\', "\\\\");
44 3178 : format!("'{}'", res)
45 3178 : }
46 :
47 : trait GenericOptionExt {
48 : fn to_pg_option(&self) -> String;
49 : fn to_pg_setting(&self) -> String;
50 : }
51 :
52 : impl GenericOptionExt for GenericOption {
53 : /// Represent `GenericOption` as SQL statement parameter.
54 6 : fn to_pg_option(&self) -> String {
55 6 : if let Some(val) = &self.value {
56 6 : match self.vartype.as_ref() {
57 6 : "string" => format!("{} {}", self.name, escape_literal(val)),
58 2 : _ => format!("{} {}", self.name, val),
59 : }
60 : } else {
61 0 : self.name.to_owned()
62 : }
63 6 : }
64 :
65 : /// Represent `GenericOption` as configuration option.
66 46 : fn to_pg_setting(&self) -> String {
67 46 : if let Some(val) = &self.value {
68 46 : match self.vartype.as_ref() {
69 46 : "string" => format!("{} = {}", self.name, escape_conf_value(val)),
70 30 : _ => format!("{} = {}", self.name, val),
71 : }
72 : } else {
73 0 : self.name.to_owned()
74 : }
75 46 : }
76 : }
77 :
78 : pub trait PgOptionsSerialize {
79 : fn as_pg_options(&self) -> String;
80 : fn as_pg_settings(&self) -> String;
81 : }
82 :
83 : impl PgOptionsSerialize for GenericOptions {
84 : /// Serialize an optional collection of `GenericOption`'s to
85 : /// Postgres SQL statement arguments.
86 4 : fn as_pg_options(&self) -> String {
87 4 : if let Some(ops) = &self {
88 2 : ops.iter()
89 6 : .map(|op| op.to_pg_option())
90 2 : .collect::<Vec<String>>()
91 2 : .join(" ")
92 : } else {
93 2 : "".to_string()
94 : }
95 4 : }
96 :
97 : /// Serialize an optional collection of `GenericOption`'s to
98 : /// `postgresql.conf` compatible format.
99 2 : fn as_pg_settings(&self) -> String {
100 2 : if let Some(ops) = &self {
101 2 : ops.iter()
102 46 : .map(|op| op.to_pg_setting())
103 2 : .collect::<Vec<String>>()
104 2 : .join("\n")
105 2 : + "\n" // newline after last setting
106 : } else {
107 0 : "".to_string()
108 : }
109 2 : }
110 : }
111 :
112 : pub trait GenericOptionsSearch {
113 : fn find(&self, name: &str) -> Option<String>;
114 : fn find_ref(&self, name: &str) -> Option<&GenericOption>;
115 : }
116 :
117 : impl GenericOptionsSearch for GenericOptions {
118 : /// Lookup option by name
119 258 : fn find(&self, name: &str) -> Option<String> {
120 258 : let ops = self.as_ref()?;
121 12 : let op = ops.iter().find(|s| s.name == name)?;
122 4 : op.value.clone()
123 258 : }
124 :
125 : /// Lookup option by name, returning ref
126 0 : fn find_ref(&self, name: &str) -> Option<&GenericOption> {
127 0 : let ops = self.as_ref()?;
128 0 : ops.iter().find(|s| s.name == name)
129 0 : }
130 : }
131 :
132 : pub trait RoleExt {
133 : fn to_pg_options(&self) -> String;
134 : }
135 :
136 : impl RoleExt for Role {
137 : /// Serialize a list of role parameters into a Postgres-acceptable
138 : /// string of arguments.
139 2 : fn to_pg_options(&self) -> String {
140 2 : // XXX: consider putting LOGIN as a default option somewhere higher, e.g. in control-plane.
141 2 : let mut params: String = self.options.as_pg_options();
142 2 : params.push_str(" LOGIN");
143 :
144 2 : if let Some(pass) = &self.encrypted_password {
145 : // Some time ago we supported only md5 and treated all encrypted_password as md5.
146 : // Now we also support SCRAM-SHA-256 and to preserve compatibility
147 : // we treat all encrypted_password as md5 unless they starts with SCRAM-SHA-256.
148 2 : if pass.starts_with("SCRAM-SHA-256") {
149 0 : write!(params, " PASSWORD '{pass}'")
150 0 : .expect("String is documented to not to error during write operations");
151 2 : } else {
152 2 : write!(params, " PASSWORD 'md5{pass}'")
153 2 : .expect("String is documented to not to error during write operations");
154 2 : }
155 0 : } else {
156 0 : params.push_str(" PASSWORD NULL");
157 0 : }
158 :
159 2 : params
160 2 : }
161 : }
162 :
163 : pub trait DatabaseExt {
164 : fn to_pg_options(&self) -> String;
165 : }
166 :
167 : impl DatabaseExt for Database {
168 : /// Serialize a list of database parameters into a Postgres-acceptable
169 : /// string of arguments.
170 : /// NB: `TEMPLATE` is actually also an identifier, but so far we only need
171 : /// to use `template0` and `template1`, so it is not a problem. Yet in the future
172 : /// it may require a proper quoting too.
173 2 : fn to_pg_options(&self) -> String {
174 2 : let mut params: String = self.options.as_pg_options();
175 2 : write!(params, " OWNER {}", &self.owner.pg_quote())
176 2 : .expect("String is documented to not to error during write operations");
177 2 :
178 2 : params
179 2 : }
180 : }
181 :
182 : /// Generic trait used to provide quoting / encoding for strings used in the
183 : /// Postgres SQL queries and DATABASE_URL.
184 : pub trait Escaping {
185 : fn pg_quote(&self) -> String;
186 : }
187 :
188 : impl Escaping for PgIdent {
189 : /// This is intended to mimic Postgres quote_ident(), but for simplicity it
190 : /// always quotes provided string with `""` and escapes every `"`.
191 : /// **Not idempotent**, i.e. if string is already escaped it will be escaped again.
192 5 : fn pg_quote(&self) -> String {
193 5 : let result = format!("\"{}\"", self.replace('"', "\"\""));
194 5 : result
195 5 : }
196 : }
197 :
198 : /// Build a list of existing Postgres roles
199 478 : pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
200 478 : let postgres_roles = xact
201 478 : .query("SELECT rolname, rolpassword FROM pg_catalog.pg_authid", &[])?
202 478 : .iter()
203 5752 : .map(|row| Role {
204 5752 : name: row.get("rolname"),
205 5752 : encrypted_password: row.get("rolpassword"),
206 5752 : options: None,
207 5752 : })
208 478 : .collect();
209 478 :
210 478 : Ok(postgres_roles)
211 478 : }
212 :
213 : /// Build a list of existing Postgres databases
214 956 : pub fn get_existing_dbs(client: &mut Client) -> Result<HashMap<String, Database>> {
215 : // `pg_database.datconnlimit = -2` means that the database is in the
216 : // invalid state. See:
217 : // https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
218 956 : let postgres_dbs: Vec<Database> = client
219 956 : .query(
220 956 : "SELECT
221 956 : datname AS name,
222 956 : datdba::regrole::text AS owner,
223 956 : NOT datallowconn AS restrict_conn,
224 956 : datconnlimit = - 2 AS invalid
225 956 : FROM
226 956 : pg_catalog.pg_database;",
227 956 : &[],
228 956 : )?
229 956 : .iter()
230 2869 : .map(|row| Database {
231 2869 : name: row.get("name"),
232 2869 : owner: row.get("owner"),
233 2869 : restrict_conn: row.get("restrict_conn"),
234 2869 : invalid: row.get("invalid"),
235 2869 : options: None,
236 2869 : })
237 956 : .collect();
238 956 :
239 956 : let dbs_map = postgres_dbs
240 956 : .iter()
241 2869 : .map(|db| (db.name.clone(), db.clone()))
242 956 : .collect::<HashMap<_, _>>();
243 956 :
244 956 : Ok(dbs_map)
245 956 : }
246 :
247 : /// Wait for Postgres to become ready to accept connections. It's ready to
248 : /// accept connections when the state-field in `pgdata/postmaster.pid` says
249 : /// 'ready'.
250 571 : #[instrument(skip_all, fields(pgdata = %pgdata.display()))]
251 : pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
252 : let pid_path = pgdata.join("postmaster.pid");
253 :
254 : // PostgreSQL writes line "ready" to the postmaster.pid file, when it has
255 : // completed initialization and is ready to accept connections. We want to
256 : // react quickly and perform the rest of our initialization as soon as
257 : // PostgreSQL starts accepting connections. Use 'notify' to be notified
258 : // whenever the PID file is changed, and whenever it changes, read it to
259 : // check if it's now "ready".
260 : //
261 : // You cannot actually watch a file before it exists, so we first watch the
262 : // data directory, and once the postmaster.pid file appears, we switch to
263 : // watch the file instead. We also wake up every 100 ms to poll, just in
264 : // case we miss some events for some reason. Not strictly necessary, but
265 : // better safe than sorry.
266 : let (tx, rx) = std::sync::mpsc::channel();
267 7368 : let watcher_res = notify::recommended_watcher(move |res| {
268 7368 : let _ = tx.send(res);
269 7368 : });
270 : let (mut watcher, rx): (Box<dyn Watcher>, _) = match watcher_res {
271 : Ok(watcher) => (Box::new(watcher), rx),
272 : Err(e) => {
273 : match e.kind {
274 : notify::ErrorKind::Io(os) if os.raw_os_error() == Some(38) => {
275 : // docker on m1 macs does not support recommended_watcher
276 : // but return "Function not implemented (os error 38)"
277 : // see https://github.com/notify-rs/notify/issues/423
278 : let (tx, rx) = std::sync::mpsc::channel();
279 :
280 : // let's poll it faster than what we check the results for (100ms)
281 : let config =
282 : notify::Config::default().with_poll_interval(Duration::from_millis(50));
283 :
284 : let watcher = notify::PollWatcher::new(
285 0 : move |res| {
286 0 : let _ = tx.send(res);
287 0 : },
288 : config,
289 : )?;
290 :
291 : (Box::new(watcher), rx)
292 : }
293 : _ => return Err(e.into()),
294 : }
295 : }
296 : };
297 :
298 : watcher.watch(pgdata, RecursiveMode::NonRecursive)?;
299 :
300 : let started_at = Instant::now();
301 : let mut postmaster_pid_seen = false;
302 : loop {
303 : if let Ok(Some(status)) = pg.try_wait() {
304 : // Postgres exited, that is not what we expected, bail out earlier.
305 : let code = status.code().unwrap_or(-1);
306 : bail!("Postgres exited unexpectedly with code {}", code);
307 : }
308 :
309 : let res = rx.recv_timeout(Duration::from_millis(100));
310 0 : debug!("woken up by notify: {res:?}");
311 : // If there are multiple events in the channel already, we only need to be
312 : // check once. Swallow the extra events before we go ahead to check the
313 : // pid file.
314 : while let Ok(res) = rx.try_recv() {
315 0 : debug!("swallowing extra event: {res:?}");
316 : }
317 :
318 : // Check that we can open pid file first.
319 : if let Ok(file) = File::open(&pid_path) {
320 : if !postmaster_pid_seen {
321 0 : debug!("postmaster.pid appeared");
322 : watcher
323 : .unwatch(pgdata)
324 : .expect("Failed to remove pgdata dir watch");
325 : watcher
326 : .watch(&pid_path, RecursiveMode::NonRecursive)
327 : .expect("Failed to add postmaster.pid file watch");
328 : postmaster_pid_seen = true;
329 : }
330 :
331 : let file = BufReader::new(file);
332 : let last_line = file.lines().last();
333 :
334 : // Pid file could be there and we could read it, but it could be empty, for example.
335 : if let Some(Ok(line)) = last_line {
336 : let status = line.trim();
337 0 : debug!("last line of postmaster.pid: {status:?}");
338 :
339 : // Now Postgres is ready to accept connections
340 : if status == "ready" {
341 : break;
342 : }
343 : }
344 : }
345 :
346 : // Give up after POSTGRES_WAIT_TIMEOUT.
347 : let duration = started_at.elapsed();
348 : if duration >= POSTGRES_WAIT_TIMEOUT {
349 : bail!("timed out while waiting for Postgres to start");
350 : }
351 : }
352 :
353 571 : tracing::info!("PostgreSQL is now running, continuing to configure it");
354 :
355 : Ok(())
356 : }
357 :
358 : /// Remove `pgdata` directory and create it again with right permissions.
359 0 : pub fn create_pgdata(pgdata: &str) -> Result<()> {
360 0 : // Ignore removal error, likely it is a 'No such file or directory (os error 2)'.
361 0 : // If it is something different then create_dir() will error out anyway.
362 0 : let _ok = fs::remove_dir_all(pgdata);
363 0 : fs::create_dir(pgdata)?;
364 0 : fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
365 :
366 0 : Ok(())
367 0 : }
368 :
369 : /// Update pgbouncer.ini with provided options
370 0 : fn update_pgbouncer_ini(
371 0 : pgbouncer_config: HashMap<String, String>,
372 0 : pgbouncer_ini_path: &str,
373 0 : ) -> Result<()> {
374 0 : let mut conf = Ini::load_from_file(pgbouncer_ini_path)?;
375 0 : let section = conf.section_mut(Some("pgbouncer")).unwrap();
376 :
377 0 : for (option_name, value) in pgbouncer_config.iter() {
378 0 : section.insert(option_name, value);
379 0 : debug!(
380 0 : "Updating pgbouncer.ini with new values {}={}",
381 0 : option_name, value
382 0 : );
383 : }
384 :
385 0 : conf.write_to_file(pgbouncer_ini_path)?;
386 0 : Ok(())
387 0 : }
388 :
389 : /// Tune pgbouncer.
390 : /// 1. Apply new config using pgbouncer admin console
391 : /// 2. Add new values to pgbouncer.ini to preserve them after restart
392 0 : pub async fn tune_pgbouncer(pgbouncer_config: HashMap<String, String>) -> Result<()> {
393 0 : let pgbouncer_connstr = if std::env::var_os("AUTOSCALING").is_some() {
394 : // for VMs use pgbouncer specific way to connect to
395 : // pgbouncer admin console without password
396 : // when pgbouncer is running under the same user.
397 0 : "host=/tmp port=6432 dbname=pgbouncer user=pgbouncer".to_string()
398 : } else {
399 : // for k8s use normal connection string with password
400 : // to connect to pgbouncer admin console
401 0 : let mut pgbouncer_connstr =
402 0 : "host=localhost port=6432 dbname=pgbouncer user=postgres sslmode=disable".to_string();
403 0 : if let Ok(pass) = std::env::var("PGBOUNCER_PASSWORD") {
404 0 : pgbouncer_connstr.push_str(format!(" password={}", pass).as_str());
405 0 : }
406 0 : pgbouncer_connstr
407 : };
408 :
409 0 : info!(
410 0 : "Connecting to pgbouncer with connection string: {}",
411 0 : pgbouncer_connstr
412 0 : );
413 :
414 : // connect to pgbouncer, retrying several times
415 : // because pgbouncer may not be ready yet
416 0 : let mut retries = 3;
417 0 : let client = loop {
418 0 : match tokio_postgres::connect(&pgbouncer_connstr, NoTls).await {
419 0 : Ok((client, connection)) => {
420 0 : tokio::spawn(async move {
421 0 : if let Err(e) = connection.await {
422 0 : eprintln!("connection error: {}", e);
423 0 : }
424 0 : });
425 0 : break client;
426 : }
427 0 : Err(e) => {
428 0 : if retries == 0 {
429 0 : return Err(e.into());
430 0 : }
431 0 : error!("Failed to connect to pgbouncer: pgbouncer_connstr {}", e);
432 0 : retries -= 1;
433 0 : tokio::time::sleep(Duration::from_secs(1)).await;
434 : }
435 : }
436 : };
437 :
438 : // Apply new config
439 0 : for (option_name, value) in pgbouncer_config.iter() {
440 0 : let query = format!("SET {}={}", option_name, value);
441 : // keep this log line for debugging purposes
442 0 : info!("Applying pgbouncer setting change: {}", query);
443 :
444 0 : if let Err(err) = client.simple_query(&query).await {
445 : // Don't fail on error, just print it into log
446 0 : error!(
447 0 : "Failed to apply pgbouncer setting change: {}, {}",
448 0 : query, err
449 0 : );
450 0 : };
451 : }
452 :
453 : // save values to pgbouncer.ini
454 : // so that they are preserved after pgbouncer restart
455 0 : let pgbouncer_ini_path = if std::env::var_os("AUTOSCALING").is_some() {
456 : // in VMs we use /etc/pgbouncer.ini
457 0 : "/etc/pgbouncer.ini".to_string()
458 : } else {
459 : // in pods we use /var/db/postgres/pgbouncer/pgbouncer.ini
460 : // this is a shared volume between pgbouncer and postgres containers
461 : // FIXME: fix permissions for this file
462 0 : "/var/db/postgres/pgbouncer/pgbouncer.ini".to_string()
463 : };
464 0 : update_pgbouncer_ini(pgbouncer_config, &pgbouncer_ini_path)?;
465 :
466 0 : Ok(())
467 0 : }
468 :
469 : /// Spawn a thread that will read Postgres logs from `stderr`, join multiline logs
470 : /// and send them to the logger. In the future we may also want to add context to
471 : /// these logs.
472 1500 : pub fn handle_postgres_logs(stderr: std::process::ChildStderr) -> JoinHandle<()> {
473 1500 : std::thread::spawn(move || {
474 1500 : let runtime = tokio::runtime::Builder::new_current_thread()
475 1500 : .enable_all()
476 1500 : .build()
477 1500 : .expect("failed to build tokio runtime");
478 1500 :
479 1500 : let res = runtime.block_on(async move {
480 1500 : let stderr = tokio::process::ChildStderr::from_std(stderr)?;
481 126703 : handle_postgres_logs_async(stderr).await
482 1500 : });
483 1500 : if let Err(e) = res {
484 0 : tracing::error!("error while processing postgres logs: {}", e);
485 1500 : }
486 1500 : })
487 1500 : }
488 :
489 : /// Read Postgres logs from `stderr` until EOF. Buffer is flushed on one of the following conditions:
490 : /// - next line starts with timestamp
491 : /// - EOF
492 : /// - no new lines were written for the last second
493 1500 : async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Result<()> {
494 1500 : let mut lines = tokio::io::BufReader::new(stderr).lines();
495 1500 : let timeout_duration = Duration::from_millis(100);
496 1500 : let ts_regex =
497 1500 : regex::Regex::new(r"^\d+-\d{2}-\d{2} \d{2}:\d{2}:\d{2}").expect("regex is valid");
498 1500 :
499 1500 : let mut buf = vec![];
500 : loop {
501 2560903 : let next_line = timeout(timeout_duration, lines.next_line()).await;
502 :
503 : // we should flush lines from the buffer if we cannot continue reading multiline message
504 2560903 : let should_flush_buf = match next_line {
505 : // Flushing if new line starts with timestamp
506 2505241 : Ok(Ok(Some(ref line))) => ts_regex.is_match(line),
507 : // Flushing on EOF, timeout or error
508 55662 : _ => true,
509 : };
510 :
511 2560903 : if !buf.is_empty() && should_flush_buf {
512 : // join multiline message into a single line, separated by unicode Zero Width Space.
513 : // "PG:" suffix is used to distinguish postgres logs from other logs.
514 1691013 : let combined = format!("PG:{}\n", buf.join("\u{200B}"));
515 1691013 : buf.clear();
516 1691013 :
517 1691013 : // sync write to stderr to avoid interleaving with other logs
518 1691013 : use std::io::Write;
519 1691013 : let res = std::io::stderr().lock().write_all(combined.as_bytes());
520 1691013 : if let Err(e) = res {
521 0 : tracing::error!("error while writing to stderr: {}", e);
522 1691013 : }
523 869890 : }
524 :
525 : // if not timeout, append line to the buffer
526 2560903 : if next_line.is_ok() {
527 2506741 : match next_line?? {
528 2505241 : Some(line) => buf.push(line),
529 : // EOF
530 1500 : None => break,
531 : };
532 54162 : }
533 : }
534 :
535 1500 : Ok(())
536 1500 : }
|