Line data Source code
1 : use std::hash::Hash;
2 : use std::sync::atomic::{AtomicUsize, Ordering};
3 :
4 : use ahash::RandomState;
5 : use clashmap::ClashMap;
6 : use rand::{Rng, thread_rng};
7 : use tokio::time::Instant;
8 : use tracing::info;
9 : use utils::leaky_bucket::LeakyBucketState;
10 :
11 : use crate::intern::EndpointIdInt;
12 :
13 : // Simple per-endpoint rate limiter.
14 : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
15 :
16 : pub struct LeakyBucketRateLimiter<Key> {
17 : map: ClashMap<Key, LeakyBucketState, RandomState>,
18 : default_config: utils::leaky_bucket::LeakyBucketConfig,
19 : access_count: AtomicUsize,
20 : }
21 :
22 : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
23 : pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
24 : rps: 600.0,
25 : max: 1500.0,
26 : };
27 :
28 3 : pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
29 3 : Self {
30 3 : map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
31 3 : default_config: config.into(),
32 3 : access_count: AtomicUsize::new(0),
33 3 : }
34 3 : }
35 :
36 : /// Check that number of connections to the endpoint is below `max_rps` rps.
37 3 : pub(crate) fn check(&self, key: K, config: Option<LeakyBucketConfig>, n: u32) -> bool {
38 3 : let now = Instant::now();
39 3 :
40 3 : let config = config.map_or(self.default_config, Into::into);
41 3 :
42 3 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
43 3 : self.do_gc(now);
44 3 : }
45 :
46 3 : let mut entry = self
47 3 : .map
48 3 : .entry(key)
49 3 : .or_insert_with(|| LeakyBucketState { empty_at: now });
50 3 :
51 3 : entry.add_tokens(&config, now, n as f64).is_ok()
52 3 : }
53 :
54 3 : fn do_gc(&self, now: Instant) {
55 3 : info!(
56 0 : "cleaning up bucket rate limiter, current size = {}",
57 0 : self.map.len()
58 : );
59 3 : let n = self.map.shards().len();
60 3 : let shard = thread_rng().gen_range(0..n);
61 3 : self.map.shards()[shard]
62 3 : .write()
63 3 : .retain(|(_, value)| !value.bucket_is_empty(now));
64 3 : }
65 : }
66 :
67 : pub struct LeakyBucketConfig {
68 : pub rps: f64,
69 : pub max: f64,
70 : }
71 :
72 : impl LeakyBucketConfig {
73 1 : pub fn new(rps: f64, max: f64) -> Self {
74 1 : assert!(rps > 0.0, "rps must be positive");
75 1 : assert!(max > 0.0, "max must be positive");
76 1 : Self { rps, max }
77 1 : }
78 : }
79 :
80 : impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
81 4 : fn from(config: LeakyBucketConfig) -> Self {
82 4 : utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max)
83 4 : }
84 : }
85 :
86 : #[cfg(test)]
87 : #[allow(clippy::float_cmp)]
88 : mod tests {
89 : use std::time::Duration;
90 :
91 : use tokio::time::Instant;
92 : use utils::leaky_bucket::LeakyBucketState;
93 :
94 : use super::LeakyBucketConfig;
95 :
96 : #[tokio::test(start_paused = true)]
97 1 : async fn check() {
98 1 : let config: utils::leaky_bucket::LeakyBucketConfig =
99 1 : LeakyBucketConfig::new(500.0, 2000.0).into();
100 1 : assert_eq!(config.cost, Duration::from_millis(2));
101 1 : assert_eq!(config.bucket_width, Duration::from_secs(4));
102 1 :
103 1 : let mut bucket = LeakyBucketState {
104 1 : empty_at: Instant::now(),
105 1 : };
106 1 :
107 1 : // should work for 2000 requests this second
108 2001 : for _ in 0..2000 {
109 2000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
110 2000 : }
111 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
112 1 : assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
113 1 :
114 1 : // in 1ms we should drain 0.5 tokens.
115 1 : // make sure we don't lose any tokens
116 1 : tokio::time::advance(Duration::from_millis(1)).await;
117 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
118 1 : tokio::time::advance(Duration::from_millis(1)).await;
119 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
120 1 :
121 1 : // in 10ms we should drain 5 tokens
122 1 : tokio::time::advance(Duration::from_millis(10)).await;
123 6 : for _ in 0..5 {
124 5 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
125 5 : }
126 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
127 1 :
128 1 : // in 10s we should drain 5000 tokens
129 1 : // but cap is only 2000
130 1 : tokio::time::advance(Duration::from_secs(10)).await;
131 2001 : for _ in 0..2000 {
132 2000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
133 2000 : }
134 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
135 1 :
136 1 : // should sustain 500rps
137 2001 : for _ in 0..2000 {
138 2000 : tokio::time::advance(Duration::from_millis(10)).await;
139 12000 : for _ in 0..5 {
140 10000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
141 10000 : }
142 1 : }
143 1 : }
144 : }
|