Line data Source code
1 : use std::sync::{
2 : atomic::{AtomicUsize, Ordering},
3 : Arc, Mutex, MutexGuard,
4 : };
5 : use tokio::sync::Semaphore;
6 :
7 : /// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
8 : /// `SemaphorePermit`, allowing use of `take` which does not require holding an outer mutex guard
9 : /// for the duration of initialization.
10 : ///
11 : /// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
12 : ///
13 : /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
14 : pub struct OnceCell<T> {
15 : inner: Mutex<Inner<T>>,
16 : initializers: AtomicUsize,
17 : }
18 :
19 : impl<T> Default for OnceCell<T> {
20 : /// Create new uninitialized [`OnceCell`].
21 54 : fn default() -> Self {
22 54 : Self {
23 54 : inner: Default::default(),
24 54 : initializers: AtomicUsize::new(0),
25 54 : }
26 54 : }
27 : }
28 :
29 : /// Semaphore is the current state:
30 : /// - open semaphore means the value is `None`, not yet initialized
31 : /// - closed semaphore means the value has been initialized
32 : #[derive(Debug)]
33 : struct Inner<T> {
34 : init_semaphore: Arc<Semaphore>,
35 : value: Option<T>,
36 : }
37 :
38 : impl<T> Default for Inner<T> {
39 1892 : fn default() -> Self {
40 1892 : Self {
41 1892 : init_semaphore: Arc::new(Semaphore::new(1)),
42 1892 : value: None,
43 1892 : }
44 1892 : }
45 : }
46 :
47 : impl<T> OnceCell<T> {
48 : /// Creates an already initialized `OnceCell` with the given value.
49 5184 : pub fn new(value: T) -> Self {
50 5184 : let sem = Semaphore::new(1);
51 5184 : sem.close();
52 5184 : Self {
53 5184 : inner: Mutex::new(Inner {
54 5184 : init_semaphore: Arc::new(sem),
55 5184 : value: Some(value),
56 5184 : }),
57 5184 : initializers: AtomicUsize::new(0),
58 5184 : }
59 5184 : }
60 :
61 : /// Returns a guard to an existing initialized value, or uniquely initializes the value before
62 : /// returning the guard.
63 : ///
64 : /// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
65 : ///
66 : /// Initialization is panic-safe and cancellation-safe.
67 714 : pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
68 714 : where
69 714 : F: FnOnce(InitPermit) -> Fut,
70 714 : Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
71 714 : {
72 : loop {
73 120 : let sem = {
74 720 : let guard = self.inner.lock().unwrap();
75 720 : if guard.value.is_some() {
76 600 : return Ok(Guard(guard));
77 120 : }
78 120 : guard.init_semaphore.clone()
79 : };
80 :
81 : {
82 114 : let permit = {
83 : // increment the count for the duration of queued
84 120 : let _guard = CountWaitingInitializers::start(self);
85 120 : sem.acquire().await
86 : };
87 :
88 114 : let Ok(permit) = permit else {
89 6 : let guard = self.inner.lock().unwrap();
90 6 : if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
91 : // there was a take_and_deinit in between
92 6 : continue;
93 0 : }
94 0 : assert!(
95 0 : guard.value.is_some(),
96 0 : "semaphore got closed, must be initialized"
97 : );
98 0 : return Ok(Guard(guard));
99 : };
100 :
101 108 : permit.forget();
102 108 : }
103 108 :
104 108 : let permit = InitPermit(sem);
105 108 : let (value, _permit) = factory(permit).await?;
106 :
107 42 : let guard = self.inner.lock().unwrap();
108 42 :
109 42 : return Ok(Self::set0(value, guard));
110 : }
111 702 : }
112 :
113 : /// Returns a guard to an existing initialized value, or returns an unique initialization
114 : /// permit which can be used to initialize this `OnceCell` using `OnceCell::set`.
115 639891 : pub async fn get_or_init_detached(&self) -> Result<Guard<'_, T>, InitPermit> {
116 : // It looks like OnceCell::get_or_init could be implemented using this method instead of
117 : // duplication. However, that makes the future be !Send due to possibly holding on to the
118 : // MutexGuard over an await point.
119 : loop {
120 60 : let sem = {
121 639891 : let guard = self.inner.lock().unwrap();
122 639891 : if guard.value.is_some() {
123 639831 : return Ok(Guard(guard));
124 60 : }
125 60 : guard.init_semaphore.clone()
126 : };
127 :
128 : {
129 60 : let permit = {
130 : // increment the count for the duration of queued
131 60 : let _guard = CountWaitingInitializers::start(self);
132 60 : sem.acquire().await
133 : };
134 :
135 60 : let Ok(permit) = permit else {
136 0 : let guard = self.inner.lock().unwrap();
137 0 : if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
138 : // there was a take_and_deinit in between
139 0 : continue;
140 0 : }
141 0 : assert!(
142 0 : guard.value.is_some(),
143 0 : "semaphore got closed, must be initialized"
144 : );
145 0 : return Ok(Guard(guard));
146 : };
147 :
148 60 : permit.forget();
149 60 : }
150 60 :
151 60 : let permit = InitPermit(sem);
152 60 : return Err(permit);
153 : }
154 639891 : }
155 :
156 : /// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
157 : /// to complete initializing the inner value.
158 : ///
159 : /// # Panics
160 : ///
161 : /// If the inner has already been initialized.
162 84 : pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
163 84 : let guard = self.inner.lock().unwrap();
164 84 :
165 84 : // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
166 84 : // give more permits right now.
167 84 : if guard.init_semaphore.try_acquire().is_ok() {
168 0 : drop(guard);
169 0 : panic!("permit is of wrong origin");
170 84 : }
171 84 :
172 84 : Self::set0(value, guard)
173 84 : }
174 :
175 126 : fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
176 126 : if guard.value.is_some() {
177 0 : drop(guard);
178 0 : unreachable!("we won permit, must not be initialized");
179 126 : }
180 126 : guard.value = Some(value);
181 126 : guard.init_semaphore.close();
182 126 : Guard(guard)
183 126 : }
184 :
185 : /// Returns a guard to an existing initialized value, if any.
186 386 : pub fn get(&self) -> Option<Guard<'_, T>> {
187 386 : let guard = self.inner.lock().unwrap();
188 386 : if guard.value.is_some() {
189 338 : Some(Guard(guard))
190 : } else {
191 48 : None
192 : }
193 386 : }
194 :
195 : /// Like [`Guard::take_and_deinit`], but will return `None` if this OnceCell was never
196 : /// initialized.
197 1742 : pub fn take_and_deinit(&mut self) -> Option<(T, InitPermit)> {
198 1742 : let inner = self.inner.get_mut().unwrap();
199 1742 :
200 1742 : inner.take_and_deinit()
201 1742 : }
202 :
203 : /// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
204 90 : pub fn initializer_count(&self) -> usize {
205 90 : self.initializers.load(Ordering::Relaxed)
206 90 : }
207 : }
208 :
209 : /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
210 : /// initializing task for example at the end of initialization.
211 : struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
212 :
213 : impl<'a, T> CountWaitingInitializers<'a, T> {
214 180 : fn start(target: &'a OnceCell<T>) -> Self {
215 180 : target.initializers.fetch_add(1, Ordering::Relaxed);
216 180 : CountWaitingInitializers(target)
217 180 : }
218 : }
219 :
220 : impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
221 180 : fn drop(&mut self) {
222 180 : self.0.initializers.fetch_sub(1, Ordering::Relaxed);
223 180 : }
224 : }
225 :
226 : /// Uninteresting guard object to allow short-lived access to inspect or clone the held,
227 : /// initialized value.
228 : #[derive(Debug)]
229 : pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
230 :
231 : impl<T> std::ops::Deref for Guard<'_, T> {
232 : type Target = T;
233 :
234 902 : fn deref(&self) -> &Self::Target {
235 902 : self.0
236 902 : .value
237 902 : .as_ref()
238 902 : .expect("guard is not created unless value has been initialized")
239 902 : }
240 : }
241 :
242 : impl<T> std::ops::DerefMut for Guard<'_, T> {
243 639861 : fn deref_mut(&mut self) -> &mut Self::Target {
244 639861 : self.0
245 639861 : .value
246 639861 : .as_mut()
247 639861 : .expect("guard is not created unless value has been initialized")
248 639861 : }
249 : }
250 :
251 : impl<'a, T> Guard<'a, T> {
252 : /// Take the current value, and a new permit for it's deinitialization.
253 : ///
254 : /// The permit will be on a semaphore part of the new internal value, and any following
255 : /// [`OnceCell::get_or_init`] will wait on it to complete.
256 114 : pub fn take_and_deinit(mut self) -> (T, InitPermit) {
257 114 : self.0
258 114 : .take_and_deinit()
259 114 : .expect("guard is not created unless value has been initialized")
260 114 : }
261 : }
262 :
263 : impl<T> Inner<T> {
264 1856 : pub fn take_and_deinit(&mut self) -> Option<(T, InitPermit)> {
265 1856 : let value = self.value.take()?;
266 :
267 1838 : let mut swapped = Inner::default();
268 1838 : let sem = swapped.init_semaphore.clone();
269 1838 : // acquire and forget right away, moving the control over to InitPermit
270 1838 : sem.try_acquire().expect("we just created this").forget();
271 1838 : let permit = InitPermit(sem);
272 1838 : std::mem::swap(self, &mut swapped);
273 1838 : Some((value, permit))
274 1856 : }
275 : }
276 :
277 : /// Type held by OnceCell (de)initializing task.
278 : ///
279 : /// On drop, this type will return the permit.
280 : pub struct InitPermit(Arc<tokio::sync::Semaphore>);
281 :
282 : impl std::fmt::Debug for InitPermit {
283 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 0 : let ptr = Arc::as_ptr(&self.0) as *const ();
285 0 : f.debug_tuple("InitPermit").field(&ptr).finish()
286 0 : }
287 : }
288 :
289 : impl Drop for InitPermit {
290 2006 : fn drop(&mut self) {
291 2006 : assert_eq!(
292 2006 : self.0.available_permits(),
293 : 0,
294 0 : "InitPermit should only exist as the unique permit"
295 : );
296 2006 : self.0.add_permits(1);
297 2006 : }
298 : }
299 :
300 : #[cfg(test)]
301 : mod tests {
302 : use futures::Future;
303 :
304 : use super::*;
305 : use std::{
306 : convert::Infallible,
307 : pin::{pin, Pin},
308 : time::Duration,
309 : };
310 :
311 : #[tokio::test]
312 6 : async fn many_initializers() {
313 6 : #[derive(Default, Debug)]
314 6 : struct Counters {
315 6 : factory_got_to_run: AtomicUsize,
316 6 : future_polled: AtomicUsize,
317 6 : winners: AtomicUsize,
318 6 : }
319 6 :
320 6 : let initializers = 100;
321 6 :
322 6 : let cell = Arc::new(OnceCell::default());
323 6 : let counters = Arc::new(Counters::default());
324 6 : let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
325 6 :
326 6 : let mut js = tokio::task::JoinSet::new();
327 600 : for i in 0..initializers {
328 600 : js.spawn({
329 600 : let cell = cell.clone();
330 600 : let counters = counters.clone();
331 600 : let barrier = barrier.clone();
332 600 :
333 600 : async move {
334 600 : barrier.wait().await;
335 600 : let won = {
336 600 : let g = cell
337 600 : .get_or_init(|permit| {
338 6 : counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
339 6 : async {
340 6 : counters.future_polled.fetch_add(1, Ordering::Relaxed);
341 6 : Ok::<_, Infallible>((i, permit))
342 6 : }
343 600 : })
344 6 : .await
345 600 : .unwrap();
346 600 :
347 600 : *g == i
348 600 : };
349 600 :
350 600 : if won {
351 6 : counters.winners.fetch_add(1, Ordering::Relaxed);
352 594 : }
353 600 : }
354 600 : });
355 600 : }
356 6 :
357 6 : barrier.wait().await;
358 6 :
359 606 : while let Some(next) = js.join_next().await {
360 600 : next.expect("no panics expected");
361 600 : }
362 6 :
363 6 : let mut counters = Arc::try_unwrap(counters).unwrap();
364 6 :
365 6 : assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
366 6 : assert_eq!(*counters.future_polled.get_mut(), 1);
367 6 : assert_eq!(*counters.winners.get_mut(), 1);
368 6 : }
369 :
370 : #[tokio::test(start_paused = true)]
371 6 : async fn reinit_waits_for_deinit() {
372 6 : // with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
373 6 : let sleep_for = Duration::from_secs(1);
374 6 : let initial = 42;
375 6 : let reinit = 1;
376 6 : let cell = Arc::new(OnceCell::new(initial));
377 6 :
378 6 : let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
379 6 :
380 6 : let jh = tokio::spawn({
381 6 : let cell = cell.clone();
382 6 : let deinitialization_started = deinitialization_started.clone();
383 6 : async move {
384 6 : let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
385 6 : assert_eq!(answer, initial);
386 6 :
387 6 : deinitialization_started.wait().await;
388 6 : tokio::time::sleep(sleep_for).await;
389 6 : }
390 6 : });
391 6 :
392 6 : deinitialization_started.wait().await;
393 6 :
394 6 : let started_at = tokio::time::Instant::now();
395 6 : cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
396 6 : .await
397 6 : .unwrap();
398 6 :
399 6 : let elapsed = started_at.elapsed();
400 6 : assert!(
401 6 : elapsed >= sleep_for,
402 6 : "initialization should had taken at least the time time slept with permit"
403 6 : );
404 6 :
405 6 : jh.await.unwrap();
406 6 :
407 6 : assert_eq!(*cell.get().unwrap(), reinit);
408 6 : }
409 :
410 : #[test]
411 6 : fn reinit_with_deinit_permit() {
412 6 : let cell = Arc::new(OnceCell::new(42));
413 6 :
414 6 : let (mol, permit) = cell.get().unwrap().take_and_deinit();
415 6 : cell.set(5, permit);
416 6 : assert_eq!(*cell.get().unwrap(), 5);
417 :
418 6 : let (five, permit) = cell.get().unwrap().take_and_deinit();
419 6 : assert_eq!(5, five);
420 6 : cell.set(mol, permit);
421 6 : assert_eq!(*cell.get().unwrap(), 42);
422 6 : }
423 :
424 : #[tokio::test]
425 6 : async fn initialization_attemptable_until_ok() {
426 6 : let cell = OnceCell::default();
427 6 :
428 66 : for _ in 0..10 {
429 60 : cell.get_or_init(|_permit| async { Err("whatever error") })
430 6 : .await
431 60 : .unwrap_err();
432 6 : }
433 6 :
434 6 : let g = cell
435 6 : .get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
436 6 : .await
437 6 : .unwrap();
438 6 : assert_eq!(*g, "finally success");
439 6 : }
440 :
441 : #[tokio::test]
442 6 : async fn initialization_is_cancellation_safe() {
443 6 : let cell = OnceCell::default();
444 6 :
445 6 : let barrier = tokio::sync::Barrier::new(2);
446 6 :
447 6 : let initializer = cell.get_or_init(|permit| async {
448 6 : barrier.wait().await;
449 6 : futures::future::pending::<()>().await;
450 6 :
451 6 : Ok::<_, Infallible>(("never reached", permit))
452 6 : });
453 6 :
454 6 : tokio::select! {
455 6 : _ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
456 6 : _ = barrier.wait() => {}
457 6 : };
458 6 :
459 6 : // now initializer is dropped
460 6 :
461 6 : assert!(cell.get().is_none());
462 6 :
463 6 : let g = cell
464 6 : .get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
465 6 : .await
466 6 : .unwrap();
467 6 : assert_eq!(*g, "now initialized");
468 6 : }
469 :
470 : #[tokio::test(start_paused = true)]
471 6 : async fn reproduce_init_take_deinit_race() {
472 12 : init_take_deinit_scenario(|cell, factory| {
473 12 : Box::pin(async {
474 36 : cell.get_or_init(factory).await.unwrap();
475 12 : })
476 12 : })
477 24 : .await;
478 6 : }
479 :
480 : type BoxedInitFuture<T, E> = Pin<Box<dyn Future<Output = Result<(T, InitPermit), E>>>>;
481 : type BoxedInitFunction<T, E> = Box<dyn Fn(InitPermit) -> BoxedInitFuture<T, E>>;
482 :
483 : /// Reproduce an assertion failure.
484 : ///
485 : /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`.
486 : /// We currently only have one, but the structure is kept.
487 6 : async fn init_take_deinit_scenario<F>(init_way: F)
488 6 : where
489 6 : F: for<'a> Fn(
490 6 : &'a OnceCell<&'static str>,
491 6 : BoxedInitFunction<&'static str, Infallible>,
492 6 : ) -> Pin<Box<dyn Future<Output = ()> + 'a>>,
493 6 : {
494 6 : let cell = OnceCell::default();
495 6 :
496 6 : // acquire the init_semaphore only permit to drive initializing tasks in order to waiting
497 6 : // on the same semaphore.
498 6 : let permit = cell
499 6 : .inner
500 6 : .lock()
501 6 : .unwrap()
502 6 : .init_semaphore
503 6 : .clone()
504 6 : .try_acquire_owned()
505 6 : .unwrap();
506 6 :
507 6 : let mut t1 = pin!(init_way(
508 6 : &cell,
509 6 : Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })),
510 6 : ));
511 6 :
512 6 : let mut t2 = pin!(init_way(
513 6 : &cell,
514 6 : Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })),
515 6 : ));
516 :
517 : // drive t2 first to the init_semaphore -- the timeout will be hit once t2 future can
518 : // no longer make progress
519 : tokio::select! {
520 : _ = &mut t2 => unreachable!("it cannot get permit"),
521 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
522 : }
523 :
524 : // followed by t1 in the init_semaphore
525 : tokio::select! {
526 : _ = &mut t1 => unreachable!("it cannot get permit"),
527 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
528 : }
529 :
530 : // now let t2 proceed and initialize
531 6 : drop(permit);
532 6 : t2.await;
533 :
534 6 : let (s, permit) = { cell.get().unwrap().take_and_deinit() };
535 6 : assert_eq!("t2", s);
536 :
537 : // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from
538 : // the new one.
539 : tokio::select! {
540 : _ = &mut t1 => unreachable!("it cannot get permit"),
541 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
542 : }
543 :
544 : // only now we get to initialize it
545 6 : drop(permit);
546 6 : t1.await;
547 :
548 6 : assert_eq!("t1", *cell.get().unwrap());
549 6 : }
550 :
551 : #[tokio::test(start_paused = true)]
552 6 : async fn detached_init_smoke() {
553 6 : let target = OnceCell::default();
554 6 :
555 6 : let Err(permit) = target.get_or_init_detached().await else {
556 6 : unreachable!("it is not initialized")
557 6 : };
558 6 :
559 6 : tokio::time::timeout(
560 6 : std::time::Duration::from_secs(3600 * 24 * 7 * 365),
561 6 : target.get_or_init(|permit2| async { Ok::<_, Infallible>((11, permit2)) }),
562 6 : )
563 6 : .await
564 6 : .expect_err("should timeout since we are already holding the permit");
565 6 :
566 6 : target.set(42, permit);
567 6 :
568 6 : let (_answer, permit) = {
569 6 : let guard = target
570 6 : .get_or_init(|permit| async { Ok::<_, Infallible>((11, permit)) })
571 6 : .await
572 6 : .unwrap();
573 6 :
574 6 : assert_eq!(*guard, 42);
575 6 :
576 6 : guard.take_and_deinit()
577 6 : };
578 6 :
579 6 : assert!(target.get().is_none());
580 6 :
581 6 : target.set(11, permit);
582 6 :
583 6 : assert_eq!(*target.get().unwrap(), 11);
584 6 : }
585 :
586 : #[tokio::test]
587 6 : async fn take_and_deinit_on_mut() {
588 6 : use std::convert::Infallible;
589 6 :
590 6 : let mut target = OnceCell::<u32>::default();
591 6 : assert!(target.take_and_deinit().is_none());
592 6 :
593 6 : target
594 6 : .get_or_init(|permit| async move { Ok::<_, Infallible>((42, permit)) })
595 6 : .await
596 6 : .unwrap();
597 6 :
598 6 : let again = target.take_and_deinit();
599 6 : assert!(matches!(again, Some((42, _))), "{again:?}");
600 6 :
601 6 : assert!(target.take_and_deinit().is_none());
602 6 : }
603 : }
|