LCOV - code coverage report
Current view: top level - proxy/src/bin - local_proxy.rs (source / functions) Coverage Total Hit
Test: 4f58e98c51285c7fa348e0b410c88a10caf68ad2.info Lines: 0.0 % 276 0
Test Date: 2025-01-07 20:58:07 Functions: 0.0 % 76 0

            Line data    Source code
       1              : use std::net::SocketAddr;
       2              : use std::pin::pin;
       3              : use std::str::FromStr;
       4              : use std::sync::Arc;
       5              : use std::time::Duration;
       6              : 
       7              : use anyhow::{bail, ensure, Context};
       8              : use camino::{Utf8Path, Utf8PathBuf};
       9              : use compute_api::spec::LocalProxySpec;
      10              : use dashmap::DashMap;
      11              : use futures::future::Either;
      12              : use proxy::auth::backend::jwt::JwkCache;
      13              : use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
      14              : use proxy::auth::{self};
      15              : use proxy::cancellation::CancellationHandlerMain;
      16              : use proxy::config::{
      17              :     self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
      18              : };
      19              : use proxy::control_plane::locks::ApiLocks;
      20              : use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
      21              : use proxy::http::health_server::AppMetrics;
      22              : use proxy::intern::RoleNameInt;
      23              : use proxy::metrics::{Metrics, ThreadPoolMetrics};
      24              : use proxy::rate_limiter::{
      25              :     BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo,
      26              : };
      27              : use proxy::scram::threadpool::ThreadPool;
      28              : use proxy::serverless::cancel_set::CancelSet;
      29              : use proxy::serverless::{self, GlobalConnPoolOptions};
      30              : use proxy::tls::client_config::compute_client_config_with_root_certs;
      31              : use proxy::types::RoleName;
      32              : use proxy::url::ApiUrl;
      33              : 
      34              : project_git_version!(GIT_VERSION);
      35              : project_build_tag!(BUILD_TAG);
      36              : 
      37              : use clap::Parser;
      38              : use thiserror::Error;
      39              : use tokio::net::TcpListener;
      40              : use tokio::sync::Notify;
      41              : use tokio::task::JoinSet;
      42              : use tokio_util::sync::CancellationToken;
      43              : use tracing::{debug, error, info, warn};
      44              : use utils::sentry_init::init_sentry;
      45              : use utils::{pid_file, project_build_tag, project_git_version};
      46              : 
      47              : #[global_allocator]
      48              : static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
      49              : 
      50              : /// Neon proxy/router
      51              : #[derive(Parser)]
      52              : #[command(version = GIT_VERSION, about)]
      53              : struct LocalProxyCliArgs {
      54              :     /// listen for incoming metrics connections on ip:port
      55              :     #[clap(long, default_value = "127.0.0.1:7001")]
      56            0 :     metrics: String,
      57              :     /// listen for incoming http connections on ip:port
      58              :     #[clap(long)]
      59            0 :     http: String,
      60              :     /// timeout for the TLS handshake
      61              :     #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
      62            0 :     handshake_timeout: tokio::time::Duration,
      63              :     /// lock for `connect_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
      64              :     #[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK)]
      65            0 :     connect_compute_lock: String,
      66              :     #[clap(flatten)]
      67              :     sql_over_http: SqlOverHttpArgs,
      68              :     /// User rate limiter max number of requests per second.
      69              :     ///
      70              :     /// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
      71              :     /// Can be given multiple times for different bucket sizes.
      72            0 :     #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
      73            0 :     user_rps_limit: Vec<RateBucketInfo>,
      74              :     /// Whether the auth rate limiter actually takes effect (for testing)
      75            0 :     #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
      76            0 :     auth_rate_limit_enabled: bool,
      77              :     /// Authentication rate limiter max number of hashes per second.
      78            0 :     #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
      79            0 :     auth_rate_limit: Vec<RateBucketInfo>,
      80              :     /// The IP subnet to use when considering whether two IP addresses are considered the same.
      81            0 :     #[clap(long, default_value_t = 64)]
      82            0 :     auth_rate_limit_ip_subnet: u8,
      83              :     /// Whether to retry the connection to the compute node
      84              :     #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
      85            0 :     connect_to_compute_retry: String,
      86              :     /// Address of the postgres server
      87              :     #[clap(long, default_value = "127.0.0.1:5432")]
      88            0 :     postgres: SocketAddr,
      89              :     /// Address of the compute-ctl api service
      90              :     #[clap(long, default_value = "http://127.0.0.1:3080/")]
      91            0 :     compute_ctl: ApiUrl,
      92              :     /// Path of the local proxy config file
      93              :     #[clap(long, default_value = "./local_proxy.json")]
      94            0 :     config_path: Utf8PathBuf,
      95              :     /// Path of the local proxy PID file
      96              :     #[clap(long, default_value = "./local_proxy.pid")]
      97            0 :     pid_path: Utf8PathBuf,
      98              : }
      99              : 
     100              : #[derive(clap::Args, Clone, Copy, Debug)]
     101              : struct SqlOverHttpArgs {
     102              :     /// How many connections to pool for each endpoint. Excess connections are discarded
     103            0 :     #[clap(long, default_value_t = 200)]
     104            0 :     sql_over_http_pool_max_total_conns: usize,
     105              : 
     106              :     /// How long pooled connections should remain idle for before closing
     107              :     #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
     108            0 :     sql_over_http_idle_timeout: tokio::time::Duration,
     109              : 
     110            0 :     #[clap(long, default_value_t = 100)]
     111            0 :     sql_over_http_client_conn_threshold: u64,
     112              : 
     113            0 :     #[clap(long, default_value_t = 16)]
     114            0 :     sql_over_http_cancel_set_shards: usize,
     115              : 
     116            0 :     #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
     117            0 :     sql_over_http_max_request_size_bytes: usize,
     118              : 
     119            0 :     #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
     120            0 :     sql_over_http_max_response_size_bytes: usize,
     121              : }
     122              : 
     123              : #[tokio::main]
     124            0 : async fn main() -> anyhow::Result<()> {
     125            0 :     let _logging_guard = proxy::logging::init_local_proxy()?;
     126            0 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
     127            0 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
     128            0 : 
     129            0 :     Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
     130            0 : 
     131            0 :     // TODO: refactor these to use labels
     132            0 :     debug!("Version: {GIT_VERSION}");
     133            0 :     debug!("Build_tag: {BUILD_TAG}");
     134            0 :     let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo {
     135            0 :         revision: GIT_VERSION,
     136            0 :         build_tag: BUILD_TAG,
     137            0 :     });
     138            0 : 
     139            0 :     let jemalloc = match proxy::jemalloc::MetricRecorder::new() {
     140            0 :         Ok(t) => Some(t),
     141            0 :         Err(e) => {
     142            0 :             tracing::error!(error = ?e, "could not start jemalloc metrics loop");
     143            0 :             None
     144            0 :         }
     145            0 :     };
     146            0 : 
     147            0 :     let args = LocalProxyCliArgs::parse();
     148            0 :     let config = build_config(&args)?;
     149            0 :     let auth_backend = build_auth_backend(&args)?;
     150            0 : 
     151            0 :     // before we bind to any ports, write the process ID to a file
     152            0 :     // so that compute-ctl can find our process later
     153            0 :     // in order to trigger the appropriate SIGHUP on config change.
     154            0 :     //
     155            0 :     // This also claims a "lock" that makes sure only one instance
     156            0 :     // of local_proxy runs at a time.
     157            0 :     let _process_guard = loop {
     158            0 :         match pid_file::claim_for_current_process(&args.pid_path) {
     159            0 :             Ok(guard) => break guard,
     160            0 :             Err(e) => {
     161            0 :                 // compute-ctl might have tried to read the pid-file to let us
     162            0 :                 // know about some config change. We should try again.
     163            0 :                 error!(path=?args.pid_path, "could not claim PID file guard: {e:?}");
     164            0 :                 tokio::time::sleep(Duration::from_secs(1)).await;
     165            0 :             }
     166            0 :         }
     167            0 :     };
     168            0 : 
     169            0 :     let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
     170            0 :     let http_listener = TcpListener::bind(args.http).await?;
     171            0 :     let shutdown = CancellationToken::new();
     172            0 : 
     173            0 :     // todo: should scale with CU
     174            0 :     let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     175            0 :         LeakyBucketConfig {
     176            0 :             rps: 10.0,
     177            0 :             max: 100.0,
     178            0 :         },
     179            0 :         16,
     180            0 :     ));
     181            0 : 
     182            0 :     let mut maintenance_tasks = JoinSet::new();
     183            0 : 
     184            0 :     let refresh_config_notify = Arc::new(Notify::new());
     185            0 :     maintenance_tasks.spawn(proxy::signals::handle(shutdown.clone(), {
     186            0 :         let refresh_config_notify = Arc::clone(&refresh_config_notify);
     187            0 :         move || {
     188            0 :             refresh_config_notify.notify_one();
     189            0 :         }
     190            0 :     }));
     191            0 : 
     192            0 :     // trigger the first config load **after** setting up the signal hook
     193            0 :     // to avoid the race condition where:
     194            0 :     // 1. No config file registered when local_proxy starts up
     195            0 :     // 2. The config file is written but the signal hook is not yet received
     196            0 :     // 3. local_proxy completes startup but has no config loaded, despite there being a registerd config.
     197            0 :     refresh_config_notify.notify_one();
     198            0 :     tokio::spawn(refresh_config_loop(args.config_path, refresh_config_notify));
     199            0 : 
     200            0 :     maintenance_tasks.spawn(proxy::http::health_server::task_main(
     201            0 :         metrics_listener,
     202            0 :         AppMetrics {
     203            0 :             jemalloc,
     204            0 :             neon_metrics,
     205            0 :             proxy: proxy::metrics::Metrics::get(),
     206            0 :         },
     207            0 :     ));
     208            0 : 
     209            0 :     let task = serverless::task_main(
     210            0 :         config,
     211            0 :         auth_backend,
     212            0 :         http_listener,
     213            0 :         shutdown.clone(),
     214            0 :         Arc::new(CancellationHandlerMain::new(
     215            0 :             &config.connect_to_compute,
     216            0 :             Arc::new(DashMap::new()),
     217            0 :             None,
     218            0 :             proxy::metrics::CancellationSource::Local,
     219            0 :         )),
     220            0 :         endpoint_rate_limiter,
     221            0 :     );
     222            0 : 
     223            0 :     match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {
     224            0 :         // exit immediately on maintenance task completion
     225            0 :         Either::Left((Some(res), _)) => match proxy::error::flatten_err(res)? {},
     226            0 :         // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
     227            0 :         Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
     228            0 :         // exit immediately on client task error
     229            0 :         Either::Right((res, _)) => res?,
     230            0 :     }
     231            0 : 
     232            0 :     Ok(())
     233            0 : }
     234              : 
     235              : /// ProxyConfig is created at proxy startup, and lives forever.
     236            0 : fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
     237              :     let config::ConcurrencyLockOptions {
     238            0 :         shards,
     239            0 :         limiter,
     240            0 :         epoch,
     241            0 :         timeout,
     242            0 :     } = args.connect_compute_lock.parse()?;
     243            0 :     info!(
     244              :         ?limiter,
     245              :         shards,
     246              :         ?epoch,
     247            0 :         "Using NodeLocks (connect_compute)"
     248              :     );
     249            0 :     let connect_compute_locks = ApiLocks::new(
     250            0 :         "connect_compute_lock",
     251            0 :         limiter,
     252            0 :         shards,
     253            0 :         timeout,
     254            0 :         epoch,
     255            0 :         &Metrics::get().proxy.connect_compute_lock,
     256            0 :     )?;
     257              : 
     258            0 :     let http_config = HttpConfig {
     259            0 :         accept_websockets: false,
     260            0 :         pool_options: GlobalConnPoolOptions {
     261            0 :             gc_epoch: Duration::from_secs(60),
     262            0 :             pool_shards: 2,
     263            0 :             idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
     264            0 :             opt_in: false,
     265            0 : 
     266            0 :             max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_total_conns,
     267            0 :             max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
     268            0 :         },
     269            0 :         cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
     270            0 :         client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
     271            0 :         max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
     272            0 :         max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
     273            0 :     };
     274              : 
     275            0 :     let compute_config = ComputeConfig {
     276            0 :         retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
     277            0 :         tls: Arc::new(compute_client_config_with_root_certs()?),
     278            0 :         timeout: Duration::from_secs(2),
     279            0 :     };
     280            0 : 
     281            0 :     Ok(Box::leak(Box::new(ProxyConfig {
     282            0 :         tls_config: None,
     283            0 :         metric_collection: None,
     284            0 :         http_config,
     285            0 :         authentication_config: AuthenticationConfig {
     286            0 :             jwks_cache: JwkCache::default(),
     287            0 :             thread_pool: ThreadPool::new(0),
     288            0 :             scram_protocol_timeout: Duration::from_secs(10),
     289            0 :             rate_limiter_enabled: false,
     290            0 :             rate_limiter: BucketRateLimiter::new(vec![]),
     291            0 :             rate_limit_ip_subnet: 64,
     292            0 :             ip_allowlist_check_enabled: true,
     293            0 :             is_auth_broker: false,
     294            0 :             accept_jwts: true,
     295            0 :             console_redirect_confirmation_timeout: Duration::ZERO,
     296            0 :         },
     297            0 :         proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
     298            0 :         handshake_timeout: Duration::from_secs(10),
     299            0 :         region: "local".into(),
     300            0 :         wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
     301            0 :         connect_compute_locks,
     302            0 :         connect_to_compute: compute_config,
     303              :     })))
     304            0 : }
     305              : 
     306              : /// auth::Backend is created at proxy startup, and lives forever.
     307            0 : fn build_auth_backend(
     308            0 :     args: &LocalProxyCliArgs,
     309            0 : ) -> anyhow::Result<&'static auth::Backend<'static, ()>> {
     310            0 :     let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
     311            0 :         LocalBackend::new(args.postgres, args.compute_ctl.clone()),
     312            0 :     ));
     313            0 : 
     314            0 :     Ok(Box::leak(Box::new(auth_backend)))
     315            0 : }
     316              : 
     317              : #[derive(Error, Debug)]
     318              : enum RefreshConfigError {
     319              :     #[error(transparent)]
     320              :     Read(#[from] std::io::Error),
     321              :     #[error(transparent)]
     322              :     Parse(#[from] serde_json::Error),
     323              :     #[error(transparent)]
     324              :     Validate(anyhow::Error),
     325              : }
     326              : 
     327            0 : async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
     328            0 :     let mut init = true;
     329              :     loop {
     330            0 :         rx.notified().await;
     331              : 
     332            0 :         match refresh_config_inner(&path).await {
     333            0 :             Ok(()) => {}
     334              :             // don't log for file not found errors if this is the first time we are checking
     335              :             // for computes that don't use local_proxy, this is not an error.
     336            0 :             Err(RefreshConfigError::Read(e))
     337            0 :                 if init && e.kind() == std::io::ErrorKind::NotFound =>
     338            0 :             {
     339            0 :                 debug!(error=?e, ?path, "could not read config file");
     340              :             }
     341            0 :             Err(e) => {
     342            0 :                 error!(error=?e, ?path, "could not read config file");
     343              :             }
     344              :         }
     345              : 
     346            0 :         init = false;
     347              :     }
     348              : }
     349              : 
     350            0 : async fn refresh_config_inner(path: &Utf8Path) -> Result<(), RefreshConfigError> {
     351            0 :     let bytes = tokio::fs::read(&path).await?;
     352            0 :     let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
     353              : 
     354            0 :     let mut jwks_set = vec![];
     355              : 
     356            0 :     fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
     357            0 :         let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
     358              : 
     359            0 :         ensure!(
     360            0 :             jwks_url.has_authority()
     361            0 :                 && (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
     362            0 :             "Invalid JWKS url. Must be HTTP",
     363              :         );
     364              : 
     365            0 :         ensure!(
     366            0 :             jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
     367            0 :             "Invalid JWKS url. No domain listed",
     368              :         );
     369              : 
     370              :         // clear username, password and ports
     371            0 :         jwks_url
     372            0 :             .set_username("")
     373            0 :             .expect("url can be a base and has a valid host and is not a file. should not error");
     374            0 :         jwks_url
     375            0 :             .set_password(None)
     376            0 :             .expect("url can be a base and has a valid host and is not a file. should not error");
     377            0 :         // local testing is hard if we need to have a specific restricted port
     378            0 :         if cfg!(not(feature = "testing")) {
     379            0 :             jwks_url.set_port(None).expect(
     380            0 :                 "url can be a base and has a valid host and is not a file. should not error",
     381            0 :             );
     382            0 :         }
     383              : 
     384              :         // clear query params
     385            0 :         jwks_url.set_fragment(None);
     386            0 :         jwks_url.query_pairs_mut().clear().finish();
     387            0 : 
     388            0 :         if jwks_url.scheme() != "https" {
     389              :             // local testing is hard if we need to set up https support.
     390            0 :             if cfg!(not(feature = "testing")) {
     391            0 :                 jwks_url
     392            0 :                     .set_scheme("https")
     393            0 :                     .expect("should not error to set the scheme to https if it was http");
     394            0 :             } else {
     395            0 :                 warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
     396              :             }
     397            0 :         }
     398              : 
     399            0 :         Ok(JwksSettings {
     400            0 :             id: jwks.id,
     401            0 :             jwks_url,
     402            0 :             provider_name: jwks.provider_name,
     403            0 :             jwt_audience: jwks.jwt_audience,
     404            0 :             role_names: jwks
     405            0 :                 .role_names
     406            0 :                 .into_iter()
     407            0 :                 .map(RoleName::from)
     408            0 :                 .map(|s| RoleNameInt::from(&s))
     409            0 :                 .collect(),
     410            0 :         })
     411            0 :     }
     412              : 
     413            0 :     for jwks in data.jwks.into_iter().flatten() {
     414            0 :         jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
     415              :     }
     416              : 
     417            0 :     info!("successfully loaded new config");
     418            0 :     JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
     419            0 : 
     420            0 :     Ok(())
     421            0 : }
        

Generated by: LCOV version 2.1-beta