LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - leaky_bucket.rs (source / functions) Coverage Total Hit
Test: 2aa98e37cd3250b9a68c97ef6050b16fe702ab33.info Lines: 97.6 % 85 83
Test Date: 2024-08-29 11:33:10 Functions: 57.1 % 14 8

            Line data    Source code
       1              : use std::{
       2              :     hash::Hash,
       3              :     sync::atomic::{AtomicUsize, Ordering},
       4              : };
       5              : 
       6              : use ahash::RandomState;
       7              : use dashmap::DashMap;
       8              : use rand::{thread_rng, Rng};
       9              : use tokio::time::Instant;
      10              : use tracing::info;
      11              : use utils::leaky_bucket::LeakyBucketState;
      12              : 
      13              : use crate::intern::EndpointIdInt;
      14              : 
      15              : // Simple per-endpoint rate limiter.
      16              : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
      17              : 
      18              : pub struct LeakyBucketRateLimiter<Key> {
      19              :     map: DashMap<Key, LeakyBucketState, RandomState>,
      20              :     config: utils::leaky_bucket::LeakyBucketConfig,
      21              :     access_count: AtomicUsize,
      22              : }
      23              : 
      24              : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
      25              :     pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
      26              :         rps: 600.0,
      27              :         max: 1500.0,
      28              :     };
      29              : 
      30           18 :     pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
      31           18 :         Self {
      32           18 :             map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
      33           18 :             config: config.into(),
      34           18 :             access_count: AtomicUsize::new(0),
      35           18 :         }
      36           18 :     }
      37              : 
      38              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
      39           18 :     pub(crate) fn check(&self, key: K, n: u32) -> bool {
      40           18 :         let now = Instant::now();
      41           18 : 
      42           18 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
      43           18 :             self.do_gc(now);
      44           18 :         }
      45              : 
      46           18 :         let mut entry = self
      47           18 :             .map
      48           18 :             .entry(key)
      49           18 :             .or_insert_with(|| LeakyBucketState { empty_at: now });
      50           18 : 
      51           18 :         entry.add_tokens(&self.config, now, n as f64).is_ok()
      52           18 :     }
      53              : 
      54           18 :     fn do_gc(&self, now: Instant) {
      55           18 :         info!(
      56            0 :             "cleaning up bucket rate limiter, current size = {}",
      57            0 :             self.map.len()
      58              :         );
      59           18 :         let n = self.map.shards().len();
      60           18 :         let shard = thread_rng().gen_range(0..n);
      61           18 :         self.map.shards()[shard]
      62           18 :             .write()
      63           18 :             .retain(|_, value| !value.get().bucket_is_empty(now));
      64           18 :     }
      65              : }
      66              : 
      67              : pub struct LeakyBucketConfig {
      68              :     pub rps: f64,
      69              :     pub max: f64,
      70              : }
      71              : 
      72              : #[cfg(test)]
      73              : impl LeakyBucketConfig {
      74            6 :     pub(crate) fn new(rps: f64, max: f64) -> Self {
      75            6 :         assert!(rps > 0.0, "rps must be positive");
      76            6 :         assert!(max > 0.0, "max must be positive");
      77            6 :         Self { rps, max }
      78            6 :     }
      79              : }
      80              : 
      81              : impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
      82           24 :     fn from(config: LeakyBucketConfig) -> Self {
      83           24 :         utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max)
      84           24 :     }
      85              : }
      86              : 
      87              : #[cfg(test)]
      88              : #[allow(clippy::float_cmp)]
      89              : mod tests {
      90              :     use std::time::Duration;
      91              : 
      92              :     use tokio::time::Instant;
      93              :     use utils::leaky_bucket::LeakyBucketState;
      94              : 
      95              :     use super::LeakyBucketConfig;
      96              : 
      97              :     #[tokio::test(start_paused = true)]
      98            6 :     async fn check() {
      99            6 :         let config: utils::leaky_bucket::LeakyBucketConfig =
     100            6 :             LeakyBucketConfig::new(500.0, 2000.0).into();
     101            6 :         assert_eq!(config.cost, Duration::from_millis(2));
     102            6 :         assert_eq!(config.bucket_width, Duration::from_secs(4));
     103            6 : 
     104            6 :         let mut bucket = LeakyBucketState {
     105            6 :             empty_at: Instant::now(),
     106            6 :         };
     107            6 : 
     108            6 :         // should work for 2000 requests this second
     109        12006 :         for _ in 0..2000 {
     110        12000 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     111        12000 :         }
     112            6 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     113            6 :         assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
     114            6 : 
     115            6 :         // in 1ms we should drain 0.5 tokens.
     116            6 :         // make sure we don't lose any tokens
     117            6 :         tokio::time::advance(Duration::from_millis(1)).await;
     118            6 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     119            6 :         tokio::time::advance(Duration::from_millis(1)).await;
     120            6 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     121            6 : 
     122            6 :         // in 10ms we should drain 5 tokens
     123            6 :         tokio::time::advance(Duration::from_millis(10)).await;
     124           36 :         for _ in 0..5 {
     125           30 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     126           30 :         }
     127            6 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     128            6 : 
     129            6 :         // in 10s we should drain 5000 tokens
     130            6 :         // but cap is only 2000
     131            6 :         tokio::time::advance(Duration::from_secs(10)).await;
     132        12006 :         for _ in 0..2000 {
     133        12000 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     134        12000 :         }
     135            6 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     136            6 : 
     137            6 :         // should sustain 500rps
     138        12006 :         for _ in 0..2000 {
     139        12000 :             tokio::time::advance(Duration::from_millis(10)).await;
     140        72000 :             for _ in 0..5 {
     141        60000 :                 bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     142        60000 :             }
     143            6 :         }
     144            6 :     }
     145              : }
        

Generated by: LCOV version 2.1-beta