Line data Source code
1 : use std::{
2 : borrow::Cow,
3 : collections::hash_map::RandomState,
4 : hash::{BuildHasher, Hash},
5 : sync::{
6 : atomic::{AtomicUsize, Ordering},
7 : Mutex,
8 : },
9 : };
10 :
11 : use anyhow::bail;
12 : use dashmap::DashMap;
13 : use itertools::Itertools;
14 : use rand::{rngs::StdRng, Rng, SeedableRng};
15 : use tokio::time::{Duration, Instant};
16 : use tracing::info;
17 :
18 : use crate::intern::EndpointIdInt;
19 :
20 : pub(crate) struct GlobalRateLimiter {
21 : data: Vec<RateBucket>,
22 : info: Vec<RateBucketInfo>,
23 : }
24 :
25 : impl GlobalRateLimiter {
26 0 : pub(crate) fn new(info: Vec<RateBucketInfo>) -> Self {
27 0 : Self {
28 0 : data: vec![
29 0 : RateBucket {
30 0 : start: Instant::now(),
31 0 : count: 0,
32 0 : };
33 0 : info.len()
34 0 : ],
35 0 : info,
36 0 : }
37 0 : }
38 :
39 : /// Check that number of connections is below `max_rps` rps.
40 0 : pub(crate) fn check(&mut self) -> bool {
41 0 : let now = Instant::now();
42 0 :
43 0 : let should_allow_request = self
44 0 : .data
45 0 : .iter_mut()
46 0 : .zip(&self.info)
47 0 : .all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
48 0 :
49 0 : if should_allow_request {
50 0 : // only increment the bucket counts if the request will actually be accepted
51 0 : self.data.iter_mut().for_each(|b| b.inc(1));
52 0 : }
53 :
54 0 : should_allow_request
55 0 : }
56 : }
57 :
58 : // Simple per-endpoint rate limiter.
59 : //
60 : // Check that number of connections to the endpoint is below `max_rps` rps.
61 : // Purposefully ignore user name and database name as clients can reconnect
62 : // with different names, so we'll end up sending some http requests to
63 : // the control plane.
64 : pub type WakeComputeRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
65 :
66 : pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
67 : map: DashMap<Key, Vec<RateBucket>, Hasher>,
68 : info: Cow<'static, [RateBucketInfo]>,
69 : access_count: AtomicUsize,
70 : rand: Mutex<Rand>,
71 : }
72 :
73 : #[derive(Clone, Copy)]
74 : struct RateBucket {
75 : start: Instant,
76 : count: u32,
77 : }
78 :
79 : impl RateBucket {
80 18005490 : fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
81 18005490 : if now - self.start < info.interval {
82 18005442 : self.count + n <= info.max_rpi
83 : } else {
84 : // bucket expired, reset
85 48 : self.count = 0;
86 48 : self.start = now;
87 48 :
88 48 : true
89 : }
90 18005490 : }
91 :
92 18005454 : fn inc(&mut self, n: u32) {
93 18005454 : self.count += n;
94 18005454 : }
95 : }
96 :
97 : #[derive(Clone, Copy, PartialEq)]
98 : pub struct RateBucketInfo {
99 : pub(crate) interval: Duration,
100 : // requests per interval
101 : pub(crate) max_rpi: u32,
102 : }
103 :
104 : impl std::fmt::Display for RateBucketInfo {
105 114 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 114 : let rps = self.rps().floor() as u64;
107 114 : write!(f, "{rps}@{}", humantime::format_duration(self.interval))
108 114 : }
109 : }
110 :
111 : impl std::fmt::Debug for RateBucketInfo {
112 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 0 : write!(f, "{self}")
114 0 : }
115 : }
116 :
117 : impl std::str::FromStr for RateBucketInfo {
118 : type Err = anyhow::Error;
119 :
120 192 : fn from_str(s: &str) -> Result<Self, Self::Err> {
121 192 : let Some((max_rps, interval)) = s.split_once('@') else {
122 0 : bail!("invalid rate info")
123 : };
124 192 : let max_rps = max_rps.parse()?;
125 192 : let interval = humantime::parse_duration(interval)?;
126 192 : Ok(Self::new(max_rps, interval))
127 192 : }
128 : }
129 :
130 : impl RateBucketInfo {
131 : pub const DEFAULT_SET: [Self; 3] = [
132 : Self::new(300, Duration::from_secs(1)),
133 : Self::new(200, Duration::from_secs(60)),
134 : Self::new(100, Duration::from_secs(600)),
135 : ];
136 :
137 : pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
138 : Self::new(500, Duration::from_secs(1)),
139 : Self::new(300, Duration::from_secs(60)),
140 : Self::new(200, Duration::from_secs(600)),
141 : ];
142 :
143 114 : pub fn rps(&self) -> f64 {
144 114 : (self.max_rpi as f64) / self.interval.as_secs_f64()
145 114 : }
146 :
147 18 : pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
148 48 : info.sort_unstable_by_key(|info| info.interval);
149 18 : let invalid = info
150 18 : .iter()
151 18 : .tuple_windows()
152 24 : .find(|(a, b)| a.max_rpi > b.max_rpi);
153 18 : if let Some((a, b)) = invalid {
154 6 : bail!(
155 6 : "invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
156 6 : b.max_rpi,
157 6 : a.max_rpi,
158 6 : );
159 12 : }
160 12 :
161 12 : Ok(())
162 18 : }
163 :
164 216 : pub const fn new(max_rps: u32, interval: Duration) -> Self {
165 216 : Self {
166 216 : interval,
167 216 : max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
168 216 : }
169 216 : }
170 : }
171 :
172 : impl<K: Hash + Eq> BucketRateLimiter<K> {
173 24 : pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
174 24 : Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
175 24 : }
176 : }
177 :
178 : impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
179 30 : fn new_with_rand_and_hasher(
180 30 : info: impl Into<Cow<'static, [RateBucketInfo]>>,
181 30 : rand: R,
182 30 : hasher: S,
183 30 : ) -> Self {
184 30 : let info = info.into();
185 30 : info!(buckets = ?info, "endpoint rate limiter");
186 30 : Self {
187 30 : info,
188 30 : map: DashMap::with_hasher_and_shard_amount(hasher, 64),
189 30 : access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
190 30 : rand: Mutex::new(rand),
191 30 : }
192 30 : }
193 :
194 : /// Check that number of connections to the endpoint is below `max_rps` rps.
195 6002742 : pub(crate) fn check(&self, key: K, n: u32) -> bool {
196 6002742 : // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
197 6002742 : // worst case memory usage is about:
198 6002742 : // = 2 * 2048 * 64 * (48B + 72B)
199 6002742 : // = 30MB
200 6002742 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
201 2928 : self.do_gc();
202 5999814 : }
203 :
204 6002742 : let now = Instant::now();
205 6002742 : let mut entry = self.map.entry(key).or_insert_with(|| {
206 6000024 : vec![
207 6000024 : RateBucket {
208 6000024 : start: now,
209 6000024 : count: 0,
210 6000024 : };
211 6000024 : self.info.len()
212 6000024 : ]
213 6002742 : });
214 6002742 :
215 6002742 : let should_allow_request = entry
216 6002742 : .iter_mut()
217 6002742 : .zip(&*self.info)
218 18005490 : .all(|(bucket, info)| bucket.should_allow_request(info, now, n));
219 6002742 :
220 6002742 : if should_allow_request {
221 6002718 : // only increment the bucket counts if the request will actually be accepted
222 18005454 : entry.iter_mut().for_each(|b| b.inc(n));
223 6002718 : }
224 :
225 6002742 : should_allow_request
226 6002742 : }
227 :
228 : /// Clean the map. Simple strategy: remove all entries in a random shard.
229 : /// At worst, we'll double the effective max_rps during the cleanup.
230 : /// But that way deletion does not aquire mutex on each entry access.
231 2928 : pub(crate) fn do_gc(&self) {
232 2928 : info!(
233 0 : "cleaning up bucket rate limiter, current size = {}",
234 0 : self.map.len()
235 : );
236 2928 : let n = self.map.shards().len();
237 2928 : // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
238 2928 : // (impossible, infact, unless we have 2048 threads)
239 2928 : let shard = self.rand.lock().unwrap().gen_range(0..n);
240 2928 : self.map.shards()[shard].write().clear();
241 2928 : }
242 : }
243 :
244 : #[cfg(test)]
245 : mod tests {
246 : use std::{hash::BuildHasherDefault, time::Duration};
247 :
248 : use rand::SeedableRng;
249 : use rustc_hash::FxHasher;
250 : use tokio::time;
251 :
252 : use super::{BucketRateLimiter, WakeComputeRateLimiter};
253 : use crate::{intern::EndpointIdInt, rate_limiter::RateBucketInfo, EndpointId};
254 :
255 : #[test]
256 6 : fn rate_bucket_rpi() {
257 6 : let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
258 6 : assert_eq!(rate_bucket.max_rpi, 50 * 5);
259 :
260 6 : let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
261 6 : assert_eq!(rate_bucket.max_rpi, 50 / 2);
262 6 : }
263 :
264 : #[test]
265 6 : fn rate_bucket_parse() {
266 6 : let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
267 6 : assert_eq!(rate_bucket.interval, Duration::from_secs(10));
268 6 : assert_eq!(rate_bucket.max_rpi, 100 * 10);
269 6 : assert_eq!(rate_bucket.to_string(), "100@10s");
270 :
271 6 : let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
272 6 : assert_eq!(rate_bucket.interval, Duration::from_secs(60));
273 6 : assert_eq!(rate_bucket.max_rpi, 100 * 60);
274 6 : assert_eq!(rate_bucket.to_string(), "100@1m");
275 6 : }
276 :
277 : #[test]
278 6 : fn default_rate_buckets() {
279 6 : let mut defaults = RateBucketInfo::DEFAULT_SET;
280 6 : RateBucketInfo::validate(&mut defaults[..]).unwrap();
281 6 : }
282 :
283 : #[test]
284 : #[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
285 6 : fn rate_buckets_validate() {
286 6 : let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
287 6 : .into_iter()
288 12 : .map(|s| s.parse().unwrap())
289 6 : .collect();
290 6 : RateBucketInfo::validate(&mut rates).unwrap();
291 6 : }
292 :
293 : #[tokio::test]
294 6 : async fn test_rate_limits() {
295 6 : let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
296 6 : .into_iter()
297 12 : .map(|s| s.parse().unwrap())
298 6 : .collect();
299 6 : RateBucketInfo::validate(&mut rates).unwrap();
300 6 : let limiter = WakeComputeRateLimiter::new(rates);
301 6 :
302 6 : let endpoint = EndpointId::from("ep-my-endpoint-1234");
303 6 : let endpoint = EndpointIdInt::from(endpoint);
304 6 :
305 6 : time::pause();
306 6 :
307 606 : for _ in 0..100 {
308 600 : assert!(limiter.check(endpoint, 1));
309 6 : }
310 6 : // more connections fail
311 6 : assert!(!limiter.check(endpoint, 1));
312 6 :
313 6 : // fail even after 500ms as it's in the same bucket
314 6 : time::advance(time::Duration::from_millis(500)).await;
315 6 : assert!(!limiter.check(endpoint, 1));
316 6 :
317 6 : // after a full 1s, 100 requests are allowed again
318 6 : time::advance(time::Duration::from_millis(500)).await;
319 36 : for _ in 1..6 {
320 1530 : for _ in 0..50 {
321 1500 : assert!(limiter.check(endpoint, 2));
322 6 : }
323 30 : time::advance(time::Duration::from_millis(1000)).await;
324 6 : }
325 6 :
326 6 : // more connections after 600 will exceed the 20rps@30s limit
327 6 : assert!(!limiter.check(endpoint, 1));
328 6 :
329 6 : // will still fail before the 30 second limit
330 6 : time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
331 6 : assert!(!limiter.check(endpoint, 1));
332 6 :
333 6 : // after the full 30 seconds, 100 requests are allowed again
334 6 : time::advance(time::Duration::from_millis(1)).await;
335 606 : for _ in 0..100 {
336 600 : assert!(limiter.check(endpoint, 1));
337 6 : }
338 6 : }
339 :
340 : #[tokio::test]
341 6 : async fn test_rate_limits_gc() {
342 6 : // fixed seeded random/hasher to ensure that the test is not flaky
343 6 : let rand = rand::rngs::StdRng::from_seed([1; 32]);
344 6 : let hasher = BuildHasherDefault::<FxHasher>::default();
345 6 :
346 6 : let limiter =
347 6 : BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
348 6000006 : for i in 0..1_000_000 {
349 6000000 : limiter.check(i, 1);
350 6000000 : }
351 6 : assert!(limiter.map.len() < 150_000);
352 6 : }
353 : }
|