Line data Source code
1 : /// The attachment service mimics the aspects of the control plane API
2 : /// that are required for a pageserver to operate.
3 : ///
4 : /// This enables running & testing pageservers without a full-blown
5 : /// deployment of the Neon cloud platform.
6 : ///
7 : use anyhow::{anyhow, Context};
8 : use attachment_service::http::make_router;
9 : use attachment_service::persistence::Persistence;
10 : use attachment_service::service::{Config, Service};
11 : use aws_config::{self, BehaviorVersion, Region};
12 : use camino::Utf8PathBuf;
13 : use clap::Parser;
14 : use diesel::Connection;
15 : use metrics::launch_timestamp::LaunchTimestamp;
16 : use std::sync::Arc;
17 : use tokio::signal::unix::SignalKind;
18 : use utils::auth::{JwtAuth, SwappableJwtAuth};
19 : use utils::logging::{self, LogFormat};
20 :
21 : use utils::{project_build_tag, project_git_version, tcp_listener};
22 :
23 : project_git_version!(GIT_VERSION);
24 : project_build_tag!(BUILD_TAG);
25 :
26 : use diesel_migrations::{embed_migrations, EmbeddedMigrations};
27 : pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
28 :
29 361 : #[derive(Parser)]
30 : #[command(author, version, about, long_about = None)]
31 : #[command(arg_required_else_help(true))]
32 : struct Cli {
33 : /// Host and port to listen on, like `127.0.0.1:1234`
34 : #[arg(short, long)]
35 0 : listen: std::net::SocketAddr,
36 :
37 : /// Public key for JWT authentication of clients
38 : #[arg(long)]
39 : public_key: Option<String>,
40 :
41 : /// Token for authenticating this service with the pageservers it controls
42 : #[arg(long)]
43 : jwt_token: Option<String>,
44 :
45 : /// Token for authenticating this service with the control plane, when calling
46 : /// the compute notification endpoint
47 : #[arg(long)]
48 : control_plane_jwt_token: Option<String>,
49 :
50 : /// URL to control plane compute notification endpoint
51 : #[arg(long)]
52 : compute_hook_url: Option<String>,
53 :
54 : /// Path to the .json file to store state (will be created if it doesn't exist)
55 : #[arg(short, long)]
56 : path: Option<Utf8PathBuf>,
57 :
58 : /// URL to connect to postgres, like postgresql://localhost:1234/attachment_service
59 : #[arg(long)]
60 : database_url: Option<String>,
61 : }
62 :
63 : /// Secrets may either be provided on the command line (for testing), or loaded from AWS SecretManager: this
64 : /// type encapsulates the logic to decide which and do the loading.
65 : struct Secrets {
66 : database_url: String,
67 : public_key: Option<JwtAuth>,
68 : jwt_token: Option<String>,
69 : control_plane_jwt_token: Option<String>,
70 : }
71 :
72 : impl Secrets {
73 : const DATABASE_URL_SECRET: &'static str = "rds-neon-storage-controller-url";
74 : const PAGESERVER_JWT_TOKEN_SECRET: &'static str =
75 : "neon-storage-controller-pageserver-jwt-token";
76 : const CONTROL_PLANE_JWT_TOKEN_SECRET: &'static str =
77 : "neon-storage-controller-control-plane-jwt-token";
78 : const PUBLIC_KEY_SECRET: &'static str = "neon-storage-controller-public-key";
79 :
80 361 : async fn load(args: &Cli) -> anyhow::Result<Self> {
81 361 : match &args.database_url {
82 361 : Some(url) => Self::load_cli(url, args),
83 0 : None => Self::load_aws_sm().await,
84 : }
85 361 : }
86 :
87 0 : async fn load_aws_sm() -> anyhow::Result<Self> {
88 0 : let Ok(region) = std::env::var("AWS_REGION") else {
89 0 : anyhow::bail!("AWS_REGION is not set, cannot load secrets automatically: either set this, or use CLI args to supply secrets");
90 : };
91 0 : let config = aws_config::defaults(BehaviorVersion::v2023_11_09())
92 0 : .region(Region::new(region.clone()))
93 0 : .load()
94 0 : .await;
95 :
96 0 : let asm = aws_sdk_secretsmanager::Client::new(&config);
97 :
98 0 : let Some(database_url) = asm
99 0 : .get_secret_value()
100 0 : .secret_id(Self::DATABASE_URL_SECRET)
101 0 : .send()
102 0 : .await?
103 0 : .secret_string()
104 0 : .map(str::to_string)
105 : else {
106 0 : anyhow::bail!(
107 0 : "Database URL secret not found at {region}/{}",
108 0 : Self::DATABASE_URL_SECRET
109 0 : )
110 : };
111 :
112 0 : let jwt_token = asm
113 0 : .get_secret_value()
114 0 : .secret_id(Self::PAGESERVER_JWT_TOKEN_SECRET)
115 0 : .send()
116 0 : .await?
117 0 : .secret_string()
118 0 : .map(str::to_string);
119 0 : if jwt_token.is_none() {
120 0 : tracing::warn!("No pageserver JWT token set: this will only work if authentication is disabled on the pageserver");
121 0 : }
122 :
123 0 : let control_plane_jwt_token = asm
124 0 : .get_secret_value()
125 0 : .secret_id(Self::CONTROL_PLANE_JWT_TOKEN_SECRET)
126 0 : .send()
127 0 : .await?
128 0 : .secret_string()
129 0 : .map(str::to_string);
130 0 : if jwt_token.is_none() {
131 0 : tracing::warn!("No control plane JWT token set: this will only work if authentication is disabled on the pageserver");
132 0 : }
133 :
134 0 : let public_key = asm
135 0 : .get_secret_value()
136 0 : .secret_id(Self::PUBLIC_KEY_SECRET)
137 0 : .send()
138 0 : .await?
139 0 : .secret_string()
140 0 : .map(str::to_string);
141 0 : let public_key = match public_key {
142 0 : Some(key) => Some(JwtAuth::from_key(key)?),
143 : None => {
144 0 : tracing::warn!(
145 0 : "No public key set: inccoming HTTP requests will not be authenticated"
146 0 : );
147 0 : None
148 : }
149 : };
150 :
151 0 : Ok(Self {
152 0 : database_url,
153 0 : public_key,
154 0 : jwt_token,
155 0 : control_plane_jwt_token,
156 0 : })
157 0 : }
158 :
159 361 : fn load_cli(database_url: &str, args: &Cli) -> anyhow::Result<Self> {
160 361 : let public_key = match &args.public_key {
161 350 : None => None,
162 11 : Some(key) => Some(JwtAuth::from_key(key.clone()).context("Loading public key")?),
163 : };
164 361 : Ok(Self {
165 361 : database_url: database_url.to_owned(),
166 361 : public_key,
167 361 : jwt_token: args.jwt_token.clone(),
168 361 : control_plane_jwt_token: args.control_plane_jwt_token.clone(),
169 361 : })
170 361 : }
171 : }
172 :
173 361 : async fn migration_run(database_url: &str) -> anyhow::Result<()> {
174 : use diesel::PgConnection;
175 : use diesel_migrations::{HarnessWithOutput, MigrationHarness};
176 361 : let mut conn = PgConnection::establish(database_url)?;
177 :
178 361 : HarnessWithOutput::write_to_stdout(&mut conn)
179 361 : .run_pending_migrations(MIGRATIONS)
180 361 : .map(|_| ())
181 361 : .map_err(|e| anyhow::anyhow!(e))?;
182 :
183 361 : Ok(())
184 361 : }
185 :
186 : #[tokio::main]
187 361 : async fn main() -> anyhow::Result<()> {
188 361 : let launch_ts = Box::leak(Box::new(LaunchTimestamp::generate()));
189 361 :
190 361 : logging::init(
191 361 : LogFormat::Plain,
192 361 : logging::TracingErrorLayerEnablement::Disabled,
193 361 : logging::Output::Stdout,
194 361 : )?;
195 :
196 361 : let args = Cli::parse();
197 361 : tracing::info!(
198 361 : "version: {}, launch_timestamp: {}, build_tag {}, state at {}, listening on {}",
199 361 : GIT_VERSION,
200 361 : launch_ts.to_string(),
201 361 : BUILD_TAG,
202 361 : args.path.as_ref().unwrap_or(&Utf8PathBuf::from("<none>")),
203 361 : args.listen
204 361 : );
205 :
206 361 : let secrets = Secrets::load(&args).await?;
207 :
208 361 : let config = Config {
209 361 : jwt_token: secrets.jwt_token,
210 361 : control_plane_jwt_token: secrets.control_plane_jwt_token,
211 361 : compute_hook_url: args.compute_hook_url,
212 361 : };
213 361 :
214 361 : // After loading secrets & config, but before starting anything else, apply database migrations
215 361 : migration_run(&secrets.database_url)
216 0 : .await
217 361 : .context("Running database migrations")?;
218 :
219 361 : let json_path = args.path;
220 361 : let persistence = Arc::new(Persistence::new(secrets.database_url, json_path.clone()));
221 :
222 1077 : let service = Service::spawn(config, persistence.clone()).await?;
223 :
224 361 : let http_listener = tcp_listener::bind(args.listen)?;
225 :
226 361 : let auth = secrets
227 361 : .public_key
228 361 : .map(|jwt_auth| Arc::new(SwappableJwtAuth::new(jwt_auth)));
229 361 : let router = make_router(service, auth)
230 361 : .build()
231 361 : .map_err(|err| anyhow!(err))?;
232 361 : let router_service = utils::http::RouterService::new(router).unwrap();
233 361 : let server = hyper::Server::from_tcp(http_listener)?.serve(router_service);
234 361 :
235 361 : tracing::info!("Serving on {0}", args.listen);
236 :
237 361 : tokio::task::spawn(server);
238 :
239 : // Wait until we receive a signal
240 361 : let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
241 361 : let mut sigquit = tokio::signal::unix::signal(SignalKind::quit())?;
242 361 : let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())?;
243 361 : tokio::select! {
244 722 : _ = sigint.recv() => {},
245 722 : _ = sigterm.recv() => {},
246 722 : _ = sigquit.recv() => {},
247 722 : }
248 361 : tracing::info!("Terminating on signal");
249 :
250 361 : if json_path.is_some() {
251 : // Write out a JSON dump on shutdown: this is used in compat tests to avoid passing
252 : // full postgres dumps around.
253 722 : if let Err(e) = persistence.write_tenants_json().await {
254 0 : tracing::error!("Failed to write JSON on shutdown: {e}")
255 361 : }
256 0 : }
257 :
258 361 : std::process::exit(0);
259 : }
|