LCOV - code coverage report
Current view: top level - libs/utils/src/sync - heavier_once_cell.rs (source / functions) Coverage Total Hit
Test: aca8877be6ceba750c1be359ed71bc1799d52b30.info Lines: 97.6 % 337 329
Test Date: 2024-02-14 18:05:35 Functions: 81.6 % 114 93

            Line data    Source code
       1              : use std::sync::{
       2              :     atomic::{AtomicUsize, Ordering},
       3              :     Arc, Mutex, MutexGuard,
       4              : };
       5              : use tokio::sync::Semaphore;
       6              : 
       7              : /// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
       8              : /// `SemaphorePermit`, allowing use of `take` which does not require holding an outer mutex guard
       9              : /// for the duration of initialization.
      10              : ///
      11              : /// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
      12              : ///
      13              : /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
      14              : pub struct OnceCell<T> {
      15              :     inner: Mutex<Inner<T>>,
      16              :     initializers: AtomicUsize,
      17              : }
      18              : 
      19              : impl<T> Default for OnceCell<T> {
      20              :     /// Create new uninitialized [`OnceCell`].
      21        40907 :     fn default() -> Self {
      22        40907 :         Self {
      23        40907 :             inner: Default::default(),
      24        40907 :             initializers: AtomicUsize::new(0),
      25        40907 :         }
      26        40907 :     }
      27              : }
      28              : 
      29              : /// Semaphore is the current state:
      30              : /// - open semaphore means the value is `None`, not yet initialized
      31              : /// - closed semaphore means the value has been initialized
      32            0 : #[derive(Debug)]
      33              : struct Inner<T> {
      34              :     init_semaphore: Arc<Semaphore>,
      35              :     value: Option<T>,
      36              : }
      37              : 
      38              : impl<T> Default for Inner<T> {
      39        43430 :     fn default() -> Self {
      40        43430 :         Self {
      41        43430 :             init_semaphore: Arc::new(Semaphore::new(1)),
      42        43430 :             value: None,
      43        43430 :         }
      44        43430 :     }
      45              : }
      46              : 
      47              : impl<T> OnceCell<T> {
      48              :     /// Creates an already initialized `OnceCell` with the given value.
      49        34461 :     pub fn new(value: T) -> Self {
      50        34461 :         let sem = Semaphore::new(1);
      51        34461 :         sem.close();
      52        34461 :         Self {
      53        34461 :             inner: Mutex::new(Inner {
      54        34461 :                 init_semaphore: Arc::new(sem),
      55        34461 :                 value: Some(value),
      56        34461 :             }),
      57        34461 :             initializers: AtomicUsize::new(0),
      58        34461 :         }
      59        34461 :     }
      60              : 
      61              :     /// Returns a guard to an existing initialized value, or uniquely initializes the value before
      62              :     /// returning the guard.
      63              :     ///
      64              :     /// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
      65              :     ///
      66              :     /// Initialization is panic-safe and cancellation-safe.
      67     16853002 :     pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
      68     16853002 :     where
      69     16853002 :         F: FnOnce(InitPermit) -> Fut,
      70     16853002 :         Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
      71     16853002 :     {
      72              :         loop {
      73        10619 :             let sem = {
      74     16853004 :                 let guard = self.inner.lock().unwrap();
      75     16853004 :                 if guard.value.is_some() {
      76     16842385 :                     return Ok(Guard(guard));
      77        10619 :                 }
      78        10619 :                 guard.init_semaphore.clone()
      79              :             };
      80              : 
      81              :             {
      82        10619 :                 let permit = {
      83              :                     // increment the count for the duration of queued
      84        10619 :                     let _guard = CountWaitingInitializers::start(self);
      85        10619 :                     sem.acquire().await
      86              :                 };
      87              : 
      88        10619 :                 let Ok(permit) = permit else {
      89          499 :                     let guard = self.inner.lock().unwrap();
      90          499 :                     if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
      91              :                         // there was a take_and_deinit in between
      92            2 :                         continue;
      93          497 :                     }
      94          497 :                     assert!(
      95          497 :                         guard.value.is_some(),
      96            0 :                         "semaphore got closed, must be initialized"
      97              :                     );
      98          497 :                     return Ok(Guard(guard));
      99              :                 };
     100              : 
     101        10120 :                 permit.forget();
     102        10120 :             }
     103        10120 : 
     104        10120 :             let permit = InitPermit(sem);
     105        29430 :             let (value, _permit) = factory(permit).await?;
     106              : 
     107         9468 :             let guard = self.inner.lock().unwrap();
     108         9468 : 
     109         9468 :             return Ok(Self::set0(value, guard));
     110              :         }
     111     16852992 :     }
     112              : 
     113              :     /// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
     114              :     /// to complete initializing the inner value.
     115              :     ///
     116              :     /// # Panics
     117              :     ///
     118              :     /// If the inner has already been initialized.
     119            4 :     pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
     120            4 :         let guard = self.inner.lock().unwrap();
     121            4 : 
     122            4 :         // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
     123            4 :         // give more permits right now.
     124            4 :         if guard.init_semaphore.try_acquire().is_ok() {
     125            0 :             drop(guard);
     126            0 :             panic!("permit is of wrong origin");
     127            4 :         }
     128            4 : 
     129            4 :         Self::set0(value, guard)
     130            4 :     }
     131              : 
     132         9472 :     fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
     133         9472 :         if guard.value.is_some() {
     134            0 :             drop(guard);
     135            0 :             unreachable!("we won permit, must not be initialized");
     136         9472 :         }
     137         9472 :         guard.value = Some(value);
     138         9472 :         guard.init_semaphore.close();
     139         9472 :         Guard(guard)
     140         9472 :     }
     141              : 
     142              :     /// Returns a guard to an existing initialized value, if any.
     143         8100 :     pub fn get(&self) -> Option<Guard<'_, T>> {
     144         8100 :         let guard = self.inner.lock().unwrap();
     145         8100 :         if guard.value.is_some() {
     146         7311 :             Some(Guard(guard))
     147              :         } else {
     148          789 :             None
     149              :         }
     150         8100 :     }
     151              : 
     152              :     /// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
     153         9456 :     pub fn initializer_count(&self) -> usize {
     154         9456 :         self.initializers.load(Ordering::Relaxed)
     155         9456 :     }
     156              : }
     157              : 
     158              : /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
     159              : /// initializing task for example at the end of initialization.
     160              : struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
     161              : 
     162              : impl<'a, T> CountWaitingInitializers<'a, T> {
     163        10619 :     fn start(target: &'a OnceCell<T>) -> Self {
     164        10619 :         target.initializers.fetch_add(1, Ordering::Relaxed);
     165        10619 :         CountWaitingInitializers(target)
     166        10619 :     }
     167              : }
     168              : 
     169              : impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
     170        10619 :     fn drop(&mut self) {
     171        10619 :         self.0.initializers.fetch_sub(1, Ordering::Relaxed);
     172        10619 :     }
     173              : }
     174              : 
     175              : /// Uninteresting guard object to allow short-lived access to inspect or clone the held,
     176              : /// initialized value.
     177            0 : #[derive(Debug)]
     178              : pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
     179              : 
     180              : impl<T> std::ops::Deref for Guard<'_, T> {
     181              :     type Target = T;
     182              : 
     183         2727 :     fn deref(&self) -> &Self::Target {
     184         2727 :         self.0
     185         2727 :             .value
     186         2727 :             .as_ref()
     187         2727 :             .expect("guard is not created unless value has been initialized")
     188         2727 :     }
     189              : }
     190              : 
     191              : impl<T> std::ops::DerefMut for Guard<'_, T> {
     192     16854655 :     fn deref_mut(&mut self) -> &mut Self::Target {
     193     16854655 :         self.0
     194     16854655 :             .value
     195     16854655 :             .as_mut()
     196     16854655 :             .expect("guard is not created unless value has been initialized")
     197     16854655 :     }
     198              : }
     199              : 
     200              : impl<'a, T> Guard<'a, T> {
     201              :     /// Take the current value, and a new permit for it's deinitialization.
     202              :     ///
     203              :     /// The permit will be on a semaphore part of the new internal value, and any following
     204              :     /// [`OnceCell::get_or_init`] will wait on it to complete.
     205         2523 :     pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
     206         2523 :         let mut swapped = Inner::default();
     207         2523 :         let sem = swapped.init_semaphore.clone();
     208         2523 :         // acquire and forget right away, moving the control over to InitPermit
     209         2523 :         sem.try_acquire().expect("we just created this").forget();
     210         2523 :         std::mem::swap(&mut *self.0, &mut swapped);
     211         2523 :         swapped
     212         2523 :             .value
     213         2523 :             .map(|v| (v, InitPermit(sem)))
     214         2523 :             .expect("guard is not created unless value has been initialized")
     215         2523 :     }
     216              : }
     217              : 
     218              : /// Type held by OnceCell (de)initializing task.
     219              : ///
     220              : /// On drop, this type will return the permit.
     221              : pub struct InitPermit(Arc<tokio::sync::Semaphore>);
     222              : 
     223              : impl Drop for InitPermit {
     224        12640 :     fn drop(&mut self) {
     225        12640 :         assert_eq!(
     226        12640 :             self.0.available_permits(),
     227              :             0,
     228            0 :             "InitPermit should only exist as the unique permit"
     229              :         );
     230        12640 :         self.0.add_permits(1);
     231        12640 :     }
     232              : }
     233              : 
     234              : #[cfg(test)]
     235              : mod tests {
     236              :     use futures::Future;
     237              : 
     238              :     use super::*;
     239              :     use std::{
     240              :         convert::Infallible,
     241              :         pin::{pin, Pin},
     242              :         sync::atomic::{AtomicUsize, Ordering},
     243              :         time::Duration,
     244              :     };
     245              : 
     246            2 :     #[tokio::test]
     247            2 :     async fn many_initializers() {
     248            2 :         #[derive(Default, Debug)]
     249            2 :         struct Counters {
     250            2 :             factory_got_to_run: AtomicUsize,
     251            2 :             future_polled: AtomicUsize,
     252            2 :             winners: AtomicUsize,
     253            2 :         }
     254            2 : 
     255            2 :         let initializers = 100;
     256            2 : 
     257            2 :         let cell = Arc::new(OnceCell::default());
     258            2 :         let counters = Arc::new(Counters::default());
     259            2 :         let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
     260            2 : 
     261            2 :         let mut js = tokio::task::JoinSet::new();
     262          200 :         for i in 0..initializers {
     263          200 :             js.spawn({
     264          200 :                 let cell = cell.clone();
     265          200 :                 let counters = counters.clone();
     266          200 :                 let barrier = barrier.clone();
     267          200 : 
     268          200 :                 async move {
     269          200 :                     barrier.wait().await;
     270          200 :                     let won = {
     271          200 :                         let g = cell
     272          200 :                             .get_or_init(|permit| {
     273            2 :                                 counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
     274            2 :                                 async {
     275            2 :                                     counters.future_polled.fetch_add(1, Ordering::Relaxed);
     276            2 :                                     Ok::<_, Infallible>((i, permit))
     277            2 :                                 }
     278          200 :                             })
     279            2 :                             .await
     280          200 :                             .unwrap();
     281          200 : 
     282          200 :                         *g == i
     283          200 :                     };
     284          200 : 
     285          200 :                     if won {
     286            2 :                         counters.winners.fetch_add(1, Ordering::Relaxed);
     287          198 :                     }
     288          200 :                 }
     289          200 :             });
     290          200 :         }
     291            2 : 
     292            2 :         barrier.wait().await;
     293            2 : 
     294          202 :         while let Some(next) = js.join_next().await {
     295          200 :             next.expect("no panics expected");
     296          200 :         }
     297            2 : 
     298            2 :         let mut counters = Arc::try_unwrap(counters).unwrap();
     299            2 : 
     300            2 :         assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
     301            2 :         assert_eq!(*counters.future_polled.get_mut(), 1);
     302            2 :         assert_eq!(*counters.winners.get_mut(), 1);
     303            2 :     }
     304              : 
     305            2 :     #[tokio::test(start_paused = true)]
     306            2 :     async fn reinit_waits_for_deinit() {
     307            2 :         // with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
     308            2 :         let sleep_for = Duration::from_secs(1);
     309            2 :         let initial = 42;
     310            2 :         let reinit = 1;
     311            2 :         let cell = Arc::new(OnceCell::new(initial));
     312            2 : 
     313            2 :         let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
     314            2 : 
     315            2 :         let jh = tokio::spawn({
     316            2 :             let cell = cell.clone();
     317            2 :             let deinitialization_started = deinitialization_started.clone();
     318            2 :             async move {
     319            2 :                 let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
     320            2 :                 assert_eq!(answer, initial);
     321            2 : 
     322            2 :                 deinitialization_started.wait().await;
     323            2 :                 tokio::time::sleep(sleep_for).await;
     324            2 :             }
     325            2 :         });
     326            2 : 
     327            2 :         deinitialization_started.wait().await;
     328            2 : 
     329            2 :         let started_at = tokio::time::Instant::now();
     330            2 :         cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
     331            2 :             .await
     332            2 :             .unwrap();
     333            2 : 
     334            2 :         let elapsed = started_at.elapsed();
     335            2 :         assert!(
     336            2 :             elapsed >= sleep_for,
     337            2 :             "initialization should had taken at least the time time slept with permit"
     338            2 :         );
     339            2 : 
     340            2 :         jh.await.unwrap();
     341            2 : 
     342            2 :         assert_eq!(*cell.get().unwrap(), reinit);
     343            2 :     }
     344              : 
     345            2 :     #[test]
     346            2 :     fn reinit_with_deinit_permit() {
     347            2 :         let cell = Arc::new(OnceCell::new(42));
     348            2 : 
     349            2 :         let (mol, permit) = cell.get().unwrap().take_and_deinit();
     350            2 :         cell.set(5, permit);
     351            2 :         assert_eq!(*cell.get().unwrap(), 5);
     352              : 
     353            2 :         let (five, permit) = cell.get().unwrap().take_and_deinit();
     354            2 :         assert_eq!(5, five);
     355            2 :         cell.set(mol, permit);
     356            2 :         assert_eq!(*cell.get().unwrap(), 42);
     357            2 :     }
     358              : 
     359            2 :     #[tokio::test]
     360            2 :     async fn initialization_attemptable_until_ok() {
     361            2 :         let cell = OnceCell::default();
     362            2 : 
     363           22 :         for _ in 0..10 {
     364           20 :             cell.get_or_init(|_permit| async { Err("whatever error") })
     365            2 :                 .await
     366           20 :                 .unwrap_err();
     367            2 :         }
     368            2 : 
     369            2 :         let g = cell
     370            2 :             .get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
     371            2 :             .await
     372            2 :             .unwrap();
     373            2 :         assert_eq!(*g, "finally success");
     374            2 :     }
     375              : 
     376            2 :     #[tokio::test]
     377            2 :     async fn initialization_is_cancellation_safe() {
     378            2 :         let cell = OnceCell::default();
     379            2 : 
     380            2 :         let barrier = tokio::sync::Barrier::new(2);
     381            2 : 
     382            2 :         let initializer = cell.get_or_init(|permit| async {
     383            2 :             barrier.wait().await;
     384            2 :             futures::future::pending::<()>().await;
     385            2 : 
     386            2 :             Ok::<_, Infallible>(("never reached", permit))
     387            2 :         });
     388            2 : 
     389            3 :         tokio::select! {
     390            3 :             _ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
     391            3 :             _ = barrier.wait() => {}
     392            3 :         };
     393            2 : 
     394            2 :         // now initializer is dropped
     395            2 : 
     396            2 :         assert!(cell.get().is_none());
     397            2 : 
     398            2 :         let g = cell
     399            2 :             .get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
     400            2 :             .await
     401            2 :             .unwrap();
     402            2 :         assert_eq!(*g, "now initialized");
     403            2 :     }
     404              : 
     405            2 :     #[tokio::test(start_paused = true)]
     406            2 :     async fn reproduce_init_take_deinit_race() {
     407            4 :         init_take_deinit_scenario(|cell, factory| {
     408            4 :             Box::pin(async {
     409           10 :                 cell.get_or_init(factory).await.unwrap();
     410            4 :             })
     411            4 :         })
     412            8 :         .await;
     413            2 :     }
     414              : 
     415              :     type BoxedInitFuture<T, E> = Pin<Box<dyn Future<Output = Result<(T, InitPermit), E>>>>;
     416              :     type BoxedInitFunction<T, E> = Box<dyn Fn(InitPermit) -> BoxedInitFuture<T, E>>;
     417              : 
     418              :     /// Reproduce an assertion failure.
     419              :     ///
     420              :     /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`.
     421              :     /// We currently only have one, but the structure is kept.
     422            2 :     async fn init_take_deinit_scenario<F>(init_way: F)
     423            2 :     where
     424            2 :         F: for<'a> Fn(
     425            2 :             &'a OnceCell<&'static str>,
     426            2 :             BoxedInitFunction<&'static str, Infallible>,
     427            2 :         ) -> Pin<Box<dyn Future<Output = ()> + 'a>>,
     428            2 :     {
     429            2 :         let cell = OnceCell::default();
     430            2 : 
     431            2 :         // acquire the init_semaphore only permit to drive initializing tasks in order to waiting
     432            2 :         // on the same semaphore.
     433            2 :         let permit = cell
     434            2 :             .inner
     435            2 :             .lock()
     436            2 :             .unwrap()
     437            2 :             .init_semaphore
     438            2 :             .clone()
     439            2 :             .try_acquire_owned()
     440            2 :             .unwrap();
     441            2 : 
     442            2 :         let mut t1 = pin!(init_way(
     443            2 :             &cell,
     444            2 :             Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })),
     445            2 :         ));
     446            2 : 
     447            2 :         let mut t2 = pin!(init_way(
     448            2 :             &cell,
     449            2 :             Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })),
     450            2 :         ));
     451            2 : 
     452            2 :         // drive t2 first to the init_semaphore -- the timeout will be hit once t2 future can
     453            2 :         // no longer make progress
     454            4 :         tokio::select! {
     455            4 :             _ = &mut t2 => unreachable!("it cannot get permit"),
     456            4 :             _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
     457            4 :         }
     458              : 
     459              :         // followed by t1 in the init_semaphore
     460            2 :         tokio::select! {
     461            4 :             _ = &mut t1 => unreachable!("it cannot get permit"),
     462            4 :             _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
     463            4 :         }
     464              : 
     465              :         // now let t2 proceed and initialize
     466            2 :         drop(permit);
     467            2 :         t2.await;
     468              : 
     469            2 :         let (s, permit) = { cell.get().unwrap().take_and_deinit() };
     470            2 :         assert_eq!("t2", s);
     471              : 
     472              :         // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from
     473              :         // the new one.
     474            2 :         tokio::select! {
     475            6 :             _ = &mut t1 => unreachable!("it cannot get permit"),
     476            6 :             _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
     477            6 :         }
     478              : 
     479              :         // only now we get to initialize it
     480            2 :         drop(permit);
     481            2 :         t1.await;
     482              : 
     483            2 :         assert_eq!("t1", *cell.get().unwrap());
     484            2 :     }
     485              : }
        

Generated by: LCOV version 2.1-beta