TLA Line data Source code
1 : use futures::future::Either;
2 : use proxy::auth;
3 : use proxy::config::HttpConfig;
4 : use proxy::console;
5 : use proxy::http;
6 : use proxy::metrics;
7 :
8 : use anyhow::bail;
9 : use proxy::config::{self, ProxyConfig};
10 : use std::pin::pin;
11 : use std::{borrow::Cow, net::SocketAddr};
12 : use tokio::net::TcpListener;
13 : use tokio::task::JoinSet;
14 : use tokio_util::sync::CancellationToken;
15 : use tracing::info;
16 : use tracing::warn;
17 : use utils::{project_git_version, sentry_init::init_sentry};
18 :
19 : project_git_version!(GIT_VERSION);
20 :
21 : use clap::{Parser, ValueEnum};
22 :
23 CBC 147 : #[derive(Clone, Debug, ValueEnum)]
24 : enum AuthBackend {
25 : Console,
26 : Postgres,
27 : Link,
28 : }
29 :
30 : /// Neon proxy/router
31 32 : #[derive(Parser)]
32 : #[command(version = GIT_VERSION, about)]
33 : struct ProxyCliArgs {
34 : /// listen for incoming client connections on ip:port
35 : #[clap(short, long, default_value = "127.0.0.1:4432")]
36 UBC 0 : proxy: String,
37 CBC 16 : #[clap(value_enum, long, default_value_t = AuthBackend::Link)]
38 UBC 0 : auth_backend: AuthBackend,
39 : /// listen for management callback connection on ip:port
40 : #[clap(short, long, default_value = "127.0.0.1:7000")]
41 0 : mgmt: String,
42 : /// listen for incoming http connections (metrics, etc) on ip:port
43 : #[clap(long, default_value = "127.0.0.1:7001")]
44 0 : http: String,
45 : /// listen for incoming wss connections on ip:port
46 : #[clap(long)]
47 : wss: Option<String>,
48 : /// redirect unauthenticated users to the given uri in case of link auth
49 : #[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
50 0 : uri: String,
51 : /// cloud API endpoint for authenticating users
52 : #[clap(
53 : short,
54 : long,
55 : default_value = "http://localhost:3000/authenticate_proxy_request/"
56 : )]
57 0 : auth_endpoint: String,
58 : /// path to TLS key for client postgres connections
59 : ///
60 : /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
61 : #[clap(short = 'k', long, alias = "ssl-key")]
62 : tls_key: Option<String>,
63 : /// path to TLS cert for client postgres connections
64 : ///
65 : /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
66 : #[clap(short = 'c', long, alias = "ssl-cert")]
67 : tls_cert: Option<String>,
68 : /// path to directory with TLS certificates for client postgres connections
69 : #[clap(long)]
70 : certs_dir: Option<String>,
71 : /// http endpoint to receive periodic metric updates
72 : #[clap(long)]
73 : metric_collection_endpoint: Option<String>,
74 : /// how often metrics should be sent to a collection endpoint
75 : #[clap(long)]
76 : metric_collection_interval: Option<String>,
77 : /// cache for `wake_compute` api method (use `size=0` to disable)
78 : #[clap(long, default_value = config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO)]
79 0 : wake_compute_cache: String,
80 : /// Allow self-signed certificates for compute nodes (for testing)
81 CBC 16 : #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
82 UBC 0 : allow_self_signed_compute: bool,
83 0 : /// timeout for http connections
84 : #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
85 0 : sql_over_http_timeout: tokio::time::Duration,
86 :
87 : /// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
88 CBC 16 : #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
89 UBC 0 : require_client_ip: bool,
90 0 : }
91 :
92 : #[tokio::main]
93 CBC 16 : async fn main() -> anyhow::Result<()> {
94 16 : let _logging_guard = proxy::logging::init().await?;
95 16 : let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
96 16 : let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
97 16 :
98 16 : info!("Version: {GIT_VERSION}");
99 16 : ::metrics::set_build_info_metric(GIT_VERSION);
100 16 :
101 16 : let args = ProxyCliArgs::parse();
102 16 : let config = build_config(&args)?;
103 :
104 16 : info!("Authentication backend: {}", config.auth_backend);
105 :
106 : // Check that we can bind to address before further initialization
107 16 : let http_address: SocketAddr = args.http.parse()?;
108 16 : info!("Starting http on {http_address}");
109 16 : let http_listener = TcpListener::bind(http_address).await?.into_std()?;
110 :
111 16 : let mgmt_address: SocketAddr = args.mgmt.parse()?;
112 16 : info!("Starting mgmt on {mgmt_address}");
113 16 : let mgmt_listener = TcpListener::bind(mgmt_address).await?;
114 :
115 16 : let proxy_address: SocketAddr = args.proxy.parse()?;
116 16 : info!("Starting proxy on {proxy_address}");
117 16 : let proxy_listener = TcpListener::bind(proxy_address).await?;
118 16 : let cancellation_token = CancellationToken::new();
119 16 :
120 16 : // client facing tasks. these will exit on error or on cancellation
121 16 : // cancellation returns Ok(())
122 16 : let mut client_tasks = JoinSet::new();
123 16 : client_tasks.spawn(proxy::proxy::task_main(
124 16 : config,
125 16 : proxy_listener,
126 16 : cancellation_token.clone(),
127 16 : ));
128 :
129 16 : if let Some(wss_address) = args.wss {
130 16 : let wss_address: SocketAddr = wss_address.parse()?;
131 16 : info!("Starting wss on {wss_address}");
132 16 : let wss_listener = TcpListener::bind(wss_address).await?;
133 :
134 16 : client_tasks.spawn(http::websocket::task_main(
135 16 : config,
136 16 : wss_listener,
137 16 : cancellation_token.clone(),
138 16 : ));
139 UBC 0 : }
140 :
141 : // maintenance tasks. these never return unless there's an error
142 CBC 16 : let mut maintenance_tasks = JoinSet::new();
143 16 : maintenance_tasks.spawn(proxy::handle_signals(cancellation_token));
144 16 : maintenance_tasks.spawn(http::server::task_main(http_listener));
145 16 : maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
146 :
147 16 : if let Some(metrics_config) = &config.metric_collection {
148 1 : maintenance_tasks.spawn(metrics::task_main(metrics_config));
149 15 : }
150 :
151 : let maintenance = loop {
152 : // get one complete task
153 48 : match futures::future::select(
154 48 : pin!(maintenance_tasks.join_next()),
155 48 : pin!(client_tasks.join_next()),
156 48 : )
157 22 : .await
158 : {
159 : // exit immediately on maintenance task completion
160 UBC 0 : Either::Left((Some(res), _)) => break proxy::flatten_err(res)?,
161 : // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
162 0 : Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
163 : // exit immediately on client task error
164 CBC 32 : Either::Right((Some(res), _)) => proxy::flatten_err(res)?,
165 : // exit if all our client tasks have shutdown gracefully
166 16 : Either::Right((None, _)) => return Ok(()),
167 : }
168 : };
169 :
170 : // maintenance tasks return Infallible success values, this is an impossible value
171 : // so this match statically ensures that there are no possibilities for that value
172 : match maintenance {}
173 : }
174 :
175 : /// ProxyConfig is created at proxy startup, and lives forever.
176 16 : fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
177 16 : let tls_config = match (&args.tls_key, &args.tls_cert) {
178 16 : (Some(key_path), Some(cert_path)) => Some(config::configure_tls(
179 16 : key_path,
180 16 : cert_path,
181 16 : args.certs_dir.as_ref(),
182 16 : )?),
183 UBC 0 : (None, None) => None,
184 0 : _ => bail!("either both or neither tls-key and tls-cert must be specified"),
185 : };
186 :
187 CBC 16 : if args.allow_self_signed_compute {
188 3 : warn!("allowing self-signed compute certificates");
189 13 : }
190 :
191 16 : let metric_collection = match (
192 16 : &args.metric_collection_endpoint,
193 16 : &args.metric_collection_interval,
194 : ) {
195 1 : (Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
196 1 : endpoint: endpoint.parse()?,
197 1 : interval: humantime::parse_duration(interval)?,
198 : }),
199 15 : (None, None) => None,
200 UBC 0 : _ => bail!(
201 0 : "either both or neither metric-collection-endpoint \
202 0 : and metric-collection-interval must be specified"
203 0 : ),
204 : };
205 :
206 CBC 16 : let auth_backend = match &args.auth_backend {
207 : AuthBackend::Console => {
208 UBC 0 : let config::CacheOptions { size, ttl } = args.wake_compute_cache.parse()?;
209 :
210 0 : info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
211 0 : let caches = Box::leak(Box::new(console::caches::ApiCaches {
212 0 : node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
213 0 : }));
214 :
215 0 : let url = args.auth_endpoint.parse()?;
216 0 : let endpoint = http::Endpoint::new(url, http::new_client());
217 0 :
218 0 : let api = console::provider::neon::Api::new(endpoint, caches);
219 0 : auth::BackendType::Console(Cow::Owned(api), ())
220 : }
221 : AuthBackend::Postgres => {
222 CBC 13 : let url = args.auth_endpoint.parse()?;
223 13 : let api = console::provider::mock::Api::new(url);
224 13 : auth::BackendType::Postgres(Cow::Owned(api), ())
225 : }
226 : AuthBackend::Link => {
227 3 : let url = args.uri.parse()?;
228 3 : auth::BackendType::Link(Cow::Owned(url))
229 : }
230 : };
231 16 : let http_config = HttpConfig {
232 16 : sql_over_http_timeout: args.sql_over_http_timeout,
233 16 : };
234 16 : let config = Box::leak(Box::new(ProxyConfig {
235 16 : tls_config,
236 16 : auth_backend,
237 16 : metric_collection,
238 16 : allow_self_signed_compute: args.allow_self_signed_compute,
239 16 : http_config,
240 16 : require_client_ip: args.require_client_ip,
241 16 : }));
242 16 :
243 16 : Ok(config)
244 16 : }
|