LCOV - code coverage report
Current view: top level - proxy/src/scram - countmin.rs (source / functions) Coverage Total Hit
Test: a43a77853355b937a79c57b07a8f05607cf29e6c.info Lines: 92.0 % 113 104
Test Date: 2024-09-19 12:04:32 Functions: 90.9 % 11 10

            Line data    Source code
       1              : use std::hash::Hash;
       2              : 
       3              : /// estimator of hash jobs per second.
       4              : /// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
       5              : pub(crate) struct CountMinSketch {
       6              :     // one for each depth
       7              :     hashers: Vec<ahash::RandomState>,
       8              :     width: usize,
       9              :     depth: usize,
      10              :     // buckets, width*depth
      11              :     buckets: Vec<u32>,
      12              : }
      13              : 
      14              : impl CountMinSketch {
      15              :     /// Given parameters (ε, δ),
      16              :     ///   set width = ceil(e/ε)
      17              :     ///   set depth = ceil(ln(1/δ))
      18              :     ///
      19              :     /// guarantees:
      20              :     /// actual <= estimate
      21              :     /// estimate <= actual + ε * N with probability 1 - δ
      22              :     /// where N is the cardinality of the stream
      23           18 :     pub(crate) fn with_params(epsilon: f64, delta: f64) -> Self {
      24           18 :         CountMinSketch::new(
      25           18 :             (std::f64::consts::E / epsilon).ceil() as usize,
      26           18 :             (1.0_f64 / delta).ln().ceil() as usize,
      27           18 :         )
      28           18 :     }
      29              : 
      30           18 :     fn new(width: usize, depth: usize) -> Self {
      31           18 :         Self {
      32           18 :             #[cfg(test)]
      33           18 :             hashers: (0..depth)
      34           72 :                 .map(|i| {
      35           72 :                     // digits of pi for good randomness
      36           72 :                     ahash::RandomState::with_seeds(
      37           72 :                         314159265358979323,
      38           72 :                         84626433832795028,
      39           72 :                         84197169399375105,
      40           72 :                         82097494459230781 + i as u64,
      41           72 :                     )
      42           72 :                 })
      43           18 :                 .collect(),
      44           18 :             #[cfg(not(test))]
      45           18 :             hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(),
      46            0 :             width,
      47            0 :             depth,
      48            0 :             buckets: vec![0; width * depth],
      49            0 :         }
      50            0 :     }
      51              : 
      52       227176 :     pub(crate) fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
      53       227176 :         let mut min = u32::MAX;
      54       795128 :         for row in 0..self.depth {
      55       795128 :             let col = (self.hashers[row].hash_one(t) as usize) % self.width;
      56       795128 : 
      57       795128 :             let row = &mut self.buckets[row * self.width..][..self.width];
      58       795128 :             row[col] = row[col].saturating_add(x);
      59       795128 :             min = std::cmp::min(min, row[col]);
      60       795128 :         }
      61       227176 :         min
      62       227176 :     }
      63              : 
      64            0 :     pub(crate) fn reset(&mut self) {
      65            0 :         self.buckets.clear();
      66            0 :         self.buckets.resize(self.width * self.depth, 0);
      67            0 :     }
      68              : }
      69              : 
      70              : #[cfg(test)]
      71              : mod tests {
      72              :     use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
      73              : 
      74              :     use super::CountMinSketch;
      75              : 
      76            8 :     fn eval_precision(n: usize, p: f64, q: f64) -> usize {
      77            8 :         // fixed value of phi for consistent test
      78            8 :         let mut rng = StdRng::seed_from_u64(16180339887498948482);
      79            8 : 
      80            8 :         #[allow(non_snake_case)]
      81            8 :         let mut N = 0;
      82            8 : 
      83            8 :         let mut ids = vec![];
      84            8 : 
      85         4400 :         for _ in 0..n {
      86         4400 :             // number to insert at once
      87         4400 :             let n = rng.gen_range(1..4096);
      88         4400 :             // number of insert operations
      89         4400 :             let m = rng.gen_range(1..100);
      90         4400 : 
      91         4400 :             let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid();
      92         4400 :             ids.push((id, n, m));
      93         4400 : 
      94         4400 :             // N = sum(actual)
      95         4400 :             N += n * m;
      96         4400 :         }
      97              : 
      98              :         // q% of counts will be within p of the actual value
      99            8 :         let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
     100            8 : 
     101            8 :         // insert a bunch of entries in a random order
     102            8 :         let mut ids2 = ids.clone();
     103          800 :         while !ids2.is_empty() {
     104          792 :             ids2.shuffle(&mut rng);
     105       222768 :             ids2.retain_mut(|id| {
     106       222768 :                 sketch.inc_and_return(&id.0, id.1);
     107       222768 :                 id.2 -= 1;
     108       222768 :                 id.2 > 0
     109       222768 :             });
     110              :         }
     111              : 
     112            8 :         let mut within_p = 0;
     113         4408 :         for (id, n, m) in ids {
     114         4400 :             let actual = n * m;
     115         4400 :             let estimate = sketch.inc_and_return(&id, 0);
     116         4400 : 
     117         4400 :             // This estimate has the guarantee that actual <= estimate
     118         4400 :             assert!(actual <= estimate);
     119              : 
     120              :             // This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ.
     121              :             // ε = p / N, δ = 1 - q;
     122              :             // therefore, estimate <= actual + p with probability q.
     123         4400 :             if estimate as f64 <= actual as f64 + p {
     124         4384 :                 within_p += 1;
     125         4384 :             }
     126              :         }
     127            8 :         within_p
     128            8 :     }
     129              : 
     130              :     #[test]
     131            1 :     fn precision() {
     132            1 :         assert_eq!(eval_precision(100, 100.0, 0.99), 100);
     133            1 :         assert_eq!(eval_precision(1000, 100.0, 0.99), 1000);
     134            1 :         assert_eq!(eval_precision(100, 4096.0, 0.99), 100);
     135            1 :         assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000);
     136              : 
     137              :         // seems to be more precise than the literature indicates?
     138              :         // probably numbers are too small to truly represent the probabilities.
     139            1 :         assert_eq!(eval_precision(100, 4096.0, 0.90), 100);
     140            1 :         assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000);
     141            1 :         assert_eq!(eval_precision(100, 4096.0, 0.1), 96);
     142            1 :         assert_eq!(eval_precision(1000, 4096.0, 0.1), 988);
     143            1 :     }
     144              : 
     145              :     // returns memory usage in bytes, and the time complexity per insert.
     146            4 :     fn eval_cost(p: f64, q: f64) -> (usize, usize) {
     147            4 :         #[allow(non_snake_case)]
     148            4 :         // N = sum(actual)
     149            4 :         // Let's assume 1021 samples, all of 4096
     150            4 :         let N = 1021 * 4096;
     151            4 :         let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
     152            4 : 
     153            4 :         let memory = size_of::<u32>() * sketch.buckets.len();
     154            4 :         let time = sketch.depth;
     155            4 :         (memory, time)
     156            4 :     }
     157              : 
     158              :     #[test]
     159            1 :     fn memory_usage() {
     160            1 :         assert_eq!(eval_cost(100.0, 0.99), (2273580, 5));
     161            1 :         assert_eq!(eval_cost(4096.0, 0.99), (55520, 5));
     162            1 :         assert_eq!(eval_cost(4096.0, 0.90), (33312, 3));
     163            1 :         assert_eq!(eval_cost(4096.0, 0.1), (11104, 1));
     164            1 :     }
     165              : }
        

Generated by: LCOV version 2.1-beta