Line data Source code
1 : use std::hash::Hash;
2 :
3 : /// estimator of hash jobs per second.
4 : /// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
5 : pub struct CountMinSketch {
6 : // one for each depth
7 : hashers: Vec<ahash::RandomState>,
8 : width: usize,
9 : depth: usize,
10 : // buckets, width*depth
11 : buckets: Vec<u32>,
12 : }
13 :
14 : impl CountMinSketch {
15 : /// Given parameters (ε, δ),
16 : /// set width = ceil(e/ε)
17 : /// set depth = ceil(ln(1/δ))
18 : ///
19 : /// guarantees:
20 : /// actual <= estimate
21 : /// estimate <= actual + ε * N with probability 1 - δ
22 : /// where N is the cardinality of the stream
23 36 : pub fn with_params(epsilon: f64, delta: f64) -> Self {
24 36 : CountMinSketch::new(
25 36 : (std::f64::consts::E / epsilon).ceil() as usize,
26 36 : (1.0_f64 / delta).ln().ceil() as usize,
27 36 : )
28 36 : }
29 :
30 36 : fn new(width: usize, depth: usize) -> Self {
31 36 : Self {
32 36 : #[cfg(test)]
33 36 : hashers: (0..depth)
34 144 : .map(|i| {
35 144 : // digits of pi for good randomness
36 144 : ahash::RandomState::with_seeds(
37 144 : 314159265358979323,
38 144 : 84626433832795028,
39 144 : 84197169399375105,
40 144 : 82097494459230781 + i as u64,
41 144 : )
42 144 : })
43 36 : .collect(),
44 36 : #[cfg(not(test))]
45 36 : hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(),
46 0 : width,
47 0 : depth,
48 0 : buckets: vec![0; width * depth],
49 0 : }
50 0 : }
51 :
52 17904020 : pub fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
53 17904020 : let mut min = u32::MAX;
54 62664088 : for row in 0..self.depth {
55 62664088 : let col = (self.hashers[row].hash_one(t) as usize) % self.width;
56 62664088 :
57 62664088 : let row = &mut self.buckets[row * self.width..][..self.width];
58 62664088 : row[col] = row[col].saturating_add(x);
59 62664088 : min = std::cmp::min(min, row[col]);
60 62664088 : }
61 17904020 : min
62 17904020 : }
63 :
64 10 : pub fn reset(&mut self) {
65 10 : self.buckets.clear();
66 10 : self.buckets.resize(self.width * self.depth, 0);
67 10 : }
68 : }
69 :
70 : #[cfg(test)]
71 : mod tests {
72 : use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
73 :
74 : use super::CountMinSketch;
75 :
76 16 : fn eval_precision(n: usize, p: f64, q: f64) -> usize {
77 16 : // fixed value of phi for consistent test
78 16 : let mut rng = StdRng::seed_from_u64(16180339887498948482);
79 16 :
80 16 : #[allow(non_snake_case)]
81 16 : let mut N = 0;
82 16 :
83 16 : let mut ids = vec![];
84 16 :
85 8800 : for _ in 0..n {
86 8800 : // number of insert operations
87 8800 : let n = rng.gen_range(1..100);
88 8800 : // number to insert at once
89 8800 : let m = rng.gen_range(1..4096);
90 8800 :
91 8800 : let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid();
92 8800 : ids.push((id, n, m));
93 8800 :
94 8800 : // N = sum(actual)
95 8800 : N += n * m;
96 8800 : }
97 :
98 : // q% of counts will be within p of the actual value
99 16 : let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
100 16 :
101 16 : dbg!(sketch.buckets.len());
102 16 :
103 16 : // insert a bunch of entries in a random order
104 16 : let mut ids2 = ids.clone();
105 65496 : while !ids2.is_empty() {
106 65480 : ids2.shuffle(&mut rng);
107 65480 :
108 65480 : let mut i = 0;
109 17960688 : while i < ids2.len() {
110 17895208 : sketch.inc_and_return(&ids2[i].0, ids2[i].1);
111 17895208 : ids2[i].2 -= 1;
112 17895208 : if ids2[i].2 == 0 {
113 8800 : ids2.remove(i);
114 17886408 : } else {
115 17886408 : i += 1;
116 17886408 : }
117 : }
118 : }
119 :
120 16 : let mut within_p = 0;
121 8816 : for (id, n, m) in ids {
122 8800 : let actual = n * m;
123 8800 : let estimate = sketch.inc_and_return(&id, 0);
124 8800 :
125 8800 : // This estimate has the guarantee that actual <= estimate
126 8800 : assert!(actual <= estimate);
127 :
128 : // This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ.
129 : // ε = p / N, δ = 1 - q;
130 : // therefore, estimate <= actual + p with probability q.
131 8800 : if estimate as f64 <= actual as f64 + p {
132 8778 : within_p += 1;
133 8778 : }
134 : }
135 16 : within_p
136 16 : }
137 :
138 : #[test]
139 2 : fn precision() {
140 2 : assert_eq!(eval_precision(100, 100.0, 0.99), 100);
141 2 : assert_eq!(eval_precision(1000, 100.0, 0.99), 1000);
142 2 : assert_eq!(eval_precision(100, 4096.0, 0.99), 100);
143 2 : assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000);
144 :
145 : // seems to be more precise than the literature indicates?
146 : // probably numbers are too small to truly represent the probabilities.
147 2 : assert_eq!(eval_precision(100, 4096.0, 0.90), 100);
148 2 : assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000);
149 2 : assert_eq!(eval_precision(100, 4096.0, 0.1), 98);
150 2 : assert_eq!(eval_precision(1000, 4096.0, 0.1), 991);
151 2 : }
152 :
153 : // returns memory usage in bytes, and the time complexity per insert.
154 8 : fn eval_cost(p: f64, q: f64) -> (usize, usize) {
155 8 : #[allow(non_snake_case)]
156 8 : // N = sum(actual)
157 8 : // Let's assume 1021 samples, all of 4096
158 8 : let N = 1021 * 4096;
159 8 : let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
160 8 :
161 8 : let memory = std::mem::size_of::<u32>() * sketch.buckets.len();
162 8 : let time = sketch.depth;
163 8 : (memory, time)
164 8 : }
165 :
166 : #[test]
167 2 : fn memory_usage() {
168 2 : assert_eq!(eval_cost(100.0, 0.99), (2273580, 5));
169 2 : assert_eq!(eval_cost(4096.0, 0.99), (55520, 5));
170 2 : assert_eq!(eval_cost(4096.0, 0.90), (33312, 3));
171 2 : assert_eq!(eval_cost(4096.0, 0.1), (11104, 1));
172 2 : }
173 : }
|