Line data Source code
1 : //! This module implements the Generic Cell Rate Algorithm for a simplified
2 : //! version of the Leaky Bucket rate limiting system.
3 : //!
4 : //! # Leaky Bucket
5 : //!
6 : //! If the bucket is full, no new requests are allowed and are throttled/errored.
7 : //! If the bucket is partially full/empty, new requests are added to the bucket in
8 : //! terms of "tokens".
9 : //!
10 : //! Over time, tokens are removed from the bucket, naturally allowing new requests at a steady rate.
11 : //!
12 : //! The bucket size tunes the burst support. The drain rate tunes the steady-rate requests per second.
13 : //!
14 : //! # [GCRA](https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm)
15 : //!
16 : //! GCRA is a continuous rate leaky-bucket impl that stores minimal state and requires
17 : //! no background jobs to drain tokens, as the design utilises timestamps to drain automatically over time.
18 : //!
19 : //! We store an "empty_at" timestamp as the only state. As time progresses, we will naturally approach
20 : //! the empty state. The full-bucket state is calculated from `empty_at - config.bucket_width`.
21 : //!
22 : //! Another explaination can be found here: <https://brandur.org/rate-limiting>
23 :
24 : use std::sync::Mutex;
25 : use std::sync::atomic::{AtomicU64, Ordering};
26 : use std::time::Duration;
27 :
28 : use tokio::sync::Notify;
29 : use tokio::time::Instant;
30 :
31 : pub struct LeakyBucketConfig {
32 : /// This is the "time cost" of a single request unit.
33 : /// Should loosely represent how long it takes to handle a request unit in active resource time.
34 : /// Loosely speaking this is the inverse of the steady-rate requests-per-second
35 : pub cost: Duration,
36 :
37 : /// total size of the bucket
38 : pub bucket_width: Duration,
39 : }
40 :
41 : impl LeakyBucketConfig {
42 456 : pub fn new(rps: f64, bucket_size: f64) -> Self {
43 456 : let cost = Duration::from_secs_f64(rps.recip());
44 456 : let bucket_width = cost.mul_f64(bucket_size);
45 456 : Self { cost, bucket_width }
46 456 : }
47 : }
48 :
49 : pub struct LeakyBucketState {
50 : /// Bucket is represented by `allow_at..empty_at` where `allow_at = empty_at - config.bucket_width`.
51 : ///
52 : /// At any given time, `empty_at - now` represents the number of tokens in the bucket, multiplied by the "time_cost".
53 : /// Adding `n` tokens to the bucket is done by moving `empty_at` forward by `n * config.time_cost`.
54 : /// If `now < allow_at`, the bucket is considered filled and cannot accept any more tokens.
55 : /// Draining the bucket will happen naturally as `now` moves forward.
56 : ///
57 : /// Let `n` be some "time cost" for the request,
58 : /// If now is after empty_at, the bucket is empty and the empty_at is reset to now,
59 : /// If now is within the `bucket window + n`, we are within time budget.
60 : /// If now is before the `bucket window + n`, we have run out of budget.
61 : ///
62 : /// This is inspired by the generic cell rate algorithm (GCRA) and works
63 : /// exactly the same as a leaky-bucket.
64 : pub empty_at: Instant,
65 : }
66 :
67 : impl LeakyBucketState {
68 452 : pub fn with_initial_tokens(config: &LeakyBucketConfig, initial_tokens: f64) -> Self {
69 452 : LeakyBucketState {
70 452 : empty_at: Instant::now() + config.cost.mul_f64(initial_tokens),
71 452 : }
72 452 : }
73 :
74 2 : pub fn bucket_is_empty(&self, now: Instant) -> bool {
75 2 : // if self.end is after now, the bucket is not empty
76 2 : self.empty_at <= now
77 2 : }
78 :
79 : /// Immediately adds tokens to the bucket, if there is space.
80 : ///
81 : /// In a scenario where you are waiting for available rate,
82 : /// rather than just erroring immediately, `started` corresponds to when this waiting started.
83 : ///
84 : /// `n` is the number of tokens that will be filled in the bucket.
85 : ///
86 : /// # Errors
87 : ///
88 : /// If there is not enough space, no tokens are added. Instead, an error is returned with the time when
89 : /// there will be space again.
90 16219 : pub fn add_tokens(
91 16219 : &mut self,
92 16219 : config: &LeakyBucketConfig,
93 16219 : started: Instant,
94 16219 : n: f64,
95 16219 : ) -> Result<(), Instant> {
96 16219 : let now = Instant::now();
97 16219 :
98 16219 : // invariant: started <= now
99 16219 : debug_assert!(started <= now);
100 :
101 : // If the bucket was empty when we started our search,
102 : // we should update the `empty_at` value accordingly.
103 : // this prevents us from having negative tokens in the bucket.
104 16219 : let mut empty_at = self.empty_at;
105 16219 : if empty_at < started {
106 6 : empty_at = started;
107 16213 : }
108 :
109 16219 : let n = config.cost.mul_f64(n);
110 16219 : let new_empty_at = empty_at + n;
111 16219 : let allow_at = new_empty_at.checked_sub(config.bucket_width);
112 :
113 : // empty_at
114 : // allow_at | new_empty_at
115 : // / | /
116 : // -------o-[---------o-|--]---------
117 : // now1 ^ now2 ^
118 : //
119 : // at now1, the bucket would be completely filled if we add n tokens.
120 : // at now2, the bucket would be partially filled if we add n tokens.
121 :
122 16219 : match allow_at {
123 16219 : Some(allow_at) if now < allow_at => Err(allow_at),
124 : _ => {
125 16210 : self.empty_at = new_empty_at;
126 16210 : Ok(())
127 : }
128 : }
129 16219 : }
130 : }
131 :
132 : pub struct RateLimiter {
133 : pub config: LeakyBucketConfig,
134 : pub sleep_counter: AtomicU64,
135 : pub state: Mutex<LeakyBucketState>,
136 : /// a queue to provide this fair ordering.
137 : pub queue: Notify,
138 : }
139 :
140 : struct Requeue<'a>(&'a Notify);
141 :
142 : impl Drop for Requeue<'_> {
143 0 : fn drop(&mut self) {
144 0 : self.0.notify_one();
145 0 : }
146 : }
147 :
148 : impl RateLimiter {
149 452 : pub fn with_initial_tokens(config: LeakyBucketConfig, initial_tokens: f64) -> Self {
150 452 : RateLimiter {
151 452 : sleep_counter: AtomicU64::new(0),
152 452 : state: Mutex::new(LeakyBucketState::with_initial_tokens(
153 452 : &config,
154 452 : initial_tokens,
155 452 : )),
156 452 : config,
157 452 : queue: {
158 452 : let queue = Notify::new();
159 452 : queue.notify_one();
160 452 : queue
161 452 : },
162 452 : }
163 452 : }
164 :
165 0 : pub fn steady_rps(&self) -> f64 {
166 0 : self.config.cost.as_secs_f64().recip()
167 0 : }
168 :
169 : /// returns true if we did throttle
170 0 : pub async fn acquire(&self, count: usize) -> bool {
171 0 : let start = tokio::time::Instant::now();
172 0 :
173 0 : let start_count = self.sleep_counter.load(Ordering::Acquire);
174 0 : let mut end_count = start_count;
175 0 :
176 0 : // wait until we are the first in the queue
177 0 : let mut notified = std::pin::pin!(self.queue.notified());
178 0 : if !notified.as_mut().enable() {
179 0 : notified.await;
180 0 : end_count = self.sleep_counter.load(Ordering::Acquire);
181 0 : }
182 :
183 : // notify the next waiter in the queue when we are done.
184 0 : let _guard = Requeue(&self.queue);
185 :
186 : loop {
187 0 : let res = self
188 0 : .state
189 0 : .lock()
190 0 : .unwrap()
191 0 : .add_tokens(&self.config, start, count as f64);
192 0 : match res {
193 0 : Ok(()) => return end_count > start_count,
194 0 : Err(ready_at) => {
195 0 : struct Increment<'a>(&'a AtomicU64);
196 0 :
197 0 : impl Drop for Increment<'_> {
198 0 : fn drop(&mut self) {
199 0 : self.0.fetch_add(1, Ordering::AcqRel);
200 0 : }
201 : }
202 :
203 : // increment the counter after we finish sleeping (or cancel this task).
204 : // this ensures that tasks that have already started the acquire will observe
205 : // the new sleep count when they are allowed to resume on the notify.
206 0 : let _inc = Increment(&self.sleep_counter);
207 0 : end_count += 1;
208 0 :
209 0 : tokio::time::sleep_until(ready_at).await;
210 : }
211 : }
212 : }
213 0 : }
214 : }
215 :
216 : #[cfg(test)]
217 : mod tests {
218 : use std::time::Duration;
219 :
220 : use tokio::time::Instant;
221 :
222 : use super::{LeakyBucketConfig, LeakyBucketState};
223 :
224 : #[tokio::test(start_paused = true)]
225 1 : async fn check() {
226 1 : let config = LeakyBucketConfig {
227 1 : // average 100rps
228 1 : cost: Duration::from_millis(10),
229 1 : // burst up to 100 requests
230 1 : bucket_width: Duration::from_millis(1000),
231 1 : };
232 1 :
233 1 : let mut state = LeakyBucketState {
234 1 : empty_at: Instant::now(),
235 1 : };
236 1 :
237 1 : // supports burst
238 1 : {
239 1 : // should work for 100 requests this instant
240 101 : for _ in 0..100 {
241 100 : state.add_tokens(&config, Instant::now(), 1.0).unwrap();
242 100 : }
243 1 : let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
244 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(10));
245 1 : }
246 1 :
247 1 : // doesn't overfill
248 1 : {
249 1 : // after 1s we should have an empty bucket again.
250 1 : tokio::time::advance(Duration::from_secs(1)).await;
251 1 : assert!(state.bucket_is_empty(Instant::now()));
252 1 :
253 1 : // after 1s more, we should not over count the tokens and allow more than 200 requests.
254 1 : tokio::time::advance(Duration::from_secs(1)).await;
255 101 : for _ in 0..100 {
256 100 : state.add_tokens(&config, Instant::now(), 1.0).unwrap();
257 100 : }
258 1 : let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
259 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(10));
260 1 : }
261 1 :
262 1 : // supports sustained rate over a long period
263 1 : {
264 1 : tokio::time::advance(Duration::from_secs(1)).await;
265 1 :
266 1 : // should sustain 100rps
267 2001 : for _ in 0..2000 {
268 2000 : tokio::time::advance(Duration::from_millis(10)).await;
269 2000 : state.add_tokens(&config, Instant::now(), 1.0).unwrap();
270 1 : }
271 1 : }
272 1 :
273 1 : // supports requesting more tokens than can be stored in the bucket
274 1 : // we just wait a little bit longer upfront.
275 1 : {
276 1 : // start the bucket completely empty
277 1 : tokio::time::advance(Duration::from_secs(5)).await;
278 1 : assert!(state.bucket_is_empty(Instant::now()));
279 1 :
280 1 : // requesting 200 tokens of space should take 200*cost = 2s
281 1 : // but we already have 1s available, so we wait 1s from start.
282 1 : let start = Instant::now();
283 1 :
284 1 : let ready = state.add_tokens(&config, start, 200.0).unwrap_err();
285 1 : assert_eq!(ready - Instant::now(), Duration::from_secs(1));
286 1 :
287 1 : tokio::time::advance(Duration::from_millis(500)).await;
288 1 : let ready = state.add_tokens(&config, start, 200.0).unwrap_err();
289 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(500));
290 1 :
291 1 : tokio::time::advance(Duration::from_millis(500)).await;
292 1 : state.add_tokens(&config, start, 200.0).unwrap();
293 1 :
294 1 : // bucket should be completely full now
295 1 : let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
296 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(10));
297 1 : }
298 1 : }
299 : }
|