Line data Source code
1 : use std::path::Path;
2 : use std::process::Stdio;
3 : use std::result::Result;
4 : use std::sync::Arc;
5 :
6 : use compute_api::responses::CatalogObjects;
7 : use futures::Stream;
8 : use postgres::NoTls;
9 : use tokio::io::{AsyncBufReadExt, BufReader};
10 : use tokio::process::Command;
11 : use tokio::spawn;
12 : use tokio_stream::{self as stream, StreamExt};
13 : use tokio_util::codec::{BytesCodec, FramedRead};
14 : use tracing::warn;
15 :
16 : use crate::compute::ComputeNode;
17 : use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async, postgres_conf_for_db};
18 :
19 0 : pub async fn get_dbs_and_roles(compute: &Arc<ComputeNode>) -> anyhow::Result<CatalogObjects> {
20 0 : let conf = compute.get_tokio_conn_conf(Some("compute_ctl:get_dbs_and_roles"));
21 0 : let (client, connection): (tokio_postgres::Client, _) = conf.connect(NoTls).await?;
22 :
23 0 : spawn(async move {
24 0 : if let Err(e) = connection.await {
25 0 : eprintln!("connection error: {}", e);
26 0 : }
27 0 : });
28 :
29 0 : let roles = get_existing_roles_async(&client).await?;
30 :
31 0 : let databases = get_existing_dbs_async(&client)
32 0 : .await?
33 0 : .into_values()
34 0 : .collect();
35 0 :
36 0 : Ok(CatalogObjects { roles, databases })
37 0 : }
38 :
39 : #[derive(Debug, thiserror::Error)]
40 : pub enum SchemaDumpError {
41 : #[error("database does not exist")]
42 : DatabaseDoesNotExist,
43 : #[error("failed to execute pg_dump")]
44 : IO(#[from] std::io::Error),
45 : #[error("unexpected I/O error")]
46 : Unexpected,
47 : }
48 :
49 : // It uses the pg_dump utility to dump the schema of the specified database.
50 : // The output is streamed back to the caller and supposed to be streamed via HTTP.
51 : //
52 : // Before return the result with the output, it checks that pg_dump produced any output.
53 : // If not, it tries to parse the stderr output to determine if the database does not exist
54 : // and special error is returned.
55 : //
56 : // To make sure that the process is killed when the caller drops the stream, we use tokio kill_on_drop feature.
57 0 : pub async fn get_database_schema(
58 0 : compute: &Arc<ComputeNode>,
59 0 : dbname: &str,
60 0 : ) -> Result<impl Stream<Item = Result<bytes::Bytes, std::io::Error>> + use<>, SchemaDumpError> {
61 0 : let pgbin = &compute.params.pgbin;
62 0 : let basepath = Path::new(pgbin).parent().unwrap();
63 0 : let pgdump = basepath.join("pg_dump");
64 :
65 : // Replace the DB in the connection string and disable it to parts.
66 : // This is the only option to handle DBs with special characters.
67 0 : let conf = postgres_conf_for_db(&compute.params.connstr, dbname)
68 0 : .map_err(|_| SchemaDumpError::Unexpected)?;
69 0 : let host = conf
70 0 : .get_hosts()
71 0 : .first()
72 0 : .ok_or(SchemaDumpError::Unexpected)?;
73 0 : let host = match host {
74 0 : tokio_postgres::config::Host::Tcp(ip) => ip.to_string(),
75 : #[cfg(unix)]
76 0 : tokio_postgres::config::Host::Unix(path) => path.to_string_lossy().to_string(),
77 : };
78 0 : let port = conf
79 0 : .get_ports()
80 0 : .first()
81 0 : .ok_or(SchemaDumpError::Unexpected)?;
82 0 : let user = conf.get_user().ok_or(SchemaDumpError::Unexpected)?;
83 0 : let dbname = conf.get_dbname().ok_or(SchemaDumpError::Unexpected)?;
84 :
85 0 : let mut cmd = Command::new(pgdump)
86 0 : // XXX: this seems to be the only option to deal with DBs with `=` in the name
87 0 : // See <https://www.postgresql.org/message-id/flat/20151023003445.931.91267%40wrigleys.postgresql.org>
88 0 : .env("PGDATABASE", dbname)
89 0 : .arg("--host")
90 0 : .arg(host)
91 0 : .arg("--port")
92 0 : .arg(port.to_string())
93 0 : .arg("--username")
94 0 : .arg(user)
95 0 : .arg("--schema-only")
96 0 : .stdout(Stdio::piped())
97 0 : .stderr(Stdio::piped())
98 0 : .kill_on_drop(true)
99 0 : .spawn()?;
100 :
101 0 : let stdout = cmd.stdout.take().ok_or_else(|| {
102 0 : std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stdout.")
103 0 : })?;
104 :
105 0 : let stderr = cmd.stderr.take().ok_or_else(|| {
106 0 : std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stderr.")
107 0 : })?;
108 :
109 0 : let mut stdout_reader = FramedRead::new(stdout, BytesCodec::new());
110 0 : let stderr_reader = BufReader::new(stderr);
111 :
112 0 : let first_chunk = match stdout_reader.next().await {
113 0 : Some(Ok(bytes)) if !bytes.is_empty() => bytes,
114 0 : Some(Err(e)) => {
115 0 : return Err(SchemaDumpError::IO(e));
116 : }
117 : _ => {
118 0 : let mut lines = stderr_reader.lines();
119 0 : if let Some(line) = lines.next_line().await? {
120 0 : if line.contains(&format!("FATAL: database \"{}\" does not exist", dbname)) {
121 0 : return Err(SchemaDumpError::DatabaseDoesNotExist);
122 0 : }
123 0 : warn!("pg_dump stderr: {}", line)
124 0 : }
125 0 : tokio::spawn(async move {
126 0 : while let Ok(Some(line)) = lines.next_line().await {
127 0 : warn!("pg_dump stderr: {}", line)
128 : }
129 0 : });
130 0 :
131 0 : return Err(SchemaDumpError::IO(std::io::Error::new(
132 0 : std::io::ErrorKind::Other,
133 0 : "failed to start pg_dump",
134 0 : )));
135 : }
136 : };
137 0 : let initial_stream = stream::once(Ok(first_chunk.freeze()));
138 0 : // Consume stderr and log warnings
139 0 : tokio::spawn(async move {
140 0 : let mut lines = stderr_reader.lines();
141 0 : while let Ok(Some(line)) = lines.next_line().await {
142 0 : warn!("pg_dump stderr: {}", line)
143 : }
144 0 : });
145 :
146 : #[allow(dead_code)]
147 : struct SchemaStream<S> {
148 : // We keep a reference to the child process to ensure it stays alive
149 : // while the stream is being consumed. When SchemaStream is dropped,
150 : // cmd will be dropped, which triggers kill_on_drop and terminates pg_dump
151 : cmd: tokio::process::Child,
152 : stream: S,
153 : }
154 :
155 : impl<S> Stream for SchemaStream<S>
156 : where
157 : S: Stream<Item = Result<bytes::Bytes, std::io::Error>> + Unpin,
158 : {
159 : type Item = Result<bytes::Bytes, std::io::Error>;
160 :
161 0 : fn poll_next(
162 0 : mut self: std::pin::Pin<&mut Self>,
163 0 : cx: &mut std::task::Context<'_>,
164 0 : ) -> std::task::Poll<Option<Self::Item>> {
165 0 : Stream::poll_next(std::pin::Pin::new(&mut self.stream), cx)
166 0 : }
167 : }
168 :
169 0 : let schema_stream = SchemaStream {
170 0 : cmd,
171 0 : stream: initial_stream.chain(stdout_reader.map(|res| res.map(|b| b.freeze()))),
172 0 : };
173 0 :
174 0 : Ok(schema_stream)
175 0 : }
|