Line data Source code
1 : use std::hash::Hash;
2 : use std::sync::atomic::{AtomicUsize, Ordering};
3 :
4 : use ahash::RandomState;
5 : use clashmap::ClashMap;
6 : use rand::{Rng, thread_rng};
7 : use tokio::time::Instant;
8 : use tracing::info;
9 : use utils::leaky_bucket::LeakyBucketState;
10 :
11 : use crate::intern::EndpointIdInt;
12 :
13 : // Simple per-endpoint rate limiter.
14 : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
15 :
16 : pub struct LeakyBucketRateLimiter<Key> {
17 : map: ClashMap<Key, LeakyBucketState, RandomState>,
18 : default_config: utils::leaky_bucket::LeakyBucketConfig,
19 : access_count: AtomicUsize,
20 : }
21 :
22 : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
23 : pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
24 : rps: 600.0,
25 : max: 1500.0,
26 : };
27 :
28 3 : pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
29 3 : Self {
30 3 : map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
31 3 : default_config: config.into(),
32 3 : access_count: AtomicUsize::new(0),
33 3 : }
34 3 : }
35 :
36 : /// Check that number of connections to the endpoint is below `max_rps` rps.
37 3 : pub(crate) fn check(&self, key: K, config: Option<LeakyBucketConfig>, n: u32) -> bool {
38 3 : let now = Instant::now();
39 :
40 3 : let config = config.map_or(self.default_config, Into::into);
41 :
42 3 : if self
43 3 : .access_count
44 3 : .fetch_add(1, Ordering::AcqRel)
45 3 : .is_multiple_of(2048)
46 3 : {
47 3 : self.do_gc(now);
48 3 : }
49 :
50 3 : let mut entry = self
51 3 : .map
52 3 : .entry(key)
53 3 : .or_insert_with(|| LeakyBucketState { empty_at: now });
54 :
55 3 : entry.add_tokens(&config, now, n as f64).is_ok()
56 3 : }
57 :
58 3 : fn do_gc(&self, now: Instant) {
59 3 : info!(
60 0 : "cleaning up bucket rate limiter, current size = {}",
61 0 : self.map.len()
62 : );
63 3 : let n = self.map.shards().len();
64 3 : let shard = thread_rng().gen_range(0..n);
65 3 : self.map.shards()[shard]
66 3 : .write()
67 3 : .retain(|(_, value)| !value.bucket_is_empty(now));
68 3 : }
69 : }
70 :
71 : pub struct LeakyBucketConfig {
72 : pub rps: f64,
73 : pub max: f64,
74 : }
75 :
76 : impl LeakyBucketConfig {
77 1 : pub fn new(rps: f64, max: f64) -> Self {
78 1 : assert!(rps > 0.0, "rps must be positive");
79 1 : assert!(max > 0.0, "max must be positive");
80 1 : Self { rps, max }
81 1 : }
82 : }
83 :
84 : impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
85 4 : fn from(config: LeakyBucketConfig) -> Self {
86 4 : utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max)
87 4 : }
88 : }
89 :
90 : #[cfg(test)]
91 : #[allow(clippy::float_cmp)]
92 : mod tests {
93 : use std::time::Duration;
94 :
95 : use tokio::time::Instant;
96 : use utils::leaky_bucket::LeakyBucketState;
97 :
98 : use super::LeakyBucketConfig;
99 :
100 : #[tokio::test(start_paused = true)]
101 1 : async fn check() {
102 1 : let config: utils::leaky_bucket::LeakyBucketConfig =
103 1 : LeakyBucketConfig::new(500.0, 2000.0).into();
104 1 : assert_eq!(config.cost, Duration::from_millis(2));
105 1 : assert_eq!(config.bucket_width, Duration::from_secs(4));
106 :
107 1 : let mut bucket = LeakyBucketState {
108 1 : empty_at: Instant::now(),
109 1 : };
110 :
111 : // should work for 2000 requests this second
112 2001 : for _ in 0..2000 {
113 2000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
114 2000 : }
115 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
116 1 : assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
117 :
118 : // in 1ms we should drain 0.5 tokens.
119 : // make sure we don't lose any tokens
120 1 : tokio::time::advance(Duration::from_millis(1)).await;
121 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
122 1 : tokio::time::advance(Duration::from_millis(1)).await;
123 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
124 :
125 : // in 10ms we should drain 5 tokens
126 1 : tokio::time::advance(Duration::from_millis(10)).await;
127 6 : for _ in 0..5 {
128 5 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
129 5 : }
130 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
131 :
132 : // in 10s we should drain 5000 tokens
133 : // but cap is only 2000
134 1 : tokio::time::advance(Duration::from_secs(10)).await;
135 2001 : for _ in 0..2000 {
136 2000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
137 2000 : }
138 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
139 :
140 : // should sustain 500rps
141 2001 : for _ in 0..2000 {
142 2000 : tokio::time::advance(Duration::from_millis(10)).await;
143 12000 : for _ in 0..5 {
144 10000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
145 10000 : }
146 1 : }
147 1 : }
148 : }
|