            Source code
       1              : use std::{
       2              :     net::SocketAddr,
       3              :     path::{Path, PathBuf},
       4              :     pin::pin,
       5              :     sync::Arc,
       6              :     time::Duration,
       7              : };
       8              : 
       9              : use anyhow::{bail, ensure};
      10              : use dashmap::DashMap;
      11              : use futures::{future::Either, FutureExt};
      12              : use proxy::{
      13              :     auth::backend::local::{JwksRoleSettings, LocalBackend, JWKS_ROLE_MAP},
      14              :     cancellation::CancellationHandlerMain,
      15              :     config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
      16              :     console::{locks::ApiLocks, messages::JwksRoleMapping},
      17              :     http::health_server::AppMetrics,
      18              :     metrics::{Metrics, ThreadPoolMetrics},
      19              :     rate_limiter::{BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo},
      20              :     scram::threadpool::ThreadPool,
      21              :     serverless::{self, cancel_set::CancelSet, GlobalConnPoolOptions},
      22              : };
      23              : 
      24              : project_git_version!(GIT_VERSION);
      25              : project_build_tag!(BUILD_TAG);
      26              : 
      27              : use clap::Parser;
      28              : use tokio::{net::TcpListener, task::JoinSet};
      29              : use tokio_util::sync::CancellationToken;
      30              : use tracing::{error, info, warn};
      31              : use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
      32              : 
      33              : #[global_allocator]
      34              : static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
      35              : 
      36              : /// Neon proxy/router
      37            0 : #[derive(Parser)]
      38              : #[command(version = GIT_VERSION, about)]
      39              : struct LocalProxyCliArgs {
      40              :     /// listen for incoming metrics connections on ip:port
      41              :     #[clap(long, default_value = "")]
      42            0 :     metrics: String,
      43              :     /// listen for incoming http connections on ip:port
      44              :     #[clap(long)]
      45            0 :     http: String,
      46              :     /// timeout for the TLS handshake
      47              :     #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
      48            0 :     handshake_timeout: tokio::time::Duration,
      49              :     /// lock for `connect_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
      50              :     #[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK)]
      51            0 :     connect_compute_lock: String,
      52              :     #[clap(flatten)]
      53              :     sql_over_http: SqlOverHttpArgs,
      54              :     /// User rate limiter max number of requests per second.
      55              :     ///
      56              :     /// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
      57              :     /// Can be given multiple times for different bucket sizes.
      58            0 :     #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
      59            0 :     user_rps_limit: Vec<RateBucketInfo>,
      60              :     /// Whether the auth rate limiter actually takes effect (for testing)
      61            0 :     #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
      62            0 :     auth_rate_limit_enabled: bool,
      63              :     /// Authentication rate limiter max number of hashes per second.
      64            0 :     #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
      65            0 :     auth_rate_limit: Vec<RateBucketInfo>,
      66              :     /// The IP subnet to use when considering whether two IP addresses are considered the same.
      67            0 :     #[clap(long, default_value_t = 64)]
      68            0 :     auth_rate_limit_ip_subnet: u8,
      69              :     /// Whether to retry the connection to the compute node
      70              :     #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
      71            0 :     connect_to_compute_retry: String,
      72              :     /// Address of the postgres server
      73              :     #[clap(long, default_value = "")]
      74            0 :     compute: SocketAddr,
      75              :     /// File address of the local proxy config file
      76              :     #[clap(long, default_value = "./localproxy.json")]
      77            0 :     config_path: PathBuf,
      78              : }
      79              : 
      80            0 : #[derive(clap::Args, Clone, Copy, Debug)]
      81              : struct SqlOverHttpArgs {
      82              :     /// How many connections to pool for each endpoint. Excess connections are discarded
      83            0 :     #[clap(long, default_value_t = 200)]
      84            0 :     sql_over_http_pool_max_total_conns: usize,
      85              : 
      86              :     /// How long pooled connections should remain idle for before closing
      87              :     #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
      88            0 :     sql_over_http_idle_timeout: tokio::time::Duration,
      89              : 
      90            0 :     #[clap(long, default_value_t = 100)]
      91            0 :     sql_over_http_client_conn_threshold: u64,
      92              : 
      93            0 :     #[clap(long, default_value_t = 16)]
      94            0 :     sql_over_http_cancel_set_shards: usize,
      95              : 
      96            0 :     #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
      97            0 :     sql_over_http_max_request_size_bytes: u64,
      98              : 
      99            0 :     #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
     100            0 :     sql_over_http_max_response_size_bytes: usize,
     101              : }
     102              : 
     103              : #[tokio::main]
     104            0 : async fn main() -> anyhow::Result<()> {
     105            0 :     let _logging_guard = proxy::logging::init().await?;
     106            0 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
     107            0 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
     108            0 : 
     109            0 :     Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
     110            0 : 
     111            0 :     info!("Version: {GIT_VERSION}");
     112            0 :     info!("Build_tag: {BUILD_TAG}");
     113            0 :     let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo {
     114            0 :         revision: GIT_VERSION,
     115            0 :         build_tag: BUILD_TAG,
     116            0 :     });
     117            0 : 
     118            0 :     let jemalloc = match proxy::jemalloc::MetricRecorder::new() {
     119            0 :         Ok(t) => Some(t),
     120            0 :         Err(e) => {
     121            0 :             tracing::error!(error = ?e, "could not start jemalloc metrics loop");
     122            0 :             None
     123            0 :         }
     124            0 :     };
     125            0 : 
     126            0 :     let args = LocalProxyCliArgs::parse();
     127            0 :     let config = build_config(&args)?;
     128            0 : 
     129            0 :     let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
     130            0 :     let http_listener = TcpListener::bind(args.http).await?;
     131            0 :     let shutdown = CancellationToken::new();
     132            0 : 
     133            0 :     // todo: should scale with CU
     134            0 :     let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     135            0 :         LeakyBucketConfig {
     136            0 :             rps: 10.0,
     137            0 :             max: 100.0,
     138            0 :         },
     139            0 :         16,
     140            0 :     ));
     141            0 : 
     142            0 :     refresh_config(args.config_path.clone()).await;
     143            0 : 
     144            0 :     let mut maintenance_tasks = JoinSet::new();
     145            0 :     maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), move || {
     146            0 :         refresh_config(args.config_path.clone()).map(Ok)
     147            0 :     }));
     148            0 :     maintenance_tasks.spawn(proxy::http::health_server::task_main(
     149            0 :         metrics_listener,
     150            0 :         AppMetrics {
     151            0 :             jemalloc,
     152            0 :             neon_metrics,
     153            0 :             proxy: proxy::metrics::Metrics::get(),
     154            0 :         },
     155            0 :     ));
     156            0 : 
     157            0 :     let task = serverless::task_main(
     158            0 :         config,
     159            0 :         http_listener,
     160            0 :         shutdown.clone(),
     161            0 :         Arc::new(CancellationHandlerMain::new(
     162            0 :             Arc::new(DashMap::new()),
     163            0 :             None,
     164            0 :             proxy::metrics::CancellationSource::Local,
     165            0 :         )),
     166            0 :         endpoint_rate_limiter,
     167            0 :     );
     168            0 : 
     169            0 :     match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {
     170            0 :         // exit immediately on maintenance task completion
     171            0 :         Either::Left((Some(res), _)) => match proxy::flatten_err(res)? {},
     172            0 :         // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
     173            0 :         Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
     174            0 :         // exit immediately on client task error
     175            0 :         Either::Right((res, _)) => res?,
     176            0 :     }
     177            0 : 
     178            0 :     Ok(())
     179            0 : }
     180              : 
     181              : /// ProxyConfig is created at proxy startup, and lives forever.
     182            0 : fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
     183              :     let config::ConcurrencyLockOptions {
     184            0 :         shards,
     185            0 :         limiter,
     186            0 :         epoch,
     187            0 :         timeout,
     188            0 :     } = args.connect_compute_lock.parse()?;
     189            0 :     info!(
     190              :         ?limiter,
     191              :         shards,
     192              :         ?epoch,
     193            0 :         "Using NodeLocks (connect_compute)"
     194              :     );
     195            0 :     let connect_compute_locks = ApiLocks::new(
     196            0 :         "connect_compute_lock",
     197            0 :         limiter,
     198            0 :         shards,
     199            0 :         timeout,
     200            0 :         epoch,
     201            0 :         &Metrics::get().proxy.connect_compute_lock,
     202            0 :     )?;
     203              : 
     204            0 :     let http_config = HttpConfig {
     205            0 :         accept_websockets: false,
     206            0 :         pool_options: GlobalConnPoolOptions {
     207            0 :             gc_epoch: Duration::from_secs(60),
     208            0 :             pool_shards: 2,
     209            0 :             idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
     210            0 :             opt_in: false,
     211            0 : 
     212            0 :             max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_total_conns,
     213            0 :             max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
     214            0 :         },
     215            0 :         cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
     216            0 :         client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
     217            0 :         max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
     218            0 :         max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
     219            0 :     };
     220            0 : 
     221            0 :     Ok(Box::leak(Box::new(ProxyConfig {
     222            0 :         tls_config: None,
     223            0 :         auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
     224            0 :             LocalBackend::new(args.compute),
     225            0 :         )),
     226            0 :         metric_collection: None,
     227            0 :         allow_self_signed_compute: false,
     228            0 :         http_config,
     229            0 :         authentication_config: AuthenticationConfig {
     230            0 :             thread_pool: ThreadPool::new(0),
     231            0 :             scram_protocol_timeout: Duration::from_secs(10),
     232            0 :             rate_limiter_enabled: false,
     233            0 :             rate_limiter: BucketRateLimiter::new(vec![]),
     234            0 :             rate_limit_ip_subnet: 64,
     235            0 :             ip_allowlist_check_enabled: true,
     236            0 :         },
     237            0 :         require_client_ip: false,
     238            0 :         handshake_timeout: Duration::from_secs(10),
     239            0 :         region: "local".into(),
     240            0 :         wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
     241            0 :         connect_compute_locks,
     242            0 :         connect_to_compute_retry_config: RetryConfig::parse(
     243            0 :             RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES,
     244            0 :         )?,
     245              :     })))
     246            0 : }
     247              : 
     248            0 : async fn refresh_config(path: PathBuf) {
     249            0 :     match refresh_config_inner(&path).await {
     250            0 :         Ok(()) => {}
     251            0 :         Err(e) => {
     252            0 :             error!(error=?e, ?path, "could not read config file");
     253              :         }
     254              :     }
     255            0 : }
     256              : 
     257            0 : async fn refresh_config_inner(path: &Path) -> anyhow::Result<()> {
     258            0 :     let bytes = tokio::fs::read(&path).await?;
     259            0 :     let mut data: JwksRoleMapping = serde_json::from_slice(&bytes)?;
     260              : 
     261            0 :     let mut settings = None;
     262              : 
     263            0 :     for mapping in data.roles.values_mut() {
     264            0 :         for jwks in &mut mapping.jwks {
     265            0 :             ensure!(
     266            0 :                 jwks.jwks_url.has_authority()
     267            0 :                     && (jwks.jwks_url.scheme() == "http" || jwks.jwks_url.scheme() == "https"),
     268            0 :                 "Invalid JWKS url. Must be HTTP",
     269              :             );
     270              : 
     271            0 :             ensure!(
     272            0 :                 jwks.jwks_url
     273            0 :                     .host()
     274            0 :                     .is_some_and(|h| h != url::Host::Domain("")),
     275            0 :                 "Invalid JWKS url. No domain listed",
     276              :             );
     277              : 
     278              :             // clear username, password and ports
     279            0 :             jwks.jwks_url.set_username("").expect(
     280            0 :                 "url can be a base and has a valid host and is not a file. should not error",
     281            0 :             );
     282            0 :             jwks.jwks_url.set_password(None).expect(
     283            0 :                 "url can be a base and has a valid host and is not a file. should not error",
     284            0 :             );
     285            0 :             // local testing is hard if we need to have a specific restricted port
     286            0 :             if cfg!(not(feature = "testing")) {
     287            0 :                 jwks.jwks_url.set_port(None).expect(
     288            0 :                     "url can be a base and has a valid host and is not a file. should not error",
     289            0 :                 );
     290            0 :             }
     291              : 
     292              :             // clear query params
     293            0 :             jwks.jwks_url.set_fragment(None);
     294            0 :             jwks.jwks_url.query_pairs_mut().clear().finish();
     295            0 : 
     296            0 :             if jwks.jwks_url.scheme() != "https" {
     297              :                 // local testing is hard if we need to set up https support.
     298            0 :                 if cfg!(not(feature = "testing")) {
     299            0 :                     jwks.jwks_url
     300            0 :                         .set_scheme("https")
     301            0 :                         .expect("should not error to set the scheme to https if it was http");
     302            0 :                 } else {
     303            0 :                     warn!(scheme = jwks.jwks_url.scheme(), "JWKS url is not HTTPS");
     304              :                 }
     305            0 :             }
     306              : 
     307            0 :             let (pr, br) = settings.get_or_insert((jwks.project_id, jwks.branch_id));
     308            0 :             ensure!(
     309            0 :                 *pr == jwks.project_id,
     310            0 :                 "inconsistent project IDs configured"
     311              :             );
     312            0 :             ensure!(*br == jwks.branch_id, "inconsistent branch IDs configured");
     313              :         }
     314              :     }
     315              : 
     316            0 :     if let Some((project_id, branch_id)) = settings {
     317            0 : {
     318            0 :             roles: data.roles,
     319            0 :             project_id,
     320            0 :             branch_id,
     321            0 :         })));
     322            0 :     }
     323              : 
     324            0 :     Ok(())
     325            0 : }

