Line data Source code
1 : use std::{
2 : net::SocketAddr,
3 : path::{Path, PathBuf},
4 : pin::pin,
5 : sync::Arc,
6 : time::Duration,
7 : };
8 :
9 : use anyhow::{bail, ensure};
10 : use dashmap::DashMap;
11 : use futures::{future::Either, FutureExt};
12 : use proxy::{
13 : auth::backend::local::{JwksRoleSettings, LocalBackend, JWKS_ROLE_MAP},
14 : cancellation::CancellationHandlerMain,
15 : config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
16 : console::{locks::ApiLocks, messages::JwksRoleMapping},
17 : http::health_server::AppMetrics,
18 : metrics::{Metrics, ThreadPoolMetrics},
19 : rate_limiter::{BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo},
20 : scram::threadpool::ThreadPool,
21 : serverless::{self, cancel_set::CancelSet, GlobalConnPoolOptions},
22 : };
23 :
24 : project_git_version!(GIT_VERSION);
25 : project_build_tag!(BUILD_TAG);
26 :
27 : use clap::Parser;
28 : use tokio::{net::TcpListener, task::JoinSet};
29 : use tokio_util::sync::CancellationToken;
30 : use tracing::{error, info, warn};
31 : use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
32 :
33 : #[global_allocator]
34 : static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
35 :
36 : /// Neon proxy/router
37 0 : #[derive(Parser)]
38 : #[command(version = GIT_VERSION, about)]
39 : struct LocalProxyCliArgs {
40 : /// listen for incoming metrics connections on ip:port
41 : #[clap(long, default_value = "127.0.0.1:7001")]
42 0 : metrics: String,
43 : /// listen for incoming http connections on ip:port
44 : #[clap(long)]
45 0 : http: String,
46 : /// timeout for the TLS handshake
47 : #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
48 0 : handshake_timeout: tokio::time::Duration,
49 : /// lock for `connect_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
50 : #[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK)]
51 0 : connect_compute_lock: String,
52 : #[clap(flatten)]
53 : sql_over_http: SqlOverHttpArgs,
54 : /// User rate limiter max number of requests per second.
55 : ///
56 : /// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
57 : /// Can be given multiple times for different bucket sizes.
58 0 : #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
59 0 : user_rps_limit: Vec<RateBucketInfo>,
60 : /// Whether the auth rate limiter actually takes effect (for testing)
61 0 : #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
62 0 : auth_rate_limit_enabled: bool,
63 : /// Authentication rate limiter max number of hashes per second.
64 0 : #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
65 0 : auth_rate_limit: Vec<RateBucketInfo>,
66 : /// The IP subnet to use when considering whether two IP addresses are considered the same.
67 0 : #[clap(long, default_value_t = 64)]
68 0 : auth_rate_limit_ip_subnet: u8,
69 : /// Whether to retry the connection to the compute node
70 : #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
71 0 : connect_to_compute_retry: String,
72 : /// Address of the postgres server
73 : #[clap(long, default_value = "127.0.0.1:5432")]
74 0 : compute: SocketAddr,
75 : /// File address of the local proxy config file
76 : #[clap(long, default_value = "./localproxy.json")]
77 0 : config_path: PathBuf,
78 : }
79 :
80 0 : #[derive(clap::Args, Clone, Copy, Debug)]
81 : struct SqlOverHttpArgs {
82 : /// How many connections to pool for each endpoint. Excess connections are discarded
83 0 : #[clap(long, default_value_t = 200)]
84 0 : sql_over_http_pool_max_total_conns: usize,
85 :
86 : /// How long pooled connections should remain idle for before closing
87 : #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
88 0 : sql_over_http_idle_timeout: tokio::time::Duration,
89 :
90 0 : #[clap(long, default_value_t = 100)]
91 0 : sql_over_http_client_conn_threshold: u64,
92 :
93 0 : #[clap(long, default_value_t = 16)]
94 0 : sql_over_http_cancel_set_shards: usize,
95 :
96 0 : #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
97 0 : sql_over_http_max_request_size_bytes: u64,
98 :
99 0 : #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
100 0 : sql_over_http_max_response_size_bytes: usize,
101 : }
102 :
103 : #[tokio::main]
104 0 : async fn main() -> anyhow::Result<()> {
105 0 : let _logging_guard = proxy::logging::init().await?;
106 0 : let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
107 0 : let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
108 0 :
109 0 : Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
110 0 :
111 0 : info!("Version: {GIT_VERSION}");
112 0 : info!("Build_tag: {BUILD_TAG}");
113 0 : let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo {
114 0 : revision: GIT_VERSION,
115 0 : build_tag: BUILD_TAG,
116 0 : });
117 0 :
118 0 : let jemalloc = match proxy::jemalloc::MetricRecorder::new() {
119 0 : Ok(t) => Some(t),
120 0 : Err(e) => {
121 0 : tracing::error!(error = ?e, "could not start jemalloc metrics loop");
122 0 : None
123 0 : }
124 0 : };
125 0 :
126 0 : let args = LocalProxyCliArgs::parse();
127 0 : let config = build_config(&args)?;
128 0 :
129 0 : let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
130 0 : let http_listener = TcpListener::bind(args.http).await?;
131 0 : let shutdown = CancellationToken::new();
132 0 :
133 0 : // todo: should scale with CU
134 0 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
135 0 : LeakyBucketConfig {
136 0 : rps: 10.0,
137 0 : max: 100.0,
138 0 : },
139 0 : 16,
140 0 : ));
141 0 :
142 0 : refresh_config(args.config_path.clone()).await;
143 0 :
144 0 : let mut maintenance_tasks = JoinSet::new();
145 0 : maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), move || {
146 0 : refresh_config(args.config_path.clone()).map(Ok)
147 0 : }));
148 0 : maintenance_tasks.spawn(proxy::http::health_server::task_main(
149 0 : metrics_listener,
150 0 : AppMetrics {
151 0 : jemalloc,
152 0 : neon_metrics,
153 0 : proxy: proxy::metrics::Metrics::get(),
154 0 : },
155 0 : ));
156 0 :
157 0 : let task = serverless::task_main(
158 0 : config,
159 0 : http_listener,
160 0 : shutdown.clone(),
161 0 : Arc::new(CancellationHandlerMain::new(
162 0 : Arc::new(DashMap::new()),
163 0 : None,
164 0 : proxy::metrics::CancellationSource::Local,
165 0 : )),
166 0 : endpoint_rate_limiter,
167 0 : );
168 0 :
169 0 : match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {
170 0 : // exit immediately on maintenance task completion
171 0 : Either::Left((Some(res), _)) => match proxy::flatten_err(res)? {},
172 0 : // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
173 0 : Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
174 0 : // exit immediately on client task error
175 0 : Either::Right((res, _)) => res?,
176 0 : }
177 0 :
178 0 : Ok(())
179 0 : }
180 :
181 : /// ProxyConfig is created at proxy startup, and lives forever.
182 0 : fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
183 : let config::ConcurrencyLockOptions {
184 0 : shards,
185 0 : limiter,
186 0 : epoch,
187 0 : timeout,
188 0 : } = args.connect_compute_lock.parse()?;
189 0 : info!(
190 : ?limiter,
191 : shards,
192 : ?epoch,
193 0 : "Using NodeLocks (connect_compute)"
194 : );
195 0 : let connect_compute_locks = ApiLocks::new(
196 0 : "connect_compute_lock",
197 0 : limiter,
198 0 : shards,
199 0 : timeout,
200 0 : epoch,
201 0 : &Metrics::get().proxy.connect_compute_lock,
202 0 : )?;
203 :
204 0 : let http_config = HttpConfig {
205 0 : accept_websockets: false,
206 0 : pool_options: GlobalConnPoolOptions {
207 0 : gc_epoch: Duration::from_secs(60),
208 0 : pool_shards: 2,
209 0 : idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
210 0 : opt_in: false,
211 0 :
212 0 : max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_total_conns,
213 0 : max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
214 0 : },
215 0 : cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
216 0 : client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
217 0 : max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
218 0 : max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
219 0 : };
220 0 :
221 0 : Ok(Box::leak(Box::new(ProxyConfig {
222 0 : tls_config: None,
223 0 : auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
224 0 : LocalBackend::new(args.compute),
225 0 : )),
226 0 : metric_collection: None,
227 0 : allow_self_signed_compute: false,
228 0 : http_config,
229 0 : authentication_config: AuthenticationConfig {
230 0 : thread_pool: ThreadPool::new(0),
231 0 : scram_protocol_timeout: Duration::from_secs(10),
232 0 : rate_limiter_enabled: false,
233 0 : rate_limiter: BucketRateLimiter::new(vec![]),
234 0 : rate_limit_ip_subnet: 64,
235 0 : ip_allowlist_check_enabled: true,
236 0 : },
237 0 : require_client_ip: false,
238 0 : handshake_timeout: Duration::from_secs(10),
239 0 : region: "local".into(),
240 0 : wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
241 0 : connect_compute_locks,
242 0 : connect_to_compute_retry_config: RetryConfig::parse(
243 0 : RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES,
244 0 : )?,
245 : })))
246 0 : }
247 :
248 0 : async fn refresh_config(path: PathBuf) {
249 0 : match refresh_config_inner(&path).await {
250 0 : Ok(()) => {}
251 0 : Err(e) => {
252 0 : error!(error=?e, ?path, "could not read config file");
253 : }
254 : }
255 0 : }
256 :
257 0 : async fn refresh_config_inner(path: &Path) -> anyhow::Result<()> {
258 0 : let bytes = tokio::fs::read(&path).await?;
259 0 : let mut data: JwksRoleMapping = serde_json::from_slice(&bytes)?;
260 :
261 0 : let mut settings = None;
262 :
263 0 : for mapping in data.roles.values_mut() {
264 0 : for jwks in &mut mapping.jwks {
265 0 : ensure!(
266 0 : jwks.jwks_url.has_authority()
267 0 : && (jwks.jwks_url.scheme() == "http" || jwks.jwks_url.scheme() == "https"),
268 0 : "Invalid JWKS url. Must be HTTP",
269 : );
270 :
271 0 : ensure!(
272 0 : jwks.jwks_url
273 0 : .host()
274 0 : .is_some_and(|h| h != url::Host::Domain("")),
275 0 : "Invalid JWKS url. No domain listed",
276 : );
277 :
278 : // clear username, password and ports
279 0 : jwks.jwks_url.set_username("").expect(
280 0 : "url can be a base and has a valid host and is not a file. should not error",
281 0 : );
282 0 : jwks.jwks_url.set_password(None).expect(
283 0 : "url can be a base and has a valid host and is not a file. should not error",
284 0 : );
285 0 : // local testing is hard if we need to have a specific restricted port
286 0 : if cfg!(not(feature = "testing")) {
287 0 : jwks.jwks_url.set_port(None).expect(
288 0 : "url can be a base and has a valid host and is not a file. should not error",
289 0 : );
290 0 : }
291 :
292 : // clear query params
293 0 : jwks.jwks_url.set_fragment(None);
294 0 : jwks.jwks_url.query_pairs_mut().clear().finish();
295 0 :
296 0 : if jwks.jwks_url.scheme() != "https" {
297 : // local testing is hard if we need to set up https support.
298 0 : if cfg!(not(feature = "testing")) {
299 0 : jwks.jwks_url
300 0 : .set_scheme("https")
301 0 : .expect("should not error to set the scheme to https if it was http");
302 0 : } else {
303 0 : warn!(scheme = jwks.jwks_url.scheme(), "JWKS url is not HTTPS");
304 : }
305 0 : }
306 :
307 0 : let (pr, br) = settings.get_or_insert((jwks.project_id, jwks.branch_id));
308 0 : ensure!(
309 0 : *pr == jwks.project_id,
310 0 : "inconsistent project IDs configured"
311 : );
312 0 : ensure!(*br == jwks.branch_id, "inconsistent branch IDs configured");
313 : }
314 : }
315 :
316 0 : if let Some((project_id, branch_id)) = settings {
317 0 : JWKS_ROLE_MAP.store(Some(Arc::new(JwksRoleSettings {
318 0 : roles: data.roles,
319 0 : project_id,
320 0 : branch_id,
321 0 : })));
322 0 : }
323 :
324 0 : Ok(())
325 0 : }
|