Line data Source code
1 : /// The attachment service mimics the aspects of the control plane API
2 : /// that are required for a pageserver to operate.
3 : ///
4 : /// This enables running & testing pageservers without a full-blown
5 : /// deployment of the Neon cloud platform.
6 : ///
7 : use anyhow::{anyhow, Context};
8 : use attachment_service::http::make_router;
9 : use attachment_service::persistence::Persistence;
10 : use attachment_service::service::{Config, Service};
11 : use aws_config::{self, BehaviorVersion, Region};
12 : use camino::Utf8PathBuf;
13 : use clap::Parser;
14 : use diesel::Connection;
15 : use metrics::launch_timestamp::LaunchTimestamp;
16 : use std::sync::Arc;
17 : use tokio::signal::unix::SignalKind;
18 : use utils::auth::{JwtAuth, SwappableJwtAuth};
19 : use utils::logging::{self, LogFormat};
20 :
21 : use utils::{project_build_tag, project_git_version, tcp_listener};
22 :
23 : project_git_version!(GIT_VERSION);
24 : project_build_tag!(BUILD_TAG);
25 :
26 : use diesel_migrations::{embed_migrations, EmbeddedMigrations};
27 : pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
28 :
29 732 : #[derive(Parser)]
30 : #[command(author, version, about, long_about = None)]
31 : #[command(arg_required_else_help(true))]
32 : struct Cli {
33 : /// Host and port to listen on, like `127.0.0.1:1234`
34 : #[arg(short, long)]
35 0 : listen: std::net::SocketAddr,
36 :
37 : /// Public key for JWT authentication of clients
38 : #[arg(long)]
39 : public_key: Option<String>,
40 :
41 : /// Token for authenticating this service with the pageservers it controls
42 : #[arg(long)]
43 : jwt_token: Option<String>,
44 :
45 : /// Token for authenticating this service with the control plane, when calling
46 : /// the compute notification endpoint
47 : #[arg(long)]
48 : control_plane_jwt_token: Option<String>,
49 :
50 : /// URL to control plane compute notification endpoint
51 : #[arg(long)]
52 : compute_hook_url: Option<String>,
53 :
54 : /// Path to the .json file to store state (will be created if it doesn't exist)
55 : #[arg(short, long)]
56 : path: Option<Utf8PathBuf>,
57 :
58 : /// URL to connect to postgres, like postgresql://localhost:1234/attachment_service
59 : #[arg(long)]
60 : database_url: Option<String>,
61 : }
62 :
63 : /// Secrets may either be provided on the command line (for testing), or loaded from AWS SecretManager: this
64 : /// type encapsulates the logic to decide which and do the loading.
65 : struct Secrets {
66 : database_url: String,
67 : public_key: Option<JwtAuth>,
68 : jwt_token: Option<String>,
69 : control_plane_jwt_token: Option<String>,
70 : }
71 :
72 : impl Secrets {
73 : const DATABASE_URL_SECRET: &'static str = "rds-neon-storage-controller-url";
74 : const PAGESERVER_JWT_TOKEN_SECRET: &'static str =
75 : "neon-storage-controller-pageserver-jwt-token";
76 : const CONTROL_PLANE_JWT_TOKEN_SECRET: &'static str =
77 : "neon-storage-controller-control-plane-jwt-token";
78 : const PUBLIC_KEY_SECRET: &'static str = "neon-storage-controller-public-key";
79 :
80 366 : async fn load(args: &Cli) -> anyhow::Result<Self> {
81 366 : match &args.database_url {
82 366 : Some(url) => Self::load_cli(url, args),
83 0 : None => Self::load_aws_sm().await,
84 : }
85 366 : }
86 :
87 0 : async fn load_aws_sm() -> anyhow::Result<Self> {
88 0 : let Ok(region) = std::env::var("AWS_REGION") else {
89 0 : anyhow::bail!("AWS_REGION is not set, cannot load secrets automatically: either set this, or use CLI args to supply secrets");
90 : };
91 0 : let config = aws_config::defaults(BehaviorVersion::v2023_11_09())
92 0 : .region(Region::new(region.clone()))
93 0 : .load()
94 0 : .await;
95 :
96 0 : let asm = aws_sdk_secretsmanager::Client::new(&config);
97 :
98 0 : let Some(database_url) = asm
99 0 : .get_secret_value()
100 0 : .secret_id(Self::DATABASE_URL_SECRET)
101 0 : .send()
102 0 : .await?
103 0 : .secret_string()
104 0 : .map(str::to_string)
105 : else {
106 0 : anyhow::bail!(
107 0 : "Database URL secret not found at {region}/{}",
108 0 : Self::DATABASE_URL_SECRET
109 0 : )
110 : };
111 :
112 0 : let jwt_token = asm
113 0 : .get_secret_value()
114 0 : .secret_id(Self::PAGESERVER_JWT_TOKEN_SECRET)
115 0 : .send()
116 0 : .await?
117 0 : .secret_string()
118 0 : .map(str::to_string);
119 0 : if jwt_token.is_none() {
120 0 : tracing::warn!("No pageserver JWT token set: this will only work if authentication is disabled on the pageserver");
121 0 : }
122 :
123 0 : let control_plane_jwt_token = asm
124 0 : .get_secret_value()
125 0 : .secret_id(Self::CONTROL_PLANE_JWT_TOKEN_SECRET)
126 0 : .send()
127 0 : .await?
128 0 : .secret_string()
129 0 : .map(str::to_string);
130 0 : if jwt_token.is_none() {
131 0 : tracing::warn!("No control plane JWT token set: this will only work if authentication is disabled on the pageserver");
132 0 : }
133 :
134 0 : let public_key = asm
135 0 : .get_secret_value()
136 0 : .secret_id(Self::PUBLIC_KEY_SECRET)
137 0 : .send()
138 0 : .await?
139 0 : .secret_string()
140 0 : .map(str::to_string);
141 0 : let public_key = match public_key {
142 0 : Some(key) => Some(JwtAuth::from_key(key)?),
143 : None => {
144 0 : tracing::warn!(
145 0 : "No public key set: inccoming HTTP requests will not be authenticated"
146 0 : );
147 0 : None
148 : }
149 : };
150 :
151 0 : Ok(Self {
152 0 : database_url,
153 0 : public_key,
154 0 : jwt_token,
155 0 : control_plane_jwt_token,
156 0 : })
157 0 : }
158 :
159 366 : fn load_cli(database_url: &str, args: &Cli) -> anyhow::Result<Self> {
160 366 : let public_key = match &args.public_key {
161 355 : None => None,
162 11 : Some(key) => Some(JwtAuth::from_key(key.clone()).context("Loading public key")?),
163 : };
164 366 : Ok(Self {
165 366 : database_url: database_url.to_owned(),
166 366 : public_key,
167 366 : jwt_token: args.jwt_token.clone(),
168 366 : control_plane_jwt_token: args.control_plane_jwt_token.clone(),
169 366 : })
170 366 : }
171 : }
172 :
173 : /// Execute the diesel migrations that are built into this binary
174 366 : async fn migration_run(database_url: &str) -> anyhow::Result<()> {
175 : use diesel::PgConnection;
176 : use diesel_migrations::{HarnessWithOutput, MigrationHarness};
177 366 : let mut conn = PgConnection::establish(database_url)?;
178 :
179 366 : HarnessWithOutput::write_to_stdout(&mut conn)
180 366 : .run_pending_migrations(MIGRATIONS)
181 366 : .map(|_| ())
182 366 : .map_err(|e| anyhow::anyhow!(e))?;
183 :
184 366 : Ok(())
185 366 : }
186 :
187 366 : fn main() -> anyhow::Result<()> {
188 366 : tokio::runtime::Builder::new_current_thread()
189 366 : // We use spawn_blocking for database operations, so require approximately
190 366 : // as many blocking threads as we will open database connections.
191 366 : .max_blocking_threads(Persistence::MAX_CONNECTIONS as usize)
192 366 : .enable_all()
193 366 : .build()
194 366 : .unwrap()
195 366 : .block_on(async_main())
196 366 : }
197 :
198 366 : async fn async_main() -> anyhow::Result<()> {
199 366 : let launch_ts = Box::leak(Box::new(LaunchTimestamp::generate()));
200 366 :
201 366 : logging::init(
202 366 : LogFormat::Plain,
203 366 : logging::TracingErrorLayerEnablement::Disabled,
204 366 : logging::Output::Stdout,
205 366 : )?;
206 :
207 366 : let args = Cli::parse();
208 366 : tracing::info!(
209 366 : "version: {}, launch_timestamp: {}, build_tag {}, state at {}, listening on {}",
210 366 : GIT_VERSION,
211 366 : launch_ts.to_string(),
212 366 : BUILD_TAG,
213 366 : args.path.as_ref().unwrap_or(&Utf8PathBuf::from("<none>")),
214 366 : args.listen
215 366 : );
216 :
217 366 : let secrets = Secrets::load(&args).await?;
218 :
219 366 : let config = Config {
220 366 : jwt_token: secrets.jwt_token,
221 366 : control_plane_jwt_token: secrets.control_plane_jwt_token,
222 366 : compute_hook_url: args.compute_hook_url,
223 366 : };
224 366 :
225 366 : // After loading secrets & config, but before starting anything else, apply database migrations
226 366 : migration_run(&secrets.database_url)
227 0 : .await
228 366 : .context("Running database migrations")?;
229 :
230 366 : let json_path = args.path;
231 366 : let persistence = Arc::new(Persistence::new(secrets.database_url, json_path.clone()));
232 :
233 1094 : let service = Service::spawn(config, persistence.clone()).await?;
234 :
235 366 : let http_listener = tcp_listener::bind(args.listen)?;
236 :
237 366 : let auth = secrets
238 366 : .public_key
239 366 : .map(|jwt_auth| Arc::new(SwappableJwtAuth::new(jwt_auth)));
240 366 : let router = make_router(service, auth)
241 366 : .build()
242 366 : .map_err(|err| anyhow!(err))?;
243 366 : let router_service = utils::http::RouterService::new(router).unwrap();
244 366 : let server = hyper::Server::from_tcp(http_listener)?.serve(router_service);
245 :
246 366 : tracing::info!("Serving on {0}", args.listen);
247 :
248 366 : tokio::task::spawn(server);
249 :
250 : // Wait until we receive a signal
251 366 : let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
252 366 : let mut sigquit = tokio::signal::unix::signal(SignalKind::quit())?;
253 366 : let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())?;
254 366 : tokio::select! {
255 732 : _ = sigint.recv() => {},
256 732 : _ = sigterm.recv() => {},
257 732 : _ = sigquit.recv() => {},
258 732 : }
259 366 : tracing::info!("Terminating on signal");
260 :
261 366 : if json_path.is_some() {
262 : // Write out a JSON dump on shutdown: this is used in compat tests to avoid passing
263 : // full postgres dumps around.
264 732 : if let Err(e) = persistence.write_tenants_json().await {
265 0 : tracing::error!("Failed to write JSON on shutdown: {e}")
266 366 : }
267 0 : }
268 :
269 366 : std::process::exit(0);
270 0 : }
|