LCOV - code coverage report
Current view: top level - proxy/src/scram - countmin.rs (source / functions) Coverage Total Hit
Test: 02e8c57acd6e2b986849f552ca30280d54699b79.info Lines: 95.8 % 120 115
Test Date: 2024-06-26 17:13:54 Functions: 100.0 % 10 10

            Line data    Source code
       1              : use std::hash::Hash;
       2              : 
       3              : /// estimator of hash jobs per second.
       4              : /// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
       5              : pub struct CountMinSketch {
       6              :     // one for each depth
       7              :     hashers: Vec<ahash::RandomState>,
       8              :     width: usize,
       9              :     depth: usize,
      10              :     // buckets, width*depth
      11              :     buckets: Vec<u32>,
      12              : }
      13              : 
      14              : impl CountMinSketch {
      15              :     /// Given parameters (ε, δ),
      16              :     ///   set width = ceil(e/ε)
      17              :     ///   set depth = ceil(ln(1/δ))
      18              :     ///
      19              :     /// guarantees:
      20              :     /// actual <= estimate
      21              :     /// estimate <= actual + ε * N with probability 1 - δ
      22              :     /// where N is the cardinality of the stream
      23           36 :     pub fn with_params(epsilon: f64, delta: f64) -> Self {
      24           36 :         CountMinSketch::new(
      25           36 :             (std::f64::consts::E / epsilon).ceil() as usize,
      26           36 :             (1.0_f64 / delta).ln().ceil() as usize,
      27           36 :         )
      28           36 :     }
      29              : 
      30           36 :     fn new(width: usize, depth: usize) -> Self {
      31           36 :         Self {
      32           36 :             #[cfg(test)]
      33           36 :             hashers: (0..depth)
      34          144 :                 .map(|i| {
      35          144 :                     // digits of pi for good randomness
      36          144 :                     ahash::RandomState::with_seeds(
      37          144 :                         314159265358979323,
      38          144 :                         84626433832795028,
      39          144 :                         84197169399375105,
      40          144 :                         82097494459230781 + i as u64,
      41          144 :                     )
      42          144 :                 })
      43           36 :                 .collect(),
      44           36 :             #[cfg(not(test))]
      45           36 :             hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(),
      46            0 :             width,
      47            0 :             depth,
      48            0 :             buckets: vec![0; width * depth],
      49            0 :         }
      50            0 :     }
      51              : 
      52     17904018 :     pub fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
      53     17904018 :         let mut min = u32::MAX;
      54     62664078 :         for row in 0..self.depth {
      55     62664078 :             let col = (self.hashers[row].hash_one(t) as usize) % self.width;
      56     62664078 : 
      57     62664078 :             let row = &mut self.buckets[row * self.width..][..self.width];
      58     62664078 :             row[col] = row[col].saturating_add(x);
      59     62664078 :             min = std::cmp::min(min, row[col]);
      60     62664078 :         }
      61     17904018 :         min
      62     17904018 :     }
      63              : 
      64           10 :     pub fn reset(&mut self) {
      65           10 :         self.buckets.clear();
      66           10 :         self.buckets.resize(self.width * self.depth, 0);
      67           10 :     }
      68              : }
      69              : 
      70              : #[cfg(test)]
      71              : mod tests {
      72              :     use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
      73              : 
      74              :     use super::CountMinSketch;
      75              : 
      76           16 :     fn eval_precision(n: usize, p: f64, q: f64) -> usize {
      77           16 :         // fixed value of phi for consistent test
      78           16 :         let mut rng = StdRng::seed_from_u64(16180339887498948482);
      79           16 : 
      80           16 :         #[allow(non_snake_case)]
      81           16 :         let mut N = 0;
      82           16 : 
      83           16 :         let mut ids = vec![];
      84           16 : 
      85         8800 :         for _ in 0..n {
      86         8800 :             // number of insert operations
      87         8800 :             let n = rng.gen_range(1..100);
      88         8800 :             // number to insert at once
      89         8800 :             let m = rng.gen_range(1..4096);
      90         8800 : 
      91         8800 :             let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid();
      92         8800 :             ids.push((id, n, m));
      93         8800 : 
      94         8800 :             // N = sum(actual)
      95         8800 :             N += n * m;
      96         8800 :         }
      97              : 
      98              :         // q% of counts will be within p of the actual value
      99           16 :         let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
     100           16 : 
     101           16 :         dbg!(sketch.buckets.len());
     102           16 : 
     103           16 :         // insert a bunch of entries in a random order
     104           16 :         let mut ids2 = ids.clone();
     105        65496 :         while !ids2.is_empty() {
     106        65480 :             ids2.shuffle(&mut rng);
     107        65480 : 
     108        65480 :             let mut i = 0;
     109     17960688 :             while i < ids2.len() {
     110     17895208 :                 sketch.inc_and_return(&ids2[i].0, ids2[i].1);
     111     17895208 :                 ids2[i].2 -= 1;
     112     17895208 :                 if ids2[i].2 == 0 {
     113         8800 :                     ids2.remove(i);
     114     17886408 :                 } else {
     115     17886408 :                     i += 1;
     116     17886408 :                 }
     117              :             }
     118              :         }
     119              : 
     120           16 :         let mut within_p = 0;
     121         8816 :         for (id, n, m) in ids {
     122         8800 :             let actual = n * m;
     123         8800 :             let estimate = sketch.inc_and_return(&id, 0);
     124         8800 : 
     125         8800 :             // This estimate has the guarantee that actual <= estimate
     126         8800 :             assert!(actual <= estimate);
     127              : 
     128              :             // This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ.
     129              :             // ε = p / N, δ = 1 - q;
     130              :             // therefore, estimate <= actual + p with probability q.
     131         8800 :             if estimate as f64 <= actual as f64 + p {
     132         8778 :                 within_p += 1;
     133         8778 :             }
     134              :         }
     135           16 :         within_p
     136           16 :     }
     137              : 
     138              :     #[test]
     139            2 :     fn precision() {
     140            2 :         assert_eq!(eval_precision(100, 100.0, 0.99), 100);
     141            2 :         assert_eq!(eval_precision(1000, 100.0, 0.99), 1000);
     142            2 :         assert_eq!(eval_precision(100, 4096.0, 0.99), 100);
     143            2 :         assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000);
     144              : 
     145              :         // seems to be more precise than the literature indicates?
     146              :         // probably numbers are too small to truly represent the probabilities.
     147            2 :         assert_eq!(eval_precision(100, 4096.0, 0.90), 100);
     148            2 :         assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000);
     149            2 :         assert_eq!(eval_precision(100, 4096.0, 0.1), 98);
     150            2 :         assert_eq!(eval_precision(1000, 4096.0, 0.1), 991);
     151            2 :     }
     152              : 
     153              :     // returns memory usage in bytes, and the time complexity per insert.
     154            8 :     fn eval_cost(p: f64, q: f64) -> (usize, usize) {
     155            8 :         #[allow(non_snake_case)]
     156            8 :         // N = sum(actual)
     157            8 :         // Let's assume 1021 samples, all of 4096
     158            8 :         let N = 1021 * 4096;
     159            8 :         let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
     160            8 : 
     161            8 :         let memory = std::mem::size_of::<u32>() * sketch.buckets.len();
     162            8 :         let time = sketch.depth;
     163            8 :         (memory, time)
     164            8 :     }
     165              : 
     166              :     #[test]
     167            2 :     fn memory_usage() {
     168            2 :         assert_eq!(eval_cost(100.0, 0.99), (2273580, 5));
     169            2 :         assert_eq!(eval_cost(4096.0, 0.99), (55520, 5));
     170            2 :         assert_eq!(eval_cost(4096.0, 0.90), (33312, 3));
     171            2 :         assert_eq!(eval_cost(4096.0, 0.1), (11104, 1));
     172            2 :     }
     173              : }
        

Generated by: LCOV version 2.1-beta