LCOV - code coverage report
Current view: top level - compute_tools/src/bin - fast_import.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 0.0 % 475 0
Test Date: 2025-07-16 12:29:03 Functions: 0.0 % 24 0

            Line data    Source code
       1              : //! This program dumps a remote Postgres database into a local Postgres database
       2              : //! and uploads the resulting PGDATA into object storage for import into a Timeline.
       3              : //!
       4              : //! # Context, Architecture, Design
       5              : //!
       6              : //! See cloud.git Fast Imports RFC (<https://github.com/neondatabase/cloud/pull/19799>)
       7              : //! for the full picture.
       8              : //! The RFC describing the storage pieces of importing the PGDATA dump into a Timeline
       9              : //! is publicly accessible at <https://github.com/neondatabase/neon/pull/9538>.
      10              : //!
      11              : //! # This is a Prototype!
      12              : //!
      13              : //! This program is part of a prototype feature and not yet used in production.
      14              : //!
      15              : //! The cloud.git RFC contains lots of suggestions for improving e2e throughput
      16              : //! of this step of the timeline import process.
      17              : //!
      18              : //! # Local Testing
      19              : //!
      20              : //! - Comment out most of the pgxns in compute-node.Dockerfile to speed up the build.
      21              : //! - Build the image with the following command:
      22              : //!
      23              : //! ```bash
      24              : //! docker buildx build --platform linux/amd64 --build-arg DEBIAN_VERSION=bullseye --build-arg GIT_VERSION=local --build-arg PG_VERSION=v14 --build-arg BUILD_TAG="$(date --iso-8601=s -u)" -t localhost:3030/localregistry/compute-node-v14:latest -f compute/compute-node.Dockerfile .
      25              : //! docker push localhost:3030/localregistry/compute-node-v14:latest
      26              : //! ```
      27              : 
      28              : use anyhow::{Context, bail};
      29              : use aws_config::BehaviorVersion;
      30              : use camino::{Utf8Path, Utf8PathBuf};
      31              : use clap::{Parser, Subcommand};
      32              : use compute_tools::extension_server::get_pg_version;
      33              : use nix::unistd::Pid;
      34              : use std::ops::Not;
      35              : use tracing::{Instrument, error, info, info_span, warn};
      36              : use utils::fs_ext::is_directory_empty;
      37              : 
      38              : #[path = "fast_import/aws_s3_sync.rs"]
      39              : mod aws_s3_sync;
      40              : #[path = "fast_import/child_stdio_to_log.rs"]
      41              : mod child_stdio_to_log;
      42              : #[path = "fast_import/s3_uri.rs"]
      43              : mod s3_uri;
      44              : 
      45              : const PG_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(600);
      46              : const PG_WAIT_RETRY_INTERVAL: std::time::Duration = std::time::Duration::from_millis(300);
      47              : 
      48              : #[derive(Subcommand, Debug, Clone, serde::Serialize)]
      49              : enum Command {
      50              :     /// Runs local postgres (neon binary), restores into it,
      51              :     /// uploads pgdata to s3 to be consumed by pageservers
      52              :     Pgdata {
      53              :         /// Raw connection string to the source database. Used only in tests,
      54              :         /// real scenario uses encrypted connection string in spec.json from s3.
      55              :         #[clap(long)]
      56              :         source_connection_string: Option<String>,
      57              :         /// If specified, will not shut down the local postgres after the import. Used in local testing
      58              :         #[clap(short, long)]
      59              :         interactive: bool,
      60              :         /// Port to run postgres on. Default is 5432.
      61              :         #[clap(long, default_value_t = 5432)]
      62              :         pg_port: u16, // port to run postgres on, 5432 is default
      63              : 
      64              :         /// Number of CPUs in the system. This is used to configure # of
      65              :         /// parallel worker processes, for index creation.
      66              :         #[clap(long, env = "NEON_IMPORTER_NUM_CPUS")]
      67              :         num_cpus: Option<usize>,
      68              : 
      69              :         /// Amount of RAM in the system. This is used to configure shared_buffers
      70              :         /// and maintenance_work_mem.
      71              :         #[clap(long, env = "NEON_IMPORTER_MEMORY_MB")]
      72              :         memory_mb: Option<usize>,
      73              :     },
      74              : 
      75              :     /// Runs pg_dump-pg_restore from source to destination without running local postgres.
      76              :     DumpRestore {
      77              :         /// Raw connection string to the source database. Used only in tests,
      78              :         /// real scenario uses encrypted connection string in spec.json from s3.
      79              :         #[clap(long)]
      80              :         source_connection_string: Option<String>,
      81              :         /// Raw connection string to the destination database. Used only in tests,
      82              :         /// real scenario uses encrypted connection string in spec.json from s3.
      83              :         #[clap(long)]
      84              :         destination_connection_string: Option<String>,
      85              :     },
      86              : }
      87              : 
      88              : impl Command {
      89            0 :     fn as_str(&self) -> &'static str {
      90            0 :         match self {
      91            0 :             Command::Pgdata { .. } => "pgdata",
      92            0 :             Command::DumpRestore { .. } => "dump-restore",
      93              :         }
      94            0 :     }
      95              : }
      96              : 
      97              : #[derive(clap::Parser)]
      98              : struct Args {
      99              :     #[clap(long, env = "NEON_IMPORTER_WORKDIR")]
     100              :     working_directory: Utf8PathBuf,
     101              :     #[clap(long, env = "NEON_IMPORTER_S3_PREFIX")]
     102              :     s3_prefix: Option<s3_uri::S3Uri>,
     103              :     #[clap(long, env = "NEON_IMPORTER_PG_BIN_DIR")]
     104              :     pg_bin_dir: Utf8PathBuf,
     105              :     #[clap(long, env = "NEON_IMPORTER_PG_LIB_DIR")]
     106              :     pg_lib_dir: Utf8PathBuf,
     107              : 
     108              :     #[clap(subcommand)]
     109              :     command: Command,
     110              : }
     111              : 
     112              : #[serde_with::serde_as]
     113              : #[derive(serde::Deserialize)]
     114              : struct Spec {
     115              :     encryption_secret: EncryptionSecret,
     116              :     #[serde_as(as = "serde_with::base64::Base64")]
     117              :     source_connstring_ciphertext_base64: Vec<u8>,
     118              :     #[serde_as(as = "Option<serde_with::base64::Base64>")]
     119              :     destination_connstring_ciphertext_base64: Option<Vec<u8>>,
     120              : }
     121              : 
     122            0 : #[derive(serde::Deserialize)]
     123              : enum EncryptionSecret {
     124              :     #[allow(clippy::upper_case_acronyms)]
     125              :     KMS { key_id: String },
     126              : }
     127              : 
     128              : // copied from pageserver_api::config::defaults::DEFAULT_LOCALE to avoid dependency just for a constant
     129              : const DEFAULT_LOCALE: &str = if cfg!(target_os = "macos") {
     130              :     "C"
     131              : } else {
     132              :     "C.UTF-8"
     133              : };
     134              : 
     135            0 : async fn decode_connstring(
     136            0 :     kms_client: &aws_sdk_kms::Client,
     137            0 :     key_id: &String,
     138            0 :     connstring_ciphertext_base64: Vec<u8>,
     139            0 : ) -> Result<String, anyhow::Error> {
     140            0 :     let mut output = kms_client
     141            0 :         .decrypt()
     142            0 :         .key_id(key_id)
     143            0 :         .ciphertext_blob(aws_sdk_s3::primitives::Blob::new(
     144            0 :             connstring_ciphertext_base64,
     145            0 :         ))
     146            0 :         .send()
     147            0 :         .await
     148            0 :         .context("decrypt connection string")?;
     149              : 
     150            0 :     let plaintext = output
     151            0 :         .plaintext
     152            0 :         .take()
     153            0 :         .context("get plaintext connection string")?;
     154              : 
     155            0 :     String::from_utf8(plaintext.into_inner()).context("parse connection string as utf8")
     156            0 : }
     157              : 
     158              : struct PostgresProcess {
     159              :     pgdata_dir: Utf8PathBuf,
     160              :     pg_bin_dir: Utf8PathBuf,
     161              :     pgbin: Utf8PathBuf,
     162              :     pg_lib_dir: Utf8PathBuf,
     163              :     postgres_proc: Option<tokio::process::Child>,
     164              : }
     165              : 
     166              : impl PostgresProcess {
     167            0 :     fn new(pgdata_dir: Utf8PathBuf, pg_bin_dir: Utf8PathBuf, pg_lib_dir: Utf8PathBuf) -> Self {
     168            0 :         Self {
     169            0 :             pgdata_dir,
     170            0 :             pgbin: pg_bin_dir.join("postgres"),
     171            0 :             pg_bin_dir,
     172            0 :             pg_lib_dir,
     173            0 :             postgres_proc: None,
     174            0 :         }
     175            0 :     }
     176              : 
     177            0 :     async fn prepare(&self, initdb_user: &str) -> Result<(), anyhow::Error> {
     178            0 :         tokio::fs::create_dir(&self.pgdata_dir)
     179            0 :             .await
     180            0 :             .context("create pgdata directory")?;
     181              : 
     182            0 :         let pg_version = get_pg_version(self.pgbin.as_ref());
     183              : 
     184            0 :         postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs {
     185            0 :             superuser: initdb_user,
     186            0 :             locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded,
     187            0 :             pg_version,
     188            0 :             initdb_bin: self.pg_bin_dir.join("initdb").as_ref(),
     189            0 :             library_search_path: &self.pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local.
     190            0 :             pgdata: &self.pgdata_dir,
     191            0 :         })
     192            0 :         .await
     193            0 :         .context("initdb")
     194            0 :     }
     195              : 
     196            0 :     async fn start(
     197            0 :         &mut self,
     198            0 :         initdb_user: &str,
     199            0 :         port: u16,
     200            0 :         nproc: usize,
     201            0 :         memory_mb: usize,
     202            0 :     ) -> Result<&tokio::process::Child, anyhow::Error> {
     203            0 :         self.prepare(initdb_user).await?;
     204              : 
     205              :         // Somewhat arbitrarily, use 10 % of memory for shared buffer cache, 70% for
     206              :         // maintenance_work_mem (i.e. for sorting during index creation), and leave the rest
     207              :         // available for misc other stuff that PostgreSQL uses memory for.
     208            0 :         let shared_buffers_mb = ((memory_mb as f32) * 0.10) as usize;
     209            0 :         let maintenance_work_mem_mb = ((memory_mb as f32) * 0.70) as usize;
     210              : 
     211              :         //
     212              :         // Launch postgres process
     213              :         //
     214            0 :         let mut proc = tokio::process::Command::new(&self.pgbin)
     215            0 :             .arg("-D")
     216            0 :             .arg(&self.pgdata_dir)
     217            0 :             .args(["-p", &format!("{port}")])
     218            0 :             .args(["-c", "wal_level=minimal"])
     219            0 :             .args(["-c", &format!("shared_buffers={shared_buffers_mb}MB")])
     220            0 :             .args(["-c", "max_wal_senders=0"])
     221            0 :             .args(["-c", "fsync=off"])
     222            0 :             .args(["-c", "full_page_writes=off"])
     223            0 :             .args(["-c", "synchronous_commit=off"])
     224            0 :             .args([
     225            0 :                 "-c",
     226            0 :                 &format!("maintenance_work_mem={maintenance_work_mem_mb}MB"),
     227            0 :             ])
     228            0 :             .args(["-c", &format!("max_parallel_maintenance_workers={nproc}")])
     229            0 :             .args(["-c", &format!("max_parallel_workers={nproc}")])
     230            0 :             .args(["-c", &format!("max_parallel_workers_per_gather={nproc}")])
     231            0 :             .args(["-c", &format!("max_worker_processes={nproc}")])
     232            0 :             .args(["-c", "effective_io_concurrency=100"])
     233            0 :             .env_clear()
     234            0 :             .env("LD_LIBRARY_PATH", &self.pg_lib_dir)
     235            0 :             .env(
     236            0 :                 "ASAN_OPTIONS",
     237            0 :                 std::env::var("ASAN_OPTIONS").unwrap_or_default(),
     238            0 :             )
     239            0 :             .env(
     240            0 :                 "UBSAN_OPTIONS",
     241            0 :                 std::env::var("UBSAN_OPTIONS").unwrap_or_default(),
     242            0 :             )
     243            0 :             .stdout(std::process::Stdio::piped())
     244            0 :             .stderr(std::process::Stdio::piped())
     245            0 :             .spawn()
     246            0 :             .context("spawn postgres")?;
     247              : 
     248            0 :         info!("spawned postgres, waiting for it to become ready");
     249            0 :         tokio::spawn(
     250            0 :             child_stdio_to_log::relay_process_output(proc.stdout.take(), proc.stderr.take())
     251            0 :                 .instrument(info_span!("postgres")),
     252              :         );
     253              : 
     254            0 :         self.postgres_proc = Some(proc);
     255            0 :         Ok(self.postgres_proc.as_ref().unwrap())
     256            0 :     }
     257              : 
     258            0 :     async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
     259            0 :         let proc: &mut tokio::process::Child = self.postgres_proc.as_mut().unwrap();
     260            0 :         info!("shutdown postgres");
     261            0 :         nix::sys::signal::kill(
     262            0 :             Pid::from_raw(i32::try_from(proc.id().unwrap()).expect("convert child pid to i32")),
     263            0 :             nix::sys::signal::SIGTERM,
     264              :         )
     265            0 :         .context("signal postgres to shut down")?;
     266            0 :         proc.wait()
     267            0 :             .await
     268            0 :             .context("wait for postgres to shut down")
     269            0 :             .map(|_| ())
     270            0 :     }
     271              : }
     272              : 
     273            0 : async fn wait_until_ready(connstring: String, create_dbname: String) {
     274              :     // Create neondb database in the running postgres
     275            0 :     let start_time = std::time::Instant::now();
     276              : 
     277              :     loop {
     278            0 :         if start_time.elapsed() > PG_WAIT_TIMEOUT {
     279            0 :             error!(
     280            0 :                 "timeout exceeded: failed to poll postgres and create database within 10 minutes"
     281              :             );
     282            0 :             std::process::exit(1);
     283            0 :         }
     284              : 
     285            0 :         match tokio_postgres::connect(
     286            0 :             &connstring.replace("dbname=neondb", "dbname=postgres"),
     287            0 :             tokio_postgres::NoTls,
     288              :         )
     289            0 :         .await
     290              :         {
     291            0 :             Ok((client, connection)) => {
     292              :                 // Spawn the connection handling task to maintain the connection
     293            0 :                 tokio::spawn(async move {
     294            0 :                     if let Err(e) = connection.await {
     295            0 :                         warn!("connection error: {}", e);
     296            0 :                     }
     297            0 :                 });
     298              : 
     299            0 :                 match client
     300            0 :                     .simple_query(format!("CREATE DATABASE {create_dbname};").as_str())
     301            0 :                     .await
     302              :                 {
     303              :                     Ok(_) => {
     304            0 :                         info!("created {} database", create_dbname);
     305            0 :                         break;
     306              :                     }
     307            0 :                     Err(e) => {
     308            0 :                         warn!(
     309            0 :                             "failed to create database: {}, retying in {}s",
     310              :                             e,
     311            0 :                             PG_WAIT_RETRY_INTERVAL.as_secs_f32()
     312              :                         );
     313            0 :                         tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await;
     314            0 :                         continue;
     315              :                     }
     316              :                 }
     317              :             }
     318              :             Err(_) => {
     319            0 :                 info!(
     320            0 :                     "postgres not ready yet, retrying in {}s",
     321            0 :                     PG_WAIT_RETRY_INTERVAL.as_secs_f32()
     322              :                 );
     323            0 :                 tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await;
     324            0 :                 continue;
     325              :             }
     326              :         }
     327              :     }
     328            0 : }
     329              : 
     330            0 : async fn run_dump_restore(
     331            0 :     workdir: Utf8PathBuf,
     332            0 :     pg_bin_dir: Utf8PathBuf,
     333            0 :     pg_lib_dir: Utf8PathBuf,
     334            0 :     source_connstring: String,
     335            0 :     destination_connstring: String,
     336            0 : ) -> Result<(), anyhow::Error> {
     337            0 :     let dumpdir = workdir.join("dumpdir");
     338            0 :     let num_jobs = num_cpus::get().to_string();
     339            0 :     info!("using {num_jobs} jobs for dump/restore");
     340              : 
     341            0 :     let common_args = [
     342            0 :         // schema mapping (prob suffices to specify them on one side)
     343            0 :         "--no-owner".to_string(),
     344            0 :         "--no-privileges".to_string(),
     345            0 :         "--no-publications".to_string(),
     346            0 :         "--no-security-labels".to_string(),
     347            0 :         "--no-subscriptions".to_string(),
     348            0 :         "--no-tablespaces".to_string(),
     349            0 :         "--no-event-triggers".to_string(),
     350            0 :         // format
     351            0 :         "--format".to_string(),
     352            0 :         "directory".to_string(),
     353            0 :         // concurrency
     354            0 :         "--jobs".to_string(),
     355            0 :         num_jobs,
     356            0 :         // progress updates
     357            0 :         "--verbose".to_string(),
     358            0 :     ];
     359              : 
     360            0 :     info!("dump into the working directory");
     361              :     {
     362            0 :         let mut pg_dump = tokio::process::Command::new(pg_bin_dir.join("pg_dump"))
     363            0 :             .args(&common_args)
     364            0 :             .arg("-f")
     365            0 :             .arg(&dumpdir)
     366            0 :             .arg("--no-sync")
     367            0 :             // POSITIONAL args
     368            0 :             // source db (db name included in connection string)
     369            0 :             .arg(&source_connstring)
     370            0 :             // how we run it
     371            0 :             .env_clear()
     372            0 :             .env("LD_LIBRARY_PATH", &pg_lib_dir)
     373            0 :             .env(
     374            0 :                 "ASAN_OPTIONS",
     375            0 :                 std::env::var("ASAN_OPTIONS").unwrap_or_default(),
     376            0 :             )
     377            0 :             .env(
     378            0 :                 "UBSAN_OPTIONS",
     379            0 :                 std::env::var("UBSAN_OPTIONS").unwrap_or_default(),
     380            0 :             )
     381            0 :             .kill_on_drop(true)
     382            0 :             .stdout(std::process::Stdio::piped())
     383            0 :             .stderr(std::process::Stdio::piped())
     384            0 :             .spawn()
     385            0 :             .context("spawn pg_dump")?;
     386              : 
     387            0 :         info!(pid=%pg_dump.id().unwrap(), "spawned pg_dump");
     388              : 
     389            0 :         tokio::spawn(
     390            0 :             child_stdio_to_log::relay_process_output(pg_dump.stdout.take(), pg_dump.stderr.take())
     391            0 :                 .instrument(info_span!("pg_dump")),
     392              :         );
     393              : 
     394            0 :         let st = pg_dump.wait().await.context("wait for pg_dump")?;
     395            0 :         info!(status=?st, "pg_dump exited");
     396            0 :         if !st.success() {
     397            0 :             error!(status=%st, "pg_dump failed, restore will likely fail as well");
     398            0 :             bail!("pg_dump failed");
     399            0 :         }
     400              :     }
     401              : 
     402              :     // TODO: maybe do it in a streaming way, plenty of internal research done on this already
     403              :     // TODO: do the unlogged table trick
     404              :     {
     405            0 :         let mut pg_restore = tokio::process::Command::new(pg_bin_dir.join("pg_restore"))
     406            0 :             .args(&common_args)
     407            0 :             .arg("-d")
     408            0 :             .arg(&destination_connstring)
     409            0 :             // POSITIONAL args
     410            0 :             .arg(&dumpdir)
     411            0 :             // how we run it
     412            0 :             .env_clear()
     413            0 :             .env("LD_LIBRARY_PATH", &pg_lib_dir)
     414            0 :             .env(
     415            0 :                 "ASAN_OPTIONS",
     416            0 :                 std::env::var("ASAN_OPTIONS").unwrap_or_default(),
     417            0 :             )
     418            0 :             .env(
     419            0 :                 "UBSAN_OPTIONS",
     420            0 :                 std::env::var("UBSAN_OPTIONS").unwrap_or_default(),
     421            0 :             )
     422            0 :             .kill_on_drop(true)
     423            0 :             .stdout(std::process::Stdio::piped())
     424            0 :             .stderr(std::process::Stdio::piped())
     425            0 :             .spawn()
     426            0 :             .context("spawn pg_restore")?;
     427              : 
     428            0 :         info!(pid=%pg_restore.id().unwrap(), "spawned pg_restore");
     429            0 :         tokio::spawn(
     430            0 :             child_stdio_to_log::relay_process_output(
     431            0 :                 pg_restore.stdout.take(),
     432            0 :                 pg_restore.stderr.take(),
     433              :             )
     434            0 :             .instrument(info_span!("pg_restore")),
     435              :         );
     436            0 :         let st = pg_restore.wait().await.context("wait for pg_restore")?;
     437            0 :         info!(status=?st, "pg_restore exited");
     438            0 :         if !st.success() {
     439            0 :             error!(status=%st, "pg_restore failed, restore will likely fail as well");
     440            0 :             bail!("pg_restore failed");
     441            0 :         }
     442              :     }
     443              : 
     444            0 :     Ok(())
     445            0 : }
     446              : 
     447              : #[allow(clippy::too_many_arguments)]
     448            0 : async fn cmd_pgdata(
     449            0 :     s3_client: Option<&aws_sdk_s3::Client>,
     450            0 :     kms_client: Option<aws_sdk_kms::Client>,
     451            0 :     maybe_s3_prefix: Option<s3_uri::S3Uri>,
     452            0 :     maybe_spec: Option<Spec>,
     453            0 :     source_connection_string: Option<String>,
     454            0 :     interactive: bool,
     455            0 :     pg_port: u16,
     456            0 :     workdir: Utf8PathBuf,
     457            0 :     pg_bin_dir: Utf8PathBuf,
     458            0 :     pg_lib_dir: Utf8PathBuf,
     459            0 :     num_cpus: Option<usize>,
     460            0 :     memory_mb: Option<usize>,
     461            0 : ) -> Result<(), anyhow::Error> {
     462            0 :     if maybe_spec.is_none() && source_connection_string.is_none() {
     463            0 :         bail!("spec must be provided for pgdata command");
     464            0 :     }
     465            0 :     if maybe_spec.is_some() && source_connection_string.is_some() {
     466            0 :         bail!("only one of spec or source_connection_string can be provided");
     467            0 :     }
     468              : 
     469            0 :     let source_connection_string = if let Some(spec) = maybe_spec {
     470            0 :         match spec.encryption_secret {
     471            0 :             EncryptionSecret::KMS { key_id } => {
     472            0 :                 decode_connstring(
     473            0 :                     kms_client.as_ref().unwrap(),
     474            0 :                     &key_id,
     475            0 :                     spec.source_connstring_ciphertext_base64,
     476            0 :                 )
     477            0 :                 .await?
     478              :             }
     479              :         }
     480              :     } else {
     481            0 :         source_connection_string.unwrap()
     482              :     };
     483              : 
     484            0 :     let superuser = "cloud_admin";
     485            0 :     let destination_connstring =
     486            0 :         format!("host=localhost port={pg_port} user={superuser} dbname=neondb");
     487              : 
     488            0 :     let pgdata_dir = workdir.join("pgdata");
     489            0 :     let mut proc = PostgresProcess::new(pgdata_dir.clone(), pg_bin_dir.clone(), pg_lib_dir.clone());
     490            0 :     let nproc = num_cpus.unwrap_or_else(num_cpus::get);
     491            0 :     let memory_mb = memory_mb.unwrap_or(256);
     492            0 :     proc.start(superuser, pg_port, nproc, memory_mb).await?;
     493            0 :     wait_until_ready(destination_connstring.clone(), "neondb".to_string()).await;
     494              : 
     495            0 :     run_dump_restore(
     496            0 :         workdir.clone(),
     497            0 :         pg_bin_dir,
     498            0 :         pg_lib_dir,
     499            0 :         source_connection_string,
     500            0 :         destination_connstring,
     501            0 :     )
     502            0 :     .await?;
     503              : 
     504              :     // If interactive mode, wait for Ctrl+C
     505            0 :     if interactive {
     506            0 :         info!("Running in interactive mode. Press Ctrl+C to shut down.");
     507            0 :         tokio::signal::ctrl_c().await.context("wait for ctrl-c")?;
     508            0 :     }
     509              : 
     510            0 :     proc.shutdown().await?;
     511              : 
     512              :     // Only sync if s3_prefix was specified
     513            0 :     if let Some(s3_prefix) = maybe_s3_prefix {
     514            0 :         info!("upload pgdata");
     515            0 :         aws_s3_sync::upload_dir_recursive(
     516            0 :             s3_client.unwrap(),
     517            0 :             Utf8Path::new(&pgdata_dir),
     518            0 :             &s3_prefix.append("/pgdata/"),
     519            0 :         )
     520            0 :         .await
     521            0 :         .context("sync dump directory to destination")?;
     522              : 
     523            0 :         info!("write pgdata status to s3");
     524              :         {
     525            0 :             let status_dir = workdir.join("status");
     526            0 :             std::fs::create_dir(&status_dir).context("create status directory")?;
     527            0 :             let status_file = status_dir.join("pgdata");
     528            0 :             std::fs::write(&status_file, serde_json::json!({"done": true}).to_string())
     529            0 :                 .context("write status file")?;
     530            0 :             aws_s3_sync::upload_dir_recursive(
     531            0 :                 s3_client.as_ref().unwrap(),
     532            0 :                 &status_dir,
     533            0 :                 &s3_prefix.append("/status/"),
     534            0 :             )
     535            0 :             .await
     536            0 :             .context("sync status directory to destination")?;
     537              :         }
     538            0 :     }
     539              : 
     540            0 :     Ok(())
     541            0 : }
     542              : 
     543            0 : async fn cmd_dumprestore(
     544            0 :     kms_client: Option<aws_sdk_kms::Client>,
     545            0 :     maybe_spec: Option<Spec>,
     546            0 :     source_connection_string: Option<String>,
     547            0 :     destination_connection_string: Option<String>,
     548            0 :     workdir: Utf8PathBuf,
     549            0 :     pg_bin_dir: Utf8PathBuf,
     550            0 :     pg_lib_dir: Utf8PathBuf,
     551            0 : ) -> Result<(), anyhow::Error> {
     552            0 :     let (source_connstring, destination_connstring) = if let Some(spec) = maybe_spec {
     553            0 :         match spec.encryption_secret {
     554            0 :             EncryptionSecret::KMS { key_id } => {
     555            0 :                 let source = decode_connstring(
     556            0 :                     kms_client.as_ref().unwrap(),
     557            0 :                     &key_id,
     558            0 :                     spec.source_connstring_ciphertext_base64,
     559            0 :                 )
     560            0 :                 .await
     561            0 :                 .context("decrypt source connection string")?;
     562              : 
     563            0 :                 let dest = if let Some(dest_ciphertext) =
     564            0 :                     spec.destination_connstring_ciphertext_base64
     565              :                 {
     566            0 :                     decode_connstring(kms_client.as_ref().unwrap(), &key_id, dest_ciphertext)
     567            0 :                         .await
     568            0 :                         .context("decrypt destination connection string")?
     569              :                 } else {
     570            0 :                     bail!(
     571            0 :                         "destination connection string must be provided in spec for dump_restore command"
     572              :                     );
     573              :                 };
     574              : 
     575            0 :                 (source, dest)
     576              :             }
     577              :         }
     578              :     } else {
     579              :         (
     580            0 :             source_connection_string.unwrap(),
     581            0 :             if let Some(val) = destination_connection_string {
     582            0 :                 val
     583              :             } else {
     584            0 :                 bail!("destination connection string must be provided for dump_restore command");
     585              :             },
     586              :         )
     587              :     };
     588              : 
     589            0 :     run_dump_restore(
     590            0 :         workdir,
     591            0 :         pg_bin_dir,
     592            0 :         pg_lib_dir,
     593            0 :         source_connstring,
     594            0 :         destination_connstring,
     595            0 :     )
     596            0 :     .await
     597            0 : }
     598              : 
     599              : #[tokio::main]
     600            0 : pub(crate) async fn main() -> anyhow::Result<()> {
     601            0 :     utils::logging::init(
     602            0 :         utils::logging::LogFormat::Json,
     603            0 :         utils::logging::TracingErrorLayerEnablement::EnableWithRustLogFilter,
     604            0 :         utils::logging::Output::Stdout,
     605            0 :     )?;
     606              : 
     607            0 :     info!("starting");
     608              : 
     609            0 :     let args = Args::parse();
     610              : 
     611              :     // Initialize AWS clients only if s3_prefix is specified
     612            0 :     let (s3_client, kms_client) = if args.s3_prefix.is_some() {
     613              :         // Create AWS config with enhanced retry settings
     614            0 :         let config = aws_config::defaults(BehaviorVersion::v2024_03_28())
     615            0 :             .retry_config(
     616            0 :                 aws_config::retry::RetryConfig::standard()
     617            0 :                     .with_max_attempts(5) // Retry up to 5 times
     618            0 :                     .with_initial_backoff(std::time::Duration::from_millis(200)) // Start with 200ms delay
     619            0 :                     .with_max_backoff(std::time::Duration::from_secs(5)), // Cap at 5 seconds
     620            0 :             )
     621            0 :             .load()
     622            0 :             .await;
     623              : 
     624              :         // Create clients from the config with enhanced retry settings
     625            0 :         let s3_client = aws_sdk_s3::Client::new(&config);
     626            0 :         let kms = aws_sdk_kms::Client::new(&config);
     627            0 :         (Some(s3_client), Some(kms))
     628              :     } else {
     629            0 :         (None, None)
     630              :     };
     631              : 
     632              :     // Capture everything from spec assignment onwards to handle errors
     633            0 :     let res = async {
     634            0 :         let spec: Option<Spec> = if let Some(s3_prefix) = &args.s3_prefix {
     635            0 :             let spec_key = s3_prefix.append("/spec.json");
     636            0 :             let object = s3_client
     637            0 :                 .as_ref()
     638            0 :                 .unwrap()
     639            0 :                 .get_object()
     640            0 :                 .bucket(&spec_key.bucket)
     641            0 :                 .key(spec_key.key)
     642            0 :                 .send()
     643            0 :                 .await
     644            0 :                 .context("get spec from s3")?
     645              :                 .body
     646            0 :                 .collect()
     647            0 :                 .await
     648            0 :                 .context("download spec body")?;
     649            0 :             serde_json::from_slice(&object.into_bytes()).context("parse spec as json")?
     650              :         } else {
     651            0 :             None
     652              :         };
     653              : 
     654            0 :         match tokio::fs::create_dir(&args.working_directory).await {
     655            0 :             Ok(()) => {}
     656            0 :             Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
     657            0 :                 if !is_directory_empty(&args.working_directory)
     658            0 :                     .await
     659            0 :                     .context("check if working directory is empty")?
     660              :                 {
     661            0 :                     bail!("working directory is not empty");
     662            0 :                 } else {
     663            0 :                     // ok
     664            0 :                 }
     665              :             }
     666            0 :             Err(e) => return Err(anyhow::Error::new(e).context("create working directory")),
     667              :         }
     668              : 
     669            0 :         match args.command.clone() {
     670              :             Command::Pgdata {
     671            0 :                 source_connection_string,
     672            0 :                 interactive,
     673            0 :                 pg_port,
     674            0 :                 num_cpus,
     675            0 :                 memory_mb,
     676              :             } => {
     677            0 :                 cmd_pgdata(
     678            0 :                     s3_client.as_ref(),
     679            0 :                     kms_client,
     680            0 :                     args.s3_prefix.clone(),
     681            0 :                     spec,
     682            0 :                     source_connection_string,
     683            0 :                     interactive,
     684            0 :                     pg_port,
     685            0 :                     args.working_directory.clone(),
     686            0 :                     args.pg_bin_dir,
     687            0 :                     args.pg_lib_dir,
     688            0 :                     num_cpus,
     689            0 :                     memory_mb,
     690            0 :                 )
     691            0 :                 .await
     692              :             }
     693              :             Command::DumpRestore {
     694            0 :                 source_connection_string,
     695            0 :                 destination_connection_string,
     696              :             } => {
     697            0 :                 cmd_dumprestore(
     698            0 :                     kms_client,
     699            0 :                     spec,
     700            0 :                     source_connection_string,
     701            0 :                     destination_connection_string,
     702            0 :                     args.working_directory.clone(),
     703            0 :                     args.pg_bin_dir,
     704            0 :                     args.pg_lib_dir,
     705            0 :                 )
     706            0 :                 .await
     707              :             }
     708              :         }
     709            0 :     }
     710            0 :     .await;
     711              : 
     712            0 :     if let Some(s3_prefix) = args.s3_prefix {
     713            0 :         info!("write job status to s3");
     714            0 :         {
     715            0 :             let status_dir = args.working_directory.join("status");
     716            0 :             if std::fs::exists(&status_dir)?.not() {
     717            0 :                 std::fs::create_dir(&status_dir).context("create status directory")?;
     718            0 :             }
     719            0 :             let status_file = status_dir.join("fast_import");
     720            0 :             let res_obj = match res {
     721            0 :                 Ok(_) => serde_json::json!({"command": args.command.as_str(), "done": true}),
     722            0 :                 Err(err) => {
     723            0 :                     serde_json::json!({"command": args.command.as_str(), "done": false, "error": err.to_string()})
     724            0 :                 }
     725            0 :             };
     726            0 :             std::fs::write(&status_file, res_obj.to_string()).context("write status file")?;
     727            0 :             aws_s3_sync::upload_dir_recursive(
     728            0 :                 s3_client.as_ref().unwrap(),
     729            0 :                 &status_dir,
     730            0 :                 &s3_prefix.append("/status/"),
     731            0 :             )
     732            0 :             .await
     733            0 :             .context("sync status directory to destination")?;
     734            0 :         }
     735            0 :     }
     736            0 : 
     737            0 :     Ok(())
     738            0 : }
        

Generated by: LCOV version 2.1-beta