LCOV - code coverage report
Current view: top level - proxy/src/scram - countmin.rs (source / functions) Coverage Total Hit
Test: 2aa98e37cd3250b9a68c97ef6050b16fe702ab33.info Lines: 95.8 % 118 113
Test Date: 2024-08-29 11:33:10 Functions: 100.0 % 10 10

            Line data    Source code
       1              : use std::hash::Hash;
       2              : 
       3              : /// estimator of hash jobs per second.
       4              : /// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
       5              : pub(crate) struct CountMinSketch {
       6              :     // one for each depth
       7              :     hashers: Vec<ahash::RandomState>,
       8              :     width: usize,
       9              :     depth: usize,
      10              :     // buckets, width*depth
      11              :     buckets: Vec<u32>,
      12              : }
      13              : 
      14              : impl CountMinSketch {
      15              :     /// Given parameters (ε, δ),
      16              :     ///   set width = ceil(e/ε)
      17              :     ///   set depth = ceil(ln(1/δ))
      18              :     ///
      19              :     /// guarantees:
      20              :     /// actual <= estimate
      21              :     /// estimate <= actual + ε * N with probability 1 - δ
      22              :     /// where N is the cardinality of the stream
      23          108 :     pub(crate) fn with_params(epsilon: f64, delta: f64) -> Self {
      24          108 :         CountMinSketch::new(
      25          108 :             (std::f64::consts::E / epsilon).ceil() as usize,
      26          108 :             (1.0_f64 / delta).ln().ceil() as usize,
      27          108 :         )
      28          108 :     }
      29              : 
      30          108 :     fn new(width: usize, depth: usize) -> Self {
      31          108 :         Self {
      32          108 :             #[cfg(test)]
      33          108 :             hashers: (0..depth)
      34          432 :                 .map(|i| {
      35          432 :                     // digits of pi for good randomness
      36          432 :                     ahash::RandomState::with_seeds(
      37          432 :                         314159265358979323,
      38          432 :                         84626433832795028,
      39          432 :                         84197169399375105,
      40          432 :                         82097494459230781 + i as u64,
      41          432 :                     )
      42          432 :                 })
      43          108 :                 .collect(),
      44          108 :             #[cfg(not(test))]
      45          108 :             hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(),
      46            0 :             width,
      47            0 :             depth,
      48            0 :             buckets: vec![0; width * depth],
      49            0 :         }
      50            0 :     }
      51              : 
      52     53712059 :     pub(crate) fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
      53     53712059 :         let mut min = u32::MAX;
      54    187992259 :         for row in 0..self.depth {
      55    187992259 :             let col = (self.hashers[row].hash_one(t) as usize) % self.width;
      56    187992259 : 
      57    187992259 :             let row = &mut self.buckets[row * self.width..][..self.width];
      58    187992259 :             row[col] = row[col].saturating_add(x);
      59    187992259 :             min = std::cmp::min(min, row[col]);
      60    187992259 :         }
      61     53712059 :         min
      62     53712059 :     }
      63              : 
      64           29 :     pub(crate) fn reset(&mut self) {
      65           29 :         self.buckets.clear();
      66           29 :         self.buckets.resize(self.width * self.depth, 0);
      67           29 :     }
      68              : }
      69              : 
      70              : #[cfg(test)]
      71              : mod tests {
      72              :     use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
      73              : 
      74              :     use super::CountMinSketch;
      75              : 
      76           48 :     fn eval_precision(n: usize, p: f64, q: f64) -> usize {
      77           48 :         // fixed value of phi for consistent test
      78           48 :         let mut rng = StdRng::seed_from_u64(16180339887498948482);
      79           48 : 
      80           48 :         #[allow(non_snake_case)]
      81           48 :         let mut N = 0;
      82           48 : 
      83           48 :         let mut ids = vec![];
      84           48 : 
      85        26400 :         for _ in 0..n {
      86        26400 :             // number of insert operations
      87        26400 :             let n = rng.gen_range(1..100);
      88        26400 :             // number to insert at once
      89        26400 :             let m = rng.gen_range(1..4096);
      90        26400 : 
      91        26400 :             let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid();
      92        26400 :             ids.push((id, n, m));
      93        26400 : 
      94        26400 :             // N = sum(actual)
      95        26400 :             N += n * m;
      96        26400 :         }
      97              : 
      98              :         // q% of counts will be within p of the actual value
      99           48 :         let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
     100           48 : 
     101           48 :         // insert a bunch of entries in a random order
     102           48 :         let mut ids2 = ids.clone();
     103       196488 :         while !ids2.is_empty() {
     104       196440 :             ids2.shuffle(&mut rng);
     105       196440 : 
     106       196440 :             let mut i = 0;
     107     53882064 :             while i < ids2.len() {
     108     53685624 :                 sketch.inc_and_return(&ids2[i].0, ids2[i].1);
     109     53685624 :                 ids2[i].2 -= 1;
     110     53685624 :                 if ids2[i].2 == 0 {
     111        26400 :                     ids2.remove(i);
     112     53659224 :                 } else {
     113     53659224 :                     i += 1;
     114     53659224 :                 }
     115              :             }
     116              :         }
     117              : 
     118           48 :         let mut within_p = 0;
     119        26448 :         for (id, n, m) in ids {
     120        26400 :             let actual = n * m;
     121        26400 :             let estimate = sketch.inc_and_return(&id, 0);
     122        26400 : 
     123        26400 :             // This estimate has the guarantee that actual <= estimate
     124        26400 :             assert!(actual <= estimate);
     125              : 
     126              :             // This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ.
     127              :             // ε = p / N, δ = 1 - q;
     128              :             // therefore, estimate <= actual + p with probability q.
     129        26400 :             if estimate as f64 <= actual as f64 + p {
     130        26334 :                 within_p += 1;
     131        26334 :             }
     132              :         }
     133           48 :         within_p
     134           48 :     }
     135              : 
     136              :     #[test]
     137            6 :     fn precision() {
     138            6 :         assert_eq!(eval_precision(100, 100.0, 0.99), 100);
     139            6 :         assert_eq!(eval_precision(1000, 100.0, 0.99), 1000);
     140            6 :         assert_eq!(eval_precision(100, 4096.0, 0.99), 100);
     141            6 :         assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000);
     142              : 
     143              :         // seems to be more precise than the literature indicates?
     144              :         // probably numbers are too small to truly represent the probabilities.
     145            6 :         assert_eq!(eval_precision(100, 4096.0, 0.90), 100);
     146            6 :         assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000);
     147            6 :         assert_eq!(eval_precision(100, 4096.0, 0.1), 98);
     148            6 :         assert_eq!(eval_precision(1000, 4096.0, 0.1), 991);
     149            6 :     }
     150              : 
     151              :     // returns memory usage in bytes, and the time complexity per insert.
     152           24 :     fn eval_cost(p: f64, q: f64) -> (usize, usize) {
     153           24 :         #[allow(non_snake_case)]
     154           24 :         // N = sum(actual)
     155           24 :         // Let's assume 1021 samples, all of 4096
     156           24 :         let N = 1021 * 4096;
     157           24 :         let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
     158           24 : 
     159           24 :         let memory = size_of::<u32>() * sketch.buckets.len();
     160           24 :         let time = sketch.depth;
     161           24 :         (memory, time)
     162           24 :     }
     163              : 
     164              :     #[test]
     165            6 :     fn memory_usage() {
     166            6 :         assert_eq!(eval_cost(100.0, 0.99), (2273580, 5));
     167            6 :         assert_eq!(eval_cost(4096.0, 0.99), (55520, 5));
     168            6 :         assert_eq!(eval_cost(4096.0, 0.90), (33312, 3));
     169            6 :         assert_eq!(eval_cost(4096.0, 0.1), (11104, 1));
     170            6 :     }
     171              : }
        

Generated by: LCOV version 2.1-beta