LCOV - code coverage report
Current view: top level - proxy/src/bin - proxy.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 76.1 % 117 89
Test Date: 2023-09-06 10:18:01 Functions: 40.5 % 42 17

            Line data    Source code
       1              : use futures::future::Either;
       2              : use proxy::auth;
       3              : use proxy::console;
       4              : use proxy::http;
       5              : use proxy::metrics;
       6              : 
       7              : use anyhow::bail;
       8              : use proxy::config::{self, ProxyConfig};
       9              : use std::pin::pin;
      10              : use std::{borrow::Cow, net::SocketAddr};
      11              : use tokio::net::TcpListener;
      12              : use tokio::task::JoinSet;
      13              : use tokio_util::sync::CancellationToken;
      14              : use tracing::info;
      15              : use tracing::warn;
      16              : use utils::{project_git_version, sentry_init::init_sentry};
      17              : 
      18              : project_git_version!(GIT_VERSION);
      19              : 
      20              : use clap::{Parser, ValueEnum};
      21              : 
      22          129 : #[derive(Clone, Debug, ValueEnum)]
      23              : enum AuthBackend {
      24              :     Console,
      25              :     Postgres,
      26              :     Link,
      27              : }
      28              : 
      29              : /// Neon proxy/router
      30           28 : #[derive(Parser)]
      31              : #[command(version = GIT_VERSION, about)]
      32              : struct ProxyCliArgs {
      33              :     /// listen for incoming client connections on ip:port
      34              :     #[clap(short, long, default_value = "127.0.0.1:4432")]
      35            0 :     proxy: String,
      36           14 :     #[clap(value_enum, long, default_value_t = AuthBackend::Link)]
      37            0 :     auth_backend: AuthBackend,
      38              :     /// listen for management callback connection on ip:port
      39              :     #[clap(short, long, default_value = "127.0.0.1:7000")]
      40            0 :     mgmt: String,
      41              :     /// listen for incoming http connections (metrics, etc) on ip:port
      42              :     #[clap(long, default_value = "127.0.0.1:7001")]
      43            0 :     http: String,
      44              :     /// listen for incoming wss connections on ip:port
      45              :     #[clap(long)]
      46              :     wss: Option<String>,
      47              :     /// redirect unauthenticated users to the given uri in case of link auth
      48              :     #[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
      49            0 :     uri: String,
      50              :     /// cloud API endpoint for authenticating users
      51              :     #[clap(
      52              :         short,
      53              :         long,
      54              :         default_value = "http://localhost:3000/authenticate_proxy_request/"
      55              :     )]
      56            0 :     auth_endpoint: String,
      57              :     /// path to TLS key for client postgres connections
      58              :     ///
      59              :     /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
      60              :     #[clap(short = 'k', long, alias = "ssl-key")]
      61              :     tls_key: Option<String>,
      62              :     /// path to TLS cert for client postgres connections
      63              :     ///
      64              :     /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
      65              :     #[clap(short = 'c', long, alias = "ssl-cert")]
      66              :     tls_cert: Option<String>,
      67              :     /// path to directory with TLS certificates for client postgres connections
      68              :     #[clap(long)]
      69              :     certs_dir: Option<String>,
      70              :     /// http endpoint to receive periodic metric updates
      71              :     #[clap(long)]
      72              :     metric_collection_endpoint: Option<String>,
      73              :     /// how often metrics should be sent to a collection endpoint
      74              :     #[clap(long)]
      75              :     metric_collection_interval: Option<String>,
      76              :     /// cache for `wake_compute` api method (use `size=0` to disable)
      77              :     #[clap(long, default_value = config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO)]
      78            0 :     wake_compute_cache: String,
      79              :     /// Allow self-signed certificates for compute nodes (for testing)
      80           14 :     #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
      81            0 :     allow_self_signed_compute: bool,
      82            0 : }
      83              : 
      84              : #[tokio::main]
      85           14 : async fn main() -> anyhow::Result<()> {
      86           14 :     let _logging_guard = proxy::logging::init().await?;
      87           14 :     let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
      88           14 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
      89           14 : 
      90           14 :     info!("Version: {GIT_VERSION}");
      91           14 :     ::metrics::set_build_info_metric(GIT_VERSION);
      92           14 : 
      93           14 :     let args = ProxyCliArgs::parse();
      94           14 :     let config = build_config(&args)?;
      95              : 
      96           14 :     info!("Authentication backend: {}", config.auth_backend);
      97              : 
      98              :     // Check that we can bind to address before further initialization
      99           14 :     let http_address: SocketAddr = args.http.parse()?;
     100           14 :     info!("Starting http on {http_address}");
     101           14 :     let http_listener = TcpListener::bind(http_address).await?.into_std()?;
     102              : 
     103           14 :     let mgmt_address: SocketAddr = args.mgmt.parse()?;
     104           14 :     info!("Starting mgmt on {mgmt_address}");
     105           14 :     let mgmt_listener = TcpListener::bind(mgmt_address).await?;
     106              : 
     107           14 :     let proxy_address: SocketAddr = args.proxy.parse()?;
     108           14 :     info!("Starting proxy on {proxy_address}");
     109           14 :     let proxy_listener = TcpListener::bind(proxy_address).await?;
     110           14 :     let cancellation_token = CancellationToken::new();
     111           14 : 
     112           14 :     // client facing tasks. these will exit on error or on cancellation
     113           14 :     // cancellation returns Ok(())
     114           14 :     let mut client_tasks = JoinSet::new();
     115           14 :     client_tasks.spawn(proxy::proxy::task_main(
     116           14 :         config,
     117           14 :         proxy_listener,
     118           14 :         cancellation_token.clone(),
     119           14 :     ));
     120              : 
     121           14 :     if let Some(wss_address) = args.wss {
     122           14 :         let wss_address: SocketAddr = wss_address.parse()?;
     123           14 :         info!("Starting wss on {wss_address}");
     124           14 :         let wss_listener = TcpListener::bind(wss_address).await?;
     125              : 
     126           14 :         client_tasks.spawn(http::websocket::task_main(
     127           14 :             config,
     128           14 :             wss_listener,
     129           14 :             cancellation_token.clone(),
     130           14 :         ));
     131            0 :     }
     132              : 
     133              :     // maintenance tasks. these never return unless there's an error
     134           14 :     let mut maintenance_tasks = JoinSet::new();
     135           14 :     maintenance_tasks.spawn(proxy::handle_signals(cancellation_token));
     136           14 :     maintenance_tasks.spawn(http::server::task_main(http_listener));
     137           14 :     maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
     138              : 
     139           14 :     if let Some(metrics_config) = &config.metric_collection {
     140            1 :         maintenance_tasks.spawn(metrics::task_main(metrics_config));
     141           13 :     }
     142              : 
     143              :     let maintenance = loop {
     144              :         // get one complete task
     145           42 :         match futures::future::select(
     146           42 :             pin!(maintenance_tasks.join_next()),
     147           42 :             pin!(client_tasks.join_next()),
     148           42 :         )
     149           23 :         .await
     150              :         {
     151              :             // exit immediately on maintenance task completion
     152            0 :             Either::Left((Some(res), _)) => break proxy::flatten_err(res)?,
     153              :             // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
     154            0 :             Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
     155              :             // exit immediately on client task error
     156           28 :             Either::Right((Some(res), _)) => proxy::flatten_err(res)?,
     157              :             // exit if all our client tasks have shutdown gracefully
     158           14 :             Either::Right((None, _)) => return Ok(()),
     159              :         }
     160              :     };
     161              : 
     162              :     // maintenance tasks return Infallible success values, this is an impossible value
     163              :     // so this match statically ensures that there are no possibilities for that value
     164              :     match maintenance {}
     165              : }
     166              : 
     167              : /// ProxyConfig is created at proxy startup, and lives forever.
     168           14 : fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
     169           14 :     let tls_config = match (&args.tls_key, &args.tls_cert) {
     170           14 :         (Some(key_path), Some(cert_path)) => Some(config::configure_tls(
     171           14 :             key_path,
     172           14 :             cert_path,
     173           14 :             args.certs_dir.as_ref(),
     174           14 :         )?),
     175            0 :         (None, None) => None,
     176            0 :         _ => bail!("either both or neither tls-key and tls-cert must be specified"),
     177              :     };
     178              : 
     179           14 :     if args.allow_self_signed_compute {
     180            3 :         warn!("allowing self-signed compute certificates");
     181           11 :     }
     182              : 
     183           14 :     let metric_collection = match (
     184           14 :         &args.metric_collection_endpoint,
     185           14 :         &args.metric_collection_interval,
     186              :     ) {
     187            1 :         (Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
     188            1 :             endpoint: endpoint.parse()?,
     189            1 :             interval: humantime::parse_duration(interval)?,
     190              :         }),
     191           13 :         (None, None) => None,
     192            0 :         _ => bail!(
     193            0 :             "either both or neither metric-collection-endpoint \
     194            0 :              and metric-collection-interval must be specified"
     195            0 :         ),
     196              :     };
     197              : 
     198           14 :     let auth_backend = match &args.auth_backend {
     199              :         AuthBackend::Console => {
     200            0 :             let config::CacheOptions { size, ttl } = args.wake_compute_cache.parse()?;
     201              : 
     202            0 :             info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
     203            0 :             let caches = Box::leak(Box::new(console::caches::ApiCaches {
     204            0 :                 node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
     205            0 :             }));
     206              : 
     207            0 :             let url = args.auth_endpoint.parse()?;
     208            0 :             let endpoint = http::Endpoint::new(url, http::new_client());
     209            0 : 
     210            0 :             let api = console::provider::neon::Api::new(endpoint, caches);
     211            0 :             auth::BackendType::Console(Cow::Owned(api), ())
     212              :         }
     213              :         AuthBackend::Postgres => {
     214           11 :             let url = args.auth_endpoint.parse()?;
     215           11 :             let api = console::provider::mock::Api::new(url);
     216           11 :             auth::BackendType::Postgres(Cow::Owned(api), ())
     217              :         }
     218              :         AuthBackend::Link => {
     219            3 :             let url = args.uri.parse()?;
     220            3 :             auth::BackendType::Link(Cow::Owned(url))
     221              :         }
     222              :     };
     223              : 
     224           14 :     let config = Box::leak(Box::new(ProxyConfig {
     225           14 :         tls_config,
     226           14 :         auth_backend,
     227           14 :         metric_collection,
     228           14 :         allow_self_signed_compute: args.allow_self_signed_compute,
     229           14 :     }));
     230           14 : 
     231           14 :     Ok(config)
     232           14 : }
        

Generated by: LCOV version 2.1-beta