Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt::Write;
3 : use std::fs;
4 : use std::fs::File;
5 : use std::io::{BufRead, BufReader};
6 : use std::os::unix::fs::PermissionsExt;
7 : use std::path::Path;
8 : use std::process::Child;
9 : use std::thread::JoinHandle;
10 : use std::time::{Duration, Instant};
11 :
12 : use anyhow::{bail, Result};
13 : use ini::Ini;
14 : use notify::{RecursiveMode, Watcher};
15 : use postgres::{Client, Transaction};
16 : use tokio::io::AsyncBufReadExt;
17 : use tokio::time::timeout;
18 : use tokio_postgres::NoTls;
19 : use tracing::{debug, error, info, instrument};
20 :
21 : use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
22 :
23 : const POSTGRES_WAIT_TIMEOUT: Duration = Duration::from_millis(60 * 1000); // milliseconds
24 :
25 : /// Escape a string for including it in a SQL literal. Wrapping the result
26 : /// with `E'{}'` or `'{}'` is not required, as it returns a ready-to-use
27 : /// SQL string literal, e.g. `'db'''` or `E'db\\'`.
28 : /// See <https://github.com/postgres/postgres/blob/da98d005cdbcd45af563d0c4ac86d0e9772cd15f/src/backend/utils/adt/quote.c#L47>
29 : /// for the original implementation.
30 12 : pub fn escape_literal(s: &str) -> String {
31 12 : let res = s.replace('\'', "''").replace('\\', "\\\\");
32 12 :
33 12 : if res.contains('\\') {
34 4 : format!("E'{}'", res)
35 : } else {
36 8 : format!("'{}'", res)
37 : }
38 12 : }
39 :
40 : /// Escape a string so that it can be used in postgresql.conf. Wrapping the result
41 : /// with `'{}'` is not required, as it returns a ready-to-use config string.
42 3151 : pub fn escape_conf_value(s: &str) -> String {
43 3151 : let res = s.replace('\'', "''").replace('\\', "\\\\");
44 3151 : format!("'{}'", res)
45 3151 : }
46 :
47 : trait GenericOptionExt {
48 : fn to_pg_option(&self) -> String;
49 : fn to_pg_setting(&self) -> String;
50 : }
51 :
52 : impl GenericOptionExt for GenericOption {
53 : /// Represent `GenericOption` as SQL statement parameter.
54 6 : fn to_pg_option(&self) -> String {
55 6 : if let Some(val) = &self.value {
56 6 : match self.vartype.as_ref() {
57 6 : "string" => format!("{} {}", self.name, escape_literal(val)),
58 2 : _ => format!("{} {}", self.name, val),
59 : }
60 : } else {
61 0 : self.name.to_owned()
62 : }
63 6 : }
64 :
65 : /// Represent `GenericOption` as configuration option.
66 46 : fn to_pg_setting(&self) -> String {
67 46 : if let Some(val) = &self.value {
68 46 : match self.vartype.as_ref() {
69 46 : "string" => format!("{} = {}", self.name, escape_conf_value(val)),
70 30 : _ => format!("{} = {}", self.name, val),
71 : }
72 : } else {
73 0 : self.name.to_owned()
74 : }
75 46 : }
76 : }
77 :
78 : pub trait PgOptionsSerialize {
79 : fn as_pg_options(&self) -> String;
80 : fn as_pg_settings(&self) -> String;
81 : }
82 :
83 : impl PgOptionsSerialize for GenericOptions {
84 : /// Serialize an optional collection of `GenericOption`'s to
85 : /// Postgres SQL statement arguments.
86 4 : fn as_pg_options(&self) -> String {
87 4 : if let Some(ops) = &self {
88 2 : ops.iter()
89 6 : .map(|op| op.to_pg_option())
90 2 : .collect::<Vec<String>>()
91 2 : .join(" ")
92 : } else {
93 2 : "".to_string()
94 : }
95 4 : }
96 :
97 : /// Serialize an optional collection of `GenericOption`'s to
98 : /// `postgresql.conf` compatible format.
99 2 : fn as_pg_settings(&self) -> String {
100 2 : if let Some(ops) = &self {
101 2 : ops.iter()
102 46 : .map(|op| op.to_pg_setting())
103 2 : .collect::<Vec<String>>()
104 2 : .join("\n")
105 2 : + "\n" // newline after last setting
106 : } else {
107 0 : "".to_string()
108 : }
109 2 : }
110 : }
111 :
112 : pub trait GenericOptionsSearch {
113 : fn find(&self, name: &str) -> Option<String>;
114 : fn find_ref(&self, name: &str) -> Option<&GenericOption>;
115 : }
116 :
117 : impl GenericOptionsSearch for GenericOptions {
118 : /// Lookup option by name
119 248 : fn find(&self, name: &str) -> Option<String> {
120 248 : let ops = self.as_ref()?;
121 12 : let op = ops.iter().find(|s| s.name == name)?;
122 4 : op.value.clone()
123 248 : }
124 :
125 : /// Lookup option by name, returning ref
126 0 : fn find_ref(&self, name: &str) -> Option<&GenericOption> {
127 0 : let ops = self.as_ref()?;
128 0 : ops.iter().find(|s| s.name == name)
129 0 : }
130 : }
131 :
132 : pub trait RoleExt {
133 : fn to_pg_options(&self) -> String;
134 : }
135 :
136 : impl RoleExt for Role {
137 : /// Serialize a list of role parameters into a Postgres-acceptable
138 : /// string of arguments.
139 2 : fn to_pg_options(&self) -> String {
140 2 : // XXX: consider putting LOGIN as a default option somewhere higher, e.g. in control-plane.
141 2 : let mut params: String = self.options.as_pg_options();
142 2 : params.push_str(" LOGIN");
143 :
144 2 : if let Some(pass) = &self.encrypted_password {
145 : // Some time ago we supported only md5 and treated all encrypted_password as md5.
146 : // Now we also support SCRAM-SHA-256 and to preserve compatibility
147 : // we treat all encrypted_password as md5 unless they starts with SCRAM-SHA-256.
148 2 : if pass.starts_with("SCRAM-SHA-256") {
149 0 : write!(params, " PASSWORD '{pass}'")
150 0 : .expect("String is documented to not to error during write operations");
151 2 : } else {
152 2 : write!(params, " PASSWORD 'md5{pass}'")
153 2 : .expect("String is documented to not to error during write operations");
154 2 : }
155 0 : } else {
156 0 : params.push_str(" PASSWORD NULL");
157 0 : }
158 :
159 2 : params
160 2 : }
161 : }
162 :
163 : pub trait DatabaseExt {
164 : fn to_pg_options(&self) -> String;
165 : }
166 :
167 : impl DatabaseExt for Database {
168 : /// Serialize a list of database parameters into a Postgres-acceptable
169 : /// string of arguments.
170 : /// NB: `TEMPLATE` is actually also an identifier, but so far we only need
171 : /// to use `template0` and `template1`, so it is not a problem. Yet in the future
172 : /// it may require a proper quoting too.
173 2 : fn to_pg_options(&self) -> String {
174 2 : let mut params: String = self.options.as_pg_options();
175 2 : write!(params, " OWNER {}", &self.owner.pg_quote())
176 2 : .expect("String is documented to not to error during write operations");
177 2 :
178 2 : params
179 2 : }
180 : }
181 :
182 : /// Generic trait used to provide quoting / encoding for strings used in the
183 : /// Postgres SQL queries and DATABASE_URL.
184 : pub trait Escaping {
185 : fn pg_quote(&self) -> String;
186 : }
187 :
188 : impl Escaping for PgIdent {
189 : /// This is intended to mimic Postgres quote_ident(), but for simplicity it
190 : /// always quotes provided string with `""` and escapes every `"`.
191 : /// **Not idempotent**, i.e. if string is already escaped it will be escaped again.
192 5 : fn pg_quote(&self) -> String {
193 5 : let result = format!("\"{}\"", self.replace('"', "\"\""));
194 5 : result
195 5 : }
196 : }
197 :
198 : /// Build a list of existing Postgres roles
199 458 : pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
200 458 : let postgres_roles = xact
201 458 : .query("SELECT rolname, rolpassword FROM pg_catalog.pg_authid", &[])?
202 458 : .iter()
203 5512 : .map(|row| Role {
204 5512 : name: row.get("rolname"),
205 5512 : encrypted_password: row.get("rolpassword"),
206 5512 : options: None,
207 5512 : })
208 458 : .collect();
209 458 :
210 458 : Ok(postgres_roles)
211 458 : }
212 :
213 : /// Build a list of existing Postgres databases
214 916 : pub fn get_existing_dbs(client: &mut Client) -> Result<HashMap<String, Database>> {
215 : // `pg_database.datconnlimit = -2` means that the database is in the
216 : // invalid state. See:
217 : // https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
218 916 : let postgres_dbs: Vec<Database> = client
219 916 : .query(
220 916 : "SELECT
221 916 : datname AS name,
222 916 : datdba::regrole::text AS owner,
223 916 : NOT datallowconn AS restrict_conn,
224 916 : datconnlimit = - 2 AS invalid
225 916 : FROM
226 916 : pg_catalog.pg_database;",
227 916 : &[],
228 916 : )?
229 916 : .iter()
230 2749 : .map(|row| Database {
231 2749 : name: row.get("name"),
232 2749 : owner: row.get("owner"),
233 2749 : restrict_conn: row.get("restrict_conn"),
234 2749 : invalid: row.get("invalid"),
235 2749 : options: None,
236 2749 : })
237 916 : .collect();
238 916 :
239 916 : let dbs_map = postgres_dbs
240 916 : .iter()
241 2749 : .map(|db| (db.name.clone(), db.clone()))
242 916 : .collect::<HashMap<_, _>>();
243 916 :
244 916 : Ok(dbs_map)
245 916 : }
246 :
247 : /// Wait for Postgres to become ready to accept connections. It's ready to
248 : /// accept connections when the state-field in `pgdata/postmaster.pid` says
249 : /// 'ready'.
250 575 : #[instrument(skip_all, fields(pgdata = %pgdata.display()))]
251 : pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
252 : let pid_path = pgdata.join("postmaster.pid");
253 :
254 : // PostgreSQL writes line "ready" to the postmaster.pid file, when it has
255 : // completed initialization and is ready to accept connections. We want to
256 : // react quickly and perform the rest of our initialization as soon as
257 : // PostgreSQL starts accepting connections. Use 'notify' to be notified
258 : // whenever the PID file is changed, and whenever it changes, read it to
259 : // check if it's now "ready".
260 : //
261 : // You cannot actually watch a file before it exists, so we first watch the
262 : // data directory, and once the postmaster.pid file appears, we switch to
263 : // watch the file instead. We also wake up every 100 ms to poll, just in
264 : // case we miss some events for some reason. Not strictly necessary, but
265 : // better safe than sorry.
266 : let (tx, rx) = std::sync::mpsc::channel();
267 7414 : let (mut watcher, rx): (Box<dyn Watcher>, _) = match notify::recommended_watcher(move |res| {
268 7414 : let _ = tx.send(res);
269 7414 : }) {
270 : Ok(watcher) => (Box::new(watcher), rx),
271 : Err(e) => {
272 : match e.kind {
273 : notify::ErrorKind::Io(os) if os.raw_os_error() == Some(38) => {
274 : // docker on m1 macs does not support recommended_watcher
275 : // but return "Function not implemented (os error 38)"
276 : // see https://github.com/notify-rs/notify/issues/423
277 : let (tx, rx) = std::sync::mpsc::channel();
278 :
279 : // let's poll it faster than what we check the results for (100ms)
280 : let config =
281 : notify::Config::default().with_poll_interval(Duration::from_millis(50));
282 :
283 : let watcher = notify::PollWatcher::new(
284 0 : move |res| {
285 0 : let _ = tx.send(res);
286 0 : },
287 : config,
288 : )?;
289 :
290 : (Box::new(watcher), rx)
291 : }
292 : _ => return Err(e.into()),
293 : }
294 : }
295 : };
296 :
297 : watcher.watch(pgdata, RecursiveMode::NonRecursive)?;
298 :
299 : let started_at = Instant::now();
300 : let mut postmaster_pid_seen = false;
301 : loop {
302 : if let Ok(Some(status)) = pg.try_wait() {
303 : // Postgres exited, that is not what we expected, bail out earlier.
304 : let code = status.code().unwrap_or(-1);
305 : bail!("Postgres exited unexpectedly with code {}", code);
306 : }
307 :
308 : let res = rx.recv_timeout(Duration::from_millis(100));
309 0 : debug!("woken up by notify: {res:?}");
310 : // If there are multiple events in the channel already, we only need to be
311 : // check once. Swallow the extra events before we go ahead to check the
312 : // pid file.
313 : while let Ok(res) = rx.try_recv() {
314 0 : debug!("swallowing extra event: {res:?}");
315 : }
316 :
317 : // Check that we can open pid file first.
318 : if let Ok(file) = File::open(&pid_path) {
319 : if !postmaster_pid_seen {
320 0 : debug!("postmaster.pid appeared");
321 : watcher
322 : .unwatch(pgdata)
323 : .expect("Failed to remove pgdata dir watch");
324 : watcher
325 : .watch(&pid_path, RecursiveMode::NonRecursive)
326 : .expect("Failed to add postmaster.pid file watch");
327 : postmaster_pid_seen = true;
328 : }
329 :
330 : let file = BufReader::new(file);
331 : let last_line = file.lines().last();
332 :
333 : // Pid file could be there and we could read it, but it could be empty, for example.
334 : if let Some(Ok(line)) = last_line {
335 : let status = line.trim();
336 0 : debug!("last line of postmaster.pid: {status:?}");
337 :
338 : // Now Postgres is ready to accept connections
339 : if status == "ready" {
340 : break;
341 : }
342 : }
343 : }
344 :
345 : // Give up after POSTGRES_WAIT_TIMEOUT.
346 : let duration = started_at.elapsed();
347 : if duration >= POSTGRES_WAIT_TIMEOUT {
348 : bail!("timed out while waiting for Postgres to start");
349 : }
350 : }
351 :
352 575 : tracing::info!("PostgreSQL is now running, continuing to configure it");
353 :
354 : Ok(())
355 : }
356 :
357 : /// Remove `pgdata` directory and create it again with right permissions.
358 0 : pub fn create_pgdata(pgdata: &str) -> Result<()> {
359 0 : // Ignore removal error, likely it is a 'No such file or directory (os error 2)'.
360 0 : // If it is something different then create_dir() will error out anyway.
361 0 : let _ok = fs::remove_dir_all(pgdata);
362 0 : fs::create_dir(pgdata)?;
363 0 : fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
364 :
365 0 : Ok(())
366 0 : }
367 :
368 : /// Update pgbouncer.ini with provided options
369 0 : fn update_pgbouncer_ini(
370 0 : pgbouncer_config: HashMap<String, String>,
371 0 : pgbouncer_ini_path: &str,
372 0 : ) -> Result<()> {
373 0 : let mut conf = Ini::load_from_file(pgbouncer_ini_path)?;
374 0 : let section = conf.section_mut(Some("pgbouncer")).unwrap();
375 :
376 0 : for (option_name, value) in pgbouncer_config.iter() {
377 0 : section.insert(option_name, value);
378 0 : debug!(
379 0 : "Updating pgbouncer.ini with new values {}={}",
380 0 : option_name, value
381 0 : );
382 : }
383 :
384 0 : conf.write_to_file(pgbouncer_ini_path)?;
385 0 : Ok(())
386 0 : }
387 :
388 : /// Tune pgbouncer.
389 : /// 1. Apply new config using pgbouncer admin console
390 : /// 2. Add new values to pgbouncer.ini to preserve them after restart
391 0 : pub async fn tune_pgbouncer(pgbouncer_config: HashMap<String, String>) -> Result<()> {
392 0 : let pgbouncer_connstr = if std::env::var_os("AUTOSCALING").is_some() {
393 : // for VMs use pgbouncer specific way to connect to
394 : // pgbouncer admin console without password
395 : // when pgbouncer is running under the same user.
396 0 : "host=/tmp port=6432 dbname=pgbouncer user=pgbouncer".to_string()
397 : } else {
398 : // for k8s use normal connection string with password
399 : // to connect to pgbouncer admin console
400 0 : let mut pgbouncer_connstr =
401 0 : "host=localhost port=6432 dbname=pgbouncer user=postgres sslmode=disable".to_string();
402 0 : if let Ok(pass) = std::env::var("PGBOUNCER_PASSWORD") {
403 0 : pgbouncer_connstr.push_str(format!(" password={}", pass).as_str());
404 0 : }
405 0 : pgbouncer_connstr
406 : };
407 :
408 0 : info!(
409 0 : "Connecting to pgbouncer with connection string: {}",
410 0 : pgbouncer_connstr
411 0 : );
412 :
413 : // connect to pgbouncer, retrying several times
414 : // because pgbouncer may not be ready yet
415 0 : let mut retries = 3;
416 0 : let client = loop {
417 0 : match tokio_postgres::connect(&pgbouncer_connstr, NoTls).await {
418 0 : Ok((client, connection)) => {
419 0 : tokio::spawn(async move {
420 0 : if let Err(e) = connection.await {
421 0 : eprintln!("connection error: {}", e);
422 0 : }
423 0 : });
424 0 : break client;
425 : }
426 0 : Err(e) => {
427 0 : if retries == 0 {
428 0 : return Err(e.into());
429 0 : }
430 0 : error!("Failed to connect to pgbouncer: pgbouncer_connstr {}", e);
431 0 : retries -= 1;
432 0 : tokio::time::sleep(Duration::from_secs(1)).await;
433 : }
434 : }
435 : };
436 :
437 : // Apply new config
438 0 : for (option_name, value) in pgbouncer_config.iter() {
439 0 : let query = format!("SET {}={}", option_name, value);
440 : // keep this log line for debugging purposes
441 0 : info!("Applying pgbouncer setting change: {}", query);
442 :
443 0 : if let Err(err) = client.simple_query(&query).await {
444 : // Don't fail on error, just print it into log
445 0 : error!(
446 0 : "Failed to apply pgbouncer setting change: {}, {}",
447 0 : query, err
448 0 : );
449 0 : };
450 : }
451 :
452 : // save values to pgbouncer.ini
453 : // so that they are preserved after pgbouncer restart
454 0 : let pgbouncer_ini_path = if std::env::var_os("AUTOSCALING").is_some() {
455 : // in VMs we use /etc/pgbouncer.ini
456 0 : "/etc/pgbouncer.ini".to_string()
457 : } else {
458 : // in pods we use /var/db/postgres/pgbouncer/pgbouncer.ini
459 : // this is a shared volume between pgbouncer and postgres containers
460 : // FIXME: fix permissions for this file
461 0 : "/var/db/postgres/pgbouncer/pgbouncer.ini".to_string()
462 : };
463 0 : update_pgbouncer_ini(pgbouncer_config, &pgbouncer_ini_path)?;
464 :
465 0 : Ok(())
466 0 : }
467 :
468 : /// Spawn a thread that will read Postgres logs from `stderr`, join multiline logs
469 : /// and send them to the logger. In the future we may also want to add context to
470 : /// these logs.
471 1506 : pub fn handle_postgres_logs(stderr: std::process::ChildStderr) -> JoinHandle<()> {
472 1506 : std::thread::spawn(move || {
473 1506 : let runtime = tokio::runtime::Builder::new_current_thread()
474 1506 : .enable_all()
475 1506 : .build()
476 1506 : .expect("failed to build tokio runtime");
477 1506 :
478 1506 : let res = runtime.block_on(async move {
479 1506 : let stderr = tokio::process::ChildStderr::from_std(stderr)?;
480 155153 : handle_postgres_logs_async(stderr).await
481 1506 : });
482 1506 : if let Err(e) = res {
483 0 : tracing::error!("error while processing postgres logs: {}", e);
484 1506 : }
485 1506 : })
486 1506 : }
487 :
488 : /// Read Postgres logs from `stderr` until EOF. Buffer is flushed on one of the following conditions:
489 : /// - next line starts with timestamp
490 : /// - EOF
491 : /// - no new lines were written for the last second
492 1506 : async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Result<()> {
493 1506 : let mut lines = tokio::io::BufReader::new(stderr).lines();
494 1506 : let timeout_duration = Duration::from_millis(100);
495 1506 : let ts_regex =
496 1506 : regex::Regex::new(r"^\d+-\d{2}-\d{2} \d{2}:\d{2}:\d{2}").expect("regex is valid");
497 1506 :
498 1506 : let mut buf = vec![];
499 : loop {
500 3004103 : let next_line = timeout(timeout_duration, lines.next_line()).await;
501 :
502 : // we should flush lines from the buffer if we cannot continue reading multiline message
503 3004103 : let should_flush_buf = match next_line {
504 : // Flushing if new line starts with timestamp
505 2946824 : Ok(Ok(Some(ref line))) => ts_regex.is_match(line),
506 : // Flushing on EOF, timeout or error
507 57279 : _ => true,
508 : };
509 :
510 3004103 : if !buf.is_empty() && should_flush_buf {
511 : // join multiline message into a single line, separated by unicode Zero Width Space.
512 : // "PG:" suffix is used to distinguish postgres logs from other logs.
513 1986562 : let combined = format!("PG:{}\n", buf.join("\u{200B}"));
514 1986562 : buf.clear();
515 1986562 :
516 1986562 : // sync write to stderr to avoid interleaving with other logs
517 1986562 : use std::io::Write;
518 1986562 : let res = std::io::stderr().lock().write_all(combined.as_bytes());
519 1986562 : if let Err(e) = res {
520 0 : tracing::error!("error while writing to stderr: {}", e);
521 1986562 : }
522 1017541 : }
523 :
524 : // if not timeout, append line to the buffer
525 3004103 : if next_line.is_ok() {
526 2948330 : match next_line?? {
527 2946824 : Some(line) => buf.push(line),
528 : // EOF
529 1506 : None => break,
530 : };
531 55773 : }
532 : }
533 :
534 1506 : Ok(())
535 1506 : }
|