Line data Source code
1 : use std::{
2 : hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
3 : };
4 :
5 : use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
6 : use rustc_hash::FxHasher;
7 :
8 : use crate::{BranchId, EndpointId, ProjectId, RoleName};
9 :
10 : pub trait InternId: Sized + 'static {
11 : fn get_interner() -> &'static StringInterner<Self>;
12 : }
13 :
14 : pub struct StringInterner<Id> {
15 : inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
16 : _id: PhantomData<Id>,
17 : }
18 :
19 : #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
20 : pub struct InternedString<Id> {
21 : inner: Spur,
22 : _id: PhantomData<Id>,
23 : }
24 :
25 : impl<Id: InternId> std::fmt::Display for InternedString<Id> {
26 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 0 : self.as_str().fmt(f)
28 0 : }
29 : }
30 :
31 : impl<Id: InternId> InternedString<Id> {
32 4 : pub(crate) fn as_str(&self) -> &'static str {
33 4 : Id::get_interner().inner.resolve(&self.inner)
34 4 : }
35 35 : pub(crate) fn get(s: &str) -> Option<Self> {
36 35 : Id::get_interner().get(s)
37 35 : }
38 : }
39 :
40 : impl<Id: InternId> AsRef<str> for InternedString<Id> {
41 0 : fn as_ref(&self) -> &str {
42 0 : self.as_str()
43 0 : }
44 : }
45 :
46 : impl<Id: InternId> std::ops::Deref for InternedString<Id> {
47 : type Target = str;
48 0 : fn deref(&self) -> &str {
49 0 : self.as_str()
50 0 : }
51 : }
52 :
53 : impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
54 26 : fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
55 : struct Visitor<Id>(PhantomData<Id>);
56 : impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
57 : type Value = InternedString<Id>;
58 :
59 0 : fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 0 : formatter.write_str("a string")
61 0 : }
62 :
63 26 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
64 26 : where
65 26 : E: serde::de::Error,
66 26 : {
67 26 : Ok(Id::get_interner().get_or_intern(v))
68 26 : }
69 : }
70 26 : d.deserialize_str(Visitor::<Id>(PhantomData))
71 26 : }
72 : }
73 :
74 : impl<Id: InternId> serde::Serialize for InternedString<Id> {
75 4 : fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
76 4 : self.as_str().serialize(s)
77 4 : }
78 : }
79 :
80 : impl<Id: InternId> StringInterner<Id> {
81 53 : pub(crate) fn new() -> Self {
82 53 : StringInterner {
83 53 : inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
84 53 : Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
85 53 : // unbounded
86 53 : MemoryLimits::for_memory_usage(usize::MAX),
87 53 : BuildHasherDefault::<FxHasher>::default(),
88 53 : ),
89 53 : _id: PhantomData,
90 53 : }
91 53 : }
92 :
93 : #[cfg(test)]
94 1 : fn len(&self) -> usize {
95 1 : self.inner.len()
96 1 : }
97 :
98 : #[cfg(test)]
99 1 : fn current_memory_usage(&self) -> usize {
100 1 : self.inner.current_memory_usage()
101 1 : }
102 :
103 100123 : pub(crate) fn get_or_intern(&self, s: &str) -> InternedString<Id> {
104 100123 : InternedString {
105 100123 : inner: self.inner.get_or_intern(s),
106 100123 : _id: PhantomData,
107 100123 : }
108 100123 : }
109 :
110 35 : pub(crate) fn get(&self, s: &str) -> Option<InternedString<Id>> {
111 35 : Some(InternedString {
112 35 : inner: self.inner.get(s)?,
113 35 : _id: PhantomData,
114 : })
115 35 : }
116 : }
117 :
118 : impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
119 : type Output = str;
120 :
121 100000 : fn index(&self, index: InternedString<Id>) -> &Self::Output {
122 100000 : self.inner.resolve(&index.inner)
123 100000 : }
124 : }
125 :
126 : impl<Id: InternId> Default for StringInterner<Id> {
127 53 : fn default() -> Self {
128 53 : Self::new()
129 53 : }
130 : }
131 :
132 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
133 : pub(crate) struct RoleNameTag;
134 : impl InternId for RoleNameTag {
135 25 : fn get_interner() -> &'static StringInterner<Self> {
136 : static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
137 25 : ROLE_NAMES.get_or_init(Default::default)
138 25 : }
139 : }
140 : pub(crate) type RoleNameInt = InternedString<RoleNameTag>;
141 : impl From<&RoleName> for RoleNameInt {
142 9 : fn from(value: &RoleName) -> Self {
143 9 : RoleNameTag::get_interner().get_or_intern(value)
144 9 : }
145 : }
146 :
147 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
148 : pub struct EndpointIdTag;
149 : impl InternId for EndpointIdTag {
150 70 : fn get_interner() -> &'static StringInterner<Self> {
151 : static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
152 70 : ROLE_NAMES.get_or_init(Default::default)
153 70 : }
154 : }
155 : pub type EndpointIdInt = InternedString<EndpointIdTag>;
156 : impl From<&EndpointId> for EndpointIdInt {
157 30 : fn from(value: &EndpointId) -> Self {
158 30 : EndpointIdTag::get_interner().get_or_intern(value)
159 30 : }
160 : }
161 : impl From<EndpointId> for EndpointIdInt {
162 10 : fn from(value: EndpointId) -> Self {
163 10 : EndpointIdTag::get_interner().get_or_intern(&value)
164 10 : }
165 : }
166 :
167 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
168 : pub struct BranchIdTag;
169 : impl InternId for BranchIdTag {
170 28 : fn get_interner() -> &'static StringInterner<Self> {
171 : static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
172 28 : ROLE_NAMES.get_or_init(Default::default)
173 28 : }
174 : }
175 : pub type BranchIdInt = InternedString<BranchIdTag>;
176 : impl From<&BranchId> for BranchIdInt {
177 18 : fn from(value: &BranchId) -> Self {
178 18 : BranchIdTag::get_interner().get_or_intern(value)
179 18 : }
180 : }
181 : impl From<BranchId> for BranchIdInt {
182 0 : fn from(value: BranchId) -> Self {
183 0 : BranchIdTag::get_interner().get_or_intern(&value)
184 0 : }
185 : }
186 :
187 : #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
188 : pub struct ProjectIdTag;
189 : impl InternId for ProjectIdTag {
190 39 : fn get_interner() -> &'static StringInterner<Self> {
191 : static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
192 39 : ROLE_NAMES.get_or_init(Default::default)
193 39 : }
194 : }
195 : pub type ProjectIdInt = InternedString<ProjectIdTag>;
196 : impl From<&ProjectId> for ProjectIdInt {
197 30 : fn from(value: &ProjectId) -> Self {
198 30 : ProjectIdTag::get_interner().get_or_intern(value)
199 30 : }
200 : }
201 : impl From<ProjectId> for ProjectIdInt {
202 0 : fn from(value: ProjectId) -> Self {
203 0 : ProjectIdTag::get_interner().get_or_intern(&value)
204 0 : }
205 : }
206 :
207 : #[cfg(test)]
208 : mod tests {
209 : use std::sync::OnceLock;
210 :
211 : use crate::intern::StringInterner;
212 :
213 : use super::InternId;
214 :
215 : struct MyId;
216 : impl InternId for MyId {
217 1 : fn get_interner() -> &'static StringInterner<Self> {
218 : pub(crate) static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
219 1 : ROLE_NAMES.get_or_init(Default::default)
220 1 : }
221 : }
222 :
223 : #[test]
224 1 : fn push_many_strings() {
225 : use rand::{rngs::StdRng, Rng, SeedableRng};
226 : use rand_distr::Zipf;
227 :
228 1 : let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
229 1 : let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
230 1 :
231 1 : let interner = MyId::get_interner();
232 :
233 : const N: usize = 100_000;
234 1 : let mut verify = Vec::with_capacity(N);
235 100000 : for endpoint in endpoints.take(N) {
236 100000 : let endpoint = format!("ep-string-interning-{endpoint}");
237 100000 : let key = interner.get_or_intern(&endpoint);
238 100000 : verify.push((endpoint, key));
239 100000 : }
240 :
241 100001 : for (s, key) in verify {
242 100000 : assert_eq!(interner[key], s);
243 : }
244 :
245 : // 2031616/59861 = 34 bytes per string
246 1 : assert_eq!(interner.len(), 59_861);
247 : // will have other overhead for the internal hashmaps that are not accounted for.
248 1 : assert_eq!(interner.current_memory_usage(), 2_031_616);
249 1 : }
250 : }
|