Line data Source code
1 : use std::net::SocketAddr;
2 : use std::pin::pin;
3 : use std::str::FromStr;
4 : use std::sync::Arc;
5 : use std::time::Duration;
6 :
7 : use anyhow::{Context, bail, ensure};
8 : use arc_swap::ArcSwapOption;
9 : use camino::{Utf8Path, Utf8PathBuf};
10 : use clap::Parser;
11 : use compute_api::spec::LocalProxySpec;
12 : use futures::future::Either;
13 : use thiserror::Error;
14 : use tokio::net::TcpListener;
15 : use tokio::sync::Notify;
16 : use tokio::task::JoinSet;
17 : use tokio_util::sync::CancellationToken;
18 : use tracing::{debug, error, info, warn};
19 : use utils::sentry_init::init_sentry;
20 : use utils::{pid_file, project_build_tag, project_git_version};
21 :
22 : use crate::auth::backend::jwt::JwkCache;
23 : use crate::auth::backend::local::{JWKS_ROLE_MAP, LocalBackend};
24 : use crate::auth::{self};
25 : use crate::cancellation::CancellationHandler;
26 : use crate::config::{
27 : self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
28 : };
29 : use crate::control_plane::locks::ApiLocks;
30 : use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
31 : use crate::ext::TaskExt;
32 : use crate::http::health_server::AppMetrics;
33 : use crate::intern::RoleNameInt;
34 : use crate::metrics::{Metrics, ThreadPoolMetrics};
35 : use crate::rate_limiter::{
36 : BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo,
37 : };
38 : use crate::scram::threadpool::ThreadPool;
39 : use crate::serverless::cancel_set::CancelSet;
40 : use crate::serverless::{self, GlobalConnPoolOptions};
41 : use crate::tls::client_config::compute_client_config_with_root_certs;
42 : use crate::types::RoleName;
43 : use crate::url::ApiUrl;
44 :
45 : project_git_version!(GIT_VERSION);
46 : project_build_tag!(BUILD_TAG);
47 :
48 : /// Neon proxy/router
49 : #[derive(Parser)]
50 : #[command(version = GIT_VERSION, about)]
51 : struct LocalProxyCliArgs {
52 : /// listen for incoming metrics connections on ip:port
53 : #[clap(long, default_value = "127.0.0.1:7001")]
54 0 : metrics: String,
55 : /// listen for incoming http connections on ip:port
56 : #[clap(long)]
57 0 : http: String,
58 : /// timeout for the TLS handshake
59 : #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
60 0 : handshake_timeout: tokio::time::Duration,
61 : /// lock for `connect_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
62 : #[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK)]
63 0 : connect_compute_lock: String,
64 : #[clap(flatten)]
65 : sql_over_http: SqlOverHttpArgs,
66 : /// User rate limiter max number of requests per second.
67 : ///
68 : /// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
69 : /// Can be given multiple times for different bucket sizes.
70 0 : #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
71 0 : user_rps_limit: Vec<RateBucketInfo>,
72 : /// Whether the auth rate limiter actually takes effect (for testing)
73 0 : #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
74 0 : auth_rate_limit_enabled: bool,
75 : /// Authentication rate limiter max number of hashes per second.
76 0 : #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
77 0 : auth_rate_limit: Vec<RateBucketInfo>,
78 : /// The IP subnet to use when considering whether two IP addresses are considered the same.
79 0 : #[clap(long, default_value_t = 64)]
80 0 : auth_rate_limit_ip_subnet: u8,
81 : /// Whether to retry the connection to the compute node
82 : #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
83 0 : connect_to_compute_retry: String,
84 : /// Address of the postgres server
85 : #[clap(long, default_value = "127.0.0.1:5432")]
86 0 : postgres: SocketAddr,
87 : /// Address of the internal compute-ctl api service
88 : #[clap(long, default_value = "http://127.0.0.1:3081/")]
89 0 : compute_ctl: ApiUrl,
90 : /// Path of the local proxy config file
91 : #[clap(long, default_value = "./local_proxy.json")]
92 0 : config_path: Utf8PathBuf,
93 : /// Path of the local proxy PID file
94 : #[clap(long, default_value = "./local_proxy.pid")]
95 0 : pid_path: Utf8PathBuf,
96 : }
97 :
98 : #[derive(clap::Args, Clone, Copy, Debug)]
99 : struct SqlOverHttpArgs {
100 : /// How many connections to pool for each endpoint. Excess connections are discarded
101 0 : #[clap(long, default_value_t = 200)]
102 0 : sql_over_http_pool_max_total_conns: usize,
103 :
104 : /// How long pooled connections should remain idle for before closing
105 : #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
106 0 : sql_over_http_idle_timeout: tokio::time::Duration,
107 :
108 0 : #[clap(long, default_value_t = 100)]
109 0 : sql_over_http_client_conn_threshold: u64,
110 :
111 0 : #[clap(long, default_value_t = 16)]
112 0 : sql_over_http_cancel_set_shards: usize,
113 :
114 0 : #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
115 0 : sql_over_http_max_request_size_bytes: usize,
116 :
117 0 : #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
118 0 : sql_over_http_max_response_size_bytes: usize,
119 : }
120 :
121 0 : pub async fn run() -> anyhow::Result<()> {
122 0 : let _logging_guard = crate::logging::init_local_proxy()?;
123 0 : let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
124 0 : let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
125 0 :
126 0 : Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
127 0 :
128 0 : // TODO: refactor these to use labels
129 0 : debug!("Version: {GIT_VERSION}");
130 0 : debug!("Build_tag: {BUILD_TAG}");
131 0 : let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo {
132 0 : revision: GIT_VERSION,
133 0 : build_tag: BUILD_TAG,
134 0 : });
135 :
136 0 : let jemalloc = match crate::jemalloc::MetricRecorder::new() {
137 0 : Ok(t) => Some(t),
138 0 : Err(e) => {
139 0 : tracing::error!(error = ?e, "could not start jemalloc metrics loop");
140 0 : None
141 : }
142 : };
143 :
144 0 : let args = LocalProxyCliArgs::parse();
145 0 : let config = build_config(&args)?;
146 0 : let auth_backend = build_auth_backend(&args);
147 :
148 : // before we bind to any ports, write the process ID to a file
149 : // so that compute-ctl can find our process later
150 : // in order to trigger the appropriate SIGHUP on config change.
151 : //
152 : // This also claims a "lock" that makes sure only one instance
153 : // of local_proxy runs at a time.
154 0 : let _process_guard = loop {
155 0 : match pid_file::claim_for_current_process(&args.pid_path) {
156 0 : Ok(guard) => break guard,
157 0 : Err(e) => {
158 0 : // compute-ctl might have tried to read the pid-file to let us
159 0 : // know about some config change. We should try again.
160 0 : error!(path=?args.pid_path, "could not claim PID file guard: {e:?}");
161 0 : tokio::time::sleep(Duration::from_secs(1)).await;
162 : }
163 : }
164 : };
165 :
166 0 : let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
167 0 : let http_listener = TcpListener::bind(args.http).await?;
168 0 : let shutdown = CancellationToken::new();
169 0 :
170 0 : // todo: should scale with CU
171 0 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
172 0 : LeakyBucketConfig {
173 0 : rps: 10.0,
174 0 : max: 100.0,
175 0 : },
176 0 : 16,
177 0 : ));
178 0 :
179 0 : let mut maintenance_tasks = JoinSet::new();
180 0 :
181 0 : let refresh_config_notify = Arc::new(Notify::new());
182 0 : maintenance_tasks.spawn(crate::signals::handle(shutdown.clone(), {
183 0 : let refresh_config_notify = Arc::clone(&refresh_config_notify);
184 0 : move || {
185 0 : refresh_config_notify.notify_one();
186 0 : }
187 0 : }));
188 0 :
189 0 : // trigger the first config load **after** setting up the signal hook
190 0 : // to avoid the race condition where:
191 0 : // 1. No config file registered when local_proxy starts up
192 0 : // 2. The config file is written but the signal hook is not yet received
193 0 : // 3. local_proxy completes startup but has no config loaded, despite there being a registerd config.
194 0 : refresh_config_notify.notify_one();
195 0 : tokio::spawn(refresh_config_loop(
196 0 : config,
197 0 : args.config_path,
198 0 : refresh_config_notify,
199 0 : ));
200 0 :
201 0 : maintenance_tasks.spawn(crate::http::health_server::task_main(
202 0 : metrics_listener,
203 0 : AppMetrics {
204 0 : jemalloc,
205 0 : neon_metrics,
206 0 : proxy: crate::metrics::Metrics::get(),
207 0 : },
208 0 : ));
209 0 :
210 0 : let task = serverless::task_main(
211 0 : config,
212 0 : auth_backend,
213 0 : http_listener,
214 0 : shutdown.clone(),
215 0 : Arc::new(CancellationHandler::new(&config.connect_to_compute, None)),
216 0 : endpoint_rate_limiter,
217 0 : );
218 0 :
219 0 : match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {
220 : // exit immediately on maintenance task completion
221 0 : Either::Left((Some(res), _)) => match crate::error::flatten_err(res)? {},
222 : // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
223 0 : Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
224 : // exit immediately on client task error
225 0 : Either::Right((res, _)) => res?,
226 : }
227 :
228 0 : Ok(())
229 0 : }
230 :
231 : /// ProxyConfig is created at proxy startup, and lives forever.
232 0 : fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
233 : let config::ConcurrencyLockOptions {
234 0 : shards,
235 0 : limiter,
236 0 : epoch,
237 0 : timeout,
238 0 : } = args.connect_compute_lock.parse()?;
239 0 : info!(
240 : ?limiter,
241 : shards,
242 : ?epoch,
243 0 : "Using NodeLocks (connect_compute)"
244 : );
245 0 : let connect_compute_locks = ApiLocks::new(
246 0 : "connect_compute_lock",
247 0 : limiter,
248 0 : shards,
249 0 : timeout,
250 0 : epoch,
251 0 : &Metrics::get().proxy.connect_compute_lock,
252 0 : );
253 0 :
254 0 : let http_config = HttpConfig {
255 0 : accept_websockets: false,
256 0 : pool_options: GlobalConnPoolOptions {
257 0 : gc_epoch: Duration::from_secs(60),
258 0 : pool_shards: 2,
259 0 : idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
260 0 : opt_in: false,
261 0 :
262 0 : max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_total_conns,
263 0 : max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
264 0 : },
265 0 : cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
266 0 : client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
267 0 : max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
268 0 : max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
269 0 : };
270 :
271 0 : let compute_config = ComputeConfig {
272 0 : retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
273 0 : tls: Arc::new(compute_client_config_with_root_certs()?),
274 0 : timeout: Duration::from_secs(2),
275 0 : };
276 0 :
277 0 : Ok(Box::leak(Box::new(ProxyConfig {
278 0 : tls_config: ArcSwapOption::from(None),
279 0 : metric_collection: None,
280 0 : http_config,
281 0 : authentication_config: AuthenticationConfig {
282 0 : jwks_cache: JwkCache::default(),
283 0 : thread_pool: ThreadPool::new(0),
284 0 : scram_protocol_timeout: Duration::from_secs(10),
285 0 : rate_limiter_enabled: false,
286 0 : rate_limiter: BucketRateLimiter::new(vec![]),
287 0 : rate_limit_ip_subnet: 64,
288 0 : ip_allowlist_check_enabled: true,
289 0 : is_vpc_acccess_proxy: false,
290 0 : is_auth_broker: false,
291 0 : accept_jwts: true,
292 0 : console_redirect_confirmation_timeout: Duration::ZERO,
293 0 : },
294 0 : proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
295 0 : handshake_timeout: Duration::from_secs(10),
296 0 : region: "local".into(),
297 0 : wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
298 0 : connect_compute_locks,
299 0 : connect_to_compute: compute_config,
300 : })))
301 0 : }
302 :
303 : /// auth::Backend is created at proxy startup, and lives forever.
304 0 : fn build_auth_backend(args: &LocalProxyCliArgs) -> &'static auth::Backend<'static, ()> {
305 0 : let auth_backend = crate::auth::Backend::Local(crate::auth::backend::MaybeOwned::Owned(
306 0 : LocalBackend::new(args.postgres, args.compute_ctl.clone()),
307 0 : ));
308 0 :
309 0 : Box::leak(Box::new(auth_backend))
310 0 : }
311 :
312 : #[derive(Error, Debug)]
313 : enum RefreshConfigError {
314 : #[error(transparent)]
315 : Read(#[from] std::io::Error),
316 : #[error(transparent)]
317 : Parse(#[from] serde_json::Error),
318 : #[error(transparent)]
319 : Validate(anyhow::Error),
320 : #[error(transparent)]
321 : Tls(anyhow::Error),
322 : }
323 :
324 0 : async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
325 0 : let mut init = true;
326 : loop {
327 0 : rx.notified().await;
328 :
329 0 : match refresh_config_inner(config, &path).await {
330 0 : Ok(()) => {}
331 : // don't log for file not found errors if this is the first time we are checking
332 : // for computes that don't use local_proxy, this is not an error.
333 0 : Err(RefreshConfigError::Read(e))
334 0 : if init && e.kind() == std::io::ErrorKind::NotFound =>
335 0 : {
336 0 : debug!(error=?e, ?path, "could not read config file");
337 : }
338 0 : Err(RefreshConfigError::Tls(e)) => {
339 0 : error!(error=?e, ?path, "could not read TLS certificates");
340 : }
341 0 : Err(e) => {
342 0 : error!(error=?e, ?path, "could not read config file");
343 : }
344 : }
345 :
346 0 : init = false;
347 : }
348 : }
349 :
350 0 : async fn refresh_config_inner(
351 0 : config: &ProxyConfig,
352 0 : path: &Utf8Path,
353 0 : ) -> Result<(), RefreshConfigError> {
354 0 : let bytes = tokio::fs::read(&path).await?;
355 0 : let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
356 :
357 0 : let mut jwks_set = vec![];
358 :
359 0 : fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
360 0 : let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
361 :
362 0 : ensure!(
363 0 : jwks_url.has_authority()
364 0 : && (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
365 0 : "Invalid JWKS url. Must be HTTP",
366 : );
367 :
368 0 : ensure!(
369 0 : jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
370 0 : "Invalid JWKS url. No domain listed",
371 : );
372 :
373 : // clear username, password and ports
374 0 : jwks_url
375 0 : .set_username("")
376 0 : .expect("url can be a base and has a valid host and is not a file. should not error");
377 0 : jwks_url
378 0 : .set_password(None)
379 0 : .expect("url can be a base and has a valid host and is not a file. should not error");
380 0 : // local testing is hard if we need to have a specific restricted port
381 0 : if cfg!(not(feature = "testing")) {
382 0 : jwks_url.set_port(None).expect(
383 0 : "url can be a base and has a valid host and is not a file. should not error",
384 0 : );
385 0 : }
386 :
387 : // clear query params
388 0 : jwks_url.set_fragment(None);
389 0 : jwks_url.query_pairs_mut().clear().finish();
390 0 :
391 0 : if jwks_url.scheme() != "https" {
392 : // local testing is hard if we need to set up https support.
393 0 : if cfg!(not(feature = "testing")) {
394 0 : jwks_url
395 0 : .set_scheme("https")
396 0 : .expect("should not error to set the scheme to https if it was http");
397 0 : } else {
398 0 : warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
399 : }
400 0 : }
401 :
402 0 : Ok(JwksSettings {
403 0 : id: jwks.id,
404 0 : jwks_url,
405 0 : _provider_name: jwks.provider_name,
406 0 : jwt_audience: jwks.jwt_audience,
407 0 : role_names: jwks
408 0 : .role_names
409 0 : .into_iter()
410 0 : .map(RoleName::from)
411 0 : .map(|s| RoleNameInt::from(&s))
412 0 : .collect(),
413 0 : })
414 0 : }
415 :
416 0 : for jwks in data.jwks.into_iter().flatten() {
417 0 : jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
418 : }
419 :
420 0 : info!("successfully loaded new config");
421 0 : JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
422 :
423 0 : if let Some(tls_config) = data.tls {
424 0 : let tls_config = tokio::task::spawn_blocking(move || {
425 0 : crate::tls::server_config::configure_tls(
426 0 : &tls_config.key_path,
427 0 : &tls_config.cert_path,
428 0 : None,
429 0 : false,
430 0 : )
431 0 : })
432 0 : .await
433 0 : .propagate_task_panic()
434 0 : .map_err(RefreshConfigError::Tls)?;
435 0 : config.tls_config.store(Some(Arc::new(tls_config)));
436 0 : }
437 :
438 0 : Ok(())
439 0 : }
|