Line data Source code
1 : //! This module implements the Generic Cell Rate Algorithm for a simplified
2 : //! version of the Leaky Bucket rate limiting system.
3 : //!
4 : //! # Leaky Bucket
5 : //!
6 : //! If the bucket is full, no new requests are allowed and are throttled/errored.
7 : //! If the bucket is partially full/empty, new requests are added to the bucket in
8 : //! terms of "tokens".
9 : //!
10 : //! Over time, tokens are removed from the bucket, naturally allowing new requests at a steady rate.
11 : //!
12 : //! The bucket size tunes the burst support. The drain rate tunes the steady-rate requests per second.
13 : //!
14 : //! # [GCRA](https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm)
15 : //!
16 : //! GCRA is a continuous rate leaky-bucket impl that stores minimal state and requires
17 : //! no background jobs to drain tokens, as the design utilises timestamps to drain automatically over time.
18 : //!
19 : //! We store an "empty_at" timestamp as the only state. As time progresses, we will naturally approach
20 : //! the empty state. The full-bucket state is calculated from `empty_at - config.bucket_width`.
21 : //!
22 : //! Another explaination can be found here: <https://brandur.org/rate-limiting>
23 :
24 : use std::sync::Mutex;
25 : use std::sync::atomic::{AtomicU64, Ordering};
26 : use std::time::Duration;
27 :
28 : use tokio::sync::Notify;
29 : use tokio::time::Instant;
30 :
31 : #[derive(Clone, Copy)]
32 : pub struct LeakyBucketConfig {
33 : /// This is the "time cost" of a single request unit.
34 : /// Should loosely represent how long it takes to handle a request unit in active resource time.
35 : /// Loosely speaking this is the inverse of the steady-rate requests-per-second
36 : pub cost: Duration,
37 :
38 : /// total size of the bucket
39 : pub bucket_width: Duration,
40 : }
41 :
42 : impl LeakyBucketConfig {
43 124 : pub fn new(rps: f64, bucket_size: f64) -> Self {
44 124 : let cost = Duration::from_secs_f64(rps.recip());
45 124 : let bucket_width = cost.mul_f64(bucket_size);
46 124 : Self { cost, bucket_width }
47 124 : }
48 : }
49 :
50 : pub struct LeakyBucketState {
51 : /// Bucket is represented by `allow_at..empty_at` where `allow_at = empty_at - config.bucket_width`.
52 : ///
53 : /// At any given time, `empty_at - now` represents the number of tokens in the bucket, multiplied by the "time_cost".
54 : /// Adding `n` tokens to the bucket is done by moving `empty_at` forward by `n * config.time_cost`.
55 : /// If `now < allow_at`, the bucket is considered filled and cannot accept any more tokens.
56 : /// Draining the bucket will happen naturally as `now` moves forward.
57 : ///
58 : /// Let `n` be some "time cost" for the request,
59 : /// If now is after empty_at, the bucket is empty and the empty_at is reset to now,
60 : /// If now is within the `bucket window + n`, we are within time budget.
61 : /// If now is before the `bucket window + n`, we have run out of budget.
62 : ///
63 : /// This is inspired by the generic cell rate algorithm (GCRA) and works
64 : /// exactly the same as a leaky-bucket.
65 : pub empty_at: Instant,
66 : }
67 :
68 : impl LeakyBucketState {
69 120 : pub fn with_initial_tokens(config: &LeakyBucketConfig, initial_tokens: f64) -> Self {
70 120 : LeakyBucketState {
71 120 : empty_at: Instant::now() + config.cost.mul_f64(initial_tokens),
72 120 : }
73 120 : }
74 :
75 2 : pub fn bucket_is_empty(&self, now: Instant) -> bool {
76 : // if self.end is after now, the bucket is not empty
77 2 : self.empty_at <= now
78 2 : }
79 :
80 : /// Immediately adds tokens to the bucket, if there is space.
81 : ///
82 : /// In a scenario where you are waiting for available rate,
83 : /// rather than just erroring immediately, `started` corresponds to when this waiting started.
84 : ///
85 : /// `n` is the number of tokens that will be filled in the bucket.
86 : ///
87 : /// # Errors
88 : ///
89 : /// If there is not enough space, no tokens are added. Instead, an error is returned with the time when
90 : /// there will be space again.
91 16219 : pub fn add_tokens(
92 16219 : &mut self,
93 16219 : config: &LeakyBucketConfig,
94 16219 : started: Instant,
95 16219 : n: f64,
96 16219 : ) -> Result<(), Instant> {
97 16219 : let now = Instant::now();
98 :
99 : // invariant: started <= now
100 16219 : debug_assert!(started <= now);
101 :
102 : // If the bucket was empty when we started our search,
103 : // we should update the `empty_at` value accordingly.
104 : // this prevents us from having negative tokens in the bucket.
105 16219 : let mut empty_at = self.empty_at;
106 16219 : if empty_at < started {
107 6 : empty_at = started;
108 16213 : }
109 :
110 16219 : let n = config.cost.mul_f64(n);
111 16219 : let new_empty_at = empty_at + n;
112 16219 : let allow_at = new_empty_at.checked_sub(config.bucket_width);
113 :
114 : // empty_at
115 : // allow_at | new_empty_at
116 : // / | /
117 : // -------o-[---------o-|--]---------
118 : // now1 ^ now2 ^
119 : //
120 : // at now1, the bucket would be completely filled if we add n tokens.
121 : // at now2, the bucket would be partially filled if we add n tokens.
122 :
123 16219 : match allow_at {
124 16219 : Some(allow_at) if now < allow_at => Err(allow_at),
125 : _ => {
126 16210 : self.empty_at = new_empty_at;
127 16210 : Ok(())
128 : }
129 : }
130 16219 : }
131 : }
132 :
133 : pub struct RateLimiter {
134 : pub config: LeakyBucketConfig,
135 : pub sleep_counter: AtomicU64,
136 : pub state: Mutex<LeakyBucketState>,
137 : /// a queue to provide this fair ordering.
138 : pub queue: Notify,
139 : }
140 :
141 : struct Requeue<'a>(&'a Notify);
142 :
143 : impl Drop for Requeue<'_> {
144 0 : fn drop(&mut self) {
145 0 : self.0.notify_one();
146 0 : }
147 : }
148 :
149 : impl RateLimiter {
150 120 : pub fn with_initial_tokens(config: LeakyBucketConfig, initial_tokens: f64) -> Self {
151 120 : RateLimiter {
152 120 : sleep_counter: AtomicU64::new(0),
153 120 : state: Mutex::new(LeakyBucketState::with_initial_tokens(
154 120 : &config,
155 120 : initial_tokens,
156 120 : )),
157 120 : config,
158 120 : queue: {
159 120 : let queue = Notify::new();
160 120 : queue.notify_one();
161 120 : queue
162 120 : },
163 120 : }
164 120 : }
165 :
166 0 : pub fn steady_rps(&self) -> f64 {
167 0 : self.config.cost.as_secs_f64().recip()
168 0 : }
169 :
170 : /// returns true if we did throttle
171 0 : pub async fn acquire(&self, count: usize) -> bool {
172 0 : let start = tokio::time::Instant::now();
173 :
174 0 : let start_count = self.sleep_counter.load(Ordering::Acquire);
175 0 : let mut end_count = start_count;
176 :
177 : // wait until we are the first in the queue
178 0 : let mut notified = std::pin::pin!(self.queue.notified());
179 0 : if !notified.as_mut().enable() {
180 0 : notified.await;
181 0 : end_count = self.sleep_counter.load(Ordering::Acquire);
182 0 : }
183 :
184 : // notify the next waiter in the queue when we are done.
185 0 : let _guard = Requeue(&self.queue);
186 :
187 : loop {
188 0 : let res = self
189 0 : .state
190 0 : .lock()
191 0 : .unwrap()
192 0 : .add_tokens(&self.config, start, count as f64);
193 0 : match res {
194 0 : Ok(()) => return end_count > start_count,
195 0 : Err(ready_at) => {
196 : struct Increment<'a>(&'a AtomicU64);
197 :
198 : impl Drop for Increment<'_> {
199 0 : fn drop(&mut self) {
200 0 : self.0.fetch_add(1, Ordering::AcqRel);
201 0 : }
202 : }
203 :
204 : // increment the counter after we finish sleeping (or cancel this task).
205 : // this ensures that tasks that have already started the acquire will observe
206 : // the new sleep count when they are allowed to resume on the notify.
207 0 : let _inc = Increment(&self.sleep_counter);
208 0 : end_count += 1;
209 :
210 0 : tokio::time::sleep_until(ready_at).await;
211 : }
212 : }
213 : }
214 0 : }
215 : }
216 :
217 : #[cfg(test)]
218 : mod tests {
219 : use std::time::Duration;
220 :
221 : use tokio::time::Instant;
222 :
223 : use super::{LeakyBucketConfig, LeakyBucketState};
224 :
225 : #[tokio::test(start_paused = true)]
226 1 : async fn check() {
227 1 : let config = LeakyBucketConfig {
228 1 : // average 100rps
229 1 : cost: Duration::from_millis(10),
230 1 : // burst up to 100 requests
231 1 : bucket_width: Duration::from_millis(1000),
232 1 : };
233 :
234 1 : let mut state = LeakyBucketState {
235 1 : empty_at: Instant::now(),
236 1 : };
237 :
238 : // supports burst
239 1 : {
240 1 : // should work for 100 requests this instant
241 101 : for _ in 0..100 {
242 100 : state.add_tokens(&config, Instant::now(), 1.0).unwrap();
243 100 : }
244 1 : let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
245 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(10));
246 1 : }
247 1 :
248 1 : // doesn't overfill
249 1 : {
250 1 : // after 1s we should have an empty bucket again.
251 1 : tokio::time::advance(Duration::from_secs(1)).await;
252 1 : assert!(state.bucket_is_empty(Instant::now()));
253 1 :
254 1 : // after 1s more, we should not over count the tokens and allow more than 200 requests.
255 1 : tokio::time::advance(Duration::from_secs(1)).await;
256 101 : for _ in 0..100 {
257 100 : state.add_tokens(&config, Instant::now(), 1.0).unwrap();
258 100 : }
259 1 : let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
260 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(10));
261 1 : }
262 1 :
263 1 : // supports sustained rate over a long period
264 1 : {
265 1 : tokio::time::advance(Duration::from_secs(1)).await;
266 1 :
267 1 : // should sustain 100rps
268 2001 : for _ in 0..2000 {
269 2000 : tokio::time::advance(Duration::from_millis(10)).await;
270 2000 : state.add_tokens(&config, Instant::now(), 1.0).unwrap();
271 1 : }
272 1 : }
273 1 :
274 1 : // supports requesting more tokens than can be stored in the bucket
275 1 : // we just wait a little bit longer upfront.
276 1 : {
277 1 : // start the bucket completely empty
278 1 : tokio::time::advance(Duration::from_secs(5)).await;
279 1 : assert!(state.bucket_is_empty(Instant::now()));
280 1 :
281 1 : // requesting 200 tokens of space should take 200*cost = 2s
282 1 : // but we already have 1s available, so we wait 1s from start.
283 1 : let start = Instant::now();
284 1 :
285 1 : let ready = state.add_tokens(&config, start, 200.0).unwrap_err();
286 1 : assert_eq!(ready - Instant::now(), Duration::from_secs(1));
287 1 :
288 1 : tokio::time::advance(Duration::from_millis(500)).await;
289 1 : let ready = state.add_tokens(&config, start, 200.0).unwrap_err();
290 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(500));
291 1 :
292 1 : tokio::time::advance(Duration::from_millis(500)).await;
293 1 : state.add_tokens(&config, start, 200.0).unwrap();
294 1 :
295 1 : // bucket should be completely full now
296 1 : let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
297 1 : assert_eq!(ready - Instant::now(), Duration::from_millis(10));
298 1 : }
299 1 : }
300 : }
|