LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - leaky_bucket.rs (source / functions) Coverage Total Hit
Test: 553e39c2773e5840c720c90d86e56f89a4330d43.info Lines: 97.7 % 87 85
Test Date: 2025-06-13 20:01:21 Functions: 57.1 % 14 8

            Line data    Source code
       1              : use std::hash::Hash;
       2              : use std::sync::atomic::{AtomicUsize, Ordering};
       3              : 
       4              : use ahash::RandomState;
       5              : use clashmap::ClashMap;
       6              : use rand::{Rng, thread_rng};
       7              : use tokio::time::Instant;
       8              : use tracing::info;
       9              : use utils::leaky_bucket::LeakyBucketState;
      10              : 
      11              : use crate::intern::EndpointIdInt;
      12              : 
      13              : // Simple per-endpoint rate limiter.
      14              : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
      15              : 
      16              : pub struct LeakyBucketRateLimiter<Key> {
      17              :     map: ClashMap<Key, LeakyBucketState, RandomState>,
      18              :     default_config: utils::leaky_bucket::LeakyBucketConfig,
      19              :     access_count: AtomicUsize,
      20              : }
      21              : 
      22              : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
      23              :     pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
      24              :         rps: 600.0,
      25              :         max: 1500.0,
      26              :     };
      27              : 
      28            3 :     pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
      29            3 :         Self {
      30            3 :             map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
      31            3 :             default_config: config.into(),
      32            3 :             access_count: AtomicUsize::new(0),
      33            3 :         }
      34            3 :     }
      35              : 
      36              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
      37            3 :     pub(crate) fn check(&self, key: K, config: Option<LeakyBucketConfig>, n: u32) -> bool {
      38            3 :         let now = Instant::now();
      39            3 : 
      40            3 :         let config = config.map_or(self.default_config, Into::into);
      41            3 : 
      42            3 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
      43            3 :             self.do_gc(now);
      44            3 :         }
      45              : 
      46            3 :         let mut entry = self
      47            3 :             .map
      48            3 :             .entry(key)
      49            3 :             .or_insert_with(|| LeakyBucketState { empty_at: now });
      50            3 : 
      51            3 :         entry.add_tokens(&config, now, n as f64).is_ok()
      52            3 :     }
      53              : 
      54            3 :     fn do_gc(&self, now: Instant) {
      55            3 :         info!(
      56            0 :             "cleaning up bucket rate limiter, current size = {}",
      57            0 :             self.map.len()
      58              :         );
      59            3 :         let n = self.map.shards().len();
      60            3 :         let shard = thread_rng().gen_range(0..n);
      61            3 :         self.map.shards()[shard]
      62            3 :             .write()
      63            3 :             .retain(|(_, value)| !value.bucket_is_empty(now));
      64            3 :     }
      65              : }
      66              : 
      67              : pub struct LeakyBucketConfig {
      68              :     pub rps: f64,
      69              :     pub max: f64,
      70              : }
      71              : 
      72              : impl LeakyBucketConfig {
      73            1 :     pub fn new(rps: f64, max: f64) -> Self {
      74            1 :         assert!(rps > 0.0, "rps must be positive");
      75            1 :         assert!(max > 0.0, "max must be positive");
      76            1 :         Self { rps, max }
      77            1 :     }
      78              : }
      79              : 
      80              : impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
      81            4 :     fn from(config: LeakyBucketConfig) -> Self {
      82            4 :         utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max)
      83            4 :     }
      84              : }
      85              : 
      86              : #[cfg(test)]
      87              : #[allow(clippy::float_cmp)]
      88              : mod tests {
      89              :     use std::time::Duration;
      90              : 
      91              :     use tokio::time::Instant;
      92              :     use utils::leaky_bucket::LeakyBucketState;
      93              : 
      94              :     use super::LeakyBucketConfig;
      95              : 
      96              :     #[tokio::test(start_paused = true)]
      97            1 :     async fn check() {
      98            1 :         let config: utils::leaky_bucket::LeakyBucketConfig =
      99            1 :             LeakyBucketConfig::new(500.0, 2000.0).into();
     100            1 :         assert_eq!(config.cost, Duration::from_millis(2));
     101            1 :         assert_eq!(config.bucket_width, Duration::from_secs(4));
     102            1 : 
     103            1 :         let mut bucket = LeakyBucketState {
     104            1 :             empty_at: Instant::now(),
     105            1 :         };
     106            1 : 
     107            1 :         // should work for 2000 requests this second
     108         2001 :         for _ in 0..2000 {
     109         2000 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     110         2000 :         }
     111            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     112            1 :         assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
     113            1 : 
     114            1 :         // in 1ms we should drain 0.5 tokens.
     115            1 :         // make sure we don't lose any tokens
     116            1 :         tokio::time::advance(Duration::from_millis(1)).await;
     117            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     118            1 :         tokio::time::advance(Duration::from_millis(1)).await;
     119            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     120            1 : 
     121            1 :         // in 10ms we should drain 5 tokens
     122            1 :         tokio::time::advance(Duration::from_millis(10)).await;
     123            6 :         for _ in 0..5 {
     124            5 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     125            5 :         }
     126            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     127            1 : 
     128            1 :         // in 10s we should drain 5000 tokens
     129            1 :         // but cap is only 2000
     130            1 :         tokio::time::advance(Duration::from_secs(10)).await;
     131         2001 :         for _ in 0..2000 {
     132         2000 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     133         2000 :         }
     134            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     135            1 : 
     136            1 :         // should sustain 500rps
     137         2001 :         for _ in 0..2000 {
     138         2000 :             tokio::time::advance(Duration::from_millis(10)).await;
     139        12000 :             for _ in 0..5 {
     140        10000 :                 bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     141        10000 :             }
     142            1 :         }
     143            1 :     }
     144              : }
        

Generated by: LCOV version 2.1-beta