Line data Source code
1 : use std::{
2 : hash::Hash,
3 : sync::atomic::{AtomicUsize, Ordering},
4 : };
5 :
6 : use ahash::RandomState;
7 : use dashmap::DashMap;
8 : use rand::{thread_rng, Rng};
9 : use tokio::time::Instant;
10 : use tracing::info;
11 : use utils::leaky_bucket::LeakyBucketState;
12 :
13 : use crate::intern::EndpointIdInt;
14 :
15 : // Simple per-endpoint rate limiter.
16 : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
17 :
18 : pub struct LeakyBucketRateLimiter<Key> {
19 : map: DashMap<Key, LeakyBucketState, RandomState>,
20 : config: utils::leaky_bucket::LeakyBucketConfig,
21 : access_count: AtomicUsize,
22 : }
23 :
24 : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
25 : pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
26 : rps: 600.0,
27 : max: 1500.0,
28 : };
29 :
30 3 : pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
31 3 : Self {
32 3 : map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
33 3 : config: config.into(),
34 3 : access_count: AtomicUsize::new(0),
35 3 : }
36 3 : }
37 :
38 : /// Check that number of connections to the endpoint is below `max_rps` rps.
39 3 : pub(crate) fn check(&self, key: K, n: u32) -> bool {
40 3 : let now = Instant::now();
41 3 :
42 3 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
43 3 : self.do_gc(now);
44 3 : }
45 :
46 3 : let mut entry = self
47 3 : .map
48 3 : .entry(key)
49 3 : .or_insert_with(|| LeakyBucketState { empty_at: now });
50 3 :
51 3 : entry.add_tokens(&self.config, now, n as f64).is_ok()
52 3 : }
53 :
54 3 : fn do_gc(&self, now: Instant) {
55 3 : info!(
56 0 : "cleaning up bucket rate limiter, current size = {}",
57 0 : self.map.len()
58 : );
59 3 : let n = self.map.shards().len();
60 3 : let shard = thread_rng().gen_range(0..n);
61 3 : self.map.shards()[shard]
62 3 : .write()
63 3 : .retain(|_, value| !value.get().bucket_is_empty(now));
64 3 : }
65 : }
66 :
67 : pub struct LeakyBucketConfig {
68 : pub rps: f64,
69 : pub max: f64,
70 : }
71 :
72 : #[cfg(test)]
73 : impl LeakyBucketConfig {
74 1 : pub(crate) fn new(rps: f64, max: f64) -> Self {
75 1 : assert!(rps > 0.0, "rps must be positive");
76 1 : assert!(max > 0.0, "max must be positive");
77 1 : Self { rps, max }
78 1 : }
79 : }
80 :
81 : impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
82 4 : fn from(config: LeakyBucketConfig) -> Self {
83 4 : utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max)
84 4 : }
85 : }
86 :
87 : #[cfg(test)]
88 : #[allow(clippy::float_cmp)]
89 : mod tests {
90 : use std::time::Duration;
91 :
92 : use tokio::time::Instant;
93 : use utils::leaky_bucket::LeakyBucketState;
94 :
95 : use super::LeakyBucketConfig;
96 :
97 : #[tokio::test(start_paused = true)]
98 1 : async fn check() {
99 1 : let config: utils::leaky_bucket::LeakyBucketConfig =
100 1 : LeakyBucketConfig::new(500.0, 2000.0).into();
101 1 : assert_eq!(config.cost, Duration::from_millis(2));
102 1 : assert_eq!(config.bucket_width, Duration::from_secs(4));
103 1 :
104 1 : let mut bucket = LeakyBucketState {
105 1 : empty_at: Instant::now(),
106 1 : };
107 1 :
108 1 : // should work for 2000 requests this second
109 2001 : for _ in 0..2000 {
110 2000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
111 2000 : }
112 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
113 1 : assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
114 1 :
115 1 : // in 1ms we should drain 0.5 tokens.
116 1 : // make sure we don't lose any tokens
117 1 : tokio::time::advance(Duration::from_millis(1)).await;
118 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
119 1 : tokio::time::advance(Duration::from_millis(1)).await;
120 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
121 1 :
122 1 : // in 10ms we should drain 5 tokens
123 1 : tokio::time::advance(Duration::from_millis(10)).await;
124 6 : for _ in 0..5 {
125 5 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
126 5 : }
127 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
128 1 :
129 1 : // in 10s we should drain 5000 tokens
130 1 : // but cap is only 2000
131 1 : tokio::time::advance(Duration::from_secs(10)).await;
132 2001 : for _ in 0..2000 {
133 2000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
134 2000 : }
135 1 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
136 1 :
137 1 : // should sustain 500rps
138 2001 : for _ in 0..2000 {
139 2000 : tokio::time::advance(Duration::from_millis(10)).await;
140 12000 : for _ in 0..5 {
141 10000 : bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
142 10000 : }
143 1 : }
144 1 : }
145 : }
|