Line data Source code
1 : use std::env;
2 : use std::net::SocketAddr;
3 : use std::pin::pin;
4 : use std::sync::Arc;
5 : use std::time::Duration;
6 :
7 : use anyhow::bail;
8 : use arc_swap::ArcSwapOption;
9 : use camino::Utf8PathBuf;
10 : use clap::Parser;
11 : use futures::future::Either;
12 : use tokio::net::TcpListener;
13 : use tokio::sync::Notify;
14 : use tokio::task::JoinSet;
15 : use tokio_util::sync::CancellationToken;
16 : use tracing::{debug, error, info};
17 : use utils::sentry_init::init_sentry;
18 : use utils::{pid_file, project_build_tag, project_git_version};
19 :
20 : use crate::auth::backend::jwt::JwkCache;
21 : use crate::auth::backend::local::LocalBackend;
22 : use crate::auth::{self};
23 : use crate::cancellation::CancellationHandler;
24 : #[cfg(feature = "rest_broker")]
25 : use crate::config::RestConfig;
26 : use crate::config::{
27 : self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
28 : refresh_config_loop,
29 : };
30 : use crate::control_plane::locks::ApiLocks;
31 : use crate::http::health_server::AppMetrics;
32 : use crate::metrics::{Metrics, ThreadPoolMetrics};
33 : use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
34 : use crate::scram::threadpool::ThreadPool;
35 : use crate::serverless::cancel_set::CancelSet;
36 : use crate::serverless::{self, GlobalConnPoolOptions};
37 : use crate::tls::client_config::compute_client_config_with_root_certs;
38 : use crate::url::ApiUrl;
39 :
40 : project_git_version!(GIT_VERSION);
41 : project_build_tag!(BUILD_TAG);
42 :
43 : /// Neon proxy/router
44 : #[derive(Parser)]
45 : #[command(version = GIT_VERSION, about)]
46 : struct LocalProxyCliArgs {
47 : /// listen for incoming metrics connections on ip:port
48 : #[clap(long, default_value = "127.0.0.1:7001")]
49 : metrics: String,
50 : /// listen for incoming http connections on ip:port
51 : #[clap(long)]
52 : http: String,
53 : /// timeout for the TLS handshake
54 : #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
55 : handshake_timeout: tokio::time::Duration,
56 : /// lock for `connect_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
57 : #[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK)]
58 : connect_compute_lock: String,
59 : #[clap(flatten)]
60 : sql_over_http: SqlOverHttpArgs,
61 : /// User rate limiter max number of requests per second.
62 : ///
63 : /// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
64 : /// Can be given multiple times for different bucket sizes.
65 : #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
66 : user_rps_limit: Vec<RateBucketInfo>,
67 : /// Whether to retry the connection to the compute node
68 : #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
69 : connect_to_compute_retry: String,
70 : /// Address of the postgres server
71 : #[clap(long, default_value = "127.0.0.1:5432")]
72 : postgres: SocketAddr,
73 : /// Address of the internal compute-ctl api service
74 : #[clap(long, default_value = "http://127.0.0.1:3081/")]
75 : compute_ctl: ApiUrl,
76 : /// Path of the local proxy config file
77 : #[clap(long, default_value = "./local_proxy.json")]
78 : config_path: Utf8PathBuf,
79 : /// Path of the local proxy PID file
80 : #[clap(long, default_value = "./local_proxy.pid")]
81 : pid_path: Utf8PathBuf,
82 : /// Disable pg_session_jwt extension installation
83 : /// This is useful for testing the local proxy with vanilla postgres.
84 : #[clap(long, default_value = "false")]
85 : #[cfg(feature = "testing")]
86 : disable_pg_session_jwt: bool,
87 : }
88 :
89 : #[derive(clap::Args, Clone, Copy, Debug)]
90 : struct SqlOverHttpArgs {
91 : /// How many connections to pool for each endpoint. Excess connections are discarded
92 : #[clap(long, default_value_t = 200)]
93 : sql_over_http_pool_max_total_conns: usize,
94 :
95 : /// How long pooled connections should remain idle for before closing
96 : #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
97 : sql_over_http_idle_timeout: tokio::time::Duration,
98 :
99 : #[clap(long, default_value_t = 100)]
100 : sql_over_http_client_conn_threshold: u64,
101 :
102 : #[clap(long, default_value_t = 16)]
103 : sql_over_http_cancel_set_shards: usize,
104 :
105 : #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
106 : sql_over_http_max_request_size_bytes: usize,
107 :
108 : #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
109 : sql_over_http_max_response_size_bytes: usize,
110 : }
111 :
112 0 : pub async fn run() -> anyhow::Result<()> {
113 0 : let _logging_guard = crate::logging::init_local_proxy()?;
114 0 : let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
115 0 : let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
116 :
117 0 : Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
118 :
119 : // TODO: refactor these to use labels
120 0 : debug!("Version: {GIT_VERSION}");
121 0 : debug!("Build_tag: {BUILD_TAG}");
122 0 : let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo {
123 0 : revision: GIT_VERSION,
124 0 : build_tag: BUILD_TAG,
125 0 : });
126 :
127 0 : let jemalloc = match crate::jemalloc::MetricRecorder::new() {
128 0 : Ok(t) => Some(t),
129 0 : Err(e) => {
130 0 : tracing::error!(error = ?e, "could not start jemalloc metrics loop");
131 0 : None
132 : }
133 : };
134 :
135 0 : let args = LocalProxyCliArgs::parse();
136 0 : let config = build_config(&args)?;
137 0 : let auth_backend = build_auth_backend(&args);
138 :
139 : // before we bind to any ports, write the process ID to a file
140 : // so that compute-ctl can find our process later
141 : // in order to trigger the appropriate SIGHUP on config change.
142 : //
143 : // This also claims a "lock" that makes sure only one instance
144 : // of local_proxy runs at a time.
145 0 : let _process_guard = loop {
146 0 : match pid_file::claim_for_current_process(&args.pid_path) {
147 0 : Ok(guard) => break guard,
148 0 : Err(e) => {
149 : // compute-ctl might have tried to read the pid-file to let us
150 : // know about some config change. We should try again.
151 0 : error!(path=?args.pid_path, "could not claim PID file guard: {e:?}");
152 0 : tokio::time::sleep(Duration::from_secs(1)).await;
153 : }
154 : }
155 : };
156 :
157 0 : let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
158 0 : let http_listener = TcpListener::bind(args.http).await?;
159 0 : let shutdown = CancellationToken::new();
160 :
161 : // todo: should scale with CU
162 0 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
163 0 : LeakyBucketConfig {
164 0 : rps: 10.0,
165 0 : max: 100.0,
166 0 : },
167 : 16,
168 : ));
169 :
170 0 : let mut maintenance_tasks = JoinSet::new();
171 :
172 0 : let refresh_config_notify = Arc::new(Notify::new());
173 0 : maintenance_tasks.spawn(crate::signals::handle(shutdown.clone(), {
174 0 : let refresh_config_notify = Arc::clone(&refresh_config_notify);
175 0 : move || {
176 0 : refresh_config_notify.notify_one();
177 0 : }
178 : }));
179 :
180 : // trigger the first config load **after** setting up the signal hook
181 : // to avoid the race condition where:
182 : // 1. No config file registered when local_proxy starts up
183 : // 2. The config file is written but the signal hook is not yet received
184 : // 3. local_proxy completes startup but has no config loaded, despite there being a registerd config.
185 0 : refresh_config_notify.notify_one();
186 0 : tokio::spawn(refresh_config_loop(
187 0 : config,
188 0 : args.config_path,
189 0 : refresh_config_notify,
190 : ));
191 :
192 0 : maintenance_tasks.spawn(crate::http::health_server::task_main(
193 0 : metrics_listener,
194 0 : AppMetrics {
195 0 : jemalloc,
196 0 : neon_metrics,
197 0 : proxy: crate::metrics::Metrics::get(),
198 0 : },
199 : ));
200 :
201 0 : let task = serverless::task_main(
202 0 : config,
203 0 : auth_backend,
204 0 : http_listener,
205 0 : shutdown.clone(),
206 0 : Arc::new(CancellationHandler::new(&config.connect_to_compute)),
207 0 : endpoint_rate_limiter,
208 : );
209 :
210 0 : match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {
211 : // exit immediately on maintenance task completion
212 0 : Either::Left((Some(res), _)) => match crate::error::flatten_err(res)? {},
213 : // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
214 0 : Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
215 : // exit immediately on client task error
216 0 : Either::Right((res, _)) => res?,
217 : }
218 :
219 0 : Ok(())
220 0 : }
221 :
222 : /// ProxyConfig is created at proxy startup, and lives forever.
223 0 : fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
224 : let config::ConcurrencyLockOptions {
225 0 : shards,
226 0 : limiter,
227 0 : epoch,
228 0 : timeout,
229 0 : } = args.connect_compute_lock.parse()?;
230 0 : info!(
231 : ?limiter,
232 : shards,
233 : ?epoch,
234 0 : "Using NodeLocks (connect_compute)"
235 : );
236 0 : let connect_compute_locks = ApiLocks::new(
237 : "connect_compute_lock",
238 0 : limiter,
239 0 : shards,
240 0 : timeout,
241 0 : epoch,
242 0 : &Metrics::get().proxy.connect_compute_lock,
243 : );
244 :
245 0 : let http_config = HttpConfig {
246 0 : accept_websockets: false,
247 0 : pool_options: GlobalConnPoolOptions {
248 0 : gc_epoch: Duration::from_secs(60),
249 0 : pool_shards: 2,
250 0 : idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
251 0 : opt_in: false,
252 0 :
253 0 : max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_total_conns,
254 0 : max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
255 0 : },
256 0 : cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
257 0 : client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
258 0 : max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
259 0 : max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
260 0 : };
261 :
262 0 : let compute_config = ComputeConfig {
263 0 : retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
264 0 : tls: Arc::new(compute_client_config_with_root_certs()?),
265 0 : timeout: Duration::from_secs(2),
266 : };
267 :
268 0 : let greetings = env::var_os("NEON_MOTD").map_or(String::new(), |s| match s.into_string() {
269 0 : Ok(s) => s,
270 : Err(_) => {
271 0 : debug!("NEON_MOTD environment variable is not valid UTF-8");
272 0 : String::new()
273 : }
274 0 : });
275 :
276 0 : Ok(Box::leak(Box::new(ProxyConfig {
277 0 : tls_config: ArcSwapOption::from(None),
278 0 : metric_collection: None,
279 0 : http_config,
280 0 : authentication_config: AuthenticationConfig {
281 0 : jwks_cache: JwkCache::default(),
282 0 : thread_pool: ThreadPool::new(0),
283 0 : scram_protocol_timeout: Duration::from_secs(10),
284 0 : ip_allowlist_check_enabled: true,
285 0 : is_vpc_acccess_proxy: false,
286 0 : is_auth_broker: false,
287 0 : accept_jwts: true,
288 0 : console_redirect_confirmation_timeout: Duration::ZERO,
289 0 : },
290 : #[cfg(feature = "rest_broker")]
291 0 : rest_config: RestConfig {
292 0 : is_rest_broker: false,
293 0 : db_schema_cache: None,
294 0 : max_schema_size: 0,
295 0 : hostname_prefix: String::new(),
296 0 : },
297 0 : proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
298 0 : handshake_timeout: Duration::from_secs(10),
299 0 : wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
300 0 : connect_compute_locks,
301 0 : connect_to_compute: compute_config,
302 0 : greetings,
303 : #[cfg(feature = "testing")]
304 0 : disable_pg_session_jwt: args.disable_pg_session_jwt,
305 : })))
306 0 : }
307 :
308 : /// auth::Backend is created at proxy startup, and lives forever.
309 0 : fn build_auth_backend(args: &LocalProxyCliArgs) -> &'static auth::Backend<'static, ()> {
310 0 : let auth_backend = crate::auth::Backend::Local(crate::auth::backend::MaybeOwned::Owned(
311 0 : LocalBackend::new(args.postgres, args.compute_ctl.clone()),
312 0 : ));
313 :
314 0 : Box::leak(Box::new(auth_backend))
315 0 : }
|