LCOV - code coverage report
Current view: top level - libs/utils/src - leaky_bucket.rs (source / functions) Coverage Total Hit
Test: 5445d246133daeceb0507e6cc0797ab7c1c70cb8.info Lines: 76.7 % 163 125
Test Date: 2025-03-12 18:05:02 Functions: 58.3 % 12 7

            Line data    Source code
       1              : //! This module implements the Generic Cell Rate Algorithm for a simplified
       2              : //! version of the Leaky Bucket rate limiting system.
       3              : //!
       4              : //! # Leaky Bucket
       5              : //!
       6              : //! If the bucket is full, no new requests are allowed and are throttled/errored.
       7              : //! If the bucket is partially full/empty, new requests are added to the bucket in
       8              : //! terms of "tokens".
       9              : //!
      10              : //! Over time, tokens are removed from the bucket, naturally allowing new requests at a steady rate.
      11              : //!
      12              : //! The bucket size tunes the burst support. The drain rate tunes the steady-rate requests per second.
      13              : //!
      14              : //! # [GCRA](https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm)
      15              : //!
      16              : //! GCRA is a continuous rate leaky-bucket impl that stores minimal state and requires
      17              : //! no background jobs to drain tokens, as the design utilises timestamps to drain automatically over time.
      18              : //!
      19              : //! We store an "empty_at" timestamp as the only state. As time progresses, we will naturally approach
      20              : //! the empty state. The full-bucket state is calculated from `empty_at - config.bucket_width`.
      21              : //!
      22              : //! Another explaination can be found here: <https://brandur.org/rate-limiting>
      23              : 
      24              : use std::sync::Mutex;
      25              : use std::sync::atomic::{AtomicU64, Ordering};
      26              : use std::time::Duration;
      27              : 
      28              : use tokio::sync::Notify;
      29              : use tokio::time::Instant;
      30              : 
      31              : pub struct LeakyBucketConfig {
      32              :     /// This is the "time cost" of a single request unit.
      33              :     /// Should loosely represent how long it takes to handle a request unit in active resource time.
      34              :     /// Loosely speaking this is the inverse of the steady-rate requests-per-second
      35              :     pub cost: Duration,
      36              : 
      37              :     /// total size of the bucket
      38              :     pub bucket_width: Duration,
      39              : }
      40              : 
      41              : impl LeakyBucketConfig {
      42          456 :     pub fn new(rps: f64, bucket_size: f64) -> Self {
      43          456 :         let cost = Duration::from_secs_f64(rps.recip());
      44          456 :         let bucket_width = cost.mul_f64(bucket_size);
      45          456 :         Self { cost, bucket_width }
      46          456 :     }
      47              : }
      48              : 
      49              : pub struct LeakyBucketState {
      50              :     /// Bucket is represented by `allow_at..empty_at` where `allow_at = empty_at - config.bucket_width`.
      51              :     ///
      52              :     /// At any given time, `empty_at - now` represents the number of tokens in the bucket, multiplied by the "time_cost".
      53              :     /// Adding `n` tokens to the bucket is done by moving `empty_at` forward by `n * config.time_cost`.
      54              :     /// If `now < allow_at`, the bucket is considered filled and cannot accept any more tokens.
      55              :     /// Draining the bucket will happen naturally as `now` moves forward.
      56              :     ///
      57              :     /// Let `n` be some "time cost" for the request,
      58              :     /// If now is after empty_at, the bucket is empty and the empty_at is reset to now,
      59              :     /// If now is within the `bucket window + n`, we are within time budget.
      60              :     /// If now is before the `bucket window + n`, we have run out of budget.
      61              :     ///
      62              :     /// This is inspired by the generic cell rate algorithm (GCRA) and works
      63              :     /// exactly the same as a leaky-bucket.
      64              :     pub empty_at: Instant,
      65              : }
      66              : 
      67              : impl LeakyBucketState {
      68          452 :     pub fn with_initial_tokens(config: &LeakyBucketConfig, initial_tokens: f64) -> Self {
      69          452 :         LeakyBucketState {
      70          452 :             empty_at: Instant::now() + config.cost.mul_f64(initial_tokens),
      71          452 :         }
      72          452 :     }
      73              : 
      74            2 :     pub fn bucket_is_empty(&self, now: Instant) -> bool {
      75            2 :         // if self.end is after now, the bucket is not empty
      76            2 :         self.empty_at <= now
      77            2 :     }
      78              : 
      79              :     /// Immediately adds tokens to the bucket, if there is space.
      80              :     ///
      81              :     /// In a scenario where you are waiting for available rate,
      82              :     /// rather than just erroring immediately, `started` corresponds to when this waiting started.
      83              :     ///
      84              :     /// `n` is the number of tokens that will be filled in the bucket.
      85              :     ///
      86              :     /// # Errors
      87              :     ///
      88              :     /// If there is not enough space, no tokens are added. Instead, an error is returned with the time when
      89              :     /// there will be space again.
      90        16219 :     pub fn add_tokens(
      91        16219 :         &mut self,
      92        16219 :         config: &LeakyBucketConfig,
      93        16219 :         started: Instant,
      94        16219 :         n: f64,
      95        16219 :     ) -> Result<(), Instant> {
      96        16219 :         let now = Instant::now();
      97        16219 : 
      98        16219 :         // invariant: started <= now
      99        16219 :         debug_assert!(started <= now);
     100              : 
     101              :         // If the bucket was empty when we started our search,
     102              :         // we should update the `empty_at` value accordingly.
     103              :         // this prevents us from having negative tokens in the bucket.
     104        16219 :         let mut empty_at = self.empty_at;
     105        16219 :         if empty_at < started {
     106            6 :             empty_at = started;
     107        16213 :         }
     108              : 
     109        16219 :         let n = config.cost.mul_f64(n);
     110        16219 :         let new_empty_at = empty_at + n;
     111        16219 :         let allow_at = new_empty_at.checked_sub(config.bucket_width);
     112              : 
     113              :         //                     empty_at
     114              :         //          allow_at    |   new_empty_at
     115              :         //           /          |   /
     116              :         // -------o-[---------o-|--]---------
     117              :         //   now1 ^      now2 ^
     118              :         //
     119              :         // at now1, the bucket would be completely filled if we add n tokens.
     120              :         // at now2, the bucket would be partially filled if we add n tokens.
     121              : 
     122        16219 :         match allow_at {
     123        16219 :             Some(allow_at) if now < allow_at => Err(allow_at),
     124              :             _ => {
     125        16210 :                 self.empty_at = new_empty_at;
     126        16210 :                 Ok(())
     127              :             }
     128              :         }
     129        16219 :     }
     130              : }
     131              : 
     132              : pub struct RateLimiter {
     133              :     pub config: LeakyBucketConfig,
     134              :     pub sleep_counter: AtomicU64,
     135              :     pub state: Mutex<LeakyBucketState>,
     136              :     /// a queue to provide this fair ordering.
     137              :     pub queue: Notify,
     138              : }
     139              : 
     140              : struct Requeue<'a>(&'a Notify);
     141              : 
     142              : impl Drop for Requeue<'_> {
     143            0 :     fn drop(&mut self) {
     144            0 :         self.0.notify_one();
     145            0 :     }
     146              : }
     147              : 
     148              : impl RateLimiter {
     149          452 :     pub fn with_initial_tokens(config: LeakyBucketConfig, initial_tokens: f64) -> Self {
     150          452 :         RateLimiter {
     151          452 :             sleep_counter: AtomicU64::new(0),
     152          452 :             state: Mutex::new(LeakyBucketState::with_initial_tokens(
     153          452 :                 &config,
     154          452 :                 initial_tokens,
     155          452 :             )),
     156          452 :             config,
     157          452 :             queue: {
     158          452 :                 let queue = Notify::new();
     159          452 :                 queue.notify_one();
     160          452 :                 queue
     161          452 :             },
     162          452 :         }
     163          452 :     }
     164              : 
     165            0 :     pub fn steady_rps(&self) -> f64 {
     166            0 :         self.config.cost.as_secs_f64().recip()
     167            0 :     }
     168              : 
     169              :     /// returns true if we did throttle
     170            0 :     pub async fn acquire(&self, count: usize) -> bool {
     171            0 :         let start = tokio::time::Instant::now();
     172            0 : 
     173            0 :         let start_count = self.sleep_counter.load(Ordering::Acquire);
     174            0 :         let mut end_count = start_count;
     175            0 : 
     176            0 :         // wait until we are the first in the queue
     177            0 :         let mut notified = std::pin::pin!(self.queue.notified());
     178            0 :         if !notified.as_mut().enable() {
     179            0 :             notified.await;
     180            0 :             end_count = self.sleep_counter.load(Ordering::Acquire);
     181            0 :         }
     182              : 
     183              :         // notify the next waiter in the queue when we are done.
     184            0 :         let _guard = Requeue(&self.queue);
     185              : 
     186              :         loop {
     187            0 :             let res = self
     188            0 :                 .state
     189            0 :                 .lock()
     190            0 :                 .unwrap()
     191            0 :                 .add_tokens(&self.config, start, count as f64);
     192            0 :             match res {
     193            0 :                 Ok(()) => return end_count > start_count,
     194            0 :                 Err(ready_at) => {
     195            0 :                     struct Increment<'a>(&'a AtomicU64);
     196            0 : 
     197            0 :                     impl Drop for Increment<'_> {
     198            0 :                         fn drop(&mut self) {
     199            0 :                             self.0.fetch_add(1, Ordering::AcqRel);
     200            0 :                         }
     201              :                     }
     202              : 
     203              :                     // increment the counter after we finish sleeping (or cancel this task).
     204              :                     // this ensures that tasks that have already started the acquire will observe
     205              :                     // the new sleep count when they are allowed to resume on the notify.
     206            0 :                     let _inc = Increment(&self.sleep_counter);
     207            0 :                     end_count += 1;
     208            0 : 
     209            0 :                     tokio::time::sleep_until(ready_at).await;
     210              :                 }
     211              :             }
     212              :         }
     213            0 :     }
     214              : }
     215              : 
     216              : #[cfg(test)]
     217              : mod tests {
     218              :     use std::time::Duration;
     219              : 
     220              :     use tokio::time::Instant;
     221              : 
     222              :     use super::{LeakyBucketConfig, LeakyBucketState};
     223              : 
     224              :     #[tokio::test(start_paused = true)]
     225            1 :     async fn check() {
     226            1 :         let config = LeakyBucketConfig {
     227            1 :             // average 100rps
     228            1 :             cost: Duration::from_millis(10),
     229            1 :             // burst up to 100 requests
     230            1 :             bucket_width: Duration::from_millis(1000),
     231            1 :         };
     232            1 : 
     233            1 :         let mut state = LeakyBucketState {
     234            1 :             empty_at: Instant::now(),
     235            1 :         };
     236            1 : 
     237            1 :         // supports burst
     238            1 :         {
     239            1 :             // should work for 100 requests this instant
     240          101 :             for _ in 0..100 {
     241          100 :                 state.add_tokens(&config, Instant::now(), 1.0).unwrap();
     242          100 :             }
     243            1 :             let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     244            1 :             assert_eq!(ready - Instant::now(), Duration::from_millis(10));
     245            1 :         }
     246            1 : 
     247            1 :         // doesn't overfill
     248            1 :         {
     249            1 :             // after 1s we should have an empty bucket again.
     250            1 :             tokio::time::advance(Duration::from_secs(1)).await;
     251            1 :             assert!(state.bucket_is_empty(Instant::now()));
     252            1 : 
     253            1 :             // after 1s more, we should not over count the tokens and allow more than 200 requests.
     254            1 :             tokio::time::advance(Duration::from_secs(1)).await;
     255          101 :             for _ in 0..100 {
     256          100 :                 state.add_tokens(&config, Instant::now(), 1.0).unwrap();
     257          100 :             }
     258            1 :             let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     259            1 :             assert_eq!(ready - Instant::now(), Duration::from_millis(10));
     260            1 :         }
     261            1 : 
     262            1 :         // supports sustained rate over a long period
     263            1 :         {
     264            1 :             tokio::time::advance(Duration::from_secs(1)).await;
     265            1 : 
     266            1 :             // should sustain 100rps
     267         2001 :             for _ in 0..2000 {
     268         2000 :                 tokio::time::advance(Duration::from_millis(10)).await;
     269         2000 :                 state.add_tokens(&config, Instant::now(), 1.0).unwrap();
     270            1 :             }
     271            1 :         }
     272            1 : 
     273            1 :         // supports requesting more tokens than can be stored in the bucket
     274            1 :         // we just wait a little bit longer upfront.
     275            1 :         {
     276            1 :             // start the bucket completely empty
     277            1 :             tokio::time::advance(Duration::from_secs(5)).await;
     278            1 :             assert!(state.bucket_is_empty(Instant::now()));
     279            1 : 
     280            1 :             // requesting 200 tokens of space should take 200*cost = 2s
     281            1 :             // but we already have 1s available, so we wait 1s from start.
     282            1 :             let start = Instant::now();
     283            1 : 
     284            1 :             let ready = state.add_tokens(&config, start, 200.0).unwrap_err();
     285            1 :             assert_eq!(ready - Instant::now(), Duration::from_secs(1));
     286            1 : 
     287            1 :             tokio::time::advance(Duration::from_millis(500)).await;
     288            1 :             let ready = state.add_tokens(&config, start, 200.0).unwrap_err();
     289            1 :             assert_eq!(ready - Instant::now(), Duration::from_millis(500));
     290            1 : 
     291            1 :             tokio::time::advance(Duration::from_millis(500)).await;
     292            1 :             state.add_tokens(&config, start, 200.0).unwrap();
     293            1 : 
     294            1 :             // bucket should be completely full now
     295            1 :             let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     296            1 :             assert_eq!(ready - Instant::now(), Duration::from_millis(10));
     297            1 :         }
     298            1 :     }
     299              : }
        

Generated by: LCOV version 2.1-beta