LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - limiter.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 84.9 % 218 185
Test Date: 2025-02-20 13:11:02 Functions: 80.4 % 46 37

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::collections::hash_map::RandomState;
       3              : use std::hash::{BuildHasher, Hash};
       4              : use std::sync::atomic::{AtomicUsize, Ordering};
       5              : use std::sync::Mutex;
       6              : 
       7              : use anyhow::bail;
       8              : use clashmap::ClashMap;
       9              : use itertools::Itertools;
      10              : use rand::rngs::StdRng;
      11              : use rand::{Rng, SeedableRng};
      12              : use tokio::time::{Duration, Instant};
      13              : use tracing::info;
      14              : 
      15              : use crate::ext::LockExt;
      16              : use crate::intern::EndpointIdInt;
      17              : 
      18              : pub struct GlobalRateLimiter {
      19              :     data: Vec<RateBucket>,
      20              :     info: Vec<RateBucketInfo>,
      21              : }
      22              : 
      23              : impl GlobalRateLimiter {
      24            0 :     pub fn new(info: Vec<RateBucketInfo>) -> Self {
      25            0 :         Self {
      26            0 :             data: vec![
      27            0 :                 RateBucket {
      28            0 :                     start: Instant::now(),
      29            0 :                     count: 0,
      30            0 :                 };
      31            0 :                 info.len()
      32            0 :             ],
      33            0 :             info,
      34            0 :         }
      35            0 :     }
      36              : 
      37              :     /// Check that number of connections is below `max_rps` rps.
      38            0 :     pub fn check(&mut self) -> bool {
      39            0 :         let now = Instant::now();
      40            0 : 
      41            0 :         let should_allow_request = self
      42            0 :             .data
      43            0 :             .iter_mut()
      44            0 :             .zip(&self.info)
      45            0 :             .all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
      46            0 : 
      47            0 :         if should_allow_request {
      48            0 :             // only increment the bucket counts if the request will actually be accepted
      49            0 :             self.data.iter_mut().for_each(|b| b.inc(1));
      50            0 :         }
      51              : 
      52            0 :         should_allow_request
      53            0 :     }
      54              : }
      55              : 
      56              : // Simple per-endpoint rate limiter.
      57              : //
      58              : // Check that number of connections to the endpoint is below `max_rps` rps.
      59              : // Purposefully ignore user name and database name as clients can reconnect
      60              : // with different names, so we'll end up sending some http requests to
      61              : // the control plane.
      62              : pub type WakeComputeRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
      63              : 
      64              : pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
      65              :     map: ClashMap<Key, Vec<RateBucket>, Hasher>,
      66              :     info: Cow<'static, [RateBucketInfo]>,
      67              :     access_count: AtomicUsize,
      68              :     rand: Mutex<Rand>,
      69              : }
      70              : 
      71              : #[derive(Clone, Copy)]
      72              : struct RateBucket {
      73              :     start: Instant,
      74              :     count: u32,
      75              : }
      76              : 
      77              : impl RateBucket {
      78      3000915 :     fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
      79      3000915 :         if now - self.start < info.interval {
      80      3000907 :             self.count + n <= info.max_rpi
      81              :         } else {
      82              :             // bucket expired, reset
      83            8 :             self.count = 0;
      84            8 :             self.start = now;
      85            8 : 
      86            8 :             true
      87              :         }
      88      3000915 :     }
      89              : 
      90      3000909 :     fn inc(&mut self, n: u32) {
      91      3000909 :         self.count += n;
      92      3000909 :     }
      93              : }
      94              : 
      95              : #[derive(Clone, Copy, PartialEq)]
      96              : pub struct RateBucketInfo {
      97              :     pub(crate) interval: Duration,
      98              :     // requests per interval
      99              :     pub(crate) max_rpi: u32,
     100              : }
     101              : 
     102              : impl std::fmt::Display for RateBucketInfo {
     103           18 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     104           18 :         let rps = self.rps().floor() as u64;
     105           18 :         write!(f, "{rps}@{}", humantime::format_duration(self.interval))
     106           18 :     }
     107              : }
     108              : 
     109              : impl std::fmt::Debug for RateBucketInfo {
     110            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     111            0 :         write!(f, "{self}")
     112            0 :     }
     113              : }
     114              : 
     115              : impl std::str::FromStr for RateBucketInfo {
     116              :     type Err = anyhow::Error;
     117              : 
     118           19 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
     119           19 :         let Some((max_rps, interval)) = s.split_once('@') else {
     120            0 :             bail!("invalid rate info")
     121              :         };
     122           19 :         let max_rps = max_rps.parse()?;
     123           19 :         let interval = humantime::parse_duration(interval)?;
     124           19 :         Ok(Self::new(max_rps, interval))
     125           19 :     }
     126              : }
     127              : 
     128              : impl RateBucketInfo {
     129              :     pub const DEFAULT_SET: [Self; 3] = [
     130              :         Self::new(300, Duration::from_secs(1)),
     131              :         Self::new(200, Duration::from_secs(60)),
     132              :         Self::new(100, Duration::from_secs(600)),
     133              :     ];
     134              : 
     135              :     pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
     136              :         Self::new(500, Duration::from_secs(1)),
     137              :         Self::new(300, Duration::from_secs(60)),
     138              :         Self::new(200, Duration::from_secs(600)),
     139              :     ];
     140              : 
     141              :     // For all the sessions will be cancel key. So this limit is essentially global proxy limit.
     142              :     pub const DEFAULT_REDIS_SET: [Self; 2] = [
     143              :         Self::new(100_000, Duration::from_secs(1)),
     144              :         Self::new(50_000, Duration::from_secs(10)),
     145              :     ];
     146              : 
     147              :     /// All of these are per endpoint-maskedip pair.
     148              :     /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
     149              :     ///
     150              :     /// First bucket: 1000mcpus total per endpoint-ip pair
     151              :     /// * 4096000 requests per second with 1 hash rounds.
     152              :     /// * 1000 requests per second with 4096 hash rounds.
     153              :     /// * 6.8 requests per second with 600000 hash rounds.
     154              :     pub const DEFAULT_AUTH_SET: [Self; 3] = [
     155              :         Self::new(1000 * 4096, Duration::from_secs(1)),
     156              :         Self::new(600 * 4096, Duration::from_secs(60)),
     157              :         Self::new(300 * 4096, Duration::from_secs(600)),
     158              :     ];
     159              : 
     160           18 :     pub fn rps(&self) -> f64 {
     161           18 :         (self.max_rpi as f64) / self.interval.as_secs_f64()
     162           18 :     }
     163              : 
     164            3 :     pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
     165            8 :         info.sort_unstable_by_key(|info| info.interval);
     166            3 :         let invalid = info
     167            3 :             .iter()
     168            3 :             .tuple_windows()
     169            4 :             .find(|(a, b)| a.max_rpi > b.max_rpi);
     170            3 :         if let Some((a, b)) = invalid {
     171            1 :             bail!(
     172            1 :                 "invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
     173            1 :                 b.max_rpi,
     174            1 :                 a.max_rpi,
     175            1 :             );
     176            2 :         }
     177            2 : 
     178            2 :         Ok(())
     179            3 :     }
     180              : 
     181           23 :     pub const fn new(max_rps: u32, interval: Duration) -> Self {
     182           23 :         Self {
     183           23 :             interval,
     184           23 :             max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
     185           23 :         }
     186           23 :     }
     187              : }
     188              : 
     189              : impl<K: Hash + Eq> BucketRateLimiter<K> {
     190            4 :     pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
     191            4 :         Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
     192            4 :     }
     193              : }
     194              : 
     195              : impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
     196            5 :     fn new_with_rand_and_hasher(
     197            5 :         info: impl Into<Cow<'static, [RateBucketInfo]>>,
     198            5 :         rand: R,
     199            5 :         hasher: S,
     200            5 :     ) -> Self {
     201            5 :         let info = info.into();
     202            5 :         info!(buckets = ?info, "endpoint rate limiter");
     203            5 :         Self {
     204            5 :             info,
     205            5 :             map: ClashMap::with_hasher_and_shard_amount(hasher, 64),
     206            5 :             access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
     207            5 :             rand: Mutex::new(rand),
     208            5 :         }
     209            5 :     }
     210              : 
     211              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
     212      1000457 :     pub(crate) fn check(&self, key: K, n: u32) -> bool {
     213      1000457 :         // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
     214      1000457 :         // worst case memory usage is about:
     215      1000457 :         //    = 2 * 2048 * 64 * (48B + 72B)
     216      1000457 :         //    = 30MB
     217      1000457 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
     218          488 :             self.do_gc();
     219       999969 :         }
     220              : 
     221      1000457 :         let now = Instant::now();
     222      1000457 :         let mut entry = self.map.entry(key).or_insert_with(|| {
     223      1000004 :             vec![
     224      1000004 :                 RateBucket {
     225      1000004 :                     start: now,
     226      1000004 :                     count: 0,
     227      1000004 :                 };
     228      1000004 :                 self.info.len()
     229      1000004 :             ]
     230      1000457 :         });
     231      1000457 : 
     232      1000457 :         let should_allow_request = entry
     233      1000457 :             .iter_mut()
     234      1000457 :             .zip(&*self.info)
     235      3000915 :             .all(|(bucket, info)| bucket.should_allow_request(info, now, n));
     236      1000457 : 
     237      1000457 :         if should_allow_request {
     238      1000453 :             // only increment the bucket counts if the request will actually be accepted
     239      3000909 :             entry.iter_mut().for_each(|b| b.inc(n));
     240      1000453 :         }
     241              : 
     242      1000457 :         should_allow_request
     243      1000457 :     }
     244              : 
     245              :     /// Clean the map. Simple strategy: remove all entries in a random shard.
     246              :     /// At worst, we'll double the effective max_rps during the cleanup.
     247              :     /// But that way deletion does not aquire mutex on each entry access.
     248          488 :     pub(crate) fn do_gc(&self) {
     249          488 :         info!(
     250            0 :             "cleaning up bucket rate limiter, current size = {}",
     251            0 :             self.map.len()
     252              :         );
     253          488 :         let n = self.map.shards().len();
     254          488 :         // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
     255          488 :         // (impossible, infact, unless we have 2048 threads)
     256          488 :         let shard = self.rand.lock_propagate_poison().gen_range(0..n);
     257          488 :         self.map.shards()[shard].write().clear();
     258          488 :     }
     259              : }
     260              : 
     261              : #[cfg(test)]
     262              : #[expect(clippy::unwrap_used)]
     263              : mod tests {
     264              :     use std::hash::BuildHasherDefault;
     265              :     use std::time::Duration;
     266              : 
     267              :     use rand::SeedableRng;
     268              :     use rustc_hash::FxHasher;
     269              :     use tokio::time;
     270              : 
     271              :     use super::{BucketRateLimiter, WakeComputeRateLimiter};
     272              :     use crate::intern::EndpointIdInt;
     273              :     use crate::rate_limiter::RateBucketInfo;
     274              :     use crate::types::EndpointId;
     275              : 
     276              :     #[test]
     277            1 :     fn rate_bucket_rpi() {
     278            1 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
     279            1 :         assert_eq!(rate_bucket.max_rpi, 50 * 5);
     280              : 
     281            1 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
     282            1 :         assert_eq!(rate_bucket.max_rpi, 50 / 2);
     283            1 :     }
     284              : 
     285              :     #[test]
     286            1 :     fn rate_bucket_parse() {
     287            1 :         let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
     288            1 :         assert_eq!(rate_bucket.interval, Duration::from_secs(10));
     289            1 :         assert_eq!(rate_bucket.max_rpi, 100 * 10);
     290            1 :         assert_eq!(rate_bucket.to_string(), "100@10s");
     291              : 
     292            1 :         let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
     293            1 :         assert_eq!(rate_bucket.interval, Duration::from_secs(60));
     294            1 :         assert_eq!(rate_bucket.max_rpi, 100 * 60);
     295            1 :         assert_eq!(rate_bucket.to_string(), "100@1m");
     296            1 :     }
     297              : 
     298              :     #[test]
     299            1 :     fn default_rate_buckets() {
     300            1 :         let mut defaults = RateBucketInfo::DEFAULT_SET;
     301            1 :         RateBucketInfo::validate(&mut defaults[..]).unwrap();
     302            1 :     }
     303              : 
     304              :     #[test]
     305              :     #[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
     306            1 :     fn rate_buckets_validate() {
     307            1 :         let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
     308            1 :             .into_iter()
     309            2 :             .map(|s| s.parse().unwrap())
     310            1 :             .collect();
     311            1 :         RateBucketInfo::validate(&mut rates).unwrap();
     312            1 :     }
     313              : 
     314              :     #[tokio::test]
     315            1 :     async fn test_rate_limits() {
     316            1 :         let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
     317            1 :             .into_iter()
     318            2 :             .map(|s| s.parse().unwrap())
     319            1 :             .collect();
     320            1 :         RateBucketInfo::validate(&mut rates).unwrap();
     321            1 :         let limiter = WakeComputeRateLimiter::new(rates);
     322            1 : 
     323            1 :         let endpoint = EndpointId::from("ep-my-endpoint-1234");
     324            1 :         let endpoint = EndpointIdInt::from(endpoint);
     325            1 : 
     326            1 :         time::pause();
     327            1 : 
     328          101 :         for _ in 0..100 {
     329          100 :             assert!(limiter.check(endpoint, 1));
     330            1 :         }
     331            1 :         // more connections fail
     332            1 :         assert!(!limiter.check(endpoint, 1));
     333            1 : 
     334            1 :         // fail even after 500ms as it's in the same bucket
     335            1 :         time::advance(time::Duration::from_millis(500)).await;
     336            1 :         assert!(!limiter.check(endpoint, 1));
     337            1 : 
     338            1 :         // after a full 1s, 100 requests are allowed again
     339            1 :         time::advance(time::Duration::from_millis(500)).await;
     340            6 :         for _ in 1..6 {
     341          255 :             for _ in 0..50 {
     342          250 :                 assert!(limiter.check(endpoint, 2));
     343            1 :             }
     344            5 :             time::advance(time::Duration::from_millis(1000)).await;
     345            1 :         }
     346            1 : 
     347            1 :         // more connections after 600 will exceed the 20rps@30s limit
     348            1 :         assert!(!limiter.check(endpoint, 1));
     349            1 : 
     350            1 :         // will still fail before the 30 second limit
     351            1 :         time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
     352            1 :         assert!(!limiter.check(endpoint, 1));
     353            1 : 
     354            1 :         // after the full 30 seconds, 100 requests are allowed again
     355            1 :         time::advance(time::Duration::from_millis(1)).await;
     356          101 :         for _ in 0..100 {
     357          100 :             assert!(limiter.check(endpoint, 1));
     358            1 :         }
     359            1 :     }
     360              : 
     361              :     #[tokio::test]
     362            1 :     async fn test_rate_limits_gc() {
     363            1 :         // fixed seeded random/hasher to ensure that the test is not flaky
     364            1 :         let rand = rand::rngs::StdRng::from_seed([1; 32]);
     365            1 :         let hasher = BuildHasherDefault::<FxHasher>::default();
     366            1 : 
     367            1 :         let limiter =
     368            1 :             BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
     369      1000001 :         for i in 0..1_000_000 {
     370      1000000 :             limiter.check(i, 1);
     371      1000000 :         }
     372            1 :         assert!(limiter.map.len() < 150_000);
     373            1 :     }
     374              : }
        

Generated by: LCOV version 2.1-beta