LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - leaky_bucket.rs (source / functions) Coverage Total Hit
Test: 8ff8efadb0253cf618c612650348666c0c564111.info Lines: 97.6 % 85 83
Test Date: 2024-11-20 17:53:50 Functions: 57.1 % 14 8

            Line data    Source code
       1              : use std::hash::Hash;
       2              : use std::sync::atomic::{AtomicUsize, Ordering};
       3              : 
       4              : use ahash::RandomState;
       5              : use dashmap::DashMap;
       6              : use rand::{thread_rng, Rng};
       7              : use tokio::time::Instant;
       8              : use tracing::info;
       9              : use utils::leaky_bucket::LeakyBucketState;
      10              : 
      11              : use crate::intern::EndpointIdInt;
      12              : 
      13              : // Simple per-endpoint rate limiter.
      14              : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
      15              : 
      16              : pub struct LeakyBucketRateLimiter<Key> {
      17              :     map: DashMap<Key, LeakyBucketState, RandomState>,
      18              :     config: utils::leaky_bucket::LeakyBucketConfig,
      19              :     access_count: AtomicUsize,
      20              : }
      21              : 
      22              : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
      23              :     pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
      24              :         rps: 600.0,
      25              :         max: 1500.0,
      26              :     };
      27              : 
      28            3 :     pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
      29            3 :         Self {
      30            3 :             map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
      31            3 :             config: config.into(),
      32            3 :             access_count: AtomicUsize::new(0),
      33            3 :         }
      34            3 :     }
      35              : 
      36              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
      37            3 :     pub(crate) fn check(&self, key: K, n: u32) -> bool {
      38            3 :         let now = Instant::now();
      39            3 : 
      40            3 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
      41            3 :             self.do_gc(now);
      42            3 :         }
      43              : 
      44            3 :         let mut entry = self
      45            3 :             .map
      46            3 :             .entry(key)
      47            3 :             .or_insert_with(|| LeakyBucketState { empty_at: now });
      48            3 : 
      49            3 :         entry.add_tokens(&self.config, now, n as f64).is_ok()
      50            3 :     }
      51              : 
      52            3 :     fn do_gc(&self, now: Instant) {
      53            3 :         info!(
      54            0 :             "cleaning up bucket rate limiter, current size = {}",
      55            0 :             self.map.len()
      56              :         );
      57            3 :         let n = self.map.shards().len();
      58            3 :         let shard = thread_rng().gen_range(0..n);
      59            3 :         self.map.shards()[shard]
      60            3 :             .write()
      61            3 :             .retain(|_, value| !value.get().bucket_is_empty(now));
      62            3 :     }
      63              : }
      64              : 
      65              : pub struct LeakyBucketConfig {
      66              :     pub rps: f64,
      67              :     pub max: f64,
      68              : }
      69              : 
      70              : #[cfg(test)]
      71              : impl LeakyBucketConfig {
      72            1 :     pub(crate) fn new(rps: f64, max: f64) -> Self {
      73            1 :         assert!(rps > 0.0, "rps must be positive");
      74            1 :         assert!(max > 0.0, "max must be positive");
      75            1 :         Self { rps, max }
      76            1 :     }
      77              : }
      78              : 
      79              : impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
      80            4 :     fn from(config: LeakyBucketConfig) -> Self {
      81            4 :         utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max)
      82            4 :     }
      83              : }
      84              : 
      85              : #[cfg(test)]
      86              : #[allow(clippy::float_cmp)]
      87              : mod tests {
      88              :     use std::time::Duration;
      89              : 
      90              :     use tokio::time::Instant;
      91              :     use utils::leaky_bucket::LeakyBucketState;
      92              : 
      93              :     use super::LeakyBucketConfig;
      94              : 
      95              :     #[tokio::test(start_paused = true)]
      96            1 :     async fn check() {
      97            1 :         let config: utils::leaky_bucket::LeakyBucketConfig =
      98            1 :             LeakyBucketConfig::new(500.0, 2000.0).into();
      99            1 :         assert_eq!(config.cost, Duration::from_millis(2));
     100            1 :         assert_eq!(config.bucket_width, Duration::from_secs(4));
     101            1 : 
     102            1 :         let mut bucket = LeakyBucketState {
     103            1 :             empty_at: Instant::now(),
     104            1 :         };
     105            1 : 
     106            1 :         // should work for 2000 requests this second
     107         2001 :         for _ in 0..2000 {
     108         2000 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     109         2000 :         }
     110            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     111            1 :         assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
     112            1 : 
     113            1 :         // in 1ms we should drain 0.5 tokens.
     114            1 :         // make sure we don't lose any tokens
     115            1 :         tokio::time::advance(Duration::from_millis(1)).await;
     116            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     117            1 :         tokio::time::advance(Duration::from_millis(1)).await;
     118            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     119            1 : 
     120            1 :         // in 10ms we should drain 5 tokens
     121            1 :         tokio::time::advance(Duration::from_millis(10)).await;
     122            6 :         for _ in 0..5 {
     123            5 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     124            5 :         }
     125            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     126            1 : 
     127            1 :         // in 10s we should drain 5000 tokens
     128            1 :         // but cap is only 2000
     129            1 :         tokio::time::advance(Duration::from_secs(10)).await;
     130         2001 :         for _ in 0..2000 {
     131         2000 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     132         2000 :         }
     133            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     134            1 : 
     135            1 :         // should sustain 500rps
     136         2001 :         for _ in 0..2000 {
     137         2000 :             tokio::time::advance(Duration::from_millis(10)).await;
     138        12000 :             for _ in 0..5 {
     139        10000 :                 bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     140        10000 :             }
     141            1 :         }
     142            1 :     }
     143              : }
        

Generated by: LCOV version 2.1-beta