Line data Source code
1 : use std::hash::BuildHasherDefault;
2 : use std::marker::PhantomData;
3 : use std::num::NonZeroUsize;
4 : use std::ops::Index;
5 : use std::sync::OnceLock;
6 :
7 : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
8 : use rustc_hash::FxHasher;
9 :
10 : use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
11 :
12 : pub trait InternId: Sized + 'static {
13 : fn get_interner() -> &'static StringInterner<Self>;
14 : }
15 :
16 : pub struct StringInterner<Id> {
17 : inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
18 : _id: PhantomData<Id>,
19 : }
20 :
21 : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
22 : pub struct InternedString<Id> {
23 : inner: Spur,
24 : _id: PhantomData<Id>,
25 : }
26 :
27 : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
28 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 0 : self.as_str().fmt(f)
30 0 : }
31 : }
32 :
33 : impl<Id: InternId> InternedString<Id> {
34 33 : pub(crate) fn as_str(&self) -> &'static str {
35 33 : Id::get_interner().inner.resolve(&self.inner)
36 33 : }
37 35 : pub(crate) fn get(s: &str) -> Option<Self> {
38 35 : Id::get_interner().get(s)
39 35 : }
40 : }
41 :
42 : impl<Id: InternId> AsRef<str> for InternedString<Id> {
43 0 : fn as_ref(&self) -> &str {
44 0 : self.as_str()
45 0 : }
46 : }
47 :
48 : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
49 : type Target = str;
50 29 : fn deref(&self) -> &str {
51 29 : self.as_str()
52 29 : }
53 : }
54 :
55 : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
56 27 : fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
57 : struct Visitor<Id>(PhantomData<Id>);
58 : impl<Id: InternId> serde::de::Visitor<'_> for Visitor<Id> {
59 : type Value = InternedString<Id>;
60 :
61 0 : fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 0 : formatter.write_str("a string")
63 0 : }
64 :
65 27 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
66 27 : where
67 27 : E: serde::de::Error,
68 27 : {
69 27 : Ok(Id::get_interner().get_or_intern(v))
70 27 : }
71 : }
72 27 : d.deserialize_str(Visitor::<Id>(PhantomData))
73 27 : }
74 : }
75 :
76 : impl<Id: InternId> serde::Serialize for InternedString<Id> {
77 4 : fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
78 4 : self.as_str().serialize(s)
79 4 : }
80 : }
81 :
82 : impl<Id: InternId> StringInterner<Id> {
83 60 : pub(crate) fn new() -> Self {
84 60 : StringInterner {
85 60 : inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
86 60 : Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
87 60 : // unbounded
88 60 : MemoryLimits::for_memory_usage(usize::MAX),
89 60 : BuildHasherDefault::<FxHasher>::default(),
90 60 : ),
91 60 : _id: PhantomData,
92 60 : }
93 60 : }
94 :
95 : #[cfg(test)]
96 1 : fn len(&self) -> usize {
97 1 : self.inner.len()
98 1 : }
99 :
100 : #[cfg(test)]
101 1 : fn current_memory_usage(&self) -> usize {
102 1 : self.inner.current_memory_usage()
103 1 : }
104 :
105 100132 : pub(crate) fn get_or_intern(&self, s: &str) -> InternedString<Id> {
106 100132 : InternedString {
107 100132 : inner: self.inner.get_or_intern(s),
108 100132 : _id: PhantomData,
109 100132 : }
110 100132 : }
111 :
112 35 : pub(crate) fn get(&self, s: &str) -> Option<InternedString<Id>> {
113 35 : Some(InternedString {
114 35 : inner: self.inner.get(s)?,
115 35 : _id: PhantomData,
116 : })
117 35 : }
118 : }
119 :
120 : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
121 : type Output = str;
122 :
123 100000 : fn index(&self, index: InternedString<Id>) -> &Self::Output {
124 100000 : self.inner.resolve(&index.inner)
125 100000 : }
126 : }
127 :
128 : impl<Id: InternId> Default for StringInterner<Id> {
129 60 : fn default() -> Self {
130 60 : Self::new()
131 60 : }
132 : }
133 :
134 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
135 : pub struct RoleNameTag;
136 : impl InternId for RoleNameTag {
137 61 : fn get_interner() -> &'static StringInterner<Self> {
138 : static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
139 61 : ROLE_NAMES.get_or_init(Default::default)
140 61 : }
141 : }
142 : pub type RoleNameInt = InternedString<RoleNameTag>;
143 : impl From<&RoleName> for RoleNameInt {
144 16 : fn from(value: &RoleName) -> Self {
145 16 : RoleNameTag::get_interner().get_or_intern(value)
146 16 : }
147 : }
148 :
149 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
150 : pub struct EndpointIdTag;
151 : impl InternId for EndpointIdTag {
152 72 : fn get_interner() -> &'static StringInterner<Self> {
153 : static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
154 72 : ROLE_NAMES.get_or_init(Default::default)
155 72 : }
156 : }
157 : pub type EndpointIdInt = InternedString<EndpointIdTag>;
158 : impl From<&EndpointId> for EndpointIdInt {
159 30 : fn from(value: &EndpointId) -> Self {
160 30 : EndpointIdTag::get_interner().get_or_intern(value)
161 30 : }
162 : }
163 : impl From<EndpointId> for EndpointIdInt {
164 11 : fn from(value: EndpointId) -> Self {
165 11 : EndpointIdTag::get_interner().get_or_intern(&value)
166 11 : }
167 : }
168 :
169 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
170 : pub struct BranchIdTag;
171 : impl InternId for BranchIdTag {
172 28 : fn get_interner() -> &'static StringInterner<Self> {
173 : static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
174 28 : ROLE_NAMES.get_or_init(Default::default)
175 28 : }
176 : }
177 : pub type BranchIdInt = InternedString<BranchIdTag>;
178 : impl From<&BranchId> for BranchIdInt {
179 18 : fn from(value: &BranchId) -> Self {
180 18 : BranchIdTag::get_interner().get_or_intern(value)
181 18 : }
182 : }
183 : impl From<BranchId> for BranchIdInt {
184 0 : fn from(value: BranchId) -> Self {
185 0 : BranchIdTag::get_interner().get_or_intern(&value)
186 0 : }
187 : }
188 :
189 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
190 : pub struct ProjectIdTag;
191 : impl InternId for ProjectIdTag {
192 39 : fn get_interner() -> &'static StringInterner<Self> {
193 : static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
194 39 : ROLE_NAMES.get_or_init(Default::default)
195 39 : }
196 : }
197 : pub type ProjectIdInt = InternedString<ProjectIdTag>;
198 : impl From<&ProjectId> for ProjectIdInt {
199 30 : fn from(value: &ProjectId) -> Self {
200 30 : ProjectIdTag::get_interner().get_or_intern(value)
201 30 : }
202 : }
203 : impl From<ProjectId> for ProjectIdInt {
204 0 : fn from(value: ProjectId) -> Self {
205 0 : ProjectIdTag::get_interner().get_or_intern(&value)
206 0 : }
207 : }
208 :
209 : #[cfg(test)]
210 : mod tests {
211 : use std::sync::OnceLock;
212 :
213 : use super::InternId;
214 : use crate::intern::StringInterner;
215 :
216 : struct MyId;
217 : impl InternId for MyId {
218 1 : fn get_interner() -> &'static StringInterner<Self> {
219 : pub(crate) static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
220 1 : ROLE_NAMES.get_or_init(Default::default)
221 1 : }
222 : }
223 :
224 : #[test]
225 1 : fn push_many_strings() {
226 : use rand::rngs::StdRng;
227 : use rand::{Rng, SeedableRng};
228 : use rand_distr::Zipf;
229 :
230 1 : let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
231 1 : let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
232 1 :
233 1 : let interner = MyId::get_interner();
234 :
235 : const N: usize = 100_000;
236 1 : let mut verify = Vec::with_capacity(N);
237 100000 : for endpoint in endpoints.take(N) {
238 100000 : let endpoint = format!("ep-string-interning-{endpoint}");
239 100000 : let key = interner.get_or_intern(&endpoint);
240 100000 : verify.push((endpoint, key));
241 100000 : }
242 :
243 100001 : for (s, key) in verify {
244 100000 : assert_eq!(interner[key], s);
245 : }
246 :
247 : // 2031616/59861 = 34 bytes per string
248 1 : assert_eq!(interner.len(), 59_861);
249 : // will have other overhead for the internal hashmaps that are not accounted for.
250 1 : assert_eq!(interner.current_memory_usage(), 2_031_616);
251 1 : }
252 : }
|