LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - limiter.rs (source / functions) Coverage Total Hit
Test: 36bb8dd7c7efcb53483d1a7d9f7cb33e8406dcf0.info Lines: 87.5 % 497 435
Test Date: 2024-04-08 10:22:05 Functions: 68.5 % 89 61

            Line data    Source code
       1              : use std::{
       2              :     borrow::Cow,
       3              :     collections::hash_map::RandomState,
       4              :     hash::{BuildHasher, Hash},
       5              :     net::IpAddr,
       6              :     sync::{
       7              :         atomic::{AtomicUsize, Ordering},
       8              :         Arc, Mutex,
       9              :     },
      10              : };
      11              : 
      12              : use anyhow::bail;
      13              : use dashmap::DashMap;
      14              : use itertools::Itertools;
      15              : use rand::{rngs::StdRng, Rng, SeedableRng};
      16              : use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
      17              : use tokio::time::{timeout, Duration, Instant};
      18              : use tracing::info;
      19              : 
      20              : use crate::{intern::EndpointIdInt, EndpointId};
      21              : 
      22              : use super::{
      23              :     limit_algorithm::{LimitAlgorithm, Sample},
      24              :     RateLimiterConfig,
      25              : };
      26              : 
      27              : pub struct RedisRateLimiter {
      28              :     data: Vec<RateBucket>,
      29              :     info: &'static [RateBucketInfo],
      30              : }
      31              : 
      32              : impl RedisRateLimiter {
      33            0 :     pub fn new(info: &'static [RateBucketInfo]) -> Self {
      34            0 :         Self {
      35            0 :             data: vec![
      36            0 :                 RateBucket {
      37            0 :                     start: Instant::now(),
      38            0 :                     count: 0,
      39            0 :                 };
      40            0 :                 info.len()
      41            0 :             ],
      42            0 :             info,
      43            0 :         }
      44            0 :     }
      45              : 
      46              :     /// Check that number of connections is below `max_rps` rps.
      47            0 :     pub fn check(&mut self) -> bool {
      48            0 :         let now = Instant::now();
      49            0 : 
      50            0 :         let should_allow_request = self
      51            0 :             .data
      52            0 :             .iter_mut()
      53            0 :             .zip(self.info)
      54            0 :             .all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
      55            0 : 
      56            0 :         if should_allow_request {
      57            0 :             // only increment the bucket counts if the request will actually be accepted
      58            0 :             self.data.iter_mut().for_each(|b| b.inc(1));
      59            0 :         }
      60              : 
      61            0 :         should_allow_request
      62            0 :     }
      63              : }
      64              : 
      65              : // Simple per-endpoint rate limiter.
      66              : //
      67              : // Check that number of connections to the endpoint is below `max_rps` rps.
      68              : // Purposefully ignore user name and database name as clients can reconnect
      69              : // with different names, so we'll end up sending some http requests to
      70              : // the control plane.
      71              : //
      72              : // We also may save quite a lot of CPU (I think) by bailing out right after we
      73              : // saw SNI, before doing TLS handshake. User-side error messages in that case
      74              : // does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
      75              : // I went with a more expensive way that yields user-friendlier error messages.
      76              : pub type EndpointRateLimiter = BucketRateLimiter<EndpointId, StdRng, RandomState>;
      77              : 
      78              : // This can't be just per IP because that would limit some PaaS that share IP addresses
      79              : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, IpAddr), StdRng, RandomState>;
      80              : 
      81              : pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
      82              :     map: DashMap<Key, Vec<RateBucket>, Hasher>,
      83              :     info: Cow<'static, [RateBucketInfo]>,
      84              :     access_count: AtomicUsize,
      85              :     rand: Mutex<Rand>,
      86              : }
      87              : 
      88              : #[derive(Clone, Copy)]
      89              : struct RateBucket {
      90              :     start: Instant,
      91              :     count: u32,
      92              : }
      93              : 
      94              : impl RateBucket {
      95      6001830 :     fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
      96      6001830 :         if now - self.start < info.interval {
      97      6001814 :             self.count + n <= info.max_rpi
      98              :         } else {
      99              :             // bucket expired, reset
     100           16 :             self.count = 0;
     101           16 :             self.start = now;
     102           16 : 
     103           16 :             true
     104              :         }
     105      6001830 :     }
     106              : 
     107      6001818 :     fn inc(&mut self, n: u32) {
     108      6001818 :         self.count += n;
     109      6001818 :     }
     110              : }
     111              : 
     112              : #[derive(Clone, Copy, PartialEq)]
     113              : pub struct RateBucketInfo {
     114              :     pub interval: Duration,
     115              :     // requests per interval
     116              :     pub max_rpi: u32,
     117              : }
     118              : 
     119              : impl std::fmt::Display for RateBucketInfo {
     120           32 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     121           32 :         let rps = (self.max_rpi as u64) * 1000 / self.interval.as_millis() as u64;
     122           32 :         write!(f, "{rps}@{}", humantime::format_duration(self.interval))
     123           32 :     }
     124              : }
     125              : 
     126              : impl std::fmt::Debug for RateBucketInfo {
     127            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     128            0 :         write!(f, "{self}")
     129            0 :     }
     130              : }
     131              : 
     132              : impl std::str::FromStr for RateBucketInfo {
     133              :     type Err = anyhow::Error;
     134              : 
     135           52 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
     136           52 :         let Some((max_rps, interval)) = s.split_once('@') else {
     137            0 :             bail!("invalid rate info")
     138              :         };
     139           52 :         let max_rps = max_rps.parse()?;
     140           52 :         let interval = humantime::parse_duration(interval)?;
     141           52 :         Ok(Self::new(max_rps, interval))
     142           52 :     }
     143              : }
     144              : 
     145              : impl RateBucketInfo {
     146              :     pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
     147              :         Self::new(300, Duration::from_secs(1)),
     148              :         Self::new(200, Duration::from_secs(60)),
     149              :         Self::new(100, Duration::from_secs(600)),
     150              :     ];
     151              : 
     152              :     /// All of these are per endpoint-ip pair.
     153              :     /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
     154              :     ///
     155              :     /// First bucket: 300mcpus total per endpoint-ip pair
     156              :     /// * 1228800 requests per second with 1 hash rounds. (endpoint rate limiter will catch this first)
     157              :     /// * 300 requests per second with 4096 hash rounds.
     158              :     /// * 2 requests per second with 600000 hash rounds.
     159              :     pub const DEFAULT_AUTH_SET: [Self; 3] = [
     160              :         Self::new(300 * 4096, Duration::from_secs(1)),
     161              :         Self::new(200 * 4096, Duration::from_secs(60)),
     162              :         Self::new(100 * 4096, Duration::from_secs(600)),
     163              :     ];
     164              : 
     165            6 :     pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
     166           16 :         info.sort_unstable_by_key(|info| info.interval);
     167            6 :         let invalid = info
     168            6 :             .iter()
     169            6 :             .tuple_windows()
     170            8 :             .find(|(a, b)| a.max_rpi > b.max_rpi);
     171            6 :         if let Some((a, b)) = invalid {
     172            2 :             bail!(
     173            2 :                 "invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
     174            2 :                 b.max_rpi,
     175            2 :                 a.max_rpi,
     176            2 :             );
     177            4 :         }
     178            4 : 
     179            4 :         Ok(())
     180            6 :     }
     181              : 
     182           60 :     pub const fn new(max_rps: u32, interval: Duration) -> Self {
     183           60 :         Self {
     184           60 :             interval,
     185           60 :             max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
     186           60 :         }
     187           60 :     }
     188              : }
     189              : 
     190              : impl<K: Hash + Eq> BucketRateLimiter<K> {
     191            8 :     pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
     192            8 :         Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
     193            8 :     }
     194              : }
     195              : 
     196              : impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
     197           10 :     fn new_with_rand_and_hasher(
     198           10 :         info: impl Into<Cow<'static, [RateBucketInfo]>>,
     199           10 :         rand: R,
     200           10 :         hasher: S,
     201           10 :     ) -> Self {
     202           10 :         let info = info.into();
     203           10 :         info!(buckets = ?info, "endpoint rate limiter");
     204           10 :         Self {
     205           10 :             info,
     206           10 :             map: DashMap::with_hasher_and_shard_amount(hasher, 64),
     207           10 :             access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
     208           10 :             rand: Mutex::new(rand),
     209           10 :         }
     210           10 :     }
     211              : 
     212              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
     213      2000914 :     pub fn check(&self, key: K, n: u32) -> bool {
     214      2000914 :         // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
     215      2000914 :         // worst case memory usage is about:
     216      2000914 :         //    = 2 * 2048 * 64 * (48B + 72B)
     217      2000914 :         //    = 30MB
     218      2000914 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
     219          976 :             self.do_gc();
     220      1999938 :         }
     221              : 
     222      2000914 :         let now = Instant::now();
     223      2000914 :         let mut entry = self.map.entry(key).or_insert_with(|| {
     224      2000008 :             vec![
     225      2000008 :                 RateBucket {
     226      2000008 :                     start: now,
     227      2000008 :                     count: 0,
     228      2000008 :                 };
     229      2000008 :                 self.info.len()
     230      2000008 :             ]
     231      2000914 :         });
     232      2000914 : 
     233      2000914 :         let should_allow_request = entry
     234      2000914 :             .iter_mut()
     235      2000914 :             .zip(&*self.info)
     236      6001830 :             .all(|(bucket, info)| bucket.should_allow_request(info, now, n));
     237      2000914 : 
     238      2000914 :         if should_allow_request {
     239      2000906 :             // only increment the bucket counts if the request will actually be accepted
     240      6001818 :             entry.iter_mut().for_each(|b| b.inc(n));
     241      2000906 :         }
     242              : 
     243      2000914 :         should_allow_request
     244      2000914 :     }
     245              : 
     246              :     /// Clean the map. Simple strategy: remove all entries in a random shard.
     247              :     /// At worst, we'll double the effective max_rps during the cleanup.
     248              :     /// But that way deletion does not aquire mutex on each entry access.
     249          976 :     pub fn do_gc(&self) {
     250          976 :         info!(
     251            0 :             "cleaning up bucket rate limiter, current size = {}",
     252            0 :             self.map.len()
     253            0 :         );
     254          976 :         let n = self.map.shards().len();
     255          976 :         // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
     256          976 :         // (impossible, infact, unless we have 2048 threads)
     257          976 :         let shard = self.rand.lock().unwrap().gen_range(0..n);
     258          976 :         self.map.shards()[shard].write().clear();
     259          976 :     }
     260              : }
     261              : 
     262              : /// Limits the number of concurrent jobs.
     263              : ///
     264              : /// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
     265              : /// token once the job is finished.
     266              : ///
     267              : /// The limit will be automatically adjusted based on observed latency (delay) and/or failures
     268              : /// caused by overload (loss).
     269              : pub struct Limiter {
     270              :     limit_algo: AsyncMutex<Box<dyn LimitAlgorithm>>,
     271              :     semaphore: std::sync::Arc<Semaphore>,
     272              :     config: RateLimiterConfig,
     273              : 
     274              :     // ONLY WRITE WHEN LIMIT_ALGO IS LOCKED
     275              :     limits: AtomicUsize,
     276              : 
     277              :     // ONLY USE ATOMIC ADD/SUB
     278              :     in_flight: Arc<AtomicUsize>,
     279              : 
     280              :     #[cfg(test)]
     281              :     notifier: Option<std::sync::Arc<tokio::sync::Notify>>,
     282              : }
     283              : 
     284              : /// A concurrency token, required to run a job.
     285              : ///
     286              : /// Release the token back to the [Limiter] after the job is complete.
     287              : #[derive(Debug)]
     288              : pub struct Token<'t> {
     289              :     permit: Option<tokio::sync::SemaphorePermit<'t>>,
     290              :     start: Instant,
     291              :     in_flight: Arc<AtomicUsize>,
     292              : }
     293              : 
     294              : /// A snapshot of the state of the [Limiter].
     295              : ///
     296              : /// Not guaranteed to be consistent under high concurrency.
     297              : #[derive(Debug, Clone, Copy)]
     298              : pub struct LimiterState {
     299              :     limit: usize,
     300              :     in_flight: usize,
     301              : }
     302              : 
     303              : /// Whether a job succeeded or failed as a result of congestion/overload.
     304              : ///
     305              : /// Errors not considered to be caused by overload should be ignored.
     306              : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
     307              : pub enum Outcome {
     308              :     /// The job succeeded, or failed in a way unrelated to overload.
     309              :     Success,
     310              :     /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
     311              :     /// was observed.
     312              :     Overload,
     313              : }
     314              : 
     315              : impl Outcome {
     316            0 :     fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self {
     317            0 :         match error {
     318            0 :             reqwest_middleware::Error::Middleware(_) => Outcome::Success,
     319            0 :             reqwest_middleware::Error::Reqwest(e) => {
     320            0 :                 if let Some(status) = e.status() {
     321            0 :                     if status.is_server_error()
     322            0 :                         || reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status
     323              :                     {
     324            0 :                         Outcome::Overload
     325              :                     } else {
     326            0 :                         Outcome::Success
     327              :                     }
     328              :                 } else {
     329            0 :                     Outcome::Success
     330              :                 }
     331              :             }
     332              :         }
     333            0 :     }
     334            4 :     fn from_reqwest_response(response: &reqwest::Response) -> Self {
     335            4 :         if response.status().is_server_error()
     336            4 :             || response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS
     337              :         {
     338            0 :             Outcome::Overload
     339              :         } else {
     340            4 :             Outcome::Success
     341              :         }
     342            4 :     }
     343              : }
     344              : 
     345              : impl Limiter {
     346              :     /// Create a limiter with a given limit control algorithm.
     347           16 :     pub fn new(config: RateLimiterConfig) -> Self {
     348           16 :         assert!(config.initial_limit > 0);
     349           16 :         Self {
     350           16 :             limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()),
     351           16 :             semaphore: Arc::new(Semaphore::new(config.initial_limit)),
     352           16 :             config,
     353           16 :             limits: AtomicUsize::new(config.initial_limit),
     354           16 :             in_flight: Arc::new(AtomicUsize::new(0)),
     355           16 :             #[cfg(test)]
     356           16 :             notifier: None,
     357           16 :         }
     358           16 :     }
     359              :     // pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self {
     360              :     //     assert!(initial_limit > 0);
     361              : 
     362              :     //     Self {
     363              :     //         limit_algo: AsyncMutex::new(limit_algorithm),
     364              :     //         semaphore: Arc::new(Semaphore::new(initial_limit)),
     365              :     //         timeout,
     366              :     //         limits: AtomicUsize::new(initial_limit),
     367              :     //         in_flight: Arc::new(AtomicUsize::new(0)),
     368              :     //         #[cfg(test)]
     369              :     //         notifier: None,
     370              :     //     }
     371              :     // }
     372              : 
     373              :     /// In some cases [Token]s are acquired asynchronously when updating the limit.
     374              :     #[cfg(test)]
     375            2 :     pub fn with_release_notifier(mut self, n: std::sync::Arc<tokio::sync::Notify>) -> Self {
     376            2 :         self.notifier = Some(n);
     377            2 :         self
     378            2 :     }
     379              : 
     380              :     /// Try to immediately acquire a concurrency [Token].
     381              :     ///
     382              :     /// Returns `None` if there are none available.
     383           26 :     pub fn try_acquire(&self) -> Option<Token> {
     384           26 :         let result = if self.config.disable {
     385              :             // If the rate limiter is disabled, we can always acquire a token.
     386            4 :             Some(Token::new(None, self.in_flight.clone()))
     387              :         } else {
     388           22 :             self.semaphore
     389           22 :                 .try_acquire()
     390           22 :                 .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
     391           22 :                 .ok()
     392              :         };
     393           26 :         if result.is_some() {
     394           22 :             self.in_flight.fetch_add(1, Ordering::AcqRel);
     395           22 :         }
     396           26 :         result
     397           26 :     }
     398              : 
     399              :     /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
     400              :     ///
     401              :     /// Returns `None` if there are none available after `duration`.
     402            8 :     pub async fn acquire_timeout(&self, duration: Duration) -> Option<Token<'_>> {
     403            8 :         info!("acquiring token: {:?}", self.semaphore.available_permits());
     404            8 :         let result = if self.config.disable {
     405              :             // If the rate limiter is disabled, we can always acquire a token.
     406            4 :             Some(Token::new(None, self.in_flight.clone()))
     407              :         } else {
     408            6 :             match timeout(duration, self.semaphore.acquire()).await {
     409            4 :                 Ok(maybe_permit) => maybe_permit
     410            4 :                     .map(|permit| Token::new(Some(permit), self.in_flight.clone()))
     411            4 :                     .ok(),
     412            0 :                 Err(_) => None,
     413              :             }
     414              :         };
     415            8 :         if result.is_some() {
     416            8 :             self.in_flight.fetch_add(1, Ordering::AcqRel);
     417            8 :         }
     418            8 :         result
     419            8 :     }
     420              : 
     421              :     /// Return the concurrency [Token], along with the outcome of the job.
     422              :     ///
     423              :     /// The [Outcome] of the job, and the time taken to perform it, may be used
     424              :     /// to update the concurrency limit.
     425              :     ///
     426              :     /// Set the outcome to `None` to ignore the job.
     427           26 :     pub async fn release(&self, mut token: Token<'_>, outcome: Option<Outcome>) {
     428           26 :         tracing::info!("outcome is {:?}", outcome);
     429           26 :         let in_flight = self.in_flight.load(Ordering::Acquire);
     430           26 :         let old_limit = self.limits.load(Ordering::Acquire);
     431           26 :         let available = if self.config.disable {
     432            8 :             0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0.
     433              :         } else {
     434           18 :             self.semaphore.available_permits()
     435              :         };
     436           26 :         let total = in_flight + available;
     437              : 
     438           26 :         let mut algo = self.limit_algo.lock().await;
     439              : 
     440           26 :         let new_limit = if let Some(outcome) = outcome {
     441           20 :             let sample = Sample {
     442           20 :                 latency: token.start.elapsed(),
     443           20 :                 in_flight,
     444           20 :                 outcome,
     445           20 :             };
     446           20 :             algo.update(old_limit, sample).await
     447              :         } else {
     448            6 :             old_limit
     449              :         };
     450           26 :         tracing::info!("new limit is {}", new_limit);
     451           26 :         let actual_limit = if new_limit < total {
     452            4 :             token.forget();
     453            4 :             total.saturating_sub(1)
     454              :         } else {
     455           22 :             if !self.config.disable {
     456           16 :                 self.semaphore.add_permits(new_limit.saturating_sub(total));
     457           16 :             }
     458           22 :             new_limit
     459              :         };
     460           26 :         crate::metrics::RATE_LIMITER_LIMIT
     461           26 :             .with_label_values(&["expected"])
     462           26 :             .set(new_limit as i64);
     463           26 :         crate::metrics::RATE_LIMITER_LIMIT
     464           26 :             .with_label_values(&["actual"])
     465           26 :             .set(actual_limit as i64);
     466           26 :         self.limits.store(new_limit, Ordering::Release);
     467            0 :         #[cfg(test)]
     468           26 :         if let Some(n) = &self.notifier {
     469            2 :             n.notify_one();
     470           24 :         }
     471           26 :     }
     472              : 
     473              :     /// The current state of the limiter.
     474           12 :     pub fn state(&self) -> LimiterState {
     475           12 :         let limit = self.limits.load(Ordering::Relaxed);
     476           12 :         let in_flight = self.in_flight.load(Ordering::Relaxed);
     477           12 :         LimiterState { limit, in_flight }
     478           12 :     }
     479              : }
     480              : 
     481              : impl<'t> Token<'t> {
     482           30 :     fn new(permit: Option<SemaphorePermit<'t>>, in_flight: Arc<AtomicUsize>) -> Self {
     483           30 :         Self {
     484           30 :             permit,
     485           30 :             start: Instant::now(),
     486           30 :             in_flight,
     487           30 :         }
     488           30 :     }
     489              : 
     490            4 :     pub fn forget(&mut self) {
     491            4 :         if let Some(permit) = self.permit.take() {
     492            2 :             permit.forget();
     493            2 :         }
     494            4 :     }
     495              : }
     496              : 
     497              : impl Drop for Token<'_> {
     498           30 :     fn drop(&mut self) {
     499           30 :         self.in_flight.fetch_sub(1, Ordering::AcqRel);
     500           30 :     }
     501              : }
     502              : 
     503              : impl LimiterState {
     504              :     /// The current concurrency limit.
     505           12 :     pub fn limit(&self) -> usize {
     506           12 :         self.limit
     507           12 :     }
     508              :     /// The number of jobs in flight.
     509            2 :     pub fn in_flight(&self) -> usize {
     510            2 :         self.in_flight
     511            2 :     }
     512              : }
     513              : 
     514              : #[async_trait::async_trait]
     515              : impl reqwest_middleware::Middleware for Limiter {
     516            4 :     async fn handle(
     517            4 :         &self,
     518            4 :         req: reqwest::Request,
     519            4 :         extensions: &mut task_local_extensions::Extensions,
     520            4 :         next: reqwest_middleware::Next<'_>,
     521            4 :     ) -> reqwest_middleware::Result<reqwest::Response> {
     522            4 :         let start = Instant::now();
     523            4 :         let token = self
     524            4 :             .acquire_timeout(self.config.timeout)
     525            0 :             .await
     526            4 :             .ok_or_else(|| {
     527            0 :                 reqwest_middleware::Error::Middleware(
     528            0 :                     // TODO: Should we map it into user facing errors?
     529            0 :                     crate::console::errors::ApiError::Console {
     530            0 :                         status: crate::http::StatusCode::TOO_MANY_REQUESTS,
     531            0 :                         text: "Too many requests".into(),
     532            0 :                     }
     533            0 :                     .into(),
     534            0 :                 )
     535            4 :             })?;
     536            4 :         info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane");
     537            4 :         crate::metrics::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64());
     538            8 :         match next.run(req, extensions).await {
     539            4 :             Ok(response) => {
     540            4 :                 self.release(token, Some(Outcome::from_reqwest_response(&response)))
     541            0 :                     .await;
     542            4 :                 Ok(response)
     543              :             }
     544            0 :             Err(e) => {
     545            0 :                 self.release(token, Some(Outcome::from_reqwest_error(&e)))
     546            0 :                     .await;
     547            0 :                 Err(e)
     548              :             }
     549              :         }
     550           12 :     }
     551              : }
     552              : 
     553              : #[cfg(test)]
     554              : mod tests {
     555              :     use std::{hash::BuildHasherDefault, pin::pin, task::Context, time::Duration};
     556              : 
     557              :     use futures::{task::noop_waker_ref, Future};
     558              :     use rand::SeedableRng;
     559              :     use rustc_hash::FxHasher;
     560              :     use tokio::time;
     561              : 
     562              :     use super::{BucketRateLimiter, EndpointRateLimiter, Limiter, Outcome};
     563              :     use crate::{
     564              :         rate_limiter::{RateBucketInfo, RateLimitAlgorithm},
     565              :         EndpointId,
     566              :     };
     567              : 
     568              :     #[tokio::test]
     569            2 :     async fn it_works() {
     570            2 :         let config = super::RateLimiterConfig {
     571            2 :             algorithm: RateLimitAlgorithm::Fixed,
     572            2 :             timeout: Duration::from_secs(1),
     573            2 :             initial_limit: 10,
     574            2 :             disable: false,
     575            2 :             ..Default::default()
     576            2 :         };
     577            2 :         let limiter = Limiter::new(config);
     578            2 : 
     579            2 :         let token = limiter.try_acquire().unwrap();
     580            2 : 
     581            2 :         limiter.release(token, Some(Outcome::Success)).await;
     582            2 : 
     583            2 :         assert_eq!(limiter.state().limit(), 10);
     584            2 :     }
     585              : 
     586              :     #[tokio::test]
     587            2 :     async fn is_fair() {
     588            2 :         let config = super::RateLimiterConfig {
     589            2 :             algorithm: RateLimitAlgorithm::Fixed,
     590            2 :             timeout: Duration::from_secs(1),
     591            2 :             initial_limit: 1,
     592            2 :             disable: false,
     593            2 :             ..Default::default()
     594            2 :         };
     595            2 :         let limiter = Limiter::new(config);
     596            2 : 
     597            2 :         // === TOKEN 1 ===
     598            2 :         let token1 = limiter.try_acquire().unwrap();
     599            2 : 
     600            2 :         let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
     601            2 :         assert!(
     602            2 :             token2_fut
     603            2 :                 .as_mut()
     604            2 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     605            2 :                 .is_pending(),
     606            2 :             "token is acquired by token1"
     607            2 :         );
     608            2 : 
     609            2 :         let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
     610            2 :         assert!(
     611            2 :             token3_fut
     612            2 :                 .as_mut()
     613            2 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     614            2 :                 .is_pending(),
     615            2 :             "token is acquired by token1"
     616            2 :         );
     617            2 : 
     618            2 :         limiter.release(token1, Some(Outcome::Success)).await;
     619            2 :         // === END TOKEN 1 ===
     620            2 : 
     621            2 :         // === TOKEN 2 ===
     622            2 :         assert!(
     623            2 :             limiter.try_acquire().is_none(),
     624            2 :             "token is acquired by token2"
     625            2 :         );
     626            2 : 
     627            2 :         assert!(
     628            2 :             token3_fut
     629            2 :                 .as_mut()
     630            2 :                 .poll(&mut Context::from_waker(noop_waker_ref()))
     631            2 :                 .is_pending(),
     632            2 :             "token is acquired by token2"
     633            2 :         );
     634            2 : 
     635            2 :         let token2 = token2_fut.await.unwrap();
     636            2 : 
     637            2 :         limiter.release(token2, Some(Outcome::Success)).await;
     638            2 :         // === END TOKEN 2 ===
     639            2 : 
     640            2 :         // === TOKEN 3 ===
     641            2 :         assert!(
     642            2 :             limiter.try_acquire().is_none(),
     643            2 :             "token is acquired by token3"
     644            2 :         );
     645            2 : 
     646            2 :         let token3 = token3_fut.await.unwrap();
     647            2 :         limiter.release(token3, Some(Outcome::Success)).await;
     648            2 :         // === END TOKEN 3 ===
     649            2 : 
     650            2 :         // === TOKEN 4 ===
     651            2 :         let token4 = limiter.try_acquire().unwrap();
     652            2 :         limiter.release(token4, Some(Outcome::Success)).await;
     653            2 :     }
     654              : 
     655              :     #[tokio::test]
     656            2 :     async fn disable() {
     657            2 :         let config = super::RateLimiterConfig {
     658            2 :             algorithm: RateLimitAlgorithm::Fixed,
     659            2 :             timeout: Duration::from_secs(1),
     660            2 :             initial_limit: 1,
     661            2 :             disable: true,
     662            2 :             ..Default::default()
     663            2 :         };
     664            2 :         let limiter = Limiter::new(config);
     665            2 : 
     666            2 :         // === TOKEN 1 ===
     667            2 :         let token1 = limiter.try_acquire().unwrap();
     668            2 :         let token2 = limiter.try_acquire().unwrap();
     669            2 :         let state = limiter.state();
     670            2 :         assert_eq!(state.limit(), 1);
     671            2 :         assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected.
     672            2 :         limiter.release(token1, None).await;
     673            2 :         limiter.release(token2, None).await;
     674            2 :     }
     675              : 
     676              :     #[test]
     677            2 :     fn rate_bucket_rpi() {
     678            2 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
     679            2 :         assert_eq!(rate_bucket.max_rpi, 50 * 5);
     680              : 
     681            2 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
     682            2 :         assert_eq!(rate_bucket.max_rpi, 50 / 2);
     683            2 :     }
     684              : 
     685              :     #[test]
     686            2 :     fn rate_bucket_parse() {
     687            2 :         let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
     688            2 :         assert_eq!(rate_bucket.interval, Duration::from_secs(10));
     689            2 :         assert_eq!(rate_bucket.max_rpi, 100 * 10);
     690            2 :         assert_eq!(rate_bucket.to_string(), "100@10s");
     691              : 
     692            2 :         let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
     693            2 :         assert_eq!(rate_bucket.interval, Duration::from_secs(60));
     694            2 :         assert_eq!(rate_bucket.max_rpi, 100 * 60);
     695            2 :         assert_eq!(rate_bucket.to_string(), "100@1m");
     696            2 :     }
     697              : 
     698              :     #[test]
     699            2 :     fn default_rate_buckets() {
     700            2 :         let mut defaults = RateBucketInfo::DEFAULT_ENDPOINT_SET;
     701            2 :         RateBucketInfo::validate(&mut defaults[..]).unwrap();
     702            2 :     }
     703              : 
     704              :     #[test]
     705              :     #[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
     706            2 :     fn rate_buckets_validate() {
     707            2 :         let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
     708            2 :             .into_iter()
     709            4 :             .map(|s| s.parse().unwrap())
     710            2 :             .collect();
     711            2 :         RateBucketInfo::validate(&mut rates).unwrap();
     712            2 :     }
     713              : 
     714              :     #[tokio::test]
     715            2 :     async fn test_rate_limits() {
     716            2 :         let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
     717            2 :             .into_iter()
     718            4 :             .map(|s| s.parse().unwrap())
     719            2 :             .collect();
     720            2 :         RateBucketInfo::validate(&mut rates).unwrap();
     721            2 :         let limiter = EndpointRateLimiter::new(rates);
     722            2 : 
     723            2 :         let endpoint = EndpointId::from("ep-my-endpoint-1234");
     724            2 : 
     725            2 :         time::pause();
     726            2 : 
     727          202 :         for _ in 0..100 {
     728          200 :             assert!(limiter.check(endpoint.clone(), 1));
     729            2 :         }
     730            2 :         // more connections fail
     731            2 :         assert!(!limiter.check(endpoint.clone(), 1));
     732            2 : 
     733            2 :         // fail even after 500ms as it's in the same bucket
     734            2 :         time::advance(time::Duration::from_millis(500)).await;
     735            2 :         assert!(!limiter.check(endpoint.clone(), 1));
     736            2 : 
     737            2 :         // after a full 1s, 100 requests are allowed again
     738            2 :         time::advance(time::Duration::from_millis(500)).await;
     739           12 :         for _ in 1..6 {
     740          510 :             for _ in 0..50 {
     741          500 :                 assert!(limiter.check(endpoint.clone(), 2));
     742            2 :             }
     743           10 :             time::advance(time::Duration::from_millis(1000)).await;
     744            2 :         }
     745            2 : 
     746            2 :         // more connections after 600 will exceed the 20rps@30s limit
     747            2 :         assert!(!limiter.check(endpoint.clone(), 1));
     748            2 : 
     749            2 :         // will still fail before the 30 second limit
     750            2 :         time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
     751            2 :         assert!(!limiter.check(endpoint.clone(), 1));
     752            2 : 
     753            2 :         // after the full 30 seconds, 100 requests are allowed again
     754            2 :         time::advance(time::Duration::from_millis(1)).await;
     755          202 :         for _ in 0..100 {
     756          200 :             assert!(limiter.check(endpoint.clone(), 1));
     757            2 :         }
     758            2 :     }
     759              : 
     760              :     #[tokio::test]
     761            2 :     async fn test_rate_limits_gc() {
     762            2 :         // fixed seeded random/hasher to ensure that the test is not flaky
     763            2 :         let rand = rand::rngs::StdRng::from_seed([1; 32]);
     764            2 :         let hasher = BuildHasherDefault::<FxHasher>::default();
     765            2 : 
     766            2 :         let limiter = BucketRateLimiter::new_with_rand_and_hasher(
     767            2 :             &RateBucketInfo::DEFAULT_ENDPOINT_SET,
     768            2 :             rand,
     769            2 :             hasher,
     770            2 :         );
     771      2000002 :         for i in 0..1_000_000 {
     772      2000000 :             limiter.check(i, 1);
     773      2000000 :         }
     774            2 :         assert!(limiter.map.len() < 150_000);
     775            2 :     }
     776              : 
     777              :     #[test]
     778            2 :     fn test_default_auth_set() {
     779            2 :         // these values used to exceed u32::MAX
     780            2 :         assert_eq!(
     781            2 :             RateBucketInfo::DEFAULT_AUTH_SET,
     782            2 :             [
     783            2 :                 RateBucketInfo {
     784            2 :                     interval: Duration::from_secs(1),
     785            2 :                     max_rpi: 300 * 4096,
     786            2 :                 },
     787            2 :                 RateBucketInfo {
     788            2 :                     interval: Duration::from_secs(60),
     789            2 :                     max_rpi: 200 * 4096 * 60,
     790            2 :                 },
     791            2 :                 RateBucketInfo {
     792            2 :                     interval: Duration::from_secs(600),
     793            2 :                     max_rpi: 100 * 4096 * 600,
     794            2 :                 }
     795            2 :             ]
     796            2 :         );
     797              : 
     798            8 :         for x in RateBucketInfo::DEFAULT_AUTH_SET {
     799            6 :             let y = x.to_string().parse().unwrap();
     800            6 :             assert_eq!(x, y);
     801              :         }
     802            2 :     }
     803              : }
        

Generated by: LCOV version 2.1-beta