Line data Source code
1 : use std::{
2 : collections::hash_map::RandomState,
3 : hash::BuildHasher,
4 : sync::{
5 : atomic::{AtomicUsize, Ordering},
6 : Arc, Mutex,
7 : },
8 : };
9 :
10 : use anyhow::bail;
11 : use dashmap::DashMap;
12 : use itertools::Itertools;
13 : use rand::{rngs::StdRng, Rng, SeedableRng};
14 : use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
15 : use tokio::time::{timeout, Duration, Instant};
16 : use tracing::info;
17 :
18 : use crate::EndpointId;
19 :
20 : use super::{
21 : limit_algorithm::{LimitAlgorithm, Sample},
22 : RateLimiterConfig,
23 : };
24 :
25 : // Simple per-endpoint rate limiter.
26 : //
27 : // Check that number of connections to the endpoint is below `max_rps` rps.
28 : // Purposefully ignore user name and database name as clients can reconnect
29 : // with different names, so we'll end up sending some http requests to
30 : // the control plane.
31 : //
32 : // We also may save quite a lot of CPU (I think) by bailing out right after we
33 : // saw SNI, before doing TLS handshake. User-side error messages in that case
34 : // does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
35 : // I went with a more expensive way that yields user-friendlier error messages.
36 : pub struct EndpointRateLimiter<Rand = StdRng, Hasher = RandomState> {
37 : map: DashMap<EndpointId, Vec<RateBucket>, Hasher>,
38 : info: &'static [RateBucketInfo],
39 : access_count: AtomicUsize,
40 : rand: Mutex<Rand>,
41 : }
42 :
43 4000052 : #[derive(Clone, Copy)]
44 : struct RateBucket {
45 : start: Instant,
46 : count: u32,
47 : }
48 :
49 : impl RateBucket {
50 6002953 : fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant) -> bool {
51 6002953 : if now - self.start < info.interval {
52 6002934 : self.count < info.max_rpi
53 : } else {
54 : // bucket expired, reset
55 19 : self.count = 0;
56 19 : self.start = now;
57 19 :
58 19 : true
59 : }
60 6002953 : }
61 :
62 6002941 : fn inc(&mut self) {
63 6002941 : self.count += 1;
64 6002941 : }
65 : }
66 :
67 4 : #[derive(Clone, Copy, PartialEq)]
68 : pub struct RateBucketInfo {
69 : pub interval: Duration,
70 : // requests per interval
71 : pub max_rpi: u32,
72 : }
73 :
74 : impl std::fmt::Display for RateBucketInfo {
75 152 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 152 : let rps = self.max_rpi * 1000 / self.interval.as_millis() as u32;
77 152 : write!(f, "{rps}@{}", humantime::format_duration(self.interval))
78 152 : }
79 : }
80 :
81 : impl std::fmt::Debug for RateBucketInfo {
82 69 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 69 : write!(f, "{self}")
84 69 : }
85 : }
86 :
87 : impl std::str::FromStr for RateBucketInfo {
88 : type Err = anyhow::Error;
89 :
90 160 : fn from_str(s: &str) -> Result<Self, Self::Err> {
91 160 : let Some((max_rps, interval)) = s.split_once('@') else {
92 0 : bail!("invalid rate info")
93 : };
94 160 : let max_rps = max_rps.parse()?;
95 160 : let interval = humantime::parse_duration(interval)?;
96 160 : Ok(Self::new(max_rps, interval))
97 160 : }
98 : }
99 :
100 : impl RateBucketInfo {
101 : pub const DEFAULT_SET: [Self; 3] = [
102 : Self::new(300, Duration::from_secs(1)),
103 : Self::new(200, Duration::from_secs(60)),
104 : Self::new(100, Duration::from_secs(600)),
105 : ];
106 :
107 29 : pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
108 108 : info.sort_unstable_by_key(|info| info.interval);
109 29 : let invalid = info
110 29 : .iter()
111 29 : .tuple_windows()
112 54 : .find(|(a, b)| a.max_rpi > b.max_rpi);
113 29 : if let Some((a, b)) = invalid {
114 2 : bail!(
115 2 : "invalid endpoint RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
116 2 : b.max_rpi,
117 2 : a.max_rpi,
118 2 : );
119 27 : }
120 27 :
121 27 : Ok(())
122 29 : }
123 :
124 168 : pub const fn new(max_rps: u32, interval: Duration) -> Self {
125 168 : Self {
126 168 : interval,
127 168 : max_rpi: max_rps * interval.as_millis() as u32 / 1000,
128 168 : }
129 168 : }
130 : }
131 :
132 : impl EndpointRateLimiter {
133 25 : pub fn new(info: &'static [RateBucketInfo]) -> Self {
134 25 : Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
135 25 : }
136 : }
137 :
138 : impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
139 27 : fn new_with_rand_and_hasher(info: &'static [RateBucketInfo], rand: R, hasher: S) -> Self {
140 27 : info!(buckets = ?info, "endpoint rate limiter");
141 27 : Self {
142 27 : info,
143 27 : map: DashMap::with_hasher_and_shard_amount(hasher, 64),
144 27 : access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
145 27 : rand: Mutex::new(rand),
146 27 : }
147 27 : }
148 :
149 : /// Check that number of connections to the endpoint is below `max_rps` rps.
150 2001455 : pub fn check(&self, endpoint: EndpointId) -> bool {
151 2001455 : // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
152 2001455 : // worst case memory usage is about:
153 2001455 : // = 2 * 2048 * 64 * (48B + 72B)
154 2001455 : // = 30MB
155 2001455 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
156 976 : self.do_gc();
157 2000479 : }
158 :
159 2001455 : let now = Instant::now();
160 2001455 : let mut entry = self.map.entry(endpoint).or_insert_with(|| {
161 2000027 : vec![
162 2000027 : RateBucket {
163 2000027 : start: now,
164 2000027 : count: 0,
165 2000027 : };
166 2000027 : self.info.len()
167 2000027 : ]
168 2001455 : });
169 2001455 :
170 2001455 : let should_allow_request = entry
171 2001455 : .iter_mut()
172 2001455 : .zip(self.info)
173 6002953 : .all(|(bucket, info)| bucket.should_allow_request(info, now));
174 2001455 :
175 2001455 : if should_allow_request {
176 2001447 : // only increment the bucket counts if the request will actually be accepted
177 2001447 : entry.iter_mut().for_each(RateBucket::inc);
178 2001447 : }
179 :
180 2001455 : should_allow_request
181 2001455 : }
182 :
183 : /// Clean the map. Simple strategy: remove all entries in a random shard.
184 : /// At worst, we'll double the effective max_rps during the cleanup.
185 : /// But that way deletion does not aquire mutex on each entry access.
186 976 : pub fn do_gc(&self) {
187 976 : info!(
188 0 : "cleaning up endpoint rate limiter, current size = {}",
189 0 : self.map.len()
190 0 : );
191 976 : let n = self.map.shards().len();
192 976 : // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
193 976 : // (impossible, infact, unless we have 2048 threads)
194 976 : let shard = self.rand.lock().unwrap().gen_range(0..n);
195 976 : self.map.shards()[shard].write().clear();
196 976 : }
197 : }
198 :
199 : /// Limits the number of concurrent jobs.
200 : ///
201 : /// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
202 : /// token once the job is finished.
203 : ///
204 : /// The limit will be automatically adjusted based on observed latency (delay) and/or failures
205 : /// caused by overload (loss).
206 : pub struct Limiter {
207 : limit_algo: AsyncMutex<Box<dyn LimitAlgorithm>>,
208 : semaphore: std::sync::Arc<Semaphore>,
209 : config: RateLimiterConfig,
210 :
211 : // ONLY WRITE WHEN LIMIT_ALGO IS LOCKED
212 : limits: AtomicUsize,
213 :
214 : // ONLY USE ATOMIC ADD/SUB
215 : in_flight: Arc<AtomicUsize>,
216 :
217 : #[cfg(test)]
218 : notifier: Option<std::sync::Arc<tokio::sync::Notify>>,
219 : }
220 :
221 : /// A concurrency token, required to run a job.
222 : ///
223 : /// Release the token back to the [Limiter] after the job is complete.
224 0 : #[derive(Debug)]
225 : pub struct Token<'t> {
226 : permit: Option<tokio::sync::SemaphorePermit<'t>>,
227 : start: Instant,
228 : in_flight: Arc<AtomicUsize>,
229 : }
230 :
231 : /// A snapshot of the state of the [Limiter].
232 : ///
233 : /// Not guaranteed to be consistent under high concurrency.
234 0 : #[derive(Debug, Clone, Copy)]
235 : pub struct LimiterState {
236 : limit: usize,
237 : in_flight: usize,
238 : }
239 :
240 : /// Whether a job succeeded or failed as a result of congestion/overload.
241 : ///
242 : /// Errors not considered to be caused by overload should be ignored.
243 6 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
244 : pub enum Outcome {
245 : /// The job succeeded, or failed in a way unrelated to overload.
246 : Success,
247 : /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
248 : /// was observed.
249 : Overload,
250 : }
251 :
252 : impl Outcome {
253 0 : fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self {
254 0 : match error {
255 0 : reqwest_middleware::Error::Middleware(_) => Outcome::Success,
256 0 : reqwest_middleware::Error::Reqwest(e) => {
257 0 : if let Some(status) = e.status() {
258 0 : if status.is_server_error()
259 0 : || reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status
260 : {
261 0 : Outcome::Overload
262 : } else {
263 0 : Outcome::Success
264 : }
265 : } else {
266 0 : Outcome::Success
267 : }
268 : }
269 : }
270 0 : }
271 7 : fn from_reqwest_response(response: &reqwest::Response) -> Self {
272 7 : if response.status().is_server_error()
273 6 : || response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS
274 : {
275 2 : Outcome::Overload
276 : } else {
277 5 : Outcome::Success
278 : }
279 7 : }
280 : }
281 :
282 : impl Limiter {
283 : /// Create a limiter with a given limit control algorithm.
284 17 : pub fn new(config: RateLimiterConfig) -> Self {
285 17 : assert!(config.initial_limit > 0);
286 17 : Self {
287 17 : limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()),
288 17 : semaphore: Arc::new(Semaphore::new(config.initial_limit)),
289 17 : config,
290 17 : limits: AtomicUsize::new(config.initial_limit),
291 17 : in_flight: Arc::new(AtomicUsize::new(0)),
292 17 : #[cfg(test)]
293 17 : notifier: None,
294 17 : }
295 17 : }
296 : // pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self {
297 : // assert!(initial_limit > 0);
298 :
299 : // Self {
300 : // limit_algo: AsyncMutex::new(limit_algorithm),
301 : // semaphore: Arc::new(Semaphore::new(initial_limit)),
302 : // timeout,
303 : // limits: AtomicUsize::new(initial_limit),
304 : // in_flight: Arc::new(AtomicUsize::new(0)),
305 : // #[cfg(test)]
306 : // notifier: None,
307 : // }
308 : // }
309 :
310 : /// In some cases [Token]s are acquired asynchronously when updating the limit.
311 : #[cfg(test)]
312 2 : pub fn with_release_notifier(mut self, n: std::sync::Arc<tokio::sync::Notify>) -> Self {
313 2 : self.notifier = Some(n);
314 2 : self
315 2 : }
316 :
317 : /// Try to immediately acquire a concurrency [Token].
318 : ///
319 : /// Returns `None` if there are none available.
320 26 : pub fn try_acquire(&self) -> Option<Token> {
321 26 : let result = if self.config.disable {
322 : // If the rate limiter is disabled, we can always acquire a token.
323 4 : Some(Token::new(None, self.in_flight.clone()))
324 : } else {
325 22 : self.semaphore
326 22 : .try_acquire()
327 22 : .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
328 22 : .ok()
329 : };
330 26 : if result.is_some() {
331 22 : self.in_flight.fetch_add(1, Ordering::AcqRel);
332 22 : }
333 26 : result
334 26 : }
335 :
336 : /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
337 : ///
338 : /// Returns `None` if there are none available after `duration`.
339 12 : pub async fn acquire_timeout(&self, duration: Duration) -> Option<Token<'_>> {
340 4 : info!("acquiring token: {:?}", self.semaphore.available_permits());
341 12 : let result = if self.config.disable {
342 : // If the rate limiter is disabled, we can always acquire a token.
343 4 : Some(Token::new(None, self.in_flight.clone()))
344 : } else {
345 8 : match timeout(duration, self.semaphore.acquire()).await {
346 7 : Ok(maybe_permit) => maybe_permit
347 7 : .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
348 7 : .ok(),
349 1 : Err(_) => None,
350 : }
351 : };
352 12 : if result.is_some() {
353 11 : self.in_flight.fetch_add(1, Ordering::AcqRel);
354 11 : }
355 12 : result
356 12 : }
357 :
358 : /// Return the concurrency [Token], along with the outcome of the job.
359 : ///
360 : /// The [Outcome] of the job, and the time taken to perform it, may be used
361 : /// to update the concurrency limit.
362 : ///
363 : /// Set the outcome to `None` to ignore the job.
364 29 : pub async fn release(&self, mut token: Token<'_>, outcome: Option<Outcome>) {
365 3 : tracing::info!("outcome is {:?}", outcome);
366 29 : let in_flight = self.in_flight.load(Ordering::Acquire);
367 29 : let old_limit = self.limits.load(Ordering::Acquire);
368 29 : let available = if self.config.disable {
369 8 : 0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0.
370 : } else {
371 21 : self.semaphore.available_permits()
372 : };
373 29 : let total = in_flight + available;
374 :
375 29 : let mut algo = self.limit_algo.lock().await;
376 :
377 29 : let new_limit = if let Some(outcome) = outcome {
378 23 : let sample = Sample {
379 23 : latency: token.start.elapsed(),
380 23 : in_flight,
381 23 : outcome,
382 23 : };
383 23 : algo.update(old_limit, sample).await
384 : } else {
385 6 : old_limit
386 : };
387 3 : tracing::info!("new limit is {}", new_limit);
388 29 : let actual_limit = if new_limit < total {
389 6 : token.forget();
390 6 : total.saturating_sub(1)
391 : } else {
392 23 : if !self.config.disable {
393 17 : self.semaphore.add_permits(new_limit.saturating_sub(total));
394 17 : }
395 23 : new_limit
396 : };
397 26 : crate::metrics::RATE_LIMITER_LIMIT
398 26 : .with_label_values(&["expected"])
399 26 : .set(new_limit as i64);
400 26 : crate::metrics::RATE_LIMITER_LIMIT
401 26 : .with_label_values(&["actual"])
402 26 : .set(actual_limit as i64);
403 26 : self.limits.store(new_limit, Ordering::Release);
404 3 : #[cfg(test)]
405 26 : if let Some(n) = &self.notifier {
406 2 : n.notify_one();
407 24 : }
408 26 : }
409 :
410 : /// The current state of the limiter.
411 12 : pub fn state(&self) -> LimiterState {
412 12 : let limit = self.limits.load(Ordering::Relaxed);
413 12 : let in_flight = self.in_flight.load(Ordering::Relaxed);
414 12 : LimiterState { limit, in_flight }
415 12 : }
416 : }
417 :
418 : impl<'t> Token<'t> {
419 33 : fn new(permit: Option<SemaphorePermit<'t>>, in_flight: Arc<AtomicUsize>) -> Self {
420 33 : Self {
421 33 : permit,
422 33 : start: Instant::now(),
423 33 : in_flight,
424 33 : }
425 33 : }
426 :
427 6 : pub fn forget(&mut self) {
428 6 : if let Some(permit) = self.permit.take() {
429 4 : permit.forget();
430 4 : }
431 6 : }
432 : }
433 :
434 : impl Drop for Token<'_> {
435 33 : fn drop(&mut self) {
436 33 : self.in_flight.fetch_sub(1, Ordering::AcqRel);
437 33 : }
438 : }
439 :
440 : impl LimiterState {
441 : /// The current concurrency limit.
442 12 : pub fn limit(&self) -> usize {
443 12 : self.limit
444 12 : }
445 : /// The number of jobs in flight.
446 2 : pub fn in_flight(&self) -> usize {
447 2 : self.in_flight
448 2 : }
449 : }
450 :
451 : #[async_trait::async_trait]
452 : impl reqwest_middleware::Middleware for Limiter {
453 8 : async fn handle(
454 8 : &self,
455 8 : req: reqwest::Request,
456 8 : extensions: &mut task_local_extensions::Extensions,
457 8 : next: reqwest_middleware::Next<'_>,
458 8 : ) -> reqwest_middleware::Result<reqwest::Response> {
459 8 : let start = Instant::now();
460 8 : let token = self
461 8 : .acquire_timeout(self.config.timeout)
462 1 : .await
463 8 : .ok_or_else(|| {
464 1 : reqwest_middleware::Error::Middleware(
465 1 : // TODO: Should we map it into user facing errors?
466 1 : crate::console::errors::ApiError::Console {
467 1 : status: crate::http::StatusCode::TOO_MANY_REQUESTS,
468 1 : text: "Too many requests".into(),
469 1 : }
470 1 : .into(),
471 1 : )
472 8 : })?;
473 7 : info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane");
474 7 : crate::metrics::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64());
475 19 : match next.run(req, extensions).await {
476 7 : Ok(response) => {
477 7 : self.release(token, Some(Outcome::from_reqwest_response(&response)))
478 0 : .await;
479 7 : Ok(response)
480 : }
481 0 : Err(e) => {
482 0 : self.release(token, Some(Outcome::from_reqwest_error(&e)))
483 0 : .await;
484 0 : Err(e)
485 : }
486 : }
487 16 : }
488 : }
489 :
490 : #[cfg(test)]
491 : mod tests {
492 : use std::{hash::BuildHasherDefault, pin::pin, task::Context, time::Duration};
493 :
494 : use futures::{task::noop_waker_ref, Future};
495 : use rand::SeedableRng;
496 : use rustc_hash::FxHasher;
497 : use tokio::time;
498 :
499 : use super::{EndpointRateLimiter, Limiter, Outcome};
500 : use crate::{
501 : rate_limiter::{RateBucketInfo, RateLimitAlgorithm},
502 : EndpointId,
503 : };
504 :
505 2 : #[tokio::test]
506 2 : async fn it_works() {
507 2 : let config = super::RateLimiterConfig {
508 2 : algorithm: RateLimitAlgorithm::Fixed,
509 2 : timeout: Duration::from_secs(1),
510 2 : initial_limit: 10,
511 2 : disable: false,
512 2 : ..Default::default()
513 2 : };
514 2 : let limiter = Limiter::new(config);
515 2 :
516 2 : let token = limiter.try_acquire().unwrap();
517 2 :
518 2 : limiter.release(token, Some(Outcome::Success)).await;
519 :
520 2 : assert_eq!(limiter.state().limit(), 10);
521 : }
522 :
523 2 : #[tokio::test]
524 2 : async fn is_fair() {
525 2 : let config = super::RateLimiterConfig {
526 2 : algorithm: RateLimitAlgorithm::Fixed,
527 2 : timeout: Duration::from_secs(1),
528 2 : initial_limit: 1,
529 2 : disable: false,
530 2 : ..Default::default()
531 2 : };
532 2 : let limiter = Limiter::new(config);
533 2 :
534 2 : // === TOKEN 1 ===
535 2 : let token1 = limiter.try_acquire().unwrap();
536 2 :
537 2 : let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
538 2 : assert!(
539 2 : token2_fut
540 2 : .as_mut()
541 2 : .poll(&mut Context::from_waker(noop_waker_ref()))
542 2 : .is_pending(),
543 0 : "token is acquired by token1"
544 : );
545 :
546 2 : let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
547 2 : assert!(
548 2 : token3_fut
549 2 : .as_mut()
550 2 : .poll(&mut Context::from_waker(noop_waker_ref()))
551 2 : .is_pending(),
552 0 : "token is acquired by token1"
553 : );
554 :
555 2 : limiter.release(token1, Some(Outcome::Success)).await;
556 : // === END TOKEN 1 ===
557 :
558 : // === TOKEN 2 ===
559 2 : assert!(
560 2 : limiter.try_acquire().is_none(),
561 0 : "token is acquired by token2"
562 : );
563 :
564 2 : assert!(
565 2 : token3_fut
566 2 : .as_mut()
567 2 : .poll(&mut Context::from_waker(noop_waker_ref()))
568 2 : .is_pending(),
569 0 : "token is acquired by token2"
570 : );
571 :
572 2 : let token2 = token2_fut.await.unwrap();
573 2 :
574 2 : limiter.release(token2, Some(Outcome::Success)).await;
575 : // === END TOKEN 2 ===
576 :
577 : // === TOKEN 3 ===
578 2 : assert!(
579 2 : limiter.try_acquire().is_none(),
580 0 : "token is acquired by token3"
581 : );
582 :
583 2 : let token3 = token3_fut.await.unwrap();
584 2 : limiter.release(token3, Some(Outcome::Success)).await;
585 : // === END TOKEN 3 ===
586 :
587 : // === TOKEN 4 ===
588 2 : let token4 = limiter.try_acquire().unwrap();
589 2 : limiter.release(token4, Some(Outcome::Success)).await;
590 : }
591 :
592 2 : #[tokio::test]
593 2 : async fn disable() {
594 2 : let config = super::RateLimiterConfig {
595 2 : algorithm: RateLimitAlgorithm::Fixed,
596 2 : timeout: Duration::from_secs(1),
597 2 : initial_limit: 1,
598 2 : disable: true,
599 2 : ..Default::default()
600 2 : };
601 2 : let limiter = Limiter::new(config);
602 2 :
603 2 : // === TOKEN 1 ===
604 2 : let token1 = limiter.try_acquire().unwrap();
605 2 : let token2 = limiter.try_acquire().unwrap();
606 2 : let state = limiter.state();
607 2 : assert_eq!(state.limit(), 1);
608 2 : assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected.
609 2 : limiter.release(token1, None).await;
610 2 : limiter.release(token2, None).await;
611 : }
612 :
613 2 : #[test]
614 2 : fn rate_bucket_rpi() {
615 2 : let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
616 2 : assert_eq!(rate_bucket.max_rpi, 50 * 5);
617 :
618 2 : let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
619 2 : assert_eq!(rate_bucket.max_rpi, 50 / 2);
620 2 : }
621 :
622 2 : #[test]
623 2 : fn rate_bucket_parse() {
624 2 : let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
625 2 : assert_eq!(rate_bucket.interval, Duration::from_secs(10));
626 2 : assert_eq!(rate_bucket.max_rpi, 100 * 10);
627 2 : assert_eq!(rate_bucket.to_string(), "100@10s");
628 :
629 2 : let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
630 2 : assert_eq!(rate_bucket.interval, Duration::from_secs(60));
631 2 : assert_eq!(rate_bucket.max_rpi, 100 * 60);
632 2 : assert_eq!(rate_bucket.to_string(), "100@1m");
633 2 : }
634 :
635 2 : #[test]
636 2 : fn default_rate_buckets() {
637 2 : let mut defaults = RateBucketInfo::DEFAULT_SET;
638 2 : RateBucketInfo::validate(&mut defaults[..]).unwrap();
639 2 : }
640 :
641 2 : #[test]
642 : #[should_panic = "invalid endpoint RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
643 2 : fn rate_buckets_validate() {
644 2 : let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
645 2 : .into_iter()
646 4 : .map(|s| s.parse().unwrap())
647 2 : .collect();
648 2 : RateBucketInfo::validate(&mut rates).unwrap();
649 2 : }
650 :
651 2 : #[tokio::test]
652 2 : async fn test_rate_limits() {
653 2 : let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
654 2 : .into_iter()
655 4 : .map(|s| s.parse().unwrap())
656 2 : .collect();
657 2 : RateBucketInfo::validate(&mut rates).unwrap();
658 2 : let limiter = EndpointRateLimiter::new(Vec::leak(rates));
659 2 :
660 2 : let endpoint = EndpointId::from("ep-my-endpoint-1234");
661 2 :
662 2 : time::pause();
663 :
664 202 : for _ in 0..100 {
665 200 : assert!(limiter.check(endpoint.clone()));
666 : }
667 : // more connections fail
668 2 : assert!(!limiter.check(endpoint.clone()));
669 :
670 : // fail even after 500ms as it's in the same bucket
671 2 : time::advance(time::Duration::from_millis(500)).await;
672 2 : assert!(!limiter.check(endpoint.clone()));
673 :
674 : // after a full 1s, 100 requests are allowed again
675 2 : time::advance(time::Duration::from_millis(500)).await;
676 12 : for _ in 1..6 {
677 1010 : for _ in 0..100 {
678 1000 : assert!(limiter.check(endpoint.clone()));
679 : }
680 10 : time::advance(time::Duration::from_millis(1000)).await;
681 : }
682 :
683 : // more connections after 600 will exceed the 20rps@30s limit
684 2 : assert!(!limiter.check(endpoint.clone()));
685 :
686 : // will still fail before the 30 second limit
687 2 : time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
688 2 : assert!(!limiter.check(endpoint.clone()));
689 :
690 : // after the full 30 seconds, 100 requests are allowed again
691 2 : time::advance(time::Duration::from_millis(1)).await;
692 202 : for _ in 0..100 {
693 200 : assert!(limiter.check(endpoint.clone()));
694 : }
695 : }
696 :
697 2 : #[tokio::test]
698 2 : async fn test_rate_limits_gc() {
699 2 : // fixed seeded random/hasher to ensure that the test is not flaky
700 2 : let rand = rand::rngs::StdRng::from_seed([1; 32]);
701 2 : let hasher = BuildHasherDefault::<FxHasher>::default();
702 2 :
703 2 : let limiter = EndpointRateLimiter::new_with_rand_and_hasher(
704 2 : &RateBucketInfo::DEFAULT_SET,
705 2 : rand,
706 2 : hasher,
707 2 : );
708 2000002 : for i in 0..1_000_000 {
709 2000000 : limiter.check(format!("{i}").into());
710 2000000 : }
711 2 : assert!(limiter.map.len() < 150_000);
712 : }
713 : }
|