TLA Line data Source code
1 : use std::{
2 : collections::hash_map::RandomState,
3 : hash::BuildHasher,
4 : sync::{
5 : atomic::{AtomicUsize, Ordering},
6 : Arc, Mutex,
7 : },
8 : };
9 :
10 : use anyhow::bail;
11 : use dashmap::DashMap;
12 : use itertools::Itertools;
13 : use rand::{rngs::StdRng, Rng, SeedableRng};
14 : use smol_str::SmolStr;
15 : use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
16 : use tokio::time::{timeout, Duration, Instant};
17 : use tracing::info;
18 :
19 : use super::{
20 : limit_algorithm::{LimitAlgorithm, Sample},
21 : RateLimiterConfig,
22 : };
23 :
24 : // Simple per-endpoint rate limiter.
25 : //
26 : // Check that number of connections to the endpoint is below `max_rps` rps.
27 : // Purposefully ignore user name and database name as clients can reconnect
28 : // with different names, so we'll end up sending some http requests to
29 : // the control plane.
30 : //
31 : // We also may save quite a lot of CPU (I think) by bailing out right after we
32 : // saw SNI, before doing TLS handshake. User-side error messages in that case
33 : // does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
34 : // I went with a more expensive way that yields user-friendlier error messages.
35 : pub struct EndpointRateLimiter<Rand = StdRng, Hasher = RandomState> {
36 : map: DashMap<SmolStr, Vec<RateBucket>, Hasher>,
37 : info: &'static [RateBucketInfo],
38 : access_count: AtomicUsize,
39 : rand: Mutex<Rand>,
40 : }
41 :
42 CBC 2000049 : #[derive(Clone, Copy)]
43 : struct RateBucket {
44 : start: Instant,
45 : count: u32,
46 : }
47 :
48 : impl RateBucket {
49 3001544 : fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant) -> bool {
50 3001544 : if now - self.start < info.interval {
51 3001533 : self.count < info.max_rpi
52 : } else {
53 : // bucket expired, reset
54 11 : self.count = 0;
55 11 : self.start = now;
56 11 :
57 11 : true
58 : }
59 3001544 : }
60 :
61 3001538 : fn inc(&mut self) {
62 3001538 : self.count += 1;
63 3001538 : }
64 : }
65 :
66 2 : #[derive(Clone, Copy, PartialEq)]
67 : pub struct RateBucketInfo {
68 : pub interval: Duration,
69 : // requests per interval
70 : pub max_rpi: u32,
71 : }
72 :
73 : impl std::fmt::Display for RateBucketInfo {
74 139 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 139 : let rps = self.max_rpi * 1000 / self.interval.as_millis() as u32;
76 139 : write!(f, "{rps}@{}", humantime::format_duration(self.interval))
77 139 : }
78 : }
79 :
80 : impl std::fmt::Debug for RateBucketInfo {
81 66 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 66 : write!(f, "{self}")
83 66 : }
84 : }
85 :
86 : impl std::str::FromStr for RateBucketInfo {
87 : type Err = anyhow::Error;
88 :
89 143 : fn from_str(s: &str) -> Result<Self, Self::Err> {
90 143 : let Some((max_rps, interval)) = s.split_once('@') else {
91 UBC 0 : bail!("invalid rate info")
92 : };
93 CBC 143 : let max_rps = max_rps.parse()?;
94 143 : let interval = humantime::parse_duration(interval)?;
95 143 : Ok(Self::new(max_rps, interval))
96 143 : }
97 : }
98 :
99 : impl RateBucketInfo {
100 : pub const DEFAULT_SET: [Self; 3] = [
101 : Self::new(300, Duration::from_secs(1)),
102 : Self::new(200, Duration::from_secs(60)),
103 : Self::new(100, Duration::from_secs(600)),
104 : ];
105 :
106 25 : pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
107 96 : info.sort_unstable_by_key(|info| info.interval);
108 25 : let invalid = info
109 25 : .iter()
110 25 : .tuple_windows()
111 48 : .find(|(a, b)| a.max_rpi > b.max_rpi);
112 25 : if let Some((a, b)) = invalid {
113 1 : bail!(
114 1 : "invalid endpoint RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
115 1 : b.max_rpi,
116 1 : a.max_rpi,
117 1 : );
118 24 : }
119 24 :
120 24 : Ok(())
121 25 : }
122 :
123 147 : pub const fn new(max_rps: u32, interval: Duration) -> Self {
124 147 : Self {
125 147 : interval,
126 147 : max_rpi: max_rps * interval.as_millis() as u32 / 1000,
127 147 : }
128 147 : }
129 : }
130 :
131 : impl EndpointRateLimiter {
132 23 : pub fn new(info: &'static [RateBucketInfo]) -> Self {
133 23 : Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
134 23 : }
135 : }
136 :
137 : impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
138 24 : fn new_with_rand_and_hasher(info: &'static [RateBucketInfo], rand: R, hasher: S) -> Self {
139 24 : info!(buckets = ?info, "endpoint rate limiter");
140 24 : Self {
141 24 : info,
142 24 : map: DashMap::with_hasher_and_shard_amount(hasher, 64),
143 24 : access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
144 24 : rand: Mutex::new(rand),
145 24 : }
146 24 : }
147 :
148 : /// Check that number of connections to the endpoint is below `max_rps` rps.
149 1000750 : pub fn check(&self, endpoint: SmolStr) -> bool {
150 1000750 : // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
151 1000750 : // worst case memory usage is about:
152 1000750 : // = 2 * 2048 * 64 * (48B + 72B)
153 1000750 : // = 30MB
154 1000750 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
155 488 : self.do_gc();
156 1000262 : }
157 :
158 1000750 : let now = Instant::now();
159 1000750 : let mut entry = self.map.entry(endpoint).or_insert_with(|| {
160 1000025 : vec![
161 1000025 : RateBucket {
162 1000025 : start: now,
163 1000025 : count: 0,
164 1000025 : };
165 1000025 : self.info.len()
166 1000025 : ]
167 1000750 : });
168 1000750 :
169 1000750 : let should_allow_request = entry
170 1000750 : .iter_mut()
171 1000750 : .zip(self.info)
172 3001544 : .all(|(bucket, info)| bucket.should_allow_request(info, now));
173 1000750 :
174 1000750 : if should_allow_request {
175 1000746 : // only increment the bucket counts if the request will actually be accepted
176 1000746 : entry.iter_mut().for_each(RateBucket::inc);
177 1000746 : }
178 :
179 1000750 : should_allow_request
180 1000750 : }
181 :
182 : /// Clean the map. Simple strategy: remove all entries in a random shard.
183 : /// At worst, we'll double the effective max_rps during the cleanup.
184 : /// But that way deletion does not aquire mutex on each entry access.
185 488 : pub fn do_gc(&self) {
186 488 : info!(
187 UBC 0 : "cleaning up endpoint rate limiter, current size = {}",
188 0 : self.map.len()
189 0 : );
190 CBC 488 : let n = self.map.shards().len();
191 488 : // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
192 488 : // (impossible, infact, unless we have 2048 threads)
193 488 : let shard = self.rand.lock().unwrap().gen_range(0..n);
194 488 : self.map.shards()[shard].write().clear();
195 488 : }
196 : }
197 :
198 : /// Limits the number of concurrent jobs.
199 : ///
200 : /// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
201 : /// token once the job is finished.
202 : ///
203 : /// The limit will be automatically adjusted based on observed latency (delay) and/or failures
204 : /// caused by overload (loss).
205 : pub struct Limiter {
206 : limit_algo: AsyncMutex<Box<dyn LimitAlgorithm>>,
207 : semaphore: std::sync::Arc<Semaphore>,
208 : config: RateLimiterConfig,
209 :
210 : // ONLY WRITE WHEN LIMIT_ALGO IS LOCKED
211 : limits: AtomicUsize,
212 :
213 : // ONLY USE ATOMIC ADD/SUB
214 : in_flight: Arc<AtomicUsize>,
215 :
216 : #[cfg(test)]
217 : notifier: Option<std::sync::Arc<tokio::sync::Notify>>,
218 : }
219 :
220 : /// A concurrency token, required to run a job.
221 : ///
222 : /// Release the token back to the [Limiter] after the job is complete.
223 UBC 0 : #[derive(Debug)]
224 : pub struct Token<'t> {
225 : permit: Option<tokio::sync::SemaphorePermit<'t>>,
226 : start: Instant,
227 : in_flight: Arc<AtomicUsize>,
228 : }
229 :
230 : /// A snapshot of the state of the [Limiter].
231 : ///
232 : /// Not guaranteed to be consistent under high concurrency.
233 0 : #[derive(Debug, Clone, Copy)]
234 : pub struct LimiterState {
235 : limit: usize,
236 : in_flight: usize,
237 : }
238 :
239 : /// Whether a job succeeded or failed as a result of congestion/overload.
240 : ///
241 : /// Errors not considered to be caused by overload should be ignored.
242 CBC 6 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
243 : pub enum Outcome {
244 : /// The job succeeded, or failed in a way unrelated to overload.
245 : Success,
246 : /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
247 : /// was observed.
248 : Overload,
249 : }
250 :
251 : impl Outcome {
252 UBC 0 : fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self {
253 0 : match error {
254 0 : reqwest_middleware::Error::Middleware(_) => Outcome::Success,
255 0 : reqwest_middleware::Error::Reqwest(e) => {
256 0 : if let Some(status) = e.status() {
257 0 : if status.is_server_error()
258 0 : || reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status
259 : {
260 0 : Outcome::Overload
261 : } else {
262 0 : Outcome::Success
263 : }
264 : } else {
265 0 : Outcome::Success
266 : }
267 : }
268 : }
269 0 : }
270 CBC 5 : fn from_reqwest_response(response: &reqwest::Response) -> Self {
271 5 : if response.status().is_server_error()
272 4 : || response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS
273 : {
274 2 : Outcome::Overload
275 : } else {
276 3 : Outcome::Success
277 : }
278 5 : }
279 : }
280 :
281 : impl Limiter {
282 : /// Create a limiter with a given limit control algorithm.
283 9 : pub fn new(config: RateLimiterConfig) -> Self {
284 9 : assert!(config.initial_limit > 0);
285 9 : Self {
286 9 : limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()),
287 9 : semaphore: Arc::new(Semaphore::new(config.initial_limit)),
288 9 : config,
289 9 : limits: AtomicUsize::new(config.initial_limit),
290 9 : in_flight: Arc::new(AtomicUsize::new(0)),
291 9 : #[cfg(test)]
292 9 : notifier: None,
293 9 : }
294 9 : }
295 : // pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self {
296 : // assert!(initial_limit > 0);
297 :
298 : // Self {
299 : // limit_algo: AsyncMutex::new(limit_algorithm),
300 : // semaphore: Arc::new(Semaphore::new(initial_limit)),
301 : // timeout,
302 : // limits: AtomicUsize::new(initial_limit),
303 : // in_flight: Arc::new(AtomicUsize::new(0)),
304 : // #[cfg(test)]
305 : // notifier: None,
306 : // }
307 : // }
308 :
309 : /// In some cases [Token]s are acquired asynchronously when updating the limit.
310 : #[cfg(test)]
311 1 : pub fn with_release_notifier(mut self, n: std::sync::Arc<tokio::sync::Notify>) -> Self {
312 1 : self.notifier = Some(n);
313 1 : self
314 1 : }
315 :
316 : /// Try to immediately acquire a concurrency [Token].
317 : ///
318 : /// Returns `None` if there are none available.
319 13 : pub fn try_acquire(&self) -> Option<Token> {
320 13 : let result = if self.config.disable {
321 : // If the rate limiter is disabled, we can always acquire a token.
322 2 : Some(Token::new(None, self.in_flight.clone()))
323 : } else {
324 11 : self.semaphore
325 11 : .try_acquire()
326 11 : .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
327 11 : .ok()
328 : };
329 13 : if result.is_some() {
330 11 : self.in_flight.fetch_add(1, Ordering::AcqRel);
331 11 : }
332 13 : result
333 13 : }
334 :
335 : /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
336 : ///
337 : /// Returns `None` if there are none available after `duration`.
338 8 : pub async fn acquire_timeout(&self, duration: Duration) -> Option<Token<'_>> {
339 4 : info!("acquiring token: {:?}", self.semaphore.available_permits());
340 8 : let result = if self.config.disable {
341 : // If the rate limiter is disabled, we can always acquire a token.
342 2 : Some(Token::new(None, self.in_flight.clone()))
343 : } else {
344 6 : match timeout(duration, self.semaphore.acquire()).await {
345 5 : Ok(maybe_permit) => maybe_permit
346 5 : .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
347 5 : .ok(),
348 1 : Err(_) => None,
349 : }
350 : };
351 8 : if result.is_some() {
352 7 : self.in_flight.fetch_add(1, Ordering::AcqRel);
353 7 : }
354 8 : result
355 8 : }
356 :
357 : /// Return the concurrency [Token], along with the outcome of the job.
358 : ///
359 : /// The [Outcome] of the job, and the time taken to perform it, may be used
360 : /// to update the concurrency limit.
361 : ///
362 : /// Set the outcome to `None` to ignore the job.
363 16 : pub async fn release(&self, mut token: Token<'_>, outcome: Option<Outcome>) {
364 3 : tracing::info!("outcome is {:?}", outcome);
365 16 : let in_flight = self.in_flight.load(Ordering::Acquire);
366 16 : let old_limit = self.limits.load(Ordering::Acquire);
367 16 : let available = if self.config.disable {
368 4 : 0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0.
369 : } else {
370 12 : self.semaphore.available_permits()
371 : };
372 16 : let total = in_flight + available;
373 :
374 16 : let mut algo = self.limit_algo.lock().await;
375 :
376 16 : let new_limit = if let Some(outcome) = outcome {
377 13 : let sample = Sample {
378 13 : latency: token.start.elapsed(),
379 13 : in_flight,
380 13 : outcome,
381 13 : };
382 13 : algo.update(old_limit, sample).await
383 : } else {
384 3 : old_limit
385 : };
386 3 : tracing::info!("new limit is {}", new_limit);
387 16 : let actual_limit = if new_limit < total {
388 4 : token.forget();
389 4 : total.saturating_sub(1)
390 : } else {
391 12 : if !self.config.disable {
392 9 : self.semaphore.add_permits(new_limit.saturating_sub(total));
393 9 : }
394 12 : new_limit
395 : };
396 13 : crate::metrics::RATE_LIMITER_LIMIT
397 13 : .with_label_values(&["expected"])
398 13 : .set(new_limit as i64);
399 13 : crate::metrics::RATE_LIMITER_LIMIT
400 13 : .with_label_values(&["actual"])
401 13 : .set(actual_limit as i64);
402 13 : self.limits.store(new_limit, Ordering::Release);
403 3 : #[cfg(test)]
404 13 : if let Some(n) = &self.notifier {
405 1 : n.notify_one();
406 12 : }
407 13 : }
408 :
409 : /// The current state of the limiter.
410 6 : pub fn state(&self) -> LimiterState {
411 6 : let limit = self.limits.load(Ordering::Relaxed);
412 6 : let in_flight = self.in_flight.load(Ordering::Relaxed);
413 6 : LimiterState { limit, in_flight }
414 6 : }
415 : }
416 :
417 : impl<'t> Token<'t> {
418 18 : fn new(permit: Option<SemaphorePermit<'t>>, in_flight: Arc<AtomicUsize>) -> Self {
419 18 : Self {
420 18 : permit,
421 18 : start: Instant::now(),
422 18 : in_flight,
423 18 : }
424 18 : }
425 :
426 4 : pub fn forget(&mut self) {
427 4 : if let Some(permit) = self.permit.take() {
428 3 : permit.forget();
429 3 : }
430 4 : }
431 : }
432 :
433 : impl Drop for Token<'_> {
434 18 : fn drop(&mut self) {
435 18 : self.in_flight.fetch_sub(1, Ordering::AcqRel);
436 18 : }
437 : }
438 :
439 : impl LimiterState {
440 : /// The current concurrency limit.
441 6 : pub fn limit(&self) -> usize {
442 6 : self.limit
443 6 : }
444 : /// The number of jobs in flight.
445 1 : pub fn in_flight(&self) -> usize {
446 1 : self.in_flight
447 1 : }
448 : }
449 :
450 : #[async_trait::async_trait]
451 : impl reqwest_middleware::Middleware for Limiter {
452 6 : async fn handle(
453 6 : &self,
454 6 : req: reqwest::Request,
455 6 : extensions: &mut task_local_extensions::Extensions,
456 6 : next: reqwest_middleware::Next<'_>,
457 6 : ) -> reqwest_middleware::Result<reqwest::Response> {
458 6 : let start = Instant::now();
459 6 : let token = self
460 6 : .acquire_timeout(self.config.timeout)
461 1 : .await
462 6 : .ok_or_else(|| {
463 1 : reqwest_middleware::Error::Middleware(
464 1 : // TODO: Should we map it into user facing errors?
465 1 : crate::console::errors::ApiError::Console {
466 1 : status: crate::http::StatusCode::TOO_MANY_REQUESTS,
467 1 : text: "Too many requests".into(),
468 1 : }
469 1 : .into(),
470 1 : )
471 6 : })?;
472 5 : info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane");
473 5 : crate::metrics::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64());
474 15 : match next.run(req, extensions).await {
475 5 : Ok(response) => {
476 5 : self.release(token, Some(Outcome::from_reqwest_response(&response)))
477 UBC 0 : .await;
478 CBC 5 : Ok(response)
479 : }
480 UBC 0 : Err(e) => {
481 0 : self.release(token, Some(Outcome::from_reqwest_error(&e)))
482 0 : .await;
483 0 : Err(e)
484 : }
485 : }
486 CBC 12 : }
487 : }
488 :
489 : #[cfg(test)]
490 : mod tests {
491 : use std::{hash::BuildHasherDefault, pin::pin, task::Context, time::Duration};
492 :
493 : use futures::{task::noop_waker_ref, Future};
494 : use rand::SeedableRng;
495 : use rustc_hash::FxHasher;
496 : use smol_str::SmolStr;
497 : use tokio::time;
498 :
499 : use super::{EndpointRateLimiter, Limiter, Outcome};
500 : use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm};
501 :
502 1 : #[tokio::test]
503 1 : async fn it_works() {
504 1 : let config = super::RateLimiterConfig {
505 1 : algorithm: RateLimitAlgorithm::Fixed,
506 1 : timeout: Duration::from_secs(1),
507 1 : initial_limit: 10,
508 1 : disable: false,
509 1 : ..Default::default()
510 1 : };
511 1 : let limiter = Limiter::new(config);
512 1 :
513 1 : let token = limiter.try_acquire().unwrap();
514 1 :
515 1 : limiter.release(token, Some(Outcome::Success)).await;
516 :
517 1 : assert_eq!(limiter.state().limit(), 10);
518 : }
519 :
520 1 : #[tokio::test]
521 1 : async fn is_fair() {
522 1 : let config = super::RateLimiterConfig {
523 1 : algorithm: RateLimitAlgorithm::Fixed,
524 1 : timeout: Duration::from_secs(1),
525 1 : initial_limit: 1,
526 1 : disable: false,
527 1 : ..Default::default()
528 1 : };
529 1 : let limiter = Limiter::new(config);
530 1 :
531 1 : // === TOKEN 1 ===
532 1 : let token1 = limiter.try_acquire().unwrap();
533 1 :
534 1 : let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
535 1 : assert!(
536 1 : token2_fut
537 1 : .as_mut()
538 1 : .poll(&mut Context::from_waker(noop_waker_ref()))
539 1 : .is_pending(),
540 UBC 0 : "token is acquired by token1"
541 : );
542 :
543 CBC 1 : let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
544 1 : assert!(
545 1 : token3_fut
546 1 : .as_mut()
547 1 : .poll(&mut Context::from_waker(noop_waker_ref()))
548 1 : .is_pending(),
549 UBC 0 : "token is acquired by token1"
550 : );
551 :
552 CBC 1 : limiter.release(token1, Some(Outcome::Success)).await;
553 : // === END TOKEN 1 ===
554 :
555 : // === TOKEN 2 ===
556 1 : assert!(
557 1 : limiter.try_acquire().is_none(),
558 UBC 0 : "token is acquired by token2"
559 : );
560 :
561 CBC 1 : assert!(
562 1 : token3_fut
563 1 : .as_mut()
564 1 : .poll(&mut Context::from_waker(noop_waker_ref()))
565 1 : .is_pending(),
566 UBC 0 : "token is acquired by token2"
567 : );
568 :
569 CBC 1 : let token2 = token2_fut.await.unwrap();
570 1 :
571 1 : limiter.release(token2, Some(Outcome::Success)).await;
572 : // === END TOKEN 2 ===
573 :
574 : // === TOKEN 3 ===
575 1 : assert!(
576 1 : limiter.try_acquire().is_none(),
577 UBC 0 : "token is acquired by token3"
578 : );
579 :
580 CBC 1 : let token3 = token3_fut.await.unwrap();
581 1 : limiter.release(token3, Some(Outcome::Success)).await;
582 : // === END TOKEN 3 ===
583 :
584 : // === TOKEN 4 ===
585 1 : let token4 = limiter.try_acquire().unwrap();
586 1 : limiter.release(token4, Some(Outcome::Success)).await;
587 : }
588 :
589 1 : #[tokio::test]
590 1 : async fn disable() {
591 1 : let config = super::RateLimiterConfig {
592 1 : algorithm: RateLimitAlgorithm::Fixed,
593 1 : timeout: Duration::from_secs(1),
594 1 : initial_limit: 1,
595 1 : disable: true,
596 1 : ..Default::default()
597 1 : };
598 1 : let limiter = Limiter::new(config);
599 1 :
600 1 : // === TOKEN 1 ===
601 1 : let token1 = limiter.try_acquire().unwrap();
602 1 : let token2 = limiter.try_acquire().unwrap();
603 1 : let state = limiter.state();
604 1 : assert_eq!(state.limit(), 1);
605 1 : assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected.
606 1 : limiter.release(token1, None).await;
607 1 : limiter.release(token2, None).await;
608 : }
609 :
610 1 : #[test]
611 1 : fn rate_bucket_rpi() {
612 1 : let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
613 1 : assert_eq!(rate_bucket.max_rpi, 50 * 5);
614 :
615 1 : let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
616 1 : assert_eq!(rate_bucket.max_rpi, 50 / 2);
617 1 : }
618 :
619 1 : #[test]
620 1 : fn rate_bucket_parse() {
621 1 : let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
622 1 : assert_eq!(rate_bucket.interval, Duration::from_secs(10));
623 1 : assert_eq!(rate_bucket.max_rpi, 100 * 10);
624 1 : assert_eq!(rate_bucket.to_string(), "100@10s");
625 :
626 1 : let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
627 1 : assert_eq!(rate_bucket.interval, Duration::from_secs(60));
628 1 : assert_eq!(rate_bucket.max_rpi, 100 * 60);
629 1 : assert_eq!(rate_bucket.to_string(), "100@1m");
630 1 : }
631 :
632 1 : #[test]
633 1 : fn default_rate_buckets() {
634 1 : let mut defaults = RateBucketInfo::DEFAULT_SET;
635 1 : RateBucketInfo::validate(&mut defaults[..]).unwrap();
636 1 : }
637 :
638 1 : #[test]
639 : #[should_panic = "invalid endpoint RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
640 1 : fn rate_buckets_validate() {
641 1 : let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
642 1 : .into_iter()
643 2 : .map(|s| s.parse().unwrap())
644 1 : .collect();
645 1 : RateBucketInfo::validate(&mut rates).unwrap();
646 1 : }
647 :
648 1 : #[tokio::test]
649 1 : async fn test_rate_limits() {
650 1 : let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
651 1 : .into_iter()
652 2 : .map(|s| s.parse().unwrap())
653 1 : .collect();
654 1 : RateBucketInfo::validate(&mut rates).unwrap();
655 1 : let limiter = EndpointRateLimiter::new(Vec::leak(rates));
656 1 :
657 1 : let endpoint = SmolStr::from("ep-my-endpoint-1234");
658 1 :
659 1 : time::pause();
660 :
661 101 : for _ in 0..100 {
662 100 : assert!(limiter.check(endpoint.clone()));
663 : }
664 : // more connections fail
665 1 : assert!(!limiter.check(endpoint.clone()));
666 :
667 : // fail even after 500ms as it's in the same bucket
668 1 : time::advance(time::Duration::from_millis(500)).await;
669 1 : assert!(!limiter.check(endpoint.clone()));
670 :
671 : // after a full 1s, 100 requests are allowed again
672 1 : time::advance(time::Duration::from_millis(500)).await;
673 6 : for _ in 1..6 {
674 505 : for _ in 0..100 {
675 500 : assert!(limiter.check(endpoint.clone()));
676 : }
677 5 : time::advance(time::Duration::from_millis(1000)).await;
678 : }
679 :
680 : // more connections after 600 will exceed the 20rps@30s limit
681 1 : assert!(!limiter.check(endpoint.clone()));
682 :
683 : // will still fail before the 30 second limit
684 1 : time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
685 1 : assert!(!limiter.check(endpoint.clone()));
686 :
687 : // after the full 30 seconds, 100 requests are allowed again
688 1 : time::advance(time::Duration::from_millis(1)).await;
689 101 : for _ in 0..100 {
690 100 : assert!(limiter.check(endpoint.clone()));
691 : }
692 : }
693 :
694 1 : #[tokio::test]
695 1 : async fn test_rate_limits_gc() {
696 1 : // fixed seeded random/hasher to ensure that the test is not flaky
697 1 : let rand = rand::rngs::StdRng::from_seed([1; 32]);
698 1 : let hasher = BuildHasherDefault::<FxHasher>::default();
699 1 :
700 1 : let limiter = EndpointRateLimiter::new_with_rand_and_hasher(
701 1 : &RateBucketInfo::DEFAULT_SET,
702 1 : rand,
703 1 : hasher,
704 1 : );
705 1000001 : for i in 0..1_000_000 {
706 1000000 : limiter.check(format!("{i}").into());
707 1000000 : }
708 1 : assert!(limiter.map.len() < 150_000);
709 : }
710 : }
|