Line data Source code
1 : use std::{
2 : hash::Hash,
3 : sync::atomic::{AtomicUsize, Ordering},
4 : };
5 :
6 : use ahash::RandomState;
7 : use dashmap::DashMap;
8 : use rand::{thread_rng, Rng};
9 : use tokio::time::Instant;
10 : use tracing::info;
11 : use utils::leaky_bucket::LeakyBucketState;
12 :
13 : use crate::intern::EndpointIdInt;
14 :
15 : // Simple per-endpoint rate limiter.
16 : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
17 :
18 : pub struct LeakyBucketRateLimiter<Key> {
19 : map: DashMap<Key, LeakyBucketState, RandomState>,
20 : config: utils::leaky_bucket::LeakyBucketConfig,
21 : access_count: AtomicUsize,
22 : }
23 :
24 : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
25 : pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
26 : rps: 600.0,
27 : max: 1500.0,
28 : };
29 :
30 18 : pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
31 18 : Self {
32 18 : map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
33 18 : config: config.into(),
34 18 : access_count: AtomicUsize::new(0),
35 18 : }
36 18 : }
37 :
38 : /// Check that number of connections to the endpoint is below `max_rps` rps.
39 18 : pub(crate) fn check(&self, key: K, n: u32) -> bool {
40 18 : let now = Instant::now();
41 18 :
42 18 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
43 18 : self.do_gc(now);
44 18 : }
45 :
46 18 : let mut entry = self
47 18 : .map
48 18 : .entry(key)
49 18 : .or_insert_with(|| LeakyBucketState { empty_at: now });
50 18 :
51 18 : entry.add_tokens(&self.config, now, n as f64).is_ok()
52 18 : }
53 :
54 18 : fn do_gc(&self, now: Instant) {
55 18 : info!(
56 0 : "cleaning up bucket rate limiter, current size = {}",
57 0 : self.map.len()
58 : );
59 18 : let n = self.map.shards().len();
60 18 : let shard = thread_rng().gen_range(0..n);
61 18 : self.map.shards()[shard]
62 18 : .write()
63 18 : .retain(|_, value| !value.get().bucket_is_empty(now));
64 18 : }
65 : }
66 :
67 : pub struct LeakyBucketConfig {
68 : pub rps: f64,
69 : pub max: f64,
70 : }
71 :
72 : #[cfg(test)]
73 : impl LeakyBucketConfig {
74 6 : pub(crate) fn new(rps: f64, max: f64) -> Self {
75 6 : assert!(rps > 0.0, "rps must be positive");
76 6 : assert!(max > 0.0, "max must be positive");
77 6 : Self { rps, max }
78 6 : }
79 : }
80 :
81 : impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
82 24 : fn from(config: LeakyBucketConfig) -> Self {
83 24 : utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max)
84 24 : }
85 : }
86 :
87 : #[cfg(test)]
88 : #[allow(clippy::float_cmp)]
89 : mod tests {
90 : use std::time::Duration;
91 :
92 : use tokio::time::Instant;
93 : use utils::leaky_bucket::LeakyBucketState;
94 :
95 : use super::LeakyBucketConfig;
96 :
97 : #[tokio::test(start_paused = true)]
98 6 : async fn check() {
99 6 : let config: utils::leaky_bucket::LeakyBucketConfig =
100 6 : LeakyBucketConfig::new(500.0, 2000.0).into();
101 6 : assert_eq!(config.cost, Duration::from_millis(2));
102 6 : assert_eq!(config.bucket_width, Duration::from_secs(4));
103 6 :
104 6 : let mut bucket = LeakyBucketState {
105 6 : empty_at: Instant::now(),
106 6 : };
107 6 :
108 6 : // should work for 2000 requests this second
109 12006 : for _ in 0..2000 {
110 12000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
111 12000 : }
112 6 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
113 6 : assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
114 6 :
115 6 : // in 1ms we should drain 0.5 tokens.
116 6 : // make sure we don't lose any tokens
117 6 : tokio::time::advance(Duration::from_millis(1)).await;
118 6 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
119 6 : tokio::time::advance(Duration::from_millis(1)).await;
120 6 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
121 6 :
122 6 : // in 10ms we should drain 5 tokens
123 6 : tokio::time::advance(Duration::from_millis(10)).await;
124 36 : for _ in 0..5 {
125 30 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
126 30 : }
127 6 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
128 6 :
129 6 : // in 10s we should drain 5000 tokens
130 6 : // but cap is only 2000
131 6 : tokio::time::advance(Duration::from_secs(10)).await;
132 12006 : for _ in 0..2000 {
133 12000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
134 12000 : }
135 6 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
136 6 :
137 6 : // should sustain 500rps
138 12006 : for _ in 0..2000 {
139 12000 : tokio::time::advance(Duration::from_millis(10)).await;
140 72000 : for _ in 0..5 {
141 60000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
142 60000 : }
143 6 : }
144 6 : }
145 : }
|