Line data Source code
1 : use std::{
2 : hash::Hash,
3 : sync::atomic::{AtomicUsize, Ordering},
4 : };
5 :
6 : use ahash::RandomState;
7 : use dashmap::DashMap;
8 : use rand::{thread_rng, Rng};
9 : use tokio::time::Instant;
10 : use tracing::info;
11 :
12 : use crate::intern::EndpointIdInt;
13 :
14 : // Simple per-endpoint rate limiter.
15 : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
16 :
17 : pub struct LeakyBucketRateLimiter<Key> {
18 : map: DashMap<Key, LeakyBucketState, RandomState>,
19 : config: LeakyBucketConfig,
20 : access_count: AtomicUsize,
21 : }
22 :
23 : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
24 : pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
25 : rps: 600.0,
26 : max: 1500.0,
27 : };
28 :
29 6 : pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
30 6 : Self {
31 6 : map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
32 6 : config,
33 6 : access_count: AtomicUsize::new(0),
34 6 : }
35 6 : }
36 :
37 : /// Check that number of connections to the endpoint is below `max_rps` rps.
38 6 : pub fn check(&self, key: K, n: u32) -> bool {
39 6 : let now = Instant::now();
40 6 :
41 6 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
42 6 : self.do_gc(now);
43 6 : }
44 :
45 6 : let mut entry = self.map.entry(key).or_insert_with(|| LeakyBucketState {
46 6 : time: now,
47 6 : filled: 0.0,
48 6 : });
49 6 :
50 6 : entry.check(&self.config, now, n as f64)
51 6 : }
52 :
53 6 : fn do_gc(&self, now: Instant) {
54 6 : info!(
55 0 : "cleaning up bucket rate limiter, current size = {}",
56 0 : self.map.len()
57 : );
58 6 : let n = self.map.shards().len();
59 6 : let shard = thread_rng().gen_range(0..n);
60 6 : self.map.shards()[shard]
61 6 : .write()
62 6 : .retain(|_, value| !value.get_mut().update(&self.config, now));
63 6 : }
64 : }
65 :
66 : pub struct LeakyBucketConfig {
67 : pub rps: f64,
68 : pub max: f64,
69 : }
70 :
71 : pub struct LeakyBucketState {
72 : filled: f64,
73 : time: Instant,
74 : }
75 :
76 : impl LeakyBucketConfig {
77 2 : pub fn new(rps: f64, max: f64) -> Self {
78 2 : assert!(rps > 0.0, "rps must be positive");
79 2 : assert!(max > 0.0, "max must be positive");
80 2 : Self { rps, max }
81 2 : }
82 : }
83 :
84 : impl LeakyBucketState {
85 2 : pub fn new() -> Self {
86 2 : Self {
87 2 : filled: 0.0,
88 2 : time: Instant::now(),
89 2 : }
90 2 : }
91 :
92 : /// updates the timer and returns true if the bucket is empty
93 28026 : fn update(&mut self, info: &LeakyBucketConfig, now: Instant) -> bool {
94 28026 : let drain = now.duration_since(self.time);
95 28026 : let drain = drain.as_secs_f64() * info.rps;
96 28026 :
97 28026 : self.filled = (self.filled - drain).clamp(0.0, info.max);
98 28026 : self.time = now;
99 28026 :
100 28026 : self.filled == 0.0
101 28026 : }
102 :
103 28026 : pub fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool {
104 28026 : self.update(info, now);
105 28026 :
106 28026 : if self.filled + n > info.max {
107 8 : return false;
108 28018 : }
109 28018 : self.filled += n;
110 28018 :
111 28018 : true
112 28026 : }
113 : }
114 :
115 : impl Default for LeakyBucketState {
116 0 : fn default() -> Self {
117 0 : Self::new()
118 0 : }
119 : }
120 :
121 : #[cfg(test)]
122 : mod tests {
123 : use std::time::Duration;
124 :
125 : use tokio::time::Instant;
126 :
127 : use super::{LeakyBucketConfig, LeakyBucketState};
128 :
129 : #[tokio::test(start_paused = true)]
130 2 : async fn check() {
131 2 : let info = LeakyBucketConfig::new(500.0, 2000.0);
132 2 : let mut bucket = LeakyBucketState::new();
133 2 :
134 2 : // should work for 2000 requests this second
135 4002 : for _ in 0..2000 {
136 4000 : assert!(bucket.check(&info, Instant::now(), 1.0));
137 2 : }
138 2 : assert!(!bucket.check(&info, Instant::now(), 1.0));
139 2 : assert_eq!(bucket.filled, 2000.0);
140 2 :
141 2 : // in 1ms we should drain 0.5 tokens.
142 2 : // make sure we don't lose any tokens
143 2 : tokio::time::advance(Duration::from_millis(1)).await;
144 2 : assert!(!bucket.check(&info, Instant::now(), 1.0));
145 2 : tokio::time::advance(Duration::from_millis(1)).await;
146 2 : assert!(bucket.check(&info, Instant::now(), 1.0));
147 2 :
148 2 : // in 10ms we should drain 5 tokens
149 2 : tokio::time::advance(Duration::from_millis(10)).await;
150 12 : for _ in 0..5 {
151 10 : assert!(bucket.check(&info, Instant::now(), 1.0));
152 2 : }
153 2 : assert!(!bucket.check(&info, Instant::now(), 1.0));
154 2 :
155 2 : // in 10s we should drain 5000 tokens
156 2 : // but cap is only 2000
157 2 : tokio::time::advance(Duration::from_secs(10)).await;
158 4002 : for _ in 0..2000 {
159 4000 : assert!(bucket.check(&info, Instant::now(), 1.0));
160 2 : }
161 2 : assert!(!bucket.check(&info, Instant::now(), 1.0));
162 2 :
163 2 : // should sustain 500rps
164 4002 : for _ in 0..2000 {
165 4000 : tokio::time::advance(Duration::from_millis(10)).await;
166 24000 : for _ in 0..5 {
167 20000 : assert!(bucket.check(&info, Instant::now(), 1.0));
168 2 : }
169 2 : }
170 2 : }
171 : }
|