Line data Source code
1 : use compute_api::responses::CatalogObjects;
2 : use futures::Stream;
3 : use postgres::NoTls;
4 : use std::{path::Path, process::Stdio, result::Result, sync::Arc};
5 : use tokio::{
6 : io::{AsyncBufReadExt, BufReader},
7 : process::Command,
8 : spawn,
9 : };
10 : use tokio_postgres::connect;
11 : use tokio_stream::{self as stream, StreamExt};
12 : use tokio_util::codec::{BytesCodec, FramedRead};
13 : use tracing::warn;
14 :
15 : use crate::compute::ComputeNode;
16 : use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async};
17 :
18 0 : pub async fn get_dbs_and_roles(compute: &Arc<ComputeNode>) -> anyhow::Result<CatalogObjects> {
19 0 : let connstr = compute.connstr.clone();
20 :
21 0 : let (client, connection): (tokio_postgres::Client, _) =
22 0 : connect(connstr.as_str(), NoTls).await?;
23 :
24 0 : spawn(async move {
25 0 : if let Err(e) = connection.await {
26 0 : eprintln!("connection error: {}", e);
27 0 : }
28 0 : });
29 :
30 0 : let roles = get_existing_roles_async(&client).await?;
31 :
32 0 : let databases = get_existing_dbs_async(&client)
33 0 : .await?
34 0 : .into_values()
35 0 : .collect();
36 0 :
37 0 : Ok(CatalogObjects { roles, databases })
38 0 : }
39 :
40 0 : #[derive(Debug, thiserror::Error)]
41 : pub enum SchemaDumpError {
42 : #[error("Database does not exist.")]
43 : DatabaseDoesNotExist,
44 : #[error("Failed to execute pg_dump.")]
45 : IO(#[from] std::io::Error),
46 : }
47 :
48 : // It uses the pg_dump utility to dump the schema of the specified database.
49 : // The output is streamed back to the caller and supposed to be streamed via HTTP.
50 : //
51 : // Before return the result with the output, it checks that pg_dump produced any output.
52 : // If not, it tries to parse the stderr output to determine if the database does not exist
53 : // and special error is returned.
54 : //
55 : // To make sure that the process is killed when the caller drops the stream, we use tokio kill_on_drop feature.
56 0 : pub async fn get_database_schema(
57 0 : compute: &Arc<ComputeNode>,
58 0 : dbname: &str,
59 0 : ) -> Result<impl Stream<Item = Result<bytes::Bytes, std::io::Error>>, SchemaDumpError> {
60 0 : let pgbin = &compute.pgbin;
61 0 : let basepath = Path::new(pgbin).parent().unwrap();
62 0 : let pgdump = basepath.join("pg_dump");
63 0 : let mut connstr = compute.connstr.clone();
64 0 : connstr.set_path(dbname);
65 0 : let mut cmd = Command::new(pgdump)
66 0 : .arg("--schema-only")
67 0 : .arg(connstr.as_str())
68 0 : .stdout(Stdio::piped())
69 0 : .stderr(Stdio::piped())
70 0 : .kill_on_drop(true)
71 0 : .spawn()?;
72 :
73 0 : let stdout = cmd.stdout.take().ok_or_else(|| {
74 0 : std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stdout.")
75 0 : })?;
76 :
77 0 : let stderr = cmd.stderr.take().ok_or_else(|| {
78 0 : std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stderr.")
79 0 : })?;
80 :
81 0 : let mut stdout_reader = FramedRead::new(stdout, BytesCodec::new());
82 0 : let stderr_reader = BufReader::new(stderr);
83 :
84 0 : let first_chunk = match stdout_reader.next().await {
85 0 : Some(Ok(bytes)) if !bytes.is_empty() => bytes,
86 0 : Some(Err(e)) => {
87 0 : return Err(SchemaDumpError::IO(e));
88 : }
89 : _ => {
90 0 : let mut lines = stderr_reader.lines();
91 0 : if let Some(line) = lines.next_line().await? {
92 0 : if line.contains(&format!("FATAL: database \"{}\" does not exist", dbname)) {
93 0 : return Err(SchemaDumpError::DatabaseDoesNotExist);
94 0 : }
95 0 : warn!("pg_dump stderr: {}", line)
96 0 : }
97 0 : tokio::spawn(async move {
98 0 : while let Ok(Some(line)) = lines.next_line().await {
99 0 : warn!("pg_dump stderr: {}", line)
100 : }
101 0 : });
102 0 :
103 0 : return Err(SchemaDumpError::IO(std::io::Error::new(
104 0 : std::io::ErrorKind::Other,
105 0 : "failed to start pg_dump",
106 0 : )));
107 : }
108 : };
109 0 : let initial_stream = stream::once(Ok(first_chunk.freeze()));
110 0 : // Consume stderr and log warnings
111 0 : tokio::spawn(async move {
112 0 : let mut lines = stderr_reader.lines();
113 0 : while let Ok(Some(line)) = lines.next_line().await {
114 0 : warn!("pg_dump stderr: {}", line)
115 : }
116 0 : });
117 0 : Ok(initial_stream.chain(stdout_reader.map(|res| res.map(|b| b.freeze()))))
118 0 : }
|