LCOV - code coverage report
Current view: top level - proxy/src - intern.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 76.5 % 132 101
Test Date: 2024-02-12 20:26:03 Functions: 46.9 % 81 38

            Line data    Source code
       1              : use std::{
       2              :     hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
       3              : };
       4              : 
       5              : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
       6              : use rustc_hash::FxHasher;
       7              : 
       8              : use crate::{BranchId, EndpointId, ProjectId, RoleName};
       9              : 
      10              : pub trait InternId: Sized + 'static {
      11              :     fn get_interner() -> &'static StringInterner<Self>;
      12              : }
      13              : 
      14              : pub struct StringInterner<Id> {
      15              :     inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
      16              :     _id: PhantomData<Id>,
      17              : }
      18              : 
      19          244 : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
      20              : pub struct InternedString<Id> {
      21              :     inner: Spur,
      22              :     _id: PhantomData<Id>,
      23              : }
      24              : 
      25              : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
      26            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      27            0 :         self.as_str().fmt(f)
      28            0 :     }
      29              : }
      30              : 
      31              : impl<Id: InternId> InternedString<Id> {
      32            0 :     pub fn as_str(&self) -> &'static str {
      33            0 :         Id::get_interner().inner.resolve(&self.inner)
      34            0 :     }
      35           74 :     pub fn get(s: &str) -> Option<Self> {
      36           74 :         Id::get_interner().get(s)
      37           74 :     }
      38              : }
      39              : 
      40              : impl<Id: InternId> AsRef<str> for InternedString<Id> {
      41            0 :     fn as_ref(&self) -> &str {
      42            0 :         self.as_str()
      43            0 :     }
      44              : }
      45              : 
      46              : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
      47              :     type Target = str;
      48            0 :     fn deref(&self) -> &str {
      49            0 :         self.as_str()
      50            0 :     }
      51              : }
      52              : 
      53              : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
      54            6 :     fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
      55            6 :         struct Visitor<Id>(PhantomData<Id>);
      56            6 :         impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
      57            6 :             type Value = InternedString<Id>;
      58            6 : 
      59            6 :             fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
      60            0 :                 formatter.write_str("a string")
      61            0 :             }
      62            6 : 
      63            6 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
      64            6 :             where
      65            6 :                 E: serde::de::Error,
      66            6 :             {
      67            6 :                 Ok(Id::get_interner().get_or_intern(v))
      68            6 :             }
      69            6 :         }
      70            6 :         d.deserialize_str(Visitor::<Id>(PhantomData))
      71            6 :     }
      72              : }
      73              : 
      74              : impl<Id: InternId> serde::Serialize for InternedString<Id> {
      75            0 :     fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
      76            0 :         self.as_str().serialize(s)
      77            0 :     }
      78              : }
      79              : 
      80              : impl<Id: InternId> StringInterner<Id> {
      81           27 :     pub fn new() -> Self {
      82           27 :         StringInterner {
      83           27 :             inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
      84           27 :                 Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
      85           27 :                 // unbounded
      86           27 :                 MemoryLimits::for_memory_usage(usize::MAX),
      87           27 :                 BuildHasherDefault::<FxHasher>::default(),
      88           27 :             ),
      89           27 :             _id: PhantomData,
      90           27 :         }
      91           27 :     }
      92              : 
      93            0 :     pub fn is_empty(&self) -> bool {
      94            0 :         self.inner.is_empty()
      95            0 :     }
      96              : 
      97            2 :     pub fn len(&self) -> usize {
      98            2 :         self.inner.len()
      99            2 :     }
     100              : 
     101            2 :     pub fn current_memory_usage(&self) -> usize {
     102            2 :         self.inner.current_memory_usage()
     103            2 :     }
     104              : 
     105       200070 :     pub fn get_or_intern(&self, s: &str) -> InternedString<Id> {
     106       200070 :         InternedString {
     107       200070 :             inner: self.inner.get_or_intern(s),
     108       200070 :             _id: PhantomData,
     109       200070 :         }
     110       200070 :     }
     111              : 
     112           74 :     pub fn get(&self, s: &str) -> Option<InternedString<Id>> {
     113           74 :         Some(InternedString {
     114           74 :             inner: self.inner.get(s)?,
     115           70 :             _id: PhantomData,
     116              :         })
     117           74 :     }
     118              : }
     119              : 
     120              : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
     121              :     type Output = str;
     122              : 
     123       200000 :     fn index(&self, index: InternedString<Id>) -> &Self::Output {
     124       200000 :         self.inner.resolve(&index.inner)
     125       200000 :     }
     126              : }
     127              : 
     128              : impl<Id: InternId> Default for StringInterner<Id> {
     129           27 :     fn default() -> Self {
     130           27 :         Self::new()
     131           27 :     }
     132              : }
     133              : 
     134            0 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     135              : pub struct RoleNameTag;
     136              : impl InternId for RoleNameTag {
     137           50 :     fn get_interner() -> &'static StringInterner<Self> {
     138           50 :         pub static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
     139           50 :         ROLE_NAMES.get_or_init(Default::default)
     140           50 :     }
     141              : }
     142              : pub type RoleNameInt = InternedString<RoleNameTag>;
     143              : impl From<&RoleName> for RoleNameInt {
     144           18 :     fn from(value: &RoleName) -> Self {
     145           18 :         RoleNameTag::get_interner().get_or_intern(value)
     146           18 :     }
     147              : }
     148              : 
     149            0 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     150              : pub struct EndpointIdTag;
     151              : impl InternId for EndpointIdTag {
     152           64 :     fn get_interner() -> &'static StringInterner<Self> {
     153           64 :         pub static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
     154           64 :         ROLE_NAMES.get_or_init(Default::default)
     155           64 :     }
     156              : }
     157              : pub type EndpointIdInt = InternedString<EndpointIdTag>;
     158              : impl From<&EndpointId> for EndpointIdInt {
     159           20 :     fn from(value: &EndpointId) -> Self {
     160           20 :         EndpointIdTag::get_interner().get_or_intern(value)
     161           20 :     }
     162              : }
     163              : 
     164            0 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     165              : pub struct BranchIdTag;
     166              : impl InternId for BranchIdTag {
     167            0 :     fn get_interner() -> &'static StringInterner<Self> {
     168            0 :         pub static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
     169            0 :         ROLE_NAMES.get_or_init(Default::default)
     170            0 :     }
     171              : }
     172              : pub type BranchIdInt = InternedString<BranchIdTag>;
     173              : impl From<&BranchId> for BranchIdInt {
     174            0 :     fn from(value: &BranchId) -> Self {
     175            0 :         BranchIdTag::get_interner().get_or_intern(value)
     176            0 :     }
     177              : }
     178              : 
     179            0 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     180              : pub struct ProjectIdTag;
     181              : impl InternId for ProjectIdTag {
     182           30 :     fn get_interner() -> &'static StringInterner<Self> {
     183           30 :         pub static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
     184           30 :         ROLE_NAMES.get_or_init(Default::default)
     185           30 :     }
     186              : }
     187              : pub type ProjectIdInt = InternedString<ProjectIdTag>;
     188              : impl From<&ProjectId> for ProjectIdInt {
     189           26 :     fn from(value: &ProjectId) -> Self {
     190           26 :         ProjectIdTag::get_interner().get_or_intern(value)
     191           26 :     }
     192              : }
     193              : 
     194              : #[cfg(test)]
     195              : mod tests {
     196              :     use std::sync::OnceLock;
     197              : 
     198              :     use crate::intern::StringInterner;
     199              : 
     200              :     use super::InternId;
     201              : 
     202              :     struct MyId;
     203              :     impl InternId for MyId {
     204            2 :         fn get_interner() -> &'static StringInterner<Self> {
     205            2 :             pub static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
     206            2 :             ROLE_NAMES.get_or_init(Default::default)
     207            2 :         }
     208              :     }
     209              : 
     210            2 :     #[test]
     211            2 :     fn push_many_strings() {
     212            2 :         use rand::{rngs::StdRng, Rng, SeedableRng};
     213            2 :         use rand_distr::Zipf;
     214            2 : 
     215            2 :         let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
     216            2 :         let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
     217            2 : 
     218            2 :         let interner = MyId::get_interner();
     219            2 : 
     220            2 :         const N: usize = 100_000;
     221            2 :         let mut verify = Vec::with_capacity(N);
     222       200000 :         for endpoint in endpoints.take(N) {
     223       200000 :             let endpoint = format!("ep-string-interning-{endpoint}");
     224       200000 :             let key = interner.get_or_intern(&endpoint);
     225       200000 :             verify.push((endpoint, key));
     226       200000 :         }
     227              : 
     228       200002 :         for (s, key) in verify {
     229       200000 :             assert_eq!(interner[key], s);
     230              :         }
     231              : 
     232              :         // 2031616/59861 = 34 bytes per string
     233            2 :         assert_eq!(interner.len(), 59_861);
     234              :         // will have other overhead for the internal hashmaps that are not accounted for.
     235            2 :         assert_eq!(interner.current_memory_usage(), 2_031_616);
     236            2 :     }
     237              : }
        

Generated by: LCOV version 2.1-beta