LCOV - code coverage report
Current view: top level - proxy/src/rate_limiter - limiter.rs (source / functions) Coverage Total Hit
Test: 02e8c57acd6e2b986849f552ca30280d54699b79.info Lines: 84.7 % 215 182
Test Date: 2024-06-26 17:13:54 Functions: 77.6 % 49 38

            Line data    Source code
       1              : use std::{
       2              :     borrow::Cow,
       3              :     collections::hash_map::RandomState,
       4              :     hash::{BuildHasher, Hash},
       5              :     sync::{
       6              :         atomic::{AtomicUsize, Ordering},
       7              :         Mutex,
       8              :     },
       9              : };
      10              : 
      11              : use anyhow::bail;
      12              : use dashmap::DashMap;
      13              : use itertools::Itertools;
      14              : use rand::{rngs::StdRng, Rng, SeedableRng};
      15              : use tokio::time::{Duration, Instant};
      16              : use tracing::info;
      17              : 
      18              : use crate::intern::EndpointIdInt;
      19              : 
      20              : pub struct GlobalRateLimiter {
      21              :     data: Vec<RateBucket>,
      22              :     info: Vec<RateBucketInfo>,
      23              : }
      24              : 
      25              : impl GlobalRateLimiter {
      26            0 :     pub fn new(info: Vec<RateBucketInfo>) -> Self {
      27            0 :         Self {
      28            0 :             data: vec![
      29            0 :                 RateBucket {
      30            0 :                     start: Instant::now(),
      31            0 :                     count: 0,
      32            0 :                 };
      33            0 :                 info.len()
      34            0 :             ],
      35            0 :             info,
      36            0 :         }
      37            0 :     }
      38              : 
      39              :     /// Check that number of connections is below `max_rps` rps.
      40            0 :     pub fn check(&mut self) -> bool {
      41            0 :         let now = Instant::now();
      42            0 : 
      43            0 :         let should_allow_request = self
      44            0 :             .data
      45            0 :             .iter_mut()
      46            0 :             .zip(&self.info)
      47            0 :             .all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
      48            0 : 
      49            0 :         if should_allow_request {
      50            0 :             // only increment the bucket counts if the request will actually be accepted
      51            0 :             self.data.iter_mut().for_each(|b| b.inc(1));
      52            0 :         }
      53              : 
      54            0 :         should_allow_request
      55            0 :     }
      56              : }
      57              : 
      58              : // Simple per-endpoint rate limiter.
      59              : //
      60              : // Check that number of connections to the endpoint is below `max_rps` rps.
      61              : // Purposefully ignore user name and database name as clients can reconnect
      62              : // with different names, so we'll end up sending some http requests to
      63              : // the control plane.
      64              : pub type EndpointRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
      65              : 
      66              : pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
      67              :     map: DashMap<Key, Vec<RateBucket>, Hasher>,
      68              :     info: Cow<'static, [RateBucketInfo]>,
      69              :     access_count: AtomicUsize,
      70              :     rand: Mutex<Rand>,
      71              : }
      72              : 
      73              : #[derive(Clone, Copy)]
      74              : struct RateBucket {
      75              :     start: Instant,
      76              :     count: u32,
      77              : }
      78              : 
      79              : impl RateBucket {
      80      6001848 :     fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
      81      6001848 :         if now - self.start < info.interval {
      82      6001832 :             self.count + n <= info.max_rpi
      83              :         } else {
      84              :             // bucket expired, reset
      85           16 :             self.count = 0;
      86           16 :             self.start = now;
      87           16 : 
      88           16 :             true
      89              :         }
      90      6001848 :     }
      91              : 
      92      6001836 :     fn inc(&mut self, n: u32) {
      93      6001836 :         self.count += n;
      94      6001836 :     }
      95              : }
      96              : 
      97              : #[derive(Clone, Copy, PartialEq)]
      98              : pub struct RateBucketInfo {
      99              :     pub interval: Duration,
     100              :     // requests per interval
     101              :     pub max_rpi: u32,
     102              : }
     103              : 
     104              : impl std::fmt::Display for RateBucketInfo {
     105           38 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     106           38 :         let rps = (self.max_rpi as u64) * 1000 / self.interval.as_millis() as u64;
     107           38 :         write!(f, "{rps}@{}", humantime::format_duration(self.interval))
     108           38 :     }
     109              : }
     110              : 
     111              : impl std::fmt::Debug for RateBucketInfo {
     112            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     113            0 :         write!(f, "{self}")
     114            0 :     }
     115              : }
     116              : 
     117              : impl std::str::FromStr for RateBucketInfo {
     118              :     type Err = anyhow::Error;
     119              : 
     120           64 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
     121           64 :         let Some((max_rps, interval)) = s.split_once('@') else {
     122            0 :             bail!("invalid rate info")
     123              :         };
     124           64 :         let max_rps = max_rps.parse()?;
     125           64 :         let interval = humantime::parse_duration(interval)?;
     126           64 :         Ok(Self::new(max_rps, interval))
     127           64 :     }
     128              : }
     129              : 
     130              : impl RateBucketInfo {
     131              :     pub const DEFAULT_SET: [Self; 3] = [
     132              :         Self::new(300, Duration::from_secs(1)),
     133              :         Self::new(200, Duration::from_secs(60)),
     134              :         Self::new(100, Duration::from_secs(600)),
     135              :     ];
     136              : 
     137              :     pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
     138              :         Self::new(500, Duration::from_secs(1)),
     139              :         Self::new(300, Duration::from_secs(60)),
     140              :         Self::new(200, Duration::from_secs(600)),
     141              :     ];
     142              : 
     143            6 :     pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
     144           16 :         info.sort_unstable_by_key(|info| info.interval);
     145            6 :         let invalid = info
     146            6 :             .iter()
     147            6 :             .tuple_windows()
     148            8 :             .find(|(a, b)| a.max_rpi > b.max_rpi);
     149            6 :         if let Some((a, b)) = invalid {
     150            2 :             bail!(
     151            2 :                 "invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
     152            2 :                 b.max_rpi,
     153            2 :                 a.max_rpi,
     154            2 :             );
     155            4 :         }
     156            4 : 
     157            4 :         Ok(())
     158            6 :     }
     159              : 
     160           72 :     pub const fn new(max_rps: u32, interval: Duration) -> Self {
     161           72 :         Self {
     162           72 :             interval,
     163           72 :             max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
     164           72 :         }
     165           72 :     }
     166              : }
     167              : 
     168              : impl<K: Hash + Eq> BucketRateLimiter<K> {
     169           14 :     pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
     170           14 :         Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
     171           14 :     }
     172              : }
     173              : 
     174              : impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
     175           16 :     fn new_with_rand_and_hasher(
     176           16 :         info: impl Into<Cow<'static, [RateBucketInfo]>>,
     177           16 :         rand: R,
     178           16 :         hasher: S,
     179           16 :     ) -> Self {
     180           16 :         let info = info.into();
     181           16 :         info!(buckets = ?info, "endpoint rate limiter");
     182           16 :         Self {
     183           16 :             info,
     184           16 :             map: DashMap::with_hasher_and_shard_amount(hasher, 64),
     185           16 :             access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
     186           16 :             rand: Mutex::new(rand),
     187           16 :         }
     188           16 :     }
     189              : 
     190              :     /// Check that number of connections to the endpoint is below `max_rps` rps.
     191      2000920 :     pub fn check(&self, key: K, n: u32) -> bool {
     192      2000920 :         // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
     193      2000920 :         // worst case memory usage is about:
     194      2000920 :         //    = 2 * 2048 * 64 * (48B + 72B)
     195      2000920 :         //    = 30MB
     196      2000920 :         if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
     197          976 :             self.do_gc();
     198      1999944 :         }
     199              : 
     200      2000920 :         let now = Instant::now();
     201      2000920 :         let mut entry = self.map.entry(key).or_insert_with(|| {
     202      2000014 :             vec![
     203      2000014 :                 RateBucket {
     204      2000014 :                     start: now,
     205      2000014 :                     count: 0,
     206      2000014 :                 };
     207      2000014 :                 self.info.len()
     208      2000014 :             ]
     209      2000920 :         });
     210      2000920 : 
     211      2000920 :         let should_allow_request = entry
     212      2000920 :             .iter_mut()
     213      2000920 :             .zip(&*self.info)
     214      6001848 :             .all(|(bucket, info)| bucket.should_allow_request(info, now, n));
     215      2000920 : 
     216      2000920 :         if should_allow_request {
     217      2000912 :             // only increment the bucket counts if the request will actually be accepted
     218      6001836 :             entry.iter_mut().for_each(|b| b.inc(n));
     219      2000912 :         }
     220              : 
     221      2000920 :         should_allow_request
     222      2000920 :     }
     223              : 
     224              :     /// Clean the map. Simple strategy: remove all entries in a random shard.
     225              :     /// At worst, we'll double the effective max_rps during the cleanup.
     226              :     /// But that way deletion does not aquire mutex on each entry access.
     227          976 :     pub fn do_gc(&self) {
     228          976 :         info!(
     229            0 :             "cleaning up bucket rate limiter, current size = {}",
     230            0 :             self.map.len()
     231              :         );
     232          976 :         let n = self.map.shards().len();
     233          976 :         // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
     234          976 :         // (impossible, infact, unless we have 2048 threads)
     235          976 :         let shard = self.rand.lock().unwrap().gen_range(0..n);
     236          976 :         self.map.shards()[shard].write().clear();
     237          976 :     }
     238              : }
     239              : 
     240              : #[cfg(test)]
     241              : mod tests {
     242              :     use std::{hash::BuildHasherDefault, time::Duration};
     243              : 
     244              :     use rand::SeedableRng;
     245              :     use rustc_hash::FxHasher;
     246              :     use tokio::time;
     247              : 
     248              :     use super::{BucketRateLimiter, EndpointRateLimiter};
     249              :     use crate::{intern::EndpointIdInt, rate_limiter::RateBucketInfo, EndpointId};
     250              : 
     251              :     #[test]
     252            2 :     fn rate_bucket_rpi() {
     253            2 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
     254            2 :         assert_eq!(rate_bucket.max_rpi, 50 * 5);
     255              : 
     256            2 :         let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
     257            2 :         assert_eq!(rate_bucket.max_rpi, 50 / 2);
     258            2 :     }
     259              : 
     260              :     #[test]
     261            2 :     fn rate_bucket_parse() {
     262            2 :         let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
     263            2 :         assert_eq!(rate_bucket.interval, Duration::from_secs(10));
     264            2 :         assert_eq!(rate_bucket.max_rpi, 100 * 10);
     265            2 :         assert_eq!(rate_bucket.to_string(), "100@10s");
     266              : 
     267            2 :         let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
     268            2 :         assert_eq!(rate_bucket.interval, Duration::from_secs(60));
     269            2 :         assert_eq!(rate_bucket.max_rpi, 100 * 60);
     270            2 :         assert_eq!(rate_bucket.to_string(), "100@1m");
     271            2 :     }
     272              : 
     273              :     #[test]
     274            2 :     fn default_rate_buckets() {
     275            2 :         let mut defaults = RateBucketInfo::DEFAULT_SET;
     276            2 :         RateBucketInfo::validate(&mut defaults[..]).unwrap();
     277            2 :     }
     278              : 
     279              :     #[test]
     280              :     #[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
     281            2 :     fn rate_buckets_validate() {
     282            2 :         let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
     283            2 :             .into_iter()
     284            4 :             .map(|s| s.parse().unwrap())
     285            2 :             .collect();
     286            2 :         RateBucketInfo::validate(&mut rates).unwrap();
     287            2 :     }
     288              : 
     289              :     #[tokio::test]
     290            2 :     async fn test_rate_limits() {
     291            2 :         let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
     292            2 :             .into_iter()
     293            4 :             .map(|s| s.parse().unwrap())
     294            2 :             .collect();
     295            2 :         RateBucketInfo::validate(&mut rates).unwrap();
     296            2 :         let limiter = EndpointRateLimiter::new(rates);
     297            2 : 
     298            2 :         let endpoint = EndpointId::from("ep-my-endpoint-1234");
     299            2 :         let endpoint = EndpointIdInt::from(endpoint);
     300            2 : 
     301            2 :         time::pause();
     302            2 : 
     303          202 :         for _ in 0..100 {
     304          200 :             assert!(limiter.check(endpoint, 1));
     305            2 :         }
     306            2 :         // more connections fail
     307            2 :         assert!(!limiter.check(endpoint, 1));
     308            2 : 
     309            2 :         // fail even after 500ms as it's in the same bucket
     310            2 :         time::advance(time::Duration::from_millis(500)).await;
     311            2 :         assert!(!limiter.check(endpoint, 1));
     312            2 : 
     313            2 :         // after a full 1s, 100 requests are allowed again
     314            2 :         time::advance(time::Duration::from_millis(500)).await;
     315           12 :         for _ in 1..6 {
     316          510 :             for _ in 0..50 {
     317          500 :                 assert!(limiter.check(endpoint, 2));
     318            2 :             }
     319           10 :             time::advance(time::Duration::from_millis(1000)).await;
     320            2 :         }
     321            2 : 
     322            2 :         // more connections after 600 will exceed the 20rps@30s limit
     323            2 :         assert!(!limiter.check(endpoint, 1));
     324            2 : 
     325            2 :         // will still fail before the 30 second limit
     326            2 :         time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
     327            2 :         assert!(!limiter.check(endpoint, 1));
     328            2 : 
     329            2 :         // after the full 30 seconds, 100 requests are allowed again
     330            2 :         time::advance(time::Duration::from_millis(1)).await;
     331          202 :         for _ in 0..100 {
     332          200 :             assert!(limiter.check(endpoint, 1));
     333            2 :         }
     334            2 :     }
     335              : 
     336              :     #[tokio::test]
     337            2 :     async fn test_rate_limits_gc() {
     338            2 :         // fixed seeded random/hasher to ensure that the test is not flaky
     339            2 :         let rand = rand::rngs::StdRng::from_seed([1; 32]);
     340            2 :         let hasher = BuildHasherDefault::<FxHasher>::default();
     341            2 : 
     342            2 :         let limiter =
     343            2 :             BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
     344      2000002 :         for i in 0..1_000_000 {
     345      2000000 :             limiter.check(i, 1);
     346      2000000 :         }
     347            2 :         assert!(limiter.map.len() < 150_000);
     348            2 :     }
     349              : }
        

Generated by: LCOV version 2.1-beta