Line data Source code
1 : use futures::future::Either;
2 : use proxy::auth;
3 : use proxy::auth::backend::MaybeOwned;
4 : use proxy::config::AuthenticationConfig;
5 : use proxy::config::CacheOptions;
6 : use proxy::config::HttpConfig;
7 : use proxy::config::ProjectInfoCacheOptions;
8 : use proxy::console;
9 : use proxy::context::parquet::ParquetUploadArgs;
10 : use proxy::http;
11 : use proxy::rate_limiter::EndpointRateLimiter;
12 : use proxy::rate_limiter::RateBucketInfo;
13 : use proxy::rate_limiter::RateLimiterConfig;
14 : use proxy::redis::notifications;
15 : use proxy::serverless::GlobalConnPoolOptions;
16 : use proxy::usage_metrics;
17 :
18 : use anyhow::bail;
19 : use proxy::config::{self, ProxyConfig};
20 : use proxy::serverless;
21 : use std::net::SocketAddr;
22 : use std::pin::pin;
23 : use std::sync::Arc;
24 : use tokio::net::TcpListener;
25 : use tokio::task::JoinSet;
26 : use tokio_util::sync::CancellationToken;
27 : use tracing::info;
28 : use tracing::warn;
29 : use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
30 :
31 : project_git_version!(GIT_VERSION);
32 : project_build_tag!(BUILD_TAG);
33 :
34 : use clap::{Parser, ValueEnum};
35 :
36 : #[global_allocator]
37 7934 : static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
38 :
39 247 : #[derive(Clone, Debug, ValueEnum)]
40 : enum AuthBackend {
41 : Console,
42 : #[cfg(feature = "testing")]
43 : Postgres,
44 : Link,
45 : }
46 :
47 : /// Neon proxy/router
48 54 : #[derive(Parser)]
49 : #[command(version = GIT_VERSION, about)]
50 : struct ProxyCliArgs {
51 : /// Name of the region this proxy is deployed in
52 27 : #[clap(long, default_value_t = String::new())]
53 0 : region: String,
54 : /// listen for incoming client connections on ip:port
55 : #[clap(short, long, default_value = "127.0.0.1:4432")]
56 0 : proxy: String,
57 27 : #[clap(value_enum, long, default_value_t = AuthBackend::Link)]
58 0 : auth_backend: AuthBackend,
59 : /// listen for management callback connection on ip:port
60 : #[clap(short, long, default_value = "127.0.0.1:7000")]
61 0 : mgmt: String,
62 : /// listen for incoming http connections (metrics, etc) on ip:port
63 : #[clap(long, default_value = "127.0.0.1:7001")]
64 0 : http: String,
65 : /// listen for incoming wss connections on ip:port
66 : #[clap(long)]
67 : wss: Option<String>,
68 : /// redirect unauthenticated users to the given uri in case of link auth
69 : #[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
70 0 : uri: String,
71 : /// cloud API endpoint for authenticating users
72 : #[clap(
73 : short,
74 : long,
75 : default_value = "http://localhost:3000/authenticate_proxy_request/"
76 : )]
77 0 : auth_endpoint: String,
78 : /// path to TLS key for client postgres connections
79 : ///
80 : /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
81 : #[clap(short = 'k', long, alias = "ssl-key")]
82 : tls_key: Option<String>,
83 : /// path to TLS cert for client postgres connections
84 : ///
85 : /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
86 : #[clap(short = 'c', long, alias = "ssl-cert")]
87 : tls_cert: Option<String>,
88 : /// path to directory with TLS certificates for client postgres connections
89 : #[clap(long)]
90 : certs_dir: Option<String>,
91 : /// timeout for the TLS handshake
92 : #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
93 0 : handshake_timeout: tokio::time::Duration,
94 : /// http endpoint to receive periodic metric updates
95 : #[clap(long)]
96 : metric_collection_endpoint: Option<String>,
97 : /// how often metrics should be sent to a collection endpoint
98 : #[clap(long)]
99 : metric_collection_interval: Option<String>,
100 : /// cache for `wake_compute` api method (use `size=0` to disable)
101 : #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
102 0 : wake_compute_cache: String,
103 : /// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
104 : #[clap(long, default_value = config::WakeComputeLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
105 0 : wake_compute_lock: String,
106 : /// Allow self-signed certificates for compute nodes (for testing)
107 27 : #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
108 0 : allow_self_signed_compute: bool,
109 : #[clap(flatten)]
110 : sql_over_http: SqlOverHttpArgs,
111 : /// timeout for scram authentication protocol
112 : #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
113 0 : scram_protocol_timeout: tokio::time::Duration,
114 : /// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
115 27 : #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
116 0 : require_client_ip: bool,
117 : /// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
118 27 : #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
119 0 : disable_dynamic_rate_limiter: bool,
120 : /// Rate limit algorithm. Makes sense only if `disable_rate_limiter` is `false`.
121 27 : #[clap(value_enum, long, default_value_t = proxy::rate_limiter::RateLimitAlgorithm::Aimd)]
122 0 : rate_limit_algorithm: proxy::rate_limiter::RateLimitAlgorithm,
123 : /// Timeout for rate limiter. If it didn't manage to aquire a permit in this time, it will return an error.
124 : #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
125 0 : rate_limiter_timeout: tokio::time::Duration,
126 : /// Endpoint rate limiter max number of requests per second.
127 : ///
128 : /// Provided in the form '<Requests Per Second>@<Bucket Duration Size>'.
129 : /// Can be given multiple times for different bucket sizes.
130 135 : #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
131 27 : endpoint_rps_limit: Vec<RateBucketInfo>,
132 : /// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
133 27 : #[clap(long, default_value_t = 100)]
134 0 : initial_limit: usize,
135 : #[clap(flatten)]
136 : aimd_config: proxy::rate_limiter::AimdConfig,
137 : /// cache for `allowed_ips` (use `size=0` to disable)
138 : #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
139 0 : allowed_ips_cache: String,
140 : /// cache for `role_secret` (use `size=0` to disable)
141 : #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
142 0 : role_secret_cache: String,
143 : /// disable ip check for http requests. If it is too time consuming, it could be turned off.
144 27 : #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
145 0 : disable_ip_check_for_http: bool,
146 : /// redis url for notifications.
147 : #[clap(long)]
148 : redis_notifications: Option<String>,
149 : /// cache for `project_info` (use `size=0` to disable)
150 : #[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
151 0 : project_info_cache: String,
152 :
153 : #[clap(flatten)]
154 : parquet_upload: ParquetUploadArgs,
155 : }
156 :
157 54 : #[derive(clap::Args, Clone, Copy, Debug)]
158 : struct SqlOverHttpArgs {
159 : /// timeout for http connection requests
160 : #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
161 0 : sql_over_http_timeout: tokio::time::Duration,
162 :
163 : /// Whether the SQL over http pool is opt-in
164 27 : #[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
165 0 : sql_over_http_pool_opt_in: bool,
166 :
167 : /// How many connections to pool for each endpoint. Excess connections are discarded
168 27 : #[clap(long, default_value_t = 20)]
169 0 : sql_over_http_pool_max_conns_per_endpoint: usize,
170 :
171 : /// How many connections to pool for each endpoint. Excess connections are discarded
172 27 : #[clap(long, default_value_t = 20000)]
173 0 : sql_over_http_pool_max_total_conns: usize,
174 :
175 : /// How long pooled connections should remain idle for before closing
176 : #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
177 0 : sql_over_http_idle_timeout: tokio::time::Duration,
178 :
179 : /// Duration each shard will wait on average before a GC sweep.
180 : /// A longer time will causes sweeps to take longer but will interfere less frequently.
181 : #[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
182 0 : sql_over_http_pool_gc_epoch: tokio::time::Duration,
183 :
184 : /// How many shards should the global pool have. Must be a power of two.
185 : /// More shards will introduce less contention for pool operations, but can
186 : /// increase memory used by the pool
187 27 : #[clap(long, default_value_t = 128)]
188 0 : sql_over_http_pool_shards: usize,
189 : }
190 :
191 : #[tokio::main]
192 25 : async fn main() -> anyhow::Result<()> {
193 25 : let _logging_guard = proxy::logging::init().await?;
194 25 : let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
195 25 : let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
196 25 :
197 25 : info!("Version: {GIT_VERSION}");
198 25 : info!("Build_tag: {BUILD_TAG}");
199 25 : ::metrics::set_build_info_metric(GIT_VERSION, BUILD_TAG);
200 25 :
201 25 : match proxy::jemalloc::MetricRecorder::new(prometheus::default_registry()) {
202 25 : Ok(t) => {
203 25 : t.start();
204 25 : }
205 25 : Err(e) => tracing::error!(error = ?e, "could not start jemalloc metrics loop"),
206 25 : }
207 25 :
208 25 : let args = ProxyCliArgs::parse();
209 25 : let config = build_config(&args)?;
210 25 :
211 25 : info!("Authentication backend: {}", config.auth_backend);
212 25 :
213 25 : // Check that we can bind to address before further initialization
214 25 : let http_address: SocketAddr = args.http.parse()?;
215 25 : info!("Starting http on {http_address}");
216 25 : let http_listener = TcpListener::bind(http_address).await?.into_std()?;
217 25 :
218 25 : let mgmt_address: SocketAddr = args.mgmt.parse()?;
219 25 : info!("Starting mgmt on {mgmt_address}");
220 25 : let mgmt_listener = TcpListener::bind(mgmt_address).await?;
221 25 :
222 25 : let proxy_address: SocketAddr = args.proxy.parse()?;
223 25 : info!("Starting proxy on {proxy_address}");
224 25 : let proxy_listener = TcpListener::bind(proxy_address).await?;
225 25 : let cancellation_token = CancellationToken::new();
226 25 :
227 25 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
228 25 :
229 25 : // client facing tasks. these will exit on error or on cancellation
230 25 : // cancellation returns Ok(())
231 25 : let mut client_tasks = JoinSet::new();
232 25 : client_tasks.spawn(proxy::proxy::task_main(
233 25 : config,
234 25 : proxy_listener,
235 25 : cancellation_token.clone(),
236 25 : endpoint_rate_limiter.clone(),
237 25 : ));
238 25 :
239 25 : // TODO: rename the argument to something like serverless.
240 25 : // It now covers more than just websockets, it also covers SQL over HTTP.
241 25 : if let Some(serverless_address) = args.wss {
242 25 : let serverless_address: SocketAddr = serverless_address.parse()?;
243 25 : info!("Starting wss on {serverless_address}");
244 25 : let serverless_listener = TcpListener::bind(serverless_address).await?;
245 25 :
246 25 : client_tasks.spawn(serverless::task_main(
247 25 : config,
248 25 : serverless_listener,
249 25 : cancellation_token.clone(),
250 25 : endpoint_rate_limiter.clone(),
251 25 : ));
252 25 : }
253 25 :
254 25 : client_tasks.spawn(proxy::context::parquet::worker(
255 25 : cancellation_token.clone(),
256 25 : args.parquet_upload,
257 25 : ));
258 25 :
259 25 : // maintenance tasks. these never return unless there's an error
260 25 : let mut maintenance_tasks = JoinSet::new();
261 25 : maintenance_tasks.spawn(proxy::handle_signals(cancellation_token));
262 25 : maintenance_tasks.spawn(http::health_server::task_main(http_listener));
263 25 : maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
264 25 :
265 25 : if let Some(metrics_config) = &config.metric_collection {
266 1 : maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
267 24 : }
268 25 :
269 25 : if let auth::BackendType::Console(api, _) = &config.auth_backend {
270 25 : if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
271 25 : let cache = api.caches.project_info.clone();
272 25 : if let Some(url) = args.redis_notifications {
273 25 : info!("Starting redis notifications listener ({url})");
274 25 : maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
275 25 : }
276 25 : maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
277 25 : }
278 25 : }
279 25 :
280 25 : let maintenance = loop {
281 25 : // get one complete task
282 100 : match futures::future::select(
283 100 : pin!(maintenance_tasks.join_next()),
284 100 : pin!(client_tasks.join_next()),
285 100 : )
286 56 : .await
287 25 : {
288 25 : // exit immediately on maintenance task completion
289 25 : Either::Left((Some(res), _)) => break proxy::flatten_err(res)?,
290 25 : // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
291 25 : Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
292 25 : // exit immediately on client task error
293 75 : Either::Right((Some(res), _)) => proxy::flatten_err(res)?,
294 25 : // exit if all our client tasks have shutdown gracefully
295 25 : Either::Right((None, _)) => return Ok(()),
296 25 : }
297 25 : };
298 25 :
299 25 : // maintenance tasks return Infallible success values, this is an impossible value
300 25 : // so this match statically ensures that there are no possibilities for that value
301 25 : match maintenance {}
302 25 : }
303 :
304 : /// ProxyConfig is created at proxy startup, and lives forever.
305 25 : fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
306 25 : let tls_config = match (&args.tls_key, &args.tls_cert) {
307 25 : (Some(key_path), Some(cert_path)) => Some(config::configure_tls(
308 25 : key_path,
309 25 : cert_path,
310 25 : args.certs_dir.as_ref(),
311 25 : )?),
312 0 : (None, None) => None,
313 0 : _ => bail!("either both or neither tls-key and tls-cert must be specified"),
314 : };
315 :
316 25 : if args.allow_self_signed_compute {
317 3 : warn!("allowing self-signed compute certificates");
318 22 : }
319 :
320 25 : let metric_collection = match (
321 25 : &args.metric_collection_endpoint,
322 25 : &args.metric_collection_interval,
323 : ) {
324 1 : (Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
325 1 : endpoint: endpoint.parse()?,
326 1 : interval: humantime::parse_duration(interval)?,
327 : }),
328 24 : (None, None) => None,
329 0 : _ => bail!(
330 0 : "either both or neither metric-collection-endpoint \
331 0 : and metric-collection-interval must be specified"
332 0 : ),
333 : };
334 25 : let rate_limiter_config = RateLimiterConfig {
335 25 : disable: args.disable_dynamic_rate_limiter,
336 25 : algorithm: args.rate_limit_algorithm,
337 25 : timeout: args.rate_limiter_timeout,
338 25 : initial_limit: args.initial_limit,
339 25 : aimd_config: Some(args.aimd_config),
340 25 : };
341 :
342 25 : let auth_backend = match &args.auth_backend {
343 : AuthBackend::Console => {
344 1 : let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
345 1 : let project_info_cache_config: ProjectInfoCacheOptions =
346 1 : args.project_info_cache.parse()?;
347 :
348 1 : info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
349 1 : info!(
350 1 : "Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
351 1 : );
352 1 : let caches = Box::leak(Box::new(console::caches::ApiCaches::new(
353 1 : wake_compute_cache_config,
354 1 : project_info_cache_config,
355 1 : )));
356 :
357 : let config::WakeComputeLockOptions {
358 1 : shards,
359 1 : permits,
360 1 : epoch,
361 1 : timeout,
362 1 : } = args.wake_compute_lock.parse()?;
363 1 : info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)");
364 1 : let locks = Box::leak(Box::new(
365 1 : console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout)
366 1 : .unwrap(),
367 1 : ));
368 1 : tokio::spawn(locks.garbage_collect_worker(epoch));
369 :
370 1 : let url = args.auth_endpoint.parse()?;
371 1 : let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));
372 1 :
373 1 : let api = console::provider::neon::Api::new(endpoint, caches, locks);
374 1 : let api = console::provider::ConsoleBackend::Console(api);
375 1 : auth::BackendType::Console(MaybeOwned::Owned(api), ())
376 : }
377 : #[cfg(feature = "testing")]
378 : AuthBackend::Postgres => {
379 21 : let url = args.auth_endpoint.parse()?;
380 21 : let api = console::provider::mock::Api::new(url);
381 21 : let api = console::provider::ConsoleBackend::Postgres(api);
382 21 : auth::BackendType::Console(MaybeOwned::Owned(api), ())
383 : }
384 : AuthBackend::Link => {
385 3 : let url = args.uri.parse()?;
386 3 : auth::BackendType::Link(MaybeOwned::Owned(url))
387 : }
388 : };
389 25 : let http_config = HttpConfig {
390 25 : request_timeout: args.sql_over_http.sql_over_http_timeout,
391 25 : pool_options: GlobalConnPoolOptions {
392 25 : max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
393 25 : gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
394 25 : pool_shards: args.sql_over_http.sql_over_http_pool_shards,
395 25 : idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
396 25 : opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
397 25 : max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
398 25 : },
399 25 : };
400 25 : let authentication_config = AuthenticationConfig {
401 25 : scram_protocol_timeout: args.scram_protocol_timeout,
402 25 : };
403 25 :
404 25 : let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
405 25 : RateBucketInfo::validate(&mut endpoint_rps_limit)?;
406 :
407 25 : let config = Box::leak(Box::new(ProxyConfig {
408 25 : tls_config,
409 25 : auth_backend,
410 25 : metric_collection,
411 25 : allow_self_signed_compute: args.allow_self_signed_compute,
412 25 : http_config,
413 25 : authentication_config,
414 25 : require_client_ip: args.require_client_ip,
415 25 : disable_ip_check_for_http: args.disable_ip_check_for_http,
416 25 : endpoint_rps_limit,
417 25 : handshake_timeout: args.handshake_timeout,
418 25 : // TODO: add this argument
419 25 : region: args.region.clone(),
420 25 : }));
421 25 :
422 25 : Ok(config)
423 25 : }
424 :
425 : #[cfg(test)]
426 : mod tests {
427 : use std::time::Duration;
428 :
429 : use clap::Parser;
430 : use proxy::rate_limiter::RateBucketInfo;
431 :
432 2 : #[test]
433 2 : fn parse_endpoint_rps_limit() {
434 2 : let config = super::ProxyCliArgs::parse_from([
435 2 : "proxy",
436 2 : "--endpoint-rps-limit",
437 2 : "100@1s",
438 2 : "--endpoint-rps-limit",
439 2 : "20@30s",
440 2 : ]);
441 2 :
442 2 : assert_eq!(
443 2 : config.endpoint_rps_limit,
444 2 : vec![
445 2 : RateBucketInfo::new(100, Duration::from_secs(1)),
446 2 : RateBucketInfo::new(20, Duration::from_secs(30)),
447 2 : ]
448 2 : );
449 2 : }
450 : }
|