Line data Source code
1 : use compute_api::{
2 : responses::CatalogObjects,
3 : spec::{Database, Role},
4 : };
5 : use futures::Stream;
6 : use postgres::{Client, NoTls};
7 : use std::{path::Path, process::Stdio, result::Result, sync::Arc};
8 : use tokio::{
9 : io::{AsyncBufReadExt, BufReader},
10 : process::Command,
11 : task,
12 : };
13 : use tokio_stream::{self as stream, StreamExt};
14 : use tokio_util::codec::{BytesCodec, FramedRead};
15 : use tracing::warn;
16 :
17 : use crate::{
18 : compute::ComputeNode,
19 : pg_helpers::{get_existing_dbs, get_existing_roles},
20 : };
21 :
22 0 : pub async fn get_dbs_and_roles(compute: &Arc<ComputeNode>) -> anyhow::Result<CatalogObjects> {
23 0 : let connstr = compute.connstr.clone();
24 0 : task::spawn_blocking(move || {
25 0 : let mut client = Client::connect(connstr.as_str(), NoTls)?;
26 : let roles: Vec<Role>;
27 : {
28 0 : let mut xact = client.transaction()?;
29 0 : roles = get_existing_roles(&mut xact)?;
30 : }
31 0 : let databases: Vec<Database> = get_existing_dbs(&mut client)?.values().cloned().collect();
32 0 :
33 0 : Ok(CatalogObjects { roles, databases })
34 0 : })
35 0 : .await?
36 0 : }
37 :
38 0 : #[derive(Debug, thiserror::Error)]
39 : pub enum SchemaDumpError {
40 : #[error("Database does not exist.")]
41 : DatabaseDoesNotExist,
42 : #[error("Failed to execute pg_dump.")]
43 : IO(#[from] std::io::Error),
44 : }
45 :
46 : // It uses the pg_dump utility to dump the schema of the specified database.
47 : // The output is streamed back to the caller and supposed to be streamed via HTTP.
48 : //
49 : // Before return the result with the output, it checks that pg_dump produced any output.
50 : // If not, it tries to parse the stderr output to determine if the database does not exist
51 : // and special error is returned.
52 : //
53 : // To make sure that the process is killed when the caller drops the stream, we use tokio kill_on_drop feature.
54 0 : pub async fn get_database_schema(
55 0 : compute: &Arc<ComputeNode>,
56 0 : dbname: &str,
57 0 : ) -> Result<impl Stream<Item = Result<bytes::Bytes, std::io::Error>>, SchemaDumpError> {
58 0 : let pgbin = &compute.pgbin;
59 0 : let basepath = Path::new(pgbin).parent().unwrap();
60 0 : let pgdump = basepath.join("pg_dump");
61 0 : let mut connstr = compute.connstr.clone();
62 0 : connstr.set_path(dbname);
63 0 : let mut cmd = Command::new(pgdump)
64 0 : .arg("--schema-only")
65 0 : .arg(connstr.as_str())
66 0 : .stdout(Stdio::piped())
67 0 : .stderr(Stdio::piped())
68 0 : .kill_on_drop(true)
69 0 : .spawn()?;
70 :
71 0 : let stdout = cmd.stdout.take().ok_or_else(|| {
72 0 : std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stdout.")
73 0 : })?;
74 :
75 0 : let stderr = cmd.stderr.take().ok_or_else(|| {
76 0 : std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stderr.")
77 0 : })?;
78 :
79 0 : let mut stdout_reader = FramedRead::new(stdout, BytesCodec::new());
80 0 : let stderr_reader = BufReader::new(stderr);
81 :
82 0 : let first_chunk = match stdout_reader.next().await {
83 0 : Some(Ok(bytes)) if !bytes.is_empty() => bytes,
84 0 : Some(Err(e)) => {
85 0 : return Err(SchemaDumpError::IO(e));
86 : }
87 : _ => {
88 0 : let mut lines = stderr_reader.lines();
89 0 : if let Some(line) = lines.next_line().await? {
90 0 : if line.contains(&format!("FATAL: database \"{}\" does not exist", dbname)) {
91 0 : return Err(SchemaDumpError::DatabaseDoesNotExist);
92 0 : }
93 0 : warn!("pg_dump stderr: {}", line)
94 0 : }
95 0 : tokio::spawn(async move {
96 0 : while let Ok(Some(line)) = lines.next_line().await {
97 0 : warn!("pg_dump stderr: {}", line)
98 : }
99 0 : });
100 0 :
101 0 : return Err(SchemaDumpError::IO(std::io::Error::new(
102 0 : std::io::ErrorKind::Other,
103 0 : "failed to start pg_dump",
104 0 : )));
105 : }
106 : };
107 0 : let initial_stream = stream::once(Ok(first_chunk.freeze()));
108 0 : // Consume stderr and log warnings
109 0 : tokio::spawn(async move {
110 0 : let mut lines = stderr_reader.lines();
111 0 : while let Ok(Some(line)) = lines.next_line().await {
112 0 : warn!("pg_dump stderr: {}", line)
113 : }
114 0 : });
115 0 : Ok(initial_stream.chain(stdout_reader.map(|res| res.map(|b| b.freeze()))))
116 0 : }
|