Line data Source code
1 : //! This module implements the Generic Cell Rate Algorithm for a simplified
2 : //! version of the Leaky Bucket rate limiting system.
3 : //!
4 : //! # Leaky Bucket
5 : //!
6 : //! If the bucket is full, no new requests are allowed and are throttled/errored.
7 : //! If the bucket is partially full/empty, new requests are added to the bucket in
8 : //! terms of "tokens".
9 : //!
10 : //! Over time, tokens are removed from the bucket, naturally allowing new requests at a steady rate.
11 : //!
12 : //! The bucket size tunes the burst support. The drain rate tunes the steady-rate requests per second.
13 : //!
14 : //! # [GCRA](https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm)
15 : //!
16 : //! GCRA is a continuous rate leaky-bucket impl that stores minimal state and requires
17 : //! no background jobs to drain tokens, as the design utilises timestamps to drain automatically over time.
18 : //!
19 : //! We store an "empty_at" timestamp as the only state. As time progresses, we will naturally approach
20 : //! the empty state. The full-bucket state is calculated from `empty_at - config.bucket_width`.
21 : //!
22 : //! Another explaination can be found here: <https://brandur.org/rate-limiting>
23 :
24 : use std::{
25 : sync::{
26 : atomic::{AtomicU64, Ordering},
27 : Mutex,
28 : },
29 : time::Duration,
30 : };
31 :
32 : use tokio::{sync::Notify, time::Instant};
33 :
34 : pub struct LeakyBucketConfig {
35 : /// This is the "time cost" of a single request unit.
36 : /// Should loosely represent how long it takes to handle a request unit in active resource time.
37 : /// Loosely speaking this is the inverse of the steady-rate requests-per-second
38 : pub cost: Duration,
39 :
40 : /// total size of the bucket
41 : pub bucket_width: Duration,
42 : }
43 :
44 : impl LeakyBucketConfig {
45 598 : pub fn new(rps: f64, bucket_size: f64) -> Self {
46 598 : let cost = Duration::from_secs_f64(rps.recip());
47 598 : let bucket_width = cost.mul_f64(bucket_size);
48 598 : Self { cost, bucket_width }
49 598 : }
50 : }
51 :
52 : pub struct LeakyBucketState {
53 : /// Bucket is represented by `allow_at..empty_at` where `allow_at = empty_at - config.bucket_width`.
54 : ///
55 : /// At any given time, `empty_at - now` represents the number of tokens in the bucket, multiplied by the "time_cost".
56 : /// Adding `n` tokens to the bucket is done by moving `empty_at` forward by `n * config.time_cost`.
57 : /// If `now < allow_at`, the bucket is considered filled and cannot accept any more tokens.
58 : /// Draining the bucket will happen naturally as `now` moves forward.
59 : ///
60 : /// Let `n` be some "time cost" for the request,
61 : /// If now is after empty_at, the bucket is empty and the empty_at is reset to now,
62 : /// If now is within the `bucket window + n`, we are within time budget.
63 : /// If now is before the `bucket window + n`, we have run out of budget.
64 : ///
65 : /// This is inspired by the generic cell rate algorithm (GCRA) and works
66 : /// exactly the same as a leaky-bucket.
67 : pub empty_at: Instant,
68 : }
69 :
70 : impl LeakyBucketState {
71 594 : pub fn with_initial_tokens(config: &LeakyBucketConfig, initial_tokens: f64) -> Self {
72 594 : LeakyBucketState {
73 594 : empty_at: Instant::now() + config.cost.mul_f64(initial_tokens),
74 594 : }
75 594 : }
76 :
77 2 : pub fn bucket_is_empty(&self, now: Instant) -> bool {
78 2 : // if self.end is after now, the bucket is not empty
79 2 : self.empty_at <= now
80 2 : }
81 :
82 : /// Immediately adds tokens to the bucket, if there is space.
83 : ///
84 : /// In a scenario where you are waiting for available rate,
85 : /// rather than just erroring immediately, `started` corresponds to when this waiting started.
86 : ///
87 : /// `n` is the number of tokens that will be filled in the bucket.
88 : ///
89 : /// # Errors
90 : ///
91 : /// If there is not enough space, no tokens are added. Instead, an error is returned with the time when
92 : /// there will be space again.
93 16219 : pub fn add_tokens(
94 16219 : &mut self,
95 16219 : config: &LeakyBucketConfig,
96 16219 : started: Instant,
97 16219 : n: f64,
98 16219 : ) -> Result<(), Instant> {
99 16219 : let now = Instant::now();
100 16219 :
101 16219 : // invariant: started <= now
102 16219 : debug_assert!(started <= now);
103 :
104 : // If the bucket was empty when we started our search,
105 : // we should update the `empty_at` value accordingly.
106 : // this prevents us from having negative tokens in the bucket.
107 16219 : let mut empty_at = self.empty_at;
108 16219 : if empty_at < started {
109 6 : empty_at = started;
110 16213 : }
111 :
112 16219 : let n = config.cost.mul_f64(n);
113 16219 : let new_empty_at = empty_at + n;
114 16219 : let allow_at = new_empty_at.checked_sub(config.bucket_width);
115 :
116 : // empty_at
117 : // allow_at | new_empty_at
118 : // / | /
119 : // -------o-[---------o-|--]---------
120 : // now1 ^ now2 ^
121 : //
122 : // at now1, the bucket would be completely filled if we add n tokens.
123 : // at now2, the bucket would be partially filled if we add n tokens.
124 :
125 16219 : match allow_at {
126 16219 : Some(allow_at) if now < allow_at => Err(allow_at),
127 : _ => {
128 16210 : self.empty_at = new_empty_at;
129 16210 : Ok(())
130 : }
131 : }
132 16219 : }
133 : }
134 :
135 : pub struct RateLimiter {
136 : pub config: LeakyBucketConfig,
137 : pub sleep_counter: AtomicU64,
138 : pub state: Mutex<LeakyBucketState>,
139 : /// a queue to provide this fair ordering.
140 : pub queue: Notify,
141 : }
142 :
143 : struct Requeue<'a>(&'a Notify);
144 :
145 : impl Drop for Requeue<'_> {
146 0 : fn drop(&mut self) {
147 0 : self.0.notify_one();
148 0 : }
149 : }
150 :
151 : impl RateLimiter {
152 594 : pub fn with_initial_tokens(config: LeakyBucketConfig, initial_tokens: f64) -> Self {
153 594 : RateLimiter {
154 594 : sleep_counter: AtomicU64::new(0),
155 594 : state: Mutex::new(LeakyBucketState::with_initial_tokens(
156 594 : &config,
157 594 : initial_tokens,
158 594 : )),
159 594 : config,
160 594 : queue: {
161 594 : let queue = Notify::new();
162 594 : queue.notify_one();
163 594 : queue
164 594 : },
165 594 : }
166 594 : }
167 :
168 0 : pub fn steady_rps(&self) -> f64 {
169 0 : self.config.cost.as_secs_f64().recip()
170 0 : }
171 :
172 : /// returns true if we did throttle
173 0 : pub async fn acquire(&self, count: usize) -> bool {
174 0 : let start = tokio::time::Instant::now();
175 0 :
176 0 : let start_count = self.sleep_counter.load(Ordering::Acquire);
177 0 : let mut end_count = start_count;
178 0 :
179 0 : // wait until we are the first in the queue
180 0 : let mut notified = std::pin::pin!(self.queue.notified());
181 0 : if !notified.as_mut().enable() {
182 0 : notified.await;
183 0 : end_count = self.sleep_counter.load(Ordering::Acquire);
184 0 : }
185 :
186 : // notify the next waiter in the queue when we are done.
187 0 : let _guard = Requeue(&self.queue);
188 :
189 : loop {
190 0 : let res = self
191 0 : .state
192 0 : .lock()
193 0 : .unwrap()
194 0 : .add_tokens(&self.config, start, count as f64);
195 0 : match res {
196 0 : Ok(()) => return end_count > start_count,
197 0 : Err(ready_at) => {
198 0 : struct Increment<'a>(&'a AtomicU64);
199 0 :
200 0 : impl Drop for Increment<'_> {
201 0 : fn drop(&mut self) {
202 0 : self.0.fetch_add(1, Ordering::AcqRel);
203 0 : }
204 : }
205 :
206 : // increment the counter after we finish sleeping (or cancel this task).
207 : // this ensures that tasks that have already started the acquire will observe
208 : // the new sleep count when they are allowed to resume on the notify.
209 0 : let _inc = Increment(&self.sleep_counter);
210 0 : end_count += 1;
211 0 :
212 0 : tokio::time::sleep_until(ready_at).await;
213 : }
214 : }
215 : }
216 0 : }
217 : }
218 :
219 : #[cfg(test)]
220 : mod tests {
221 : use std::time::Duration;
222 :
223 : use tokio::time::Instant;
224 :
225 : use super::{LeakyBucketConfig, LeakyBucketState};
226 :
227 : #[tokio::test(start_paused = true)]
228 1 : async fn check() {
229 1 : let config = LeakyBucketConfig {
230 1 : // average 100rps
231 1 : cost: Duration::from_millis(10),
232 1 : // burst up to 100 requests
233 1 : bucket_width: Duration::from_millis(1000),
234 1 : };
235 1 :
236 1 : let mut state = LeakyBucketState {
237 1 : empty_at: Instant::now(),
238 1 : };
239 1 :
240 1 : // supports burst
241 1 : {
242 1 : // should work for 100 requests this instant
243 101 : for _ in 0..100 {
244 100 : state.add_tokens(&config, Instant::now(), 1.0).unwrap();
245 100 : }
246 1 : let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
247 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(10));
248 1 : }
249 1 :
250 1 : // doesn't overfill
251 1 : {
252 1 : // after 1s we should have an empty bucket again.
253 1 : tokio::time::advance(Duration::from_secs(1)).await;
254 1 : assert!(state.bucket_is_empty(Instant::now()));
255 1 :
256 1 : // after 1s more, we should not over count the tokens and allow more than 200 requests.
257 1 : tokio::time::advance(Duration::from_secs(1)).await;
258 101 : for _ in 0..100 {
259 100 : state.add_tokens(&config, Instant::now(), 1.0).unwrap();
260 100 : }
261 1 : let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
262 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(10));
263 1 : }
264 1 :
265 1 : // supports sustained rate over a long period
266 1 : {
267 1 : tokio::time::advance(Duration::from_secs(1)).await;
268 1 :
269 1 : // should sustain 100rps
270 2001 : for _ in 0..2000 {
271 2000 : tokio::time::advance(Duration::from_millis(10)).await;
272 2000 : state.add_tokens(&config, Instant::now(), 1.0).unwrap();
273 1 : }
274 1 : }
275 1 :
276 1 : // supports requesting more tokens than can be stored in the bucket
277 1 : // we just wait a little bit longer upfront.
278 1 : {
279 1 : // start the bucket completely empty
280 1 : tokio::time::advance(Duration::from_secs(5)).await;
281 1 : assert!(state.bucket_is_empty(Instant::now()));
282 1 :
283 1 : // requesting 200 tokens of space should take 200*cost = 2s
284 1 : // but we already have 1s available, so we wait 1s from start.
285 1 : let start = Instant::now();
286 1 :
287 1 : let ready = state.add_tokens(&config, start, 200.0).unwrap_err();
288 1 : assert_eq!(ready - Instant::now(), Duration::from_secs(1));
289 1 :
290 1 : tokio::time::advance(Duration::from_millis(500)).await;
291 1 : let ready = state.add_tokens(&config, start, 200.0).unwrap_err();
292 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(500));
293 1 :
294 1 : tokio::time::advance(Duration::from_millis(500)).await;
295 1 : state.add_tokens(&config, start, 200.0).unwrap();
296 1 :
297 1 : // bucket should be completely full now
298 1 : let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
299 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(10));
300 1 : }
301 1 : }
302 : }
|