LCOV - code coverage report
Current view: top level - proxy/src - intern.rs (source / functions) Coverage Total Hit
Test: 5fe7fa8d483b39476409aee736d6d5e32728bfac.info Lines: 80.8 % 125 101
Test Date: 2025-03-12 16:10:49 Functions: 57.1 % 91 52

            Line data    Source code
       1              : use std::hash::BuildHasherDefault;
       2              : use std::marker::PhantomData;
       3              : use std::num::NonZeroUsize;
       4              : use std::ops::Index;
       5              : use std::sync::OnceLock;
       6              : 
       7              : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
       8              : use rustc_hash::FxHasher;
       9              : 
      10              : use crate::types::{AccountId, BranchId, EndpointId, ProjectId, RoleName};
      11              : 
      12              : pub trait InternId: Sized + 'static {
      13              :     fn get_interner() -> &'static StringInterner<Self>;
      14              : }
      15              : 
      16              : pub struct StringInterner<Id> {
      17              :     inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
      18              :     _id: PhantomData<Id>,
      19              : }
      20              : 
      21              : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
      22              : pub struct InternedString<Id> {
      23              :     inner: Spur,
      24              :     _id: PhantomData<Id>,
      25              : }
      26              : 
      27              : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
      28            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      29            0 :         self.as_str().fmt(f)
      30            0 :     }
      31              : }
      32              : 
      33              : impl<Id: InternId> InternedString<Id> {
      34           37 :     pub(crate) fn as_str(&self) -> &'static str {
      35           37 :         Id::get_interner().inner.resolve(&self.inner)
      36           37 :     }
      37           35 :     pub(crate) fn get(s: &str) -> Option<Self> {
      38           35 :         Id::get_interner().get(s)
      39           35 :     }
      40              : }
      41              : 
      42              : impl<Id: InternId> AsRef<str> for InternedString<Id> {
      43            0 :     fn as_ref(&self) -> &str {
      44            0 :         self.as_str()
      45            0 :     }
      46              : }
      47              : 
      48              : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
      49              :     type Target = str;
      50           29 :     fn deref(&self) -> &str {
      51           29 :         self.as_str()
      52           29 :     }
      53              : }
      54              : 
      55              : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
      56           31 :     fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
      57              :         struct Visitor<Id>(PhantomData<Id>);
      58              :         impl<Id: InternId> serde::de::Visitor<'_> for Visitor<Id> {
      59              :             type Value = InternedString<Id>;
      60              : 
      61            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      62            0 :                 formatter.write_str("a string")
      63            0 :             }
      64              : 
      65           31 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
      66           31 :             where
      67           31 :                 E: serde::de::Error,
      68           31 :             {
      69           31 :                 Ok(Id::get_interner().get_or_intern(v))
      70           31 :             }
      71              :         }
      72           31 :         d.deserialize_str(Visitor::<Id>(PhantomData))
      73           31 :     }
      74              : }
      75              : 
      76              : impl<Id: InternId> serde::Serialize for InternedString<Id> {
      77            8 :     fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
      78            8 :         self.as_str().serialize(s)
      79            8 :     }
      80              : }
      81              : 
      82              : impl<Id: InternId> StringInterner<Id> {
      83           60 :     pub(crate) fn new() -> Self {
      84           60 :         StringInterner {
      85           60 :             inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
      86           60 :                 Capacity::new(2500, NonZeroUsize::new(1 << 16).expect("value is nonzero")),
      87           60 :                 // unbounded
      88           60 :                 MemoryLimits::for_memory_usage(usize::MAX),
      89           60 :                 BuildHasherDefault::<FxHasher>::default(),
      90           60 :             ),
      91           60 :             _id: PhantomData,
      92           60 :         }
      93           60 :     }
      94              : 
      95              :     #[cfg(test)]
      96            1 :     fn len(&self) -> usize {
      97            1 :         self.inner.len()
      98            1 :     }
      99              : 
     100              :     #[cfg(test)]
     101            1 :     fn current_memory_usage(&self) -> usize {
     102            1 :         self.inner.current_memory_usage()
     103            1 :     }
     104              : 
     105       100136 :     pub(crate) fn get_or_intern(&self, s: &str) -> InternedString<Id> {
     106       100136 :         InternedString {
     107       100136 :             inner: self.inner.get_or_intern(s),
     108       100136 :             _id: PhantomData,
     109       100136 :         }
     110       100136 :     }
     111              : 
     112           35 :     pub(crate) fn get(&self, s: &str) -> Option<InternedString<Id>> {
     113           35 :         Some(InternedString {
     114           35 :             inner: self.inner.get(s)?,
     115           35 :             _id: PhantomData,
     116              :         })
     117           35 :     }
     118              : }
     119              : 
     120              : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
     121              :     type Output = str;
     122              : 
     123       100000 :     fn index(&self, index: InternedString<Id>) -> &Self::Output {
     124       100000 :         self.inner.resolve(&index.inner)
     125       100000 :     }
     126              : }
     127              : 
     128              : impl<Id: InternId> Default for StringInterner<Id> {
     129           60 :     fn default() -> Self {
     130           60 :         Self::new()
     131           60 :     }
     132              : }
     133              : 
     134              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     135              : pub struct RoleNameTag;
     136              : impl InternId for RoleNameTag {
     137           61 :     fn get_interner() -> &'static StringInterner<Self> {
     138              :         static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
     139           61 :         ROLE_NAMES.get_or_init(Default::default)
     140           61 :     }
     141              : }
     142              : pub type RoleNameInt = InternedString<RoleNameTag>;
     143              : impl From<&RoleName> for RoleNameInt {
     144           16 :     fn from(value: &RoleName) -> Self {
     145           16 :         RoleNameTag::get_interner().get_or_intern(value)
     146           16 :     }
     147              : }
     148              : 
     149              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     150              : pub struct EndpointIdTag;
     151              : impl InternId for EndpointIdTag {
     152           76 :     fn get_interner() -> &'static StringInterner<Self> {
     153              :         static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
     154           76 :         ROLE_NAMES.get_or_init(Default::default)
     155           76 :     }
     156              : }
     157              : pub type EndpointIdInt = InternedString<EndpointIdTag>;
     158              : impl From<&EndpointId> for EndpointIdInt {
     159           30 :     fn from(value: &EndpointId) -> Self {
     160           30 :         EndpointIdTag::get_interner().get_or_intern(value)
     161           30 :     }
     162              : }
     163              : impl From<EndpointId> for EndpointIdInt {
     164           11 :     fn from(value: EndpointId) -> Self {
     165           11 :         EndpointIdTag::get_interner().get_or_intern(&value)
     166           11 :     }
     167              : }
     168              : 
     169              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     170              : pub struct BranchIdTag;
     171              : impl InternId for BranchIdTag {
     172           32 :     fn get_interner() -> &'static StringInterner<Self> {
     173              :         static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
     174           32 :         ROLE_NAMES.get_or_init(Default::default)
     175           32 :     }
     176              : }
     177              : pub type BranchIdInt = InternedString<BranchIdTag>;
     178              : impl From<&BranchId> for BranchIdInt {
     179           18 :     fn from(value: &BranchId) -> Self {
     180           18 :         BranchIdTag::get_interner().get_or_intern(value)
     181           18 :     }
     182              : }
     183              : impl From<BranchId> for BranchIdInt {
     184            0 :     fn from(value: BranchId) -> Self {
     185            0 :         BranchIdTag::get_interner().get_or_intern(&value)
     186            0 :     }
     187              : }
     188              : 
     189              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     190              : pub struct ProjectIdTag;
     191              : impl InternId for ProjectIdTag {
     192           39 :     fn get_interner() -> &'static StringInterner<Self> {
     193              :         static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
     194           39 :         ROLE_NAMES.get_or_init(Default::default)
     195           39 :     }
     196              : }
     197              : pub type ProjectIdInt = InternedString<ProjectIdTag>;
     198              : impl From<&ProjectId> for ProjectIdInt {
     199           30 :     fn from(value: &ProjectId) -> Self {
     200           30 :         ProjectIdTag::get_interner().get_or_intern(value)
     201           30 :     }
     202              : }
     203              : impl From<ProjectId> for ProjectIdInt {
     204            0 :     fn from(value: ProjectId) -> Self {
     205            0 :         ProjectIdTag::get_interner().get_or_intern(&value)
     206            0 :     }
     207              : }
     208              : 
     209              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     210              : pub struct AccountIdTag;
     211              : impl InternId for AccountIdTag {
     212            0 :     fn get_interner() -> &'static StringInterner<Self> {
     213              :         static ROLE_NAMES: OnceLock<StringInterner<AccountIdTag>> = OnceLock::new();
     214            0 :         ROLE_NAMES.get_or_init(Default::default)
     215            0 :     }
     216              : }
     217              : pub type AccountIdInt = InternedString<AccountIdTag>;
     218              : impl From<&AccountId> for AccountIdInt {
     219            0 :     fn from(value: &AccountId) -> Self {
     220            0 :         AccountIdTag::get_interner().get_or_intern(value)
     221            0 :     }
     222              : }
     223              : impl From<AccountId> for AccountIdInt {
     224            0 :     fn from(value: AccountId) -> Self {
     225            0 :         AccountIdTag::get_interner().get_or_intern(&value)
     226            0 :     }
     227              : }
     228              : 
     229              : #[cfg(test)]
     230              : #[expect(clippy::unwrap_used)]
     231              : mod tests {
     232              :     use std::sync::OnceLock;
     233              : 
     234              :     use super::InternId;
     235              :     use crate::intern::StringInterner;
     236              : 
     237              :     struct MyId;
     238              :     impl InternId for MyId {
     239            1 :         fn get_interner() -> &'static StringInterner<Self> {
     240              :             pub(crate) static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
     241            1 :             ROLE_NAMES.get_or_init(Default::default)
     242            1 :         }
     243              :     }
     244              : 
     245              :     #[test]
     246            1 :     fn push_many_strings() {
     247              :         use rand::rngs::StdRng;
     248              :         use rand::{Rng, SeedableRng};
     249              :         use rand_distr::Zipf;
     250              : 
     251            1 :         let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
     252            1 :         let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
     253            1 : 
     254            1 :         let interner = MyId::get_interner();
     255              : 
     256              :         const N: usize = 100_000;
     257            1 :         let mut verify = Vec::with_capacity(N);
     258       100000 :         for endpoint in endpoints.take(N) {
     259       100000 :             let endpoint = format!("ep-string-interning-{endpoint}");
     260       100000 :             let key = interner.get_or_intern(&endpoint);
     261       100000 :             verify.push((endpoint, key));
     262       100000 :         }
     263              : 
     264       100001 :         for (s, key) in verify {
     265       100000 :             assert_eq!(interner[key], s);
     266              :         }
     267              : 
     268              :         // 2031616/59861 = 34 bytes per string
     269            1 :         assert_eq!(interner.len(), 59_861);
     270              :         // will have other overhead for the internal hashmaps that are not accounted for.
     271            1 :         assert_eq!(interner.current_memory_usage(), 2_031_616);
     272            1 :     }
     273              : }
        

Generated by: LCOV version 2.1-beta