LCOV - differential code coverage report
Current view: top level - proxy/src/bin - proxy.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 75.4 % 126 95 31 95
Current Date: 2023-10-19 02:04:12 Functions: 37.0 % 46 17 29 17
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : use futures::future::Either;
       2                 : use proxy::auth;
       3                 : use proxy::config::HttpConfig;
       4                 : use proxy::console;
       5                 : use proxy::http;
       6                 : use proxy::metrics;
       7                 : 
       8                 : use anyhow::bail;
       9                 : use proxy::config::{self, ProxyConfig};
      10                 : use std::pin::pin;
      11                 : use std::{borrow::Cow, net::SocketAddr};
      12                 : use tokio::net::TcpListener;
      13                 : use tokio::task::JoinSet;
      14                 : use tokio_util::sync::CancellationToken;
      15                 : use tracing::info;
      16                 : use tracing::warn;
      17                 : use utils::{project_git_version, sentry_init::init_sentry};
      18                 : 
      19                 : project_git_version!(GIT_VERSION);
      20                 : 
      21                 : use clap::{Parser, ValueEnum};
      22                 : 
      23 CBC         147 : #[derive(Clone, Debug, ValueEnum)]
      24                 : enum AuthBackend {
      25                 :     Console,
      26                 :     Postgres,
      27                 :     Link,
      28                 : }
      29                 : 
      30                 : /// Neon proxy/router
      31              32 : #[derive(Parser)]
      32                 : #[command(version = GIT_VERSION, about)]
      33                 : struct ProxyCliArgs {
      34                 :     /// listen for incoming client connections on ip:port
      35                 :     #[clap(short, long, default_value = "127.0.0.1:4432")]
      36 UBC           0 :     proxy: String,
      37 CBC          16 :     #[clap(value_enum, long, default_value_t = AuthBackend::Link)]
      38 UBC           0 :     auth_backend: AuthBackend,
      39                 :     /// listen for management callback connection on ip:port
      40                 :     #[clap(short, long, default_value = "127.0.0.1:7000")]
      41               0 :     mgmt: String,
      42                 :     /// listen for incoming http connections (metrics, etc) on ip:port
      43                 :     #[clap(long, default_value = "127.0.0.1:7001")]
      44               0 :     http: String,
      45                 :     /// listen for incoming wss connections on ip:port
      46                 :     #[clap(long)]
      47                 :     wss: Option<String>,
      48                 :     /// redirect unauthenticated users to the given uri in case of link auth
      49                 :     #[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
      50               0 :     uri: String,
      51                 :     /// cloud API endpoint for authenticating users
      52                 :     #[clap(
      53                 :         short,
      54                 :         long,
      55                 :         default_value = "http://localhost:3000/authenticate_proxy_request/"
      56                 :     )]
      57               0 :     auth_endpoint: String,
      58                 :     /// path to TLS key for client postgres connections
      59                 :     ///
      60                 :     /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
      61                 :     #[clap(short = 'k', long, alias = "ssl-key")]
      62                 :     tls_key: Option<String>,
      63                 :     /// path to TLS cert for client postgres connections
      64                 :     ///
      65                 :     /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
      66                 :     #[clap(short = 'c', long, alias = "ssl-cert")]
      67                 :     tls_cert: Option<String>,
      68                 :     /// path to directory with TLS certificates for client postgres connections
      69                 :     #[clap(long)]
      70                 :     certs_dir: Option<String>,
      71                 :     /// http endpoint to receive periodic metric updates
      72                 :     #[clap(long)]
      73                 :     metric_collection_endpoint: Option<String>,
      74                 :     /// how often metrics should be sent to a collection endpoint
      75                 :     #[clap(long)]
      76                 :     metric_collection_interval: Option<String>,
      77                 :     /// cache for `wake_compute` api method (use `size=0` to disable)
      78                 :     #[clap(long, default_value = config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO)]
      79               0 :     wake_compute_cache: String,
      80                 :     /// Allow self-signed certificates for compute nodes (for testing)
      81 CBC          16 :     #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
      82 UBC           0 :     allow_self_signed_compute: bool,
      83               0 :     /// timeout for http connections
      84                 :     #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
      85               0 :     sql_over_http_timeout: tokio::time::Duration,
      86                 : 
      87                 :     /// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
      88 CBC          16 :     #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
      89 UBC           0 :     require_client_ip: bool,
      90               0 : }
      91                 : 
      92                 : #[tokio::main]
      93 CBC          16 : async fn main() -> anyhow::Result<()> {
      94              16 :     let _logging_guard = proxy::logging::init().await?;
      95              16 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
      96              16 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
      97              16 : 
      98              16 :     info!("Version: {GIT_VERSION}");
      99              16 :     ::metrics::set_build_info_metric(GIT_VERSION);
     100              16 : 
     101              16 :     let args = ProxyCliArgs::parse();
     102              16 :     let config = build_config(&args)?;
     103                 : 
     104              16 :     info!("Authentication backend: {}", config.auth_backend);
     105                 : 
     106                 :     // Check that we can bind to address before further initialization
     107              16 :     let http_address: SocketAddr = args.http.parse()?;
     108              16 :     info!("Starting http on {http_address}");
     109              16 :     let http_listener = TcpListener::bind(http_address).await?.into_std()?;
     110                 : 
     111              16 :     let mgmt_address: SocketAddr = args.mgmt.parse()?;
     112              16 :     info!("Starting mgmt on {mgmt_address}");
     113              16 :     let mgmt_listener = TcpListener::bind(mgmt_address).await?;
     114                 : 
     115              16 :     let proxy_address: SocketAddr = args.proxy.parse()?;
     116              16 :     info!("Starting proxy on {proxy_address}");
     117              16 :     let proxy_listener = TcpListener::bind(proxy_address).await?;
     118              16 :     let cancellation_token = CancellationToken::new();
     119              16 : 
     120              16 :     // client facing tasks. these will exit on error or on cancellation
     121              16 :     // cancellation returns Ok(())
     122              16 :     let mut client_tasks = JoinSet::new();
     123              16 :     client_tasks.spawn(proxy::proxy::task_main(
     124              16 :         config,
     125              16 :         proxy_listener,
     126              16 :         cancellation_token.clone(),
     127              16 :     ));
     128                 : 
     129              16 :     if let Some(wss_address) = args.wss {
     130              16 :         let wss_address: SocketAddr = wss_address.parse()?;
     131              16 :         info!("Starting wss on {wss_address}");
     132              16 :         let wss_listener = TcpListener::bind(wss_address).await?;
     133                 : 
     134              16 :         client_tasks.spawn(http::websocket::task_main(
     135              16 :             config,
     136              16 :             wss_listener,
     137              16 :             cancellation_token.clone(),
     138              16 :         ));
     139 UBC           0 :     }
     140                 : 
     141                 :     // maintenance tasks. these never return unless there's an error
     142 CBC          16 :     let mut maintenance_tasks = JoinSet::new();
     143              16 :     maintenance_tasks.spawn(proxy::handle_signals(cancellation_token));
     144              16 :     maintenance_tasks.spawn(http::server::task_main(http_listener));
     145              16 :     maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
     146                 : 
     147              16 :     if let Some(metrics_config) = &config.metric_collection {
     148               1 :         maintenance_tasks.spawn(metrics::task_main(metrics_config));
     149              15 :     }
     150                 : 
     151                 :     let maintenance = loop {
     152                 :         // get one complete task
     153              48 :         match futures::future::select(
     154              48 :             pin!(maintenance_tasks.join_next()),
     155              48 :             pin!(client_tasks.join_next()),
     156              48 :         )
     157              22 :         .await
     158                 :         {
     159                 :             // exit immediately on maintenance task completion
     160 UBC           0 :             Either::Left((Some(res), _)) => break proxy::flatten_err(res)?,
     161                 :             // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
     162               0 :             Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
     163                 :             // exit immediately on client task error
     164 CBC          32 :             Either::Right((Some(res), _)) => proxy::flatten_err(res)?,
     165                 :             // exit if all our client tasks have shutdown gracefully
     166              16 :             Either::Right((None, _)) => return Ok(()),
     167                 :         }
     168                 :     };
     169                 : 
     170                 :     // maintenance tasks return Infallible success values, this is an impossible value
     171                 :     // so this match statically ensures that there are no possibilities for that value
     172                 :     match maintenance {}
     173                 : }
     174                 : 
     175                 : /// ProxyConfig is created at proxy startup, and lives forever.
     176              16 : fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
     177              16 :     let tls_config = match (&args.tls_key, &args.tls_cert) {
     178              16 :         (Some(key_path), Some(cert_path)) => Some(config::configure_tls(
     179              16 :             key_path,
     180              16 :             cert_path,
     181              16 :             args.certs_dir.as_ref(),
     182              16 :         )?),
     183 UBC           0 :         (None, None) => None,
     184               0 :         _ => bail!("either both or neither tls-key and tls-cert must be specified"),
     185                 :     };
     186                 : 
     187 CBC          16 :     if args.allow_self_signed_compute {
     188               3 :         warn!("allowing self-signed compute certificates");
     189              13 :     }
     190                 : 
     191              16 :     let metric_collection = match (
     192              16 :         &args.metric_collection_endpoint,
     193              16 :         &args.metric_collection_interval,
     194                 :     ) {
     195               1 :         (Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
     196               1 :             endpoint: endpoint.parse()?,
     197               1 :             interval: humantime::parse_duration(interval)?,
     198                 :         }),
     199              15 :         (None, None) => None,
     200 UBC           0 :         _ => bail!(
     201               0 :             "either both or neither metric-collection-endpoint \
     202               0 :              and metric-collection-interval must be specified"
     203               0 :         ),
     204                 :     };
     205                 : 
     206 CBC          16 :     let auth_backend = match &args.auth_backend {
     207                 :         AuthBackend::Console => {
     208 UBC           0 :             let config::CacheOptions { size, ttl } = args.wake_compute_cache.parse()?;
     209                 : 
     210               0 :             info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
     211               0 :             let caches = Box::leak(Box::new(console::caches::ApiCaches {
     212               0 :                 node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
     213               0 :             }));
     214                 : 
     215               0 :             let url = args.auth_endpoint.parse()?;
     216               0 :             let endpoint = http::Endpoint::new(url, http::new_client());
     217               0 : 
     218               0 :             let api = console::provider::neon::Api::new(endpoint, caches);
     219               0 :             auth::BackendType::Console(Cow::Owned(api), ())
     220                 :         }
     221                 :         AuthBackend::Postgres => {
     222 CBC          13 :             let url = args.auth_endpoint.parse()?;
     223              13 :             let api = console::provider::mock::Api::new(url);
     224              13 :             auth::BackendType::Postgres(Cow::Owned(api), ())
     225                 :         }
     226                 :         AuthBackend::Link => {
     227               3 :             let url = args.uri.parse()?;
     228               3 :             auth::BackendType::Link(Cow::Owned(url))
     229                 :         }
     230                 :     };
     231              16 :     let http_config = HttpConfig {
     232              16 :         sql_over_http_timeout: args.sql_over_http_timeout,
     233              16 :     };
     234              16 :     let config = Box::leak(Box::new(ProxyConfig {
     235              16 :         tls_config,
     236              16 :         auth_backend,
     237              16 :         metric_collection,
     238              16 :         allow_self_signed_compute: args.allow_self_signed_compute,
     239              16 :         http_config,
     240              16 :         require_client_ip: args.require_client_ip,
     241              16 :     }));
     242              16 : 
     243              16 :     Ok(config)
     244              16 : }
        

Generated by: LCOV version 2.1-beta