LCOV - code coverage report
Current view: top level - proxy/src/bin - proxy.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 87.6 % 274 240
Test Date: 2024-02-12 20:26:03 Functions: 31.3 % 115 36

            Line data    Source code
       1              : use futures::future::Either;
       2              : use proxy::auth;
       3              : use proxy::auth::backend::MaybeOwned;
       4              : use proxy::config::AuthenticationConfig;
       5              : use proxy::config::CacheOptions;
       6              : use proxy::config::HttpConfig;
       7              : use proxy::config::ProjectInfoCacheOptions;
       8              : use proxy::console;
       9              : use proxy::context::parquet::ParquetUploadArgs;
      10              : use proxy::http;
      11              : use proxy::rate_limiter::EndpointRateLimiter;
      12              : use proxy::rate_limiter::RateBucketInfo;
      13              : use proxy::rate_limiter::RateLimiterConfig;
      14              : use proxy::redis::notifications;
      15              : use proxy::serverless::GlobalConnPoolOptions;
      16              : use proxy::usage_metrics;
      17              : 
      18              : use anyhow::bail;
      19              : use proxy::config::{self, ProxyConfig};
      20              : use proxy::serverless;
      21              : use std::net::SocketAddr;
      22              : use std::pin::pin;
      23              : use std::sync::Arc;
      24              : use tokio::net::TcpListener;
      25              : use tokio::task::JoinSet;
      26              : use tokio_util::sync::CancellationToken;
      27              : use tracing::info;
      28              : use tracing::warn;
      29              : use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
      30              : 
      31              : project_git_version!(GIT_VERSION);
      32              : project_build_tag!(BUILD_TAG);
      33              : 
      34              : use clap::{Parser, ValueEnum};
      35              : 
      36              : #[global_allocator]
      37         7934 : static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
      38              : 
      39          247 : #[derive(Clone, Debug, ValueEnum)]
      40              : enum AuthBackend {
      41              :     Console,
      42              :     #[cfg(feature = "testing")]
      43              :     Postgres,
      44              :     Link,
      45              : }
      46              : 
      47              : /// Neon proxy/router
      48           54 : #[derive(Parser)]
      49              : #[command(version = GIT_VERSION, about)]
      50              : struct ProxyCliArgs {
      51              :     /// Name of the region this proxy is deployed in
      52           27 :     #[clap(long, default_value_t = String::new())]
      53            0 :     region: String,
      54              :     /// listen for incoming client connections on ip:port
      55              :     #[clap(short, long, default_value = "127.0.0.1:4432")]
      56            0 :     proxy: String,
      57           27 :     #[clap(value_enum, long, default_value_t = AuthBackend::Link)]
      58            0 :     auth_backend: AuthBackend,
      59              :     /// listen for management callback connection on ip:port
      60              :     #[clap(short, long, default_value = "127.0.0.1:7000")]
      61            0 :     mgmt: String,
      62              :     /// listen for incoming http connections (metrics, etc) on ip:port
      63              :     #[clap(long, default_value = "127.0.0.1:7001")]
      64            0 :     http: String,
      65              :     /// listen for incoming wss connections on ip:port
      66              :     #[clap(long)]
      67              :     wss: Option<String>,
      68              :     /// redirect unauthenticated users to the given uri in case of link auth
      69              :     #[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
      70            0 :     uri: String,
      71              :     /// cloud API endpoint for authenticating users
      72              :     #[clap(
      73              :         short,
      74              :         long,
      75              :         default_value = "http://localhost:3000/authenticate_proxy_request/"
      76              :     )]
      77            0 :     auth_endpoint: String,
      78              :     /// path to TLS key for client postgres connections
      79              :     ///
      80              :     /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
      81              :     #[clap(short = 'k', long, alias = "ssl-key")]
      82              :     tls_key: Option<String>,
      83              :     /// path to TLS cert for client postgres connections
      84              :     ///
      85              :     /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
      86              :     #[clap(short = 'c', long, alias = "ssl-cert")]
      87              :     tls_cert: Option<String>,
      88              :     /// path to directory with TLS certificates for client postgres connections
      89              :     #[clap(long)]
      90              :     certs_dir: Option<String>,
      91              :     /// timeout for the TLS handshake
      92              :     #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
      93            0 :     handshake_timeout: tokio::time::Duration,
      94              :     /// http endpoint to receive periodic metric updates
      95              :     #[clap(long)]
      96              :     metric_collection_endpoint: Option<String>,
      97              :     /// how often metrics should be sent to a collection endpoint
      98              :     #[clap(long)]
      99              :     metric_collection_interval: Option<String>,
     100              :     /// cache for `wake_compute` api method (use `size=0` to disable)
     101              :     #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
     102            0 :     wake_compute_cache: String,
     103              :     /// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
     104              :     #[clap(long, default_value = config::WakeComputeLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
     105            0 :     wake_compute_lock: String,
     106              :     /// Allow self-signed certificates for compute nodes (for testing)
     107           27 :     #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
     108            0 :     allow_self_signed_compute: bool,
     109              :     #[clap(flatten)]
     110              :     sql_over_http: SqlOverHttpArgs,
     111              :     /// timeout for scram authentication protocol
     112              :     #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
     113            0 :     scram_protocol_timeout: tokio::time::Duration,
     114              :     /// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
     115           27 :     #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
     116            0 :     require_client_ip: bool,
     117              :     /// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
     118           27 :     #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
     119            0 :     disable_dynamic_rate_limiter: bool,
     120              :     /// Rate limit algorithm. Makes sense only if `disable_rate_limiter` is `false`.
     121           27 :     #[clap(value_enum, long, default_value_t = proxy::rate_limiter::RateLimitAlgorithm::Aimd)]
     122            0 :     rate_limit_algorithm: proxy::rate_limiter::RateLimitAlgorithm,
     123              :     /// Timeout for rate limiter. If it didn't manage to aquire a permit in this time, it will return an error.
     124              :     #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
     125            0 :     rate_limiter_timeout: tokio::time::Duration,
     126              :     /// Endpoint rate limiter max number of requests per second.
     127              :     ///
     128              :     /// Provided in the form '<Requests Per Second>@<Bucket Duration Size>'.
     129              :     /// Can be given multiple times for different bucket sizes.
     130          135 :     #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
     131           27 :     endpoint_rps_limit: Vec<RateBucketInfo>,
     132              :     /// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
     133           27 :     #[clap(long, default_value_t = 100)]
     134            0 :     initial_limit: usize,
     135              :     #[clap(flatten)]
     136              :     aimd_config: proxy::rate_limiter::AimdConfig,
     137              :     /// cache for `allowed_ips` (use `size=0` to disable)
     138              :     #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
     139            0 :     allowed_ips_cache: String,
     140              :     /// cache for `role_secret` (use `size=0` to disable)
     141              :     #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
     142            0 :     role_secret_cache: String,
     143              :     /// disable ip check for http requests. If it is too time consuming, it could be turned off.
     144           27 :     #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
     145            0 :     disable_ip_check_for_http: bool,
     146              :     /// redis url for notifications.
     147              :     #[clap(long)]
     148              :     redis_notifications: Option<String>,
     149              :     /// cache for `project_info` (use `size=0` to disable)
     150              :     #[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
     151            0 :     project_info_cache: String,
     152              : 
     153              :     #[clap(flatten)]
     154              :     parquet_upload: ParquetUploadArgs,
     155              : }
     156              : 
     157           54 : #[derive(clap::Args, Clone, Copy, Debug)]
     158              : struct SqlOverHttpArgs {
     159              :     /// timeout for http connection requests
     160              :     #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
     161            0 :     sql_over_http_timeout: tokio::time::Duration,
     162              : 
     163              :     /// Whether the SQL over http pool is opt-in
     164           27 :     #[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
     165            0 :     sql_over_http_pool_opt_in: bool,
     166              : 
     167              :     /// How many connections to pool for each endpoint. Excess connections are discarded
     168           27 :     #[clap(long, default_value_t = 20)]
     169            0 :     sql_over_http_pool_max_conns_per_endpoint: usize,
     170              : 
     171              :     /// How many connections to pool for each endpoint. Excess connections are discarded
     172           27 :     #[clap(long, default_value_t = 20000)]
     173            0 :     sql_over_http_pool_max_total_conns: usize,
     174              : 
     175              :     /// How long pooled connections should remain idle for before closing
     176              :     #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
     177            0 :     sql_over_http_idle_timeout: tokio::time::Duration,
     178              : 
     179              :     /// Duration each shard will wait on average before a GC sweep.
     180              :     /// A longer time will causes sweeps to take longer but will interfere less frequently.
     181              :     #[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
     182            0 :     sql_over_http_pool_gc_epoch: tokio::time::Duration,
     183              : 
     184              :     /// How many shards should the global pool have. Must be a power of two.
     185              :     /// More shards will introduce less contention for pool operations, but can
     186              :     /// increase memory used by the pool
     187           27 :     #[clap(long, default_value_t = 128)]
     188            0 :     sql_over_http_pool_shards: usize,
     189              : }
     190              : 
     191              : #[tokio::main]
     192           25 : async fn main() -> anyhow::Result<()> {
     193           25 :     let _logging_guard = proxy::logging::init().await?;
     194           25 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
     195           25 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
     196           25 : 
     197           25 :     info!("Version: {GIT_VERSION}");
     198           25 :     info!("Build_tag: {BUILD_TAG}");
     199           25 :     ::metrics::set_build_info_metric(GIT_VERSION, BUILD_TAG);
     200           25 : 
     201           25 :     match proxy::jemalloc::MetricRecorder::new(prometheus::default_registry()) {
     202           25 :         Ok(t) => {
     203           25 :             t.start();
     204           25 :         }
     205           25 :         Err(e) => tracing::error!(error = ?e, "could not start jemalloc metrics loop"),
     206           25 :     }
     207           25 : 
     208           25 :     let args = ProxyCliArgs::parse();
     209           25 :     let config = build_config(&args)?;
     210           25 : 
     211           25 :     info!("Authentication backend: {}", config.auth_backend);
     212           25 : 
     213           25 :     // Check that we can bind to address before further initialization
     214           25 :     let http_address: SocketAddr = args.http.parse()?;
     215           25 :     info!("Starting http on {http_address}");
     216           25 :     let http_listener = TcpListener::bind(http_address).await?.into_std()?;
     217           25 : 
     218           25 :     let mgmt_address: SocketAddr = args.mgmt.parse()?;
     219           25 :     info!("Starting mgmt on {mgmt_address}");
     220           25 :     let mgmt_listener = TcpListener::bind(mgmt_address).await?;
     221           25 : 
     222           25 :     let proxy_address: SocketAddr = args.proxy.parse()?;
     223           25 :     info!("Starting proxy on {proxy_address}");
     224           25 :     let proxy_listener = TcpListener::bind(proxy_address).await?;
     225           25 :     let cancellation_token = CancellationToken::new();
     226           25 : 
     227           25 :     let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
     228           25 : 
     229           25 :     // client facing tasks. these will exit on error or on cancellation
     230           25 :     // cancellation returns Ok(())
     231           25 :     let mut client_tasks = JoinSet::new();
     232           25 :     client_tasks.spawn(proxy::proxy::task_main(
     233           25 :         config,
     234           25 :         proxy_listener,
     235           25 :         cancellation_token.clone(),
     236           25 :         endpoint_rate_limiter.clone(),
     237           25 :     ));
     238           25 : 
     239           25 :     // TODO: rename the argument to something like serverless.
     240           25 :     // It now covers more than just websockets, it also covers SQL over HTTP.
     241           25 :     if let Some(serverless_address) = args.wss {
     242           25 :         let serverless_address: SocketAddr = serverless_address.parse()?;
     243           25 :         info!("Starting wss on {serverless_address}");
     244           25 :         let serverless_listener = TcpListener::bind(serverless_address).await?;
     245           25 : 
     246           25 :         client_tasks.spawn(serverless::task_main(
     247           25 :             config,
     248           25 :             serverless_listener,
     249           25 :             cancellation_token.clone(),
     250           25 :             endpoint_rate_limiter.clone(),
     251           25 :         ));
     252           25 :     }
     253           25 : 
     254           25 :     client_tasks.spawn(proxy::context::parquet::worker(
     255           25 :         cancellation_token.clone(),
     256           25 :         args.parquet_upload,
     257           25 :     ));
     258           25 : 
     259           25 :     // maintenance tasks. these never return unless there's an error
     260           25 :     let mut maintenance_tasks = JoinSet::new();
     261           25 :     maintenance_tasks.spawn(proxy::handle_signals(cancellation_token));
     262           25 :     maintenance_tasks.spawn(http::health_server::task_main(http_listener));
     263           25 :     maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
     264           25 : 
     265           25 :     if let Some(metrics_config) = &config.metric_collection {
     266            1 :         maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
     267           24 :     }
     268           25 : 
     269           25 :     if let auth::BackendType::Console(api, _) = &config.auth_backend {
     270           25 :         if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
     271           25 :             let cache = api.caches.project_info.clone();
     272           25 :             if let Some(url) = args.redis_notifications {
     273           25 :                 info!("Starting redis notifications listener ({url})");
     274           25 :                 maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
     275           25 :             }
     276           25 :             maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
     277           25 :         }
     278           25 :     }
     279           25 : 
     280           25 :     let maintenance = loop {
     281           25 :         // get one complete task
     282          100 :         match futures::future::select(
     283          100 :             pin!(maintenance_tasks.join_next()),
     284          100 :             pin!(client_tasks.join_next()),
     285          100 :         )
     286           56 :         .await
     287           25 :         {
     288           25 :             // exit immediately on maintenance task completion
     289           25 :             Either::Left((Some(res), _)) => break proxy::flatten_err(res)?,
     290           25 :             // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
     291           25 :             Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
     292           25 :             // exit immediately on client task error
     293           75 :             Either::Right((Some(res), _)) => proxy::flatten_err(res)?,
     294           25 :             // exit if all our client tasks have shutdown gracefully
     295           25 :             Either::Right((None, _)) => return Ok(()),
     296           25 :         }
     297           25 :     };
     298           25 : 
     299           25 :     // maintenance tasks return Infallible success values, this is an impossible value
     300           25 :     // so this match statically ensures that there are no possibilities for that value
     301           25 :     match maintenance {}
     302           25 : }
     303              : 
     304              : /// ProxyConfig is created at proxy startup, and lives forever.
     305           25 : fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
     306           25 :     let tls_config = match (&args.tls_key, &args.tls_cert) {
     307           25 :         (Some(key_path), Some(cert_path)) => Some(config::configure_tls(
     308           25 :             key_path,
     309           25 :             cert_path,
     310           25 :             args.certs_dir.as_ref(),
     311           25 :         )?),
     312            0 :         (None, None) => None,
     313            0 :         _ => bail!("either both or neither tls-key and tls-cert must be specified"),
     314              :     };
     315              : 
     316           25 :     if args.allow_self_signed_compute {
     317            3 :         warn!("allowing self-signed compute certificates");
     318           22 :     }
     319              : 
     320           25 :     let metric_collection = match (
     321           25 :         &args.metric_collection_endpoint,
     322           25 :         &args.metric_collection_interval,
     323              :     ) {
     324            1 :         (Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
     325            1 :             endpoint: endpoint.parse()?,
     326            1 :             interval: humantime::parse_duration(interval)?,
     327              :         }),
     328           24 :         (None, None) => None,
     329            0 :         _ => bail!(
     330            0 :             "either both or neither metric-collection-endpoint \
     331            0 :              and metric-collection-interval must be specified"
     332            0 :         ),
     333              :     };
     334           25 :     let rate_limiter_config = RateLimiterConfig {
     335           25 :         disable: args.disable_dynamic_rate_limiter,
     336           25 :         algorithm: args.rate_limit_algorithm,
     337           25 :         timeout: args.rate_limiter_timeout,
     338           25 :         initial_limit: args.initial_limit,
     339           25 :         aimd_config: Some(args.aimd_config),
     340           25 :     };
     341              : 
     342           25 :     let auth_backend = match &args.auth_backend {
     343              :         AuthBackend::Console => {
     344            1 :             let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
     345            1 :             let project_info_cache_config: ProjectInfoCacheOptions =
     346            1 :                 args.project_info_cache.parse()?;
     347              : 
     348            1 :             info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
     349            1 :             info!(
     350            1 :                 "Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
     351            1 :             );
     352            1 :             let caches = Box::leak(Box::new(console::caches::ApiCaches::new(
     353            1 :                 wake_compute_cache_config,
     354            1 :                 project_info_cache_config,
     355            1 :             )));
     356              : 
     357              :             let config::WakeComputeLockOptions {
     358            1 :                 shards,
     359            1 :                 permits,
     360            1 :                 epoch,
     361            1 :                 timeout,
     362            1 :             } = args.wake_compute_lock.parse()?;
     363            1 :             info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)");
     364            1 :             let locks = Box::leak(Box::new(
     365            1 :                 console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout)
     366            1 :                     .unwrap(),
     367            1 :             ));
     368            1 :             tokio::spawn(locks.garbage_collect_worker(epoch));
     369              : 
     370            1 :             let url = args.auth_endpoint.parse()?;
     371            1 :             let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));
     372            1 : 
     373            1 :             let api = console::provider::neon::Api::new(endpoint, caches, locks);
     374            1 :             let api = console::provider::ConsoleBackend::Console(api);
     375            1 :             auth::BackendType::Console(MaybeOwned::Owned(api), ())
     376              :         }
     377              :         #[cfg(feature = "testing")]
     378              :         AuthBackend::Postgres => {
     379           21 :             let url = args.auth_endpoint.parse()?;
     380           21 :             let api = console::provider::mock::Api::new(url);
     381           21 :             let api = console::provider::ConsoleBackend::Postgres(api);
     382           21 :             auth::BackendType::Console(MaybeOwned::Owned(api), ())
     383              :         }
     384              :         AuthBackend::Link => {
     385            3 :             let url = args.uri.parse()?;
     386            3 :             auth::BackendType::Link(MaybeOwned::Owned(url))
     387              :         }
     388              :     };
     389           25 :     let http_config = HttpConfig {
     390           25 :         request_timeout: args.sql_over_http.sql_over_http_timeout,
     391           25 :         pool_options: GlobalConnPoolOptions {
     392           25 :             max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
     393           25 :             gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
     394           25 :             pool_shards: args.sql_over_http.sql_over_http_pool_shards,
     395           25 :             idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
     396           25 :             opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
     397           25 :             max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
     398           25 :         },
     399           25 :     };
     400           25 :     let authentication_config = AuthenticationConfig {
     401           25 :         scram_protocol_timeout: args.scram_protocol_timeout,
     402           25 :     };
     403           25 : 
     404           25 :     let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
     405           25 :     RateBucketInfo::validate(&mut endpoint_rps_limit)?;
     406              : 
     407           25 :     let config = Box::leak(Box::new(ProxyConfig {
     408           25 :         tls_config,
     409           25 :         auth_backend,
     410           25 :         metric_collection,
     411           25 :         allow_self_signed_compute: args.allow_self_signed_compute,
     412           25 :         http_config,
     413           25 :         authentication_config,
     414           25 :         require_client_ip: args.require_client_ip,
     415           25 :         disable_ip_check_for_http: args.disable_ip_check_for_http,
     416           25 :         endpoint_rps_limit,
     417           25 :         handshake_timeout: args.handshake_timeout,
     418           25 :         // TODO: add this argument
     419           25 :         region: args.region.clone(),
     420           25 :     }));
     421           25 : 
     422           25 :     Ok(config)
     423           25 : }
     424              : 
     425              : #[cfg(test)]
     426              : mod tests {
     427              :     use std::time::Duration;
     428              : 
     429              :     use clap::Parser;
     430              :     use proxy::rate_limiter::RateBucketInfo;
     431              : 
     432            2 :     #[test]
     433            2 :     fn parse_endpoint_rps_limit() {
     434            2 :         let config = super::ProxyCliArgs::parse_from([
     435            2 :             "proxy",
     436            2 :             "--endpoint-rps-limit",
     437            2 :             "100@1s",
     438            2 :             "--endpoint-rps-limit",
     439            2 :             "20@30s",
     440            2 :         ]);
     441            2 : 
     442            2 :         assert_eq!(
     443            2 :             config.endpoint_rps_limit,
     444            2 :             vec![
     445            2 :                 RateBucketInfo::new(100, Duration::from_secs(1)),
     446            2 :                 RateBucketInfo::new(20, Duration::from_secs(30)),
     447            2 :             ]
     448            2 :         );
     449            2 :     }
     450              : }
        

Generated by: LCOV version 2.1-beta