Line data Source code
1 : use futures::Stream;
2 : use postgres::NoTls;
3 : use std::{path::Path, process::Stdio, result::Result, sync::Arc};
4 : use tokio::{
5 : io::{AsyncBufReadExt, BufReader},
6 : process::Command,
7 : spawn,
8 : };
9 : use tokio_stream::{self as stream, StreamExt};
10 : use tokio_util::codec::{BytesCodec, FramedRead};
11 : use tracing::warn;
12 :
13 : use crate::compute::ComputeNode;
14 : use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async, postgres_conf_for_db};
15 : use compute_api::responses::CatalogObjects;
16 :
17 0 : pub async fn get_dbs_and_roles(compute: &Arc<ComputeNode>) -> anyhow::Result<CatalogObjects> {
18 0 : let conf = compute.get_tokio_conn_conf(Some("compute_ctl:get_dbs_and_roles"));
19 0 : let (client, connection): (tokio_postgres::Client, _) = conf.connect(NoTls).await?;
20 :
21 0 : spawn(async move {
22 0 : if let Err(e) = connection.await {
23 0 : eprintln!("connection error: {}", e);
24 0 : }
25 0 : });
26 :
27 0 : let roles = get_existing_roles_async(&client).await?;
28 :
29 0 : let databases = get_existing_dbs_async(&client)
30 0 : .await?
31 0 : .into_values()
32 0 : .collect();
33 0 :
34 0 : Ok(CatalogObjects { roles, databases })
35 0 : }
36 :
37 : #[derive(Debug, thiserror::Error)]
38 : pub enum SchemaDumpError {
39 : #[error("database does not exist")]
40 : DatabaseDoesNotExist,
41 : #[error("failed to execute pg_dump")]
42 : IO(#[from] std::io::Error),
43 : #[error("unexpected I/O error")]
44 : Unexpected,
45 : }
46 :
47 : // It uses the pg_dump utility to dump the schema of the specified database.
48 : // The output is streamed back to the caller and supposed to be streamed via HTTP.
49 : //
50 : // Before return the result with the output, it checks that pg_dump produced any output.
51 : // If not, it tries to parse the stderr output to determine if the database does not exist
52 : // and special error is returned.
53 : //
54 : // To make sure that the process is killed when the caller drops the stream, we use tokio kill_on_drop feature.
55 0 : pub async fn get_database_schema(
56 0 : compute: &Arc<ComputeNode>,
57 0 : dbname: &str,
58 0 : ) -> Result<impl Stream<Item = Result<bytes::Bytes, std::io::Error>>, SchemaDumpError> {
59 0 : let pgbin = &compute.pgbin;
60 0 : let basepath = Path::new(pgbin).parent().unwrap();
61 0 : let pgdump = basepath.join("pg_dump");
62 :
63 : // Replace the DB in the connection string and disable it to parts.
64 : // This is the only option to handle DBs with special characters.
65 0 : let conf =
66 0 : postgres_conf_for_db(&compute.connstr, dbname).map_err(|_| SchemaDumpError::Unexpected)?;
67 0 : let host = conf
68 0 : .get_hosts()
69 0 : .first()
70 0 : .ok_or(SchemaDumpError::Unexpected)?;
71 0 : let host = match host {
72 0 : tokio_postgres::config::Host::Tcp(ip) => ip.to_string(),
73 : #[cfg(unix)]
74 0 : tokio_postgres::config::Host::Unix(path) => path.to_string_lossy().to_string(),
75 : };
76 0 : let port = conf
77 0 : .get_ports()
78 0 : .first()
79 0 : .ok_or(SchemaDumpError::Unexpected)?;
80 0 : let user = conf.get_user().ok_or(SchemaDumpError::Unexpected)?;
81 0 : let dbname = conf.get_dbname().ok_or(SchemaDumpError::Unexpected)?;
82 :
83 0 : let mut cmd = Command::new(pgdump)
84 0 : // XXX: this seems to be the only option to deal with DBs with `=` in the name
85 0 : // See <https://www.postgresql.org/message-id/flat/20151023003445.931.91267%40wrigleys.postgresql.org>
86 0 : .env("PGDATABASE", dbname)
87 0 : .arg("--host")
88 0 : .arg(host)
89 0 : .arg("--port")
90 0 : .arg(port.to_string())
91 0 : .arg("--username")
92 0 : .arg(user)
93 0 : .arg("--schema-only")
94 0 : .stdout(Stdio::piped())
95 0 : .stderr(Stdio::piped())
96 0 : .kill_on_drop(true)
97 0 : .spawn()?;
98 :
99 0 : let stdout = cmd.stdout.take().ok_or_else(|| {
100 0 : std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stdout.")
101 0 : })?;
102 :
103 0 : let stderr = cmd.stderr.take().ok_or_else(|| {
104 0 : std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stderr.")
105 0 : })?;
106 :
107 0 : let mut stdout_reader = FramedRead::new(stdout, BytesCodec::new());
108 0 : let stderr_reader = BufReader::new(stderr);
109 :
110 0 : let first_chunk = match stdout_reader.next().await {
111 0 : Some(Ok(bytes)) if !bytes.is_empty() => bytes,
112 0 : Some(Err(e)) => {
113 0 : return Err(SchemaDumpError::IO(e));
114 : }
115 : _ => {
116 0 : let mut lines = stderr_reader.lines();
117 0 : if let Some(line) = lines.next_line().await? {
118 0 : if line.contains(&format!("FATAL: database \"{}\" does not exist", dbname)) {
119 0 : return Err(SchemaDumpError::DatabaseDoesNotExist);
120 0 : }
121 0 : warn!("pg_dump stderr: {}", line)
122 0 : }
123 0 : tokio::spawn(async move {
124 0 : while let Ok(Some(line)) = lines.next_line().await {
125 0 : warn!("pg_dump stderr: {}", line)
126 : }
127 0 : });
128 0 :
129 0 : return Err(SchemaDumpError::IO(std::io::Error::new(
130 0 : std::io::ErrorKind::Other,
131 0 : "failed to start pg_dump",
132 0 : )));
133 : }
134 : };
135 0 : let initial_stream = stream::once(Ok(first_chunk.freeze()));
136 0 : // Consume stderr and log warnings
137 0 : tokio::spawn(async move {
138 0 : let mut lines = stderr_reader.lines();
139 0 : while let Ok(Some(line)) = lines.next_line().await {
140 0 : warn!("pg_dump stderr: {}", line)
141 : }
142 0 : });
143 :
144 : #[allow(dead_code)]
145 : struct SchemaStream<S> {
146 : // We keep a reference to the child process to ensure it stays alive
147 : // while the stream is being consumed. When SchemaStream is dropped,
148 : // cmd will be dropped, which triggers kill_on_drop and terminates pg_dump
149 : cmd: tokio::process::Child,
150 : stream: S,
151 : }
152 :
153 : impl<S> Stream for SchemaStream<S>
154 : where
155 : S: Stream<Item = Result<bytes::Bytes, std::io::Error>> + Unpin,
156 : {
157 : type Item = Result<bytes::Bytes, std::io::Error>;
158 :
159 0 : fn poll_next(
160 0 : mut self: std::pin::Pin<&mut Self>,
161 0 : cx: &mut std::task::Context<'_>,
162 0 : ) -> std::task::Poll<Option<Self::Item>> {
163 0 : Stream::poll_next(std::pin::Pin::new(&mut self.stream), cx)
164 0 : }
165 : }
166 :
167 0 : let schema_stream = SchemaStream {
168 0 : cmd,
169 0 : stream: initial_stream.chain(stdout_reader.map(|res| res.map(|b| b.freeze()))),
170 0 : };
171 0 :
172 0 : Ok(schema_stream)
173 0 : }
|