LCOV - code coverage report
Current view: top level - compute_tools/src - catalog.rs (source / functions) Coverage Total Hit
Test: 6df3fc19ec669bcfbbf9aba41d1338898d24eaa0.info Lines: 0.0 % 101 0
Test Date: 2025-03-12 18:28:53 Functions: 0.0 % 13 0

            Line data    Source code
       1              : use std::path::Path;
       2              : use std::process::Stdio;
       3              : use std::result::Result;
       4              : use std::sync::Arc;
       5              : 
       6              : use compute_api::responses::CatalogObjects;
       7              : use futures::Stream;
       8              : use postgres::NoTls;
       9              : use tokio::io::{AsyncBufReadExt, BufReader};
      10              : use tokio::process::Command;
      11              : use tokio::spawn;
      12              : use tokio_stream::{self as stream, StreamExt};
      13              : use tokio_util::codec::{BytesCodec, FramedRead};
      14              : use tracing::warn;
      15              : 
      16              : use crate::compute::ComputeNode;
      17              : use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async, postgres_conf_for_db};
      18              : 
      19            0 : pub async fn get_dbs_and_roles(compute: &Arc<ComputeNode>) -> anyhow::Result<CatalogObjects> {
      20            0 :     let conf = compute.get_tokio_conn_conf(Some("compute_ctl:get_dbs_and_roles"));
      21            0 :     let (client, connection): (tokio_postgres::Client, _) = conf.connect(NoTls).await?;
      22              : 
      23            0 :     spawn(async move {
      24            0 :         if let Err(e) = connection.await {
      25            0 :             eprintln!("connection error: {}", e);
      26            0 :         }
      27            0 :     });
      28              : 
      29            0 :     let roles = get_existing_roles_async(&client).await?;
      30              : 
      31            0 :     let databases = get_existing_dbs_async(&client)
      32            0 :         .await?
      33            0 :         .into_values()
      34            0 :         .collect();
      35            0 : 
      36            0 :     Ok(CatalogObjects { roles, databases })
      37            0 : }
      38              : 
      39              : #[derive(Debug, thiserror::Error)]
      40              : pub enum SchemaDumpError {
      41              :     #[error("database does not exist")]
      42              :     DatabaseDoesNotExist,
      43              :     #[error("failed to execute pg_dump")]
      44              :     IO(#[from] std::io::Error),
      45              :     #[error("unexpected I/O error")]
      46              :     Unexpected,
      47              : }
      48              : 
      49              : // It uses the pg_dump utility to dump the schema of the specified database.
      50              : // The output is streamed back to the caller and supposed to be streamed via HTTP.
      51              : //
      52              : // Before return the result with the output, it checks that pg_dump produced any output.
      53              : // If not, it tries to parse the stderr output to determine if the database does not exist
      54              : // and special error is returned.
      55              : //
      56              : // To make sure that the process is killed when the caller drops the stream, we use tokio kill_on_drop feature.
      57            0 : pub async fn get_database_schema(
      58            0 :     compute: &Arc<ComputeNode>,
      59            0 :     dbname: &str,
      60            0 : ) -> Result<impl Stream<Item = Result<bytes::Bytes, std::io::Error>> + use<>, SchemaDumpError> {
      61            0 :     let pgbin = &compute.params.pgbin;
      62            0 :     let basepath = Path::new(pgbin).parent().unwrap();
      63            0 :     let pgdump = basepath.join("pg_dump");
      64              : 
      65              :     // Replace the DB in the connection string and disable it to parts.
      66              :     // This is the only option to handle DBs with special characters.
      67            0 :     let conf = postgres_conf_for_db(&compute.params.connstr, dbname)
      68            0 :         .map_err(|_| SchemaDumpError::Unexpected)?;
      69            0 :     let host = conf
      70            0 :         .get_hosts()
      71            0 :         .first()
      72            0 :         .ok_or(SchemaDumpError::Unexpected)?;
      73            0 :     let host = match host {
      74            0 :         tokio_postgres::config::Host::Tcp(ip) => ip.to_string(),
      75              :         #[cfg(unix)]
      76            0 :         tokio_postgres::config::Host::Unix(path) => path.to_string_lossy().to_string(),
      77              :     };
      78            0 :     let port = conf
      79            0 :         .get_ports()
      80            0 :         .first()
      81            0 :         .ok_or(SchemaDumpError::Unexpected)?;
      82            0 :     let user = conf.get_user().ok_or(SchemaDumpError::Unexpected)?;
      83            0 :     let dbname = conf.get_dbname().ok_or(SchemaDumpError::Unexpected)?;
      84              : 
      85            0 :     let mut cmd = Command::new(pgdump)
      86            0 :         // XXX: this seems to be the only option to deal with DBs with `=` in the name
      87            0 :         // See <https://www.postgresql.org/message-id/flat/20151023003445.931.91267%40wrigleys.postgresql.org>
      88            0 :         .env("PGDATABASE", dbname)
      89            0 :         .arg("--host")
      90            0 :         .arg(host)
      91            0 :         .arg("--port")
      92            0 :         .arg(port.to_string())
      93            0 :         .arg("--username")
      94            0 :         .arg(user)
      95            0 :         .arg("--schema-only")
      96            0 :         .stdout(Stdio::piped())
      97            0 :         .stderr(Stdio::piped())
      98            0 :         .kill_on_drop(true)
      99            0 :         .spawn()?;
     100              : 
     101            0 :     let stdout = cmd.stdout.take().ok_or_else(|| {
     102            0 :         std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stdout.")
     103            0 :     })?;
     104              : 
     105            0 :     let stderr = cmd.stderr.take().ok_or_else(|| {
     106            0 :         std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stderr.")
     107            0 :     })?;
     108              : 
     109            0 :     let mut stdout_reader = FramedRead::new(stdout, BytesCodec::new());
     110            0 :     let stderr_reader = BufReader::new(stderr);
     111              : 
     112            0 :     let first_chunk = match stdout_reader.next().await {
     113            0 :         Some(Ok(bytes)) if !bytes.is_empty() => bytes,
     114            0 :         Some(Err(e)) => {
     115            0 :             return Err(SchemaDumpError::IO(e));
     116              :         }
     117              :         _ => {
     118            0 :             let mut lines = stderr_reader.lines();
     119            0 :             if let Some(line) = lines.next_line().await? {
     120            0 :                 if line.contains(&format!("FATAL:  database \"{}\" does not exist", dbname)) {
     121            0 :                     return Err(SchemaDumpError::DatabaseDoesNotExist);
     122            0 :                 }
     123            0 :                 warn!("pg_dump stderr: {}", line)
     124            0 :             }
     125            0 :             tokio::spawn(async move {
     126            0 :                 while let Ok(Some(line)) = lines.next_line().await {
     127            0 :                     warn!("pg_dump stderr: {}", line)
     128              :                 }
     129            0 :             });
     130            0 : 
     131            0 :             return Err(SchemaDumpError::IO(std::io::Error::new(
     132            0 :                 std::io::ErrorKind::Other,
     133            0 :                 "failed to start pg_dump",
     134            0 :             )));
     135              :         }
     136              :     };
     137            0 :     let initial_stream = stream::once(Ok(first_chunk.freeze()));
     138            0 :     // Consume stderr and log warnings
     139            0 :     tokio::spawn(async move {
     140            0 :         let mut lines = stderr_reader.lines();
     141            0 :         while let Ok(Some(line)) = lines.next_line().await {
     142            0 :             warn!("pg_dump stderr: {}", line)
     143              :         }
     144            0 :     });
     145              : 
     146              :     #[allow(dead_code)]
     147              :     struct SchemaStream<S> {
     148              :         // We keep a reference to the child process to ensure it stays alive
     149              :         // while the stream is being consumed. When SchemaStream is dropped,
     150              :         // cmd will be dropped, which triggers kill_on_drop and terminates pg_dump
     151              :         cmd: tokio::process::Child,
     152              :         stream: S,
     153              :     }
     154              : 
     155              :     impl<S> Stream for SchemaStream<S>
     156              :     where
     157              :         S: Stream<Item = Result<bytes::Bytes, std::io::Error>> + Unpin,
     158              :     {
     159              :         type Item = Result<bytes::Bytes, std::io::Error>;
     160              : 
     161            0 :         fn poll_next(
     162            0 :             mut self: std::pin::Pin<&mut Self>,
     163            0 :             cx: &mut std::task::Context<'_>,
     164            0 :         ) -> std::task::Poll<Option<Self::Item>> {
     165            0 :             Stream::poll_next(std::pin::Pin::new(&mut self.stream), cx)
     166            0 :         }
     167              :     }
     168              : 
     169            0 :     let schema_stream = SchemaStream {
     170            0 :         cmd,
     171            0 :         stream: initial_stream.chain(stdout_reader.map(|res| res.map(|b| b.freeze()))),
     172            0 :     };
     173            0 : 
     174            0 :     Ok(schema_stream)
     175            0 : }
        

Generated by: LCOV version 2.1-beta