Line data Source code
1 : use std::{
2 : hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
3 : };
4 :
5 : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
6 : use rustc_hash::FxHasher;
7 :
8 : use crate::{BranchId, EndpointId, ProjectId, RoleName};
9 :
10 : pub trait InternId: Sized + 'static {
11 : fn get_interner() -> &'static StringInterner<Self>;
12 : }
13 :
14 : pub struct StringInterner<Id> {
15 : inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
16 : _id: PhantomData<Id>,
17 : }
18 :
19 : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
20 : pub struct InternedString<Id> {
21 : inner: Spur,
22 : _id: PhantomData<Id>,
23 : }
24 :
25 : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
26 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 0 : self.as_str().fmt(f)
28 0 : }
29 : }
30 :
31 : impl<Id: InternId> InternedString<Id> {
32 24 : pub(crate) fn as_str(&self) -> &'static str {
33 24 : Id::get_interner().inner.resolve(&self.inner)
34 24 : }
35 210 : pub(crate) fn get(s: &str) -> Option<Self> {
36 210 : Id::get_interner().get(s)
37 210 : }
38 : }
39 :
40 : impl<Id: InternId> AsRef<str> for InternedString<Id> {
41 0 : fn as_ref(&self) -> &str {
42 0 : self.as_str()
43 0 : }
44 : }
45 :
46 : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
47 : type Target = str;
48 0 : fn deref(&self) -> &str {
49 0 : self.as_str()
50 0 : }
51 : }
52 :
53 : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
54 138 : fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
55 138 : struct Visitor<Id>(PhantomData<Id>);
56 138 : impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
57 138 : type Value = InternedString<Id>;
58 138 :
59 138 : fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 0 : formatter.write_str("a string")
61 0 : }
62 138 :
63 138 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
64 138 : where
65 138 : E: serde::de::Error,
66 138 : {
67 138 : Ok(Id::get_interner().get_or_intern(v))
68 138 : }
69 138 : }
70 138 : d.deserialize_str(Visitor::<Id>(PhantomData))
71 138 : }
72 : }
73 :
74 : impl<Id: InternId> serde::Serialize for InternedString<Id> {
75 24 : fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
76 24 : self.as_str().serialize(s)
77 24 : }
78 : }
79 :
80 : impl<Id: InternId> StringInterner<Id> {
81 318 : pub(crate) fn new() -> Self {
82 318 : StringInterner {
83 318 : inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
84 318 : Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
85 318 : // unbounded
86 318 : MemoryLimits::for_memory_usage(usize::MAX),
87 318 : BuildHasherDefault::<FxHasher>::default(),
88 318 : ),
89 318 : _id: PhantomData,
90 318 : }
91 318 : }
92 :
93 : #[cfg(test)]
94 6 : fn len(&self) -> usize {
95 6 : self.inner.len()
96 6 : }
97 :
98 : #[cfg(test)]
99 6 : fn current_memory_usage(&self) -> usize {
100 6 : self.inner.current_memory_usage()
101 6 : }
102 :
103 600720 : pub(crate) fn get_or_intern(&self, s: &str) -> InternedString<Id> {
104 600720 : InternedString {
105 600720 : inner: self.inner.get_or_intern(s),
106 600720 : _id: PhantomData,
107 600720 : }
108 600720 : }
109 :
110 210 : pub(crate) fn get(&self, s: &str) -> Option<InternedString<Id>> {
111 210 : Some(InternedString {
112 210 : inner: self.inner.get(s)?,
113 210 : _id: PhantomData,
114 : })
115 210 : }
116 : }
117 :
118 : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
119 : type Output = str;
120 :
121 600000 : fn index(&self, index: InternedString<Id>) -> &Self::Output {
122 600000 : self.inner.resolve(&index.inner)
123 600000 : }
124 : }
125 :
126 : impl<Id: InternId> Default for StringInterner<Id> {
127 318 : fn default() -> Self {
128 318 : Self::new()
129 318 : }
130 : }
131 :
132 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
133 : pub(crate) struct RoleNameTag;
134 : impl InternId for RoleNameTag {
135 150 : fn get_interner() -> &'static StringInterner<Self> {
136 150 : static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
137 150 : ROLE_NAMES.get_or_init(Default::default)
138 150 : }
139 : }
140 : pub(crate) type RoleNameInt = InternedString<RoleNameTag>;
141 : impl From<&RoleName> for RoleNameInt {
142 54 : fn from(value: &RoleName) -> Self {
143 54 : RoleNameTag::get_interner().get_or_intern(value)
144 54 : }
145 : }
146 :
147 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
148 : pub struct EndpointIdTag;
149 : impl InternId for EndpointIdTag {
150 414 : fn get_interner() -> &'static StringInterner<Self> {
151 414 : static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
152 414 : ROLE_NAMES.get_or_init(Default::default)
153 414 : }
154 : }
155 : pub type EndpointIdInt = InternedString<EndpointIdTag>;
156 : impl From<&EndpointId> for EndpointIdInt {
157 180 : fn from(value: &EndpointId) -> Self {
158 180 : EndpointIdTag::get_interner().get_or_intern(value)
159 180 : }
160 : }
161 : impl From<EndpointId> for EndpointIdInt {
162 60 : fn from(value: EndpointId) -> Self {
163 60 : EndpointIdTag::get_interner().get_or_intern(&value)
164 60 : }
165 : }
166 :
167 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
168 : pub struct BranchIdTag;
169 : impl InternId for BranchIdTag {
170 162 : fn get_interner() -> &'static StringInterner<Self> {
171 162 : static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
172 162 : ROLE_NAMES.get_or_init(Default::default)
173 162 : }
174 : }
175 : pub type BranchIdInt = InternedString<BranchIdTag>;
176 : impl From<&BranchId> for BranchIdInt {
177 108 : fn from(value: &BranchId) -> Self {
178 108 : BranchIdTag::get_interner().get_or_intern(value)
179 108 : }
180 : }
181 : impl From<BranchId> for BranchIdInt {
182 0 : fn from(value: BranchId) -> Self {
183 0 : BranchIdTag::get_interner().get_or_intern(&value)
184 0 : }
185 : }
186 :
187 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
188 : pub struct ProjectIdTag;
189 : impl InternId for ProjectIdTag {
190 228 : fn get_interner() -> &'static StringInterner<Self> {
191 228 : static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
192 228 : ROLE_NAMES.get_or_init(Default::default)
193 228 : }
194 : }
195 : pub type ProjectIdInt = InternedString<ProjectIdTag>;
196 : impl From<&ProjectId> for ProjectIdInt {
197 180 : fn from(value: &ProjectId) -> Self {
198 180 : ProjectIdTag::get_interner().get_or_intern(value)
199 180 : }
200 : }
201 : impl From<ProjectId> for ProjectIdInt {
202 0 : fn from(value: ProjectId) -> Self {
203 0 : ProjectIdTag::get_interner().get_or_intern(&value)
204 0 : }
205 : }
206 :
207 : #[cfg(test)]
208 : mod tests {
209 : use std::sync::OnceLock;
210 :
211 : use crate::intern::StringInterner;
212 :
213 : use super::InternId;
214 :
215 : struct MyId;
216 : impl InternId for MyId {
217 6 : fn get_interner() -> &'static StringInterner<Self> {
218 6 : pub(crate) static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
219 6 : ROLE_NAMES.get_or_init(Default::default)
220 6 : }
221 : }
222 :
223 : #[test]
224 6 : fn push_many_strings() {
225 6 : use rand::{rngs::StdRng, Rng, SeedableRng};
226 6 : use rand_distr::Zipf;
227 6 :
228 6 : let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
229 6 : let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
230 6 :
231 6 : let interner = MyId::get_interner();
232 6 :
233 6 : const N: usize = 100_000;
234 6 : let mut verify = Vec::with_capacity(N);
235 600000 : for endpoint in endpoints.take(N) {
236 600000 : let endpoint = format!("ep-string-interning-{endpoint}");
237 600000 : let key = interner.get_or_intern(&endpoint);
238 600000 : verify.push((endpoint, key));
239 600000 : }
240 :
241 600006 : for (s, key) in verify {
242 600000 : assert_eq!(interner[key], s);
243 : }
244 :
245 : // 2031616/59861 = 34 bytes per string
246 6 : assert_eq!(interner.len(), 59_861);
247 : // will have other overhead for the internal hashmaps that are not accounted for.
248 6 : assert_eq!(interner.current_memory_usage(), 2_031_616);
249 6 : }
250 : }
|