|             Line data    Source code 
       1              : use std::hash::BuildHasherDefault;
       2              : use std::marker::PhantomData;
       3              : use std::num::NonZeroUsize;
       4              : use std::ops::Index;
       5              : use std::sync::OnceLock;
       6              : 
       7              : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
       8              : use rustc_hash::FxHasher;
       9              : 
      10              : use crate::types::{AccountId, BranchId, EndpointId, ProjectId, RoleName};
      11              : 
      12              : pub trait InternId: Sized + 'static {
      13              :     fn get_interner() -> &'static StringInterner<Self>;
      14              : }
      15              : 
      16              : pub struct StringInterner<Id> {
      17              :     inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
      18              :     _id: PhantomData<Id>,
      19              : }
      20              : 
      21              : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
      22              : pub struct InternedString<Id> {
      23              :     inner: Spur,
      24              :     _id: PhantomData<Id>,
      25              : }
      26              : 
      27              : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
      28            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      29            0 :         self.as_str().fmt(f)
      30            0 :     }
      31              : }
      32              : 
      33              : impl<Id: InternId> InternedString<Id> {
      34           45 :     pub(crate) fn as_str(&self) -> &'static str {
      35           45 :         Id::get_interner().inner.resolve(&self.inner)
      36           45 :     }
      37           20 :     pub(crate) fn get(s: &str) -> Option<Self> {
      38           20 :         Id::get_interner().get(s)
      39           20 :     }
      40              : }
      41              : 
      42              : impl<Id: InternId> AsRef<str> for InternedString<Id> {
      43            0 :     fn as_ref(&self) -> &str {
      44            0 :         self.as_str()
      45            0 :     }
      46              : }
      47              : 
      48              : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
      49              :     type Target = str;
      50           29 :     fn deref(&self) -> &str {
      51           29 :         self.as_str()
      52           29 :     }
      53              : }
      54              : 
      55              : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
      56           40 :     fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
      57              :         struct Visitor<Id>(PhantomData<Id>);
      58              :         impl<Id: InternId> serde::de::Visitor<'_> for Visitor<Id> {
      59              :             type Value = InternedString<Id>;
      60              : 
      61            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      62            0 :                 formatter.write_str("a string")
      63            0 :             }
      64              : 
      65           40 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
      66           40 :             where
      67           40 :                 E: serde::de::Error,
      68              :             {
      69           40 :                 Ok(Id::get_interner().get_or_intern(v))
      70           40 :             }
      71              :         }
      72           40 :         d.deserialize_str(Visitor::<Id>(PhantomData))
      73           40 :     }
      74              : }
      75              : 
      76              : impl<Id: InternId> serde::Serialize for InternedString<Id> {
      77           16 :     fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
      78           16 :         self.as_str().serialize(s)
      79           16 :     }
      80              : }
      81              : 
      82              : impl<Id: InternId> StringInterner<Id> {
      83           81 :     pub(crate) fn new() -> Self {
      84           81 :         StringInterner {
      85           81 :             inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
      86           81 :                 Capacity::new(2500, NonZeroUsize::new(1 << 16).expect("value is nonzero")),
      87           81 :                 // unbounded
      88           81 :                 MemoryLimits::for_memory_usage(usize::MAX),
      89           81 :                 BuildHasherDefault::<FxHasher>::default(),
      90           81 :             ),
      91           81 :             _id: PhantomData,
      92           81 :         }
      93           81 :     }
      94              : 
      95              :     #[cfg(test)]
      96            1 :     fn len(&self) -> usize {
      97            1 :         self.inner.len()
      98            1 :     }
      99              : 
     100              :     #[cfg(test)]
     101            1 :     fn current_memory_usage(&self) -> usize {
     102            1 :         self.inner.current_memory_usage()
     103            1 :     }
     104              : 
     105       100166 :     pub(crate) fn get_or_intern(&self, s: &str) -> InternedString<Id> {
     106       100166 :         InternedString {
     107       100166 :             inner: self.inner.get_or_intern(s),
     108       100166 :             _id: PhantomData,
     109       100166 :         }
     110       100166 :     }
     111              : 
     112           20 :     pub(crate) fn get(&self, s: &str) -> Option<InternedString<Id>> {
     113              :         Some(InternedString {
     114           20 :             inner: self.inner.get(s)?,
     115           19 :             _id: PhantomData,
     116              :         })
     117           20 :     }
     118              : }
     119              : 
     120              : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
     121              :     type Output = str;
     122              : 
     123       100000 :     fn index(&self, index: InternedString<Id>) -> &Self::Output {
     124       100000 :         self.inner.resolve(&index.inner)
     125       100000 :     }
     126              : }
     127              : 
     128              : impl<Id: InternId> Default for StringInterner<Id> {
     129           81 :     fn default() -> Self {
     130           81 :         Self::new()
     131           81 :     }
     132              : }
     133              : 
     134              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     135              : pub struct RoleNameTag;
     136              : impl InternId for RoleNameTag {
     137           61 :     fn get_interner() -> &'static StringInterner<Self> {
     138              :         static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
     139           61 :         ROLE_NAMES.get_or_init(Default::default)
     140           61 :     }
     141              : }
     142              : pub type RoleNameInt = InternedString<RoleNameTag>;
     143              : impl From<&RoleName> for RoleNameInt {
     144           23 :     fn from(value: &RoleName) -> Self {
     145           23 :         RoleNameTag::get_interner().get_or_intern(value)
     146           23 :     }
     147              : }
     148              : 
     149              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     150              : pub struct EndpointIdTag;
     151              : impl InternId for EndpointIdTag {
     152           80 :     fn get_interner() -> &'static StringInterner<Self> {
     153              :         static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
     154           80 :         ROLE_NAMES.get_or_init(Default::default)
     155           80 :     }
     156              : }
     157              : pub type EndpointIdInt = InternedString<EndpointIdTag>;
     158              : impl From<&EndpointId> for EndpointIdInt {
     159           38 :     fn from(value: &EndpointId) -> Self {
     160           38 :         EndpointIdTag::get_interner().get_or_intern(value)
     161           38 :     }
     162              : }
     163              : impl From<EndpointId> for EndpointIdInt {
     164            8 :     fn from(value: EndpointId) -> Self {
     165            8 :         EndpointIdTag::get_interner().get_or_intern(&value)
     166            8 :     }
     167              : }
     168              : 
     169              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     170              : pub struct BranchIdTag;
     171              : impl InternId for BranchIdTag {
     172           48 :     fn get_interner() -> &'static StringInterner<Self> {
     173              :         static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
     174           48 :         ROLE_NAMES.get_or_init(Default::default)
     175           48 :     }
     176              : }
     177              : pub type BranchIdInt = InternedString<BranchIdTag>;
     178              : impl From<&BranchId> for BranchIdInt {
     179           26 :     fn from(value: &BranchId) -> Self {
     180           26 :         BranchIdTag::get_interner().get_or_intern(value)
     181           26 :     }
     182              : }
     183              : impl From<BranchId> for BranchIdInt {
     184            0 :     fn from(value: BranchId) -> Self {
     185            0 :         BranchIdTag::get_interner().get_or_intern(&value)
     186            0 :     }
     187              : }
     188              : 
     189              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     190              : pub struct ProjectIdTag;
     191              : impl InternId for ProjectIdTag {
     192           42 :     fn get_interner() -> &'static StringInterner<Self> {
     193              :         static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
     194           42 :         ROLE_NAMES.get_or_init(Default::default)
     195           42 :     }
     196              : }
     197              : pub type ProjectIdInt = InternedString<ProjectIdTag>;
     198              : impl From<&ProjectId> for ProjectIdInt {
     199           31 :     fn from(value: &ProjectId) -> Self {
     200           31 :         ProjectIdTag::get_interner().get_or_intern(value)
     201           31 :     }
     202              : }
     203              : impl From<ProjectId> for ProjectIdInt {
     204            0 :     fn from(value: ProjectId) -> Self {
     205            0 :         ProjectIdTag::get_interner().get_or_intern(&value)
     206            0 :     }
     207              : }
     208              : 
     209              : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
     210              : pub struct AccountIdTag;
     211              : impl InternId for AccountIdTag {
     212            0 :     fn get_interner() -> &'static StringInterner<Self> {
     213              :         static ROLE_NAMES: OnceLock<StringInterner<AccountIdTag>> = OnceLock::new();
     214            0 :         ROLE_NAMES.get_or_init(Default::default)
     215            0 :     }
     216              : }
     217              : pub type AccountIdInt = InternedString<AccountIdTag>;
     218              : impl From<&AccountId> for AccountIdInt {
     219            0 :     fn from(value: &AccountId) -> Self {
     220            0 :         AccountIdTag::get_interner().get_or_intern(value)
     221            0 :     }
     222              : }
     223              : impl From<AccountId> for AccountIdInt {
     224            0 :     fn from(value: AccountId) -> Self {
     225            0 :         AccountIdTag::get_interner().get_or_intern(&value)
     226            0 :     }
     227              : }
     228              : 
     229              : #[cfg(test)]
     230              : mod tests {
     231              :     use std::sync::OnceLock;
     232              : 
     233              :     use super::InternId;
     234              :     use crate::intern::StringInterner;
     235              : 
     236              :     struct MyId;
     237              :     impl InternId for MyId {
     238            1 :         fn get_interner() -> &'static StringInterner<Self> {
     239              :             pub(crate) static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
     240            1 :             ROLE_NAMES.get_or_init(Default::default)
     241            1 :         }
     242              :     }
     243              : 
     244              :     #[test]
     245            1 :     fn push_many_strings() {
     246              :         use rand::rngs::StdRng;
     247              :         use rand::{Rng, SeedableRng};
     248              :         use rand_distr::Zipf;
     249              : 
     250            1 :         let endpoint_dist = Zipf::new(500000.0, 0.8).unwrap();
     251            1 :         let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
     252              : 
     253            1 :         let interner = MyId::get_interner();
     254              : 
     255              :         const N: usize = 100_000;
     256            1 :         let mut verify = Vec::with_capacity(N);
     257       100000 :         for endpoint in endpoints.take(N) {
     258       100000 :             let endpoint = format!("ep-string-interning-{endpoint}");
     259       100000 :             let key = interner.get_or_intern(&endpoint);
     260       100000 :             verify.push((endpoint, key));
     261       100000 :         }
     262              : 
     263       100001 :         for (s, key) in verify {
     264       100000 :             assert_eq!(interner[key], s);
     265              :         }
     266              : 
     267              :         // 2031616/59861 = 34 bytes per string
     268            1 :         assert_eq!(interner.len(), 59_861);
     269              :         // will have other overhead for the internal hashmaps that are not accounted for.
     270            1 :         assert_eq!(interner.current_memory_usage(), 2_031_616);
     271            1 :     }
     272              : }
         |