LCOV - code coverage report
Current view: top level - proxy/src - intern.rs (source / functions) Coverage Total Hit
Test: 792183ae0ef4f1f8b22e9ac7e8748740ab73f873.info Lines: 85.2 % 135 115
Test Date: 2024-06-26 01:04:33 Functions: 63.3 % 79 50

            Line data    Source code
       1              : use std::{
       2              :     hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
       3              : };
       4              : 
       5              : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
       6              : use rustc_hash::FxHasher;
       7              : 
       8              : use crate::{BranchId, EndpointId, ProjectId, RoleName};
       9              : 
      10              : pub trait InternId: Sized + 'static {
      11              :     fn get_interner() -> &'static StringInterner<Self>;
      12              : }
      13              : 
      14              : pub struct StringInterner<Id> {
      15              :     inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
      16              :     _id: PhantomData<Id>,
      17              : }
      18              : 
      19              : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
      20              : pub struct InternedString<Id> {
      21              :     inner: Spur,
      22              :     _id: PhantomData<Id>,
      23              : }
      24              : 
      25              : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
      26            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      27            0 :         self.as_str().fmt(f)
      28            0 :     }
      29              : }
      30              : 
      31              : impl<Id: InternId> InternedString<Id> {
      32            8 :     pub fn as_str(&self) -> &'static str {
      33            8 :         Id::get_interner().inner.resolve(&self.inner)
      34            8 :     }
      35           70 :     pub fn get(s: &str) -> Option<Self> {
      36           70 :         Id::get_interner().get(s)
      37           70 :     }
      38              : }
      39              : 
      40              : impl<Id: InternId> AsRef<str> for InternedString<Id> {
      41            0 :     fn as_ref(&self) -> &str {
      42            0 :         self.as_str()
      43            0 :     }
      44              : }
      45              : 
      46              : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
      47              :     type Target = str;
      48            0 :     fn deref(&self) -> &str {
      49            0 :         self.as_str()
      50            0 :     }
      51              : }
      52              : 
      53              : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
      54           46 :     fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
      55           46 :         struct Visitor<Id>(PhantomData<Id>);
      56           46 :         impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
      57           46 :             type Value = InternedString<Id>;
      58           46 : 
      59           46 :             fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
      60            0 :                 formatter.write_str("a string")
      61            0 :             }
      62           46 : 
      63           46 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
      64           46 :             where
      65           46 :                 E: serde::de::Error,
      66           46 :             {
      67           46 :                 Ok(Id::get_interner().get_or_intern(v))
      68           46 :             }
      69           46 :         }
      70           46 :         d.deserialize_str(Visitor::<Id>(PhantomData))
      71           46 :     }
      72              : }
      73              : 
      74              : impl<Id: InternId> serde::Serialize for InternedString<Id> {
      75            8 :     fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
      76            8 :         self.as_str().serialize(s)
      77            8 :     }
      78              : }
      79              : 
      80              : impl<Id: InternId> StringInterner<Id> {
      81          106 :     pub fn new() -> Self {
      82          106 :         StringInterner {
      83          106 :             inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
      84          106 :                 Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
      85          106 :                 // unbounded
      86          106 :                 MemoryLimits::for_memory_usage(usize::MAX),
      87          106 :                 BuildHasherDefault::<FxHasher>::default(),
      88          106 :             ),
      89          106 :             _id: PhantomData,
      90          106 :         }
      91          106 :     }
      92              : 
      93            0 :     pub fn is_empty(&self) -> bool {
      94            0 :         self.inner.is_empty()
      95            0 :     }
      96              : 
      97            2 :     pub fn len(&self) -> usize {
      98            2 :         self.inner.len()
      99            2 :     }
     100              : 
     101            2 :     pub fn current_memory_usage(&self) -> usize {
     102            2 :         self.inner.current_memory_usage()
     103            2 :     }
     104              : 
     105       200240 :     pub fn get_or_intern(&self, s: &str) -> InternedString<Id> {
     106       200240 :         InternedString {
     107       200240 :             inner: self.inner.get_or_intern(s),
     108       200240 :             _id: PhantomData,
     109       200240 :         }
     110       200240 :     }
     111              : 
     112           70 :     pub fn get(&self, s: &str) -> Option<InternedString<Id>> {
     113           70 :         Some(InternedString {
     114           70 :             inner: self.inner.get(s)?,
     115           70 :             _id: PhantomData,
     116              :         })
     117           70 :     }
     118              : }
     119              : 
     120              : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
     121              :     type Output = str;
     122              : 
     123       200000 :     fn index(&self, index: InternedString<Id>) -> &Self::Output {
     124       200000 :         self.inner.resolve(&index.inner)
     125       200000 :     }
     126              : }
     127              : 
     128              : impl<Id: InternId> Default for StringInterner<Id> {
     129          106 :     fn default() -> Self {
     130          106 :         Self::new()
     131          106 :     }
     132              : }
     133              : 
     134              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     135              : pub struct RoleNameTag;
     136              : impl InternId for RoleNameTag {
     137           50 :     fn get_interner() -> &'static StringInterner<Self> {
     138           50 :         pub static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
     139           50 :         ROLE_NAMES.get_or_init(Default::default)
     140           50 :     }
     141              : }
     142              : pub type RoleNameInt = InternedString<RoleNameTag>;
     143              : impl From<&RoleName> for RoleNameInt {
     144           18 :     fn from(value: &RoleName) -> Self {
     145           18 :         RoleNameTag::get_interner().get_or_intern(value)
     146           18 :     }
     147              : }
     148              : 
     149              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     150              : pub struct EndpointIdTag;
     151              : impl InternId for EndpointIdTag {
     152          138 :     fn get_interner() -> &'static StringInterner<Self> {
     153          138 :         pub static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
     154          138 :         ROLE_NAMES.get_or_init(Default::default)
     155          138 :     }
     156              : }
     157              : pub type EndpointIdInt = InternedString<EndpointIdTag>;
     158              : impl From<&EndpointId> for EndpointIdInt {
     159           60 :     fn from(value: &EndpointId) -> Self {
     160           60 :         EndpointIdTag::get_interner().get_or_intern(value)
     161           60 :     }
     162              : }
     163              : impl From<EndpointId> for EndpointIdInt {
     164           20 :     fn from(value: EndpointId) -> Self {
     165           20 :         EndpointIdTag::get_interner().get_or_intern(&value)
     166           20 :     }
     167              : }
     168              : 
     169              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     170              : pub struct BranchIdTag;
     171              : impl InternId for BranchIdTag {
     172           54 :     fn get_interner() -> &'static StringInterner<Self> {
     173           54 :         pub static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
     174           54 :         ROLE_NAMES.get_or_init(Default::default)
     175           54 :     }
     176              : }
     177              : pub type BranchIdInt = InternedString<BranchIdTag>;
     178              : impl From<&BranchId> for BranchIdInt {
     179           36 :     fn from(value: &BranchId) -> Self {
     180           36 :         BranchIdTag::get_interner().get_or_intern(value)
     181           36 :     }
     182              : }
     183              : impl From<BranchId> for BranchIdInt {
     184            0 :     fn from(value: BranchId) -> Self {
     185            0 :         BranchIdTag::get_interner().get_or_intern(&value)
     186            0 :     }
     187              : }
     188              : 
     189              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     190              : pub struct ProjectIdTag;
     191              : impl InternId for ProjectIdTag {
     192           76 :     fn get_interner() -> &'static StringInterner<Self> {
     193           76 :         pub static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
     194           76 :         ROLE_NAMES.get_or_init(Default::default)
     195           76 :     }
     196              : }
     197              : pub type ProjectIdInt = InternedString<ProjectIdTag>;
     198              : impl From<&ProjectId> for ProjectIdInt {
     199           60 :     fn from(value: &ProjectId) -> Self {
     200           60 :         ProjectIdTag::get_interner().get_or_intern(value)
     201           60 :     }
     202              : }
     203              : impl From<ProjectId> for ProjectIdInt {
     204            0 :     fn from(value: ProjectId) -> Self {
     205            0 :         ProjectIdTag::get_interner().get_or_intern(&value)
     206            0 :     }
     207              : }
     208              : 
     209              : #[cfg(test)]
     210              : mod tests {
     211              :     use std::sync::OnceLock;
     212              : 
     213              :     use crate::intern::StringInterner;
     214              : 
     215              :     use super::InternId;
     216              : 
     217              :     struct MyId;
     218              :     impl InternId for MyId {
     219            2 :         fn get_interner() -> &'static StringInterner<Self> {
     220            2 :             pub static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
     221            2 :             ROLE_NAMES.get_or_init(Default::default)
     222            2 :         }
     223              :     }
     224              : 
     225              :     #[test]
     226            2 :     fn push_many_strings() {
     227            2 :         use rand::{rngs::StdRng, Rng, SeedableRng};
     228            2 :         use rand_distr::Zipf;
     229            2 : 
     230            2 :         let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
     231            2 :         let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
     232            2 : 
     233            2 :         let interner = MyId::get_interner();
     234            2 : 
     235            2 :         const N: usize = 100_000;
     236            2 :         let mut verify = Vec::with_capacity(N);
     237       200000 :         for endpoint in endpoints.take(N) {
     238       200000 :             let endpoint = format!("ep-string-interning-{endpoint}");
     239       200000 :             let key = interner.get_or_intern(&endpoint);
     240       200000 :             verify.push((endpoint, key));
     241       200000 :         }
     242              : 
     243       200002 :         for (s, key) in verify {
     244       200000 :             assert_eq!(interner[key], s);
     245              :         }
     246              : 
     247              :         // 2031616/59861 = 34 bytes per string
     248            2 :         assert_eq!(interner.len(), 59_861);
     249              :         // will have other overhead for the internal hashmaps that are not accounted for.
     250            2 :         assert_eq!(interner.current_memory_usage(), 2_031_616);
     251            2 :     }
     252              : }
        

Generated by: LCOV version 2.1-beta