Line data Source code
1 : use std::borrow::Cow;
2 : use std::collections::hash_map::RandomState;
3 : use std::hash::{BuildHasher, Hash};
4 : use std::sync::Mutex;
5 : use std::sync::atomic::{AtomicUsize, Ordering};
6 :
7 : use anyhow::bail;
8 : use clashmap::ClashMap;
9 : use itertools::Itertools;
10 : use rand::rngs::StdRng;
11 : use rand::{Rng, SeedableRng};
12 : use tokio::time::{Duration, Instant};
13 : use tracing::info;
14 :
15 : use super::LeakyBucketConfig;
16 : use crate::ext::LockExt;
17 : use crate::intern::EndpointIdInt;
18 :
19 : pub struct GlobalRateLimiter {
20 : data: Vec<RateBucket>,
21 : info: Vec<RateBucketInfo>,
22 : }
23 :
24 : impl GlobalRateLimiter {
25 0 : pub fn new(info: Vec<RateBucketInfo>) -> Self {
26 0 : Self {
27 0 : data: vec![
28 0 : RateBucket {
29 0 : start: Instant::now(),
30 0 : count: 0,
31 0 : };
32 0 : info.len()
33 0 : ],
34 0 : info,
35 0 : }
36 0 : }
37 :
38 : /// Check that number of connections is below `max_rps` rps.
39 0 : pub fn check(&mut self) -> bool {
40 0 : let now = Instant::now();
41 0 :
42 0 : let should_allow_request = self
43 0 : .data
44 0 : .iter_mut()
45 0 : .zip(&self.info)
46 0 : .all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
47 0 :
48 0 : if should_allow_request {
49 0 : // only increment the bucket counts if the request will actually be accepted
50 0 : self.data.iter_mut().for_each(|b| b.inc(1));
51 0 : }
52 :
53 0 : should_allow_request
54 0 : }
55 : }
56 :
57 : // Simple per-endpoint rate limiter.
58 : //
59 : // Check that number of connections to the endpoint is below `max_rps` rps.
60 : // Purposefully ignore user name and database name as clients can reconnect
61 : // with different names, so we'll end up sending some http requests to
62 : // the control plane.
63 : pub type WakeComputeRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
64 :
65 : pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
66 : map: ClashMap<Key, Vec<RateBucket>, Hasher>,
67 : info: Cow<'static, [RateBucketInfo]>,
68 : access_count: AtomicUsize,
69 : rand: Mutex<Rand>,
70 : }
71 :
72 : #[derive(Clone, Copy)]
73 : struct RateBucket {
74 : start: Instant,
75 : count: u32,
76 : }
77 :
78 : impl RateBucket {
79 3000906 : fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
80 3000906 : if now - self.start < info.interval {
81 3000898 : self.count + n <= info.max_rpi
82 : } else {
83 : // bucket expired, reset
84 8 : self.count = 0;
85 8 : self.start = now;
86 8 :
87 8 : true
88 : }
89 3000906 : }
90 :
91 3000900 : fn inc(&mut self, n: u32) {
92 3000900 : self.count += n;
93 3000900 : }
94 : }
95 :
96 : #[derive(Clone, Copy, PartialEq)]
97 : pub struct RateBucketInfo {
98 : pub(crate) interval: Duration,
99 : // requests per interval
100 : pub(crate) max_rpi: u32,
101 : }
102 :
103 : impl std::fmt::Display for RateBucketInfo {
104 12 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 12 : let rps = self.rps().floor() as u64;
106 12 : write!(f, "{rps}@{}", humantime::format_duration(self.interval))
107 12 : }
108 : }
109 :
110 : impl std::fmt::Debug for RateBucketInfo {
111 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 0 : write!(f, "{self}")
113 0 : }
114 : }
115 :
116 : impl std::str::FromStr for RateBucketInfo {
117 : type Err = anyhow::Error;
118 :
119 13 : fn from_str(s: &str) -> Result<Self, Self::Err> {
120 13 : let Some((max_rps, interval)) = s.split_once('@') else {
121 0 : bail!("invalid rate info")
122 : };
123 13 : let max_rps = max_rps.parse()?;
124 13 : let interval = humantime::parse_duration(interval)?;
125 13 : Ok(Self::new(max_rps, interval))
126 13 : }
127 : }
128 :
129 : impl RateBucketInfo {
130 : pub const DEFAULT_SET: [Self; 3] = [
131 : Self::new(300, Duration::from_secs(1)),
132 : Self::new(200, Duration::from_secs(60)),
133 : Self::new(100, Duration::from_secs(600)),
134 : ];
135 :
136 : pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
137 : Self::new(500, Duration::from_secs(1)),
138 : Self::new(300, Duration::from_secs(60)),
139 : Self::new(200, Duration::from_secs(600)),
140 : ];
141 :
142 : // For all the sessions will be cancel key. So this limit is essentially global proxy limit.
143 : pub const DEFAULT_REDIS_SET: [Self; 2] = [
144 : Self::new(100_000, Duration::from_secs(1)),
145 : Self::new(50_000, Duration::from_secs(10)),
146 : ];
147 :
148 12 : pub fn rps(&self) -> f64 {
149 12 : (self.max_rpi as f64) / self.interval.as_secs_f64()
150 12 : }
151 :
152 3 : pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
153 8 : info.sort_unstable_by_key(|info| info.interval);
154 3 : let invalid = info
155 3 : .iter()
156 3 : .tuple_windows()
157 4 : .find(|(a, b)| a.max_rpi > b.max_rpi);
158 3 : if let Some((a, b)) = invalid {
159 1 : bail!(
160 1 : "invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
161 1 : b.max_rpi,
162 1 : a.max_rpi,
163 1 : );
164 2 : }
165 2 :
166 2 : Ok(())
167 3 : }
168 :
169 17 : pub const fn new(max_rps: u32, interval: Duration) -> Self {
170 17 : Self {
171 17 : interval,
172 17 : max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
173 17 : }
174 17 : }
175 :
176 0 : pub fn to_leaky_bucket(this: &[Self]) -> Option<LeakyBucketConfig> {
177 0 : // bit of a hack - find the min rps and max rps supported and turn it into
178 0 : // leaky bucket config instead
179 0 :
180 0 : let mut iter = this.iter().map(|info| info.rps());
181 0 : let first = iter.next()?;
182 :
183 0 : let (min, max) = (first, first);
184 0 : let (min, max) = iter.fold((min, max), |(min, max), rps| {
185 0 : (f64::min(min, rps), f64::max(max, rps))
186 0 : });
187 0 :
188 0 : Some(LeakyBucketConfig { rps: min, max })
189 0 : }
190 : }
191 :
192 : impl<K: Hash + Eq> BucketRateLimiter<K> {
193 1 : pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
194 1 : Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
195 1 : }
196 : }
197 :
198 : impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
199 2 : fn new_with_rand_and_hasher(
200 2 : info: impl Into<Cow<'static, [RateBucketInfo]>>,
201 2 : rand: R,
202 2 : hasher: S,
203 2 : ) -> Self {
204 2 : let info = info.into();
205 2 : info!(buckets = ?info, "endpoint rate limiter");
206 2 : Self {
207 2 : info,
208 2 : map: ClashMap::with_hasher_and_shard_amount(hasher, 64),
209 2 : access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
210 2 : rand: Mutex::new(rand),
211 2 : }
212 2 : }
213 :
214 : /// Check that number of connections to the endpoint is below `max_rps` rps.
215 1000454 : pub(crate) fn check(&self, key: K, n: u32) -> bool {
216 1000454 : // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
217 1000454 : // worst case memory usage is about:
218 1000454 : // = 2 * 2048 * 64 * (48B + 72B)
219 1000454 : // = 30MB
220 1000454 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
221 488 : self.do_gc();
222 999966 : }
223 :
224 1000454 : let now = Instant::now();
225 1000454 : let mut entry = self.map.entry(key).or_insert_with(|| {
226 1000001 : vec![
227 1000001 : RateBucket {
228 1000001 : start: now,
229 1000001 : count: 0,
230 1000001 : };
231 1000001 : self.info.len()
232 1000001 : ]
233 1000454 : });
234 1000454 :
235 1000454 : let should_allow_request = entry
236 1000454 : .iter_mut()
237 1000454 : .zip(&*self.info)
238 3000906 : .all(|(bucket, info)| bucket.should_allow_request(info, now, n));
239 1000454 :
240 1000454 : if should_allow_request {
241 1000450 : // only increment the bucket counts if the request will actually be accepted
242 3000900 : entry.iter_mut().for_each(|b| b.inc(n));
243 1000450 : }
244 :
245 1000454 : should_allow_request
246 1000454 : }
247 :
248 : /// Clean the map. Simple strategy: remove all entries in a random shard.
249 : /// At worst, we'll double the effective max_rps during the cleanup.
250 : /// But that way deletion does not aquire mutex on each entry access.
251 488 : pub(crate) fn do_gc(&self) {
252 488 : info!(
253 0 : "cleaning up bucket rate limiter, current size = {}",
254 0 : self.map.len()
255 : );
256 488 : let n = self.map.shards().len();
257 488 : // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
258 488 : // (impossible, infact, unless we have 2048 threads)
259 488 : let shard = self.rand.lock_propagate_poison().gen_range(0..n);
260 488 : self.map.shards()[shard].write().clear();
261 488 : }
262 : }
263 :
264 : #[cfg(test)]
265 : mod tests {
266 : use std::hash::BuildHasherDefault;
267 : use std::time::Duration;
268 :
269 : use rand::SeedableRng;
270 : use rustc_hash::FxHasher;
271 : use tokio::time;
272 :
273 : use super::{BucketRateLimiter, WakeComputeRateLimiter};
274 : use crate::intern::EndpointIdInt;
275 : use crate::rate_limiter::RateBucketInfo;
276 : use crate::types::EndpointId;
277 :
278 : #[test]
279 1 : fn rate_bucket_rpi() {
280 1 : let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
281 1 : assert_eq!(rate_bucket.max_rpi, 50 * 5);
282 :
283 1 : let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
284 1 : assert_eq!(rate_bucket.max_rpi, 50 / 2);
285 1 : }
286 :
287 : #[test]
288 1 : fn rate_bucket_parse() {
289 1 : let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
290 1 : assert_eq!(rate_bucket.interval, Duration::from_secs(10));
291 1 : assert_eq!(rate_bucket.max_rpi, 100 * 10);
292 1 : assert_eq!(rate_bucket.to_string(), "100@10s");
293 :
294 1 : let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
295 1 : assert_eq!(rate_bucket.interval, Duration::from_secs(60));
296 1 : assert_eq!(rate_bucket.max_rpi, 100 * 60);
297 1 : assert_eq!(rate_bucket.to_string(), "100@1m");
298 1 : }
299 :
300 : #[test]
301 1 : fn default_rate_buckets() {
302 1 : let mut defaults = RateBucketInfo::DEFAULT_SET;
303 1 : RateBucketInfo::validate(&mut defaults[..]).unwrap();
304 1 : }
305 :
306 : #[test]
307 : #[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
308 1 : fn rate_buckets_validate() {
309 1 : let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
310 1 : .into_iter()
311 2 : .map(|s| s.parse().unwrap())
312 1 : .collect();
313 1 : RateBucketInfo::validate(&mut rates).unwrap();
314 1 : }
315 :
316 : #[tokio::test]
317 1 : async fn test_rate_limits() {
318 1 : let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
319 1 : .into_iter()
320 2 : .map(|s| s.parse().unwrap())
321 1 : .collect();
322 1 : RateBucketInfo::validate(&mut rates).unwrap();
323 1 : let limiter = WakeComputeRateLimiter::new(rates);
324 1 :
325 1 : let endpoint = EndpointId::from("ep-my-endpoint-1234");
326 1 : let endpoint = EndpointIdInt::from(endpoint);
327 1 :
328 1 : time::pause();
329 1 :
330 101 : for _ in 0..100 {
331 100 : assert!(limiter.check(endpoint, 1));
332 1 : }
333 1 : // more connections fail
334 1 : assert!(!limiter.check(endpoint, 1));
335 1 :
336 1 : // fail even after 500ms as it's in the same bucket
337 1 : time::advance(time::Duration::from_millis(500)).await;
338 1 : assert!(!limiter.check(endpoint, 1));
339 1 :
340 1 : // after a full 1s, 100 requests are allowed again
341 1 : time::advance(time::Duration::from_millis(500)).await;
342 6 : for _ in 1..6 {
343 255 : for _ in 0..50 {
344 250 : assert!(limiter.check(endpoint, 2));
345 1 : }
346 5 : time::advance(time::Duration::from_millis(1000)).await;
347 1 : }
348 1 :
349 1 : // more connections after 600 will exceed the 20rps@30s limit
350 1 : assert!(!limiter.check(endpoint, 1));
351 1 :
352 1 : // will still fail before the 30 second limit
353 1 : time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
354 1 : assert!(!limiter.check(endpoint, 1));
355 1 :
356 1 : // after the full 30 seconds, 100 requests are allowed again
357 1 : time::advance(time::Duration::from_millis(1)).await;
358 101 : for _ in 0..100 {
359 100 : assert!(limiter.check(endpoint, 1));
360 1 : }
361 1 : }
362 :
363 : #[tokio::test]
364 1 : async fn test_rate_limits_gc() {
365 1 : // fixed seeded random/hasher to ensure that the test is not flaky
366 1 : let rand = rand::rngs::StdRng::from_seed([1; 32]);
367 1 : let hasher = BuildHasherDefault::<FxHasher>::default();
368 1 :
369 1 : let limiter =
370 1 : BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
371 1000001 : for i in 0..1_000_000 {
372 1000000 : limiter.check(i, 1);
373 1000000 : }
374 1 : assert!(limiter.map.len() < 150_000);
375 1 : }
376 : }
|