LCOV - code coverage report
Current view: top level - proxy/src/binary - local_proxy.rs (source / functions) Coverage Total Hit
Test: 2da33972d224f272aca590aed34655d25ca4e832.info Lines: 0.0 % 145 0
Test Date: 2025-07-28 19:13:49 Functions: 0.0 % 6 0

            Line data    Source code
       1              : use std::env;
       2              : use std::net::SocketAddr;
       3              : use std::pin::pin;
       4              : use std::sync::Arc;
       5              : use std::time::Duration;
       6              : 
       7              : use anyhow::bail;
       8              : use arc_swap::ArcSwapOption;
       9              : use camino::Utf8PathBuf;
      10              : use clap::Parser;
      11              : use futures::future::Either;
      12              : use tokio::net::TcpListener;
      13              : use tokio::sync::Notify;
      14              : use tokio::task::JoinSet;
      15              : use tokio_util::sync::CancellationToken;
      16              : use tracing::{debug, error, info};
      17              : use utils::sentry_init::init_sentry;
      18              : use utils::{pid_file, project_build_tag, project_git_version};
      19              : 
      20              : use crate::auth::backend::jwt::JwkCache;
      21              : use crate::auth::backend::local::LocalBackend;
      22              : use crate::auth::{self};
      23              : use crate::cancellation::CancellationHandler;
      24              : #[cfg(feature = "rest_broker")]
      25              : use crate::config::RestConfig;
      26              : use crate::config::{
      27              :     self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
      28              :     refresh_config_loop,
      29              : };
      30              : use crate::control_plane::locks::ApiLocks;
      31              : use crate::http::health_server::AppMetrics;
      32              : use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
      33              : use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
      34              : use crate::scram::threadpool::ThreadPool;
      35              : use crate::serverless::cancel_set::CancelSet;
      36              : use crate::serverless::{self, GlobalConnPoolOptions};
      37              : use crate::tls::client_config::compute_client_config_with_root_certs;
      38              : use crate::url::ApiUrl;
      39              : 
      40              : project_git_version!(GIT_VERSION);
      41              : project_build_tag!(BUILD_TAG);
      42              : 
      43              : /// Neon proxy/router
      44              : #[derive(Parser)]
      45              : #[command(version = GIT_VERSION, about)]
      46              : struct LocalProxyCliArgs {
      47              :     /// listen for incoming metrics connections on ip:port
      48              :     #[clap(long, default_value = "127.0.0.1:7001")]
      49              :     metrics: String,
      50              :     /// listen for incoming TCP connections on ip:port
      51              :     #[clap(short, long, default_value = "127.0.0.1:4432")]
      52              :     proxy: Option<SocketAddr>,
      53              :     /// listen for incoming http connections on ip:port
      54              :     #[clap(long)]
      55              :     http: String,
      56              :     /// timeout for the TLS handshake
      57              :     #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
      58              :     handshake_timeout: tokio::time::Duration,
      59              :     /// lock for `connect_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
      60              :     #[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK)]
      61              :     connect_compute_lock: String,
      62              :     #[clap(flatten)]
      63              :     sql_over_http: SqlOverHttpArgs,
      64              :     /// User rate limiter max number of requests per second.
      65              :     ///
      66              :     /// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
      67              :     /// Can be given multiple times for different bucket sizes.
      68              :     #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
      69              :     user_rps_limit: Vec<RateBucketInfo>,
      70              :     /// Whether to retry the connection to the compute node
      71              :     #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
      72              :     connect_to_compute_retry: String,
      73              :     /// Address of the postgres server
      74              :     #[clap(long, default_value = "127.0.0.1:5432")]
      75              :     postgres: SocketAddr,
      76              :     /// Address of the internal compute-ctl api service
      77              :     #[clap(long, default_value = "http://127.0.0.1:3081/")]
      78              :     compute_ctl: ApiUrl,
      79              :     /// Path of the local proxy config file
      80              :     #[clap(long, default_value = "./local_proxy.json")]
      81              :     config_path: Utf8PathBuf,
      82              :     /// Path of the local proxy PID file
      83              :     #[clap(long, default_value = "./local_proxy.pid")]
      84              :     pid_path: Utf8PathBuf,
      85              :     /// Disable pg_session_jwt extension installation
      86              :     /// This is useful for testing the local proxy with vanilla postgres.
      87              :     #[clap(long, default_value = "false")]
      88              :     #[cfg(feature = "testing")]
      89              :     disable_pg_session_jwt: bool,
      90              : }
      91              : 
      92              : #[derive(clap::Args, Clone, Copy, Debug)]
      93              : struct SqlOverHttpArgs {
      94              :     /// How many connections to pool for each endpoint. Excess connections are discarded
      95              :     #[clap(long, default_value_t = 200)]
      96              :     sql_over_http_pool_max_total_conns: usize,
      97              : 
      98              :     /// How long pooled connections should remain idle for before closing
      99              :     #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
     100              :     sql_over_http_idle_timeout: tokio::time::Duration,
     101              : 
     102              :     #[clap(long, default_value_t = 100)]
     103              :     sql_over_http_client_conn_threshold: u64,
     104              : 
     105              :     #[clap(long, default_value_t = 16)]
     106              :     sql_over_http_cancel_set_shards: usize,
     107              : 
     108              :     #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
     109              :     sql_over_http_max_request_size_bytes: usize,
     110              : 
     111              :     #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
     112              :     sql_over_http_max_response_size_bytes: usize,
     113              : }
     114              : 
     115            0 : pub async fn run() -> anyhow::Result<()> {
     116            0 :     let _logging_guard = crate::logging::init_local_proxy()?;
     117            0 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
     118            0 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
     119              : 
     120            0 :     Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
     121              : 
     122              :     // TODO: refactor these to use labels
     123            0 :     debug!("Version: {GIT_VERSION}");
     124            0 :     debug!("Build_tag: {BUILD_TAG}");
     125            0 :     let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo {
     126            0 :         revision: GIT_VERSION,
     127            0 :         build_tag: BUILD_TAG,
     128            0 :     });
     129              : 
     130            0 :     let jemalloc = match crate::jemalloc::MetricRecorder::new() {
     131            0 :         Ok(t) => Some(t),
     132            0 :         Err(e) => {
     133            0 :             tracing::error!(error = ?e, "could not start jemalloc metrics loop");
     134            0 :             None
     135              :         }
     136              :     };
     137              : 
     138            0 :     let args = LocalProxyCliArgs::parse();
     139            0 :     let config = build_config(&args)?;
     140            0 :     let auth_backend = build_auth_backend(&args);
     141              : 
     142              :     // before we bind to any ports, write the process ID to a file
     143              :     // so that compute-ctl can find our process later
     144              :     // in order to trigger the appropriate SIGHUP on config change.
     145              :     //
     146              :     // This also claims a "lock" that makes sure only one instance
     147              :     // of local_proxy runs at a time.
     148            0 :     let _process_guard = loop {
     149            0 :         match pid_file::claim_for_current_process(&args.pid_path) {
     150            0 :             Ok(guard) => break guard,
     151            0 :             Err(e) => {
     152              :                 // compute-ctl might have tried to read the pid-file to let us
     153              :                 // know about some config change. We should try again.
     154            0 :                 error!(path=?args.pid_path, "could not claim PID file guard: {e:?}");
     155            0 :                 tokio::time::sleep(Duration::from_secs(1)).await;
     156              :             }
     157              :         }
     158              :     };
     159              : 
     160            0 :     let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
     161            0 :     let http_listener = TcpListener::bind(args.http).await?;
     162              : 
     163            0 :     let tcp_listener = if let Some(proxy_addr) = args.proxy {
     164            0 :         Some(TcpListener::bind(proxy_addr).await?)
     165              :     } else {
     166            0 :         None
     167              :     };
     168              : 
     169            0 :     let shutdown = CancellationToken::new();
     170              : 
     171              :     // todo: should scale with CU
     172            0 :     let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     173            0 :         LeakyBucketConfig {
     174            0 :             rps: 10.0,
     175            0 :             max: 100.0,
     176            0 :         },
     177              :         16,
     178              :     ));
     179              : 
     180            0 :     let mut maintenance_tasks = JoinSet::new();
     181              : 
     182            0 :     let refresh_config_notify = Arc::new(Notify::new());
     183            0 :     maintenance_tasks.spawn(crate::signals::handle(shutdown.clone(), {
     184            0 :         let refresh_config_notify = Arc::clone(&refresh_config_notify);
     185            0 :         move || {
     186            0 :             refresh_config_notify.notify_one();
     187            0 :         }
     188              :     }));
     189              : 
     190              :     // trigger the first config load **after** setting up the signal hook
     191              :     // to avoid the race condition where:
     192              :     // 1. No config file registered when local_proxy starts up
     193              :     // 2. The config file is written but the signal hook is not yet received
     194              :     // 3. local_proxy completes startup but has no config loaded, despite there being a registerd config.
     195            0 :     refresh_config_notify.notify_one();
     196            0 :     tokio::spawn(refresh_config_loop(
     197            0 :         config,
     198            0 :         args.config_path,
     199            0 :         refresh_config_notify,
     200              :     ));
     201              : 
     202            0 :     maintenance_tasks.spawn(crate::http::health_server::task_main(
     203            0 :         metrics_listener,
     204            0 :         AppMetrics {
     205            0 :             jemalloc,
     206            0 :             neon_metrics,
     207            0 :             proxy: crate::metrics::Metrics::get(),
     208            0 :         },
     209              :     ));
     210              : 
     211            0 :     let task = serverless::task_main(
     212            0 :         config,
     213            0 :         auth_backend,
     214            0 :         http_listener,
     215            0 :         shutdown.clone(),
     216            0 :         Arc::new(CancellationHandler::new(&config.connect_to_compute)),
     217            0 :         endpoint_rate_limiter,
     218              :     );
     219              : 
     220              :     // if tcp_listener.is_some() {
     221              :     //     maintenance_tasks.spawn(serverless::tcp_listener::task_main(
     222              :     //         tcp_listener.unwrap(),
     223              :     //         shutdown.clone(),
     224              :     //         auth_backend,
     225              :     //         config,
     226              :     //         endpoint_rate_limiter,
     227              :     //     ));
     228              :     // }
     229              : 
     230            0 :     Metrics::get()
     231            0 :         .service
     232            0 :         .info
     233            0 :         .set_label(ServiceInfo::running());
     234              : 
     235            0 :     match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {
     236              :         // exit immediately on maintenance task completion
     237            0 :         Either::Left((Some(res), _)) => match crate::error::flatten_err(res)? {},
     238              :         // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
     239            0 :         Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
     240              :         // exit immediately on client task error
     241            0 :         Either::Right((res, _)) => res?,
     242              :     }
     243              : 
     244            0 :     Ok(())
     245            0 : }
     246              : 
     247              : /// ProxyConfig is created at proxy startup, and lives forever.
     248            0 : fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
     249              :     let config::ConcurrencyLockOptions {
     250            0 :         shards,
     251            0 :         limiter,
     252            0 :         epoch,
     253            0 :         timeout,
     254            0 :     } = args.connect_compute_lock.parse()?;
     255            0 :     info!(
     256              :         ?limiter,
     257              :         shards,
     258              :         ?epoch,
     259            0 :         "Using NodeLocks (connect_compute)"
     260              :     );
     261            0 :     let connect_compute_locks = ApiLocks::new(
     262              :         "connect_compute_lock",
     263            0 :         limiter,
     264            0 :         shards,
     265            0 :         timeout,
     266            0 :         epoch,
     267            0 :         &Metrics::get().proxy.connect_compute_lock,
     268              :     );
     269              : 
     270            0 :     let http_config = HttpConfig {
     271            0 :         accept_websockets: false,
     272            0 :         pool_options: GlobalConnPoolOptions {
     273            0 :             gc_epoch: Duration::from_secs(60),
     274            0 :             pool_shards: 2,
     275            0 :             idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
     276            0 :             opt_in: false,
     277            0 : 
     278            0 :             max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_total_conns,
     279            0 :             max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
     280            0 :         },
     281            0 :         cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
     282            0 :         client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
     283            0 :         max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
     284            0 :         max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
     285            0 :     };
     286              : 
     287            0 :     let compute_config = ComputeConfig {
     288            0 :         retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
     289            0 :         tls: Arc::new(compute_client_config_with_root_certs()?),
     290            0 :         timeout: Duration::from_secs(2),
     291              :     };
     292              : 
     293            0 :     let greetings = env::var_os("NEON_MOTD").map_or(String::new(), |s| match s.into_string() {
     294            0 :         Ok(s) => s,
     295              :         Err(_) => {
     296            0 :             debug!("NEON_MOTD environment variable is not valid UTF-8");
     297            0 :             String::new()
     298              :         }
     299            0 :     });
     300              : 
     301            0 :     Ok(Box::leak(Box::new(ProxyConfig {
     302            0 :         tls_config: ArcSwapOption::from(None),
     303            0 :         metric_collection: None,
     304            0 :         http_config,
     305            0 :         authentication_config: AuthenticationConfig {
     306            0 :             jwks_cache: JwkCache::default(),
     307            0 :             thread_pool: ThreadPool::new(0),
     308            0 :             scram_protocol_timeout: Duration::from_secs(10),
     309            0 :             ip_allowlist_check_enabled: true,
     310            0 :             is_vpc_acccess_proxy: false,
     311            0 :             is_auth_broker: false,
     312            0 :             accept_jwts: true,
     313            0 :             console_redirect_confirmation_timeout: Duration::ZERO,
     314            0 :         },
     315              :         #[cfg(feature = "rest_broker")]
     316            0 :         rest_config: RestConfig {
     317            0 :             is_rest_broker: false,
     318            0 :             db_schema_cache: None,
     319            0 :             max_schema_size: 0,
     320            0 :             hostname_prefix: String::new(),
     321            0 :         },
     322            0 :         proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
     323            0 :         handshake_timeout: Duration::from_secs(10),
     324            0 :         wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
     325            0 :         connect_compute_locks,
     326            0 :         connect_to_compute: compute_config,
     327            0 :         greetings,
     328              :         #[cfg(feature = "testing")]
     329            0 :         disable_pg_session_jwt: args.disable_pg_session_jwt,
     330              :     })))
     331            0 : }
     332              : 
     333              : /// auth::Backend is created at proxy startup, and lives forever.
     334            0 : fn build_auth_backend(args: &LocalProxyCliArgs) -> &'static auth::Backend<'static, ()> {
     335            0 :     let auth_backend = crate::auth::Backend::Local(crate::auth::backend::MaybeOwned::Owned(
     336            0 :         LocalBackend::new(args.postgres, args.compute_ctl.clone()),
     337            0 :     ));
     338              : 
     339            0 :     Box::leak(Box::new(auth_backend))
     340            0 : }
        

Generated by: LCOV version 2.1-beta