LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - leaky_bucket.rs (source / functions) Coverage Total Hit
Test: 42f947419473a288706e86ecdf7c2863d760d5d7.info Lines: 95.2 % 104 99
Test Date: 2024-08-02 21:34:27 Functions: 58.8 % 17 10

            Line data    Source code
       1              : use std::{
       2              :     hash::Hash,
       3              :     sync::atomic::{AtomicUsize, Ordering},
       4              : };
       5              : 
       6              : use ahash::RandomState;
       7              : use dashmap::DashMap;
       8              : use rand::{thread_rng, Rng};
       9              : use tokio::time::Instant;
      10              : use tracing::info;
      11              : 
      12              : use crate::intern::EndpointIdInt;
      13              : 
      14              : // Simple per-endpoint rate limiter.
      15              : pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
      16              : 
      17              : pub struct LeakyBucketRateLimiter<Key> {
      18              :     map: DashMap<Key, LeakyBucketState, RandomState>,
      19              :     config: LeakyBucketConfig,
      20              :     access_count: AtomicUsize,
      21              : }
      22              : 
      23              : impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
      24              :     pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
      25              :         rps: 600.0,
      26              :         max: 1500.0,
      27              :     };
      28              : 
      29            6 :     pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
      30            6 :         Self {
      31            6 :             map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
      32            6 :             config,
      33            6 :             access_count: AtomicUsize::new(0),
      34            6 :         }
      35            6 :     }
      36              : 
      37              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
      38            6 :     pub fn check(&self, key: K, n: u32) -> bool {
      39            6 :         let now = Instant::now();
      40            6 : 
      41            6 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
      42            6 :             self.do_gc(now);
      43            6 :         }
      44              : 
      45            6 :         let mut entry = self.map.entry(key).or_insert_with(|| LeakyBucketState {
      46            6 :             time: now,
      47            6 :             filled: 0.0,
      48            6 :         });
      49            6 : 
      50            6 :         entry.check(&self.config, now, n as f64)
      51            6 :     }
      52              : 
      53            6 :     fn do_gc(&self, now: Instant) {
      54            6 :         info!(
      55            0 :             "cleaning up bucket rate limiter, current size = {}",
      56            0 :             self.map.len()
      57              :         );
      58            6 :         let n = self.map.shards().len();
      59            6 :         let shard = thread_rng().gen_range(0..n);
      60            6 :         self.map.shards()[shard]
      61            6 :             .write()
      62            6 :             .retain(|_, value| !value.get_mut().update(&self.config, now));
      63            6 :     }
      64              : }
      65              : 
      66              : pub struct LeakyBucketConfig {
      67              :     pub rps: f64,
      68              :     pub max: f64,
      69              : }
      70              : 
      71              : pub struct LeakyBucketState {
      72              :     filled: f64,
      73              :     time: Instant,
      74              : }
      75              : 
      76              : impl LeakyBucketConfig {
      77            2 :     pub fn new(rps: f64, max: f64) -> Self {
      78            2 :         assert!(rps > 0.0, "rps must be positive");
      79            2 :         assert!(max > 0.0, "max must be positive");
      80            2 :         Self { rps, max }
      81            2 :     }
      82              : }
      83              : 
      84              : impl LeakyBucketState {
      85            2 :     pub fn new() -> Self {
      86            2 :         Self {
      87            2 :             filled: 0.0,
      88            2 :             time: Instant::now(),
      89            2 :         }
      90            2 :     }
      91              : 
      92              :     /// updates the timer and returns true if the bucket is empty
      93        28026 :     fn update(&mut self, info: &LeakyBucketConfig, now: Instant) -> bool {
      94        28026 :         let drain = now.duration_since(self.time);
      95        28026 :         let drain = drain.as_secs_f64() * info.rps;
      96        28026 : 
      97        28026 :         self.filled = (self.filled - drain).clamp(0.0, info.max);
      98        28026 :         self.time = now;
      99        28026 : 
     100        28026 :         self.filled == 0.0
     101        28026 :     }
     102              : 
     103        28026 :     pub fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool {
     104        28026 :         self.update(info, now);
     105        28026 : 
     106        28026 :         if self.filled + n > info.max {
     107            8 :             return false;
     108        28018 :         }
     109        28018 :         self.filled += n;
     110        28018 : 
     111        28018 :         true
     112        28026 :     }
     113              : }
     114              : 
     115              : impl Default for LeakyBucketState {
     116            0 :     fn default() -> Self {
     117            0 :         Self::new()
     118            0 :     }
     119              : }
     120              : 
     121              : #[cfg(test)]
     122              : mod tests {
     123              :     use std::time::Duration;
     124              : 
     125              :     use tokio::time::Instant;
     126              : 
     127              :     use super::{LeakyBucketConfig, LeakyBucketState};
     128              : 
     129              :     #[tokio::test(start_paused = true)]
     130            2 :     async fn check() {
     131            2 :         let info = LeakyBucketConfig::new(500.0, 2000.0);
     132            2 :         let mut bucket = LeakyBucketState::new();
     133            2 : 
     134            2 :         // should work for 2000 requests this second
     135         4002 :         for _ in 0..2000 {
     136         4000 :             assert!(bucket.check(&info, Instant::now(), 1.0));
     137            2 :         }
     138            2 :         assert!(!bucket.check(&info, Instant::now(), 1.0));
     139            2 :         assert_eq!(bucket.filled, 2000.0);
     140            2 : 
     141            2 :         // in 1ms we should drain 0.5 tokens.
     142            2 :         // make sure we don't lose any tokens
     143            2 :         tokio::time::advance(Duration::from_millis(1)).await;
     144            2 :         assert!(!bucket.check(&info, Instant::now(), 1.0));
     145            2 :         tokio::time::advance(Duration::from_millis(1)).await;
     146            2 :         assert!(bucket.check(&info, Instant::now(), 1.0));
     147            2 : 
     148            2 :         // in 10ms we should drain 5 tokens
     149            2 :         tokio::time::advance(Duration::from_millis(10)).await;
     150           12 :         for _ in 0..5 {
     151           10 :             assert!(bucket.check(&info, Instant::now(), 1.0));
     152            2 :         }
     153            2 :         assert!(!bucket.check(&info, Instant::now(), 1.0));
     154            2 : 
     155            2 :         // in 10s we should drain 5000 tokens
     156            2 :         // but cap is only 2000
     157            2 :         tokio::time::advance(Duration::from_secs(10)).await;
     158         4002 :         for _ in 0..2000 {
     159         4000 :             assert!(bucket.check(&info, Instant::now(), 1.0));
     160            2 :         }
     161            2 :         assert!(!bucket.check(&info, Instant::now(), 1.0));
     162            2 : 
     163            2 :         // should sustain 500rps
     164         4002 :         for _ in 0..2000 {
     165         4000 :             tokio::time::advance(Duration::from_millis(10)).await;
     166        24000 :             for _ in 0..5 {
     167        20000 :                 assert!(bucket.check(&info, Instant::now(), 1.0));
     168            2 :             }
     169            2 :         }
     170            2 :     }
     171              : }
        

Generated by: LCOV version 2.1-beta