LCOV - code coverage report
Current view: top level - proxy/src - intern.rs (source / functions) Coverage Total Hit
Test: 90b23405d17e36048d3bb64e314067f397803f1b.info Lines: 84.5 % 116 98
Test Date: 2024-09-20 13:14:58 Functions: 65.8 % 76 50

            Line data    Source code
       1              : use std::{
       2              :     hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
       3              : };
       4              : 
       5              : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
       6              : use rustc_hash::FxHasher;
       7              : 
       8              : use crate::{BranchId, EndpointId, ProjectId, RoleName};
       9              : 
      10              : pub trait InternId: Sized + 'static {
      11              :     fn get_interner() -> &'static StringInterner<Self>;
      12              : }
      13              : 
      14              : pub struct StringInterner<Id> {
      15              :     inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
      16              :     _id: PhantomData<Id>,
      17              : }
      18              : 
      19              : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
      20              : pub struct InternedString<Id> {
      21              :     inner: Spur,
      22              :     _id: PhantomData<Id>,
      23              : }
      24              : 
      25              : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
      26            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      27            0 :         self.as_str().fmt(f)
      28            0 :     }
      29              : }
      30              : 
      31              : impl<Id: InternId> InternedString<Id> {
      32            4 :     pub(crate) fn as_str(&self) -> &'static str {
      33            4 :         Id::get_interner().inner.resolve(&self.inner)
      34            4 :     }
      35           35 :     pub(crate) fn get(s: &str) -> Option<Self> {
      36           35 :         Id::get_interner().get(s)
      37           35 :     }
      38              : }
      39              : 
      40              : impl<Id: InternId> AsRef<str> for InternedString<Id> {
      41            0 :     fn as_ref(&self) -> &str {
      42            0 :         self.as_str()
      43            0 :     }
      44              : }
      45              : 
      46              : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
      47              :     type Target = str;
      48            0 :     fn deref(&self) -> &str {
      49            0 :         self.as_str()
      50            0 :     }
      51              : }
      52              : 
      53              : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
      54           23 :     fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
      55              :         struct Visitor<Id>(PhantomData<Id>);
      56              :         impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
      57              :             type Value = InternedString<Id>;
      58              : 
      59            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      60            0 :                 formatter.write_str("a string")
      61            0 :             }
      62              : 
      63           23 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
      64           23 :             where
      65           23 :                 E: serde::de::Error,
      66           23 :             {
      67           23 :                 Ok(Id::get_interner().get_or_intern(v))
      68           23 :             }
      69              :         }
      70           23 :         d.deserialize_str(Visitor::<Id>(PhantomData))
      71           23 :     }
      72              : }
      73              : 
      74              : impl<Id: InternId> serde::Serialize for InternedString<Id> {
      75            4 :     fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
      76            4 :         self.as_str().serialize(s)
      77            4 :     }
      78              : }
      79              : 
      80              : impl<Id: InternId> StringInterner<Id> {
      81           53 :     pub(crate) fn new() -> Self {
      82           53 :         StringInterner {
      83           53 :             inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
      84           53 :                 Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
      85           53 :                 // unbounded
      86           53 :                 MemoryLimits::for_memory_usage(usize::MAX),
      87           53 :                 BuildHasherDefault::<FxHasher>::default(),
      88           53 :             ),
      89           53 :             _id: PhantomData,
      90           53 :         }
      91           53 :     }
      92              : 
      93              :     #[cfg(test)]
      94            1 :     fn len(&self) -> usize {
      95            1 :         self.inner.len()
      96            1 :     }
      97              : 
      98              :     #[cfg(test)]
      99            1 :     fn current_memory_usage(&self) -> usize {
     100            1 :         self.inner.current_memory_usage()
     101            1 :     }
     102              : 
     103       100120 :     pub(crate) fn get_or_intern(&self, s: &str) -> InternedString<Id> {
     104       100120 :         InternedString {
     105       100120 :             inner: self.inner.get_or_intern(s),
     106       100120 :             _id: PhantomData,
     107       100120 :         }
     108       100120 :     }
     109              : 
     110           35 :     pub(crate) fn get(&self, s: &str) -> Option<InternedString<Id>> {
     111           35 :         Some(InternedString {
     112           35 :             inner: self.inner.get(s)?,
     113           35 :             _id: PhantomData,
     114              :         })
     115           35 :     }
     116              : }
     117              : 
     118              : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
     119              :     type Output = str;
     120              : 
     121       100000 :     fn index(&self, index: InternedString<Id>) -> &Self::Output {
     122       100000 :         self.inner.resolve(&index.inner)
     123       100000 :     }
     124              : }
     125              : 
     126              : impl<Id: InternId> Default for StringInterner<Id> {
     127           53 :     fn default() -> Self {
     128           53 :         Self::new()
     129           53 :     }
     130              : }
     131              : 
     132              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     133              : pub(crate) struct RoleNameTag;
     134              : impl InternId for RoleNameTag {
     135           25 :     fn get_interner() -> &'static StringInterner<Self> {
     136              :         static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
     137           25 :         ROLE_NAMES.get_or_init(Default::default)
     138           25 :     }
     139              : }
     140              : pub(crate) type RoleNameInt = InternedString<RoleNameTag>;
     141              : impl From<&RoleName> for RoleNameInt {
     142            9 :     fn from(value: &RoleName) -> Self {
     143            9 :         RoleNameTag::get_interner().get_or_intern(value)
     144            9 :     }
     145              : }
     146              : 
     147              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     148              : pub struct EndpointIdTag;
     149              : impl InternId for EndpointIdTag {
     150           69 :     fn get_interner() -> &'static StringInterner<Self> {
     151              :         static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
     152           69 :         ROLE_NAMES.get_or_init(Default::default)
     153           69 :     }
     154              : }
     155              : pub type EndpointIdInt = InternedString<EndpointIdTag>;
     156              : impl From<&EndpointId> for EndpointIdInt {
     157           30 :     fn from(value: &EndpointId) -> Self {
     158           30 :         EndpointIdTag::get_interner().get_or_intern(value)
     159           30 :     }
     160              : }
     161              : impl From<EndpointId> for EndpointIdInt {
     162           10 :     fn from(value: EndpointId) -> Self {
     163           10 :         EndpointIdTag::get_interner().get_or_intern(&value)
     164           10 :     }
     165              : }
     166              : 
     167              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     168              : pub struct BranchIdTag;
     169              : impl InternId for BranchIdTag {
     170           27 :     fn get_interner() -> &'static StringInterner<Self> {
     171              :         static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
     172           27 :         ROLE_NAMES.get_or_init(Default::default)
     173           27 :     }
     174              : }
     175              : pub type BranchIdInt = InternedString<BranchIdTag>;
     176              : impl From<&BranchId> for BranchIdInt {
     177           18 :     fn from(value: &BranchId) -> Self {
     178           18 :         BranchIdTag::get_interner().get_or_intern(value)
     179           18 :     }
     180              : }
     181              : impl From<BranchId> for BranchIdInt {
     182            0 :     fn from(value: BranchId) -> Self {
     183            0 :         BranchIdTag::get_interner().get_or_intern(&value)
     184            0 :     }
     185              : }
     186              : 
     187              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     188              : pub struct ProjectIdTag;
     189              : impl InternId for ProjectIdTag {
     190           38 :     fn get_interner() -> &'static StringInterner<Self> {
     191              :         static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
     192           38 :         ROLE_NAMES.get_or_init(Default::default)
     193           38 :     }
     194              : }
     195              : pub type ProjectIdInt = InternedString<ProjectIdTag>;
     196              : impl From<&ProjectId> for ProjectIdInt {
     197           30 :     fn from(value: &ProjectId) -> Self {
     198           30 :         ProjectIdTag::get_interner().get_or_intern(value)
     199           30 :     }
     200              : }
     201              : impl From<ProjectId> for ProjectIdInt {
     202            0 :     fn from(value: ProjectId) -> Self {
     203            0 :         ProjectIdTag::get_interner().get_or_intern(&value)
     204            0 :     }
     205              : }
     206              : 
     207              : #[cfg(test)]
     208              : mod tests {
     209              :     use std::sync::OnceLock;
     210              : 
     211              :     use crate::intern::StringInterner;
     212              : 
     213              :     use super::InternId;
     214              : 
     215              :     struct MyId;
     216              :     impl InternId for MyId {
     217            1 :         fn get_interner() -> &'static StringInterner<Self> {
     218              :             pub(crate) static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
     219            1 :             ROLE_NAMES.get_or_init(Default::default)
     220            1 :         }
     221              :     }
     222              : 
     223              :     #[test]
     224            1 :     fn push_many_strings() {
     225              :         use rand::{rngs::StdRng, Rng, SeedableRng};
     226              :         use rand_distr::Zipf;
     227              : 
     228            1 :         let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
     229            1 :         let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
     230            1 : 
     231            1 :         let interner = MyId::get_interner();
     232              : 
     233              :         const N: usize = 100_000;
     234            1 :         let mut verify = Vec::with_capacity(N);
     235       100000 :         for endpoint in endpoints.take(N) {
     236       100000 :             let endpoint = format!("ep-string-interning-{endpoint}");
     237       100000 :             let key = interner.get_or_intern(&endpoint);
     238       100000 :             verify.push((endpoint, key));
     239       100000 :         }
     240              : 
     241       100001 :         for (s, key) in verify {
     242       100000 :             assert_eq!(interner[key], s);
     243              :         }
     244              : 
     245              :         // 2031616/59861 = 34 bytes per string
     246            1 :         assert_eq!(interner.len(), 59_861);
     247              :         // will have other overhead for the internal hashmaps that are not accounted for.
     248            1 :         assert_eq!(interner.current_memory_usage(), 2_031_616);
     249            1 :     }
     250              : }
        

Generated by: LCOV version 2.1-beta