|             Line data    Source code 
       1              : //! This module implements the Generic Cell Rate Algorithm for a simplified
       2              : //! version of the Leaky Bucket rate limiting system.
       3              : //!
       4              : //! # Leaky Bucket
       5              : //!
       6              : //! If the bucket is full, no new requests are allowed and are throttled/errored.
       7              : //! If the bucket is partially full/empty, new requests are added to the bucket in
       8              : //! terms of "tokens".
       9              : //!
      10              : //! Over time, tokens are removed from the bucket, naturally allowing new requests at a steady rate.
      11              : //!
      12              : //! The bucket size tunes the burst support. The drain rate tunes the steady-rate requests per second.
      13              : //!
      14              : //! # [GCRA](https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm)
      15              : //!
      16              : //! GCRA is a continuous rate leaky-bucket impl that stores minimal state and requires
      17              : //! no background jobs to drain tokens, as the design utilises timestamps to drain automatically over time.
      18              : //!
      19              : //! We store an "empty_at" timestamp as the only state. As time progresses, we will naturally approach
      20              : //! the empty state. The full-bucket state is calculated from `empty_at - config.bucket_width`.
      21              : //!
      22              : //! Another explaination can be found here: <https://brandur.org/rate-limiting>
      23              : 
      24              : use std::sync::Mutex;
      25              : use std::sync::atomic::{AtomicU64, Ordering};
      26              : use std::time::Duration;
      27              : 
      28              : use tokio::sync::Notify;
      29              : use tokio::time::Instant;
      30              : 
      31              : #[derive(Clone, Copy)]
      32              : pub struct LeakyBucketConfig {
      33              :     /// This is the "time cost" of a single request unit.
      34              :     /// Should loosely represent how long it takes to handle a request unit in active resource time.
      35              :     /// Loosely speaking this is the inverse of the steady-rate requests-per-second
      36              :     pub cost: Duration,
      37              : 
      38              :     /// total size of the bucket
      39              :     pub bucket_width: Duration,
      40              : }
      41              : 
      42              : impl LeakyBucketConfig {
      43          124 :     pub fn new(rps: f64, bucket_size: f64) -> Self {
      44          124 :         let cost = Duration::from_secs_f64(rps.recip());
      45          124 :         let bucket_width = cost.mul_f64(bucket_size);
      46          124 :         Self { cost, bucket_width }
      47          124 :     }
      48              : }
      49              : 
      50              : pub struct LeakyBucketState {
      51              :     /// Bucket is represented by `allow_at..empty_at` where `allow_at = empty_at - config.bucket_width`.
      52              :     ///
      53              :     /// At any given time, `empty_at - now` represents the number of tokens in the bucket, multiplied by the "time_cost".
      54              :     /// Adding `n` tokens to the bucket is done by moving `empty_at` forward by `n * config.time_cost`.
      55              :     /// If `now < allow_at`, the bucket is considered filled and cannot accept any more tokens.
      56              :     /// Draining the bucket will happen naturally as `now` moves forward.
      57              :     ///
      58              :     /// Let `n` be some "time cost" for the request,
      59              :     /// If now is after empty_at, the bucket is empty and the empty_at is reset to now,
      60              :     /// If now is within the `bucket window + n`, we are within time budget.
      61              :     /// If now is before the `bucket window + n`, we have run out of budget.
      62              :     ///
      63              :     /// This is inspired by the generic cell rate algorithm (GCRA) and works
      64              :     /// exactly the same as a leaky-bucket.
      65              :     pub empty_at: Instant,
      66              : }
      67              : 
      68              : impl LeakyBucketState {
      69          120 :     pub fn with_initial_tokens(config: &LeakyBucketConfig, initial_tokens: f64) -> Self {
      70          120 :         LeakyBucketState {
      71          120 :             empty_at: Instant::now() + config.cost.mul_f64(initial_tokens),
      72          120 :         }
      73          120 :     }
      74              : 
      75            2 :     pub fn bucket_is_empty(&self, now: Instant) -> bool {
      76              :         // if self.end is after now, the bucket is not empty
      77            2 :         self.empty_at <= now
      78            2 :     }
      79              : 
      80              :     /// Immediately adds tokens to the bucket, if there is space.
      81              :     ///
      82              :     /// In a scenario where you are waiting for available rate,
      83              :     /// rather than just erroring immediately, `started` corresponds to when this waiting started.
      84              :     ///
      85              :     /// `n` is the number of tokens that will be filled in the bucket.
      86              :     ///
      87              :     /// # Errors
      88              :     ///
      89              :     /// If there is not enough space, no tokens are added. Instead, an error is returned with the time when
      90              :     /// there will be space again.
      91        16219 :     pub fn add_tokens(
      92        16219 :         &mut self,
      93        16219 :         config: &LeakyBucketConfig,
      94        16219 :         started: Instant,
      95        16219 :         n: f64,
      96        16219 :     ) -> Result<(), Instant> {
      97        16219 :         let now = Instant::now();
      98              : 
      99              :         // invariant: started <= now
     100        16219 :         debug_assert!(started <= now);
     101              : 
     102              :         // If the bucket was empty when we started our search,
     103              :         // we should update the `empty_at` value accordingly.
     104              :         // this prevents us from having negative tokens in the bucket.
     105        16219 :         let mut empty_at = self.empty_at;
     106        16219 :         if empty_at < started {
     107            6 :             empty_at = started;
     108        16213 :         }
     109              : 
     110        16219 :         let n = config.cost.mul_f64(n);
     111        16219 :         let new_empty_at = empty_at + n;
     112        16219 :         let allow_at = new_empty_at.checked_sub(config.bucket_width);
     113              : 
     114              :         //                     empty_at
     115              :         //          allow_at    |   new_empty_at
     116              :         //           /          |   /
     117              :         // -------o-[---------o-|--]---------
     118              :         //   now1 ^      now2 ^
     119              :         //
     120              :         // at now1, the bucket would be completely filled if we add n tokens.
     121              :         // at now2, the bucket would be partially filled if we add n tokens.
     122              : 
     123        16219 :         match allow_at {
     124        16219 :             Some(allow_at) if now < allow_at => Err(allow_at),
     125              :             _ => {
     126        16210 :                 self.empty_at = new_empty_at;
     127        16210 :                 Ok(())
     128              :             }
     129              :         }
     130        16219 :     }
     131              : }
     132              : 
     133              : pub struct RateLimiter {
     134              :     pub config: LeakyBucketConfig,
     135              :     pub sleep_counter: AtomicU64,
     136              :     pub state: Mutex<LeakyBucketState>,
     137              :     /// a queue to provide this fair ordering.
     138              :     pub queue: Notify,
     139              : }
     140              : 
     141              : struct Requeue<'a>(&'a Notify);
     142              : 
     143              : impl Drop for Requeue<'_> {
     144            0 :     fn drop(&mut self) {
     145            0 :         self.0.notify_one();
     146            0 :     }
     147              : }
     148              : 
     149              : impl RateLimiter {
     150          120 :     pub fn with_initial_tokens(config: LeakyBucketConfig, initial_tokens: f64) -> Self {
     151          120 :         RateLimiter {
     152          120 :             sleep_counter: AtomicU64::new(0),
     153          120 :             state: Mutex::new(LeakyBucketState::with_initial_tokens(
     154          120 :                 &config,
     155          120 :                 initial_tokens,
     156          120 :             )),
     157          120 :             config,
     158          120 :             queue: {
     159          120 :                 let queue = Notify::new();
     160          120 :                 queue.notify_one();
     161          120 :                 queue
     162          120 :             },
     163          120 :         }
     164          120 :     }
     165              : 
     166            0 :     pub fn steady_rps(&self) -> f64 {
     167            0 :         self.config.cost.as_secs_f64().recip()
     168            0 :     }
     169              : 
     170              :     /// returns true if we did throttle
     171            0 :     pub async fn acquire(&self, count: usize) -> bool {
     172            0 :         let start = tokio::time::Instant::now();
     173              : 
     174            0 :         let start_count = self.sleep_counter.load(Ordering::Acquire);
     175            0 :         let mut end_count = start_count;
     176              : 
     177              :         // wait until we are the first in the queue
     178            0 :         let mut notified = std::pin::pin!(self.queue.notified());
     179            0 :         if !notified.as_mut().enable() {
     180            0 :             notified.await;
     181            0 :             end_count = self.sleep_counter.load(Ordering::Acquire);
     182            0 :         }
     183              : 
     184              :         // notify the next waiter in the queue when we are done.
     185            0 :         let _guard = Requeue(&self.queue);
     186              : 
     187              :         loop {
     188            0 :             let res = self
     189            0 :                 .state
     190            0 :                 .lock()
     191            0 :                 .unwrap()
     192            0 :                 .add_tokens(&self.config, start, count as f64);
     193            0 :             match res {
     194            0 :                 Ok(()) => return end_count > start_count,
     195            0 :                 Err(ready_at) => {
     196              :                     struct Increment<'a>(&'a AtomicU64);
     197              : 
     198              :                     impl Drop for Increment<'_> {
     199            0 :                         fn drop(&mut self) {
     200            0 :                             self.0.fetch_add(1, Ordering::AcqRel);
     201            0 :                         }
     202              :                     }
     203              : 
     204              :                     // increment the counter after we finish sleeping (or cancel this task).
     205              :                     // this ensures that tasks that have already started the acquire will observe
     206              :                     // the new sleep count when they are allowed to resume on the notify.
     207            0 :                     let _inc = Increment(&self.sleep_counter);
     208            0 :                     end_count += 1;
     209              : 
     210            0 :                     tokio::time::sleep_until(ready_at).await;
     211              :                 }
     212              :             }
     213              :         }
     214            0 :     }
     215              : }
     216              : 
     217              : #[cfg(test)]
     218              : mod tests {
     219              :     use std::time::Duration;
     220              : 
     221              :     use tokio::time::Instant;
     222              : 
     223              :     use super::{LeakyBucketConfig, LeakyBucketState};
     224              : 
     225              :     #[tokio::test(start_paused = true)]
     226            1 :     async fn check() {
     227            1 :         let config = LeakyBucketConfig {
     228            1 :             // average 100rps
     229            1 :             cost: Duration::from_millis(10),
     230            1 :             // burst up to 100 requests
     231            1 :             bucket_width: Duration::from_millis(1000),
     232            1 :         };
     233              : 
     234            1 :         let mut state = LeakyBucketState {
     235            1 :             empty_at: Instant::now(),
     236            1 :         };
     237              : 
     238              :         // supports burst
     239            1 :         {
     240            1 :             // should work for 100 requests this instant
     241          101 :             for _ in 0..100 {
     242          100 :                 state.add_tokens(&config, Instant::now(), 1.0).unwrap();
     243          100 :             }
     244            1 :             let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     245            1 :             assert_eq!(ready - Instant::now(), Duration::from_millis(10));
     246            1 :         }
     247            1 : 
     248            1 :         // doesn't overfill
     249            1 :         {
     250            1 :             // after 1s we should have an empty bucket again.
     251            1 :             tokio::time::advance(Duration::from_secs(1)).await;
     252            1 :             assert!(state.bucket_is_empty(Instant::now()));
     253            1 : 
     254            1 :             // after 1s more, we should not over count the tokens and allow more than 200 requests.
     255            1 :             tokio::time::advance(Duration::from_secs(1)).await;
     256          101 :             for _ in 0..100 {
     257          100 :                 state.add_tokens(&config, Instant::now(), 1.0).unwrap();
     258          100 :             }
     259            1 :             let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     260            1 :             assert_eq!(ready - Instant::now(), Duration::from_millis(10));
     261            1 :         }
     262            1 : 
     263            1 :         // supports sustained rate over a long period
     264            1 :         {
     265            1 :             tokio::time::advance(Duration::from_secs(1)).await;
     266            1 : 
     267            1 :             // should sustain 100rps
     268         2001 :             for _ in 0..2000 {
     269         2000 :                 tokio::time::advance(Duration::from_millis(10)).await;
     270         2000 :                 state.add_tokens(&config, Instant::now(), 1.0).unwrap();
     271            1 :             }
     272            1 :         }
     273            1 : 
     274            1 :         // supports requesting more tokens than can be stored in the bucket
     275            1 :         // we just wait a little bit longer upfront.
     276            1 :         {
     277            1 :             // start the bucket completely empty
     278            1 :             tokio::time::advance(Duration::from_secs(5)).await;
     279            1 :             assert!(state.bucket_is_empty(Instant::now()));
     280            1 : 
     281            1 :             // requesting 200 tokens of space should take 200*cost = 2s
     282            1 :             // but we already have 1s available, so we wait 1s from start.
     283            1 :             let start = Instant::now();
     284            1 : 
     285            1 :             let ready = state.add_tokens(&config, start, 200.0).unwrap_err();
     286            1 :             assert_eq!(ready - Instant::now(), Duration::from_secs(1));
     287            1 : 
     288            1 :             tokio::time::advance(Duration::from_millis(500)).await;
     289            1 :             let ready = state.add_tokens(&config, start, 200.0).unwrap_err();
     290            1 :             assert_eq!(ready - Instant::now(), Duration::from_millis(500));
     291            1 : 
     292            1 :             tokio::time::advance(Duration::from_millis(500)).await;
     293            1 :             state.add_tokens(&config, start, 200.0).unwrap();
     294            1 : 
     295            1 :             // bucket should be completely full now
     296            1 :             let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
     297            1 :             assert_eq!(ready - Instant::now(), Duration::from_millis(10));
     298            1 :         }
     299            1 :     }
     300              : }
         |