Line data Source code
1 : use std::{
2 : collections::hash_map::RandomState,
3 : hash::BuildHasher,
4 : sync::{
5 : atomic::{AtomicUsize, Ordering},
6 : Arc, Mutex,
7 : },
8 : };
9 :
10 : use anyhow::bail;
11 : use dashmap::DashMap;
12 : use itertools::Itertools;
13 : use rand::{rngs::StdRng, Rng, SeedableRng};
14 : use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
15 : use tokio::time::{timeout, Duration, Instant};
16 : use tracing::info;
17 :
18 : use crate::EndpointId;
19 :
20 : use super::{
21 : limit_algorithm::{LimitAlgorithm, Sample},
22 : RateLimiterConfig,
23 : };
24 :
25 : pub struct RedisRateLimiter {
26 : data: Vec<RateBucket>,
27 : info: &'static [RateBucketInfo],
28 : }
29 :
30 : impl RedisRateLimiter {
31 0 : pub fn new(info: &'static [RateBucketInfo]) -> Self {
32 0 : Self {
33 0 : data: vec![
34 0 : RateBucket {
35 0 : start: Instant::now(),
36 0 : count: 0,
37 0 : };
38 0 : info.len()
39 0 : ],
40 0 : info,
41 0 : }
42 0 : }
43 :
44 : /// Check that number of connections is below `max_rps` rps.
45 0 : pub fn check(&mut self) -> bool {
46 0 : let now = Instant::now();
47 0 :
48 0 : let should_allow_request = self
49 0 : .data
50 0 : .iter_mut()
51 0 : .zip(self.info)
52 0 : .all(|(bucket, info)| bucket.should_allow_request(info, now));
53 0 :
54 0 : if should_allow_request {
55 0 : // only increment the bucket counts if the request will actually be accepted
56 0 : self.data.iter_mut().for_each(RateBucket::inc);
57 0 : }
58 :
59 0 : should_allow_request
60 0 : }
61 : }
62 :
63 : // Simple per-endpoint rate limiter.
64 : //
65 : // Check that number of connections to the endpoint is below `max_rps` rps.
66 : // Purposefully ignore user name and database name as clients can reconnect
67 : // with different names, so we'll end up sending some http requests to
68 : // the control plane.
69 : //
70 : // We also may save quite a lot of CPU (I think) by bailing out right after we
71 : // saw SNI, before doing TLS handshake. User-side error messages in that case
72 : // does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
73 : // I went with a more expensive way that yields user-friendlier error messages.
74 : pub struct EndpointRateLimiter<Rand = StdRng, Hasher = RandomState> {
75 : map: DashMap<EndpointId, Vec<RateBucket>, Hasher>,
76 : info: &'static [RateBucketInfo],
77 : access_count: AtomicUsize,
78 : rand: Mutex<Rand>,
79 : }
80 :
81 4000002 : #[derive(Clone, Copy)]
82 : struct RateBucket {
83 : start: Instant,
84 : count: u32,
85 : }
86 :
87 : impl RateBucket {
88 6002812 : fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant) -> bool {
89 6002812 : if now - self.start < info.interval {
90 6002796 : self.count < info.max_rpi
91 : } else {
92 : // bucket expired, reset
93 16 : self.count = 0;
94 16 : self.start = now;
95 16 :
96 16 : true
97 : }
98 6002812 : }
99 :
100 6002800 : fn inc(&mut self) {
101 6002800 : self.count += 1;
102 6002800 : }
103 : }
104 :
105 4 : #[derive(Clone, Copy, PartialEq)]
106 : pub struct RateBucketInfo {
107 : pub interval: Duration,
108 : // requests per interval
109 : pub max_rpi: u32,
110 : }
111 :
112 : impl std::fmt::Display for RateBucketInfo {
113 20 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 20 : let rps = self.max_rpi * 1000 / self.interval.as_millis() as u32;
115 20 : write!(f, "{rps}@{}", humantime::format_duration(self.interval))
116 20 : }
117 : }
118 :
119 : impl std::fmt::Debug for RateBucketInfo {
120 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 0 : write!(f, "{self}")
122 0 : }
123 : }
124 :
125 : impl std::str::FromStr for RateBucketInfo {
126 : type Err = anyhow::Error;
127 :
128 34 : fn from_str(s: &str) -> Result<Self, Self::Err> {
129 34 : let Some((max_rps, interval)) = s.split_once('@') else {
130 0 : bail!("invalid rate info")
131 : };
132 34 : let max_rps = max_rps.parse()?;
133 34 : let interval = humantime::parse_duration(interval)?;
134 34 : Ok(Self::new(max_rps, interval))
135 34 : }
136 : }
137 :
138 : impl RateBucketInfo {
139 : pub const DEFAULT_SET: [Self; 3] = [
140 : Self::new(300, Duration::from_secs(1)),
141 : Self::new(200, Duration::from_secs(60)),
142 : Self::new(100, Duration::from_secs(600)),
143 : ];
144 :
145 6 : pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
146 16 : info.sort_unstable_by_key(|info| info.interval);
147 6 : let invalid = info
148 6 : .iter()
149 6 : .tuple_windows()
150 8 : .find(|(a, b)| a.max_rpi > b.max_rpi);
151 6 : if let Some((a, b)) = invalid {
152 2 : bail!(
153 2 : "invalid endpoint RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
154 2 : b.max_rpi,
155 2 : a.max_rpi,
156 2 : );
157 4 : }
158 4 :
159 4 : Ok(())
160 6 : }
161 :
162 42 : pub const fn new(max_rps: u32, interval: Duration) -> Self {
163 42 : Self {
164 42 : interval,
165 42 : max_rpi: max_rps * interval.as_millis() as u32 / 1000,
166 42 : }
167 42 : }
168 : }
169 :
170 : impl EndpointRateLimiter {
171 2 : pub fn new(info: &'static [RateBucketInfo]) -> Self {
172 2 : Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
173 2 : }
174 : }
175 :
176 : impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
177 4 : fn new_with_rand_and_hasher(info: &'static [RateBucketInfo], rand: R, hasher: S) -> Self {
178 4 : info!(buckets = ?info, "endpoint rate limiter");
179 4 : Self {
180 4 : info,
181 4 : map: DashMap::with_hasher_and_shard_amount(hasher, 64),
182 4 : access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
183 4 : rand: Mutex::new(rand),
184 4 : }
185 4 : }
186 :
187 : /// Check that number of connections to the endpoint is below `max_rps` rps.
188 2001408 : pub fn check(&self, endpoint: EndpointId) -> bool {
189 2001408 : // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
190 2001408 : // worst case memory usage is about:
191 2001408 : // = 2 * 2048 * 64 * (48B + 72B)
192 2001408 : // = 30MB
193 2001408 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
194 976 : self.do_gc();
195 2000432 : }
196 :
197 2001408 : let now = Instant::now();
198 2001408 : let mut entry = self.map.entry(endpoint).or_insert_with(|| {
199 2000002 : vec![
200 2000002 : RateBucket {
201 2000002 : start: now,
202 2000002 : count: 0,
203 2000002 : };
204 2000002 : self.info.len()
205 2000002 : ]
206 2001408 : });
207 2001408 :
208 2001408 : let should_allow_request = entry
209 2001408 : .iter_mut()
210 2001408 : .zip(self.info)
211 6002812 : .all(|(bucket, info)| bucket.should_allow_request(info, now));
212 2001408 :
213 2001408 : if should_allow_request {
214 2001400 : // only increment the bucket counts if the request will actually be accepted
215 2001400 : entry.iter_mut().for_each(RateBucket::inc);
216 2001400 : }
217 :
218 2001408 : should_allow_request
219 2001408 : }
220 :
221 : /// Clean the map. Simple strategy: remove all entries in a random shard.
222 : /// At worst, we'll double the effective max_rps during the cleanup.
223 : /// But that way deletion does not aquire mutex on each entry access.
224 976 : pub fn do_gc(&self) {
225 976 : info!(
226 0 : "cleaning up endpoint rate limiter, current size = {}",
227 0 : self.map.len()
228 0 : );
229 976 : let n = self.map.shards().len();
230 976 : // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
231 976 : // (impossible, infact, unless we have 2048 threads)
232 976 : let shard = self.rand.lock().unwrap().gen_range(0..n);
233 976 : self.map.shards()[shard].write().clear();
234 976 : }
235 : }
236 :
237 : /// Limits the number of concurrent jobs.
238 : ///
239 : /// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
240 : /// token once the job is finished.
241 : ///
242 : /// The limit will be automatically adjusted based on observed latency (delay) and/or failures
243 : /// caused by overload (loss).
244 : pub struct Limiter {
245 : limit_algo: AsyncMutex<Box<dyn LimitAlgorithm>>,
246 : semaphore: std::sync::Arc<Semaphore>,
247 : config: RateLimiterConfig,
248 :
249 : // ONLY WRITE WHEN LIMIT_ALGO IS LOCKED
250 : limits: AtomicUsize,
251 :
252 : // ONLY USE ATOMIC ADD/SUB
253 : in_flight: Arc<AtomicUsize>,
254 :
255 : #[cfg(test)]
256 : notifier: Option<std::sync::Arc<tokio::sync::Notify>>,
257 : }
258 :
259 : /// A concurrency token, required to run a job.
260 : ///
261 : /// Release the token back to the [Limiter] after the job is complete.
262 0 : #[derive(Debug)]
263 : pub struct Token<'t> {
264 : permit: Option<tokio::sync::SemaphorePermit<'t>>,
265 : start: Instant,
266 : in_flight: Arc<AtomicUsize>,
267 : }
268 :
269 : /// A snapshot of the state of the [Limiter].
270 : ///
271 : /// Not guaranteed to be consistent under high concurrency.
272 0 : #[derive(Debug, Clone, Copy)]
273 : pub struct LimiterState {
274 : limit: usize,
275 : in_flight: usize,
276 : }
277 :
278 : /// Whether a job succeeded or failed as a result of congestion/overload.
279 : ///
280 : /// Errors not considered to be caused by overload should be ignored.
281 0 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
282 : pub enum Outcome {
283 : /// The job succeeded, or failed in a way unrelated to overload.
284 : Success,
285 : /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
286 : /// was observed.
287 : Overload,
288 : }
289 :
290 : impl Outcome {
291 0 : fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self {
292 0 : match error {
293 0 : reqwest_middleware::Error::Middleware(_) => Outcome::Success,
294 0 : reqwest_middleware::Error::Reqwest(e) => {
295 0 : if let Some(status) = e.status() {
296 0 : if status.is_server_error()
297 0 : || reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status
298 : {
299 0 : Outcome::Overload
300 : } else {
301 0 : Outcome::Success
302 : }
303 : } else {
304 0 : Outcome::Success
305 : }
306 : }
307 : }
308 0 : }
309 4 : fn from_reqwest_response(response: &reqwest::Response) -> Self {
310 4 : if response.status().is_server_error()
311 4 : || response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS
312 : {
313 0 : Outcome::Overload
314 : } else {
315 4 : Outcome::Success
316 : }
317 4 : }
318 : }
319 :
320 : impl Limiter {
321 : /// Create a limiter with a given limit control algorithm.
322 16 : pub fn new(config: RateLimiterConfig) -> Self {
323 16 : assert!(config.initial_limit > 0);
324 16 : Self {
325 16 : limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()),
326 16 : semaphore: Arc::new(Semaphore::new(config.initial_limit)),
327 16 : config,
328 16 : limits: AtomicUsize::new(config.initial_limit),
329 16 : in_flight: Arc::new(AtomicUsize::new(0)),
330 16 : #[cfg(test)]
331 16 : notifier: None,
332 16 : }
333 16 : }
334 : // pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self {
335 : // assert!(initial_limit > 0);
336 :
337 : // Self {
338 : // limit_algo: AsyncMutex::new(limit_algorithm),
339 : // semaphore: Arc::new(Semaphore::new(initial_limit)),
340 : // timeout,
341 : // limits: AtomicUsize::new(initial_limit),
342 : // in_flight: Arc::new(AtomicUsize::new(0)),
343 : // #[cfg(test)]
344 : // notifier: None,
345 : // }
346 : // }
347 :
348 : /// In some cases [Token]s are acquired asynchronously when updating the limit.
349 : #[cfg(test)]
350 2 : pub fn with_release_notifier(mut self, n: std::sync::Arc<tokio::sync::Notify>) -> Self {
351 2 : self.notifier = Some(n);
352 2 : self
353 2 : }
354 :
355 : /// Try to immediately acquire a concurrency [Token].
356 : ///
357 : /// Returns `None` if there are none available.
358 26 : pub fn try_acquire(&self) -> Option<Token> {
359 26 : let result = if self.config.disable {
360 : // If the rate limiter is disabled, we can always acquire a token.
361 4 : Some(Token::new(None, self.in_flight.clone()))
362 : } else {
363 22 : self.semaphore
364 22 : .try_acquire()
365 22 : .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
366 22 : .ok()
367 : };
368 26 : if result.is_some() {
369 22 : self.in_flight.fetch_add(1, Ordering::AcqRel);
370 22 : }
371 26 : result
372 26 : }
373 :
374 : /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
375 : ///
376 : /// Returns `None` if there are none available after `duration`.
377 8 : pub async fn acquire_timeout(&self, duration: Duration) -> Option<Token<'_>> {
378 0 : info!("acquiring token: {:?}", self.semaphore.available_permits());
379 8 : let result = if self.config.disable {
380 : // If the rate limiter is disabled, we can always acquire a token.
381 4 : Some(Token::new(None, self.in_flight.clone()))
382 : } else {
383 6 : match timeout(duration, self.semaphore.acquire()).await {
384 4 : Ok(maybe_permit) => maybe_permit
385 4 : .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
386 4 : .ok(),
387 0 : Err(_) => None,
388 : }
389 : };
390 8 : if result.is_some() {
391 8 : self.in_flight.fetch_add(1, Ordering::AcqRel);
392 8 : }
393 8 : result
394 8 : }
395 :
396 : /// Return the concurrency [Token], along with the outcome of the job.
397 : ///
398 : /// The [Outcome] of the job, and the time taken to perform it, may be used
399 : /// to update the concurrency limit.
400 : ///
401 : /// Set the outcome to `None` to ignore the job.
402 26 : pub async fn release(&self, mut token: Token<'_>, outcome: Option<Outcome>) {
403 0 : tracing::info!("outcome is {:?}", outcome);
404 26 : let in_flight = self.in_flight.load(Ordering::Acquire);
405 26 : let old_limit = self.limits.load(Ordering::Acquire);
406 26 : let available = if self.config.disable {
407 8 : 0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0.
408 : } else {
409 18 : self.semaphore.available_permits()
410 : };
411 26 : let total = in_flight + available;
412 :
413 26 : let mut algo = self.limit_algo.lock().await;
414 :
415 26 : let new_limit = if let Some(outcome) = outcome {
416 20 : let sample = Sample {
417 20 : latency: token.start.elapsed(),
418 20 : in_flight,
419 20 : outcome,
420 20 : };
421 20 : algo.update(old_limit, sample).await
422 : } else {
423 6 : old_limit
424 : };
425 0 : tracing::info!("new limit is {}", new_limit);
426 26 : let actual_limit = if new_limit < total {
427 4 : token.forget();
428 4 : total.saturating_sub(1)
429 : } else {
430 22 : if !self.config.disable {
431 16 : self.semaphore.add_permits(new_limit.saturating_sub(total));
432 16 : }
433 22 : new_limit
434 : };
435 26 : crate::metrics::RATE_LIMITER_LIMIT
436 26 : .with_label_values(&["expected"])
437 26 : .set(new_limit as i64);
438 26 : crate::metrics::RATE_LIMITER_LIMIT
439 26 : .with_label_values(&["actual"])
440 26 : .set(actual_limit as i64);
441 26 : self.limits.store(new_limit, Ordering::Release);
442 0 : #[cfg(test)]
443 26 : if let Some(n) = &self.notifier {
444 2 : n.notify_one();
445 24 : }
446 26 : }
447 :
448 : /// The current state of the limiter.
449 12 : pub fn state(&self) -> LimiterState {
450 12 : let limit = self.limits.load(Ordering::Relaxed);
451 12 : let in_flight = self.in_flight.load(Ordering::Relaxed);
452 12 : LimiterState { limit, in_flight }
453 12 : }
454 : }
455 :
456 : impl<'t> Token<'t> {
457 30 : fn new(permit: Option<SemaphorePermit<'t>>, in_flight: Arc<AtomicUsize>) -> Self {
458 30 : Self {
459 30 : permit,
460 30 : start: Instant::now(),
461 30 : in_flight,
462 30 : }
463 30 : }
464 :
465 4 : pub fn forget(&mut self) {
466 4 : if let Some(permit) = self.permit.take() {
467 2 : permit.forget();
468 2 : }
469 4 : }
470 : }
471 :
472 : impl Drop for Token<'_> {
473 30 : fn drop(&mut self) {
474 30 : self.in_flight.fetch_sub(1, Ordering::AcqRel);
475 30 : }
476 : }
477 :
478 : impl LimiterState {
479 : /// The current concurrency limit.
480 12 : pub fn limit(&self) -> usize {
481 12 : self.limit
482 12 : }
483 : /// The number of jobs in flight.
484 2 : pub fn in_flight(&self) -> usize {
485 2 : self.in_flight
486 2 : }
487 : }
488 :
489 : #[async_trait::async_trait]
490 : impl reqwest_middleware::Middleware for Limiter {
491 4 : async fn handle(
492 4 : &self,
493 4 : req: reqwest::Request,
494 4 : extensions: &mut task_local_extensions::Extensions,
495 4 : next: reqwest_middleware::Next<'_>,
496 4 : ) -> reqwest_middleware::Result<reqwest::Response> {
497 4 : let start = Instant::now();
498 4 : let token = self
499 4 : .acquire_timeout(self.config.timeout)
500 0 : .await
501 4 : .ok_or_else(|| {
502 0 : reqwest_middleware::Error::Middleware(
503 0 : // TODO: Should we map it into user facing errors?
504 0 : crate::console::errors::ApiError::Console {
505 0 : status: crate::http::StatusCode::TOO_MANY_REQUESTS,
506 0 : text: "Too many requests".into(),
507 0 : }
508 0 : .into(),
509 0 : )
510 4 : })?;
511 4 : info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane");
512 4 : crate::metrics::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64());
513 8 : match next.run(req, extensions).await {
514 4 : Ok(response) => {
515 4 : self.release(token, Some(Outcome::from_reqwest_response(&response)))
516 0 : .await;
517 4 : Ok(response)
518 : }
519 0 : Err(e) => {
520 0 : self.release(token, Some(Outcome::from_reqwest_error(&e)))
521 0 : .await;
522 0 : Err(e)
523 : }
524 : }
525 12 : }
526 : }
527 :
528 : #[cfg(test)]
529 : mod tests {
530 : use std::{hash::BuildHasherDefault, pin::pin, task::Context, time::Duration};
531 :
532 : use futures::{task::noop_waker_ref, Future};
533 : use rand::SeedableRng;
534 : use rustc_hash::FxHasher;
535 : use tokio::time;
536 :
537 : use super::{EndpointRateLimiter, Limiter, Outcome};
538 : use crate::{
539 : rate_limiter::{RateBucketInfo, RateLimitAlgorithm},
540 : EndpointId,
541 : };
542 :
543 2 : #[tokio::test]
544 2 : async fn it_works() {
545 2 : let config = super::RateLimiterConfig {
546 2 : algorithm: RateLimitAlgorithm::Fixed,
547 2 : timeout: Duration::from_secs(1),
548 2 : initial_limit: 10,
549 2 : disable: false,
550 2 : ..Default::default()
551 2 : };
552 2 : let limiter = Limiter::new(config);
553 2 :
554 2 : let token = limiter.try_acquire().unwrap();
555 2 :
556 2 : limiter.release(token, Some(Outcome::Success)).await;
557 2 :
558 2 : assert_eq!(limiter.state().limit(), 10);
559 2 : }
560 :
561 2 : #[tokio::test]
562 2 : async fn is_fair() {
563 2 : let config = super::RateLimiterConfig {
564 2 : algorithm: RateLimitAlgorithm::Fixed,
565 2 : timeout: Duration::from_secs(1),
566 2 : initial_limit: 1,
567 2 : disable: false,
568 2 : ..Default::default()
569 2 : };
570 2 : let limiter = Limiter::new(config);
571 2 :
572 2 : // === TOKEN 1 ===
573 2 : let token1 = limiter.try_acquire().unwrap();
574 2 :
575 2 : let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
576 2 : assert!(
577 2 : token2_fut
578 2 : .as_mut()
579 2 : .poll(&mut Context::from_waker(noop_waker_ref()))
580 2 : .is_pending(),
581 2 : "token is acquired by token1"
582 2 : );
583 2 :
584 2 : let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
585 2 : assert!(
586 2 : token3_fut
587 2 : .as_mut()
588 2 : .poll(&mut Context::from_waker(noop_waker_ref()))
589 2 : .is_pending(),
590 2 : "token is acquired by token1"
591 2 : );
592 2 :
593 2 : limiter.release(token1, Some(Outcome::Success)).await;
594 2 : // === END TOKEN 1 ===
595 2 :
596 2 : // === TOKEN 2 ===
597 2 : assert!(
598 2 : limiter.try_acquire().is_none(),
599 2 : "token is acquired by token2"
600 2 : );
601 2 :
602 2 : assert!(
603 2 : token3_fut
604 2 : .as_mut()
605 2 : .poll(&mut Context::from_waker(noop_waker_ref()))
606 2 : .is_pending(),
607 2 : "token is acquired by token2"
608 2 : );
609 2 :
610 2 : let token2 = token2_fut.await.unwrap();
611 2 :
612 2 : limiter.release(token2, Some(Outcome::Success)).await;
613 2 : // === END TOKEN 2 ===
614 2 :
615 2 : // === TOKEN 3 ===
616 2 : assert!(
617 2 : limiter.try_acquire().is_none(),
618 2 : "token is acquired by token3"
619 2 : );
620 2 :
621 2 : let token3 = token3_fut.await.unwrap();
622 2 : limiter.release(token3, Some(Outcome::Success)).await;
623 2 : // === END TOKEN 3 ===
624 2 :
625 2 : // === TOKEN 4 ===
626 2 : let token4 = limiter.try_acquire().unwrap();
627 2 : limiter.release(token4, Some(Outcome::Success)).await;
628 2 : }
629 :
630 2 : #[tokio::test]
631 2 : async fn disable() {
632 2 : let config = super::RateLimiterConfig {
633 2 : algorithm: RateLimitAlgorithm::Fixed,
634 2 : timeout: Duration::from_secs(1),
635 2 : initial_limit: 1,
636 2 : disable: true,
637 2 : ..Default::default()
638 2 : };
639 2 : let limiter = Limiter::new(config);
640 2 :
641 2 : // === TOKEN 1 ===
642 2 : let token1 = limiter.try_acquire().unwrap();
643 2 : let token2 = limiter.try_acquire().unwrap();
644 2 : let state = limiter.state();
645 2 : assert_eq!(state.limit(), 1);
646 2 : assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected.
647 2 : limiter.release(token1, None).await;
648 2 : limiter.release(token2, None).await;
649 2 : }
650 :
651 2 : #[test]
652 2 : fn rate_bucket_rpi() {
653 2 : let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
654 2 : assert_eq!(rate_bucket.max_rpi, 50 * 5);
655 :
656 2 : let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
657 2 : assert_eq!(rate_bucket.max_rpi, 50 / 2);
658 2 : }
659 :
660 2 : #[test]
661 2 : fn rate_bucket_parse() {
662 2 : let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
663 2 : assert_eq!(rate_bucket.interval, Duration::from_secs(10));
664 2 : assert_eq!(rate_bucket.max_rpi, 100 * 10);
665 2 : assert_eq!(rate_bucket.to_string(), "100@10s");
666 :
667 2 : let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
668 2 : assert_eq!(rate_bucket.interval, Duration::from_secs(60));
669 2 : assert_eq!(rate_bucket.max_rpi, 100 * 60);
670 2 : assert_eq!(rate_bucket.to_string(), "100@1m");
671 2 : }
672 :
673 2 : #[test]
674 2 : fn default_rate_buckets() {
675 2 : let mut defaults = RateBucketInfo::DEFAULT_SET;
676 2 : RateBucketInfo::validate(&mut defaults[..]).unwrap();
677 2 : }
678 :
679 2 : #[test]
680 : #[should_panic = "invalid endpoint RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
681 2 : fn rate_buckets_validate() {
682 2 : let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
683 2 : .into_iter()
684 4 : .map(|s| s.parse().unwrap())
685 2 : .collect();
686 2 : RateBucketInfo::validate(&mut rates).unwrap();
687 2 : }
688 :
689 2 : #[tokio::test]
690 2 : async fn test_rate_limits() {
691 2 : let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
692 2 : .into_iter()
693 4 : .map(|s| s.parse().unwrap())
694 2 : .collect();
695 2 : RateBucketInfo::validate(&mut rates).unwrap();
696 2 : let limiter = EndpointRateLimiter::new(Vec::leak(rates));
697 2 :
698 2 : let endpoint = EndpointId::from("ep-my-endpoint-1234");
699 2 :
700 2 : time::pause();
701 2 :
702 202 : for _ in 0..100 {
703 200 : assert!(limiter.check(endpoint.clone()));
704 2 : }
705 2 : // more connections fail
706 2 : assert!(!limiter.check(endpoint.clone()));
707 2 :
708 2 : // fail even after 500ms as it's in the same bucket
709 2 : time::advance(time::Duration::from_millis(500)).await;
710 2 : assert!(!limiter.check(endpoint.clone()));
711 2 :
712 2 : // after a full 1s, 100 requests are allowed again
713 2 : time::advance(time::Duration::from_millis(500)).await;
714 12 : for _ in 1..6 {
715 1010 : for _ in 0..100 {
716 1000 : assert!(limiter.check(endpoint.clone()));
717 2 : }
718 10 : time::advance(time::Duration::from_millis(1000)).await;
719 2 : }
720 2 :
721 2 : // more connections after 600 will exceed the 20rps@30s limit
722 2 : assert!(!limiter.check(endpoint.clone()));
723 2 :
724 2 : // will still fail before the 30 second limit
725 2 : time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
726 2 : assert!(!limiter.check(endpoint.clone()));
727 2 :
728 2 : // after the full 30 seconds, 100 requests are allowed again
729 2 : time::advance(time::Duration::from_millis(1)).await;
730 202 : for _ in 0..100 {
731 200 : assert!(limiter.check(endpoint.clone()));
732 2 : }
733 2 : }
734 :
735 2 : #[tokio::test]
736 2 : async fn test_rate_limits_gc() {
737 2 : // fixed seeded random/hasher to ensure that the test is not flaky
738 2 : let rand = rand::rngs::StdRng::from_seed([1; 32]);
739 2 : let hasher = BuildHasherDefault::<FxHasher>::default();
740 2 :
741 2 : let limiter = EndpointRateLimiter::new_with_rand_and_hasher(
742 2 : &RateBucketInfo::DEFAULT_SET,
743 2 : rand,
744 2 : hasher,
745 2 : );
746 2000002 : for i in 0..1_000_000 {
747 2000000 : limiter.check(format!("{i}").into());
748 2000000 : }
749 2 : assert!(limiter.map.len() < 150_000);
750 2 : }
751 : }
|