Line data Source code
1 : use std::{
2 : hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
3 : };
4 :
5 : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
6 : use rustc_hash::FxHasher;
7 :
8 : use crate::{BranchId, EndpointId, ProjectId, RoleName};
9 :
10 : pub trait InternId: Sized + 'static {
11 : fn get_interner() -> &'static StringInterner<Self>;
12 : }
13 :
14 : pub struct StringInterner<Id> {
15 : inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
16 : _id: PhantomData<Id>,
17 : }
18 :
19 : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
20 : pub struct InternedString<Id> {
21 : inner: Spur,
22 : _id: PhantomData<Id>,
23 : }
24 :
25 : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
26 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 0 : self.as_str().fmt(f)
28 0 : }
29 : }
30 :
31 : impl<Id: InternId> InternedString<Id> {
32 8 : pub fn as_str(&self) -> &'static str {
33 8 : Id::get_interner().inner.resolve(&self.inner)
34 8 : }
35 70 : pub fn get(s: &str) -> Option<Self> {
36 70 : Id::get_interner().get(s)
37 70 : }
38 : }
39 :
40 : impl<Id: InternId> AsRef<str> for InternedString<Id> {
41 0 : fn as_ref(&self) -> &str {
42 0 : self.as_str()
43 0 : }
44 : }
45 :
46 : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
47 : type Target = str;
48 0 : fn deref(&self) -> &str {
49 0 : self.as_str()
50 0 : }
51 : }
52 :
53 : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
54 46 : fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
55 46 : struct Visitor<Id>(PhantomData<Id>);
56 46 : impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
57 46 : type Value = InternedString<Id>;
58 46 :
59 46 : fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
60 0 : formatter.write_str("a string")
61 0 : }
62 46 :
63 46 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
64 46 : where
65 46 : E: serde::de::Error,
66 46 : {
67 46 : Ok(Id::get_interner().get_or_intern(v))
68 46 : }
69 46 : }
70 46 : d.deserialize_str(Visitor::<Id>(PhantomData))
71 46 : }
72 : }
73 :
74 : impl<Id: InternId> serde::Serialize for InternedString<Id> {
75 8 : fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
76 8 : self.as_str().serialize(s)
77 8 : }
78 : }
79 :
80 : impl<Id: InternId> StringInterner<Id> {
81 100 : pub fn new() -> Self {
82 100 : StringInterner {
83 100 : inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
84 100 : Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
85 100 : // unbounded
86 100 : MemoryLimits::for_memory_usage(usize::MAX),
87 100 : BuildHasherDefault::<FxHasher>::default(),
88 100 : ),
89 100 : _id: PhantomData,
90 100 : }
91 100 : }
92 :
93 0 : pub fn is_empty(&self) -> bool {
94 0 : self.inner.is_empty()
95 0 : }
96 :
97 2 : pub fn len(&self) -> usize {
98 2 : self.inner.len()
99 2 : }
100 :
101 2 : pub fn current_memory_usage(&self) -> usize {
102 2 : self.inner.current_memory_usage()
103 2 : }
104 :
105 200230 : pub fn get_or_intern(&self, s: &str) -> InternedString<Id> {
106 200230 : InternedString {
107 200230 : inner: self.inner.get_or_intern(s),
108 200230 : _id: PhantomData,
109 200230 : }
110 200230 : }
111 :
112 70 : pub fn get(&self, s: &str) -> Option<InternedString<Id>> {
113 70 : Some(InternedString {
114 70 : inner: self.inner.get(s)?,
115 70 : _id: PhantomData,
116 : })
117 70 : }
118 : }
119 :
120 : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
121 : type Output = str;
122 :
123 200000 : fn index(&self, index: InternedString<Id>) -> &Self::Output {
124 200000 : self.inner.resolve(&index.inner)
125 200000 : }
126 : }
127 :
128 : impl<Id: InternId> Default for StringInterner<Id> {
129 100 : fn default() -> Self {
130 100 : Self::new()
131 100 : }
132 : }
133 :
134 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
135 : pub struct RoleNameTag;
136 : impl InternId for RoleNameTag {
137 50 : fn get_interner() -> &'static StringInterner<Self> {
138 50 : pub static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
139 50 : ROLE_NAMES.get_or_init(Default::default)
140 50 : }
141 : }
142 : pub type RoleNameInt = InternedString<RoleNameTag>;
143 : impl From<&RoleName> for RoleNameInt {
144 18 : fn from(value: &RoleName) -> Self {
145 18 : RoleNameTag::get_interner().get_or_intern(value)
146 18 : }
147 : }
148 :
149 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
150 : pub struct EndpointIdTag;
151 : impl InternId for EndpointIdTag {
152 128 : fn get_interner() -> &'static StringInterner<Self> {
153 128 : pub static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
154 128 : ROLE_NAMES.get_or_init(Default::default)
155 128 : }
156 : }
157 : pub type EndpointIdInt = InternedString<EndpointIdTag>;
158 : impl From<&EndpointId> for EndpointIdInt {
159 56 : fn from(value: &EndpointId) -> Self {
160 56 : EndpointIdTag::get_interner().get_or_intern(value)
161 56 : }
162 : }
163 : impl From<EndpointId> for EndpointIdInt {
164 14 : fn from(value: EndpointId) -> Self {
165 14 : EndpointIdTag::get_interner().get_or_intern(&value)
166 14 : }
167 : }
168 :
169 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
170 : pub struct BranchIdTag;
171 : impl InternId for BranchIdTag {
172 54 : fn get_interner() -> &'static StringInterner<Self> {
173 54 : pub static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
174 54 : ROLE_NAMES.get_or_init(Default::default)
175 54 : }
176 : }
177 : pub type BranchIdInt = InternedString<BranchIdTag>;
178 : impl From<&BranchId> for BranchIdInt {
179 36 : fn from(value: &BranchId) -> Self {
180 36 : BranchIdTag::get_interner().get_or_intern(value)
181 36 : }
182 : }
183 : impl From<BranchId> for BranchIdInt {
184 0 : fn from(value: BranchId) -> Self {
185 0 : BranchIdTag::get_interner().get_or_intern(&value)
186 0 : }
187 : }
188 :
189 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
190 : pub struct ProjectIdTag;
191 : impl InternId for ProjectIdTag {
192 76 : fn get_interner() -> &'static StringInterner<Self> {
193 76 : pub static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
194 76 : ROLE_NAMES.get_or_init(Default::default)
195 76 : }
196 : }
197 : pub type ProjectIdInt = InternedString<ProjectIdTag>;
198 : impl From<&ProjectId> for ProjectIdInt {
199 60 : fn from(value: &ProjectId) -> Self {
200 60 : ProjectIdTag::get_interner().get_or_intern(value)
201 60 : }
202 : }
203 : impl From<ProjectId> for ProjectIdInt {
204 0 : fn from(value: ProjectId) -> Self {
205 0 : ProjectIdTag::get_interner().get_or_intern(&value)
206 0 : }
207 : }
208 :
209 : #[cfg(test)]
210 : mod tests {
211 : use std::sync::OnceLock;
212 :
213 : use crate::intern::StringInterner;
214 :
215 : use super::InternId;
216 :
217 : struct MyId;
218 : impl InternId for MyId {
219 2 : fn get_interner() -> &'static StringInterner<Self> {
220 2 : pub static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
221 2 : ROLE_NAMES.get_or_init(Default::default)
222 2 : }
223 : }
224 :
225 : #[test]
226 2 : fn push_many_strings() {
227 2 : use rand::{rngs::StdRng, Rng, SeedableRng};
228 2 : use rand_distr::Zipf;
229 2 :
230 2 : let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
231 2 : let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
232 2 :
233 2 : let interner = MyId::get_interner();
234 2 :
235 2 : const N: usize = 100_000;
236 2 : let mut verify = Vec::with_capacity(N);
237 200000 : for endpoint in endpoints.take(N) {
238 200000 : let endpoint = format!("ep-string-interning-{endpoint}");
239 200000 : let key = interner.get_or_intern(&endpoint);
240 200000 : verify.push((endpoint, key));
241 200000 : }
242 :
243 200002 : for (s, key) in verify {
244 200000 : assert_eq!(interner[key], s);
245 : }
246 :
247 : // 2031616/59861 = 34 bytes per string
248 2 : assert_eq!(interner.len(), 59_861);
249 : // will have other overhead for the internal hashmaps that are not accounted for.
250 2 : assert_eq!(interner.current_memory_usage(), 2_031_616);
251 2 : }
252 : }
|