LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - limiter.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 91.2 % 170 155
Test Date: 2025-07-16 12:29:03 Functions: 85.7 % 35 30

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::collections::hash_map::RandomState;
       3              : use std::hash::{BuildHasher, Hash};
       4              : use std::sync::Mutex;
       5              : use std::sync::atomic::{AtomicUsize, Ordering};
       6              : 
       7              : use anyhow::bail;
       8              : use clashmap::ClashMap;
       9              : use itertools::Itertools;
      10              : use rand::rngs::StdRng;
      11              : use rand::{Rng, SeedableRng};
      12              : use tokio::time::{Duration, Instant};
      13              : use tracing::info;
      14              : 
      15              : use super::LeakyBucketConfig;
      16              : use crate::ext::LockExt;
      17              : use crate::intern::EndpointIdInt;
      18              : 
      19              : // Simple per-endpoint rate limiter.
      20              : //
      21              : // Check that number of connections to the endpoint is below `max_rps` rps.
      22              : // Purposefully ignore user name and database name as clients can reconnect
      23              : // with different names, so we'll end up sending some http requests to
      24              : // the control plane.
      25              : pub type WakeComputeRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
      26              : 
      27              : pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
      28              :     map: ClashMap<Key, Vec<RateBucket>, Hasher>,
      29              :     info: Cow<'static, [RateBucketInfo]>,
      30              :     access_count: AtomicUsize,
      31              :     rand: Mutex<Rand>,
      32              : }
      33              : 
      34              : #[derive(Clone, Copy)]
      35              : struct RateBucket {
      36              :     start: Instant,
      37              :     count: u32,
      38              : }
      39              : 
      40              : impl RateBucket {
      41      3000906 :     fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
      42      3000906 :         if now - self.start < info.interval {
      43      3000898 :             self.count + n <= info.max_rpi
      44              :         } else {
      45              :             // bucket expired, reset
      46            8 :             self.count = 0;
      47            8 :             self.start = now;
      48              : 
      49            8 :             true
      50              :         }
      51      3000906 :     }
      52              : 
      53      3000900 :     fn inc(&mut self, n: u32) {
      54      3000900 :         self.count += n;
      55      3000900 :     }
      56              : }
      57              : 
      58              : #[derive(Clone, Copy, PartialEq)]
      59              : pub struct RateBucketInfo {
      60              :     pub(crate) interval: Duration,
      61              :     // requests per interval
      62              :     pub(crate) max_rpi: u32,
      63              : }
      64              : 
      65              : impl std::fmt::Display for RateBucketInfo {
      66           10 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      67           10 :         let rps = self.rps().floor() as u64;
      68           10 :         write!(f, "{rps}@{}", humantime::format_duration(self.interval))
      69           10 :     }
      70              : }
      71              : 
      72              : impl std::fmt::Debug for RateBucketInfo {
      73            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      74            0 :         write!(f, "{self}")
      75            0 :     }
      76              : }
      77              : 
      78              : impl std::str::FromStr for RateBucketInfo {
      79              :     type Err = anyhow::Error;
      80              : 
      81           11 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
      82           11 :         let Some((max_rps, interval)) = s.split_once('@') else {
      83            0 :             bail!("invalid rate info")
      84              :         };
      85           11 :         let max_rps = max_rps.parse()?;
      86           11 :         let interval = humantime::parse_duration(interval)?;
      87           11 :         Ok(Self::new(max_rps, interval))
      88           11 :     }
      89              : }
      90              : 
      91              : impl RateBucketInfo {
      92              :     pub const DEFAULT_SET: [Self; 3] = [
      93              :         Self::new(300, Duration::from_secs(1)),
      94              :         Self::new(200, Duration::from_secs(60)),
      95              :         Self::new(100, Duration::from_secs(600)),
      96              :     ];
      97              : 
      98              :     pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
      99              :         Self::new(500, Duration::from_secs(1)),
     100              :         Self::new(300, Duration::from_secs(60)),
     101              :         Self::new(200, Duration::from_secs(600)),
     102              :     ];
     103              : 
     104           10 :     pub fn rps(&self) -> f64 {
     105           10 :         (self.max_rpi as f64) / self.interval.as_secs_f64()
     106           10 :     }
     107              : 
     108            3 :     pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
     109            3 :         info.sort_unstable_by_key(|info| info.interval);
     110            3 :         let invalid = info
     111            3 :             .iter()
     112            3 :             .tuple_windows()
     113            4 :             .find(|(a, b)| a.max_rpi > b.max_rpi);
     114            3 :         if let Some((a, b)) = invalid {
     115            1 :             bail!(
     116            1 :                 "invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
     117              :                 b.max_rpi,
     118              :                 a.max_rpi,
     119              :             );
     120            2 :         }
     121              : 
     122            2 :         Ok(())
     123            3 :     }
     124              : 
     125           15 :     pub const fn new(max_rps: u32, interval: Duration) -> Self {
     126           15 :         Self {
     127           15 :             interval,
     128           15 :             max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
     129           15 :         }
     130           15 :     }
     131              : 
     132            0 :     pub fn to_leaky_bucket(this: &[Self]) -> Option<LeakyBucketConfig> {
     133              :         // bit of a hack - find the min rps and max rps supported and turn it into
     134              :         // leaky bucket config instead
     135              : 
     136            0 :         let mut iter = this.iter().map(|info| info.rps());
     137            0 :         let first = iter.next()?;
     138              : 
     139            0 :         let (min, max) = (first, first);
     140            0 :         let (min, max) = iter.fold((min, max), |(min, max), rps| {
     141            0 :             (f64::min(min, rps), f64::max(max, rps))
     142            0 :         });
     143              : 
     144            0 :         Some(LeakyBucketConfig { rps: min, max })
     145            0 :     }
     146              : }
     147              : 
     148              : impl<K: Hash + Eq> BucketRateLimiter<K> {
     149            1 :     pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
     150            1 :         Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
     151            1 :     }
     152              : }
     153              : 
     154              : impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
     155            2 :     fn new_with_rand_and_hasher(
     156            2 :         info: impl Into<Cow<'static, [RateBucketInfo]>>,
     157            2 :         rand: R,
     158            2 :         hasher: S,
     159            2 :     ) -> Self {
     160            2 :         let info = info.into();
     161            2 :         info!(buckets = ?info, "endpoint rate limiter");
     162            2 :         Self {
     163            2 :             info,
     164            2 :             map: ClashMap::with_hasher_and_shard_amount(hasher, 64),
     165            2 :             access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
     166            2 :             rand: Mutex::new(rand),
     167            2 :         }
     168            2 :     }
     169              : 
     170              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
     171      1000454 :     pub(crate) fn check(&self, key: K, n: u32) -> bool {
     172              :         // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
     173              :         // worst case memory usage is about:
     174              :         //    = 2 * 2048 * 64 * (48B + 72B)
     175              :         //    = 30MB
     176      1000454 :         if self
     177      1000454 :             .access_count
     178      1000454 :             .fetch_add(1, Ordering::AcqRel)
     179      1000454 :             .is_multiple_of(2048)
     180          488 :         {
     181          488 :             self.do_gc();
     182       999966 :         }
     183              : 
     184      1000454 :         let now = Instant::now();
     185      1000454 :         let mut entry = self.map.entry(key).or_insert_with(|| {
     186      1000001 :             vec![
     187      1000001 :                 RateBucket {
     188      1000001 :                     start: now,
     189      1000001 :                     count: 0,
     190      1000001 :                 };
     191      1000001 :                 self.info.len()
     192              :             ]
     193      1000001 :         });
     194              : 
     195      1000454 :         let should_allow_request = entry
     196      1000454 :             .iter_mut()
     197      1000454 :             .zip(&*self.info)
     198      3000906 :             .all(|(bucket, info)| bucket.should_allow_request(info, now, n));
     199              : 
     200      1000454 :         if should_allow_request {
     201              :             // only increment the bucket counts if the request will actually be accepted
     202      3000900 :             entry.iter_mut().for_each(|b| b.inc(n));
     203            4 :         }
     204              : 
     205      1000454 :         should_allow_request
     206      1000454 :     }
     207              : 
     208              :     /// Clean the map. Simple strategy: remove all entries in a random shard.
     209              :     /// At worst, we'll double the effective max_rps during the cleanup.
     210              :     /// But that way deletion does not aquire mutex on each entry access.
     211          488 :     pub(crate) fn do_gc(&self) {
     212          488 :         info!(
     213            0 :             "cleaning up bucket rate limiter, current size = {}",
     214            0 :             self.map.len()
     215              :         );
     216          488 :         let n = self.map.shards().len();
     217              :         // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
     218              :         // (impossible, infact, unless we have 2048 threads)
     219          488 :         let shard = self.rand.lock_propagate_poison().gen_range(0..n);
     220          488 :         self.map.shards()[shard].write().clear();
     221          488 :     }
     222              : }
     223              : 
     224              : #[cfg(test)]
     225              : mod tests {
     226              :     use std::hash::BuildHasherDefault;
     227              :     use std::time::Duration;
     228              : 
     229              :     use rand::SeedableRng;
     230              :     use rustc_hash::FxHasher;
     231              :     use tokio::time;
     232              : 
     233              :     use super::{BucketRateLimiter, WakeComputeRateLimiter};
     234              :     use crate::intern::EndpointIdInt;
     235              :     use crate::rate_limiter::RateBucketInfo;
     236              :     use crate::types::EndpointId;
     237              : 
     238              :     #[test]
     239            1 :     fn rate_bucket_rpi() {
     240            1 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
     241            1 :         assert_eq!(rate_bucket.max_rpi, 50 * 5);
     242              : 
     243            1 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
     244            1 :         assert_eq!(rate_bucket.max_rpi, 50 / 2);
     245            1 :     }
     246              : 
     247              :     #[test]
     248            1 :     fn rate_bucket_parse() {
     249            1 :         let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
     250            1 :         assert_eq!(rate_bucket.interval, Duration::from_secs(10));
     251            1 :         assert_eq!(rate_bucket.max_rpi, 100 * 10);
     252            1 :         assert_eq!(rate_bucket.to_string(), "100@10s");
     253              : 
     254            1 :         let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
     255            1 :         assert_eq!(rate_bucket.interval, Duration::from_secs(60));
     256            1 :         assert_eq!(rate_bucket.max_rpi, 100 * 60);
     257            1 :         assert_eq!(rate_bucket.to_string(), "100@1m");
     258            1 :     }
     259              : 
     260              :     #[test]
     261            1 :     fn default_rate_buckets() {
     262            1 :         let mut defaults = RateBucketInfo::DEFAULT_SET;
     263            1 :         RateBucketInfo::validate(&mut defaults[..]).unwrap();
     264            1 :     }
     265              : 
     266              :     #[test]
     267              :     #[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
     268            1 :     fn rate_buckets_validate() {
     269            1 :         let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
     270            1 :             .into_iter()
     271            2 :             .map(|s| s.parse().unwrap())
     272            1 :             .collect();
     273            1 :         RateBucketInfo::validate(&mut rates).unwrap();
     274            1 :     }
     275              : 
     276              :     #[tokio::test]
     277            1 :     async fn test_rate_limits() {
     278            1 :         let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
     279            1 :             .into_iter()
     280            2 :             .map(|s| s.parse().unwrap())
     281            1 :             .collect();
     282            1 :         RateBucketInfo::validate(&mut rates).unwrap();
     283            1 :         let limiter = WakeComputeRateLimiter::new(rates);
     284              : 
     285            1 :         let endpoint = EndpointId::from("ep-my-endpoint-1234");
     286            1 :         let endpoint = EndpointIdInt::from(endpoint);
     287              : 
     288            1 :         time::pause();
     289              : 
     290          101 :         for _ in 0..100 {
     291          100 :             assert!(limiter.check(endpoint, 1));
     292              :         }
     293              :         // more connections fail
     294            1 :         assert!(!limiter.check(endpoint, 1));
     295              : 
     296              :         // fail even after 500ms as it's in the same bucket
     297            1 :         time::advance(time::Duration::from_millis(500)).await;
     298            1 :         assert!(!limiter.check(endpoint, 1));
     299              : 
     300              :         // after a full 1s, 100 requests are allowed again
     301            1 :         time::advance(time::Duration::from_millis(500)).await;
     302            6 :         for _ in 1..6 {
     303          255 :             for _ in 0..50 {
     304          250 :                 assert!(limiter.check(endpoint, 2));
     305              :             }
     306            5 :             time::advance(time::Duration::from_millis(1000)).await;
     307              :         }
     308              : 
     309              :         // more connections after 600 will exceed the 20rps@30s limit
     310            1 :         assert!(!limiter.check(endpoint, 1));
     311              : 
     312              :         // will still fail before the 30 second limit
     313            1 :         time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
     314            1 :         assert!(!limiter.check(endpoint, 1));
     315              : 
     316              :         // after the full 30 seconds, 100 requests are allowed again
     317            1 :         time::advance(time::Duration::from_millis(1)).await;
     318          101 :         for _ in 0..100 {
     319          100 :             assert!(limiter.check(endpoint, 1));
     320            1 :         }
     321            1 :     }
     322              : 
     323              :     #[tokio::test]
     324            1 :     async fn test_rate_limits_gc() {
     325              :         // fixed seeded random/hasher to ensure that the test is not flaky
     326            1 :         let rand = rand::rngs::StdRng::from_seed([1; 32]);
     327            1 :         let hasher = BuildHasherDefault::<FxHasher>::default();
     328              : 
     329            1 :         let limiter =
     330            1 :             BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
     331      1000001 :         for i in 0..1_000_000 {
     332      1000000 :             limiter.check(i, 1);
     333      1000000 :         }
     334            1 :         assert!(limiter.map.len() < 150_000);
     335            1 :     }
     336              : }
        

Generated by: LCOV version 2.1-beta