Line data Source code
1 : use std::borrow::Cow;
2 : use std::collections::hash_map::RandomState;
3 : use std::hash::{BuildHasher, Hash};
4 : use std::sync::atomic::{AtomicUsize, Ordering};
5 : use std::sync::Mutex;
6 :
7 : use anyhow::bail;
8 : use clashmap::ClashMap;
9 : use itertools::Itertools;
10 : use rand::rngs::StdRng;
11 : use rand::{Rng, SeedableRng};
12 : use tokio::time::{Duration, Instant};
13 : use tracing::info;
14 :
15 : use crate::ext::LockExt;
16 : use crate::intern::EndpointIdInt;
17 :
18 : pub struct GlobalRateLimiter {
19 : data: Vec<RateBucket>,
20 : info: Vec<RateBucketInfo>,
21 : }
22 :
23 : impl GlobalRateLimiter {
24 0 : pub fn new(info: Vec<RateBucketInfo>) -> Self {
25 0 : Self {
26 0 : data: vec![
27 0 : RateBucket {
28 0 : start: Instant::now(),
29 0 : count: 0,
30 0 : };
31 0 : info.len()
32 0 : ],
33 0 : info,
34 0 : }
35 0 : }
36 :
37 : /// Check that number of connections is below `max_rps` rps.
38 0 : pub fn check(&mut self) -> bool {
39 0 : let now = Instant::now();
40 0 :
41 0 : let should_allow_request = self
42 0 : .data
43 0 : .iter_mut()
44 0 : .zip(&self.info)
45 0 : .all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
46 0 :
47 0 : if should_allow_request {
48 0 : // only increment the bucket counts if the request will actually be accepted
49 0 : self.data.iter_mut().for_each(|b| b.inc(1));
50 0 : }
51 :
52 0 : should_allow_request
53 0 : }
54 : }
55 :
56 : // Simple per-endpoint rate limiter.
57 : //
58 : // Check that number of connections to the endpoint is below `max_rps` rps.
59 : // Purposefully ignore user name and database name as clients can reconnect
60 : // with different names, so we'll end up sending some http requests to
61 : // the control plane.
62 : pub type WakeComputeRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
63 :
64 : pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
65 : map: ClashMap<Key, Vec<RateBucket>, Hasher>,
66 : info: Cow<'static, [RateBucketInfo]>,
67 : access_count: AtomicUsize,
68 : rand: Mutex<Rand>,
69 : }
70 :
71 : #[derive(Clone, Copy)]
72 : struct RateBucket {
73 : start: Instant,
74 : count: u32,
75 : }
76 :
77 : impl RateBucket {
78 3000915 : fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
79 3000915 : if now - self.start < info.interval {
80 3000907 : self.count + n <= info.max_rpi
81 : } else {
82 : // bucket expired, reset
83 8 : self.count = 0;
84 8 : self.start = now;
85 8 :
86 8 : true
87 : }
88 3000915 : }
89 :
90 3000909 : fn inc(&mut self, n: u32) {
91 3000909 : self.count += n;
92 3000909 : }
93 : }
94 :
95 : #[derive(Clone, Copy, PartialEq)]
96 : pub struct RateBucketInfo {
97 : pub(crate) interval: Duration,
98 : // requests per interval
99 : pub(crate) max_rpi: u32,
100 : }
101 :
102 : impl std::fmt::Display for RateBucketInfo {
103 18 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 18 : let rps = self.rps().floor() as u64;
105 18 : write!(f, "{rps}@{}", humantime::format_duration(self.interval))
106 18 : }
107 : }
108 :
109 : impl std::fmt::Debug for RateBucketInfo {
110 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 0 : write!(f, "{self}")
112 0 : }
113 : }
114 :
115 : impl std::str::FromStr for RateBucketInfo {
116 : type Err = anyhow::Error;
117 :
118 19 : fn from_str(s: &str) -> Result<Self, Self::Err> {
119 19 : let Some((max_rps, interval)) = s.split_once('@') else {
120 0 : bail!("invalid rate info")
121 : };
122 19 : let max_rps = max_rps.parse()?;
123 19 : let interval = humantime::parse_duration(interval)?;
124 19 : Ok(Self::new(max_rps, interval))
125 19 : }
126 : }
127 :
128 : impl RateBucketInfo {
129 : pub const DEFAULT_SET: [Self; 3] = [
130 : Self::new(300, Duration::from_secs(1)),
131 : Self::new(200, Duration::from_secs(60)),
132 : Self::new(100, Duration::from_secs(600)),
133 : ];
134 :
135 : pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
136 : Self::new(500, Duration::from_secs(1)),
137 : Self::new(300, Duration::from_secs(60)),
138 : Self::new(200, Duration::from_secs(600)),
139 : ];
140 :
141 : // For all the sessions will be cancel key. So this limit is essentially global proxy limit.
142 : pub const DEFAULT_REDIS_SET: [Self; 2] = [
143 : Self::new(100_000, Duration::from_secs(1)),
144 : Self::new(50_000, Duration::from_secs(10)),
145 : ];
146 :
147 : /// All of these are per endpoint-maskedip pair.
148 : /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
149 : ///
150 : /// First bucket: 1000mcpus total per endpoint-ip pair
151 : /// * 4096000 requests per second with 1 hash rounds.
152 : /// * 1000 requests per second with 4096 hash rounds.
153 : /// * 6.8 requests per second with 600000 hash rounds.
154 : pub const DEFAULT_AUTH_SET: [Self; 3] = [
155 : Self::new(1000 * 4096, Duration::from_secs(1)),
156 : Self::new(600 * 4096, Duration::from_secs(60)),
157 : Self::new(300 * 4096, Duration::from_secs(600)),
158 : ];
159 :
160 18 : pub fn rps(&self) -> f64 {
161 18 : (self.max_rpi as f64) / self.interval.as_secs_f64()
162 18 : }
163 :
164 3 : pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
165 8 : info.sort_unstable_by_key(|info| info.interval);
166 3 : let invalid = info
167 3 : .iter()
168 3 : .tuple_windows()
169 4 : .find(|(a, b)| a.max_rpi > b.max_rpi);
170 3 : if let Some((a, b)) = invalid {
171 1 : bail!(
172 1 : "invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
173 1 : b.max_rpi,
174 1 : a.max_rpi,
175 1 : );
176 2 : }
177 2 :
178 2 : Ok(())
179 3 : }
180 :
181 23 : pub const fn new(max_rps: u32, interval: Duration) -> Self {
182 23 : Self {
183 23 : interval,
184 23 : max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
185 23 : }
186 23 : }
187 : }
188 :
189 : impl<K: Hash + Eq> BucketRateLimiter<K> {
190 4 : pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
191 4 : Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
192 4 : }
193 : }
194 :
195 : impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
196 5 : fn new_with_rand_and_hasher(
197 5 : info: impl Into<Cow<'static, [RateBucketInfo]>>,
198 5 : rand: R,
199 5 : hasher: S,
200 5 : ) -> Self {
201 5 : let info = info.into();
202 5 : info!(buckets = ?info, "endpoint rate limiter");
203 5 : Self {
204 5 : info,
205 5 : map: ClashMap::with_hasher_and_shard_amount(hasher, 64),
206 5 : access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
207 5 : rand: Mutex::new(rand),
208 5 : }
209 5 : }
210 :
211 : /// Check that number of connections to the endpoint is below `max_rps` rps.
212 1000457 : pub(crate) fn check(&self, key: K, n: u32) -> bool {
213 1000457 : // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
214 1000457 : // worst case memory usage is about:
215 1000457 : // = 2 * 2048 * 64 * (48B + 72B)
216 1000457 : // = 30MB
217 1000457 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
218 488 : self.do_gc();
219 999969 : }
220 :
221 1000457 : let now = Instant::now();
222 1000457 : let mut entry = self.map.entry(key).or_insert_with(|| {
223 1000004 : vec![
224 1000004 : RateBucket {
225 1000004 : start: now,
226 1000004 : count: 0,
227 1000004 : };
228 1000004 : self.info.len()
229 1000004 : ]
230 1000457 : });
231 1000457 :
232 1000457 : let should_allow_request = entry
233 1000457 : .iter_mut()
234 1000457 : .zip(&*self.info)
235 3000915 : .all(|(bucket, info)| bucket.should_allow_request(info, now, n));
236 1000457 :
237 1000457 : if should_allow_request {
238 1000453 : // only increment the bucket counts if the request will actually be accepted
239 3000909 : entry.iter_mut().for_each(|b| b.inc(n));
240 1000453 : }
241 :
242 1000457 : should_allow_request
243 1000457 : }
244 :
245 : /// Clean the map. Simple strategy: remove all entries in a random shard.
246 : /// At worst, we'll double the effective max_rps during the cleanup.
247 : /// But that way deletion does not aquire mutex on each entry access.
248 488 : pub(crate) fn do_gc(&self) {
249 488 : info!(
250 0 : "cleaning up bucket rate limiter, current size = {}",
251 0 : self.map.len()
252 : );
253 488 : let n = self.map.shards().len();
254 488 : // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
255 488 : // (impossible, infact, unless we have 2048 threads)
256 488 : let shard = self.rand.lock_propagate_poison().gen_range(0..n);
257 488 : self.map.shards()[shard].write().clear();
258 488 : }
259 : }
260 :
261 : #[cfg(test)]
262 : #[expect(clippy::unwrap_used)]
263 : mod tests {
264 : use std::hash::BuildHasherDefault;
265 : use std::time::Duration;
266 :
267 : use rand::SeedableRng;
268 : use rustc_hash::FxHasher;
269 : use tokio::time;
270 :
271 : use super::{BucketRateLimiter, WakeComputeRateLimiter};
272 : use crate::intern::EndpointIdInt;
273 : use crate::rate_limiter::RateBucketInfo;
274 : use crate::types::EndpointId;
275 :
276 : #[test]
277 1 : fn rate_bucket_rpi() {
278 1 : let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
279 1 : assert_eq!(rate_bucket.max_rpi, 50 * 5);
280 :
281 1 : let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
282 1 : assert_eq!(rate_bucket.max_rpi, 50 / 2);
283 1 : }
284 :
285 : #[test]
286 1 : fn rate_bucket_parse() {
287 1 : let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
288 1 : assert_eq!(rate_bucket.interval, Duration::from_secs(10));
289 1 : assert_eq!(rate_bucket.max_rpi, 100 * 10);
290 1 : assert_eq!(rate_bucket.to_string(), "100@10s");
291 :
292 1 : let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
293 1 : assert_eq!(rate_bucket.interval, Duration::from_secs(60));
294 1 : assert_eq!(rate_bucket.max_rpi, 100 * 60);
295 1 : assert_eq!(rate_bucket.to_string(), "100@1m");
296 1 : }
297 :
298 : #[test]
299 1 : fn default_rate_buckets() {
300 1 : let mut defaults = RateBucketInfo::DEFAULT_SET;
301 1 : RateBucketInfo::validate(&mut defaults[..]).unwrap();
302 1 : }
303 :
304 : #[test]
305 : #[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
306 1 : fn rate_buckets_validate() {
307 1 : let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
308 1 : .into_iter()
309 2 : .map(|s| s.parse().unwrap())
310 1 : .collect();
311 1 : RateBucketInfo::validate(&mut rates).unwrap();
312 1 : }
313 :
314 : #[tokio::test]
315 1 : async fn test_rate_limits() {
316 1 : let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
317 1 : .into_iter()
318 2 : .map(|s| s.parse().unwrap())
319 1 : .collect();
320 1 : RateBucketInfo::validate(&mut rates).unwrap();
321 1 : let limiter = WakeComputeRateLimiter::new(rates);
322 1 :
323 1 : let endpoint = EndpointId::from("ep-my-endpoint-1234");
324 1 : let endpoint = EndpointIdInt::from(endpoint);
325 1 :
326 1 : time::pause();
327 1 :
328 101 : for _ in 0..100 {
329 100 : assert!(limiter.check(endpoint, 1));
330 1 : }
331 1 : // more connections fail
332 1 : assert!(!limiter.check(endpoint, 1));
333 1 :
334 1 : // fail even after 500ms as it's in the same bucket
335 1 : time::advance(time::Duration::from_millis(500)).await;
336 1 : assert!(!limiter.check(endpoint, 1));
337 1 :
338 1 : // after a full 1s, 100 requests are allowed again
339 1 : time::advance(time::Duration::from_millis(500)).await;
340 6 : for _ in 1..6 {
341 255 : for _ in 0..50 {
342 250 : assert!(limiter.check(endpoint, 2));
343 1 : }
344 5 : time::advance(time::Duration::from_millis(1000)).await;
345 1 : }
346 1 :
347 1 : // more connections after 600 will exceed the 20rps@30s limit
348 1 : assert!(!limiter.check(endpoint, 1));
349 1 :
350 1 : // will still fail before the 30 second limit
351 1 : time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
352 1 : assert!(!limiter.check(endpoint, 1));
353 1 :
354 1 : // after the full 30 seconds, 100 requests are allowed again
355 1 : time::advance(time::Duration::from_millis(1)).await;
356 101 : for _ in 0..100 {
357 100 : assert!(limiter.check(endpoint, 1));
358 1 : }
359 1 : }
360 :
361 : #[tokio::test]
362 1 : async fn test_rate_limits_gc() {
363 1 : // fixed seeded random/hasher to ensure that the test is not flaky
364 1 : let rand = rand::rngs::StdRng::from_seed([1; 32]);
365 1 : let hasher = BuildHasherDefault::<FxHasher>::default();
366 1 :
367 1 : let limiter =
368 1 : BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
369 1000001 : for i in 0..1_000_000 {
370 1000000 : limiter.check(i, 1);
371 1000000 : }
372 1 : assert!(limiter.map.len() < 150_000);
373 1 : }
374 : }
|