Line data Source code
1 : use std::{
2 : hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
3 : };
4 :
5 : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
6 : use rustc_hash::FxHasher;
7 :
8 : use crate::{BranchId, EndpointId, ProjectId, RoleName};
9 :
10 : pub trait InternId: Sized + 'static {
11 : fn get_interner() -> &'static StringInterner<Self>;
12 : }
13 :
14 : pub struct StringInterner<Id> {
15 : inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
16 : _id: PhantomData<Id>,
17 : }
18 :
19 244 : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
20 : pub struct InternedString<Id> {
21 : inner: Spur,
22 : _id: PhantomData<Id>,
23 : }
24 :
25 : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
26 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 0 : self.as_str().fmt(f)
28 0 : }
29 : }
30 :
31 : impl<Id: InternId> InternedString<Id> {
32 0 : pub fn as_str(&self) -> &'static str {
33 0 : Id::get_interner().inner.resolve(&self.inner)
34 0 : }
35 74 : pub fn get(s: &str) -> Option<Self> {
36 74 : Id::get_interner().get(s)
37 74 : }
38 : }
39 :
40 : impl<Id: InternId> AsRef<str> for InternedString<Id> {
41 0 : fn as_ref(&self) -> &str {
42 0 : self.as_str()
43 0 : }
44 : }
45 :
46 : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
47 : type Target = str;
48 0 : fn deref(&self) -> &str {
49 0 : self.as_str()
50 0 : }
51 : }
52 :
53 : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
54 6 : fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
55 6 : struct Visitor<Id>(PhantomData<Id>);
56 6 : impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
57 6 : type Value = InternedString<Id>;
58 6 :
59 6 : fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
60 0 : formatter.write_str("a string")
61 0 : }
62 6 :
63 6 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
64 6 : where
65 6 : E: serde::de::Error,
66 6 : {
67 6 : Ok(Id::get_interner().get_or_intern(v))
68 6 : }
69 6 : }
70 6 : d.deserialize_str(Visitor::<Id>(PhantomData))
71 6 : }
72 : }
73 :
74 : impl<Id: InternId> serde::Serialize for InternedString<Id> {
75 0 : fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
76 0 : self.as_str().serialize(s)
77 0 : }
78 : }
79 :
80 : impl<Id: InternId> StringInterner<Id> {
81 27 : pub fn new() -> Self {
82 27 : StringInterner {
83 27 : inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
84 27 : Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
85 27 : // unbounded
86 27 : MemoryLimits::for_memory_usage(usize::MAX),
87 27 : BuildHasherDefault::<FxHasher>::default(),
88 27 : ),
89 27 : _id: PhantomData,
90 27 : }
91 27 : }
92 :
93 0 : pub fn is_empty(&self) -> bool {
94 0 : self.inner.is_empty()
95 0 : }
96 :
97 2 : pub fn len(&self) -> usize {
98 2 : self.inner.len()
99 2 : }
100 :
101 2 : pub fn current_memory_usage(&self) -> usize {
102 2 : self.inner.current_memory_usage()
103 2 : }
104 :
105 200070 : pub fn get_or_intern(&self, s: &str) -> InternedString<Id> {
106 200070 : InternedString {
107 200070 : inner: self.inner.get_or_intern(s),
108 200070 : _id: PhantomData,
109 200070 : }
110 200070 : }
111 :
112 74 : pub fn get(&self, s: &str) -> Option<InternedString<Id>> {
113 74 : Some(InternedString {
114 74 : inner: self.inner.get(s)?,
115 70 : _id: PhantomData,
116 : })
117 74 : }
118 : }
119 :
120 : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
121 : type Output = str;
122 :
123 200000 : fn index(&self, index: InternedString<Id>) -> &Self::Output {
124 200000 : self.inner.resolve(&index.inner)
125 200000 : }
126 : }
127 :
128 : impl<Id: InternId> Default for StringInterner<Id> {
129 27 : fn default() -> Self {
130 27 : Self::new()
131 27 : }
132 : }
133 :
134 0 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
135 : pub struct RoleNameTag;
136 : impl InternId for RoleNameTag {
137 50 : fn get_interner() -> &'static StringInterner<Self> {
138 50 : pub static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
139 50 : ROLE_NAMES.get_or_init(Default::default)
140 50 : }
141 : }
142 : pub type RoleNameInt = InternedString<RoleNameTag>;
143 : impl From<&RoleName> for RoleNameInt {
144 18 : fn from(value: &RoleName) -> Self {
145 18 : RoleNameTag::get_interner().get_or_intern(value)
146 18 : }
147 : }
148 :
149 0 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
150 : pub struct EndpointIdTag;
151 : impl InternId for EndpointIdTag {
152 64 : fn get_interner() -> &'static StringInterner<Self> {
153 64 : pub static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
154 64 : ROLE_NAMES.get_or_init(Default::default)
155 64 : }
156 : }
157 : pub type EndpointIdInt = InternedString<EndpointIdTag>;
158 : impl From<&EndpointId> for EndpointIdInt {
159 20 : fn from(value: &EndpointId) -> Self {
160 20 : EndpointIdTag::get_interner().get_or_intern(value)
161 20 : }
162 : }
163 :
164 0 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
165 : pub struct BranchIdTag;
166 : impl InternId for BranchIdTag {
167 0 : fn get_interner() -> &'static StringInterner<Self> {
168 0 : pub static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
169 0 : ROLE_NAMES.get_or_init(Default::default)
170 0 : }
171 : }
172 : pub type BranchIdInt = InternedString<BranchIdTag>;
173 : impl From<&BranchId> for BranchIdInt {
174 0 : fn from(value: &BranchId) -> Self {
175 0 : BranchIdTag::get_interner().get_or_intern(value)
176 0 : }
177 : }
178 :
179 0 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
180 : pub struct ProjectIdTag;
181 : impl InternId for ProjectIdTag {
182 30 : fn get_interner() -> &'static StringInterner<Self> {
183 30 : pub static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
184 30 : ROLE_NAMES.get_or_init(Default::default)
185 30 : }
186 : }
187 : pub type ProjectIdInt = InternedString<ProjectIdTag>;
188 : impl From<&ProjectId> for ProjectIdInt {
189 26 : fn from(value: &ProjectId) -> Self {
190 26 : ProjectIdTag::get_interner().get_or_intern(value)
191 26 : }
192 : }
193 :
194 : #[cfg(test)]
195 : mod tests {
196 : use std::sync::OnceLock;
197 :
198 : use crate::intern::StringInterner;
199 :
200 : use super::InternId;
201 :
202 : struct MyId;
203 : impl InternId for MyId {
204 2 : fn get_interner() -> &'static StringInterner<Self> {
205 2 : pub static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
206 2 : ROLE_NAMES.get_or_init(Default::default)
207 2 : }
208 : }
209 :
210 2 : #[test]
211 2 : fn push_many_strings() {
212 2 : use rand::{rngs::StdRng, Rng, SeedableRng};
213 2 : use rand_distr::Zipf;
214 2 :
215 2 : let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
216 2 : let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
217 2 :
218 2 : let interner = MyId::get_interner();
219 2 :
220 2 : const N: usize = 100_000;
221 2 : let mut verify = Vec::with_capacity(N);
222 200000 : for endpoint in endpoints.take(N) {
223 200000 : let endpoint = format!("ep-string-interning-{endpoint}");
224 200000 : let key = interner.get_or_intern(&endpoint);
225 200000 : verify.push((endpoint, key));
226 200000 : }
227 :
228 200002 : for (s, key) in verify {
229 200000 : assert_eq!(interner[key], s);
230 : }
231 :
232 : // 2031616/59861 = 34 bytes per string
233 2 : assert_eq!(interner.len(), 59_861);
234 : // will have other overhead for the internal hashmaps that are not accounted for.
235 2 : assert_eq!(interner.current_memory_usage(), 2_031_616);
236 2 : }
237 : }
|