Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt::Write;
3 : use std::fs;
4 : use std::fs::File;
5 : use std::io::{BufRead, BufReader};
6 : use std::os::unix::fs::PermissionsExt;
7 : use std::path::Path;
8 : use std::process::Child;
9 : use std::thread::JoinHandle;
10 : use std::time::{Duration, Instant};
11 :
12 : use anyhow::{bail, Result};
13 : use ini::Ini;
14 : use notify::{RecursiveMode, Watcher};
15 : use postgres::{Client, Transaction};
16 : use tokio::io::AsyncBufReadExt;
17 : use tokio::time::timeout;
18 : use tokio_postgres::NoTls;
19 : use tracing::{debug, error, info, instrument};
20 :
21 : use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
22 :
23 : const POSTGRES_WAIT_TIMEOUT: Duration = Duration::from_millis(60 * 1000); // milliseconds
24 :
25 : /// Escape a string for including it in a SQL literal.
26 : ///
27 : /// Wrapping the result with `E'{}'` or `'{}'` is not required,
28 : /// as it returns a ready-to-use SQL string literal, e.g. `'db'''` or `E'db\\'`.
29 : /// See <https://github.com/postgres/postgres/blob/da98d005cdbcd45af563d0c4ac86d0e9772cd15f/src/backend/utils/adt/quote.c#L47>
30 : /// for the original implementation.
31 6 : pub fn escape_literal(s: &str) -> String {
32 6 : let res = s.replace('\'', "''").replace('\\', "\\\\");
33 6 :
34 6 : if res.contains('\\') {
35 2 : format!("E'{}'", res)
36 : } else {
37 4 : format!("'{}'", res)
38 : }
39 6 : }
40 :
41 : /// Escape a string so that it can be used in postgresql.conf. Wrapping the result
42 : /// with `'{}'` is not required, as it returns a ready-to-use config string.
43 8 : pub fn escape_conf_value(s: &str) -> String {
44 8 : let res = s.replace('\'', "''").replace('\\', "\\\\");
45 8 : format!("'{}'", res)
46 8 : }
47 :
48 : pub trait GenericOptionExt {
49 : fn to_pg_option(&self) -> String;
50 : fn to_pg_setting(&self) -> String;
51 : }
52 :
53 : impl GenericOptionExt for GenericOption {
54 : /// Represent `GenericOption` as SQL statement parameter.
55 3 : fn to_pg_option(&self) -> String {
56 3 : if let Some(val) = &self.value {
57 3 : match self.vartype.as_ref() {
58 3 : "string" => format!("{} {}", self.name, escape_literal(val)),
59 1 : _ => format!("{} {}", self.name, val),
60 : }
61 : } else {
62 0 : self.name.to_owned()
63 : }
64 3 : }
65 :
66 : /// Represent `GenericOption` as configuration option.
67 23 : fn to_pg_setting(&self) -> String {
68 23 : if let Some(val) = &self.value {
69 23 : match self.vartype.as_ref() {
70 23 : "string" => format!("{} = {}", self.name, escape_conf_value(val)),
71 15 : _ => format!("{} = {}", self.name, val),
72 : }
73 : } else {
74 0 : self.name.to_owned()
75 : }
76 23 : }
77 : }
78 :
79 : pub trait PgOptionsSerialize {
80 : fn as_pg_options(&self) -> String;
81 : fn as_pg_settings(&self) -> String;
82 : }
83 :
84 : impl PgOptionsSerialize for GenericOptions {
85 : /// Serialize an optional collection of `GenericOption`'s to
86 : /// Postgres SQL statement arguments.
87 2 : fn as_pg_options(&self) -> String {
88 2 : if let Some(ops) = &self {
89 1 : ops.iter()
90 3 : .map(|op| op.to_pg_option())
91 1 : .collect::<Vec<String>>()
92 1 : .join(" ")
93 : } else {
94 1 : "".to_string()
95 : }
96 2 : }
97 :
98 : /// Serialize an optional collection of `GenericOption`'s to
99 : /// `postgresql.conf` compatible format.
100 1 : fn as_pg_settings(&self) -> String {
101 1 : if let Some(ops) = &self {
102 1 : ops.iter()
103 23 : .map(|op| op.to_pg_setting())
104 1 : .collect::<Vec<String>>()
105 1 : .join("\n")
106 1 : + "\n" // newline after last setting
107 : } else {
108 0 : "".to_string()
109 : }
110 1 : }
111 : }
112 :
113 : pub trait GenericOptionsSearch {
114 : fn find(&self, name: &str) -> Option<String>;
115 : fn find_ref(&self, name: &str) -> Option<&GenericOption>;
116 : }
117 :
118 : impl GenericOptionsSearch for GenericOptions {
119 : /// Lookup option by name
120 9 : fn find(&self, name: &str) -> Option<String> {
121 9 : let ops = self.as_ref()?;
122 6 : let op = ops.iter().find(|s| s.name == name)?;
123 2 : op.value.clone()
124 9 : }
125 :
126 : /// Lookup option by name, returning ref
127 0 : fn find_ref(&self, name: &str) -> Option<&GenericOption> {
128 0 : let ops = self.as_ref()?;
129 0 : ops.iter().find(|s| s.name == name)
130 0 : }
131 : }
132 :
133 : pub trait RoleExt {
134 : fn to_pg_options(&self) -> String;
135 : }
136 :
137 : impl RoleExt for Role {
138 : /// Serialize a list of role parameters into a Postgres-acceptable
139 : /// string of arguments.
140 1 : fn to_pg_options(&self) -> String {
141 1 : // XXX: consider putting LOGIN as a default option somewhere higher, e.g. in control-plane.
142 1 : let mut params: String = self.options.as_pg_options();
143 1 : params.push_str(" LOGIN");
144 :
145 1 : if let Some(pass) = &self.encrypted_password {
146 : // Some time ago we supported only md5 and treated all encrypted_password as md5.
147 : // Now we also support SCRAM-SHA-256 and to preserve compatibility
148 : // we treat all encrypted_password as md5 unless they starts with SCRAM-SHA-256.
149 1 : if pass.starts_with("SCRAM-SHA-256") {
150 0 : write!(params, " PASSWORD '{pass}'")
151 0 : .expect("String is documented to not to error during write operations");
152 1 : } else {
153 1 : write!(params, " PASSWORD 'md5{pass}'")
154 1 : .expect("String is documented to not to error during write operations");
155 1 : }
156 0 : } else {
157 0 : params.push_str(" PASSWORD NULL");
158 0 : }
159 :
160 1 : params
161 1 : }
162 : }
163 :
164 : pub trait DatabaseExt {
165 : fn to_pg_options(&self) -> String;
166 : }
167 :
168 : impl DatabaseExt for Database {
169 : /// Serialize a list of database parameters into a Postgres-acceptable
170 : /// string of arguments.
171 : /// NB: `TEMPLATE` is actually also an identifier, but so far we only need
172 : /// to use `template0` and `template1`, so it is not a problem. Yet in the future
173 : /// it may require a proper quoting too.
174 1 : fn to_pg_options(&self) -> String {
175 1 : let mut params: String = self.options.as_pg_options();
176 1 : write!(params, " OWNER {}", &self.owner.pg_quote())
177 1 : .expect("String is documented to not to error during write operations");
178 1 :
179 1 : params
180 1 : }
181 : }
182 :
183 : /// Generic trait used to provide quoting / encoding for strings used in the
184 : /// Postgres SQL queries and DATABASE_URL.
185 : pub trait Escaping {
186 : fn pg_quote(&self) -> String;
187 : }
188 :
189 : impl Escaping for PgIdent {
190 : /// This is intended to mimic Postgres quote_ident(), but for simplicity it
191 : /// always quotes provided string with `""` and escapes every `"`.
192 : /// **Not idempotent**, i.e. if string is already escaped it will be escaped again.
193 2 : fn pg_quote(&self) -> String {
194 2 : let result = format!("\"{}\"", self.replace('"', "\"\""));
195 2 : result
196 2 : }
197 : }
198 :
199 : /// Build a list of existing Postgres roles
200 0 : pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
201 0 : let postgres_roles = xact
202 0 : .query("SELECT rolname, rolpassword FROM pg_catalog.pg_authid", &[])?
203 0 : .iter()
204 0 : .map(|row| Role {
205 0 : name: row.get("rolname"),
206 0 : encrypted_password: row.get("rolpassword"),
207 0 : options: None,
208 0 : })
209 0 : .collect();
210 0 :
211 0 : Ok(postgres_roles)
212 0 : }
213 :
214 : /// Build a list of existing Postgres databases
215 0 : pub fn get_existing_dbs(client: &mut Client) -> Result<HashMap<String, Database>> {
216 : // `pg_database.datconnlimit = -2` means that the database is in the
217 : // invalid state. See:
218 : // https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
219 0 : let postgres_dbs: Vec<Database> = client
220 0 : .query(
221 0 : "SELECT
222 0 : datname AS name,
223 0 : datdba::regrole::text AS owner,
224 0 : NOT datallowconn AS restrict_conn,
225 0 : datconnlimit = - 2 AS invalid
226 0 : FROM
227 0 : pg_catalog.pg_database;",
228 0 : &[],
229 0 : )?
230 0 : .iter()
231 0 : .map(|row| Database {
232 0 : name: row.get("name"),
233 0 : owner: row.get("owner"),
234 0 : restrict_conn: row.get("restrict_conn"),
235 0 : invalid: row.get("invalid"),
236 0 : options: None,
237 0 : })
238 0 : .collect();
239 0 :
240 0 : let dbs_map = postgres_dbs
241 0 : .iter()
242 0 : .map(|db| (db.name.clone(), db.clone()))
243 0 : .collect::<HashMap<_, _>>();
244 0 :
245 0 : Ok(dbs_map)
246 0 : }
247 :
248 : /// Wait for Postgres to become ready to accept connections. It's ready to
249 : /// accept connections when the state-field in `pgdata/postmaster.pid` says
250 : /// 'ready'.
251 0 : #[instrument(skip_all, fields(pgdata = %pgdata.display()))]
252 : pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
253 : let pid_path = pgdata.join("postmaster.pid");
254 :
255 : // PostgreSQL writes line "ready" to the postmaster.pid file, when it has
256 : // completed initialization and is ready to accept connections. We want to
257 : // react quickly and perform the rest of our initialization as soon as
258 : // PostgreSQL starts accepting connections. Use 'notify' to be notified
259 : // whenever the PID file is changed, and whenever it changes, read it to
260 : // check if it's now "ready".
261 : //
262 : // You cannot actually watch a file before it exists, so we first watch the
263 : // data directory, and once the postmaster.pid file appears, we switch to
264 : // watch the file instead. We also wake up every 100 ms to poll, just in
265 : // case we miss some events for some reason. Not strictly necessary, but
266 : // better safe than sorry.
267 : let (tx, rx) = std::sync::mpsc::channel();
268 0 : let watcher_res = notify::recommended_watcher(move |res| {
269 0 : let _ = tx.send(res);
270 0 : });
271 : let (mut watcher, rx): (Box<dyn Watcher>, _) = match watcher_res {
272 : Ok(watcher) => (Box::new(watcher), rx),
273 : Err(e) => {
274 : match e.kind {
275 : notify::ErrorKind::Io(os) if os.raw_os_error() == Some(38) => {
276 : // docker on m1 macs does not support recommended_watcher
277 : // but return "Function not implemented (os error 38)"
278 : // see https://github.com/notify-rs/notify/issues/423
279 : let (tx, rx) = std::sync::mpsc::channel();
280 :
281 : // let's poll it faster than what we check the results for (100ms)
282 : let config =
283 : notify::Config::default().with_poll_interval(Duration::from_millis(50));
284 :
285 : let watcher = notify::PollWatcher::new(
286 0 : move |res| {
287 0 : let _ = tx.send(res);
288 0 : },
289 : config,
290 : )?;
291 :
292 : (Box::new(watcher), rx)
293 : }
294 : _ => return Err(e.into()),
295 : }
296 : }
297 : };
298 :
299 : watcher.watch(pgdata, RecursiveMode::NonRecursive)?;
300 :
301 : let started_at = Instant::now();
302 : let mut postmaster_pid_seen = false;
303 : loop {
304 : if let Ok(Some(status)) = pg.try_wait() {
305 : // Postgres exited, that is not what we expected, bail out earlier.
306 : let code = status.code().unwrap_or(-1);
307 : bail!("Postgres exited unexpectedly with code {}", code);
308 : }
309 :
310 : let res = rx.recv_timeout(Duration::from_millis(100));
311 : debug!("woken up by notify: {res:?}");
312 : // If there are multiple events in the channel already, we only need to be
313 : // check once. Swallow the extra events before we go ahead to check the
314 : // pid file.
315 : while let Ok(res) = rx.try_recv() {
316 : debug!("swallowing extra event: {res:?}");
317 : }
318 :
319 : // Check that we can open pid file first.
320 : if let Ok(file) = File::open(&pid_path) {
321 : if !postmaster_pid_seen {
322 : debug!("postmaster.pid appeared");
323 : watcher
324 : .unwatch(pgdata)
325 : .expect("Failed to remove pgdata dir watch");
326 : watcher
327 : .watch(&pid_path, RecursiveMode::NonRecursive)
328 : .expect("Failed to add postmaster.pid file watch");
329 : postmaster_pid_seen = true;
330 : }
331 :
332 : let file = BufReader::new(file);
333 : let last_line = file.lines().last();
334 :
335 : // Pid file could be there and we could read it, but it could be empty, for example.
336 : if let Some(Ok(line)) = last_line {
337 : let status = line.trim();
338 : debug!("last line of postmaster.pid: {status:?}");
339 :
340 : // Now Postgres is ready to accept connections
341 : if status == "ready" {
342 : break;
343 : }
344 : }
345 : }
346 :
347 : // Give up after POSTGRES_WAIT_TIMEOUT.
348 : let duration = started_at.elapsed();
349 : if duration >= POSTGRES_WAIT_TIMEOUT {
350 : bail!("timed out while waiting for Postgres to start");
351 : }
352 : }
353 :
354 : tracing::info!("PostgreSQL is now running, continuing to configure it");
355 :
356 : Ok(())
357 : }
358 :
359 : /// Remove `pgdata` directory and create it again with right permissions.
360 0 : pub fn create_pgdata(pgdata: &str) -> Result<()> {
361 0 : // Ignore removal error, likely it is a 'No such file or directory (os error 2)'.
362 0 : // If it is something different then create_dir() will error out anyway.
363 0 : let _ok = fs::remove_dir_all(pgdata);
364 0 : fs::create_dir(pgdata)?;
365 0 : fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
366 :
367 0 : Ok(())
368 0 : }
369 :
370 : /// Update pgbouncer.ini with provided options
371 0 : fn update_pgbouncer_ini(
372 0 : pgbouncer_config: HashMap<String, String>,
373 0 : pgbouncer_ini_path: &str,
374 0 : ) -> Result<()> {
375 0 : let mut conf = Ini::load_from_file(pgbouncer_ini_path)?;
376 0 : let section = conf.section_mut(Some("pgbouncer")).unwrap();
377 :
378 0 : for (option_name, value) in pgbouncer_config.iter() {
379 0 : section.insert(option_name, value);
380 0 : debug!(
381 0 : "Updating pgbouncer.ini with new values {}={}",
382 : option_name, value
383 : );
384 : }
385 :
386 0 : conf.write_to_file(pgbouncer_ini_path)?;
387 0 : Ok(())
388 0 : }
389 :
390 : /// Tune pgbouncer.
391 : /// 1. Apply new config using pgbouncer admin console
392 : /// 2. Add new values to pgbouncer.ini to preserve them after restart
393 0 : pub async fn tune_pgbouncer(pgbouncer_config: HashMap<String, String>) -> Result<()> {
394 0 : let pgbouncer_connstr = if std::env::var_os("AUTOSCALING").is_some() {
395 : // for VMs use pgbouncer specific way to connect to
396 : // pgbouncer admin console without password
397 : // when pgbouncer is running under the same user.
398 0 : "host=/tmp port=6432 dbname=pgbouncer user=pgbouncer".to_string()
399 : } else {
400 : // for k8s use normal connection string with password
401 : // to connect to pgbouncer admin console
402 0 : let mut pgbouncer_connstr =
403 0 : "host=localhost port=6432 dbname=pgbouncer user=postgres sslmode=disable".to_string();
404 0 : if let Ok(pass) = std::env::var("PGBOUNCER_PASSWORD") {
405 0 : pgbouncer_connstr.push_str(format!(" password={}", pass).as_str());
406 0 : }
407 0 : pgbouncer_connstr
408 : };
409 :
410 0 : info!(
411 0 : "Connecting to pgbouncer with connection string: {}",
412 : pgbouncer_connstr
413 : );
414 :
415 : // connect to pgbouncer, retrying several times
416 : // because pgbouncer may not be ready yet
417 0 : let mut retries = 3;
418 0 : let client = loop {
419 0 : match tokio_postgres::connect(&pgbouncer_connstr, NoTls).await {
420 0 : Ok((client, connection)) => {
421 0 : tokio::spawn(async move {
422 0 : if let Err(e) = connection.await {
423 0 : eprintln!("connection error: {}", e);
424 0 : }
425 0 : });
426 0 : break client;
427 : }
428 0 : Err(e) => {
429 0 : if retries == 0 {
430 0 : return Err(e.into());
431 0 : }
432 0 : error!("Failed to connect to pgbouncer: pgbouncer_connstr {}", e);
433 0 : retries -= 1;
434 0 : tokio::time::sleep(Duration::from_secs(1)).await;
435 : }
436 : }
437 : };
438 :
439 : // Apply new config
440 0 : for (option_name, value) in pgbouncer_config.iter() {
441 0 : let query = format!("SET {}={}", option_name, value);
442 0 : // keep this log line for debugging purposes
443 0 : info!("Applying pgbouncer setting change: {}", query);
444 :
445 0 : if let Err(err) = client.simple_query(&query).await {
446 : // Don't fail on error, just print it into log
447 0 : error!(
448 0 : "Failed to apply pgbouncer setting change: {}, {}",
449 : query, err
450 : );
451 0 : };
452 : }
453 :
454 : // save values to pgbouncer.ini
455 : // so that they are preserved after pgbouncer restart
456 0 : let pgbouncer_ini_path = if std::env::var_os("AUTOSCALING").is_some() {
457 : // in VMs we use /etc/pgbouncer.ini
458 0 : "/etc/pgbouncer.ini".to_string()
459 : } else {
460 : // in pods we use /var/db/postgres/pgbouncer/pgbouncer.ini
461 : // this is a shared volume between pgbouncer and postgres containers
462 : // FIXME: fix permissions for this file
463 0 : "/var/db/postgres/pgbouncer/pgbouncer.ini".to_string()
464 : };
465 0 : update_pgbouncer_ini(pgbouncer_config, &pgbouncer_ini_path)?;
466 :
467 0 : Ok(())
468 0 : }
469 :
470 : /// Spawn a thread that will read Postgres logs from `stderr`, join multiline logs
471 : /// and send them to the logger. In the future we may also want to add context to
472 : /// these logs.
473 0 : pub fn handle_postgres_logs(stderr: std::process::ChildStderr) -> JoinHandle<()> {
474 0 : std::thread::spawn(move || {
475 0 : let runtime = tokio::runtime::Builder::new_current_thread()
476 0 : .enable_all()
477 0 : .build()
478 0 : .expect("failed to build tokio runtime");
479 0 :
480 0 : let res = runtime.block_on(async move {
481 0 : let stderr = tokio::process::ChildStderr::from_std(stderr)?;
482 0 : handle_postgres_logs_async(stderr).await
483 0 : });
484 0 : if let Err(e) = res {
485 0 : tracing::error!("error while processing postgres logs: {}", e);
486 0 : }
487 0 : })
488 0 : }
489 :
490 : /// Read Postgres logs from `stderr` until EOF. Buffer is flushed on one of the following conditions:
491 : /// - next line starts with timestamp
492 : /// - EOF
493 : /// - no new lines were written for the last 100 milliseconds
494 0 : async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Result<()> {
495 0 : let mut lines = tokio::io::BufReader::new(stderr).lines();
496 0 : let timeout_duration = Duration::from_millis(100);
497 0 : let ts_regex =
498 0 : regex::Regex::new(r"^\d+-\d{2}-\d{2} \d{2}:\d{2}:\d{2}").expect("regex is valid");
499 0 :
500 0 : let mut buf = vec![];
501 : loop {
502 0 : let next_line = timeout(timeout_duration, lines.next_line()).await;
503 :
504 : // we should flush lines from the buffer if we cannot continue reading multiline message
505 0 : let should_flush_buf = match next_line {
506 : // Flushing if new line starts with timestamp
507 0 : Ok(Ok(Some(ref line))) => ts_regex.is_match(line),
508 : // Flushing on EOF, timeout or error
509 0 : _ => true,
510 : };
511 :
512 0 : if !buf.is_empty() && should_flush_buf {
513 : // join multiline message into a single line, separated by unicode Zero Width Space.
514 : // "PG:" suffix is used to distinguish postgres logs from other logs.
515 0 : let combined = format!("PG:{}\n", buf.join("\u{200B}"));
516 0 : buf.clear();
517 :
518 : // sync write to stderr to avoid interleaving with other logs
519 : use std::io::Write;
520 0 : let res = std::io::stderr().lock().write_all(combined.as_bytes());
521 0 : if let Err(e) = res {
522 0 : tracing::error!("error while writing to stderr: {}", e);
523 0 : }
524 0 : }
525 :
526 : // if not timeout, append line to the buffer
527 0 : if next_line.is_ok() {
528 0 : match next_line?? {
529 0 : Some(line) => buf.push(line),
530 : // EOF
531 0 : None => break,
532 : };
533 0 : }
534 : }
535 :
536 0 : Ok(())
537 0 : }
|