LCOV - code coverage report
Current view: top level - proxy/src/scram - countmin.rs (source / functions) Coverage Total Hit
Test: 249f165943bd2c492f96a3f7d250276e4addca1a.info Lines: 92.0 % 113 104
Test Date: 2024-11-20 18:39:52 Functions: 90.9 % 11 10

            Line data    Source code
       1              : use std::hash::Hash;
       2              : 
       3              : /// estimator of hash jobs per second.
       4              : /// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
       5              : pub(crate) struct CountMinSketch {
       6              :     // one for each depth
       7              :     hashers: Vec<ahash::RandomState>,
       8              :     width: usize,
       9              :     depth: usize,
      10              :     // buckets, width*depth
      11              :     buckets: Vec<u32>,
      12              : }
      13              : 
      14              : impl CountMinSketch {
      15              :     /// Given parameters (ε, δ),
      16              :     ///   set width = ceil(e/ε)
      17              :     ///   set depth = ceil(ln(1/δ))
      18              :     ///
      19              :     /// guarantees:
      20              :     /// actual <= estimate
      21              :     /// estimate <= actual + ε * N with probability 1 - δ
      22              :     /// where N is the cardinality of the stream
      23           18 :     pub(crate) fn with_params(epsilon: f64, delta: f64) -> Self {
      24           18 :         CountMinSketch::new(
      25           18 :             (std::f64::consts::E / epsilon).ceil() as usize,
      26           18 :             (1.0_f64 / delta).ln().ceil() as usize,
      27           18 :         )
      28           18 :     }
      29              : 
      30           18 :     fn new(width: usize, depth: usize) -> Self {
      31           18 :         Self {
      32           18 :             #[cfg(test)]
      33           18 :             hashers: (0..depth)
      34           72 :                 .map(|i| {
      35           72 :                     // digits of pi for good randomness
      36           72 :                     ahash::RandomState::with_seeds(
      37           72 :                         314159265358979323,
      38           72 :                         84626433832795028,
      39           72 :                         84197169399375105,
      40           72 :                         82097494459230781 + i as u64,
      41           72 :                     )
      42           72 :                 })
      43           18 :                 .collect(),
      44           18 :             #[cfg(not(test))]
      45           18 :             hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(),
      46            0 :             width,
      47            0 :             depth,
      48            0 :             buckets: vec![0; width * depth],
      49            0 :         }
      50            0 :     }
      51              : 
      52       227174 :     pub(crate) fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
      53       227174 :         let mut min = u32::MAX;
      54       795118 :         for row in 0..self.depth {
      55       795118 :             let col = (self.hashers[row].hash_one(t) as usize) % self.width;
      56       795118 : 
      57       795118 :             let row = &mut self.buckets[row * self.width..][..self.width];
      58       795118 :             row[col] = row[col].saturating_add(x);
      59       795118 :             min = std::cmp::min(min, row[col]);
      60       795118 :         }
      61       227174 :         min
      62       227174 :     }
      63              : 
      64            0 :     pub(crate) fn reset(&mut self) {
      65            0 :         self.buckets.clear();
      66            0 :         self.buckets.resize(self.width * self.depth, 0);
      67            0 :     }
      68              : }
      69              : 
      70              : #[cfg(test)]
      71              : mod tests {
      72              :     use rand::rngs::StdRng;
      73              :     use rand::seq::SliceRandom;
      74              :     use rand::{Rng, SeedableRng};
      75              : 
      76              :     use super::CountMinSketch;
      77              : 
      78            8 :     fn eval_precision(n: usize, p: f64, q: f64) -> usize {
      79            8 :         // fixed value of phi for consistent test
      80            8 :         let mut rng = StdRng::seed_from_u64(16180339887498948482);
      81            8 : 
      82            8 :         #[allow(non_snake_case)]
      83            8 :         let mut N = 0;
      84            8 : 
      85            8 :         let mut ids = vec![];
      86            8 : 
      87         4400 :         for _ in 0..n {
      88         4400 :             // number to insert at once
      89         4400 :             let n = rng.gen_range(1..4096);
      90         4400 :             // number of insert operations
      91         4400 :             let m = rng.gen_range(1..100);
      92         4400 : 
      93         4400 :             let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid();
      94         4400 :             ids.push((id, n, m));
      95         4400 : 
      96         4400 :             // N = sum(actual)
      97         4400 :             N += n * m;
      98         4400 :         }
      99              : 
     100              :         // q% of counts will be within p of the actual value
     101            8 :         let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
     102            8 : 
     103            8 :         // insert a bunch of entries in a random order
     104            8 :         let mut ids2 = ids.clone();
     105          800 :         while !ids2.is_empty() {
     106          792 :             ids2.shuffle(&mut rng);
     107       222768 :             ids2.retain_mut(|id| {
     108       222768 :                 sketch.inc_and_return(&id.0, id.1);
     109       222768 :                 id.2 -= 1;
     110       222768 :                 id.2 > 0
     111       222768 :             });
     112              :         }
     113              : 
     114            8 :         let mut within_p = 0;
     115         4408 :         for (id, n, m) in ids {
     116         4400 :             let actual = n * m;
     117         4400 :             let estimate = sketch.inc_and_return(&id, 0);
     118         4400 : 
     119         4400 :             // This estimate has the guarantee that actual <= estimate
     120         4400 :             assert!(actual <= estimate);
     121              : 
     122              :             // This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ.
     123              :             // ε = p / N, δ = 1 - q;
     124              :             // therefore, estimate <= actual + p with probability q.
     125         4400 :             if estimate as f64 <= actual as f64 + p {
     126         4384 :                 within_p += 1;
     127         4384 :             }
     128              :         }
     129            8 :         within_p
     130            8 :     }
     131              : 
     132              :     #[test]
     133            1 :     fn precision() {
     134            1 :         assert_eq!(eval_precision(100, 100.0, 0.99), 100);
     135            1 :         assert_eq!(eval_precision(1000, 100.0, 0.99), 1000);
     136            1 :         assert_eq!(eval_precision(100, 4096.0, 0.99), 100);
     137            1 :         assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000);
     138              : 
     139              :         // seems to be more precise than the literature indicates?
     140              :         // probably numbers are too small to truly represent the probabilities.
     141            1 :         assert_eq!(eval_precision(100, 4096.0, 0.90), 100);
     142            1 :         assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000);
     143            1 :         assert_eq!(eval_precision(100, 4096.0, 0.1), 96);
     144            1 :         assert_eq!(eval_precision(1000, 4096.0, 0.1), 988);
     145            1 :     }
     146              : 
     147              :     // returns memory usage in bytes, and the time complexity per insert.
     148            4 :     fn eval_cost(p: f64, q: f64) -> (usize, usize) {
     149            4 :         #[allow(non_snake_case)]
     150            4 :         // N = sum(actual)
     151            4 :         // Let's assume 1021 samples, all of 4096
     152            4 :         let N = 1021 * 4096;
     153            4 :         let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
     154            4 : 
     155            4 :         let memory = size_of::<u32>() * sketch.buckets.len();
     156            4 :         let time = sketch.depth;
     157            4 :         (memory, time)
     158            4 :     }
     159              : 
     160              :     #[test]
     161            1 :     fn memory_usage() {
     162            1 :         assert_eq!(eval_cost(100.0, 0.99), (2273580, 5));
     163            1 :         assert_eq!(eval_cost(4096.0, 0.99), (55520, 5));
     164            1 :         assert_eq!(eval_cost(4096.0, 0.90), (33312, 3));
     165            1 :         assert_eq!(eval_cost(4096.0, 0.1), (11104, 1));
     166            1 :     }
     167              : }
        

Generated by: LCOV version 2.1-beta