Line data Source code
1 : use std::sync::{
2 : atomic::{AtomicUsize, Ordering},
3 : Arc, Mutex, MutexGuard,
4 : };
5 : use tokio::sync::Semaphore;
6 :
7 : /// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
8 : /// `SemaphorePermit`.
9 : ///
10 : /// Allows use of `take` which does not require holding an outer mutex guard
11 : /// for the duration of initialization.
12 : ///
13 : /// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
14 : ///
15 : /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
16 : pub struct OnceCell<T> {
17 : inner: Mutex<Inner<T>>,
18 : initializers: AtomicUsize,
19 : }
20 :
21 : impl<T> Default for OnceCell<T> {
22 : /// Create new uninitialized [`OnceCell`].
23 14 : fn default() -> Self {
24 14 : Self {
25 14 : inner: Default::default(),
26 14 : initializers: AtomicUsize::new(0),
27 14 : }
28 14 : }
29 : }
30 :
31 : /// Semaphore is the current state:
32 : /// - open semaphore means the value is `None`, not yet initialized
33 : /// - closed semaphore means the value has been initialized
34 : #[derive(Debug)]
35 : struct Inner<T> {
36 : init_semaphore: Arc<Semaphore>,
37 : value: Option<T>,
38 : }
39 :
40 : impl<T> Default for Inner<T> {
41 621 : fn default() -> Self {
42 621 : Self {
43 621 : init_semaphore: Arc::new(Semaphore::new(1)),
44 621 : value: None,
45 621 : }
46 621 : }
47 : }
48 :
49 : impl<T> OnceCell<T> {
50 : /// Creates an already initialized `OnceCell` with the given value.
51 1756 : pub fn new(value: T) -> Self {
52 1756 : let sem = Semaphore::new(1);
53 1756 : sem.close();
54 1756 : Self {
55 1756 : inner: Mutex::new(Inner {
56 1756 : init_semaphore: Arc::new(sem),
57 1756 : value: Some(value),
58 1756 : }),
59 1756 : initializers: AtomicUsize::new(0),
60 1756 : }
61 1756 : }
62 :
63 : /// Returns a guard to an existing initialized value, or uniquely initializes the value before
64 : /// returning the guard.
65 : ///
66 : /// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
67 : ///
68 : /// Initialization is panic-safe and cancellation-safe.
69 119 : pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
70 119 : where
71 119 : F: FnOnce(InitPermit) -> Fut,
72 119 : Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
73 119 : {
74 : loop {
75 20 : let sem = {
76 120 : let guard = self.inner.lock().unwrap();
77 120 : if guard.value.is_some() {
78 100 : return Ok(Guard(guard));
79 20 : }
80 20 : guard.init_semaphore.clone()
81 : };
82 :
83 : {
84 19 : let permit = {
85 : // increment the count for the duration of queued
86 20 : let _guard = CountWaitingInitializers::start(self);
87 20 : sem.acquire().await
88 : };
89 :
90 19 : let Ok(permit) = permit else {
91 1 : let guard = self.inner.lock().unwrap();
92 1 : if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
93 : // there was a take_and_deinit in between
94 1 : continue;
95 0 : }
96 0 : assert!(
97 0 : guard.value.is_some(),
98 0 : "semaphore got closed, must be initialized"
99 : );
100 0 : return Ok(Guard(guard));
101 : };
102 :
103 18 : permit.forget();
104 18 : }
105 18 :
106 18 : let permit = InitPermit(sem);
107 18 : let (value, _permit) = factory(permit).await?;
108 :
109 7 : let guard = self.inner.lock().unwrap();
110 7 :
111 7 : return Ok(Self::set0(value, guard));
112 : }
113 117 : }
114 :
115 : /// Returns a guard to an existing initialized value, or returns an unique initialization
116 : /// permit which can be used to initialize this `OnceCell` using `OnceCell::set`.
117 240361 : pub async fn get_or_init_detached(&self) -> Result<Guard<'_, T>, InitPermit> {
118 : // It looks like OnceCell::get_or_init could be implemented using this method instead of
119 : // duplication. However, that makes the future be !Send due to possibly holding on to the
120 : // MutexGuard over an await point.
121 : loop {
122 21 : let sem = {
123 240361 : let guard = self.inner.lock().unwrap();
124 240361 : if guard.value.is_some() {
125 240340 : return Ok(Guard(guard));
126 21 : }
127 21 : guard.init_semaphore.clone()
128 : };
129 :
130 : {
131 21 : let permit = {
132 : // increment the count for the duration of queued
133 21 : let _guard = CountWaitingInitializers::start(self);
134 21 : sem.acquire().await
135 : };
136 :
137 21 : let Ok(permit) = permit else {
138 0 : let guard = self.inner.lock().unwrap();
139 0 : if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
140 : // there was a take_and_deinit in between
141 0 : continue;
142 0 : }
143 0 : assert!(
144 0 : guard.value.is_some(),
145 0 : "semaphore got closed, must be initialized"
146 : );
147 0 : return Ok(Guard(guard));
148 : };
149 :
150 21 : permit.forget();
151 21 : }
152 21 :
153 21 : let permit = InitPermit(sem);
154 21 : return Err(permit);
155 : }
156 240361 : }
157 :
158 : /// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
159 : /// to complete initializing the inner value.
160 : ///
161 : /// # Panics
162 : ///
163 : /// If the inner has already been initialized.
164 26 : pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
165 26 : let guard = self.inner.lock().unwrap();
166 26 :
167 26 : // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
168 26 : // give more permits right now.
169 26 : if guard.init_semaphore.try_acquire().is_ok() {
170 0 : drop(guard);
171 0 : panic!("permit is of wrong origin");
172 26 : }
173 26 :
174 26 : Self::set0(value, guard)
175 26 : }
176 :
177 33 : fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
178 33 : if guard.value.is_some() {
179 0 : drop(guard);
180 0 : unreachable!("we won permit, must not be initialized");
181 33 : }
182 33 : guard.value = Some(value);
183 33 : guard.init_semaphore.close();
184 33 : Guard(guard)
185 33 : }
186 :
187 : /// Returns a guard to an existing initialized value, if any.
188 118 : pub fn get(&self) -> Option<Guard<'_, T>> {
189 118 : let guard = self.inner.lock().unwrap();
190 118 : if guard.value.is_some() {
191 104 : Some(Guard(guard))
192 : } else {
193 14 : None
194 : }
195 118 : }
196 :
197 : /// Like [`Guard::take_and_deinit`], but will return `None` if this OnceCell was never
198 : /// initialized.
199 578 : pub fn take_and_deinit(&mut self) -> Option<(T, InitPermit)> {
200 578 : let inner = self.inner.get_mut().unwrap();
201 578 :
202 578 : inner.take_and_deinit()
203 578 : }
204 :
205 : /// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
206 30 : pub fn initializer_count(&self) -> usize {
207 30 : self.initializers.load(Ordering::Relaxed)
208 30 : }
209 : }
210 :
211 : /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
212 : /// initializing task for example at the end of initialization.
213 : struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
214 :
215 : impl<'a, T> CountWaitingInitializers<'a, T> {
216 41 : fn start(target: &'a OnceCell<T>) -> Self {
217 41 : target.initializers.fetch_add(1, Ordering::Relaxed);
218 41 : CountWaitingInitializers(target)
219 41 : }
220 : }
221 :
222 : impl<T> Drop for CountWaitingInitializers<'_, T> {
223 41 : fn drop(&mut self) {
224 41 : self.0.initializers.fetch_sub(1, Ordering::Relaxed);
225 41 : }
226 : }
227 :
228 : /// Uninteresting guard object to allow short-lived access to inspect or clone the held,
229 : /// initialized value.
230 : #[derive(Debug)]
231 : pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
232 :
233 : impl<T> std::ops::Deref for Guard<'_, T> {
234 : type Target = T;
235 :
236 193 : fn deref(&self) -> &Self::Target {
237 193 : self.0
238 193 : .value
239 193 : .as_ref()
240 193 : .expect("guard is not created unless value has been initialized")
241 193 : }
242 : }
243 :
244 : impl<T> std::ops::DerefMut for Guard<'_, T> {
245 240350 : fn deref_mut(&mut self) -> &mut Self::Target {
246 240350 : self.0
247 240350 : .value
248 240350 : .as_mut()
249 240350 : .expect("guard is not created unless value has been initialized")
250 240350 : }
251 : }
252 :
253 : impl<T> Guard<'_, T> {
254 : /// Take the current value, and a new permit for it's deinitialization.
255 : ///
256 : /// The permit will be on a semaphore part of the new internal value, and any following
257 : /// [`OnceCell::get_or_init`] will wait on it to complete.
258 33 : pub fn take_and_deinit(mut self) -> (T, InitPermit) {
259 33 : self.0
260 33 : .take_and_deinit()
261 33 : .expect("guard is not created unless value has been initialized")
262 33 : }
263 : }
264 :
265 : impl<T> Inner<T> {
266 611 : pub fn take_and_deinit(&mut self) -> Option<(T, InitPermit)> {
267 611 : let value = self.value.take()?;
268 :
269 607 : let mut swapped = Inner::default();
270 607 : let sem = swapped.init_semaphore.clone();
271 607 : // acquire and forget right away, moving the control over to InitPermit
272 607 : sem.try_acquire().expect("we just created this").forget();
273 607 : let permit = InitPermit(sem);
274 607 : std::mem::swap(self, &mut swapped);
275 607 : Some((value, permit))
276 611 : }
277 : }
278 :
279 : /// Type held by OnceCell (de)initializing task.
280 : ///
281 : /// On drop, this type will return the permit.
282 : pub struct InitPermit(Arc<tokio::sync::Semaphore>);
283 :
284 : impl std::fmt::Debug for InitPermit {
285 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 0 : let ptr = Arc::as_ptr(&self.0) as *const ();
287 0 : f.debug_tuple("InitPermit").field(&ptr).finish()
288 0 : }
289 : }
290 :
291 : impl Drop for InitPermit {
292 646 : fn drop(&mut self) {
293 646 : assert_eq!(
294 646 : self.0.available_permits(),
295 : 0,
296 0 : "InitPermit should only exist as the unique permit"
297 : );
298 646 : self.0.add_permits(1);
299 646 : }
300 : }
301 :
302 : #[cfg(test)]
303 : mod tests {
304 : use futures::Future;
305 :
306 : use super::*;
307 : use std::{
308 : convert::Infallible,
309 : pin::{pin, Pin},
310 : time::Duration,
311 : };
312 :
313 : #[tokio::test]
314 1 : async fn many_initializers() {
315 1 : #[derive(Default, Debug)]
316 1 : struct Counters {
317 1 : factory_got_to_run: AtomicUsize,
318 1 : future_polled: AtomicUsize,
319 1 : winners: AtomicUsize,
320 1 : }
321 1 :
322 1 : let initializers = 100;
323 1 :
324 1 : let cell = Arc::new(OnceCell::default());
325 1 : let counters = Arc::new(Counters::default());
326 1 : let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
327 1 :
328 1 : let mut js = tokio::task::JoinSet::new();
329 100 : for i in 0..initializers {
330 100 : js.spawn({
331 100 : let cell = cell.clone();
332 100 : let counters = counters.clone();
333 100 : let barrier = barrier.clone();
334 100 :
335 100 : async move {
336 100 : barrier.wait().await;
337 100 : let won = {
338 100 : let g = cell
339 100 : .get_or_init(|permit| {
340 1 : counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
341 1 : async {
342 1 : counters.future_polled.fetch_add(1, Ordering::Relaxed);
343 1 : Ok::<_, Infallible>((i, permit))
344 1 : }
345 100 : })
346 1 : .await
347 100 : .unwrap();
348 100 :
349 100 : *g == i
350 100 : };
351 100 :
352 100 : if won {
353 1 : counters.winners.fetch_add(1, Ordering::Relaxed);
354 99 : }
355 100 : }
356 100 : });
357 100 : }
358 1 :
359 1 : barrier.wait().await;
360 1 :
361 101 : while let Some(next) = js.join_next().await {
362 100 : next.expect("no panics expected");
363 100 : }
364 1 :
365 1 : let mut counters = Arc::try_unwrap(counters).unwrap();
366 1 :
367 1 : assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
368 1 : assert_eq!(*counters.future_polled.get_mut(), 1);
369 1 : assert_eq!(*counters.winners.get_mut(), 1);
370 1 : }
371 :
372 : #[tokio::test(start_paused = true)]
373 1 : async fn reinit_waits_for_deinit() {
374 1 : // with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
375 1 : let sleep_for = Duration::from_secs(1);
376 1 : let initial = 42;
377 1 : let reinit = 1;
378 1 : let cell = Arc::new(OnceCell::new(initial));
379 1 :
380 1 : let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
381 1 :
382 1 : let jh = tokio::spawn({
383 1 : let cell = cell.clone();
384 1 : let deinitialization_started = deinitialization_started.clone();
385 1 : async move {
386 1 : let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
387 1 : assert_eq!(answer, initial);
388 1 :
389 1 : deinitialization_started.wait().await;
390 1 : tokio::time::sleep(sleep_for).await;
391 1 : }
392 1 : });
393 1 :
394 1 : deinitialization_started.wait().await;
395 1 :
396 1 : let started_at = tokio::time::Instant::now();
397 1 : cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
398 1 : .await
399 1 : .unwrap();
400 1 :
401 1 : let elapsed = started_at.elapsed();
402 1 : assert!(
403 1 : elapsed >= sleep_for,
404 1 : "initialization should had taken at least the time time slept with permit"
405 1 : );
406 1 :
407 1 : jh.await.unwrap();
408 1 :
409 1 : assert_eq!(*cell.get().unwrap(), reinit);
410 1 : }
411 :
412 : #[test]
413 1 : fn reinit_with_deinit_permit() {
414 1 : let cell = Arc::new(OnceCell::new(42));
415 1 :
416 1 : let (mol, permit) = cell.get().unwrap().take_and_deinit();
417 1 : cell.set(5, permit);
418 1 : assert_eq!(*cell.get().unwrap(), 5);
419 :
420 1 : let (five, permit) = cell.get().unwrap().take_and_deinit();
421 1 : assert_eq!(5, five);
422 1 : cell.set(mol, permit);
423 1 : assert_eq!(*cell.get().unwrap(), 42);
424 1 : }
425 :
426 : #[tokio::test]
427 1 : async fn initialization_attemptable_until_ok() {
428 1 : let cell = OnceCell::default();
429 1 :
430 11 : for _ in 0..10 {
431 10 : cell.get_or_init(|_permit| async { Err("whatever error") })
432 1 : .await
433 10 : .unwrap_err();
434 1 : }
435 1 :
436 1 : let g = cell
437 1 : .get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
438 1 : .await
439 1 : .unwrap();
440 1 : assert_eq!(*g, "finally success");
441 1 : }
442 :
443 : #[tokio::test]
444 1 : async fn initialization_is_cancellation_safe() {
445 1 : let cell = OnceCell::default();
446 1 :
447 1 : let barrier = tokio::sync::Barrier::new(2);
448 1 :
449 1 : let initializer = cell.get_or_init(|permit| async {
450 1 : barrier.wait().await;
451 1 : futures::future::pending::<()>().await;
452 1 :
453 1 : Ok::<_, Infallible>(("never reached", permit))
454 1 : });
455 1 :
456 1 : tokio::select! {
457 1 : _ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
458 1 : _ = barrier.wait() => {}
459 1 : };
460 1 :
461 1 : // now initializer is dropped
462 1 :
463 1 : assert!(cell.get().is_none());
464 1 :
465 1 : let g = cell
466 1 : .get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
467 1 : .await
468 1 : .unwrap();
469 1 : assert_eq!(*g, "now initialized");
470 1 : }
471 :
472 : #[tokio::test(start_paused = true)]
473 1 : async fn reproduce_init_take_deinit_race() {
474 2 : init_take_deinit_scenario(|cell, factory| {
475 2 : Box::pin(async {
476 5 : cell.get_or_init(factory).await.unwrap();
477 2 : })
478 2 : })
479 4 : .await;
480 1 : }
481 :
482 : type BoxedInitFuture<T, E> = Pin<Box<dyn Future<Output = Result<(T, InitPermit), E>>>>;
483 : type BoxedInitFunction<T, E> = Box<dyn Fn(InitPermit) -> BoxedInitFuture<T, E>>;
484 :
485 : /// Reproduce an assertion failure.
486 : ///
487 : /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`.
488 : /// We currently only have one, but the structure is kept.
489 1 : async fn init_take_deinit_scenario<F>(init_way: F)
490 1 : where
491 1 : F: for<'a> Fn(
492 1 : &'a OnceCell<&'static str>,
493 1 : BoxedInitFunction<&'static str, Infallible>,
494 1 : ) -> Pin<Box<dyn Future<Output = ()> + 'a>>,
495 1 : {
496 1 : let cell = OnceCell::default();
497 1 :
498 1 : // acquire the init_semaphore only permit to drive initializing tasks in order to waiting
499 1 : // on the same semaphore.
500 1 : let permit = cell
501 1 : .inner
502 1 : .lock()
503 1 : .unwrap()
504 1 : .init_semaphore
505 1 : .clone()
506 1 : .try_acquire_owned()
507 1 : .unwrap();
508 1 :
509 1 : let mut t1 = pin!(init_way(
510 1 : &cell,
511 1 : Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })),
512 1 : ));
513 1 :
514 1 : let mut t2 = pin!(init_way(
515 1 : &cell,
516 1 : Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })),
517 1 : ));
518 1 :
519 1 : // drive t2 first to the init_semaphore -- the timeout will be hit once t2 future can
520 1 : // no longer make progress
521 1 : tokio::select! {
522 1 : _ = &mut t2 => unreachable!("it cannot get permit"),
523 1 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
524 1 : }
525 1 :
526 1 : // followed by t1 in the init_semaphore
527 1 : tokio::select! {
528 1 : _ = &mut t1 => unreachable!("it cannot get permit"),
529 1 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
530 1 : }
531 1 :
532 1 : // now let t2 proceed and initialize
533 1 : drop(permit);
534 1 : t2.await;
535 :
536 1 : let (s, permit) = { cell.get().unwrap().take_and_deinit() };
537 1 : assert_eq!("t2", s);
538 :
539 : // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from
540 : // the new one.
541 1 : tokio::select! {
542 1 : _ = &mut t1 => unreachable!("it cannot get permit"),
543 1 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
544 1 : }
545 1 :
546 1 : // only now we get to initialize it
547 1 : drop(permit);
548 1 : t1.await;
549 :
550 1 : assert_eq!("t1", *cell.get().unwrap());
551 1 : }
552 :
553 : #[tokio::test(start_paused = true)]
554 1 : async fn detached_init_smoke() {
555 1 : let target = OnceCell::default();
556 1 :
557 1 : let Err(permit) = target.get_or_init_detached().await else {
558 1 : unreachable!("it is not initialized")
559 1 : };
560 1 :
561 1 : tokio::time::timeout(
562 1 : std::time::Duration::from_secs(3600 * 24 * 7 * 365),
563 1 : target.get_or_init(|permit2| async { Ok::<_, Infallible>((11, permit2)) }),
564 1 : )
565 1 : .await
566 1 : .expect_err("should timeout since we are already holding the permit");
567 1 :
568 1 : target.set(42, permit);
569 1 :
570 1 : let (_answer, permit) = {
571 1 : let guard = target
572 1 : .get_or_init(|permit| async { Ok::<_, Infallible>((11, permit)) })
573 1 : .await
574 1 : .unwrap();
575 1 :
576 1 : assert_eq!(*guard, 42);
577 1 :
578 1 : guard.take_and_deinit()
579 1 : };
580 1 :
581 1 : assert!(target.get().is_none());
582 1 :
583 1 : target.set(11, permit);
584 1 :
585 1 : assert_eq!(*target.get().unwrap(), 11);
586 1 : }
587 :
588 : #[tokio::test]
589 1 : async fn take_and_deinit_on_mut() {
590 1 : use std::convert::Infallible;
591 1 :
592 1 : let mut target = OnceCell::<u32>::default();
593 1 : assert!(target.take_and_deinit().is_none());
594 1 :
595 1 : target
596 1 : .get_or_init(|permit| async move { Ok::<_, Infallible>((42, permit)) })
597 1 : .await
598 1 : .unwrap();
599 1 :
600 1 : let again = target.take_and_deinit();
601 1 : assert!(matches!(again, Some((42, _))), "{again:?}");
602 1 :
603 1 : assert!(target.take_and_deinit().is_none());
604 1 : }
605 : }
|