Line data Source code
1 : use std::hash::Hash;
2 : use std::sync::atomic::{AtomicUsize, Ordering};
3 :
4 : use ahash::RandomState;
5 : use dashmap::DashMap;
6 : use rand::{thread_rng, Rng};
7 : use tokio::time::Instant;
8 : use tracing::info;
9 : use utils::leaky_bucket::LeakyBucketState;
10 :
11 : use crate::intern::EndpointIdInt;
12 :
13 : // Simple per-endpoint rate limiter.
14 : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
15 :
16 : pub struct LeakyBucketRateLimiter<Key> {
17 : map: DashMap<Key, LeakyBucketState, RandomState>,
18 : config: utils::leaky_bucket::LeakyBucketConfig,
19 : access_count: AtomicUsize,
20 : }
21 :
22 : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
23 : pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
24 : rps: 600.0,
25 : max: 1500.0,
26 : };
27 :
28 3 : pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
29 3 : Self {
30 3 : map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
31 3 : config: config.into(),
32 3 : access_count: AtomicUsize::new(0),
33 3 : }
34 3 : }
35 :
36 : /// Check that number of connections to the endpoint is below `max_rps` rps.
37 3 : pub(crate) fn check(&self, key: K, n: u32) -> bool {
38 3 : let now = Instant::now();
39 3 :
40 3 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
41 3 : self.do_gc(now);
42 3 : }
43 :
44 3 : let mut entry = self
45 3 : .map
46 3 : .entry(key)
47 3 : .or_insert_with(|| LeakyBucketState { empty_at: now });
48 3 :
49 3 : entry.add_tokens(&self.config, now, n as f64).is_ok()
50 3 : }
51 :
52 3 : fn do_gc(&self, now: Instant) {
53 3 : info!(
54 0 : "cleaning up bucket rate limiter, current size = {}",
55 0 : self.map.len()
56 : );
57 3 : let n = self.map.shards().len();
58 3 : let shard = thread_rng().gen_range(0..n);
59 3 : self.map.shards()[shard]
60 3 : .write()
61 3 : .retain(|_, value| !value.get().bucket_is_empty(now));
62 3 : }
63 : }
64 :
65 : pub struct LeakyBucketConfig {
66 : pub rps: f64,
67 : pub max: f64,
68 : }
69 :
70 : #[cfg(test)]
71 : impl LeakyBucketConfig {
72 1 : pub(crate) fn new(rps: f64, max: f64) -> Self {
73 1 : assert!(rps > 0.0, "rps must be positive");
74 1 : assert!(max > 0.0, "max must be positive");
75 1 : Self { rps, max }
76 1 : }
77 : }
78 :
79 : impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
80 4 : fn from(config: LeakyBucketConfig) -> Self {
81 4 : utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max)
82 4 : }
83 : }
84 :
85 : #[cfg(test)]
86 : #[allow(clippy::float_cmp)]
87 : mod tests {
88 : use std::time::Duration;
89 :
90 : use tokio::time::Instant;
91 : use utils::leaky_bucket::LeakyBucketState;
92 :
93 : use super::LeakyBucketConfig;
94 :
95 : #[tokio::test(start_paused = true)]
96 1 : async fn check() {
97 1 : let config: utils::leaky_bucket::LeakyBucketConfig =
98 1 : LeakyBucketConfig::new(500.0, 2000.0).into();
99 1 : assert_eq!(config.cost, Duration::from_millis(2));
100 1 : assert_eq!(config.bucket_width, Duration::from_secs(4));
101 1 :
102 1 : let mut bucket = LeakyBucketState {
103 1 : empty_at: Instant::now(),
104 1 : };
105 1 :
106 1 : // should work for 2000 requests this second
107 2001 : for _ in 0..2000 {
108 2000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
109 2000 : }
110 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
111 1 : assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
112 1 :
113 1 : // in 1ms we should drain 0.5 tokens.
114 1 : // make sure we don't lose any tokens
115 1 : tokio::time::advance(Duration::from_millis(1)).await;
116 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
117 1 : tokio::time::advance(Duration::from_millis(1)).await;
118 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
119 1 :
120 1 : // in 10ms we should drain 5 tokens
121 1 : tokio::time::advance(Duration::from_millis(10)).await;
122 6 : for _ in 0..5 {
123 5 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
124 5 : }
125 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
126 1 :
127 1 : // in 10s we should drain 5000 tokens
128 1 : // but cap is only 2000
129 1 : tokio::time::advance(Duration::from_secs(10)).await;
130 2001 : for _ in 0..2000 {
131 2000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
132 2000 : }
133 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
134 1 :
135 1 : // should sustain 500rps
136 2001 : for _ in 0..2000 {
137 2000 : tokio::time::advance(Duration::from_millis(10)).await;
138 12000 : for _ in 0..5 {
139 10000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
140 10000 : }
141 1 : }
142 1 : }
143 : }
|