Line data Source code
1 : use std::borrow::Cow;
2 : use std::collections::hash_map::RandomState;
3 : use std::hash::{BuildHasher, Hash};
4 : use std::sync::Mutex;
5 : use std::sync::atomic::{AtomicUsize, Ordering};
6 :
7 : use anyhow::bail;
8 : use clashmap::ClashMap;
9 : use itertools::Itertools;
10 : use rand::rngs::StdRng;
11 : use rand::{Rng, SeedableRng};
12 : use tokio::time::{Duration, Instant};
13 : use tracing::info;
14 :
15 : use super::LeakyBucketConfig;
16 : use crate::ext::LockExt;
17 : use crate::intern::EndpointIdInt;
18 :
19 : // Simple per-endpoint rate limiter.
20 : //
21 : // Check that number of connections to the endpoint is below `max_rps` rps.
22 : // Purposefully ignore user name and database name as clients can reconnect
23 : // with different names, so we'll end up sending some http requests to
24 : // the control plane.
25 : pub type WakeComputeRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
26 :
27 : pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
28 : map: ClashMap<Key, Vec<RateBucket>, Hasher>,
29 : info: Cow<'static, [RateBucketInfo]>,
30 : access_count: AtomicUsize,
31 : rand: Mutex<Rand>,
32 : }
33 :
34 : #[derive(Clone, Copy)]
35 : struct RateBucket {
36 : start: Instant,
37 : count: u32,
38 : }
39 :
40 : impl RateBucket {
41 3000906 : fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
42 3000906 : if now - self.start < info.interval {
43 3000898 : self.count + n <= info.max_rpi
44 : } else {
45 : // bucket expired, reset
46 8 : self.count = 0;
47 8 : self.start = now;
48 :
49 8 : true
50 : }
51 3000906 : }
52 :
53 3000900 : fn inc(&mut self, n: u32) {
54 3000900 : self.count += n;
55 3000900 : }
56 : }
57 :
58 : #[derive(Clone, Copy, PartialEq)]
59 : pub struct RateBucketInfo {
60 : pub(crate) interval: Duration,
61 : // requests per interval
62 : pub(crate) max_rpi: u32,
63 : }
64 :
65 : impl std::fmt::Display for RateBucketInfo {
66 10 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 10 : let rps = self.rps().floor() as u64;
68 10 : write!(f, "{rps}@{}", humantime::format_duration(self.interval))
69 10 : }
70 : }
71 :
72 : impl std::fmt::Debug for RateBucketInfo {
73 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 0 : write!(f, "{self}")
75 0 : }
76 : }
77 :
78 : impl std::str::FromStr for RateBucketInfo {
79 : type Err = anyhow::Error;
80 :
81 11 : fn from_str(s: &str) -> Result<Self, Self::Err> {
82 11 : let Some((max_rps, interval)) = s.split_once('@') else {
83 0 : bail!("invalid rate info")
84 : };
85 11 : let max_rps = max_rps.parse()?;
86 11 : let interval = humantime::parse_duration(interval)?;
87 11 : Ok(Self::new(max_rps, interval))
88 11 : }
89 : }
90 :
91 : impl RateBucketInfo {
92 : pub const DEFAULT_SET: [Self; 3] = [
93 : Self::new(300, Duration::from_secs(1)),
94 : Self::new(200, Duration::from_secs(60)),
95 : Self::new(100, Duration::from_secs(600)),
96 : ];
97 :
98 : pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
99 : Self::new(500, Duration::from_secs(1)),
100 : Self::new(300, Duration::from_secs(60)),
101 : Self::new(200, Duration::from_secs(600)),
102 : ];
103 :
104 10 : pub fn rps(&self) -> f64 {
105 10 : (self.max_rpi as f64) / self.interval.as_secs_f64()
106 10 : }
107 :
108 3 : pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
109 3 : info.sort_unstable_by_key(|info| info.interval);
110 3 : let invalid = info
111 3 : .iter()
112 3 : .tuple_windows()
113 4 : .find(|(a, b)| a.max_rpi > b.max_rpi);
114 3 : if let Some((a, b)) = invalid {
115 1 : bail!(
116 1 : "invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
117 : b.max_rpi,
118 : a.max_rpi,
119 : );
120 2 : }
121 :
122 2 : Ok(())
123 3 : }
124 :
125 15 : pub const fn new(max_rps: u32, interval: Duration) -> Self {
126 15 : Self {
127 15 : interval,
128 15 : max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
129 15 : }
130 15 : }
131 :
132 0 : pub fn to_leaky_bucket(this: &[Self]) -> Option<LeakyBucketConfig> {
133 : // bit of a hack - find the min rps and max rps supported and turn it into
134 : // leaky bucket config instead
135 :
136 0 : let mut iter = this.iter().map(|info| info.rps());
137 0 : let first = iter.next()?;
138 :
139 0 : let (min, max) = (first, first);
140 0 : let (min, max) = iter.fold((min, max), |(min, max), rps| {
141 0 : (f64::min(min, rps), f64::max(max, rps))
142 0 : });
143 :
144 0 : Some(LeakyBucketConfig { rps: min, max })
145 0 : }
146 : }
147 :
148 : impl<K: Hash + Eq> BucketRateLimiter<K> {
149 1 : pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
150 1 : Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
151 1 : }
152 : }
153 :
154 : impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
155 2 : fn new_with_rand_and_hasher(
156 2 : info: impl Into<Cow<'static, [RateBucketInfo]>>,
157 2 : rand: R,
158 2 : hasher: S,
159 2 : ) -> Self {
160 2 : let info = info.into();
161 2 : info!(buckets = ?info, "endpoint rate limiter");
162 2 : Self {
163 2 : info,
164 2 : map: ClashMap::with_hasher_and_shard_amount(hasher, 64),
165 2 : access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
166 2 : rand: Mutex::new(rand),
167 2 : }
168 2 : }
169 :
170 : /// Check that number of connections to the endpoint is below `max_rps` rps.
171 1000454 : pub(crate) fn check(&self, key: K, n: u32) -> bool {
172 : // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
173 : // worst case memory usage is about:
174 : // = 2 * 2048 * 64 * (48B + 72B)
175 : // = 30MB
176 1000454 : if self
177 1000454 : .access_count
178 1000454 : .fetch_add(1, Ordering::AcqRel)
179 1000454 : .is_multiple_of(2048)
180 488 : {
181 488 : self.do_gc();
182 999966 : }
183 :
184 1000454 : let now = Instant::now();
185 1000454 : let mut entry = self.map.entry(key).or_insert_with(|| {
186 1000001 : vec![
187 1000001 : RateBucket {
188 1000001 : start: now,
189 1000001 : count: 0,
190 1000001 : };
191 1000001 : self.info.len()
192 : ]
193 1000001 : });
194 :
195 1000454 : let should_allow_request = entry
196 1000454 : .iter_mut()
197 1000454 : .zip(&*self.info)
198 3000906 : .all(|(bucket, info)| bucket.should_allow_request(info, now, n));
199 :
200 1000454 : if should_allow_request {
201 : // only increment the bucket counts if the request will actually be accepted
202 3000900 : entry.iter_mut().for_each(|b| b.inc(n));
203 4 : }
204 :
205 1000454 : should_allow_request
206 1000454 : }
207 :
208 : /// Clean the map. Simple strategy: remove all entries in a random shard.
209 : /// At worst, we'll double the effective max_rps during the cleanup.
210 : /// But that way deletion does not aquire mutex on each entry access.
211 488 : pub(crate) fn do_gc(&self) {
212 488 : info!(
213 0 : "cleaning up bucket rate limiter, current size = {}",
214 0 : self.map.len()
215 : );
216 488 : let n = self.map.shards().len();
217 : // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
218 : // (impossible, infact, unless we have 2048 threads)
219 488 : let shard = self.rand.lock_propagate_poison().gen_range(0..n);
220 488 : self.map.shards()[shard].write().clear();
221 488 : }
222 : }
223 :
224 : #[cfg(test)]
225 : mod tests {
226 : use std::hash::BuildHasherDefault;
227 : use std::time::Duration;
228 :
229 : use rand::SeedableRng;
230 : use rustc_hash::FxHasher;
231 : use tokio::time;
232 :
233 : use super::{BucketRateLimiter, WakeComputeRateLimiter};
234 : use crate::intern::EndpointIdInt;
235 : use crate::rate_limiter::RateBucketInfo;
236 : use crate::types::EndpointId;
237 :
238 : #[test]
239 1 : fn rate_bucket_rpi() {
240 1 : let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
241 1 : assert_eq!(rate_bucket.max_rpi, 50 * 5);
242 :
243 1 : let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
244 1 : assert_eq!(rate_bucket.max_rpi, 50 / 2);
245 1 : }
246 :
247 : #[test]
248 1 : fn rate_bucket_parse() {
249 1 : let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
250 1 : assert_eq!(rate_bucket.interval, Duration::from_secs(10));
251 1 : assert_eq!(rate_bucket.max_rpi, 100 * 10);
252 1 : assert_eq!(rate_bucket.to_string(), "100@10s");
253 :
254 1 : let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
255 1 : assert_eq!(rate_bucket.interval, Duration::from_secs(60));
256 1 : assert_eq!(rate_bucket.max_rpi, 100 * 60);
257 1 : assert_eq!(rate_bucket.to_string(), "100@1m");
258 1 : }
259 :
260 : #[test]
261 1 : fn default_rate_buckets() {
262 1 : let mut defaults = RateBucketInfo::DEFAULT_SET;
263 1 : RateBucketInfo::validate(&mut defaults[..]).unwrap();
264 1 : }
265 :
266 : #[test]
267 : #[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
268 1 : fn rate_buckets_validate() {
269 1 : let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
270 1 : .into_iter()
271 2 : .map(|s| s.parse().unwrap())
272 1 : .collect();
273 1 : RateBucketInfo::validate(&mut rates).unwrap();
274 1 : }
275 :
276 : #[tokio::test]
277 1 : async fn test_rate_limits() {
278 1 : let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
279 1 : .into_iter()
280 2 : .map(|s| s.parse().unwrap())
281 1 : .collect();
282 1 : RateBucketInfo::validate(&mut rates).unwrap();
283 1 : let limiter = WakeComputeRateLimiter::new(rates);
284 :
285 1 : let endpoint = EndpointId::from("ep-my-endpoint-1234");
286 1 : let endpoint = EndpointIdInt::from(endpoint);
287 :
288 1 : time::pause();
289 :
290 101 : for _ in 0..100 {
291 100 : assert!(limiter.check(endpoint, 1));
292 : }
293 : // more connections fail
294 1 : assert!(!limiter.check(endpoint, 1));
295 :
296 : // fail even after 500ms as it's in the same bucket
297 1 : time::advance(time::Duration::from_millis(500)).await;
298 1 : assert!(!limiter.check(endpoint, 1));
299 :
300 : // after a full 1s, 100 requests are allowed again
301 1 : time::advance(time::Duration::from_millis(500)).await;
302 6 : for _ in 1..6 {
303 255 : for _ in 0..50 {
304 250 : assert!(limiter.check(endpoint, 2));
305 : }
306 5 : time::advance(time::Duration::from_millis(1000)).await;
307 : }
308 :
309 : // more connections after 600 will exceed the 20rps@30s limit
310 1 : assert!(!limiter.check(endpoint, 1));
311 :
312 : // will still fail before the 30 second limit
313 1 : time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
314 1 : assert!(!limiter.check(endpoint, 1));
315 :
316 : // after the full 30 seconds, 100 requests are allowed again
317 1 : time::advance(time::Duration::from_millis(1)).await;
318 101 : for _ in 0..100 {
319 100 : assert!(limiter.check(endpoint, 1));
320 1 : }
321 1 : }
322 :
323 : #[tokio::test]
324 1 : async fn test_rate_limits_gc() {
325 : // fixed seeded random/hasher to ensure that the test is not flaky
326 1 : let rand = rand::rngs::StdRng::from_seed([1; 32]);
327 1 : let hasher = BuildHasherDefault::<FxHasher>::default();
328 :
329 1 : let limiter =
330 1 : BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
331 1000001 : for i in 0..1_000_000 {
332 1000000 : limiter.check(i, 1);
333 1000000 : }
334 1 : assert!(limiter.map.len() < 150_000);
335 1 : }
336 : }
|