       1              : //! HyperLogLog is an algorithm for the count-distinct problem,
       2              : //! approximating the number of distinct elements in a multiset.
       3              : //! Calculating the exact cardinality of the distinct elements
       4              : //! of a multiset requires an amount of memory proportional to
       5              : //! the cardinality, which is impractical for very large data sets.
       6              : //! Probabilistic cardinality estimators, such as the HyperLogLog algorithm,
       7              : //! use significantly less memory than this, but can only approximate the cardinality.
       8              : 
       9              : use std::hash::{BuildHasher, BuildHasherDefault, Hash};
      10              : use std::sync::atomic::AtomicU8;
      11              : 
      12              : use measured::LabelGroup;
      13              : use measured::label::{LabelGroupVisitor, LabelName, LabelValue, LabelVisitor};
      14              : use measured::metric::counter::CounterState;
      15              : use measured::metric::name::MetricNameEncoder;
      16              : use measured::metric::{Metric, MetricType, MetricVec};
      17              : use measured::text::TextEncoder;
      18              : use twox_hash::xxh3;
      19              : 
      20              : /// Create an [`HyperLogLogVec`] and registers to default registry.
      21              : #[macro_export(local_inner_macros)]
      22              : macro_rules! register_hll_vec {
      23              :     ($N:literal, $OPTS:expr, $LABELS_NAMES:expr $(,)?) => {{
      24              :         let hll_vec = $crate::HyperLogLogVec::<$N>::new($OPTS, $LABELS_NAMES).unwrap();
      25              :         $crate::register(Box::new(hll_vec.clone())).map(|_| hll_vec)
      26              :     }};
      27              : 
      28              :     ($N:literal, $NAME:expr, $HELP:expr, $LABELS_NAMES:expr $(,)?) => {{ $crate::register_hll_vec!($N, $crate::opts!($NAME, $HELP), $LABELS_NAMES) }};
      29              : }
      30              : 
      31              : /// Create an [`HyperLogLog`] and registers to default registry.
      32              : #[macro_export(local_inner_macros)]
      33              : macro_rules! register_hll {
      34              :     ($N:literal, $OPTS:expr $(,)?) => {{
      35              :         let hll = $crate::HyperLogLog::<$N>::with_opts($OPTS).unwrap();
      36              :         $crate::register(Box::new(hll.clone())).map(|_| hll)
      37              :     }};
      38              : 
      39              :     ($N:literal, $NAME:expr, $HELP:expr $(,)?) => {{ $crate::register_hll!($N, $crate::opts!($NAME, $HELP)) }};
      40              : }
      41              : 
      42              : /// HLL is a probabilistic cardinality measure.
      43              : ///
      44              : /// How to use this time-series for a metric name `my_metrics_total_hll`:
      45              : ///
      46              : /// ```promql
      47              : /// # harmonic mean
      48              : /// 1 / (
      49              : ///     sum (
      50              : ///         2 ^ -(
      51              : ///             # HLL merge operation
      52              : ///             max (my_metrics_total_hll{}) by (hll_shard, other_labels...)
      53              : ///         )
      54              : ///     ) without (hll_shard)
      55              : /// )
      56              : /// * alpha
      57              : /// * shards_count
      58              : /// * shards_count
      59              : /// ```
      60              : ///
      61              : /// If you want an estimate over time, you can use the following query:
      62              : ///
      63              : /// ```promql
      64              : /// # harmonic mean
      65              : /// 1 / (
      66              : ///     sum (
      67              : ///         2 ^ -(
      68              : ///             # HLL merge operation
      69              : ///             max (
      70              : ///                 max_over_time(my_metrics_total_hll{}[$__rate_interval])
      71              : ///             ) by (hll_shard, other_labels...)
      72              : ///         )
      73              : ///     ) without (hll_shard)
      74              : /// )
      75              : /// * alpha
      76              : /// * shards_count
      77              : /// * shards_count
      78              : /// ```
      79              : ///
      80              : /// In the case of low cardinality, you might want to use the linear counting approximation:
      81              : ///
      82              : /// ```promql
      83              : /// # LinearCounting(m, V) = m log (m / V)
      84              : /// shards_count * ln(shards_count /
      85              : ///     # calculate V = how many shards contain a 0
      86              : ///     count(max (proxy_connecting_endpoints{}) by (hll_shard, protocol) == 0) without (hll_shard)
      87              : /// )
      88              : /// ```
      89              : ///
      90              : /// See <> for estimates on alpha
      91              : pub type HyperLogLogVec<L, const N: usize> = MetricVec<HyperLogLogState<N>, L>;
      92              : pub type HyperLogLog<const N: usize> = Metric<HyperLogLogState<N>>;
      93              : 
      94              : pub struct HyperLogLogState<const N: usize> {
      95              :     shards: [AtomicU8; N],
      96              : }
      97              : impl<const N: usize> Default for HyperLogLogState<N> {
      98           76 :     fn default() -> Self {
      99              :         #[allow(clippy::declare_interior_mutable_const)]
     100              :         const ZERO: AtomicU8 = AtomicU8::new(0);
     101           76 :         Self { shards: [ZERO; N] }
     102           76 :     }
     103              : }
     104              : 
     105              : impl<const N: usize> MetricType for HyperLogLogState<N> {
     106              :     type Metadata = ();
     107              : }
     108              : 
     109              : impl<const N: usize> HyperLogLogState<N> {
     110      4040428 :     pub fn measure(&self, item: &impl Hash) {
     111      4040428 :         // changing the hasher will break compatibility with previous measurements.
     112      4040428 :         self.record(BuildHasherDefault::<xxh3::Hash64>::default().hash_one(item));
     113      4040428 :     }
     114              : 
     115      4040428 :     fn record(&self, hash: u64) {
     116      4040428 :         let p = N.ilog2() as u8;
     117      4040428 :         let j = hash & (N as u64 - 1);
     118      4040428 :         let rho = (hash >> p).leading_zeros() as u8 + 1 - p;
     119      4040428 :         self.shards[j as usize].fetch_max(rho, std::sync::atomic::Ordering::Relaxed);
     120      4040428 :     }
     121              : 
     122           12 :     fn take_sample(&self) -> [u8; N] {
     123          384 :         self.shards.each_ref().map(|x| {
     124          384 :             // We reset the counter to 0 so we can perform a cardinality measure over any time slice in prometheus.
     125          384 : 
     126          384 :             // This seems like it would be a race condition,
     127          384 :             // but HLL is not impacted by a write in one shard happening in between.
     128          384 :             // This is because in PromQL we will be implementing a harmonic mean of all buckets.
     129          384 :             // we will also merge samples in a time series using `max by (hll_shard)`.
     130          384 : 
     131          384 :             // TODO: maybe we shouldn't reset this on every collect, instead, only after a time window.
     132          384 :             // this would mean that a dev port-forwarding the metrics url won't break the sampling.
     133          384 :             x.swap(0, std::sync::atomic::Ordering::Relaxed)
     134          384 :         })
     135           12 :     }
     136              : }
     137              : 
     138              : impl<W: std::io::Write, const N: usize> measured::metric::MetricEncoding<TextEncoder<W>>
     139              :     for HyperLogLogState<N>
     140              : {
     141            0 :     fn write_type(
     142            0 :         name: impl MetricNameEncoder,
     143            0 :         enc: &mut TextEncoder<W>,
     144            0 :     ) -> Result<(), std::io::Error> {
     145            0 :         enc.write_type(&name, measured::text::MetricType::Gauge)
     146            0 :     }
     147            0 :     fn collect_into(
     148            0 :         &self,
     149            0 :         _: &(),
     150            0 :         labels: impl LabelGroup,
     151            0 :         name: impl MetricNameEncoder,
     152            0 :         enc: &mut TextEncoder<W>,
     153            0 :     ) -> Result<(), std::io::Error> {
     154              :         struct I64(i64);
     155              :         impl LabelValue for I64 {
     156            0 :             fn visit<V: LabelVisitor>(&self, v: V) -> V::Output {
     157            0 :                 v.write_int(self.0)
     158            0 :             }
     159              :         }
     160              : 
     161              :         struct HllShardLabel {
     162              :             hll_shard: i64,
     163              :         }
     164              : 
     165              :         impl LabelGroup for HllShardLabel {
     166            0 :             fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
     167              :                 const LE: &LabelName = LabelName::from_str("hll_shard");
     168            0 :                 v.write_value(LE, &I64(self.hll_shard));
     169            0 :             }
     170              :         }
     171              : 
     172            0 :         self.take_sample()
     173            0 :             .into_iter()
     174            0 :             .enumerate()
     175            0 :             .try_for_each(|(hll_shard, val)| {
     176            0 :                 CounterState::new(val as u64).collect_into(
     177            0 :                     &(),
     178            0 :                     labels.by_ref().compose_with(HllShardLabel {
     179            0 :                         hll_shard: hll_shard as i64,
     180            0 :                     }),
     181            0 :                     name.by_ref(),
     182            0 :                     enc,
     183            0 :                 )
     184            0 :             })
     185            0 :     }
     186              : }
     187              : 
     188              : #[cfg(test)]
     189              : mod tests {
     190              :     use std::collections::HashSet;
     191              : 
     192              :     use measured::FixedCardinalityLabel;
     193              :     use measured::label::StaticLabelSet;
     194              :     use rand::rngs::StdRng;
     195              :     use rand::{Rng, SeedableRng};
     196              :     use rand_distr::{Distribution, Zipf};
     197              : 
     198              :     use crate::HyperLogLogVec;
     199              : 
     200              :     #[derive(FixedCardinalityLabel, Clone, Copy)]
     201              :     #[label(singleton = "x")]
     202              :     enum Label {
     203              :         A,
     204              :         B,
     205              :     }
     206              : 
     207            6 :     fn collect(hll: &HyperLogLogVec<StaticLabelSet<Label>, 32>) -> ([u8; 32], [u8; 32]) {
     208            6 :         // cannot go through the `hll.collect_family_into` interface yet...
     209            6 :         // need to see if I can fix the conflicting impls problem in measured.
     210            6 :         (
     211            6 :             hll.get_metric(hll.with_labels(Label::A)).take_sample(),
     212            6 :             hll.get_metric(hll.with_labels(Label::B)).take_sample(),
     213            6 :         )
     214            6 :     }
     215              : 
     216           18 :     fn get_cardinality(samples: &[[u8; 32]]) -> f64 {
     217           18 :         let mut buckets = [0.0; 32];
     218           42 :         for &sample in samples {
     219          768 :             for (i, m) in sample.into_iter().enumerate() {
     220          768 :                 buckets[i] = f64::max(buckets[i], m as f64);
     221          768 :             }
     222              :         }
     223              : 
     224           18 :         buckets
     225           18 :             .into_iter()
     226          576 :             .map(|f| 2.0f64.powf(-f))
     227           18 :             .sum::<f64>()
     228           18 :             .recip()
     229           18 :             * 0.697
     230           18 :             * 32.0
     231           18 :             * 32.0
     232           18 :     }
     233              : 
     234            6 :     fn test_cardinality(n: usize, dist: impl Distribution<f64>) -> ([usize; 3], [f64; 3]) {
     235            6 :         let hll = HyperLogLogVec::<StaticLabelSet<Label>, 32>::new();
     236            6 : 
     237            6 :         let mut iter = StdRng::seed_from_u64(0x2024_0112).sample_iter(dist);
     238            6 :         let mut set_a = HashSet::new();
     239            6 :         let mut set_b = HashSet::new();
     240              : 
     241      2020200 :         for x in iter.by_ref().take(n) {
     242      2020200 :             set_a.insert(x.to_bits());
     243      2020200 :             hll.get_metric(hll.with_labels(Label::A))
     244      2020200 :                 .measure(&x.to_bits());
     245      2020200 :         }
     246      2020200 :         for x in iter.by_ref().take(n) {
     247      2020200 :             set_b.insert(x.to_bits());
     248      2020200 :             hll.get_metric(hll.with_labels(Label::B))
     249      2020200 :                 .measure(&x.to_bits());
     250      2020200 :         }
     251            6 :         let merge = &set_a | &set_b;
     252            6 : 
     253            6 :         let (a, b) = collect(&hll);
     254            6 :         let len = get_cardinality(&[a, b]);
     255            6 :         let len_a = get_cardinality(&[a]);
     256            6 :         let len_b = get_cardinality(&[b]);
     257            6 : 
     258            6 :         ([merge.len(), set_a.len(), set_b.len()], [len, len_a, len_b])
     259            6 :     }
     260              : 
     261              :     #[test]
     262            1 :     fn test_cardinality_small() {
     263            1 :         let (actual, estimate) = test_cardinality(100, Zipf::new(100, 1.2f64).unwrap());
     264            1 : 
     265            1 :         assert_eq!(actual, [46, 30, 32]);
     266            1 :         assert!(51.3 < estimate[0] && estimate[0] < 51.4);
     267            1 :         assert!(44.0 < estimate[1] && estimate[1] < 44.1);
     268            1 :         assert!(39.0 < estimate[2] && estimate[2] < 39.1);
     269            1 :     }
     270              : 
     271              :     #[test]
     272            1 :     fn test_cardinality_medium() {
     273            1 :         let (actual, estimate) = test_cardinality(10000, Zipf::new(10000, 1.2f64).unwrap());
     274            1 : 
     275            1 :         assert_eq!(actual, [2529, 1618, 1629]);
     276            1 :         assert!(2309.1 < estimate[0] && estimate[0] < 2309.2);
     277            1 :         assert!(1566.6 < estimate[1] && estimate[1] < 1566.7);
     278            1 :         assert!(1629.5 < estimate[2] && estimate[2] < 1629.6);
     279            1 :     }
     280              : 
     281              :     #[test]
     282            1 :     fn test_cardinality_large() {
     283            1 :         let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(1_000_000, 1.2f64).unwrap());
     284            1 : 
     285            1 :         assert_eq!(actual, [129077, 79579, 79630]);
     286            1 :         assert!(126067.2 < estimate[0] && estimate[0] < 126067.3);
     287            1 :         assert!(83076.8 < estimate[1] && estimate[1] < 83076.9);
     288            1 :         assert!(64251.2 < estimate[2] && estimate[2] < 64251.3);
     289            1 :     }
     290              : 
     291              :     #[test]
     292            1 :     fn test_cardinality_small2() {
     293            1 :         let (actual, estimate) = test_cardinality(100, Zipf::new(200, 0.8f64).unwrap());
     294            1 : 
     295            1 :         assert_eq!(actual, [92, 58, 60]);
     296            1 :         assert!(116.1 < estimate[0] && estimate[0] < 116.2);
     297            1 :         assert!(81.7 < estimate[1] && estimate[1] < 81.8);
     298            1 :         assert!(69.3 < estimate[2] && estimate[2] < 69.4);
     299            1 :     }
     300              : 
     301              :     #[test]
     302            1 :     fn test_cardinality_medium2() {
     303            1 :         let (actual, estimate) = test_cardinality(10000, Zipf::new(20000, 0.8f64).unwrap());
     304            1 : 
     305            1 :         assert_eq!(actual, [8201, 5131, 5051]);
     306            1 :         assert!(6846.4 < estimate[0] && estimate[0] < 6846.5);
     307            1 :         assert!(5239.1 < estimate[1] && estimate[1] < 5239.2);
     308            1 :         assert!(4292.8 < estimate[2] && estimate[2] < 4292.9);
     309            1 :     }
     310              : 
     311              :     #[test]
     312            1 :     fn test_cardinality_large2() {
     313            1 :         let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(2_000_000, 0.8f64).unwrap());
     314            1 : 
     315            1 :         assert_eq!(actual, [777847, 482069, 482246]);
     316            1 :         assert!(699437.4 < estimate[0] && estimate[0] < 699437.5);
     317            1 :         assert!(374948.9 < estimate[1] && estimate[1] < 374949.0);
     318            1 :         assert!(434609.7 < estimate[2] && estimate[2] < 434609.8);
     319            1 :     }
     320              : }

