Line data Source code
1 : use std::sync::{
2 : atomic::{AtomicUsize, Ordering},
3 : Arc,
4 : };
5 : use tokio::sync::Semaphore;
6 :
7 : /// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
8 : /// `SemaphorePermit`, allowing use of `take` which does not require holding an outer mutex guard
9 : /// for the duration of initialization.
10 : ///
11 : /// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
12 : ///
13 : /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
14 : pub struct OnceCell<T> {
15 : inner: tokio::sync::RwLock<Inner<T>>,
16 : initializers: AtomicUsize,
17 : }
18 :
19 : impl<T> Default for OnceCell<T> {
20 : /// Create new uninitialized [`OnceCell`].
21 44123 : fn default() -> Self {
22 44123 : Self {
23 44123 : inner: Default::default(),
24 44123 : initializers: AtomicUsize::new(0),
25 44123 : }
26 44123 : }
27 : }
28 :
29 : /// Semaphore is the current state:
30 : /// - open semaphore means the value is `None`, not yet initialized
31 : /// - closed semaphore means the value has been initialized
32 0 : #[derive(Debug)]
33 : struct Inner<T> {
34 : init_semaphore: Arc<Semaphore>,
35 : value: Option<T>,
36 : }
37 :
38 : impl<T> Default for Inner<T> {
39 46690 : fn default() -> Self {
40 46690 : Self {
41 46690 : init_semaphore: Arc::new(Semaphore::new(1)),
42 46690 : value: None,
43 46690 : }
44 46690 : }
45 : }
46 :
47 : impl<T> OnceCell<T> {
48 : /// Creates an already initialized `OnceCell` with the given value.
49 34667 : pub fn new(value: T) -> Self {
50 34667 : let sem = Semaphore::new(1);
51 34667 : sem.close();
52 34667 : Self {
53 34667 : inner: tokio::sync::RwLock::new(Inner {
54 34667 : init_semaphore: Arc::new(sem),
55 34667 : value: Some(value),
56 34667 : }),
57 34667 : initializers: AtomicUsize::new(0),
58 34667 : }
59 34667 : }
60 :
61 : /// Returns a guard to an existing initialized value, or uniquely initializes the value before
62 : /// returning the guard.
63 : ///
64 : /// Initializing might wait on any existing [`GuardMut::take_and_deinit`] deinitialization.
65 : ///
66 : /// Initialization is panic-safe and cancellation-safe.
67 23943801 : pub async fn get_mut_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardMut<'_, T>, E>
68 23943801 : where
69 23943801 : F: FnOnce(InitPermit) -> Fut,
70 23943801 : Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
71 23943801 : {
72 10926 : let sem = {
73 23943801 : let guard = self.inner.write().await;
74 23943801 : if guard.value.is_some() {
75 23932875 : return Ok(GuardMut(guard));
76 10926 : }
77 10926 : guard.init_semaphore.clone()
78 : };
79 :
80 10926 : let permit = {
81 : // increment the count for the duration of queued
82 10926 : let _guard = CountWaitingInitializers::start(self);
83 10926 : sem.acquire_owned().await
84 : };
85 :
86 10926 : match permit {
87 10402 : Ok(permit) => {
88 10402 : let permit = InitPermit(permit);
89 30618 : let (value, _permit) = factory(permit).await?;
90 :
91 9776 : let guard = self.inner.write().await;
92 :
93 9776 : Ok(Self::set0(value, guard))
94 : }
95 524 : Err(_closed) => {
96 524 : let guard = self.inner.write().await;
97 : assert!(
98 524 : guard.value.is_some(),
99 0 : "semaphore got closed, must be initialized"
100 : );
101 524 : return Ok(GuardMut(guard));
102 : }
103 : }
104 23943790 : }
105 :
106 : /// Returns a guard to an existing initialized value, or uniquely initializes the value before
107 : /// returning the guard.
108 : ///
109 : /// Initialization is panic-safe and cancellation-safe.
110 0 : pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardRef<'_, T>, E>
111 0 : where
112 0 : F: FnOnce(InitPermit) -> Fut,
113 0 : Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
114 0 : {
115 0 : let sem = {
116 0 : let guard = self.inner.read().await;
117 0 : if guard.value.is_some() {
118 0 : return Ok(GuardRef(guard));
119 0 : }
120 0 : guard.init_semaphore.clone()
121 : };
122 :
123 0 : let permit = {
124 : // increment the count for the duration of queued
125 0 : let _guard = CountWaitingInitializers::start(self);
126 0 : sem.acquire_owned().await
127 : };
128 :
129 0 : match permit {
130 0 : Ok(permit) => {
131 0 : let permit = InitPermit(permit);
132 0 : let (value, _permit) = factory(permit).await?;
133 :
134 0 : let guard = self.inner.write().await;
135 :
136 0 : Ok(Self::set0(value, guard).downgrade())
137 : }
138 0 : Err(_closed) => {
139 0 : let guard = self.inner.read().await;
140 : assert!(
141 0 : guard.value.is_some(),
142 0 : "semaphore got closed, must be initialized"
143 : );
144 0 : return Ok(GuardRef(guard));
145 : }
146 : }
147 0 : }
148 :
149 : /// Assuming a permit is held after previous call to [`GuardMut::take_and_deinit`], it can be used
150 : /// to complete initializing the inner value.
151 : ///
152 : /// # Panics
153 : ///
154 : /// If the inner has already been initialized.
155 4 : pub async fn set(&self, value: T, _permit: InitPermit) -> GuardMut<'_, T> {
156 4 : let guard = self.inner.write().await;
157 :
158 : // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
159 : // give more permits right now.
160 4 : if guard.init_semaphore.try_acquire().is_ok() {
161 0 : drop(guard);
162 0 : panic!("permit is of wrong origin");
163 4 : }
164 4 :
165 4 : Self::set0(value, guard)
166 4 : }
167 :
168 9780 : fn set0(value: T, mut guard: tokio::sync::RwLockWriteGuard<'_, Inner<T>>) -> GuardMut<'_, T> {
169 9780 : if guard.value.is_some() {
170 0 : drop(guard);
171 0 : unreachable!("we won permit, must not be initialized");
172 9780 : }
173 9780 : guard.value = Some(value);
174 9780 : guard.init_semaphore.close();
175 9780 : GuardMut(guard)
176 9780 : }
177 :
178 : /// Returns a guard to an existing initialized value, if any.
179 8152 : pub async fn get_mut(&self) -> Option<GuardMut<'_, T>> {
180 8152 : let guard = self.inner.write().await;
181 8152 : if guard.value.is_some() {
182 7398 : Some(GuardMut(guard))
183 : } else {
184 754 : None
185 : }
186 8152 : }
187 :
188 : /// Returns a guard to an existing initialized value, if any.
189 0 : pub async fn get(&self) -> Option<GuardRef<'_, T>> {
190 0 : let guard = self.inner.read().await;
191 0 : if guard.value.is_some() {
192 0 : Some(GuardRef(guard))
193 : } else {
194 0 : None
195 : }
196 0 : }
197 :
198 : /// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
199 9768 : pub fn initializer_count(&self) -> usize {
200 9768 : self.initializers.load(Ordering::Relaxed)
201 9768 : }
202 : }
203 :
204 : /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
205 : /// initializing task for example at the end of initialization.
206 : struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
207 :
208 : impl<'a, T> CountWaitingInitializers<'a, T> {
209 10926 : fn start(target: &'a OnceCell<T>) -> Self {
210 10926 : target.initializers.fetch_add(1, Ordering::Relaxed);
211 10926 : CountWaitingInitializers(target)
212 10926 : }
213 : }
214 :
215 : impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
216 10926 : fn drop(&mut self) {
217 10926 : self.0.initializers.fetch_sub(1, Ordering::Relaxed);
218 10926 : }
219 : }
220 :
221 : /// Uninteresting guard object to allow short-lived access to inspect or clone the held,
222 : /// initialized value.
223 0 : #[derive(Debug)]
224 : pub struct GuardMut<'a, T>(tokio::sync::RwLockWriteGuard<'a, Inner<T>>);
225 :
226 : impl<T> std::ops::Deref for GuardMut<'_, T> {
227 : type Target = T;
228 :
229 2771 : fn deref(&self) -> &Self::Target {
230 2771 : self.0
231 2771 : .value
232 2771 : .as_ref()
233 2771 : .expect("guard is not created unless value has been initialized")
234 2771 : }
235 : }
236 :
237 : impl<T> std::ops::DerefMut for GuardMut<'_, T> {
238 23945530 : fn deref_mut(&mut self) -> &mut Self::Target {
239 23945530 : self.0
240 23945530 : .value
241 23945530 : .as_mut()
242 23945530 : .expect("guard is not created unless value has been initialized")
243 23945530 : }
244 : }
245 :
246 : impl<'a, T> GuardMut<'a, T> {
247 : /// Take the current value, and a new permit for it's deinitialization.
248 : ///
249 : /// The permit will be on a semaphore part of the new internal value, and any following
250 : /// [`OnceCell::get_or_init`] will wait on it to complete.
251 2567 : pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
252 2567 : let mut swapped = Inner::default();
253 2567 : let permit = swapped
254 2567 : .init_semaphore
255 2567 : .clone()
256 2567 : .try_acquire_owned()
257 2567 : .expect("we just created this");
258 2567 : std::mem::swap(&mut *self.0, &mut swapped);
259 2567 : swapped
260 2567 : .value
261 2567 : .map(|v| (v, InitPermit(permit)))
262 2567 : .expect("guard is not created unless value has been initialized")
263 2567 : }
264 :
265 0 : pub fn downgrade(self) -> GuardRef<'a, T> {
266 0 : GuardRef(self.0.downgrade())
267 0 : }
268 : }
269 :
270 0 : #[derive(Debug)]
271 : pub struct GuardRef<'a, T>(tokio::sync::RwLockReadGuard<'a, Inner<T>>);
272 :
273 : impl<T> std::ops::Deref for GuardRef<'_, T> {
274 : type Target = T;
275 :
276 0 : fn deref(&self) -> &Self::Target {
277 0 : self.0
278 0 : .value
279 0 : .as_ref()
280 0 : .expect("guard is not created unless value has been initialized")
281 0 : }
282 : }
283 :
284 : /// Type held by OnceCell (de)initializing task.
285 : pub struct InitPermit(tokio::sync::OwnedSemaphorePermit);
286 :
287 : #[cfg(test)]
288 : mod tests {
289 : use super::*;
290 : use std::{
291 : convert::Infallible,
292 : sync::atomic::{AtomicUsize, Ordering},
293 : time::Duration,
294 : };
295 :
296 2 : #[tokio::test]
297 2 : async fn many_initializers() {
298 2 : #[derive(Default, Debug)]
299 2 : struct Counters {
300 2 : factory_got_to_run: AtomicUsize,
301 2 : future_polled: AtomicUsize,
302 2 : winners: AtomicUsize,
303 2 : }
304 2 :
305 2 : let initializers = 100;
306 2 :
307 2 : let cell = Arc::new(OnceCell::default());
308 2 : let counters = Arc::new(Counters::default());
309 2 : let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
310 2 :
311 2 : let mut js = tokio::task::JoinSet::new();
312 200 : for i in 0..initializers {
313 200 : js.spawn({
314 200 : let cell = cell.clone();
315 200 : let counters = counters.clone();
316 200 : let barrier = barrier.clone();
317 200 :
318 200 : async move {
319 200 : barrier.wait().await;
320 200 : let won = {
321 200 : let g = cell
322 200 : .get_mut_or_init(|permit| {
323 2 : counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
324 2 : async {
325 2 : counters.future_polled.fetch_add(1, Ordering::Relaxed);
326 2 : Ok::<_, Infallible>((i, permit))
327 2 : }
328 200 : })
329 0 : .await
330 200 : .unwrap();
331 200 :
332 200 : *g == i
333 200 : };
334 200 :
335 200 : if won {
336 2 : counters.winners.fetch_add(1, Ordering::Relaxed);
337 198 : }
338 200 : }
339 200 : });
340 200 : }
341 :
342 2 : barrier.wait().await;
343 :
344 202 : while let Some(next) = js.join_next().await {
345 200 : next.expect("no panics expected");
346 200 : }
347 :
348 2 : let mut counters = Arc::try_unwrap(counters).unwrap();
349 2 :
350 2 : assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
351 2 : assert_eq!(*counters.future_polled.get_mut(), 1);
352 2 : assert_eq!(*counters.winners.get_mut(), 1);
353 : }
354 :
355 2 : #[tokio::test(start_paused = true)]
356 2 : async fn reinit_waits_for_deinit() {
357 2 : // with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
358 2 : let sleep_for = Duration::from_secs(1);
359 2 : let initial = 42;
360 2 : let reinit = 1;
361 2 : let cell = Arc::new(OnceCell::new(initial));
362 2 :
363 2 : let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
364 2 :
365 2 : let jh = tokio::spawn({
366 2 : let cell = cell.clone();
367 2 : let deinitialization_started = deinitialization_started.clone();
368 2 : async move {
369 2 : let (answer, _permit) = cell
370 2 : .get_mut()
371 0 : .await
372 2 : .expect("initialized to value")
373 2 : .take_and_deinit();
374 2 : assert_eq!(answer, initial);
375 :
376 2 : deinitialization_started.wait().await;
377 2 : tokio::time::sleep(sleep_for).await;
378 2 : }
379 2 : });
380 2 :
381 2 : deinitialization_started.wait().await;
382 :
383 2 : let started_at = tokio::time::Instant::now();
384 2 : cell.get_mut_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
385 2 : .await
386 2 : .unwrap();
387 2 :
388 2 : let elapsed = started_at.elapsed();
389 2 : assert!(
390 2 : elapsed >= sleep_for,
391 0 : "initialization should had taken at least the time time slept with permit"
392 : );
393 :
394 2 : jh.await.unwrap();
395 :
396 2 : assert_eq!(*cell.get_mut().await.unwrap(), reinit);
397 : }
398 :
399 2 : #[tokio::test]
400 2 : async fn reinit_with_deinit_permit() {
401 2 : let cell = Arc::new(OnceCell::new(42));
402 :
403 2 : let (mol, permit) = cell.get_mut().await.unwrap().take_and_deinit();
404 2 : cell.set(5, permit).await;
405 2 : assert_eq!(*cell.get_mut().await.unwrap(), 5);
406 :
407 2 : let (five, permit) = cell.get_mut().await.unwrap().take_and_deinit();
408 2 : assert_eq!(5, five);
409 2 : cell.set(mol, permit).await;
410 2 : assert_eq!(*cell.get_mut().await.unwrap(), 42);
411 : }
412 :
413 2 : #[tokio::test]
414 2 : async fn initialization_attemptable_until_ok() {
415 2 : let cell = OnceCell::default();
416 :
417 22 : for _ in 0..10 {
418 20 : cell.get_mut_or_init(|_permit| async { Err("whatever error") })
419 0 : .await
420 20 : .unwrap_err();
421 : }
422 :
423 2 : let g = cell
424 2 : .get_mut_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
425 0 : .await
426 2 : .unwrap();
427 2 : assert_eq!(*g, "finally success");
428 : }
429 :
430 2 : #[tokio::test]
431 2 : async fn initialization_is_cancellation_safe() {
432 2 : let cell = OnceCell::default();
433 2 :
434 2 : let barrier = tokio::sync::Barrier::new(2);
435 2 :
436 2 : let initializer = cell.get_mut_or_init(|permit| async {
437 2 : barrier.wait().await;
438 0 : futures::future::pending::<()>().await;
439 :
440 0 : Ok::<_, Infallible>(("never reached", permit))
441 2 : });
442 2 :
443 2 : tokio::select! {
444 2 : _ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
445 2 : _ = barrier.wait() => {}
446 2 : };
447 :
448 : // now initializer is dropped
449 :
450 2 : assert!(cell.get_mut().await.is_none());
451 :
452 2 : let g = cell
453 2 : .get_mut_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
454 0 : .await
455 2 : .unwrap();
456 2 : assert_eq!(*g, "now initialized");
457 : }
458 : }
|