LCOV - differential code coverage report
Current view: top level - proxy/src/rate_limiter - limiter.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 93.5 % 415 388 27 388
Current Date: 2024-01-09 02:06:09 Functions: 80.2 % 81 65 16 65
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use std::{
       2                 :     collections::hash_map::RandomState,
       3                 :     hash::BuildHasher,
       4                 :     sync::{
       5                 :         atomic::{AtomicUsize, Ordering},
       6                 :         Arc, Mutex,
       7                 :     },
       8                 : };
       9                 : 
      10                 : use anyhow::bail;
      11                 : use dashmap::DashMap;
      12                 : use itertools::Itertools;
      13                 : use rand::{rngs::StdRng, Rng, SeedableRng};
      14                 : use smol_str::SmolStr;
      15                 : use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
      16                 : use tokio::time::{timeout, Duration, Instant};
      17                 : use tracing::info;
      18                 : 
      19                 : use super::{
      20                 :     limit_algorithm::{LimitAlgorithm, Sample},
      21                 :     RateLimiterConfig,
      22                 : };
      23                 : 
      24                 : // Simple per-endpoint rate limiter.
      25                 : //
      26                 : // Check that number of connections to the endpoint is below `max_rps` rps.
      27                 : // Purposefully ignore user name and database name as clients can reconnect
      28                 : // with different names, so we'll end up sending some http requests to
      29                 : // the control plane.
      30                 : //
      31                 : // We also may save quite a lot of CPU (I think) by bailing out right after we
      32                 : // saw SNI, before doing TLS handshake. User-side error messages in that case
      33                 : // does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
      34                 : // I went with a more expensive way that yields user-friendlier error messages.
      35                 : pub struct EndpointRateLimiter<Rand = StdRng, Hasher = RandomState> {
      36                 :     map: DashMap<SmolStr, Vec<RateBucket>, Hasher>,
      37                 :     info: &'static [RateBucketInfo],
      38                 :     access_count: AtomicUsize,
      39                 :     rand: Mutex<Rand>,
      40                 : }
      41                 : 
      42 CBC     2000049 : #[derive(Clone, Copy)]
      43                 : struct RateBucket {
      44                 :     start: Instant,
      45                 :     count: u32,
      46                 : }
      47                 : 
      48                 : impl RateBucket {
      49         3001544 :     fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant) -> bool {
      50         3001544 :         if now - self.start < info.interval {
      51         3001533 :             self.count < info.max_rpi
      52                 :         } else {
      53                 :             // bucket expired, reset
      54              11 :             self.count = 0;
      55              11 :             self.start = now;
      56              11 : 
      57              11 :             true
      58                 :         }
      59         3001544 :     }
      60                 : 
      61         3001538 :     fn inc(&mut self) {
      62         3001538 :         self.count += 1;
      63         3001538 :     }
      64                 : }
      65                 : 
      66               2 : #[derive(Clone, Copy, PartialEq)]
      67                 : pub struct RateBucketInfo {
      68                 :     pub interval: Duration,
      69                 :     // requests per interval
      70                 :     pub max_rpi: u32,
      71                 : }
      72                 : 
      73                 : impl std::fmt::Display for RateBucketInfo {
      74             139 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      75             139 :         let rps = self.max_rpi * 1000 / self.interval.as_millis() as u32;
      76             139 :         write!(f, "{rps}@{}", humantime::format_duration(self.interval))
      77             139 :     }
      78                 : }
      79                 : 
      80                 : impl std::fmt::Debug for RateBucketInfo {
      81              66 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      82              66 :         write!(f, "{self}")
      83              66 :     }
      84                 : }
      85                 : 
      86                 : impl std::str::FromStr for RateBucketInfo {
      87                 :     type Err = anyhow::Error;
      88                 : 
      89             143 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
      90             143 :         let Some((max_rps, interval)) = s.split_once('@') else {
      91 UBC           0 :             bail!("invalid rate info")
      92                 :         };
      93 CBC         143 :         let max_rps = max_rps.parse()?;
      94             143 :         let interval = humantime::parse_duration(interval)?;
      95             143 :         Ok(Self::new(max_rps, interval))
      96             143 :     }
      97                 : }
      98                 : 
      99                 : impl RateBucketInfo {
     100                 :     pub const DEFAULT_SET: [Self; 3] = [
     101                 :         Self::new(300, Duration::from_secs(1)),
     102                 :         Self::new(200, Duration::from_secs(60)),
     103                 :         Self::new(100, Duration::from_secs(600)),
     104                 :     ];
     105                 : 
     106              25 :     pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
     107              96 :         info.sort_unstable_by_key(|info| info.interval);
     108              25 :         let invalid = info
     109              25 :             .iter()
     110              25 :             .tuple_windows()
     111              48 :             .find(|(a, b)| a.max_rpi > b.max_rpi);
     112              25 :         if let Some((a, b)) = invalid {
     113               1 :             bail!(
     114               1 :                 "invalid endpoint RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
     115               1 :                 b.max_rpi,
     116               1 :                 a.max_rpi,
     117               1 :             );
     118              24 :         }
     119              24 : 
     120              24 :         Ok(())
     121              25 :     }
     122                 : 
     123             147 :     pub const fn new(max_rps: u32, interval: Duration) -> Self {
     124             147 :         Self {
     125             147 :             interval,
     126             147 :             max_rpi: max_rps * interval.as_millis() as u32 / 1000,
     127             147 :         }
     128             147 :     }
     129                 : }
     130                 : 
     131                 : impl EndpointRateLimiter {
     132              23 :     pub fn new(info: &'static [RateBucketInfo]) -> Self {
     133              23 :         Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
     134              23 :     }
     135                 : }
     136                 : 
     137                 : impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
     138              24 :     fn new_with_rand_and_hasher(info: &'static [RateBucketInfo], rand: R, hasher: S) -> Self {
     139              24 :         info!(buckets = ?info, "endpoint rate limiter");
     140              24 :         Self {
     141              24 :             info,
     142              24 :             map: DashMap::with_hasher_and_shard_amount(hasher, 64),
     143              24 :             access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
     144              24 :             rand: Mutex::new(rand),
     145              24 :         }
     146              24 :     }
     147                 : 
     148                 :     /// Check that number of connections to the endpoint is below `max_rps` rps.
     149         1000750 :     pub fn check(&self, endpoint: SmolStr) -> bool {
     150         1000750 :         // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
     151         1000750 :         // worst case memory usage is about:
     152         1000750 :         //    = 2 * 2048 * 64 * (48B + 72B)
     153         1000750 :         //    = 30MB
     154         1000750 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
     155             488 :             self.do_gc();
     156         1000262 :         }
     157                 : 
     158         1000750 :         let now = Instant::now();
     159         1000750 :         let mut entry = self.map.entry(endpoint).or_insert_with(|| {
     160         1000025 :             vec![
     161         1000025 :                 RateBucket {
     162         1000025 :                     start: now,
     163         1000025 :                     count: 0,
     164         1000025 :                 };
     165         1000025 :                 self.info.len()
     166         1000025 :             ]
     167         1000750 :         });
     168         1000750 : 
     169         1000750 :         let should_allow_request = entry
     170         1000750 :             .iter_mut()
     171         1000750 :             .zip(self.info)
     172         3001544 :             .all(|(bucket, info)| bucket.should_allow_request(info, now));
     173         1000750 : 
     174         1000750 :         if should_allow_request {
     175         1000746 :             // only increment the bucket counts if the request will actually be accepted
     176         1000746 :             entry.iter_mut().for_each(RateBucket::inc);
     177         1000746 :         }
     178                 : 
     179         1000750 :         should_allow_request
     180         1000750 :     }
     181                 : 
     182                 :     /// Clean the map. Simple strategy: remove all entries in a random shard.
     183                 :     /// At worst, we'll double the effective max_rps during the cleanup.
     184                 :     /// But that way deletion does not aquire mutex on each entry access.
     185             488 :     pub fn do_gc(&self) {
     186             488 :         info!(
     187 UBC           0 :             "cleaning up endpoint rate limiter, current size = {}",
     188               0 :             self.map.len()
     189               0 :         );
     190 CBC         488 :         let n = self.map.shards().len();
     191             488 :         // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
     192             488 :         // (impossible, infact, unless we have 2048 threads)
     193             488 :         let shard = self.rand.lock().unwrap().gen_range(0..n);
     194             488 :         self.map.shards()[shard].write().clear();
     195             488 :     }
     196                 : }
     197                 : 
     198                 : /// Limits the number of concurrent jobs.
     199                 : ///
     200                 : /// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
     201                 : /// token once the job is finished.
     202                 : ///
     203                 : /// The limit will be automatically adjusted based on observed latency (delay) and/or failures
     204                 : /// caused by overload (loss).
     205                 : pub struct Limiter {
     206                 :     limit_algo: AsyncMutex<Box<dyn LimitAlgorithm>>,
     207                 :     semaphore: std::sync::Arc<Semaphore>,
     208                 :     config: RateLimiterConfig,
     209                 : 
     210                 :     // ONLY WRITE WHEN LIMIT_ALGO IS LOCKED
     211                 :     limits: AtomicUsize,
     212                 : 
     213                 :     // ONLY USE ATOMIC ADD/SUB
     214                 :     in_flight: Arc<AtomicUsize>,
     215                 : 
     216                 :     #[cfg(test)]
     217                 :     notifier: Option<std::sync::Arc<tokio::sync::Notify>>,
     218                 : }
     219                 : 
     220                 : /// A concurrency token, required to run a job.
     221                 : ///
     222                 : /// Release the token back to the [Limiter] after the job is complete.
     223 UBC           0 : #[derive(Debug)]
     224                 : pub struct Token<'t> {
     225                 :     permit: Option<tokio::sync::SemaphorePermit<'t>>,
     226                 :     start: Instant,
     227                 :     in_flight: Arc<AtomicUsize>,
     228                 : }
     229                 : 
     230                 : /// A snapshot of the state of the [Limiter].
     231                 : ///
     232                 : /// Not guaranteed to be consistent under high concurrency.
     233               0 : #[derive(Debug, Clone, Copy)]
     234                 : pub struct LimiterState {
     235                 :     limit: usize,
     236                 :     in_flight: usize,
     237                 : }
     238                 : 
     239                 : /// Whether a job succeeded or failed as a result of congestion/overload.
     240                 : ///
     241                 : /// Errors not considered to be caused by overload should be ignored.
     242 CBC           6 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
     243                 : pub enum Outcome {
     244                 :     /// The job succeeded, or failed in a way unrelated to overload.
     245                 :     Success,
     246                 :     /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
     247                 :     /// was observed.
     248                 :     Overload,
     249                 : }
     250                 : 
     251                 : impl Outcome {
     252 UBC           0 :     fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self {
     253               0 :         match error {
     254               0 :             reqwest_middleware::Error::Middleware(_) => Outcome::Success,
     255               0 :             reqwest_middleware::Error::Reqwest(e) => {
     256               0 :                 if let Some(status) = e.status() {
     257               0 :                     if status.is_server_error()
     258               0 :                         || reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status
     259                 :                     {
     260               0 :                         Outcome::Overload
     261                 :                     } else {
     262               0 :                         Outcome::Success
     263                 :                     }
     264                 :                 } else {
     265               0 :                     Outcome::Success
     266                 :                 }
     267                 :             }
     268                 :         }
     269               0 :     }
     270 CBC           5 :     fn from_reqwest_response(response: &reqwest::Response) -> Self {
     271               5 :         if response.status().is_server_error()
     272               4 :             || response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS
     273                 :         {
     274               2 :             Outcome::Overload
     275                 :         } else {
     276               3 :             Outcome::Success
     277                 :         }
     278               5 :     }
     279                 : }
     280                 : 
     281                 : impl Limiter {
     282                 :     /// Create a limiter with a given limit control algorithm.
     283               9 :     pub fn new(config: RateLimiterConfig) -> Self {
     284               9 :         assert!(config.initial_limit > 0);
     285               9 :         Self {
     286               9 :             limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()),
     287               9 :             semaphore: Arc::new(Semaphore::new(config.initial_limit)),
     288               9 :             config,
     289               9 :             limits: AtomicUsize::new(config.initial_limit),
     290               9 :             in_flight: Arc::new(AtomicUsize::new(0)),
     291               9 :             #[cfg(test)]
     292               9 :             notifier: None,
     293               9 :         }
     294               9 :     }
     295                 :     // pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self {
     296                 :     //     assert!(initial_limit > 0);
     297                 : 
     298                 :     //     Self {
     299                 :     //         limit_algo: AsyncMutex::new(limit_algorithm),
     300                 :     //         semaphore: Arc::new(Semaphore::new(initial_limit)),
     301                 :     //         timeout,
     302                 :     //         limits: AtomicUsize::new(initial_limit),
     303                 :     //         in_flight: Arc::new(AtomicUsize::new(0)),
     304                 :     //         #[cfg(test)]
     305                 :     //         notifier: None,
     306                 :     //     }
     307                 :     // }
     308                 : 
     309                 :     /// In some cases [Token]s are acquired asynchronously when updating the limit.
     310                 :     #[cfg(test)]
     311               1 :     pub fn with_release_notifier(mut self, n: std::sync::Arc<tokio::sync::Notify>) -> Self {
     312               1 :         self.notifier = Some(n);
     313               1 :         self
     314               1 :     }
     315                 : 
     316                 :     /// Try to immediately acquire a concurrency [Token].
     317                 :     ///
     318                 :     /// Returns `None` if there are none available.
     319              13 :     pub fn try_acquire(&self) -> Option<Token> {
     320              13 :         let result = if self.config.disable {
     321                 :             // If the rate limiter is disabled, we can always acquire a token.
     322               2 :             Some(Token::new(None, self.in_flight.clone()))
     323                 :         } else {
     324              11 :             self.semaphore
     325              11 :                 .try_acquire()
     326              11 :                 .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
     327              11 :                 .ok()
     328                 :         };
     329              13 :         if result.is_some() {
     330              11 :             self.in_flight.fetch_add(1, Ordering::AcqRel);
     331              11 :         }
     332              13 :         result
     333              13 :     }
     334                 : 
     335                 :     /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
     336                 :     ///
     337                 :     /// Returns `None` if there are none available after `duration`.
     338               8 :     pub async fn acquire_timeout(&self, duration: Duration) -> Option<Token<'_>> {
     339               4 :         info!("acquiring token: {:?}", self.semaphore.available_permits());
     340               8 :         let result = if self.config.disable {
     341                 :             // If the rate limiter is disabled, we can always acquire a token.
     342               2 :             Some(Token::new(None, self.in_flight.clone()))
     343                 :         } else {
     344               6 :             match timeout(duration, self.semaphore.acquire()).await {
     345               5 :                 Ok(maybe_permit) => maybe_permit
     346               5 :                     .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
     347               5 :                     .ok(),
     348               1 :                 Err(_) => None,
     349                 :             }
     350                 :         };
     351               8 :         if result.is_some() {
     352               7 :             self.in_flight.fetch_add(1, Ordering::AcqRel);
     353               7 :         }
     354               8 :         result
     355               8 :     }
     356                 : 
     357                 :     /// Return the concurrency [Token], along with the outcome of the job.
     358                 :     ///
     359                 :     /// The [Outcome] of the job, and the time taken to perform it, may be used
     360                 :     /// to update the concurrency limit.
     361                 :     ///
     362                 :     /// Set the outcome to `None` to ignore the job.
     363              16 :     pub async fn release(&self, mut token: Token<'_>, outcome: Option<Outcome>) {
     364               3 :         tracing::info!("outcome is {:?}", outcome);
     365              16 :         let in_flight = self.in_flight.load(Ordering::Acquire);
     366              16 :         let old_limit = self.limits.load(Ordering::Acquire);
     367              16 :         let available = if self.config.disable {
     368               4 :             0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0.
     369                 :         } else {
     370              12 :             self.semaphore.available_permits()
     371                 :         };
     372              16 :         let total = in_flight + available;
     373                 : 
     374              16 :         let mut algo = self.limit_algo.lock().await;
     375                 : 
     376              16 :         let new_limit = if let Some(outcome) = outcome {
     377              13 :             let sample = Sample {
     378              13 :                 latency: token.start.elapsed(),
     379              13 :                 in_flight,
     380              13 :                 outcome,
     381              13 :             };
     382              13 :             algo.update(old_limit, sample).await
     383                 :         } else {
     384               3 :             old_limit
     385                 :         };
     386               3 :         tracing::info!("new limit is {}", new_limit);
     387              16 :         let actual_limit = if new_limit < total {
     388               4 :             token.forget();
     389               4 :             total.saturating_sub(1)
     390                 :         } else {
     391              12 :             if !self.config.disable {
     392               9 :                 self.semaphore.add_permits(new_limit.saturating_sub(total));
     393               9 :             }
     394              12 :             new_limit
     395                 :         };
     396              13 :         crate::metrics::RATE_LIMITER_LIMIT
     397              13 :             .with_label_values(&["expected"])
     398              13 :             .set(new_limit as i64);
     399              13 :         crate::metrics::RATE_LIMITER_LIMIT
     400              13 :             .with_label_values(&["actual"])
     401              13 :             .set(actual_limit as i64);
     402              13 :         self.limits.store(new_limit, Ordering::Release);
     403               3 :         #[cfg(test)]
     404              13 :         if let Some(n) = &self.notifier {
     405               1 :             n.notify_one();
     406              12 :         }
     407              13 :     }
     408                 : 
     409                 :     /// The current state of the limiter.
     410               6 :     pub fn state(&self) -> LimiterState {
     411               6 :         let limit = self.limits.load(Ordering::Relaxed);
     412               6 :         let in_flight = self.in_flight.load(Ordering::Relaxed);
     413               6 :         LimiterState { limit, in_flight }
     414               6 :     }
     415                 : }
     416                 : 
     417                 : impl<'t> Token<'t> {
     418              18 :     fn new(permit: Option<SemaphorePermit<'t>>, in_flight: Arc<AtomicUsize>) -> Self {
     419              18 :         Self {
     420              18 :             permit,
     421              18 :             start: Instant::now(),
     422              18 :             in_flight,
     423              18 :         }
     424              18 :     }
     425                 : 
     426               4 :     pub fn forget(&mut self) {
     427               4 :         if let Some(permit) = self.permit.take() {
     428               3 :             permit.forget();
     429               3 :         }
     430               4 :     }
     431                 : }
     432                 : 
     433                 : impl Drop for Token<'_> {
     434              18 :     fn drop(&mut self) {
     435              18 :         self.in_flight.fetch_sub(1, Ordering::AcqRel);
     436              18 :     }
     437                 : }
     438                 : 
     439                 : impl LimiterState {
     440                 :     /// The current concurrency limit.
     441               6 :     pub fn limit(&self) -> usize {
     442               6 :         self.limit
     443               6 :     }
     444                 :     /// The number of jobs in flight.
     445               1 :     pub fn in_flight(&self) -> usize {
     446               1 :         self.in_flight
     447               1 :     }
     448                 : }
     449                 : 
     450                 : #[async_trait::async_trait]
     451                 : impl reqwest_middleware::Middleware for Limiter {
     452               6 :     async fn handle(
     453               6 :         &self,
     454               6 :         req: reqwest::Request,
     455               6 :         extensions: &mut task_local_extensions::Extensions,
     456               6 :         next: reqwest_middleware::Next<'_>,
     457               6 :     ) -> reqwest_middleware::Result<reqwest::Response> {
     458               6 :         let start = Instant::now();
     459               6 :         let token = self
     460               6 :             .acquire_timeout(self.config.timeout)
     461               1 :             .await
     462               6 :             .ok_or_else(|| {
     463               1 :                 reqwest_middleware::Error::Middleware(
     464               1 :                     // TODO: Should we map it into user facing errors?
     465               1 :                     crate::console::errors::ApiError::Console {
     466               1 :                         status: crate::http::StatusCode::TOO_MANY_REQUESTS,
     467               1 :                         text: "Too many requests".into(),
     468               1 :                     }
     469               1 :                     .into(),
     470               1 :                 )
     471               6 :             })?;
     472               5 :         info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane");
     473               5 :         crate::metrics::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64());
     474              15 :         match next.run(req, extensions).await {
     475               5 :             Ok(response) => {
     476               5 :                 self.release(token, Some(Outcome::from_reqwest_response(&response)))
     477 UBC           0 :                     .await;
     478 CBC           5 :                 Ok(response)
     479                 :             }
     480 UBC           0 :             Err(e) => {
     481               0 :                 self.release(token, Some(Outcome::from_reqwest_error(&e)))
     482               0 :                     .await;
     483               0 :                 Err(e)
     484                 :             }
     485                 :         }
     486 CBC          12 :     }
     487                 : }
     488                 : 
     489                 : #[cfg(test)]
     490                 : mod tests {
     491                 :     use std::{hash::BuildHasherDefault, pin::pin, task::Context, time::Duration};
     492                 : 
     493                 :     use futures::{task::noop_waker_ref, Future};
     494                 :     use rand::SeedableRng;
     495                 :     use rustc_hash::FxHasher;
     496                 :     use smol_str::SmolStr;
     497                 :     use tokio::time;
     498                 : 
     499                 :     use super::{EndpointRateLimiter, Limiter, Outcome};
     500                 :     use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm};
     501                 : 
     502               1 :     #[tokio::test]
     503               1 :     async fn it_works() {
     504               1 :         let config = super::RateLimiterConfig {
     505               1 :             algorithm: RateLimitAlgorithm::Fixed,
     506               1 :             timeout: Duration::from_secs(1),
     507               1 :             initial_limit: 10,
     508               1 :             disable: false,
     509               1 :             ..Default::default()
     510               1 :         };
     511               1 :         let limiter = Limiter::new(config);
     512               1 : 
     513               1 :         let token = limiter.try_acquire().unwrap();
     514               1 : 
     515               1 :         limiter.release(token, Some(Outcome::Success)).await;
     516                 : 
     517               1 :         assert_eq!(limiter.state().limit(), 10);
     518                 :     }
     519                 : 
     520               1 :     #[tokio::test]
     521               1 :     async fn is_fair() {
     522               1 :         let config = super::RateLimiterConfig {
     523               1 :             algorithm: RateLimitAlgorithm::Fixed,
     524               1 :             timeout: Duration::from_secs(1),
     525               1 :             initial_limit: 1,
     526               1 :             disable: false,
     527               1 :             ..Default::default()
     528               1 :         };
     529               1 :         let limiter = Limiter::new(config);
     530               1 : 
     531               1 :         // === TOKEN 1 ===
     532               1 :         let token1 = limiter.try_acquire().unwrap();
     533               1 : 
     534               1 :         let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
     535               1 :         assert!(
     536               1 :             token2_fut
     537               1 :                 .as_mut()
     538               1 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     539               1 :                 .is_pending(),
     540 UBC           0 :             "token is acquired by token1"
     541                 :         );
     542                 : 
     543 CBC           1 :         let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
     544               1 :         assert!(
     545               1 :             token3_fut
     546               1 :                 .as_mut()
     547               1 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     548               1 :                 .is_pending(),
     549 UBC           0 :             "token is acquired by token1"
     550                 :         );
     551                 : 
     552 CBC           1 :         limiter.release(token1, Some(Outcome::Success)).await;
     553                 :         // === END TOKEN 1 ===
     554                 : 
     555                 :         // === TOKEN 2 ===
     556               1 :         assert!(
     557               1 :             limiter.try_acquire().is_none(),
     558 UBC           0 :             "token is acquired by token2"
     559                 :         );
     560                 : 
     561 CBC           1 :         assert!(
     562               1 :             token3_fut
     563               1 :                 .as_mut()
     564               1 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     565               1 :                 .is_pending(),
     566 UBC           0 :             "token is acquired by token2"
     567                 :         );
     568                 : 
     569 CBC           1 :         let token2 = token2_fut.await.unwrap();
     570               1 : 
     571               1 :         limiter.release(token2, Some(Outcome::Success)).await;
     572                 :         // === END TOKEN 2 ===
     573                 : 
     574                 :         // === TOKEN 3 ===
     575               1 :         assert!(
     576               1 :             limiter.try_acquire().is_none(),
     577 UBC           0 :             "token is acquired by token3"
     578                 :         );
     579                 : 
     580 CBC           1 :         let token3 = token3_fut.await.unwrap();
     581               1 :         limiter.release(token3, Some(Outcome::Success)).await;
     582                 :         // === END TOKEN 3 ===
     583                 : 
     584                 :         // === TOKEN 4 ===
     585               1 :         let token4 = limiter.try_acquire().unwrap();
     586               1 :         limiter.release(token4, Some(Outcome::Success)).await;
     587                 :     }
     588                 : 
     589               1 :     #[tokio::test]
     590               1 :     async fn disable() {
     591               1 :         let config = super::RateLimiterConfig {
     592               1 :             algorithm: RateLimitAlgorithm::Fixed,
     593               1 :             timeout: Duration::from_secs(1),
     594               1 :             initial_limit: 1,
     595               1 :             disable: true,
     596               1 :             ..Default::default()
     597               1 :         };
     598               1 :         let limiter = Limiter::new(config);
     599               1 : 
     600               1 :         // === TOKEN 1 ===
     601               1 :         let token1 = limiter.try_acquire().unwrap();
     602               1 :         let token2 = limiter.try_acquire().unwrap();
     603               1 :         let state = limiter.state();
     604               1 :         assert_eq!(state.limit(), 1);
     605               1 :         assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected.
     606               1 :         limiter.release(token1, None).await;
     607               1 :         limiter.release(token2, None).await;
     608                 :     }
     609                 : 
     610               1 :     #[test]
     611               1 :     fn rate_bucket_rpi() {
     612               1 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
     613               1 :         assert_eq!(rate_bucket.max_rpi, 50 * 5);
     614                 : 
     615               1 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
     616               1 :         assert_eq!(rate_bucket.max_rpi, 50 / 2);
     617               1 :     }
     618                 : 
     619               1 :     #[test]
     620               1 :     fn rate_bucket_parse() {
     621               1 :         let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
     622               1 :         assert_eq!(rate_bucket.interval, Duration::from_secs(10));
     623               1 :         assert_eq!(rate_bucket.max_rpi, 100 * 10);
     624               1 :         assert_eq!(rate_bucket.to_string(), "100@10s");
     625                 : 
     626               1 :         let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
     627               1 :         assert_eq!(rate_bucket.interval, Duration::from_secs(60));
     628               1 :         assert_eq!(rate_bucket.max_rpi, 100 * 60);
     629               1 :         assert_eq!(rate_bucket.to_string(), "100@1m");
     630               1 :     }
     631                 : 
     632               1 :     #[test]
     633               1 :     fn default_rate_buckets() {
     634               1 :         let mut defaults = RateBucketInfo::DEFAULT_SET;
     635               1 :         RateBucketInfo::validate(&mut defaults[..]).unwrap();
     636               1 :     }
     637                 : 
     638               1 :     #[test]
     639                 :     #[should_panic = "invalid endpoint RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
     640               1 :     fn rate_buckets_validate() {
     641               1 :         let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
     642               1 :             .into_iter()
     643               2 :             .map(|s| s.parse().unwrap())
     644               1 :             .collect();
     645               1 :         RateBucketInfo::validate(&mut rates).unwrap();
     646               1 :     }
     647                 : 
     648               1 :     #[tokio::test]
     649               1 :     async fn test_rate_limits() {
     650               1 :         let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
     651               1 :             .into_iter()
     652               2 :             .map(|s| s.parse().unwrap())
     653               1 :             .collect();
     654               1 :         RateBucketInfo::validate(&mut rates).unwrap();
     655               1 :         let limiter = EndpointRateLimiter::new(Vec::leak(rates));
     656               1 : 
     657               1 :         let endpoint = SmolStr::from("ep-my-endpoint-1234");
     658               1 : 
     659               1 :         time::pause();
     660                 : 
     661             101 :         for _ in 0..100 {
     662             100 :             assert!(limiter.check(endpoint.clone()));
     663                 :         }
     664                 :         // more connections fail
     665               1 :         assert!(!limiter.check(endpoint.clone()));
     666                 : 
     667                 :         // fail even after 500ms as it's in the same bucket
     668               1 :         time::advance(time::Duration::from_millis(500)).await;
     669               1 :         assert!(!limiter.check(endpoint.clone()));
     670                 : 
     671                 :         // after a full 1s, 100 requests are allowed again
     672               1 :         time::advance(time::Duration::from_millis(500)).await;
     673               6 :         for _ in 1..6 {
     674             505 :             for _ in 0..100 {
     675             500 :                 assert!(limiter.check(endpoint.clone()));
     676                 :             }
     677               5 :             time::advance(time::Duration::from_millis(1000)).await;
     678                 :         }
     679                 : 
     680                 :         // more connections after 600 will exceed the 20rps@30s limit
     681               1 :         assert!(!limiter.check(endpoint.clone()));
     682                 : 
     683                 :         // will still fail before the 30 second limit
     684               1 :         time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
     685               1 :         assert!(!limiter.check(endpoint.clone()));
     686                 : 
     687                 :         // after the full 30 seconds, 100 requests are allowed again
     688               1 :         time::advance(time::Duration::from_millis(1)).await;
     689             101 :         for _ in 0..100 {
     690             100 :             assert!(limiter.check(endpoint.clone()));
     691                 :         }
     692                 :     }
     693                 : 
     694               1 :     #[tokio::test]
     695               1 :     async fn test_rate_limits_gc() {
     696               1 :         // fixed seeded random/hasher to ensure that the test is not flaky
     697               1 :         let rand = rand::rngs::StdRng::from_seed([1; 32]);
     698               1 :         let hasher = BuildHasherDefault::<FxHasher>::default();
     699               1 : 
     700               1 :         let limiter = EndpointRateLimiter::new_with_rand_and_hasher(
     701               1 :             &RateBucketInfo::DEFAULT_SET,
     702               1 :             rand,
     703               1 :             hasher,
     704               1 :         );
     705         1000001 :         for i in 0..1_000_000 {
     706         1000000 :             limiter.check(format!("{i}").into());
     707         1000000 :         }
     708               1 :         assert!(limiter.map.len() < 150_000);
     709                 :     }
     710                 : }
        

Generated by: LCOV version 2.1-beta