LCOV - code coverage report
Current view: top level - proxy/src - intern.rs (source / functions) Coverage Total Hit
Test: 249f165943bd2c492f96a3f7d250276e4addca1a.info Lines: 87.1 % 116 101
Test Date: 2024-11-20 18:39:52 Functions: 63.4 % 82 52

            Line data    Source code
       1              : use std::hash::BuildHasherDefault;
       2              : use std::marker::PhantomData;
       3              : use std::num::NonZeroUsize;
       4              : use std::ops::Index;
       5              : use std::sync::OnceLock;
       6              : 
       7              : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
       8              : use rustc_hash::FxHasher;
       9              : 
      10              : use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
      11              : 
      12              : pub trait InternId: Sized + 'static {
      13              :     fn get_interner() -> &'static StringInterner<Self>;
      14              : }
      15              : 
      16              : pub struct StringInterner<Id> {
      17              :     inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
      18              :     _id: PhantomData<Id>,
      19              : }
      20              : 
      21              : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
      22              : pub struct InternedString<Id> {
      23              :     inner: Spur,
      24              :     _id: PhantomData<Id>,
      25              : }
      26              : 
      27              : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
      28            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      29            0 :         self.as_str().fmt(f)
      30            0 :     }
      31              : }
      32              : 
      33              : impl<Id: InternId> InternedString<Id> {
      34           33 :     pub(crate) fn as_str(&self) -> &'static str {
      35           33 :         Id::get_interner().inner.resolve(&self.inner)
      36           33 :     }
      37           35 :     pub(crate) fn get(s: &str) -> Option<Self> {
      38           35 :         Id::get_interner().get(s)
      39           35 :     }
      40              : }
      41              : 
      42              : impl<Id: InternId> AsRef<str> for InternedString<Id> {
      43            0 :     fn as_ref(&self) -> &str {
      44            0 :         self.as_str()
      45            0 :     }
      46              : }
      47              : 
      48              : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
      49              :     type Target = str;
      50           29 :     fn deref(&self) -> &str {
      51           29 :         self.as_str()
      52           29 :     }
      53              : }
      54              : 
      55              : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
      56           27 :     fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
      57              :         struct Visitor<Id>(PhantomData<Id>);
      58              :         impl<Id: InternId> serde::de::Visitor<'_> for Visitor<Id> {
      59              :             type Value = InternedString<Id>;
      60              : 
      61            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      62            0 :                 formatter.write_str("a string")
      63            0 :             }
      64              : 
      65           27 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
      66           27 :             where
      67           27 :                 E: serde::de::Error,
      68           27 :             {
      69           27 :                 Ok(Id::get_interner().get_or_intern(v))
      70           27 :             }
      71              :         }
      72           27 :         d.deserialize_str(Visitor::<Id>(PhantomData))
      73           27 :     }
      74              : }
      75              : 
      76              : impl<Id: InternId> serde::Serialize for InternedString<Id> {
      77            4 :     fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
      78            4 :         self.as_str().serialize(s)
      79            4 :     }
      80              : }
      81              : 
      82              : impl<Id: InternId> StringInterner<Id> {
      83           60 :     pub(crate) fn new() -> Self {
      84           60 :         StringInterner {
      85           60 :             inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
      86           60 :                 Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
      87           60 :                 // unbounded
      88           60 :                 MemoryLimits::for_memory_usage(usize::MAX),
      89           60 :                 BuildHasherDefault::<FxHasher>::default(),
      90           60 :             ),
      91           60 :             _id: PhantomData,
      92           60 :         }
      93           60 :     }
      94              : 
      95              :     #[cfg(test)]
      96            1 :     fn len(&self) -> usize {
      97            1 :         self.inner.len()
      98            1 :     }
      99              : 
     100              :     #[cfg(test)]
     101            1 :     fn current_memory_usage(&self) -> usize {
     102            1 :         self.inner.current_memory_usage()
     103            1 :     }
     104              : 
     105       100132 :     pub(crate) fn get_or_intern(&self, s: &str) -> InternedString<Id> {
     106       100132 :         InternedString {
     107       100132 :             inner: self.inner.get_or_intern(s),
     108       100132 :             _id: PhantomData,
     109       100132 :         }
     110       100132 :     }
     111              : 
     112           35 :     pub(crate) fn get(&self, s: &str) -> Option<InternedString<Id>> {
     113           35 :         Some(InternedString {
     114           35 :             inner: self.inner.get(s)?,
     115           35 :             _id: PhantomData,
     116              :         })
     117           35 :     }
     118              : }
     119              : 
     120              : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
     121              :     type Output = str;
     122              : 
     123       100000 :     fn index(&self, index: InternedString<Id>) -> &Self::Output {
     124       100000 :         self.inner.resolve(&index.inner)
     125       100000 :     }
     126              : }
     127              : 
     128              : impl<Id: InternId> Default for StringInterner<Id> {
     129           60 :     fn default() -> Self {
     130           60 :         Self::new()
     131           60 :     }
     132              : }
     133              : 
     134              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     135              : pub struct RoleNameTag;
     136              : impl InternId for RoleNameTag {
     137           61 :     fn get_interner() -> &'static StringInterner<Self> {
     138              :         static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
     139           61 :         ROLE_NAMES.get_or_init(Default::default)
     140           61 :     }
     141              : }
     142              : pub type RoleNameInt = InternedString<RoleNameTag>;
     143              : impl From<&RoleName> for RoleNameInt {
     144           16 :     fn from(value: &RoleName) -> Self {
     145           16 :         RoleNameTag::get_interner().get_or_intern(value)
     146           16 :     }
     147              : }
     148              : 
     149              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     150              : pub struct EndpointIdTag;
     151              : impl InternId for EndpointIdTag {
     152           72 :     fn get_interner() -> &'static StringInterner<Self> {
     153              :         static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
     154           72 :         ROLE_NAMES.get_or_init(Default::default)
     155           72 :     }
     156              : }
     157              : pub type EndpointIdInt = InternedString<EndpointIdTag>;
     158              : impl From<&EndpointId> for EndpointIdInt {
     159           30 :     fn from(value: &EndpointId) -> Self {
     160           30 :         EndpointIdTag::get_interner().get_or_intern(value)
     161           30 :     }
     162              : }
     163              : impl From<EndpointId> for EndpointIdInt {
     164           11 :     fn from(value: EndpointId) -> Self {
     165           11 :         EndpointIdTag::get_interner().get_or_intern(&value)
     166           11 :     }
     167              : }
     168              : 
     169              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     170              : pub struct BranchIdTag;
     171              : impl InternId for BranchIdTag {
     172           28 :     fn get_interner() -> &'static StringInterner<Self> {
     173              :         static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
     174           28 :         ROLE_NAMES.get_or_init(Default::default)
     175           28 :     }
     176              : }
     177              : pub type BranchIdInt = InternedString<BranchIdTag>;
     178              : impl From<&BranchId> for BranchIdInt {
     179           18 :     fn from(value: &BranchId) -> Self {
     180           18 :         BranchIdTag::get_interner().get_or_intern(value)
     181           18 :     }
     182              : }
     183              : impl From<BranchId> for BranchIdInt {
     184            0 :     fn from(value: BranchId) -> Self {
     185            0 :         BranchIdTag::get_interner().get_or_intern(&value)
     186            0 :     }
     187              : }
     188              : 
     189              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     190              : pub struct ProjectIdTag;
     191              : impl InternId for ProjectIdTag {
     192           39 :     fn get_interner() -> &'static StringInterner<Self> {
     193              :         static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
     194           39 :         ROLE_NAMES.get_or_init(Default::default)
     195           39 :     }
     196              : }
     197              : pub type ProjectIdInt = InternedString<ProjectIdTag>;
     198              : impl From<&ProjectId> for ProjectIdInt {
     199           30 :     fn from(value: &ProjectId) -> Self {
     200           30 :         ProjectIdTag::get_interner().get_or_intern(value)
     201           30 :     }
     202              : }
     203              : impl From<ProjectId> for ProjectIdInt {
     204            0 :     fn from(value: ProjectId) -> Self {
     205            0 :         ProjectIdTag::get_interner().get_or_intern(&value)
     206            0 :     }
     207              : }
     208              : 
     209              : #[cfg(test)]
     210              : mod tests {
     211              :     use std::sync::OnceLock;
     212              : 
     213              :     use super::InternId;
     214              :     use crate::intern::StringInterner;
     215              : 
     216              :     struct MyId;
     217              :     impl InternId for MyId {
     218            1 :         fn get_interner() -> &'static StringInterner<Self> {
     219              :             pub(crate) static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
     220            1 :             ROLE_NAMES.get_or_init(Default::default)
     221            1 :         }
     222              :     }
     223              : 
     224              :     #[test]
     225            1 :     fn push_many_strings() {
     226              :         use rand::rngs::StdRng;
     227              :         use rand::{Rng, SeedableRng};
     228              :         use rand_distr::Zipf;
     229              : 
     230            1 :         let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
     231            1 :         let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
     232            1 : 
     233            1 :         let interner = MyId::get_interner();
     234              : 
     235              :         const N: usize = 100_000;
     236            1 :         let mut verify = Vec::with_capacity(N);
     237       100000 :         for endpoint in endpoints.take(N) {
     238       100000 :             let endpoint = format!("ep-string-interning-{endpoint}");
     239       100000 :             let key = interner.get_or_intern(&endpoint);
     240       100000 :             verify.push((endpoint, key));
     241       100000 :         }
     242              : 
     243       100001 :         for (s, key) in verify {
     244       100000 :             assert_eq!(interner[key], s);
     245              :         }
     246              : 
     247              :         // 2031616/59861 = 34 bytes per string
     248            1 :         assert_eq!(interner.len(), 59_861);
     249              :         // will have other overhead for the internal hashmaps that are not accounted for.
     250            1 :         assert_eq!(interner.current_memory_usage(), 2_031_616);
     251            1 :     }
     252              : }
        

Generated by: LCOV version 2.1-beta