Line data Source code
1 : //! This program dumps a remote Postgres database into a local Postgres database
2 : //! and uploads the resulting PGDATA into object storage for import into a Timeline.
3 : //!
4 : //! # Context, Architecture, Design
5 : //!
6 : //! See cloud.git Fast Imports RFC (<https://github.com/neondatabase/cloud/pull/19799>)
7 : //! for the full picture.
8 : //! The RFC describing the storage pieces of importing the PGDATA dump into a Timeline
9 : //! is publicly accessible at <https://github.com/neondatabase/neon/pull/9538>.
10 : //!
11 : //! # This is a Prototype!
12 : //!
13 : //! This program is part of a prototype feature and not yet used in production.
14 : //!
15 : //! The cloud.git RFC contains lots of suggestions for improving e2e throughput
16 : //! of this step of the timeline import process.
17 : //!
18 : //! # Local Testing
19 : //!
20 : //! - Comment out most of the pgxns in compute-node.Dockerfile to speed up the build.
21 : //! - Build the image with the following command:
22 : //!
23 : //! ```bash
24 : //! docker buildx build --platform linux/amd64 --build-arg DEBIAN_VERSION=bullseye --build-arg GIT_VERSION=local --build-arg PG_VERSION=v14 --build-arg BUILD_TAG="$(date --iso-8601=s -u)" -t localhost:3030/localregistry/compute-node-v14:latest -f compute/compute-node.Dockerfile .
25 : //! docker push localhost:3030/localregistry/compute-node-v14:latest
26 : //! ```
27 :
28 : use anyhow::{Context, bail};
29 : use aws_config::BehaviorVersion;
30 : use camino::{Utf8Path, Utf8PathBuf};
31 : use clap::{Parser, Subcommand};
32 : use compute_tools::extension_server::{PostgresMajorVersion, get_pg_version};
33 : use nix::unistd::Pid;
34 : use std::ops::Not;
35 : use tracing::{Instrument, error, info, info_span, warn};
36 : use utils::fs_ext::is_directory_empty;
37 :
38 : #[path = "fast_import/aws_s3_sync.rs"]
39 : mod aws_s3_sync;
40 : #[path = "fast_import/child_stdio_to_log.rs"]
41 : mod child_stdio_to_log;
42 : #[path = "fast_import/s3_uri.rs"]
43 : mod s3_uri;
44 :
45 : const PG_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(600);
46 : const PG_WAIT_RETRY_INTERVAL: std::time::Duration = std::time::Duration::from_millis(300);
47 :
48 : #[derive(Subcommand, Debug, Clone, serde::Serialize)]
49 : enum Command {
50 : /// Runs local postgres (neon binary), restores into it,
51 : /// uploads pgdata to s3 to be consumed by pageservers
52 : Pgdata {
53 : /// Raw connection string to the source database. Used only in tests,
54 : /// real scenario uses encrypted connection string in spec.json from s3.
55 : #[clap(long)]
56 : source_connection_string: Option<String>,
57 : /// If specified, will not shut down the local postgres after the import. Used in local testing
58 : #[clap(short, long)]
59 0 : interactive: bool,
60 : /// Port to run postgres on. Default is 5432.
61 0 : #[clap(long, default_value_t = 5432)]
62 0 : pg_port: u16, // port to run postgres on, 5432 is default
63 :
64 : /// Number of CPUs in the system. This is used to configure # of
65 : /// parallel worker processes, for index creation.
66 : #[clap(long, env = "NEON_IMPORTER_NUM_CPUS")]
67 : num_cpus: Option<usize>,
68 :
69 : /// Amount of RAM in the system. This is used to configure shared_buffers
70 : /// and maintenance_work_mem.
71 : #[clap(long, env = "NEON_IMPORTER_MEMORY_MB")]
72 : memory_mb: Option<usize>,
73 : },
74 :
75 : /// Runs pg_dump-pg_restore from source to destination without running local postgres.
76 : DumpRestore {
77 : /// Raw connection string to the source database. Used only in tests,
78 : /// real scenario uses encrypted connection string in spec.json from s3.
79 : #[clap(long)]
80 : source_connection_string: Option<String>,
81 : /// Raw connection string to the destination database. Used only in tests,
82 : /// real scenario uses encrypted connection string in spec.json from s3.
83 : #[clap(long)]
84 : destination_connection_string: Option<String>,
85 : },
86 : }
87 :
88 : impl Command {
89 0 : fn as_str(&self) -> &'static str {
90 0 : match self {
91 0 : Command::Pgdata { .. } => "pgdata",
92 0 : Command::DumpRestore { .. } => "dump-restore",
93 : }
94 0 : }
95 : }
96 :
97 : #[derive(clap::Parser)]
98 : struct Args {
99 : #[clap(long, env = "NEON_IMPORTER_WORKDIR")]
100 0 : working_directory: Utf8PathBuf,
101 : #[clap(long, env = "NEON_IMPORTER_S3_PREFIX")]
102 : s3_prefix: Option<s3_uri::S3Uri>,
103 : #[clap(long, env = "NEON_IMPORTER_PG_BIN_DIR")]
104 0 : pg_bin_dir: Utf8PathBuf,
105 : #[clap(long, env = "NEON_IMPORTER_PG_LIB_DIR")]
106 0 : pg_lib_dir: Utf8PathBuf,
107 :
108 : #[clap(subcommand)]
109 : command: Command,
110 : }
111 :
112 : #[serde_with::serde_as]
113 0 : #[derive(serde::Deserialize)]
114 : struct Spec {
115 : encryption_secret: EncryptionSecret,
116 : #[serde_as(as = "serde_with::base64::Base64")]
117 : source_connstring_ciphertext_base64: Vec<u8>,
118 : #[serde_as(as = "Option<serde_with::base64::Base64>")]
119 : destination_connstring_ciphertext_base64: Option<Vec<u8>>,
120 : }
121 :
122 0 : #[derive(serde::Deserialize)]
123 : enum EncryptionSecret {
124 : #[allow(clippy::upper_case_acronyms)]
125 : KMS { key_id: String },
126 : }
127 :
128 : // copied from pageserver_api::config::defaults::DEFAULT_LOCALE to avoid dependency just for a constant
129 : const DEFAULT_LOCALE: &str = if cfg!(target_os = "macos") {
130 : "C"
131 : } else {
132 : "C.UTF-8"
133 : };
134 :
135 0 : async fn decode_connstring(
136 0 : kms_client: &aws_sdk_kms::Client,
137 0 : key_id: &String,
138 0 : connstring_ciphertext_base64: Vec<u8>,
139 0 : ) -> Result<String, anyhow::Error> {
140 0 : let mut output = kms_client
141 0 : .decrypt()
142 0 : .key_id(key_id)
143 0 : .ciphertext_blob(aws_sdk_s3::primitives::Blob::new(
144 0 : connstring_ciphertext_base64,
145 0 : ))
146 0 : .send()
147 0 : .await
148 0 : .context("decrypt connection string")?;
149 :
150 0 : let plaintext = output
151 0 : .plaintext
152 0 : .take()
153 0 : .context("get plaintext connection string")?;
154 :
155 0 : String::from_utf8(plaintext.into_inner()).context("parse connection string as utf8")
156 0 : }
157 :
158 : struct PostgresProcess {
159 : pgdata_dir: Utf8PathBuf,
160 : pg_bin_dir: Utf8PathBuf,
161 : pgbin: Utf8PathBuf,
162 : pg_lib_dir: Utf8PathBuf,
163 : postgres_proc: Option<tokio::process::Child>,
164 : }
165 :
166 : impl PostgresProcess {
167 0 : fn new(pgdata_dir: Utf8PathBuf, pg_bin_dir: Utf8PathBuf, pg_lib_dir: Utf8PathBuf) -> Self {
168 0 : Self {
169 0 : pgdata_dir,
170 0 : pgbin: pg_bin_dir.join("postgres"),
171 0 : pg_bin_dir,
172 0 : pg_lib_dir,
173 0 : postgres_proc: None,
174 0 : }
175 0 : }
176 :
177 0 : async fn prepare(&self, initdb_user: &str) -> Result<(), anyhow::Error> {
178 0 : tokio::fs::create_dir(&self.pgdata_dir)
179 0 : .await
180 0 : .context("create pgdata directory")?;
181 :
182 0 : let pg_version = match get_pg_version(self.pgbin.as_ref()) {
183 0 : PostgresMajorVersion::V14 => 14,
184 0 : PostgresMajorVersion::V15 => 15,
185 0 : PostgresMajorVersion::V16 => 16,
186 0 : PostgresMajorVersion::V17 => 17,
187 : };
188 0 : postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs {
189 0 : superuser: initdb_user,
190 0 : locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded,
191 0 : pg_version,
192 0 : initdb_bin: self.pg_bin_dir.join("initdb").as_ref(),
193 0 : library_search_path: &self.pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local.
194 0 : pgdata: &self.pgdata_dir,
195 0 : })
196 0 : .await
197 0 : .context("initdb")
198 0 : }
199 :
200 0 : async fn start(
201 0 : &mut self,
202 0 : initdb_user: &str,
203 0 : port: u16,
204 0 : nproc: usize,
205 0 : memory_mb: usize,
206 0 : ) -> Result<&tokio::process::Child, anyhow::Error> {
207 0 : self.prepare(initdb_user).await?;
208 :
209 : // Somewhat arbitrarily, use 10 % of memory for shared buffer cache, 70% for
210 : // maintenance_work_mem (i.e. for sorting during index creation), and leave the rest
211 : // available for misc other stuff that PostgreSQL uses memory for.
212 0 : let shared_buffers_mb = ((memory_mb as f32) * 0.10) as usize;
213 0 : let maintenance_work_mem_mb = ((memory_mb as f32) * 0.70) as usize;
214 :
215 : //
216 : // Launch postgres process
217 : //
218 0 : let mut proc = tokio::process::Command::new(&self.pgbin)
219 0 : .arg("-D")
220 0 : .arg(&self.pgdata_dir)
221 0 : .args(["-p", &format!("{port}")])
222 0 : .args(["-c", "wal_level=minimal"])
223 0 : .args(["-c", &format!("shared_buffers={shared_buffers_mb}MB")])
224 0 : .args(["-c", "max_wal_senders=0"])
225 0 : .args(["-c", "fsync=off"])
226 0 : .args(["-c", "full_page_writes=off"])
227 0 : .args(["-c", "synchronous_commit=off"])
228 0 : .args([
229 0 : "-c",
230 0 : &format!("maintenance_work_mem={maintenance_work_mem_mb}MB"),
231 0 : ])
232 0 : .args(["-c", &format!("max_parallel_maintenance_workers={nproc}")])
233 0 : .args(["-c", &format!("max_parallel_workers={nproc}")])
234 0 : .args(["-c", &format!("max_parallel_workers_per_gather={nproc}")])
235 0 : .args(["-c", &format!("max_worker_processes={nproc}")])
236 0 : .args(["-c", "effective_io_concurrency=100"])
237 0 : .env_clear()
238 0 : .env("LD_LIBRARY_PATH", &self.pg_lib_dir)
239 0 : .env(
240 0 : "ASAN_OPTIONS",
241 0 : std::env::var("ASAN_OPTIONS").unwrap_or_default(),
242 0 : )
243 0 : .env(
244 0 : "UBSAN_OPTIONS",
245 0 : std::env::var("UBSAN_OPTIONS").unwrap_or_default(),
246 0 : )
247 0 : .stdout(std::process::Stdio::piped())
248 0 : .stderr(std::process::Stdio::piped())
249 0 : .spawn()
250 0 : .context("spawn postgres")?;
251 :
252 0 : info!("spawned postgres, waiting for it to become ready");
253 : tokio::spawn(
254 0 : child_stdio_to_log::relay_process_output(proc.stdout.take(), proc.stderr.take())
255 0 : .instrument(info_span!("postgres")),
256 : );
257 :
258 0 : self.postgres_proc = Some(proc);
259 0 : Ok(self.postgres_proc.as_ref().unwrap())
260 0 : }
261 :
262 0 : async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
263 0 : let proc: &mut tokio::process::Child = self.postgres_proc.as_mut().unwrap();
264 0 : info!("shutdown postgres");
265 0 : nix::sys::signal::kill(
266 0 : Pid::from_raw(i32::try_from(proc.id().unwrap()).expect("convert child pid to i32")),
267 0 : nix::sys::signal::SIGTERM,
268 0 : )
269 0 : .context("signal postgres to shut down")?;
270 0 : proc.wait()
271 0 : .await
272 0 : .context("wait for postgres to shut down")
273 0 : .map(|_| ())
274 0 : }
275 : }
276 :
277 0 : async fn wait_until_ready(connstring: String, create_dbname: String) {
278 0 : // Create neondb database in the running postgres
279 0 : let start_time = std::time::Instant::now();
280 :
281 : loop {
282 0 : if start_time.elapsed() > PG_WAIT_TIMEOUT {
283 0 : error!(
284 0 : "timeout exceeded: failed to poll postgres and create database within 10 minutes"
285 : );
286 0 : std::process::exit(1);
287 0 : }
288 0 :
289 0 : match tokio_postgres::connect(
290 0 : &connstring.replace("dbname=neondb", "dbname=postgres"),
291 0 : tokio_postgres::NoTls,
292 0 : )
293 0 : .await
294 : {
295 0 : Ok((client, connection)) => {
296 0 : // Spawn the connection handling task to maintain the connection
297 0 : tokio::spawn(async move {
298 0 : if let Err(e) = connection.await {
299 0 : warn!("connection error: {}", e);
300 0 : }
301 0 : });
302 0 :
303 0 : match client
304 0 : .simple_query(format!("CREATE DATABASE {create_dbname};").as_str())
305 0 : .await
306 : {
307 : Ok(_) => {
308 0 : info!("created {} database", create_dbname);
309 0 : break;
310 : }
311 0 : Err(e) => {
312 0 : warn!(
313 0 : "failed to create database: {}, retying in {}s",
314 0 : e,
315 0 : PG_WAIT_RETRY_INTERVAL.as_secs_f32()
316 : );
317 0 : tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await;
318 0 : continue;
319 : }
320 : }
321 : }
322 : Err(_) => {
323 0 : info!(
324 0 : "postgres not ready yet, retrying in {}s",
325 0 : PG_WAIT_RETRY_INTERVAL.as_secs_f32()
326 : );
327 0 : tokio::time::sleep(PG_WAIT_RETRY_INTERVAL).await;
328 0 : continue;
329 : }
330 : }
331 : }
332 0 : }
333 :
334 0 : async fn run_dump_restore(
335 0 : workdir: Utf8PathBuf,
336 0 : pg_bin_dir: Utf8PathBuf,
337 0 : pg_lib_dir: Utf8PathBuf,
338 0 : source_connstring: String,
339 0 : destination_connstring: String,
340 0 : ) -> Result<(), anyhow::Error> {
341 0 : let dumpdir = workdir.join("dumpdir");
342 0 :
343 0 : let common_args = [
344 0 : // schema mapping (prob suffices to specify them on one side)
345 0 : "--no-owner".to_string(),
346 0 : "--no-privileges".to_string(),
347 0 : "--no-publications".to_string(),
348 0 : "--no-security-labels".to_string(),
349 0 : "--no-subscriptions".to_string(),
350 0 : "--no-tablespaces".to_string(),
351 0 : // format
352 0 : "--format".to_string(),
353 0 : "directory".to_string(),
354 0 : // concurrency
355 0 : "--jobs".to_string(),
356 0 : num_cpus::get().to_string(),
357 0 : // progress updates
358 0 : "--verbose".to_string(),
359 0 : ];
360 0 :
361 0 : info!("dump into the working directory");
362 : {
363 0 : let mut pg_dump = tokio::process::Command::new(pg_bin_dir.join("pg_dump"))
364 0 : .args(&common_args)
365 0 : .arg("-f")
366 0 : .arg(&dumpdir)
367 0 : .arg("--no-sync")
368 0 : // POSITIONAL args
369 0 : // source db (db name included in connection string)
370 0 : .arg(&source_connstring)
371 0 : // how we run it
372 0 : .env_clear()
373 0 : .env("LD_LIBRARY_PATH", &pg_lib_dir)
374 0 : .env(
375 0 : "ASAN_OPTIONS",
376 0 : std::env::var("ASAN_OPTIONS").unwrap_or_default(),
377 0 : )
378 0 : .env(
379 0 : "UBSAN_OPTIONS",
380 0 : std::env::var("UBSAN_OPTIONS").unwrap_or_default(),
381 0 : )
382 0 : .kill_on_drop(true)
383 0 : .stdout(std::process::Stdio::piped())
384 0 : .stderr(std::process::Stdio::piped())
385 0 : .spawn()
386 0 : .context("spawn pg_dump")?;
387 :
388 0 : info!(pid=%pg_dump.id().unwrap(), "spawned pg_dump");
389 :
390 : tokio::spawn(
391 0 : child_stdio_to_log::relay_process_output(pg_dump.stdout.take(), pg_dump.stderr.take())
392 0 : .instrument(info_span!("pg_dump")),
393 : );
394 :
395 0 : let st = pg_dump.wait().await.context("wait for pg_dump")?;
396 0 : info!(status=?st, "pg_dump exited");
397 0 : if !st.success() {
398 0 : error!(status=%st, "pg_dump failed, restore will likely fail as well");
399 0 : bail!("pg_dump failed");
400 0 : }
401 : }
402 :
403 : // TODO: maybe do it in a streaming way, plenty of internal research done on this already
404 : // TODO: do the unlogged table trick
405 : {
406 0 : let mut pg_restore = tokio::process::Command::new(pg_bin_dir.join("pg_restore"))
407 0 : .args(&common_args)
408 0 : .arg("-d")
409 0 : .arg(&destination_connstring)
410 0 : // POSITIONAL args
411 0 : .arg(&dumpdir)
412 0 : // how we run it
413 0 : .env_clear()
414 0 : .env("LD_LIBRARY_PATH", &pg_lib_dir)
415 0 : .env(
416 0 : "ASAN_OPTIONS",
417 0 : std::env::var("ASAN_OPTIONS").unwrap_or_default(),
418 0 : )
419 0 : .env(
420 0 : "UBSAN_OPTIONS",
421 0 : std::env::var("UBSAN_OPTIONS").unwrap_or_default(),
422 0 : )
423 0 : .kill_on_drop(true)
424 0 : .stdout(std::process::Stdio::piped())
425 0 : .stderr(std::process::Stdio::piped())
426 0 : .spawn()
427 0 : .context("spawn pg_restore")?;
428 :
429 0 : info!(pid=%pg_restore.id().unwrap(), "spawned pg_restore");
430 : tokio::spawn(
431 0 : child_stdio_to_log::relay_process_output(
432 0 : pg_restore.stdout.take(),
433 0 : pg_restore.stderr.take(),
434 0 : )
435 0 : .instrument(info_span!("pg_restore")),
436 : );
437 0 : let st = pg_restore.wait().await.context("wait for pg_restore")?;
438 0 : info!(status=?st, "pg_restore exited");
439 0 : if !st.success() {
440 0 : error!(status=%st, "pg_restore failed, restore will likely fail as well");
441 0 : bail!("pg_restore failed");
442 0 : }
443 0 : }
444 0 :
445 0 : Ok(())
446 0 : }
447 :
448 : #[allow(clippy::too_many_arguments)]
449 0 : async fn cmd_pgdata(
450 0 : s3_client: Option<&aws_sdk_s3::Client>,
451 0 : kms_client: Option<aws_sdk_kms::Client>,
452 0 : maybe_s3_prefix: Option<s3_uri::S3Uri>,
453 0 : maybe_spec: Option<Spec>,
454 0 : source_connection_string: Option<String>,
455 0 : interactive: bool,
456 0 : pg_port: u16,
457 0 : workdir: Utf8PathBuf,
458 0 : pg_bin_dir: Utf8PathBuf,
459 0 : pg_lib_dir: Utf8PathBuf,
460 0 : num_cpus: Option<usize>,
461 0 : memory_mb: Option<usize>,
462 0 : ) -> Result<(), anyhow::Error> {
463 0 : if maybe_spec.is_none() && source_connection_string.is_none() {
464 0 : bail!("spec must be provided for pgdata command");
465 0 : }
466 0 : if maybe_spec.is_some() && source_connection_string.is_some() {
467 0 : bail!("only one of spec or source_connection_string can be provided");
468 0 : }
469 :
470 0 : let source_connection_string = if let Some(spec) = maybe_spec {
471 0 : match spec.encryption_secret {
472 0 : EncryptionSecret::KMS { key_id } => {
473 0 : decode_connstring(
474 0 : kms_client.as_ref().unwrap(),
475 0 : &key_id,
476 0 : spec.source_connstring_ciphertext_base64,
477 0 : )
478 0 : .await?
479 : }
480 : }
481 : } else {
482 0 : source_connection_string.unwrap()
483 : };
484 :
485 0 : let superuser = "cloud_admin";
486 0 : let destination_connstring = format!(
487 0 : "host=localhost port={} user={} dbname=neondb",
488 0 : pg_port, superuser
489 0 : );
490 0 :
491 0 : let pgdata_dir = workdir.join("pgdata");
492 0 : let mut proc = PostgresProcess::new(pgdata_dir.clone(), pg_bin_dir.clone(), pg_lib_dir.clone());
493 0 : let nproc = num_cpus.unwrap_or_else(num_cpus::get);
494 0 : let memory_mb = memory_mb.unwrap_or(256);
495 0 : proc.start(superuser, pg_port, nproc, memory_mb).await?;
496 0 : wait_until_ready(destination_connstring.clone(), "neondb".to_string()).await;
497 :
498 0 : run_dump_restore(
499 0 : workdir.clone(),
500 0 : pg_bin_dir,
501 0 : pg_lib_dir,
502 0 : source_connection_string,
503 0 : destination_connstring,
504 0 : )
505 0 : .await?;
506 :
507 : // If interactive mode, wait for Ctrl+C
508 0 : if interactive {
509 0 : info!("Running in interactive mode. Press Ctrl+C to shut down.");
510 0 : tokio::signal::ctrl_c().await.context("wait for ctrl-c")?;
511 0 : }
512 :
513 0 : proc.shutdown().await?;
514 :
515 : // Only sync if s3_prefix was specified
516 0 : if let Some(s3_prefix) = maybe_s3_prefix {
517 0 : info!("upload pgdata");
518 0 : aws_s3_sync::upload_dir_recursive(
519 0 : s3_client.unwrap(),
520 0 : Utf8Path::new(&pgdata_dir),
521 0 : &s3_prefix.append("/pgdata/"),
522 0 : )
523 0 : .await
524 0 : .context("sync dump directory to destination")?;
525 :
526 0 : info!("write pgdata status to s3");
527 : {
528 0 : let status_dir = workdir.join("status");
529 0 : std::fs::create_dir(&status_dir).context("create status directory")?;
530 0 : let status_file = status_dir.join("pgdata");
531 0 : std::fs::write(&status_file, serde_json::json!({"done": true}).to_string())
532 0 : .context("write status file")?;
533 0 : aws_s3_sync::upload_dir_recursive(
534 0 : s3_client.as_ref().unwrap(),
535 0 : &status_dir,
536 0 : &s3_prefix.append("/status/"),
537 0 : )
538 0 : .await
539 0 : .context("sync status directory to destination")?;
540 : }
541 0 : }
542 :
543 0 : Ok(())
544 0 : }
545 :
546 0 : async fn cmd_dumprestore(
547 0 : kms_client: Option<aws_sdk_kms::Client>,
548 0 : maybe_spec: Option<Spec>,
549 0 : source_connection_string: Option<String>,
550 0 : destination_connection_string: Option<String>,
551 0 : workdir: Utf8PathBuf,
552 0 : pg_bin_dir: Utf8PathBuf,
553 0 : pg_lib_dir: Utf8PathBuf,
554 0 : ) -> Result<(), anyhow::Error> {
555 0 : let (source_connstring, destination_connstring) = if let Some(spec) = maybe_spec {
556 0 : match spec.encryption_secret {
557 0 : EncryptionSecret::KMS { key_id } => {
558 0 : let source = decode_connstring(
559 0 : kms_client.as_ref().unwrap(),
560 0 : &key_id,
561 0 : spec.source_connstring_ciphertext_base64,
562 0 : )
563 0 : .await
564 0 : .context("decrypt source connection string")?;
565 :
566 0 : let dest = if let Some(dest_ciphertext) =
567 0 : spec.destination_connstring_ciphertext_base64
568 : {
569 0 : decode_connstring(kms_client.as_ref().unwrap(), &key_id, dest_ciphertext)
570 0 : .await
571 0 : .context("decrypt destination connection string")?
572 : } else {
573 0 : bail!(
574 0 : "destination connection string must be provided in spec for dump_restore command"
575 0 : );
576 : };
577 :
578 0 : (source, dest)
579 : }
580 : }
581 : } else {
582 : (
583 0 : source_connection_string.unwrap(),
584 0 : if let Some(val) = destination_connection_string {
585 0 : val
586 : } else {
587 0 : bail!("destination connection string must be provided for dump_restore command");
588 : },
589 : )
590 : };
591 :
592 0 : run_dump_restore(
593 0 : workdir,
594 0 : pg_bin_dir,
595 0 : pg_lib_dir,
596 0 : source_connstring,
597 0 : destination_connstring,
598 0 : )
599 0 : .await
600 0 : }
601 :
602 : #[tokio::main]
603 0 : pub(crate) async fn main() -> anyhow::Result<()> {
604 0 : utils::logging::init(
605 0 : utils::logging::LogFormat::Json,
606 0 : utils::logging::TracingErrorLayerEnablement::EnableWithRustLogFilter,
607 0 : utils::logging::Output::Stdout,
608 0 : )?;
609 0 :
610 0 : info!("starting");
611 0 :
612 0 : let args = Args::parse();
613 0 :
614 0 : // Initialize AWS clients only if s3_prefix is specified
615 0 : let (s3_client, kms_client) = if args.s3_prefix.is_some() {
616 0 : // Create AWS config with enhanced retry settings
617 0 : let config = aws_config::defaults(BehaviorVersion::v2024_03_28())
618 0 : .retry_config(
619 0 : aws_config::retry::RetryConfig::standard()
620 0 : .with_max_attempts(5) // Retry up to 5 times
621 0 : .with_initial_backoff(std::time::Duration::from_millis(200)) // Start with 200ms delay
622 0 : .with_max_backoff(std::time::Duration::from_secs(5)), // Cap at 5 seconds
623 0 : )
624 0 : .load()
625 0 : .await;
626 0 :
627 0 : // Create clients from the config with enhanced retry settings
628 0 : let s3_client = aws_sdk_s3::Client::new(&config);
629 0 : let kms = aws_sdk_kms::Client::new(&config);
630 0 : (Some(s3_client), Some(kms))
631 0 : } else {
632 0 : (None, None)
633 0 : };
634 0 :
635 0 : // Capture everything from spec assignment onwards to handle errors
636 0 : let res = async {
637 0 : let spec: Option<Spec> = if let Some(s3_prefix) = &args.s3_prefix {
638 0 : let spec_key = s3_prefix.append("/spec.json");
639 0 : let object = s3_client
640 0 : .as_ref()
641 0 : .unwrap()
642 0 : .get_object()
643 0 : .bucket(&spec_key.bucket)
644 0 : .key(spec_key.key)
645 0 : .send()
646 0 : .await
647 0 : .context("get spec from s3")?
648 0 : .body
649 0 : .collect()
650 0 : .await
651 0 : .context("download spec body")?;
652 0 : serde_json::from_slice(&object.into_bytes()).context("parse spec as json")?
653 0 : } else {
654 0 : None
655 0 : };
656 0 :
657 0 : match tokio::fs::create_dir(&args.working_directory).await {
658 0 : Ok(()) => {}
659 0 : Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
660 0 : if !is_directory_empty(&args.working_directory)
661 0 : .await
662 0 : .context("check if working directory is empty")?
663 0 : {
664 0 : bail!("working directory is not empty");
665 0 : } else {
666 0 : // ok
667 0 : }
668 0 : }
669 0 : Err(e) => return Err(anyhow::Error::new(e).context("create working directory")),
670 0 : }
671 0 :
672 0 : match args.command.clone() {
673 0 : Command::Pgdata {
674 0 : source_connection_string,
675 0 : interactive,
676 0 : pg_port,
677 0 : num_cpus,
678 0 : memory_mb,
679 0 : } => {
680 0 : cmd_pgdata(
681 0 : s3_client.as_ref(),
682 0 : kms_client,
683 0 : args.s3_prefix.clone(),
684 0 : spec,
685 0 : source_connection_string,
686 0 : interactive,
687 0 : pg_port,
688 0 : args.working_directory.clone(),
689 0 : args.pg_bin_dir,
690 0 : args.pg_lib_dir,
691 0 : num_cpus,
692 0 : memory_mb,
693 0 : )
694 0 : .await
695 0 : }
696 0 : Command::DumpRestore {
697 0 : source_connection_string,
698 0 : destination_connection_string,
699 0 : } => {
700 0 : cmd_dumprestore(
701 0 : kms_client,
702 0 : spec,
703 0 : source_connection_string,
704 0 : destination_connection_string,
705 0 : args.working_directory.clone(),
706 0 : args.pg_bin_dir,
707 0 : args.pg_lib_dir,
708 0 : )
709 0 : .await
710 0 : }
711 0 : }
712 0 : }
713 0 : .await;
714 0 :
715 0 : if let Some(s3_prefix) = args.s3_prefix {
716 0 : info!("write job status to s3");
717 0 : {
718 0 : let status_dir = args.working_directory.join("status");
719 0 : if std::fs::exists(&status_dir)?.not() {
720 0 : std::fs::create_dir(&status_dir).context("create status directory")?;
721 0 : }
722 0 : let status_file = status_dir.join("fast_import");
723 0 : let res_obj = match res {
724 0 : Ok(_) => serde_json::json!({"command": args.command.as_str(), "done": true}),
725 0 : Err(err) => {
726 0 : serde_json::json!({"command": args.command.as_str(), "done": false, "error": err.to_string()})
727 0 : }
728 0 : };
729 0 : std::fs::write(&status_file, res_obj.to_string()).context("write status file")?;
730 0 : aws_s3_sync::upload_dir_recursive(
731 0 : s3_client.as_ref().unwrap(),
732 0 : &status_dir,
733 0 : &s3_prefix.append("/status/"),
734 0 : )
735 0 : .await
736 0 : .context("sync status directory to destination")?;
737 0 : }
738 0 : }
739 0 :
740 0 : Ok(())
741 0 : }
|