LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - limiter.rs (source / functions) Coverage Total Hit
Test: aca8877be6ceba750c1be359ed71bc1799d52b30.info Lines: 89.9 % 483 434
Test Date: 2024-02-14 18:05:35 Functions: 77.4 % 84 65

            Line data    Source code
       1              : use std::{
       2              :     collections::hash_map::RandomState,
       3              :     hash::BuildHasher,
       4              :     sync::{
       5              :         atomic::{AtomicUsize, Ordering},
       6              :         Arc, Mutex,
       7              :     },
       8              : };
       9              : 
      10              : use anyhow::bail;
      11              : use dashmap::DashMap;
      12              : use itertools::Itertools;
      13              : use rand::{rngs::StdRng, Rng, SeedableRng};
      14              : use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
      15              : use tokio::time::{timeout, Duration, Instant};
      16              : use tracing::info;
      17              : 
      18              : use crate::EndpointId;
      19              : 
      20              : use super::{
      21              :     limit_algorithm::{LimitAlgorithm, Sample},
      22              :     RateLimiterConfig,
      23              : };
      24              : 
      25              : pub struct RedisRateLimiter {
      26              :     data: Vec<RateBucket>,
      27              :     info: &'static [RateBucketInfo],
      28              : }
      29              : 
      30              : impl RedisRateLimiter {
      31            0 :     pub fn new(info: &'static [RateBucketInfo]) -> Self {
      32            0 :         Self {
      33            0 :             data: vec![
      34            0 :                 RateBucket {
      35            0 :                     start: Instant::now(),
      36            0 :                     count: 0,
      37            0 :                 };
      38            0 :                 info.len()
      39            0 :             ],
      40            0 :             info,
      41            0 :         }
      42            0 :     }
      43              : 
      44              :     /// Check that number of connections is below `max_rps` rps.
      45            0 :     pub fn check(&mut self) -> bool {
      46            0 :         let now = Instant::now();
      47            0 : 
      48            0 :         let should_allow_request = self
      49            0 :             .data
      50            0 :             .iter_mut()
      51            0 :             .zip(self.info)
      52            0 :             .all(|(bucket, info)| bucket.should_allow_request(info, now));
      53            0 : 
      54            0 :         if should_allow_request {
      55            0 :             // only increment the bucket counts if the request will actually be accepted
      56            0 :             self.data.iter_mut().for_each(RateBucket::inc);
      57            0 :         }
      58              : 
      59            0 :         should_allow_request
      60            0 :     }
      61              : }
      62              : 
      63              : // Simple per-endpoint rate limiter.
      64              : //
      65              : // Check that number of connections to the endpoint is below `max_rps` rps.
      66              : // Purposefully ignore user name and database name as clients can reconnect
      67              : // with different names, so we'll end up sending some http requests to
      68              : // the control plane.
      69              : //
      70              : // We also may save quite a lot of CPU (I think) by bailing out right after we
      71              : // saw SNI, before doing TLS handshake. User-side error messages in that case
      72              : // does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
      73              : // I went with a more expensive way that yields user-friendlier error messages.
      74              : pub struct EndpointRateLimiter<Rand = StdRng, Hasher = RandomState> {
      75              :     map: DashMap<EndpointId, Vec<RateBucket>, Hasher>,
      76              :     info: &'static [RateBucketInfo],
      77              :     access_count: AtomicUsize,
      78              :     rand: Mutex<Rand>,
      79              : }
      80              : 
      81      4000056 : #[derive(Clone, Copy)]
      82              : struct RateBucket {
      83              :     start: Instant,
      84              :     count: u32,
      85              : }
      86              : 
      87              : impl RateBucket {
      88      6002959 :     fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant) -> bool {
      89      6002959 :         if now - self.start < info.interval {
      90      6002940 :             self.count < info.max_rpi
      91              :         } else {
      92              :             // bucket expired, reset
      93           19 :             self.count = 0;
      94           19 :             self.start = now;
      95           19 : 
      96           19 :             true
      97              :         }
      98      6002959 :     }
      99              : 
     100      6002947 :     fn inc(&mut self) {
     101      6002947 :         self.count += 1;
     102      6002947 :     }
     103              : }
     104              : 
     105            4 : #[derive(Clone, Copy, PartialEq)]
     106              : pub struct RateBucketInfo {
     107              :     pub interval: Duration,
     108              :     // requests per interval
     109              :     pub max_rpi: u32,
     110              : }
     111              : 
     112              : impl std::fmt::Display for RateBucketInfo {
     113          245 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     114          245 :         let rps = self.max_rpi * 1000 / self.interval.as_millis() as u32;
     115          245 :         write!(f, "{rps}@{}", humantime::format_duration(self.interval))
     116          245 :     }
     117              : }
     118              : 
     119              : impl std::fmt::Debug for RateBucketInfo {
     120           75 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     121           75 :         write!(f, "{self}")
     122           75 :     }
     123              : }
     124              : 
     125              : impl std::str::FromStr for RateBucketInfo {
     126              :     type Err = anyhow::Error;
     127              : 
     128          334 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
     129          334 :         let Some((max_rps, interval)) = s.split_once('@') else {
     130            0 :             bail!("invalid rate info")
     131              :         };
     132          334 :         let max_rps = max_rps.parse()?;
     133          334 :         let interval = humantime::parse_duration(interval)?;
     134          334 :         Ok(Self::new(max_rps, interval))
     135          334 :     }
     136              : }
     137              : 
     138              : impl RateBucketInfo {
     139              :     pub const DEFAULT_SET: [Self; 3] = [
     140              :         Self::new(300, Duration::from_secs(1)),
     141              :         Self::new(200, Duration::from_secs(60)),
     142              :         Self::new(100, Duration::from_secs(600)),
     143              :     ];
     144              : 
     145           56 :     pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
     146          216 :         info.sort_unstable_by_key(|info| info.interval);
     147           56 :         let invalid = info
     148           56 :             .iter()
     149           56 :             .tuple_windows()
     150          108 :             .find(|(a, b)| a.max_rpi > b.max_rpi);
     151           56 :         if let Some((a, b)) = invalid {
     152            2 :             bail!(
     153            2 :                 "invalid endpoint RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
     154            2 :                 b.max_rpi,
     155            2 :                 a.max_rpi,
     156            2 :             );
     157           54 :         }
     158           54 : 
     159           54 :         Ok(())
     160           56 :     }
     161              : 
     162          342 :     pub const fn new(max_rps: u32, interval: Duration) -> Self {
     163          342 :         Self {
     164          342 :             interval,
     165          342 :             max_rpi: max_rps * interval.as_millis() as u32 / 1000,
     166          342 :         }
     167          342 :     }
     168              : }
     169              : 
     170              : impl EndpointRateLimiter {
     171           27 :     pub fn new(info: &'static [RateBucketInfo]) -> Self {
     172           27 :         Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
     173           27 :     }
     174              : }
     175              : 
     176              : impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
     177           29 :     fn new_with_rand_and_hasher(info: &'static [RateBucketInfo], rand: R, hasher: S) -> Self {
     178           29 :         info!(buckets = ?info, "endpoint rate limiter");
     179           29 :         Self {
     180           29 :             info,
     181           29 :             map: DashMap::with_hasher_and_shard_amount(hasher, 64),
     182           29 :             access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
     183           29 :             rand: Mutex::new(rand),
     184           29 :         }
     185           29 :     }
     186              : 
     187              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
     188      2001457 :     pub fn check(&self, endpoint: EndpointId) -> bool {
     189      2001457 :         // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
     190      2001457 :         // worst case memory usage is about:
     191      2001457 :         //    = 2 * 2048 * 64 * (48B + 72B)
     192      2001457 :         //    = 30MB
     193      2001457 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
     194          976 :             self.do_gc();
     195      2000481 :         }
     196              : 
     197      2001457 :         let now = Instant::now();
     198      2001457 :         let mut entry = self.map.entry(endpoint).or_insert_with(|| {
     199      2000029 :             vec![
     200      2000029 :                 RateBucket {
     201      2000029 :                     start: now,
     202      2000029 :                     count: 0,
     203      2000029 :                 };
     204      2000029 :                 self.info.len()
     205      2000029 :             ]
     206      2001457 :         });
     207      2001457 : 
     208      2001457 :         let should_allow_request = entry
     209      2001457 :             .iter_mut()
     210      2001457 :             .zip(self.info)
     211      6002959 :             .all(|(bucket, info)| bucket.should_allow_request(info, now));
     212      2001457 : 
     213      2001457 :         if should_allow_request {
     214      2001449 :             // only increment the bucket counts if the request will actually be accepted
     215      2001449 :             entry.iter_mut().for_each(RateBucket::inc);
     216      2001449 :         }
     217              : 
     218      2001457 :         should_allow_request
     219      2001457 :     }
     220              : 
     221              :     /// Clean the map. Simple strategy: remove all entries in a random shard.
     222              :     /// At worst, we'll double the effective max_rps during the cleanup.
     223              :     /// But that way deletion does not aquire mutex on each entry access.
     224          976 :     pub fn do_gc(&self) {
     225          976 :         info!(
     226            0 :             "cleaning up endpoint rate limiter, current size = {}",
     227            0 :             self.map.len()
     228            0 :         );
     229          976 :         let n = self.map.shards().len();
     230          976 :         // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
     231          976 :         // (impossible, infact, unless we have 2048 threads)
     232          976 :         let shard = self.rand.lock().unwrap().gen_range(0..n);
     233          976 :         self.map.shards()[shard].write().clear();
     234          976 :     }
     235              : }
     236              : 
     237              : /// Limits the number of concurrent jobs.
     238              : ///
     239              : /// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
     240              : /// token once the job is finished.
     241              : ///
     242              : /// The limit will be automatically adjusted based on observed latency (delay) and/or failures
     243              : /// caused by overload (loss).
     244              : pub struct Limiter {
     245              :     limit_algo: AsyncMutex<Box<dyn LimitAlgorithm>>,
     246              :     semaphore: std::sync::Arc<Semaphore>,
     247              :     config: RateLimiterConfig,
     248              : 
     249              :     // ONLY WRITE WHEN LIMIT_ALGO IS LOCKED
     250              :     limits: AtomicUsize,
     251              : 
     252              :     // ONLY USE ATOMIC ADD/SUB
     253              :     in_flight: Arc<AtomicUsize>,
     254              : 
     255              :     #[cfg(test)]
     256              :     notifier: Option<std::sync::Arc<tokio::sync::Notify>>,
     257              : }
     258              : 
     259              : /// A concurrency token, required to run a job.
     260              : ///
     261              : /// Release the token back to the [Limiter] after the job is complete.
     262            0 : #[derive(Debug)]
     263              : pub struct Token<'t> {
     264              :     permit: Option<tokio::sync::SemaphorePermit<'t>>,
     265              :     start: Instant,
     266              :     in_flight: Arc<AtomicUsize>,
     267              : }
     268              : 
     269              : /// A snapshot of the state of the [Limiter].
     270              : ///
     271              : /// Not guaranteed to be consistent under high concurrency.
     272            0 : #[derive(Debug, Clone, Copy)]
     273              : pub struct LimiterState {
     274              :     limit: usize,
     275              :     in_flight: usize,
     276              : }
     277              : 
     278              : /// Whether a job succeeded or failed as a result of congestion/overload.
     279              : ///
     280              : /// Errors not considered to be caused by overload should be ignored.
     281            6 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
     282              : pub enum Outcome {
     283              :     /// The job succeeded, or failed in a way unrelated to overload.
     284              :     Success,
     285              :     /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
     286              :     /// was observed.
     287              :     Overload,
     288              : }
     289              : 
     290              : impl Outcome {
     291            0 :     fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self {
     292            0 :         match error {
     293            0 :             reqwest_middleware::Error::Middleware(_) => Outcome::Success,
     294            0 :             reqwest_middleware::Error::Reqwest(e) => {
     295            0 :                 if let Some(status) = e.status() {
     296            0 :                     if status.is_server_error()
     297            0 :                         || reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status
     298              :                     {
     299            0 :                         Outcome::Overload
     300              :                     } else {
     301            0 :                         Outcome::Success
     302              :                     }
     303              :                 } else {
     304            0 :                     Outcome::Success
     305              :                 }
     306              :             }
     307              :         }
     308            0 :     }
     309            7 :     fn from_reqwest_response(response: &reqwest::Response) -> Self {
     310            7 :         if response.status().is_server_error()
     311            6 :             || response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS
     312              :         {
     313            2 :             Outcome::Overload
     314              :         } else {
     315            5 :             Outcome::Success
     316              :         }
     317            7 :     }
     318              : }
     319              : 
     320              : impl Limiter {
     321              :     /// Create a limiter with a given limit control algorithm.
     322           17 :     pub fn new(config: RateLimiterConfig) -> Self {
     323           17 :         assert!(config.initial_limit > 0);
     324           17 :         Self {
     325           17 :             limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()),
     326           17 :             semaphore: Arc::new(Semaphore::new(config.initial_limit)),
     327           17 :             config,
     328           17 :             limits: AtomicUsize::new(config.initial_limit),
     329           17 :             in_flight: Arc::new(AtomicUsize::new(0)),
     330           17 :             #[cfg(test)]
     331           17 :             notifier: None,
     332           17 :         }
     333           17 :     }
     334              :     // pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self {
     335              :     //     assert!(initial_limit > 0);
     336              : 
     337              :     //     Self {
     338              :     //         limit_algo: AsyncMutex::new(limit_algorithm),
     339              :     //         semaphore: Arc::new(Semaphore::new(initial_limit)),
     340              :     //         timeout,
     341              :     //         limits: AtomicUsize::new(initial_limit),
     342              :     //         in_flight: Arc::new(AtomicUsize::new(0)),
     343              :     //         #[cfg(test)]
     344              :     //         notifier: None,
     345              :     //     }
     346              :     // }
     347              : 
     348              :     /// In some cases [Token]s are acquired asynchronously when updating the limit.
     349              :     #[cfg(test)]
     350            2 :     pub fn with_release_notifier(mut self, n: std::sync::Arc<tokio::sync::Notify>) -> Self {
     351            2 :         self.notifier = Some(n);
     352            2 :         self
     353            2 :     }
     354              : 
     355              :     /// Try to immediately acquire a concurrency [Token].
     356              :     ///
     357              :     /// Returns `None` if there are none available.
     358           26 :     pub fn try_acquire(&self) -> Option<Token> {
     359           26 :         let result = if self.config.disable {
     360              :             // If the rate limiter is disabled, we can always acquire a token.
     361            4 :             Some(Token::new(None, self.in_flight.clone()))
     362              :         } else {
     363           22 :             self.semaphore
     364           22 :                 .try_acquire()
     365           22 :                 .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
     366           22 :                 .ok()
     367              :         };
     368           26 :         if result.is_some() {
     369           22 :             self.in_flight.fetch_add(1, Ordering::AcqRel);
     370           22 :         }
     371           26 :         result
     372           26 :     }
     373              : 
     374              :     /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
     375              :     ///
     376              :     /// Returns `None` if there are none available after `duration`.
     377           12 :     pub async fn acquire_timeout(&self, duration: Duration) -> Option<Token<'_>> {
     378            4 :         info!("acquiring token: {:?}", self.semaphore.available_permits());
     379           12 :         let result = if self.config.disable {
     380              :             // If the rate limiter is disabled, we can always acquire a token.
     381            4 :             Some(Token::new(None, self.in_flight.clone()))
     382              :         } else {
     383            8 :             match timeout(duration, self.semaphore.acquire()).await {
     384            7 :                 Ok(maybe_permit) => maybe_permit
     385            7 :                     .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
     386            7 :                     .ok(),
     387            1 :                 Err(_) => None,
     388              :             }
     389              :         };
     390           12 :         if result.is_some() {
     391           11 :             self.in_flight.fetch_add(1, Ordering::AcqRel);
     392           11 :         }
     393           12 :         result
     394           12 :     }
     395              : 
     396              :     /// Return the concurrency [Token], along with the outcome of the job.
     397              :     ///
     398              :     /// The [Outcome] of the job, and the time taken to perform it, may be used
     399              :     /// to update the concurrency limit.
     400              :     ///
     401              :     /// Set the outcome to `None` to ignore the job.
     402           29 :     pub async fn release(&self, mut token: Token<'_>, outcome: Option<Outcome>) {
     403            3 :         tracing::info!("outcome is {:?}", outcome);
     404           29 :         let in_flight = self.in_flight.load(Ordering::Acquire);
     405           29 :         let old_limit = self.limits.load(Ordering::Acquire);
     406           29 :         let available = if self.config.disable {
     407            8 :             0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0.
     408              :         } else {
     409           21 :             self.semaphore.available_permits()
     410              :         };
     411           29 :         let total = in_flight + available;
     412              : 
     413           29 :         let mut algo = self.limit_algo.lock().await;
     414              : 
     415           29 :         let new_limit = if let Some(outcome) = outcome {
     416           23 :             let sample = Sample {
     417           23 :                 latency: token.start.elapsed(),
     418           23 :                 in_flight,
     419           23 :                 outcome,
     420           23 :             };
     421           23 :             algo.update(old_limit, sample).await
     422              :         } else {
     423            6 :             old_limit
     424              :         };
     425            3 :         tracing::info!("new limit is {}", new_limit);
     426           29 :         let actual_limit = if new_limit < total {
     427            6 :             token.forget();
     428            6 :             total.saturating_sub(1)
     429              :         } else {
     430           23 :             if !self.config.disable {
     431           17 :                 self.semaphore.add_permits(new_limit.saturating_sub(total));
     432           17 :             }
     433           23 :             new_limit
     434              :         };
     435           26 :         crate::metrics::RATE_LIMITER_LIMIT
     436           26 :             .with_label_values(&["expected"])
     437           26 :             .set(new_limit as i64);
     438           26 :         crate::metrics::RATE_LIMITER_LIMIT
     439           26 :             .with_label_values(&["actual"])
     440           26 :             .set(actual_limit as i64);
     441           26 :         self.limits.store(new_limit, Ordering::Release);
     442            3 :         #[cfg(test)]
     443           26 :         if let Some(n) = &self.notifier {
     444            2 :             n.notify_one();
     445           24 :         }
     446           26 :     }
     447              : 
     448              :     /// The current state of the limiter.
     449           12 :     pub fn state(&self) -> LimiterState {
     450           12 :         let limit = self.limits.load(Ordering::Relaxed);
     451           12 :         let in_flight = self.in_flight.load(Ordering::Relaxed);
     452           12 :         LimiterState { limit, in_flight }
     453           12 :     }
     454              : }
     455              : 
     456              : impl<'t> Token<'t> {
     457           33 :     fn new(permit: Option<SemaphorePermit<'t>>, in_flight: Arc<AtomicUsize>) -> Self {
     458           33 :         Self {
     459           33 :             permit,
     460           33 :             start: Instant::now(),
     461           33 :             in_flight,
     462           33 :         }
     463           33 :     }
     464              : 
     465            6 :     pub fn forget(&mut self) {
     466            6 :         if let Some(permit) = self.permit.take() {
     467            4 :             permit.forget();
     468            4 :         }
     469            6 :     }
     470              : }
     471              : 
     472              : impl Drop for Token<'_> {
     473           33 :     fn drop(&mut self) {
     474           33 :         self.in_flight.fetch_sub(1, Ordering::AcqRel);
     475           33 :     }
     476              : }
     477              : 
     478              : impl LimiterState {
     479              :     /// The current concurrency limit.
     480           12 :     pub fn limit(&self) -> usize {
     481           12 :         self.limit
     482           12 :     }
     483              :     /// The number of jobs in flight.
     484            2 :     pub fn in_flight(&self) -> usize {
     485            2 :         self.in_flight
     486            2 :     }
     487              : }
     488              : 
     489              : #[async_trait::async_trait]
     490              : impl reqwest_middleware::Middleware for Limiter {
     491            8 :     async fn handle(
     492            8 :         &self,
     493            8 :         req: reqwest::Request,
     494            8 :         extensions: &mut task_local_extensions::Extensions,
     495            8 :         next: reqwest_middleware::Next<'_>,
     496            8 :     ) -> reqwest_middleware::Result<reqwest::Response> {
     497            8 :         let start = Instant::now();
     498            8 :         let token = self
     499            8 :             .acquire_timeout(self.config.timeout)
     500            1 :             .await
     501            8 :             .ok_or_else(|| {
     502            1 :                 reqwest_middleware::Error::Middleware(
     503            1 :                     // TODO: Should we map it into user facing errors?
     504            1 :                     crate::console::errors::ApiError::Console {
     505            1 :                         status: crate::http::StatusCode::TOO_MANY_REQUESTS,
     506            1 :                         text: "Too many requests".into(),
     507            1 :                     }
     508            1 :                     .into(),
     509            1 :                 )
     510            8 :             })?;
     511            7 :         info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane");
     512            7 :         crate::metrics::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64());
     513           19 :         match next.run(req, extensions).await {
     514            7 :             Ok(response) => {
     515            7 :                 self.release(token, Some(Outcome::from_reqwest_response(&response)))
     516            0 :                     .await;
     517            7 :                 Ok(response)
     518              :             }
     519            0 :             Err(e) => {
     520            0 :                 self.release(token, Some(Outcome::from_reqwest_error(&e)))
     521            0 :                     .await;
     522            0 :                 Err(e)
     523              :             }
     524              :         }
     525           24 :     }
     526              : }
     527              : 
     528              : #[cfg(test)]
     529              : mod tests {
     530              :     use std::{hash::BuildHasherDefault, pin::pin, task::Context, time::Duration};
     531              : 
     532              :     use futures::{task::noop_waker_ref, Future};
     533              :     use rand::SeedableRng;
     534              :     use rustc_hash::FxHasher;
     535              :     use tokio::time;
     536              : 
     537              :     use super::{EndpointRateLimiter, Limiter, Outcome};
     538              :     use crate::{
     539              :         rate_limiter::{RateBucketInfo, RateLimitAlgorithm},
     540              :         EndpointId,
     541              :     };
     542              : 
     543            2 :     #[tokio::test]
     544            2 :     async fn it_works() {
     545            2 :         let config = super::RateLimiterConfig {
     546            2 :             algorithm: RateLimitAlgorithm::Fixed,
     547            2 :             timeout: Duration::from_secs(1),
     548            2 :             initial_limit: 10,
     549            2 :             disable: false,
     550            2 :             ..Default::default()
     551            2 :         };
     552            2 :         let limiter = Limiter::new(config);
     553            2 : 
     554            2 :         let token = limiter.try_acquire().unwrap();
     555            2 : 
     556            2 :         limiter.release(token, Some(Outcome::Success)).await;
     557            2 : 
     558            2 :         assert_eq!(limiter.state().limit(), 10);
     559            2 :     }
     560              : 
     561            2 :     #[tokio::test]
     562            2 :     async fn is_fair() {
     563            2 :         let config = super::RateLimiterConfig {
     564            2 :             algorithm: RateLimitAlgorithm::Fixed,
     565            2 :             timeout: Duration::from_secs(1),
     566            2 :             initial_limit: 1,
     567            2 :             disable: false,
     568            2 :             ..Default::default()
     569            2 :         };
     570            2 :         let limiter = Limiter::new(config);
     571            2 : 
     572            2 :         // === TOKEN 1 ===
     573            2 :         let token1 = limiter.try_acquire().unwrap();
     574            2 : 
     575            2 :         let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
     576            2 :         assert!(
     577            2 :             token2_fut
     578            2 :                 .as_mut()
     579            2 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     580            2 :                 .is_pending(),
     581            2 :             "token is acquired by token1"
     582            2 :         );
     583            2 : 
     584            2 :         let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
     585            2 :         assert!(
     586            2 :             token3_fut
     587            2 :                 .as_mut()
     588            2 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     589            2 :                 .is_pending(),
     590            2 :             "token is acquired by token1"
     591            2 :         );
     592            2 : 
     593            2 :         limiter.release(token1, Some(Outcome::Success)).await;
     594            2 :         // === END TOKEN 1 ===
     595            2 : 
     596            2 :         // === TOKEN 2 ===
     597            2 :         assert!(
     598            2 :             limiter.try_acquire().is_none(),
     599            2 :             "token is acquired by token2"
     600            2 :         );
     601            2 : 
     602            2 :         assert!(
     603            2 :             token3_fut
     604            2 :                 .as_mut()
     605            2 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     606            2 :                 .is_pending(),
     607            2 :             "token is acquired by token2"
     608            2 :         );
     609            2 : 
     610            2 :         let token2 = token2_fut.await.unwrap();
     611            2 : 
     612            2 :         limiter.release(token2, Some(Outcome::Success)).await;
     613            2 :         // === END TOKEN 2 ===
     614            2 : 
     615            2 :         // === TOKEN 3 ===
     616            2 :         assert!(
     617            2 :             limiter.try_acquire().is_none(),
     618            2 :             "token is acquired by token3"
     619            2 :         );
     620            2 : 
     621            2 :         let token3 = token3_fut.await.unwrap();
     622            2 :         limiter.release(token3, Some(Outcome::Success)).await;
     623            2 :         // === END TOKEN 3 ===
     624            2 : 
     625            2 :         // === TOKEN 4 ===
     626            2 :         let token4 = limiter.try_acquire().unwrap();
     627            2 :         limiter.release(token4, Some(Outcome::Success)).await;
     628            2 :     }
     629              : 
     630            2 :     #[tokio::test]
     631            2 :     async fn disable() {
     632            2 :         let config = super::RateLimiterConfig {
     633            2 :             algorithm: RateLimitAlgorithm::Fixed,
     634            2 :             timeout: Duration::from_secs(1),
     635            2 :             initial_limit: 1,
     636            2 :             disable: true,
     637            2 :             ..Default::default()
     638            2 :         };
     639            2 :         let limiter = Limiter::new(config);
     640            2 : 
     641            2 :         // === TOKEN 1 ===
     642            2 :         let token1 = limiter.try_acquire().unwrap();
     643            2 :         let token2 = limiter.try_acquire().unwrap();
     644            2 :         let state = limiter.state();
     645            2 :         assert_eq!(state.limit(), 1);
     646            2 :         assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected.
     647            2 :         limiter.release(token1, None).await;
     648            2 :         limiter.release(token2, None).await;
     649            2 :     }
     650              : 
     651            2 :     #[test]
     652            2 :     fn rate_bucket_rpi() {
     653            2 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
     654            2 :         assert_eq!(rate_bucket.max_rpi, 50 * 5);
     655              : 
     656            2 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
     657            2 :         assert_eq!(rate_bucket.max_rpi, 50 / 2);
     658            2 :     }
     659              : 
     660            2 :     #[test]
     661            2 :     fn rate_bucket_parse() {
     662            2 :         let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
     663            2 :         assert_eq!(rate_bucket.interval, Duration::from_secs(10));
     664            2 :         assert_eq!(rate_bucket.max_rpi, 100 * 10);
     665            2 :         assert_eq!(rate_bucket.to_string(), "100@10s");
     666              : 
     667            2 :         let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
     668            2 :         assert_eq!(rate_bucket.interval, Duration::from_secs(60));
     669            2 :         assert_eq!(rate_bucket.max_rpi, 100 * 60);
     670            2 :         assert_eq!(rate_bucket.to_string(), "100@1m");
     671            2 :     }
     672              : 
     673            2 :     #[test]
     674            2 :     fn default_rate_buckets() {
     675            2 :         let mut defaults = RateBucketInfo::DEFAULT_SET;
     676            2 :         RateBucketInfo::validate(&mut defaults[..]).unwrap();
     677            2 :     }
     678              : 
     679            2 :     #[test]
     680              :     #[should_panic = "invalid endpoint RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
     681            2 :     fn rate_buckets_validate() {
     682            2 :         let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
     683            2 :             .into_iter()
     684            4 :             .map(|s| s.parse().unwrap())
     685            2 :             .collect();
     686            2 :         RateBucketInfo::validate(&mut rates).unwrap();
     687            2 :     }
     688              : 
     689            2 :     #[tokio::test]
     690            2 :     async fn test_rate_limits() {
     691            2 :         let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
     692            2 :             .into_iter()
     693            4 :             .map(|s| s.parse().unwrap())
     694            2 :             .collect();
     695            2 :         RateBucketInfo::validate(&mut rates).unwrap();
     696            2 :         let limiter = EndpointRateLimiter::new(Vec::leak(rates));
     697            2 : 
     698            2 :         let endpoint = EndpointId::from("ep-my-endpoint-1234");
     699            2 : 
     700            2 :         time::pause();
     701            2 : 
     702          202 :         for _ in 0..100 {
     703          200 :             assert!(limiter.check(endpoint.clone()));
     704            2 :         }
     705            2 :         // more connections fail
     706            2 :         assert!(!limiter.check(endpoint.clone()));
     707            2 : 
     708            2 :         // fail even after 500ms as it's in the same bucket
     709            2 :         time::advance(time::Duration::from_millis(500)).await;
     710            2 :         assert!(!limiter.check(endpoint.clone()));
     711            2 : 
     712            2 :         // after a full 1s, 100 requests are allowed again
     713            2 :         time::advance(time::Duration::from_millis(500)).await;
     714           12 :         for _ in 1..6 {
     715         1010 :             for _ in 0..100 {
     716         1000 :                 assert!(limiter.check(endpoint.clone()));
     717            2 :             }
     718           10 :             time::advance(time::Duration::from_millis(1000)).await;
     719            2 :         }
     720            2 : 
     721            2 :         // more connections after 600 will exceed the 20rps@30s limit
     722            2 :         assert!(!limiter.check(endpoint.clone()));
     723            2 : 
     724            2 :         // will still fail before the 30 second limit
     725            2 :         time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
     726            2 :         assert!(!limiter.check(endpoint.clone()));
     727            2 : 
     728            2 :         // after the full 30 seconds, 100 requests are allowed again
     729            2 :         time::advance(time::Duration::from_millis(1)).await;
     730          202 :         for _ in 0..100 {
     731          200 :             assert!(limiter.check(endpoint.clone()));
     732            2 :         }
     733            2 :     }
     734              : 
     735            2 :     #[tokio::test]
     736            2 :     async fn test_rate_limits_gc() {
     737            2 :         // fixed seeded random/hasher to ensure that the test is not flaky
     738            2 :         let rand = rand::rngs::StdRng::from_seed([1; 32]);
     739            2 :         let hasher = BuildHasherDefault::<FxHasher>::default();
     740            2 : 
     741            2 :         let limiter = EndpointRateLimiter::new_with_rand_and_hasher(
     742            2 :             &RateBucketInfo::DEFAULT_SET,
     743            2 :             rand,
     744            2 :             hasher,
     745            2 :         );
     746      2000002 :         for i in 0..1_000_000 {
     747      2000000 :             limiter.check(format!("{i}").into());
     748      2000000 :         }
     749            2 :         assert!(limiter.map.len() < 150_000);
     750            2 :     }
     751              : }
        

Generated by: LCOV version 2.1-beta