LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - leaky_bucket.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 97.3 % 75 73
Test Date: 2025-07-16 12:29:03 Functions: 57.1 % 14 8

            Line data    Source code
       1              : use std::hash::Hash;
       2              : use std::sync::atomic::{AtomicUsize, Ordering};
       3              : 
       4              : use ahash::RandomState;
       5              : use clashmap::ClashMap;
       6              : use rand::{Rng, thread_rng};
       7              : use tokio::time::Instant;
       8              : use tracing::info;
       9              : use utils::leaky_bucket::LeakyBucketState;
      10              : 
      11              : use crate::intern::EndpointIdInt;
      12              : 
      13              : // Simple per-endpoint rate limiter.
      14              : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
      15              : 
      16              : pub struct LeakyBucketRateLimiter<Key> {
      17              :     map: ClashMap<Key, LeakyBucketState, RandomState>,
      18              :     default_config: utils::leaky_bucket::LeakyBucketConfig,
      19              :     access_count: AtomicUsize,
      20              : }
      21              : 
      22              : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
      23              :     pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
      24              :         rps: 600.0,
      25              :         max: 1500.0,
      26              :     };
      27              : 
      28            3 :     pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
      29            3 :         Self {
      30            3 :             map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
      31            3 :             default_config: config.into(),
      32            3 :             access_count: AtomicUsize::new(0),
      33            3 :         }
      34            3 :     }
      35              : 
      36              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
      37            3 :     pub(crate) fn check(&self, key: K, config: Option<LeakyBucketConfig>, n: u32) -> bool {
      38            3 :         let now = Instant::now();
      39              : 
      40            3 :         let config = config.map_or(self.default_config, Into::into);
      41              : 
      42            3 :         if self
      43            3 :             .access_count
      44            3 :             .fetch_add(1, Ordering::AcqRel)
      45            3 :             .is_multiple_of(2048)
      46            3 :         {
      47            3 :             self.do_gc(now);
      48            3 :         }
      49              : 
      50            3 :         let mut entry = self
      51            3 :             .map
      52            3 :             .entry(key)
      53            3 :             .or_insert_with(|| LeakyBucketState { empty_at: now });
      54              : 
      55            3 :         entry.add_tokens(&config, now, n as f64).is_ok()
      56            3 :     }
      57              : 
      58            3 :     fn do_gc(&self, now: Instant) {
      59            3 :         info!(
      60            0 :             "cleaning up bucket rate limiter, current size = {}",
      61            0 :             self.map.len()
      62              :         );
      63            3 :         let n = self.map.shards().len();
      64            3 :         let shard = thread_rng().gen_range(0..n);
      65            3 :         self.map.shards()[shard]
      66            3 :             .write()
      67            3 :             .retain(|(_, value)| !value.bucket_is_empty(now));
      68            3 :     }
      69              : }
      70              : 
      71              : pub struct LeakyBucketConfig {
      72              :     pub rps: f64,
      73              :     pub max: f64,
      74              : }
      75              : 
      76              : impl LeakyBucketConfig {
      77            1 :     pub fn new(rps: f64, max: f64) -> Self {
      78            1 :         assert!(rps > 0.0, "rps must be positive");
      79            1 :         assert!(max > 0.0, "max must be positive");
      80            1 :         Self { rps, max }
      81            1 :     }
      82              : }
      83              : 
      84              : impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
      85            4 :     fn from(config: LeakyBucketConfig) -> Self {
      86            4 :         utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max)
      87            4 :     }
      88              : }
      89              : 
      90              : #[cfg(test)]
      91              : #[allow(clippy::float_cmp)]
      92              : mod tests {
      93              :     use std::time::Duration;
      94              : 
      95              :     use tokio::time::Instant;
      96              :     use utils::leaky_bucket::LeakyBucketState;
      97              : 
      98              :     use super::LeakyBucketConfig;
      99              : 
     100              :     #[tokio::test(start_paused = true)]
     101            1 :     async fn check() {
     102            1 :         let config: utils::leaky_bucket::LeakyBucketConfig =
     103            1 :             LeakyBucketConfig::new(500.0, 2000.0).into();
     104            1 :         assert_eq!(config.cost, Duration::from_millis(2));
     105            1 :         assert_eq!(config.bucket_width, Duration::from_secs(4));
     106              : 
     107            1 :         let mut bucket = LeakyBucketState {
     108            1 :             empty_at: Instant::now(),
     109            1 :         };
     110              : 
     111              :         // should work for 2000 requests this second
     112         2001 :         for _ in 0..2000 {
     113         2000 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     114         2000 :         }
     115            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     116            1 :         assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
     117              : 
     118              :         // in 1ms we should drain 0.5 tokens.
     119              :         // make sure we don't lose any tokens
     120            1 :         tokio::time::advance(Duration::from_millis(1)).await;
     121            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     122            1 :         tokio::time::advance(Duration::from_millis(1)).await;
     123            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     124              : 
     125              :         // in 10ms we should drain 5 tokens
     126            1 :         tokio::time::advance(Duration::from_millis(10)).await;
     127            6 :         for _ in 0..5 {
     128            5 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     129            5 :         }
     130            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     131              : 
     132              :         // in 10s we should drain 5000 tokens
     133              :         // but cap is only 2000
     134            1 :         tokio::time::advance(Duration::from_secs(10)).await;
     135         2001 :         for _ in 0..2000 {
     136         2000 :             bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     137         2000 :         }
     138            1 :         bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     139              : 
     140              :         // should sustain 500rps
     141         2001 :         for _ in 0..2000 {
     142         2000 :             tokio::time::advance(Duration::from_millis(10)).await;
     143        12000 :             for _ in 0..5 {
     144        10000 :                 bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
     145        10000 :             }
     146            1 :         }
     147            1 :     }
     148              : }
        

Generated by: LCOV version 2.1-beta