Line data Source code
1 : use std::{
2 : borrow::Cow,
3 : collections::hash_map::RandomState,
4 : hash::{BuildHasher, Hash},
5 : net::IpAddr,
6 : sync::{
7 : atomic::{AtomicUsize, Ordering},
8 : Arc, Mutex,
9 : },
10 : };
11 :
12 : use anyhow::bail;
13 : use dashmap::DashMap;
14 : use itertools::Itertools;
15 : use rand::{rngs::StdRng, Rng, SeedableRng};
16 : use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
17 : use tokio::time::{timeout, Duration, Instant};
18 : use tracing::info;
19 :
20 : use crate::{intern::EndpointIdInt, EndpointId};
21 :
22 : use super::{
23 : limit_algorithm::{LimitAlgorithm, Sample},
24 : RateLimiterConfig,
25 : };
26 :
27 : pub struct RedisRateLimiter {
28 : data: Vec<RateBucket>,
29 : info: &'static [RateBucketInfo],
30 : }
31 :
32 : impl RedisRateLimiter {
33 0 : pub fn new(info: &'static [RateBucketInfo]) -> Self {
34 0 : Self {
35 0 : data: vec![
36 0 : RateBucket {
37 0 : start: Instant::now(),
38 0 : count: 0,
39 0 : };
40 0 : info.len()
41 0 : ],
42 0 : info,
43 0 : }
44 0 : }
45 :
46 : /// Check that number of connections is below `max_rps` rps.
47 0 : pub fn check(&mut self) -> bool {
48 0 : let now = Instant::now();
49 0 :
50 0 : let should_allow_request = self
51 0 : .data
52 0 : .iter_mut()
53 0 : .zip(self.info)
54 0 : .all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
55 0 :
56 0 : if should_allow_request {
57 0 : // only increment the bucket counts if the request will actually be accepted
58 0 : self.data.iter_mut().for_each(|b| b.inc(1));
59 0 : }
60 :
61 0 : should_allow_request
62 0 : }
63 : }
64 :
65 : // Simple per-endpoint rate limiter.
66 : //
67 : // Check that number of connections to the endpoint is below `max_rps` rps.
68 : // Purposefully ignore user name and database name as clients can reconnect
69 : // with different names, so we'll end up sending some http requests to
70 : // the control plane.
71 : //
72 : // We also may save quite a lot of CPU (I think) by bailing out right after we
73 : // saw SNI, before doing TLS handshake. User-side error messages in that case
74 : // does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
75 : // I went with a more expensive way that yields user-friendlier error messages.
76 : pub type EndpointRateLimiter = BucketRateLimiter<EndpointId, StdRng, RandomState>;
77 :
78 : // This can't be just per IP because that would limit some PaaS that share IP addresses
79 : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, IpAddr), StdRng, RandomState>;
80 :
81 : pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
82 : map: DashMap<Key, Vec<RateBucket>, Hasher>,
83 : info: Cow<'static, [RateBucketInfo]>,
84 : access_count: AtomicUsize,
85 : rand: Mutex<Rand>,
86 : }
87 :
88 : #[derive(Clone, Copy)]
89 : struct RateBucket {
90 : start: Instant,
91 : count: u32,
92 : }
93 :
94 : impl RateBucket {
95 6001830 : fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
96 6001830 : if now - self.start < info.interval {
97 6001814 : self.count + n <= info.max_rpi
98 : } else {
99 : // bucket expired, reset
100 16 : self.count = 0;
101 16 : self.start = now;
102 16 :
103 16 : true
104 : }
105 6001830 : }
106 :
107 6001818 : fn inc(&mut self, n: u32) {
108 6001818 : self.count += n;
109 6001818 : }
110 : }
111 :
112 : #[derive(Clone, Copy, PartialEq)]
113 : pub struct RateBucketInfo {
114 : pub interval: Duration,
115 : // requests per interval
116 : pub max_rpi: u32,
117 : }
118 :
119 : impl std::fmt::Display for RateBucketInfo {
120 32 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 32 : let rps = (self.max_rpi as u64) * 1000 / self.interval.as_millis() as u64;
122 32 : write!(f, "{rps}@{}", humantime::format_duration(self.interval))
123 32 : }
124 : }
125 :
126 : impl std::fmt::Debug for RateBucketInfo {
127 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 0 : write!(f, "{self}")
129 0 : }
130 : }
131 :
132 : impl std::str::FromStr for RateBucketInfo {
133 : type Err = anyhow::Error;
134 :
135 52 : fn from_str(s: &str) -> Result<Self, Self::Err> {
136 52 : let Some((max_rps, interval)) = s.split_once('@') else {
137 0 : bail!("invalid rate info")
138 : };
139 52 : let max_rps = max_rps.parse()?;
140 52 : let interval = humantime::parse_duration(interval)?;
141 52 : Ok(Self::new(max_rps, interval))
142 52 : }
143 : }
144 :
145 : impl RateBucketInfo {
146 : pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
147 : Self::new(300, Duration::from_secs(1)),
148 : Self::new(200, Duration::from_secs(60)),
149 : Self::new(100, Duration::from_secs(600)),
150 : ];
151 :
152 : /// All of these are per endpoint-ip pair.
153 : /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
154 : ///
155 : /// First bucket: 300mcpus total per endpoint-ip pair
156 : /// * 1228800 requests per second with 1 hash rounds. (endpoint rate limiter will catch this first)
157 : /// * 300 requests per second with 4096 hash rounds.
158 : /// * 2 requests per second with 600000 hash rounds.
159 : pub const DEFAULT_AUTH_SET: [Self; 3] = [
160 : Self::new(300 * 4096, Duration::from_secs(1)),
161 : Self::new(200 * 4096, Duration::from_secs(60)),
162 : Self::new(100 * 4096, Duration::from_secs(600)),
163 : ];
164 :
165 6 : pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
166 16 : info.sort_unstable_by_key(|info| info.interval);
167 6 : let invalid = info
168 6 : .iter()
169 6 : .tuple_windows()
170 8 : .find(|(a, b)| a.max_rpi > b.max_rpi);
171 6 : if let Some((a, b)) = invalid {
172 2 : bail!(
173 2 : "invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
174 2 : b.max_rpi,
175 2 : a.max_rpi,
176 2 : );
177 4 : }
178 4 :
179 4 : Ok(())
180 6 : }
181 :
182 60 : pub const fn new(max_rps: u32, interval: Duration) -> Self {
183 60 : Self {
184 60 : interval,
185 60 : max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
186 60 : }
187 60 : }
188 : }
189 :
190 : impl<K: Hash + Eq> BucketRateLimiter<K> {
191 8 : pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
192 8 : Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
193 8 : }
194 : }
195 :
196 : impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
197 10 : fn new_with_rand_and_hasher(
198 10 : info: impl Into<Cow<'static, [RateBucketInfo]>>,
199 10 : rand: R,
200 10 : hasher: S,
201 10 : ) -> Self {
202 10 : let info = info.into();
203 10 : info!(buckets = ?info, "endpoint rate limiter");
204 10 : Self {
205 10 : info,
206 10 : map: DashMap::with_hasher_and_shard_amount(hasher, 64),
207 10 : access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
208 10 : rand: Mutex::new(rand),
209 10 : }
210 10 : }
211 :
212 : /// Check that number of connections to the endpoint is below `max_rps` rps.
213 2000914 : pub fn check(&self, key: K, n: u32) -> bool {
214 2000914 : // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
215 2000914 : // worst case memory usage is about:
216 2000914 : // = 2 * 2048 * 64 * (48B + 72B)
217 2000914 : // = 30MB
218 2000914 : if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
219 976 : self.do_gc();
220 1999938 : }
221 :
222 2000914 : let now = Instant::now();
223 2000914 : let mut entry = self.map.entry(key).or_insert_with(|| {
224 2000008 : vec![
225 2000008 : RateBucket {
226 2000008 : start: now,
227 2000008 : count: 0,
228 2000008 : };
229 2000008 : self.info.len()
230 2000008 : ]
231 2000914 : });
232 2000914 :
233 2000914 : let should_allow_request = entry
234 2000914 : .iter_mut()
235 2000914 : .zip(&*self.info)
236 6001830 : .all(|(bucket, info)| bucket.should_allow_request(info, now, n));
237 2000914 :
238 2000914 : if should_allow_request {
239 2000906 : // only increment the bucket counts if the request will actually be accepted
240 6001818 : entry.iter_mut().for_each(|b| b.inc(n));
241 2000906 : }
242 :
243 2000914 : should_allow_request
244 2000914 : }
245 :
246 : /// Clean the map. Simple strategy: remove all entries in a random shard.
247 : /// At worst, we'll double the effective max_rps during the cleanup.
248 : /// But that way deletion does not aquire mutex on each entry access.
249 976 : pub fn do_gc(&self) {
250 976 : info!(
251 0 : "cleaning up bucket rate limiter, current size = {}",
252 0 : self.map.len()
253 0 : );
254 976 : let n = self.map.shards().len();
255 976 : // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
256 976 : // (impossible, infact, unless we have 2048 threads)
257 976 : let shard = self.rand.lock().unwrap().gen_range(0..n);
258 976 : self.map.shards()[shard].write().clear();
259 976 : }
260 : }
261 :
262 : /// Limits the number of concurrent jobs.
263 : ///
264 : /// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
265 : /// token once the job is finished.
266 : ///
267 : /// The limit will be automatically adjusted based on observed latency (delay) and/or failures
268 : /// caused by overload (loss).
269 : pub struct Limiter {
270 : limit_algo: AsyncMutex<Box<dyn LimitAlgorithm>>,
271 : semaphore: std::sync::Arc<Semaphore>,
272 : config: RateLimiterConfig,
273 :
274 : // ONLY WRITE WHEN LIMIT_ALGO IS LOCKED
275 : limits: AtomicUsize,
276 :
277 : // ONLY USE ATOMIC ADD/SUB
278 : in_flight: Arc<AtomicUsize>,
279 :
280 : #[cfg(test)]
281 : notifier: Option<std::sync::Arc<tokio::sync::Notify>>,
282 : }
283 :
284 : /// A concurrency token, required to run a job.
285 : ///
286 : /// Release the token back to the [Limiter] after the job is complete.
287 : #[derive(Debug)]
288 : pub struct Token<'t> {
289 : permit: Option<tokio::sync::SemaphorePermit<'t>>,
290 : start: Instant,
291 : in_flight: Arc<AtomicUsize>,
292 : }
293 :
294 : /// A snapshot of the state of the [Limiter].
295 : ///
296 : /// Not guaranteed to be consistent under high concurrency.
297 : #[derive(Debug, Clone, Copy)]
298 : pub struct LimiterState {
299 : limit: usize,
300 : in_flight: usize,
301 : }
302 :
303 : /// Whether a job succeeded or failed as a result of congestion/overload.
304 : ///
305 : /// Errors not considered to be caused by overload should be ignored.
306 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
307 : pub enum Outcome {
308 : /// The job succeeded, or failed in a way unrelated to overload.
309 : Success,
310 : /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
311 : /// was observed.
312 : Overload,
313 : }
314 :
315 : impl Outcome {
316 0 : fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self {
317 0 : match error {
318 0 : reqwest_middleware::Error::Middleware(_) => Outcome::Success,
319 0 : reqwest_middleware::Error::Reqwest(e) => {
320 0 : if let Some(status) = e.status() {
321 0 : if status.is_server_error()
322 0 : || reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status
323 : {
324 0 : Outcome::Overload
325 : } else {
326 0 : Outcome::Success
327 : }
328 : } else {
329 0 : Outcome::Success
330 : }
331 : }
332 : }
333 0 : }
334 4 : fn from_reqwest_response(response: &reqwest::Response) -> Self {
335 4 : if response.status().is_server_error()
336 4 : || response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS
337 : {
338 0 : Outcome::Overload
339 : } else {
340 4 : Outcome::Success
341 : }
342 4 : }
343 : }
344 :
345 : impl Limiter {
346 : /// Create a limiter with a given limit control algorithm.
347 16 : pub fn new(config: RateLimiterConfig) -> Self {
348 16 : assert!(config.initial_limit > 0);
349 16 : Self {
350 16 : limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()),
351 16 : semaphore: Arc::new(Semaphore::new(config.initial_limit)),
352 16 : config,
353 16 : limits: AtomicUsize::new(config.initial_limit),
354 16 : in_flight: Arc::new(AtomicUsize::new(0)),
355 16 : #[cfg(test)]
356 16 : notifier: None,
357 16 : }
358 16 : }
359 : // pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self {
360 : // assert!(initial_limit > 0);
361 :
362 : // Self {
363 : // limit_algo: AsyncMutex::new(limit_algorithm),
364 : // semaphore: Arc::new(Semaphore::new(initial_limit)),
365 : // timeout,
366 : // limits: AtomicUsize::new(initial_limit),
367 : // in_flight: Arc::new(AtomicUsize::new(0)),
368 : // #[cfg(test)]
369 : // notifier: None,
370 : // }
371 : // }
372 :
373 : /// In some cases [Token]s are acquired asynchronously when updating the limit.
374 : #[cfg(test)]
375 2 : pub fn with_release_notifier(mut self, n: std::sync::Arc<tokio::sync::Notify>) -> Self {
376 2 : self.notifier = Some(n);
377 2 : self
378 2 : }
379 :
380 : /// Try to immediately acquire a concurrency [Token].
381 : ///
382 : /// Returns `None` if there are none available.
383 26 : pub fn try_acquire(&self) -> Option<Token> {
384 26 : let result = if self.config.disable {
385 : // If the rate limiter is disabled, we can always acquire a token.
386 4 : Some(Token::new(None, self.in_flight.clone()))
387 : } else {
388 22 : self.semaphore
389 22 : .try_acquire()
390 22 : .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
391 22 : .ok()
392 : };
393 26 : if result.is_some() {
394 22 : self.in_flight.fetch_add(1, Ordering::AcqRel);
395 22 : }
396 26 : result
397 26 : }
398 :
399 : /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
400 : ///
401 : /// Returns `None` if there are none available after `duration`.
402 8 : pub async fn acquire_timeout(&self, duration: Duration) -> Option<Token<'_>> {
403 8 : info!("acquiring token: {:?}", self.semaphore.available_permits());
404 8 : let result = if self.config.disable {
405 : // If the rate limiter is disabled, we can always acquire a token.
406 4 : Some(Token::new(None, self.in_flight.clone()))
407 : } else {
408 6 : match timeout(duration, self.semaphore.acquire()).await {
409 4 : Ok(maybe_permit) => maybe_permit
410 4 : .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
411 4 : .ok(),
412 0 : Err(_) => None,
413 : }
414 : };
415 8 : if result.is_some() {
416 8 : self.in_flight.fetch_add(1, Ordering::AcqRel);
417 8 : }
418 8 : result
419 8 : }
420 :
421 : /// Return the concurrency [Token], along with the outcome of the job.
422 : ///
423 : /// The [Outcome] of the job, and the time taken to perform it, may be used
424 : /// to update the concurrency limit.
425 : ///
426 : /// Set the outcome to `None` to ignore the job.
427 26 : pub async fn release(&self, mut token: Token<'_>, outcome: Option<Outcome>) {
428 26 : tracing::info!("outcome is {:?}", outcome);
429 26 : let in_flight = self.in_flight.load(Ordering::Acquire);
430 26 : let old_limit = self.limits.load(Ordering::Acquire);
431 26 : let available = if self.config.disable {
432 8 : 0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0.
433 : } else {
434 18 : self.semaphore.available_permits()
435 : };
436 26 : let total = in_flight + available;
437 :
438 26 : let mut algo = self.limit_algo.lock().await;
439 :
440 26 : let new_limit = if let Some(outcome) = outcome {
441 20 : let sample = Sample {
442 20 : latency: token.start.elapsed(),
443 20 : in_flight,
444 20 : outcome,
445 20 : };
446 20 : algo.update(old_limit, sample).await
447 : } else {
448 6 : old_limit
449 : };
450 26 : tracing::info!("new limit is {}", new_limit);
451 26 : let actual_limit = if new_limit < total {
452 4 : token.forget();
453 4 : total.saturating_sub(1)
454 : } else {
455 22 : if !self.config.disable {
456 16 : self.semaphore.add_permits(new_limit.saturating_sub(total));
457 16 : }
458 22 : new_limit
459 : };
460 26 : crate::metrics::RATE_LIMITER_LIMIT
461 26 : .with_label_values(&["expected"])
462 26 : .set(new_limit as i64);
463 26 : crate::metrics::RATE_LIMITER_LIMIT
464 26 : .with_label_values(&["actual"])
465 26 : .set(actual_limit as i64);
466 26 : self.limits.store(new_limit, Ordering::Release);
467 0 : #[cfg(test)]
468 26 : if let Some(n) = &self.notifier {
469 2 : n.notify_one();
470 24 : }
471 26 : }
472 :
473 : /// The current state of the limiter.
474 12 : pub fn state(&self) -> LimiterState {
475 12 : let limit = self.limits.load(Ordering::Relaxed);
476 12 : let in_flight = self.in_flight.load(Ordering::Relaxed);
477 12 : LimiterState { limit, in_flight }
478 12 : }
479 : }
480 :
481 : impl<'t> Token<'t> {
482 30 : fn new(permit: Option<SemaphorePermit<'t>>, in_flight: Arc<AtomicUsize>) -> Self {
483 30 : Self {
484 30 : permit,
485 30 : start: Instant::now(),
486 30 : in_flight,
487 30 : }
488 30 : }
489 :
490 4 : pub fn forget(&mut self) {
491 4 : if let Some(permit) = self.permit.take() {
492 2 : permit.forget();
493 2 : }
494 4 : }
495 : }
496 :
497 : impl Drop for Token<'_> {
498 30 : fn drop(&mut self) {
499 30 : self.in_flight.fetch_sub(1, Ordering::AcqRel);
500 30 : }
501 : }
502 :
503 : impl LimiterState {
504 : /// The current concurrency limit.
505 12 : pub fn limit(&self) -> usize {
506 12 : self.limit
507 12 : }
508 : /// The number of jobs in flight.
509 2 : pub fn in_flight(&self) -> usize {
510 2 : self.in_flight
511 2 : }
512 : }
513 :
514 : #[async_trait::async_trait]
515 : impl reqwest_middleware::Middleware for Limiter {
516 4 : async fn handle(
517 4 : &self,
518 4 : req: reqwest::Request,
519 4 : extensions: &mut task_local_extensions::Extensions,
520 4 : next: reqwest_middleware::Next<'_>,
521 4 : ) -> reqwest_middleware::Result<reqwest::Response> {
522 4 : let start = Instant::now();
523 4 : let token = self
524 4 : .acquire_timeout(self.config.timeout)
525 0 : .await
526 4 : .ok_or_else(|| {
527 0 : reqwest_middleware::Error::Middleware(
528 0 : // TODO: Should we map it into user facing errors?
529 0 : crate::console::errors::ApiError::Console {
530 0 : status: crate::http::StatusCode::TOO_MANY_REQUESTS,
531 0 : text: "Too many requests".into(),
532 0 : }
533 0 : .into(),
534 0 : )
535 4 : })?;
536 4 : info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane");
537 4 : crate::metrics::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64());
538 8 : match next.run(req, extensions).await {
539 4 : Ok(response) => {
540 4 : self.release(token, Some(Outcome::from_reqwest_response(&response)))
541 0 : .await;
542 4 : Ok(response)
543 : }
544 0 : Err(e) => {
545 0 : self.release(token, Some(Outcome::from_reqwest_error(&e)))
546 0 : .await;
547 0 : Err(e)
548 : }
549 : }
550 12 : }
551 : }
552 :
553 : #[cfg(test)]
554 : mod tests {
555 : use std::{hash::BuildHasherDefault, pin::pin, task::Context, time::Duration};
556 :
557 : use futures::{task::noop_waker_ref, Future};
558 : use rand::SeedableRng;
559 : use rustc_hash::FxHasher;
560 : use tokio::time;
561 :
562 : use super::{BucketRateLimiter, EndpointRateLimiter, Limiter, Outcome};
563 : use crate::{
564 : rate_limiter::{RateBucketInfo, RateLimitAlgorithm},
565 : EndpointId,
566 : };
567 :
568 : #[tokio::test]
569 2 : async fn it_works() {
570 2 : let config = super::RateLimiterConfig {
571 2 : algorithm: RateLimitAlgorithm::Fixed,
572 2 : timeout: Duration::from_secs(1),
573 2 : initial_limit: 10,
574 2 : disable: false,
575 2 : ..Default::default()
576 2 : };
577 2 : let limiter = Limiter::new(config);
578 2 :
579 2 : let token = limiter.try_acquire().unwrap();
580 2 :
581 2 : limiter.release(token, Some(Outcome::Success)).await;
582 2 :
583 2 : assert_eq!(limiter.state().limit(), 10);
584 2 : }
585 :
586 : #[tokio::test]
587 2 : async fn is_fair() {
588 2 : let config = super::RateLimiterConfig {
589 2 : algorithm: RateLimitAlgorithm::Fixed,
590 2 : timeout: Duration::from_secs(1),
591 2 : initial_limit: 1,
592 2 : disable: false,
593 2 : ..Default::default()
594 2 : };
595 2 : let limiter = Limiter::new(config);
596 2 :
597 2 : // === TOKEN 1 ===
598 2 : let token1 = limiter.try_acquire().unwrap();
599 2 :
600 2 : let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
601 2 : assert!(
602 2 : token2_fut
603 2 : .as_mut()
604 2 : .poll(&mut Context::from_waker(noop_waker_ref()))
605 2 : .is_pending(),
606 2 : "token is acquired by token1"
607 2 : );
608 2 :
609 2 : let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
610 2 : assert!(
611 2 : token3_fut
612 2 : .as_mut()
613 2 : .poll(&mut Context::from_waker(noop_waker_ref()))
614 2 : .is_pending(),
615 2 : "token is acquired by token1"
616 2 : );
617 2 :
618 2 : limiter.release(token1, Some(Outcome::Success)).await;
619 2 : // === END TOKEN 1 ===
620 2 :
621 2 : // === TOKEN 2 ===
622 2 : assert!(
623 2 : limiter.try_acquire().is_none(),
624 2 : "token is acquired by token2"
625 2 : );
626 2 :
627 2 : assert!(
628 2 : token3_fut
629 2 : .as_mut()
630 2 : .poll(&mut Context::from_waker(noop_waker_ref()))
631 2 : .is_pending(),
632 2 : "token is acquired by token2"
633 2 : );
634 2 :
635 2 : let token2 = token2_fut.await.unwrap();
636 2 :
637 2 : limiter.release(token2, Some(Outcome::Success)).await;
638 2 : // === END TOKEN 2 ===
639 2 :
640 2 : // === TOKEN 3 ===
641 2 : assert!(
642 2 : limiter.try_acquire().is_none(),
643 2 : "token is acquired by token3"
644 2 : );
645 2 :
646 2 : let token3 = token3_fut.await.unwrap();
647 2 : limiter.release(token3, Some(Outcome::Success)).await;
648 2 : // === END TOKEN 3 ===
649 2 :
650 2 : // === TOKEN 4 ===
651 2 : let token4 = limiter.try_acquire().unwrap();
652 2 : limiter.release(token4, Some(Outcome::Success)).await;
653 2 : }
654 :
655 : #[tokio::test]
656 2 : async fn disable() {
657 2 : let config = super::RateLimiterConfig {
658 2 : algorithm: RateLimitAlgorithm::Fixed,
659 2 : timeout: Duration::from_secs(1),
660 2 : initial_limit: 1,
661 2 : disable: true,
662 2 : ..Default::default()
663 2 : };
664 2 : let limiter = Limiter::new(config);
665 2 :
666 2 : // === TOKEN 1 ===
667 2 : let token1 = limiter.try_acquire().unwrap();
668 2 : let token2 = limiter.try_acquire().unwrap();
669 2 : let state = limiter.state();
670 2 : assert_eq!(state.limit(), 1);
671 2 : assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected.
672 2 : limiter.release(token1, None).await;
673 2 : limiter.release(token2, None).await;
674 2 : }
675 :
676 : #[test]
677 2 : fn rate_bucket_rpi() {
678 2 : let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
679 2 : assert_eq!(rate_bucket.max_rpi, 50 * 5);
680 :
681 2 : let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
682 2 : assert_eq!(rate_bucket.max_rpi, 50 / 2);
683 2 : }
684 :
685 : #[test]
686 2 : fn rate_bucket_parse() {
687 2 : let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
688 2 : assert_eq!(rate_bucket.interval, Duration::from_secs(10));
689 2 : assert_eq!(rate_bucket.max_rpi, 100 * 10);
690 2 : assert_eq!(rate_bucket.to_string(), "100@10s");
691 :
692 2 : let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
693 2 : assert_eq!(rate_bucket.interval, Duration::from_secs(60));
694 2 : assert_eq!(rate_bucket.max_rpi, 100 * 60);
695 2 : assert_eq!(rate_bucket.to_string(), "100@1m");
696 2 : }
697 :
698 : #[test]
699 2 : fn default_rate_buckets() {
700 2 : let mut defaults = RateBucketInfo::DEFAULT_ENDPOINT_SET;
701 2 : RateBucketInfo::validate(&mut defaults[..]).unwrap();
702 2 : }
703 :
704 : #[test]
705 : #[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
706 2 : fn rate_buckets_validate() {
707 2 : let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
708 2 : .into_iter()
709 4 : .map(|s| s.parse().unwrap())
710 2 : .collect();
711 2 : RateBucketInfo::validate(&mut rates).unwrap();
712 2 : }
713 :
714 : #[tokio::test]
715 2 : async fn test_rate_limits() {
716 2 : let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
717 2 : .into_iter()
718 4 : .map(|s| s.parse().unwrap())
719 2 : .collect();
720 2 : RateBucketInfo::validate(&mut rates).unwrap();
721 2 : let limiter = EndpointRateLimiter::new(rates);
722 2 :
723 2 : let endpoint = EndpointId::from("ep-my-endpoint-1234");
724 2 :
725 2 : time::pause();
726 2 :
727 202 : for _ in 0..100 {
728 200 : assert!(limiter.check(endpoint.clone(), 1));
729 2 : }
730 2 : // more connections fail
731 2 : assert!(!limiter.check(endpoint.clone(), 1));
732 2 :
733 2 : // fail even after 500ms as it's in the same bucket
734 2 : time::advance(time::Duration::from_millis(500)).await;
735 2 : assert!(!limiter.check(endpoint.clone(), 1));
736 2 :
737 2 : // after a full 1s, 100 requests are allowed again
738 2 : time::advance(time::Duration::from_millis(500)).await;
739 12 : for _ in 1..6 {
740 510 : for _ in 0..50 {
741 500 : assert!(limiter.check(endpoint.clone(), 2));
742 2 : }
743 10 : time::advance(time::Duration::from_millis(1000)).await;
744 2 : }
745 2 :
746 2 : // more connections after 600 will exceed the 20rps@30s limit
747 2 : assert!(!limiter.check(endpoint.clone(), 1));
748 2 :
749 2 : // will still fail before the 30 second limit
750 2 : time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
751 2 : assert!(!limiter.check(endpoint.clone(), 1));
752 2 :
753 2 : // after the full 30 seconds, 100 requests are allowed again
754 2 : time::advance(time::Duration::from_millis(1)).await;
755 202 : for _ in 0..100 {
756 200 : assert!(limiter.check(endpoint.clone(), 1));
757 2 : }
758 2 : }
759 :
760 : #[tokio::test]
761 2 : async fn test_rate_limits_gc() {
762 2 : // fixed seeded random/hasher to ensure that the test is not flaky
763 2 : let rand = rand::rngs::StdRng::from_seed([1; 32]);
764 2 : let hasher = BuildHasherDefault::<FxHasher>::default();
765 2 :
766 2 : let limiter = BucketRateLimiter::new_with_rand_and_hasher(
767 2 : &RateBucketInfo::DEFAULT_ENDPOINT_SET,
768 2 : rand,
769 2 : hasher,
770 2 : );
771 2000002 : for i in 0..1_000_000 {
772 2000000 : limiter.check(i, 1);
773 2000000 : }
774 2 : assert!(limiter.map.len() < 150_000);
775 2 : }
776 :
777 : #[test]
778 2 : fn test_default_auth_set() {
779 2 : // these values used to exceed u32::MAX
780 2 : assert_eq!(
781 2 : RateBucketInfo::DEFAULT_AUTH_SET,
782 2 : [
783 2 : RateBucketInfo {
784 2 : interval: Duration::from_secs(1),
785 2 : max_rpi: 300 * 4096,
786 2 : },
787 2 : RateBucketInfo {
788 2 : interval: Duration::from_secs(60),
789 2 : max_rpi: 200 * 4096 * 60,
790 2 : },
791 2 : RateBucketInfo {
792 2 : interval: Duration::from_secs(600),
793 2 : max_rpi: 100 * 4096 * 600,
794 2 : }
795 2 : ]
796 2 : );
797 :
798 8 : for x in RateBucketInfo::DEFAULT_AUTH_SET {
799 6 : let y = x.to_string().parse().unwrap();
800 6 : assert_eq!(x, y);
801 : }
802 2 : }
803 : }
|