|             Line data    Source code 
       1              : use std::{
       2              :     hash::Hash,
       3              :     sync::atomic::{AtomicUsize, Ordering},
       4              : };
       5              : 
       6              : use ahash::RandomState;
       7              : use dashmap::DashMap;
       8              : use rand::{thread_rng, Rng};
       9              : use tokio::time::Instant;
      10              : use tracing::info;
      11              : use utils::leaky_bucket::LeakyBucketState;
      12              : 
      13              : use crate::intern::EndpointIdInt;
      14              : 
      15              : // Simple per-endpoint rate limiter.
      16              : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
      17              : 
      18              : pub struct LeakyBucketRateLimiter<Key> {
      19              :     map: DashMap<Key, LeakyBucketState, RandomState>,
      20              :     config: utils::leaky_bucket::LeakyBucketConfig,
      21              :     access_count: AtomicUsize,
      22              : }
      23              : 
      24              : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
      25              :     pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
      26              :         rps: 600.0,
      27              :         max: 1500.0,
      28              :     };
      29              : 
      30            3 :     pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
      31            3 :         Self {
      32            3 :             map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
      33            3 :             config: config.into(),
      34            3 :             access_count: AtomicUsize::new(0),
      35            3 :         }
      36            3 :     }
      37              : 
      38              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
      39            3 :     pub(crate) fn check(&self, key: K, n: u32) -> bool {
      40            3 :         let now = Instant::now();
      41            3 : 
      42            3 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
      43            3 :             self.do_gc(now);
      44            3 :         }
      45              : 
      46            3 :         let mut entry = self
      47            3 :             .map
      48            3 :             .entry(key)
      49            3 :             .or_insert_with(|| LeakyBucketState { empty_at: now });
      50            3 : 
      51            3 :         entry.add_tokens(&self.config, now, n as f64).is_ok()
      52            3 :     }
      53              : 
      54            3 :     fn do_gc(&self, now: Instant) {
      55            3 :         info!(
      56            0 :             "cleaning up bucket rate limiter, current size = {}",
      57            0 :             self.map.len()
      58              :         );
      59            3 :         let n = self.map.shards().len();
      60            3 :         let shard = thread_rng().gen_range(0..n);
      61            3 :         self.map.shards()[shard]
      62            3 :             .write()
      63            3 :             .retain(|_, value| !value.get().bucket_is_empty(now));
      64            3 :     }
      65              : }
      66              : 
      67              : pub struct LeakyBucketConfig {
      68              :     pub rps: f64,
      69              :     pub max: f64,
      70              : }
      71              : 
      72              : #[cfg(test)]
      73              : impl LeakyBucketConfig {
      74            1 :     pub(crate) fn new(rps: f64, max: f64) -> Self {
      75            1 :         assert!(rps > 0.0, "rps must be positive");
      76            1 :         assert!(max > 0.0, "max must be positive");
      77            1 :         Self { rps, max }
      78            1 :     }
      79              : }
      80              : 
      81              : impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
      82            4 :     fn from(config: LeakyBucketConfig) -> Self {
      83            4 :         utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max)
      84            4 :     }
      85              : }
      86              : 
      87              : #[cfg(test)]
      88              : #[allow(clippy::float_cmp)]
      89              : mod tests {
      90              :     use std::time::Duration;
      91              : 
      92              :     use tokio::time::Instant;
      93              :     use utils::leaky_bucket::LeakyBucketState;
      94              : 
      95              :     use super::LeakyBucketConfig;
      96              : 
      97              :     #[tokio::test(start_paused = true)]
      98            1 :     async fn check() {
      99            1 :         let config: utils::leaky_bucket::LeakyBucketConfig =
     100            1 :             LeakyBucketConfig::new(500.0, 2000.0).into();
     101            1 :         assert_eq!(config.cost, Duration::from_millis(2));
     102            1 :         assert_eq!(config.bucket_width, Duration::from_secs(4));
     103            1 : 
     104            1 :         let mut bucket = LeakyBucketState {
     105            1 :             empty_at: Instant::now(),
     106            1 :         };
     107            1 : 
     108            1 :         // should work for 2000 requests this second
     109         2001 :         for _ in 0..2000 {
     110         2000 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     111         2000 :         }
     112            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     113            1 :         assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
     114            1 : 
     115            1 :         // in 1ms we should drain 0.5 tokens.
     116            1 :         // make sure we don't lose any tokens
     117            1 :         tokio::time::advance(Duration::from_millis(1)).await;
     118            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     119            1 :         tokio::time::advance(Duration::from_millis(1)).await;
     120            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     121            1 : 
     122            1 :         // in 10ms we should drain 5 tokens
     123            1 :         tokio::time::advance(Duration::from_millis(10)).await;
     124            6 :         for _ in 0..5 {
     125            5 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     126            5 :         }
     127            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     128            1 : 
     129            1 :         // in 10s we should drain 5000 tokens
     130            1 :         // but cap is only 2000
     131            1 :         tokio::time::advance(Duration::from_secs(10)).await;
     132         2001 :         for _ in 0..2000 {
     133         2000 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     134         2000 :         }
     135            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     136            1 : 
     137            1 :         // should sustain 500rps
     138         2001 :         for _ in 0..2000 {
     139         2000 :             tokio::time::advance(Duration::from_millis(10)).await;
     140        12000 :             for _ in 0..5 {
     141        10000 :                 bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     142        10000 :             }
     143            1 :         }
     144            1 :     }
     145              : }
         |