Line data Source code
1 : use futures::future::Either;
2 : use proxy::auth;
3 : use proxy::console;
4 : use proxy::http;
5 : use proxy::metrics;
6 :
7 : use anyhow::bail;
8 : use proxy::config::{self, ProxyConfig};
9 : use std::pin::pin;
10 : use std::{borrow::Cow, net::SocketAddr};
11 : use tokio::net::TcpListener;
12 : use tokio::task::JoinSet;
13 : use tokio_util::sync::CancellationToken;
14 : use tracing::info;
15 : use tracing::warn;
16 : use utils::{project_git_version, sentry_init::init_sentry};
17 :
18 : project_git_version!(GIT_VERSION);
19 :
20 : use clap::{Parser, ValueEnum};
21 :
22 129 : #[derive(Clone, Debug, ValueEnum)]
23 : enum AuthBackend {
24 : Console,
25 : Postgres,
26 : Link,
27 : }
28 :
29 : /// Neon proxy/router
30 28 : #[derive(Parser)]
31 : #[command(version = GIT_VERSION, about)]
32 : struct ProxyCliArgs {
33 : /// listen for incoming client connections on ip:port
34 : #[clap(short, long, default_value = "127.0.0.1:4432")]
35 0 : proxy: String,
36 14 : #[clap(value_enum, long, default_value_t = AuthBackend::Link)]
37 0 : auth_backend: AuthBackend,
38 : /// listen for management callback connection on ip:port
39 : #[clap(short, long, default_value = "127.0.0.1:7000")]
40 0 : mgmt: String,
41 : /// listen for incoming http connections (metrics, etc) on ip:port
42 : #[clap(long, default_value = "127.0.0.1:7001")]
43 0 : http: String,
44 : /// listen for incoming wss connections on ip:port
45 : #[clap(long)]
46 : wss: Option<String>,
47 : /// redirect unauthenticated users to the given uri in case of link auth
48 : #[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
49 0 : uri: String,
50 : /// cloud API endpoint for authenticating users
51 : #[clap(
52 : short,
53 : long,
54 : default_value = "http://localhost:3000/authenticate_proxy_request/"
55 : )]
56 0 : auth_endpoint: String,
57 : /// path to TLS key for client postgres connections
58 : ///
59 : /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
60 : #[clap(short = 'k', long, alias = "ssl-key")]
61 : tls_key: Option<String>,
62 : /// path to TLS cert for client postgres connections
63 : ///
64 : /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
65 : #[clap(short = 'c', long, alias = "ssl-cert")]
66 : tls_cert: Option<String>,
67 : /// path to directory with TLS certificates for client postgres connections
68 : #[clap(long)]
69 : certs_dir: Option<String>,
70 : /// http endpoint to receive periodic metric updates
71 : #[clap(long)]
72 : metric_collection_endpoint: Option<String>,
73 : /// how often metrics should be sent to a collection endpoint
74 : #[clap(long)]
75 : metric_collection_interval: Option<String>,
76 : /// cache for `wake_compute` api method (use `size=0` to disable)
77 : #[clap(long, default_value = config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO)]
78 0 : wake_compute_cache: String,
79 : /// Allow self-signed certificates for compute nodes (for testing)
80 14 : #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
81 0 : allow_self_signed_compute: bool,
82 0 : }
83 :
84 : #[tokio::main]
85 14 : async fn main() -> anyhow::Result<()> {
86 14 : let _logging_guard = proxy::logging::init().await?;
87 14 : let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
88 14 : let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
89 14 :
90 14 : info!("Version: {GIT_VERSION}");
91 14 : ::metrics::set_build_info_metric(GIT_VERSION);
92 14 :
93 14 : let args = ProxyCliArgs::parse();
94 14 : let config = build_config(&args)?;
95 :
96 14 : info!("Authentication backend: {}", config.auth_backend);
97 :
98 : // Check that we can bind to address before further initialization
99 14 : let http_address: SocketAddr = args.http.parse()?;
100 14 : info!("Starting http on {http_address}");
101 14 : let http_listener = TcpListener::bind(http_address).await?.into_std()?;
102 :
103 14 : let mgmt_address: SocketAddr = args.mgmt.parse()?;
104 14 : info!("Starting mgmt on {mgmt_address}");
105 14 : let mgmt_listener = TcpListener::bind(mgmt_address).await?;
106 :
107 14 : let proxy_address: SocketAddr = args.proxy.parse()?;
108 14 : info!("Starting proxy on {proxy_address}");
109 14 : let proxy_listener = TcpListener::bind(proxy_address).await?;
110 14 : let cancellation_token = CancellationToken::new();
111 14 :
112 14 : // client facing tasks. these will exit on error or on cancellation
113 14 : // cancellation returns Ok(())
114 14 : let mut client_tasks = JoinSet::new();
115 14 : client_tasks.spawn(proxy::proxy::task_main(
116 14 : config,
117 14 : proxy_listener,
118 14 : cancellation_token.clone(),
119 14 : ));
120 :
121 14 : if let Some(wss_address) = args.wss {
122 14 : let wss_address: SocketAddr = wss_address.parse()?;
123 14 : info!("Starting wss on {wss_address}");
124 14 : let wss_listener = TcpListener::bind(wss_address).await?;
125 :
126 14 : client_tasks.spawn(http::websocket::task_main(
127 14 : config,
128 14 : wss_listener,
129 14 : cancellation_token.clone(),
130 14 : ));
131 0 : }
132 :
133 : // maintenance tasks. these never return unless there's an error
134 14 : let mut maintenance_tasks = JoinSet::new();
135 14 : maintenance_tasks.spawn(proxy::handle_signals(cancellation_token));
136 14 : maintenance_tasks.spawn(http::server::task_main(http_listener));
137 14 : maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
138 :
139 14 : if let Some(metrics_config) = &config.metric_collection {
140 1 : maintenance_tasks.spawn(metrics::task_main(metrics_config));
141 13 : }
142 :
143 : let maintenance = loop {
144 : // get one complete task
145 42 : match futures::future::select(
146 42 : pin!(maintenance_tasks.join_next()),
147 42 : pin!(client_tasks.join_next()),
148 42 : )
149 23 : .await
150 : {
151 : // exit immediately on maintenance task completion
152 0 : Either::Left((Some(res), _)) => break proxy::flatten_err(res)?,
153 : // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
154 0 : Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
155 : // exit immediately on client task error
156 28 : Either::Right((Some(res), _)) => proxy::flatten_err(res)?,
157 : // exit if all our client tasks have shutdown gracefully
158 14 : Either::Right((None, _)) => return Ok(()),
159 : }
160 : };
161 :
162 : // maintenance tasks return Infallible success values, this is an impossible value
163 : // so this match statically ensures that there are no possibilities for that value
164 : match maintenance {}
165 : }
166 :
167 : /// ProxyConfig is created at proxy startup, and lives forever.
168 14 : fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
169 14 : let tls_config = match (&args.tls_key, &args.tls_cert) {
170 14 : (Some(key_path), Some(cert_path)) => Some(config::configure_tls(
171 14 : key_path,
172 14 : cert_path,
173 14 : args.certs_dir.as_ref(),
174 14 : )?),
175 0 : (None, None) => None,
176 0 : _ => bail!("either both or neither tls-key and tls-cert must be specified"),
177 : };
178 :
179 14 : if args.allow_self_signed_compute {
180 3 : warn!("allowing self-signed compute certificates");
181 11 : }
182 :
183 14 : let metric_collection = match (
184 14 : &args.metric_collection_endpoint,
185 14 : &args.metric_collection_interval,
186 : ) {
187 1 : (Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
188 1 : endpoint: endpoint.parse()?,
189 1 : interval: humantime::parse_duration(interval)?,
190 : }),
191 13 : (None, None) => None,
192 0 : _ => bail!(
193 0 : "either both or neither metric-collection-endpoint \
194 0 : and metric-collection-interval must be specified"
195 0 : ),
196 : };
197 :
198 14 : let auth_backend = match &args.auth_backend {
199 : AuthBackend::Console => {
200 0 : let config::CacheOptions { size, ttl } = args.wake_compute_cache.parse()?;
201 :
202 0 : info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
203 0 : let caches = Box::leak(Box::new(console::caches::ApiCaches {
204 0 : node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
205 0 : }));
206 :
207 0 : let url = args.auth_endpoint.parse()?;
208 0 : let endpoint = http::Endpoint::new(url, http::new_client());
209 0 :
210 0 : let api = console::provider::neon::Api::new(endpoint, caches);
211 0 : auth::BackendType::Console(Cow::Owned(api), ())
212 : }
213 : AuthBackend::Postgres => {
214 11 : let url = args.auth_endpoint.parse()?;
215 11 : let api = console::provider::mock::Api::new(url);
216 11 : auth::BackendType::Postgres(Cow::Owned(api), ())
217 : }
218 : AuthBackend::Link => {
219 3 : let url = args.uri.parse()?;
220 3 : auth::BackendType::Link(Cow::Owned(url))
221 : }
222 : };
223 :
224 14 : let config = Box::leak(Box::new(ProxyConfig {
225 14 : tls_config,
226 14 : auth_backend,
227 14 : metric_collection,
228 14 : allow_self_signed_compute: args.allow_self_signed_compute,
229 14 : }));
230 14 :
231 14 : Ok(config)
232 14 : }
|