LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - limiter.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 93.5 % 415 388
Test Date: 2024-02-07 07:37:29 Functions: 80.2 % 81 65

            Line data    Source code
       1              : use std::{
       2              :     collections::hash_map::RandomState,
       3              :     hash::BuildHasher,
       4              :     sync::{
       5              :         atomic::{AtomicUsize, Ordering},
       6              :         Arc, Mutex,
       7              :     },
       8              : };
       9              : 
      10              : use anyhow::bail;
      11              : use dashmap::DashMap;
      12              : use itertools::Itertools;
      13              : use rand::{rngs::StdRng, Rng, SeedableRng};
      14              : use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
      15              : use tokio::time::{timeout, Duration, Instant};
      16              : use tracing::info;
      17              : 
      18              : use crate::EndpointId;
      19              : 
      20              : use super::{
      21              :     limit_algorithm::{LimitAlgorithm, Sample},
      22              :     RateLimiterConfig,
      23              : };
      24              : 
      25              : // Simple per-endpoint rate limiter.
      26              : //
      27              : // Check that number of connections to the endpoint is below `max_rps` rps.
      28              : // Purposefully ignore user name and database name as clients can reconnect
      29              : // with different names, so we'll end up sending some http requests to
      30              : // the control plane.
      31              : //
      32              : // We also may save quite a lot of CPU (I think) by bailing out right after we
      33              : // saw SNI, before doing TLS handshake. User-side error messages in that case
      34              : // does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
      35              : // I went with a more expensive way that yields user-friendlier error messages.
      36              : pub struct EndpointRateLimiter<Rand = StdRng, Hasher = RandomState> {
      37              :     map: DashMap<EndpointId, Vec<RateBucket>, Hasher>,
      38              :     info: &'static [RateBucketInfo],
      39              :     access_count: AtomicUsize,
      40              :     rand: Mutex<Rand>,
      41              : }
      42              : 
      43      4000052 : #[derive(Clone, Copy)]
      44              : struct RateBucket {
      45              :     start: Instant,
      46              :     count: u32,
      47              : }
      48              : 
      49              : impl RateBucket {
      50      6002953 :     fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant) -> bool {
      51      6002953 :         if now - self.start < info.interval {
      52      6002934 :             self.count < info.max_rpi
      53              :         } else {
      54              :             // bucket expired, reset
      55           19 :             self.count = 0;
      56           19 :             self.start = now;
      57           19 : 
      58           19 :             true
      59              :         }
      60      6002953 :     }
      61              : 
      62      6002941 :     fn inc(&mut self) {
      63      6002941 :         self.count += 1;
      64      6002941 :     }
      65              : }
      66              : 
      67            4 : #[derive(Clone, Copy, PartialEq)]
      68              : pub struct RateBucketInfo {
      69              :     pub interval: Duration,
      70              :     // requests per interval
      71              :     pub max_rpi: u32,
      72              : }
      73              : 
      74              : impl std::fmt::Display for RateBucketInfo {
      75          152 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      76          152 :         let rps = self.max_rpi * 1000 / self.interval.as_millis() as u32;
      77          152 :         write!(f, "{rps}@{}", humantime::format_duration(self.interval))
      78          152 :     }
      79              : }
      80              : 
      81              : impl std::fmt::Debug for RateBucketInfo {
      82           69 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      83           69 :         write!(f, "{self}")
      84           69 :     }
      85              : }
      86              : 
      87              : impl std::str::FromStr for RateBucketInfo {
      88              :     type Err = anyhow::Error;
      89              : 
      90          160 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
      91          160 :         let Some((max_rps, interval)) = s.split_once('@') else {
      92            0 :             bail!("invalid rate info")
      93              :         };
      94          160 :         let max_rps = max_rps.parse()?;
      95          160 :         let interval = humantime::parse_duration(interval)?;
      96          160 :         Ok(Self::new(max_rps, interval))
      97          160 :     }
      98              : }
      99              : 
     100              : impl RateBucketInfo {
     101              :     pub const DEFAULT_SET: [Self; 3] = [
     102              :         Self::new(300, Duration::from_secs(1)),
     103              :         Self::new(200, Duration::from_secs(60)),
     104              :         Self::new(100, Duration::from_secs(600)),
     105              :     ];
     106              : 
     107           29 :     pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
     108          108 :         info.sort_unstable_by_key(|info| info.interval);
     109           29 :         let invalid = info
     110           29 :             .iter()
     111           29 :             .tuple_windows()
     112           54 :             .find(|(a, b)| a.max_rpi > b.max_rpi);
     113           29 :         if let Some((a, b)) = invalid {
     114            2 :             bail!(
     115            2 :                 "invalid endpoint RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
     116            2 :                 b.max_rpi,
     117            2 :                 a.max_rpi,
     118            2 :             );
     119           27 :         }
     120           27 : 
     121           27 :         Ok(())
     122           29 :     }
     123              : 
     124          168 :     pub const fn new(max_rps: u32, interval: Duration) -> Self {
     125          168 :         Self {
     126          168 :             interval,
     127          168 :             max_rpi: max_rps * interval.as_millis() as u32 / 1000,
     128          168 :         }
     129          168 :     }
     130              : }
     131              : 
     132              : impl EndpointRateLimiter {
     133           25 :     pub fn new(info: &'static [RateBucketInfo]) -> Self {
     134           25 :         Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
     135           25 :     }
     136              : }
     137              : 
     138              : impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
     139           27 :     fn new_with_rand_and_hasher(info: &'static [RateBucketInfo], rand: R, hasher: S) -> Self {
     140           27 :         info!(buckets = ?info, "endpoint rate limiter");
     141           27 :         Self {
     142           27 :             info,
     143           27 :             map: DashMap::with_hasher_and_shard_amount(hasher, 64),
     144           27 :             access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
     145           27 :             rand: Mutex::new(rand),
     146           27 :         }
     147           27 :     }
     148              : 
     149              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
     150      2001455 :     pub fn check(&self, endpoint: EndpointId) -> bool {
     151      2001455 :         // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
     152      2001455 :         // worst case memory usage is about:
     153      2001455 :         //    = 2 * 2048 * 64 * (48B + 72B)
     154      2001455 :         //    = 30MB
     155      2001455 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
     156          976 :             self.do_gc();
     157      2000479 :         }
     158              : 
     159      2001455 :         let now = Instant::now();
     160      2001455 :         let mut entry = self.map.entry(endpoint).or_insert_with(|| {
     161      2000027 :             vec![
     162      2000027 :                 RateBucket {
     163      2000027 :                     start: now,
     164      2000027 :                     count: 0,
     165      2000027 :                 };
     166      2000027 :                 self.info.len()
     167      2000027 :             ]
     168      2001455 :         });
     169      2001455 : 
     170      2001455 :         let should_allow_request = entry
     171      2001455 :             .iter_mut()
     172      2001455 :             .zip(self.info)
     173      6002953 :             .all(|(bucket, info)| bucket.should_allow_request(info, now));
     174      2001455 : 
     175      2001455 :         if should_allow_request {
     176      2001447 :             // only increment the bucket counts if the request will actually be accepted
     177      2001447 :             entry.iter_mut().for_each(RateBucket::inc);
     178      2001447 :         }
     179              : 
     180      2001455 :         should_allow_request
     181      2001455 :     }
     182              : 
     183              :     /// Clean the map. Simple strategy: remove all entries in a random shard.
     184              :     /// At worst, we'll double the effective max_rps during the cleanup.
     185              :     /// But that way deletion does not aquire mutex on each entry access.
     186          976 :     pub fn do_gc(&self) {
     187          976 :         info!(
     188            0 :             "cleaning up endpoint rate limiter, current size = {}",
     189            0 :             self.map.len()
     190            0 :         );
     191          976 :         let n = self.map.shards().len();
     192          976 :         // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
     193          976 :         // (impossible, infact, unless we have 2048 threads)
     194          976 :         let shard = self.rand.lock().unwrap().gen_range(0..n);
     195          976 :         self.map.shards()[shard].write().clear();
     196          976 :     }
     197              : }
     198              : 
     199              : /// Limits the number of concurrent jobs.
     200              : ///
     201              : /// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
     202              : /// token once the job is finished.
     203              : ///
     204              : /// The limit will be automatically adjusted based on observed latency (delay) and/or failures
     205              : /// caused by overload (loss).
     206              : pub struct Limiter {
     207              :     limit_algo: AsyncMutex<Box<dyn LimitAlgorithm>>,
     208              :     semaphore: std::sync::Arc<Semaphore>,
     209              :     config: RateLimiterConfig,
     210              : 
     211              :     // ONLY WRITE WHEN LIMIT_ALGO IS LOCKED
     212              :     limits: AtomicUsize,
     213              : 
     214              :     // ONLY USE ATOMIC ADD/SUB
     215              :     in_flight: Arc<AtomicUsize>,
     216              : 
     217              :     #[cfg(test)]
     218              :     notifier: Option<std::sync::Arc<tokio::sync::Notify>>,
     219              : }
     220              : 
     221              : /// A concurrency token, required to run a job.
     222              : ///
     223              : /// Release the token back to the [Limiter] after the job is complete.
     224            0 : #[derive(Debug)]
     225              : pub struct Token<'t> {
     226              :     permit: Option<tokio::sync::SemaphorePermit<'t>>,
     227              :     start: Instant,
     228              :     in_flight: Arc<AtomicUsize>,
     229              : }
     230              : 
     231              : /// A snapshot of the state of the [Limiter].
     232              : ///
     233              : /// Not guaranteed to be consistent under high concurrency.
     234            0 : #[derive(Debug, Clone, Copy)]
     235              : pub struct LimiterState {
     236              :     limit: usize,
     237              :     in_flight: usize,
     238              : }
     239              : 
     240              : /// Whether a job succeeded or failed as a result of congestion/overload.
     241              : ///
     242              : /// Errors not considered to be caused by overload should be ignored.
     243            6 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
     244              : pub enum Outcome {
     245              :     /// The job succeeded, or failed in a way unrelated to overload.
     246              :     Success,
     247              :     /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
     248              :     /// was observed.
     249              :     Overload,
     250              : }
     251              : 
     252              : impl Outcome {
     253            0 :     fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self {
     254            0 :         match error {
     255            0 :             reqwest_middleware::Error::Middleware(_) => Outcome::Success,
     256            0 :             reqwest_middleware::Error::Reqwest(e) => {
     257            0 :                 if let Some(status) = e.status() {
     258            0 :                     if status.is_server_error()
     259            0 :                         || reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status
     260              :                     {
     261            0 :                         Outcome::Overload
     262              :                     } else {
     263            0 :                         Outcome::Success
     264              :                     }
     265              :                 } else {
     266            0 :                     Outcome::Success
     267              :                 }
     268              :             }
     269              :         }
     270            0 :     }
     271            7 :     fn from_reqwest_response(response: &reqwest::Response) -> Self {
     272            7 :         if response.status().is_server_error()
     273            6 :             || response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS
     274              :         {
     275            2 :             Outcome::Overload
     276              :         } else {
     277            5 :             Outcome::Success
     278              :         }
     279            7 :     }
     280              : }
     281              : 
     282              : impl Limiter {
     283              :     /// Create a limiter with a given limit control algorithm.
     284           17 :     pub fn new(config: RateLimiterConfig) -> Self {
     285           17 :         assert!(config.initial_limit > 0);
     286           17 :         Self {
     287           17 :             limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()),
     288           17 :             semaphore: Arc::new(Semaphore::new(config.initial_limit)),
     289           17 :             config,
     290           17 :             limits: AtomicUsize::new(config.initial_limit),
     291           17 :             in_flight: Arc::new(AtomicUsize::new(0)),
     292           17 :             #[cfg(test)]
     293           17 :             notifier: None,
     294           17 :         }
     295           17 :     }
     296              :     // pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self {
     297              :     //     assert!(initial_limit > 0);
     298              : 
     299              :     //     Self {
     300              :     //         limit_algo: AsyncMutex::new(limit_algorithm),
     301              :     //         semaphore: Arc::new(Semaphore::new(initial_limit)),
     302              :     //         timeout,
     303              :     //         limits: AtomicUsize::new(initial_limit),
     304              :     //         in_flight: Arc::new(AtomicUsize::new(0)),
     305              :     //         #[cfg(test)]
     306              :     //         notifier: None,
     307              :     //     }
     308              :     // }
     309              : 
     310              :     /// In some cases [Token]s are acquired asynchronously when updating the limit.
     311              :     #[cfg(test)]
     312            2 :     pub fn with_release_notifier(mut self, n: std::sync::Arc<tokio::sync::Notify>) -> Self {
     313            2 :         self.notifier = Some(n);
     314            2 :         self
     315            2 :     }
     316              : 
     317              :     /// Try to immediately acquire a concurrency [Token].
     318              :     ///
     319              :     /// Returns `None` if there are none available.
     320           26 :     pub fn try_acquire(&self) -> Option<Token> {
     321           26 :         let result = if self.config.disable {
     322              :             // If the rate limiter is disabled, we can always acquire a token.
     323            4 :             Some(Token::new(None, self.in_flight.clone()))
     324              :         } else {
     325           22 :             self.semaphore
     326           22 :                 .try_acquire()
     327           22 :                 .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
     328           22 :                 .ok()
     329              :         };
     330           26 :         if result.is_some() {
     331           22 :             self.in_flight.fetch_add(1, Ordering::AcqRel);
     332           22 :         }
     333           26 :         result
     334           26 :     }
     335              : 
     336              :     /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
     337              :     ///
     338              :     /// Returns `None` if there are none available after `duration`.
     339           12 :     pub async fn acquire_timeout(&self, duration: Duration) -> Option<Token<'_>> {
     340            4 :         info!("acquiring token: {:?}", self.semaphore.available_permits());
     341           12 :         let result = if self.config.disable {
     342              :             // If the rate limiter is disabled, we can always acquire a token.
     343            4 :             Some(Token::new(None, self.in_flight.clone()))
     344              :         } else {
     345            8 :             match timeout(duration, self.semaphore.acquire()).await {
     346            7 :                 Ok(maybe_permit) => maybe_permit
     347            7 :                     .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
     348            7 :                     .ok(),
     349            1 :                 Err(_) => None,
     350              :             }
     351              :         };
     352           12 :         if result.is_some() {
     353           11 :             self.in_flight.fetch_add(1, Ordering::AcqRel);
     354           11 :         }
     355           12 :         result
     356           12 :     }
     357              : 
     358              :     /// Return the concurrency [Token], along with the outcome of the job.
     359              :     ///
     360              :     /// The [Outcome] of the job, and the time taken to perform it, may be used
     361              :     /// to update the concurrency limit.
     362              :     ///
     363              :     /// Set the outcome to `None` to ignore the job.
     364           29 :     pub async fn release(&self, mut token: Token<'_>, outcome: Option<Outcome>) {
     365            3 :         tracing::info!("outcome is {:?}", outcome);
     366           29 :         let in_flight = self.in_flight.load(Ordering::Acquire);
     367           29 :         let old_limit = self.limits.load(Ordering::Acquire);
     368           29 :         let available = if self.config.disable {
     369            8 :             0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0.
     370              :         } else {
     371           21 :             self.semaphore.available_permits()
     372              :         };
     373           29 :         let total = in_flight + available;
     374              : 
     375           29 :         let mut algo = self.limit_algo.lock().await;
     376              : 
     377           29 :         let new_limit = if let Some(outcome) = outcome {
     378           23 :             let sample = Sample {
     379           23 :                 latency: token.start.elapsed(),
     380           23 :                 in_flight,
     381           23 :                 outcome,
     382           23 :             };
     383           23 :             algo.update(old_limit, sample).await
     384              :         } else {
     385            6 :             old_limit
     386              :         };
     387            3 :         tracing::info!("new limit is {}", new_limit);
     388           29 :         let actual_limit = if new_limit < total {
     389            6 :             token.forget();
     390            6 :             total.saturating_sub(1)
     391              :         } else {
     392           23 :             if !self.config.disable {
     393           17 :                 self.semaphore.add_permits(new_limit.saturating_sub(total));
     394           17 :             }
     395           23 :             new_limit
     396              :         };
     397           26 :         crate::metrics::RATE_LIMITER_LIMIT
     398           26 :             .with_label_values(&["expected"])
     399           26 :             .set(new_limit as i64);
     400           26 :         crate::metrics::RATE_LIMITER_LIMIT
     401           26 :             .with_label_values(&["actual"])
     402           26 :             .set(actual_limit as i64);
     403           26 :         self.limits.store(new_limit, Ordering::Release);
     404            3 :         #[cfg(test)]
     405           26 :         if let Some(n) = &self.notifier {
     406            2 :             n.notify_one();
     407           24 :         }
     408           26 :     }
     409              : 
     410              :     /// The current state of the limiter.
     411           12 :     pub fn state(&self) -> LimiterState {
     412           12 :         let limit = self.limits.load(Ordering::Relaxed);
     413           12 :         let in_flight = self.in_flight.load(Ordering::Relaxed);
     414           12 :         LimiterState { limit, in_flight }
     415           12 :     }
     416              : }
     417              : 
     418              : impl<'t> Token<'t> {
     419           33 :     fn new(permit: Option<SemaphorePermit<'t>>, in_flight: Arc<AtomicUsize>) -> Self {
     420           33 :         Self {
     421           33 :             permit,
     422           33 :             start: Instant::now(),
     423           33 :             in_flight,
     424           33 :         }
     425           33 :     }
     426              : 
     427            6 :     pub fn forget(&mut self) {
     428            6 :         if let Some(permit) = self.permit.take() {
     429            4 :             permit.forget();
     430            4 :         }
     431            6 :     }
     432              : }
     433              : 
     434              : impl Drop for Token<'_> {
     435           33 :     fn drop(&mut self) {
     436           33 :         self.in_flight.fetch_sub(1, Ordering::AcqRel);
     437           33 :     }
     438              : }
     439              : 
     440              : impl LimiterState {
     441              :     /// The current concurrency limit.
     442           12 :     pub fn limit(&self) -> usize {
     443           12 :         self.limit
     444           12 :     }
     445              :     /// The number of jobs in flight.
     446            2 :     pub fn in_flight(&self) -> usize {
     447            2 :         self.in_flight
     448            2 :     }
     449              : }
     450              : 
     451              : #[async_trait::async_trait]
     452              : impl reqwest_middleware::Middleware for Limiter {
     453            8 :     async fn handle(
     454            8 :         &self,
     455            8 :         req: reqwest::Request,
     456            8 :         extensions: &mut task_local_extensions::Extensions,
     457            8 :         next: reqwest_middleware::Next<'_>,
     458            8 :     ) -> reqwest_middleware::Result<reqwest::Response> {
     459            8 :         let start = Instant::now();
     460            8 :         let token = self
     461            8 :             .acquire_timeout(self.config.timeout)
     462            1 :             .await
     463            8 :             .ok_or_else(|| {
     464            1 :                 reqwest_middleware::Error::Middleware(
     465            1 :                     // TODO: Should we map it into user facing errors?
     466            1 :                     crate::console::errors::ApiError::Console {
     467            1 :                         status: crate::http::StatusCode::TOO_MANY_REQUESTS,
     468            1 :                         text: "Too many requests".into(),
     469            1 :                     }
     470            1 :                     .into(),
     471            1 :                 )
     472            8 :             })?;
     473            7 :         info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane");
     474            7 :         crate::metrics::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64());
     475           19 :         match next.run(req, extensions).await {
     476            7 :             Ok(response) => {
     477            7 :                 self.release(token, Some(Outcome::from_reqwest_response(&response)))
     478            0 :                     .await;
     479            7 :                 Ok(response)
     480              :             }
     481            0 :             Err(e) => {
     482            0 :                 self.release(token, Some(Outcome::from_reqwest_error(&e)))
     483            0 :                     .await;
     484            0 :                 Err(e)
     485              :             }
     486              :         }
     487           16 :     }
     488              : }
     489              : 
     490              : #[cfg(test)]
     491              : mod tests {
     492              :     use std::{hash::BuildHasherDefault, pin::pin, task::Context, time::Duration};
     493              : 
     494              :     use futures::{task::noop_waker_ref, Future};
     495              :     use rand::SeedableRng;
     496              :     use rustc_hash::FxHasher;
     497              :     use tokio::time;
     498              : 
     499              :     use super::{EndpointRateLimiter, Limiter, Outcome};
     500              :     use crate::{
     501              :         rate_limiter::{RateBucketInfo, RateLimitAlgorithm},
     502              :         EndpointId,
     503              :     };
     504              : 
     505            2 :     #[tokio::test]
     506            2 :     async fn it_works() {
     507            2 :         let config = super::RateLimiterConfig {
     508            2 :             algorithm: RateLimitAlgorithm::Fixed,
     509            2 :             timeout: Duration::from_secs(1),
     510            2 :             initial_limit: 10,
     511            2 :             disable: false,
     512            2 :             ..Default::default()
     513            2 :         };
     514            2 :         let limiter = Limiter::new(config);
     515            2 : 
     516            2 :         let token = limiter.try_acquire().unwrap();
     517            2 : 
     518            2 :         limiter.release(token, Some(Outcome::Success)).await;
     519              : 
     520            2 :         assert_eq!(limiter.state().limit(), 10);
     521              :     }
     522              : 
     523            2 :     #[tokio::test]
     524            2 :     async fn is_fair() {
     525            2 :         let config = super::RateLimiterConfig {
     526            2 :             algorithm: RateLimitAlgorithm::Fixed,
     527            2 :             timeout: Duration::from_secs(1),
     528            2 :             initial_limit: 1,
     529            2 :             disable: false,
     530            2 :             ..Default::default()
     531            2 :         };
     532            2 :         let limiter = Limiter::new(config);
     533            2 : 
     534            2 :         // === TOKEN 1 ===
     535            2 :         let token1 = limiter.try_acquire().unwrap();
     536            2 : 
     537            2 :         let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
     538            2 :         assert!(
     539            2 :             token2_fut
     540            2 :                 .as_mut()
     541            2 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     542            2 :                 .is_pending(),
     543            0 :             "token is acquired by token1"
     544              :         );
     545              : 
     546            2 :         let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
     547            2 :         assert!(
     548            2 :             token3_fut
     549            2 :                 .as_mut()
     550            2 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     551            2 :                 .is_pending(),
     552            0 :             "token is acquired by token1"
     553              :         );
     554              : 
     555            2 :         limiter.release(token1, Some(Outcome::Success)).await;
     556              :         // === END TOKEN 1 ===
     557              : 
     558              :         // === TOKEN 2 ===
     559            2 :         assert!(
     560            2 :             limiter.try_acquire().is_none(),
     561            0 :             "token is acquired by token2"
     562              :         );
     563              : 
     564            2 :         assert!(
     565            2 :             token3_fut
     566            2 :                 .as_mut()
     567            2 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     568            2 :                 .is_pending(),
     569            0 :             "token is acquired by token2"
     570              :         );
     571              : 
     572            2 :         let token2 = token2_fut.await.unwrap();
     573            2 : 
     574            2 :         limiter.release(token2, Some(Outcome::Success)).await;
     575              :         // === END TOKEN 2 ===
     576              : 
     577              :         // === TOKEN 3 ===
     578            2 :         assert!(
     579            2 :             limiter.try_acquire().is_none(),
     580            0 :             "token is acquired by token3"
     581              :         );
     582              : 
     583            2 :         let token3 = token3_fut.await.unwrap();
     584            2 :         limiter.release(token3, Some(Outcome::Success)).await;
     585              :         // === END TOKEN 3 ===
     586              : 
     587              :         // === TOKEN 4 ===
     588            2 :         let token4 = limiter.try_acquire().unwrap();
     589            2 :         limiter.release(token4, Some(Outcome::Success)).await;
     590              :     }
     591              : 
     592            2 :     #[tokio::test]
     593            2 :     async fn disable() {
     594            2 :         let config = super::RateLimiterConfig {
     595            2 :             algorithm: RateLimitAlgorithm::Fixed,
     596            2 :             timeout: Duration::from_secs(1),
     597            2 :             initial_limit: 1,
     598            2 :             disable: true,
     599            2 :             ..Default::default()
     600            2 :         };
     601            2 :         let limiter = Limiter::new(config);
     602            2 : 
     603            2 :         // === TOKEN 1 ===
     604            2 :         let token1 = limiter.try_acquire().unwrap();
     605            2 :         let token2 = limiter.try_acquire().unwrap();
     606            2 :         let state = limiter.state();
     607            2 :         assert_eq!(state.limit(), 1);
     608            2 :         assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected.
     609            2 :         limiter.release(token1, None).await;
     610            2 :         limiter.release(token2, None).await;
     611              :     }
     612              : 
     613            2 :     #[test]
     614            2 :     fn rate_bucket_rpi() {
     615            2 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
     616            2 :         assert_eq!(rate_bucket.max_rpi, 50 * 5);
     617              : 
     618            2 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
     619            2 :         assert_eq!(rate_bucket.max_rpi, 50 / 2);
     620            2 :     }
     621              : 
     622            2 :     #[test]
     623            2 :     fn rate_bucket_parse() {
     624            2 :         let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
     625            2 :         assert_eq!(rate_bucket.interval, Duration::from_secs(10));
     626            2 :         assert_eq!(rate_bucket.max_rpi, 100 * 10);
     627            2 :         assert_eq!(rate_bucket.to_string(), "100@10s");
     628              : 
     629            2 :         let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
     630            2 :         assert_eq!(rate_bucket.interval, Duration::from_secs(60));
     631            2 :         assert_eq!(rate_bucket.max_rpi, 100 * 60);
     632            2 :         assert_eq!(rate_bucket.to_string(), "100@1m");
     633            2 :     }
     634              : 
     635            2 :     #[test]
     636            2 :     fn default_rate_buckets() {
     637            2 :         let mut defaults = RateBucketInfo::DEFAULT_SET;
     638            2 :         RateBucketInfo::validate(&mut defaults[..]).unwrap();
     639            2 :     }
     640              : 
     641            2 :     #[test]
     642              :     #[should_panic = "invalid endpoint RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
     643            2 :     fn rate_buckets_validate() {
     644            2 :         let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
     645            2 :             .into_iter()
     646            4 :             .map(|s| s.parse().unwrap())
     647            2 :             .collect();
     648            2 :         RateBucketInfo::validate(&mut rates).unwrap();
     649            2 :     }
     650              : 
     651            2 :     #[tokio::test]
     652            2 :     async fn test_rate_limits() {
     653            2 :         let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
     654            2 :             .into_iter()
     655            4 :             .map(|s| s.parse().unwrap())
     656            2 :             .collect();
     657            2 :         RateBucketInfo::validate(&mut rates).unwrap();
     658            2 :         let limiter = EndpointRateLimiter::new(Vec::leak(rates));
     659            2 : 
     660            2 :         let endpoint = EndpointId::from("ep-my-endpoint-1234");
     661            2 : 
     662            2 :         time::pause();
     663              : 
     664          202 :         for _ in 0..100 {
     665          200 :             assert!(limiter.check(endpoint.clone()));
     666              :         }
     667              :         // more connections fail
     668            2 :         assert!(!limiter.check(endpoint.clone()));
     669              : 
     670              :         // fail even after 500ms as it's in the same bucket
     671            2 :         time::advance(time::Duration::from_millis(500)).await;
     672            2 :         assert!(!limiter.check(endpoint.clone()));
     673              : 
     674              :         // after a full 1s, 100 requests are allowed again
     675            2 :         time::advance(time::Duration::from_millis(500)).await;
     676           12 :         for _ in 1..6 {
     677         1010 :             for _ in 0..100 {
     678         1000 :                 assert!(limiter.check(endpoint.clone()));
     679              :             }
     680           10 :             time::advance(time::Duration::from_millis(1000)).await;
     681              :         }
     682              : 
     683              :         // more connections after 600 will exceed the 20rps@30s limit
     684            2 :         assert!(!limiter.check(endpoint.clone()));
     685              : 
     686              :         // will still fail before the 30 second limit
     687            2 :         time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
     688            2 :         assert!(!limiter.check(endpoint.clone()));
     689              : 
     690              :         // after the full 30 seconds, 100 requests are allowed again
     691            2 :         time::advance(time::Duration::from_millis(1)).await;
     692          202 :         for _ in 0..100 {
     693          200 :             assert!(limiter.check(endpoint.clone()));
     694              :         }
     695              :     }
     696              : 
     697            2 :     #[tokio::test]
     698            2 :     async fn test_rate_limits_gc() {
     699            2 :         // fixed seeded random/hasher to ensure that the test is not flaky
     700            2 :         let rand = rand::rngs::StdRng::from_seed([1; 32]);
     701            2 :         let hasher = BuildHasherDefault::<FxHasher>::default();
     702            2 : 
     703            2 :         let limiter = EndpointRateLimiter::new_with_rand_and_hasher(
     704            2 :             &RateBucketInfo::DEFAULT_SET,
     705            2 :             rand,
     706            2 :             hasher,
     707            2 :         );
     708      2000002 :         for i in 0..1_000_000 {
     709      2000000 :             limiter.check(format!("{i}").into());
     710      2000000 :         }
     711            2 :         assert!(limiter.map.len() < 150_000);
     712              :     }
     713              : }
        

Generated by: LCOV version 2.1-beta