LCOV - code coverage report
Current view: top level - libs/utils/src/sync - heavier_once_cell.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 80.1 % 287 230
Test Date: 2024-02-07 07:37:29 Functions: 70.5 % 105 74

            Line data    Source code
       1              : use std::sync::{
       2              :     atomic::{AtomicUsize, Ordering},
       3              :     Arc,
       4              : };
       5              : use tokio::sync::Semaphore;
       6              : 
       7              : /// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
       8              : /// `SemaphorePermit`, allowing use of `take` which does not require holding an outer mutex guard
       9              : /// for the duration of initialization.
      10              : ///
      11              : /// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
      12              : ///
      13              : /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
      14              : pub struct OnceCell<T> {
      15              :     inner: tokio::sync::RwLock<Inner<T>>,
      16              :     initializers: AtomicUsize,
      17              : }
      18              : 
      19              : impl<T> Default for OnceCell<T> {
      20              :     /// Create new uninitialized [`OnceCell`].
      21        44123 :     fn default() -> Self {
      22        44123 :         Self {
      23        44123 :             inner: Default::default(),
      24        44123 :             initializers: AtomicUsize::new(0),
      25        44123 :         }
      26        44123 :     }
      27              : }
      28              : 
      29              : /// Semaphore is the current state:
      30              : /// - open semaphore means the value is `None`, not yet initialized
      31              : /// - closed semaphore means the value has been initialized
      32            0 : #[derive(Debug)]
      33              : struct Inner<T> {
      34              :     init_semaphore: Arc<Semaphore>,
      35              :     value: Option<T>,
      36              : }
      37              : 
      38              : impl<T> Default for Inner<T> {
      39        46690 :     fn default() -> Self {
      40        46690 :         Self {
      41        46690 :             init_semaphore: Arc::new(Semaphore::new(1)),
      42        46690 :             value: None,
      43        46690 :         }
      44        46690 :     }
      45              : }
      46              : 
      47              : impl<T> OnceCell<T> {
      48              :     /// Creates an already initialized `OnceCell` with the given value.
      49        34667 :     pub fn new(value: T) -> Self {
      50        34667 :         let sem = Semaphore::new(1);
      51        34667 :         sem.close();
      52        34667 :         Self {
      53        34667 :             inner: tokio::sync::RwLock::new(Inner {
      54        34667 :                 init_semaphore: Arc::new(sem),
      55        34667 :                 value: Some(value),
      56        34667 :             }),
      57        34667 :             initializers: AtomicUsize::new(0),
      58        34667 :         }
      59        34667 :     }
      60              : 
      61              :     /// Returns a guard to an existing initialized value, or uniquely initializes the value before
      62              :     /// returning the guard.
      63              :     ///
      64              :     /// Initializing might wait on any existing [`GuardMut::take_and_deinit`] deinitialization.
      65              :     ///
      66              :     /// Initialization is panic-safe and cancellation-safe.
      67     23943801 :     pub async fn get_mut_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardMut<'_, T>, E>
      68     23943801 :     where
      69     23943801 :         F: FnOnce(InitPermit) -> Fut,
      70     23943801 :         Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
      71     23943801 :     {
      72        10926 :         let sem = {
      73     23943801 :             let guard = self.inner.write().await;
      74     23943801 :             if guard.value.is_some() {
      75     23932875 :                 return Ok(GuardMut(guard));
      76        10926 :             }
      77        10926 :             guard.init_semaphore.clone()
      78              :         };
      79              : 
      80        10926 :         let permit = {
      81              :             // increment the count for the duration of queued
      82        10926 :             let _guard = CountWaitingInitializers::start(self);
      83        10926 :             sem.acquire_owned().await
      84              :         };
      85              : 
      86        10926 :         match permit {
      87        10402 :             Ok(permit) => {
      88        10402 :                 let permit = InitPermit(permit);
      89        30618 :                 let (value, _permit) = factory(permit).await?;
      90              : 
      91         9776 :                 let guard = self.inner.write().await;
      92              : 
      93         9776 :                 Ok(Self::set0(value, guard))
      94              :             }
      95          524 :             Err(_closed) => {
      96          524 :                 let guard = self.inner.write().await;
      97              :                 assert!(
      98          524 :                     guard.value.is_some(),
      99            0 :                     "semaphore got closed, must be initialized"
     100              :                 );
     101          524 :                 return Ok(GuardMut(guard));
     102              :             }
     103              :         }
     104     23943790 :     }
     105              : 
     106              :     /// Returns a guard to an existing initialized value, or uniquely initializes the value before
     107              :     /// returning the guard.
     108              :     ///
     109              :     /// Initialization is panic-safe and cancellation-safe.
     110            0 :     pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardRef<'_, T>, E>
     111            0 :     where
     112            0 :         F: FnOnce(InitPermit) -> Fut,
     113            0 :         Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
     114            0 :     {
     115            0 :         let sem = {
     116            0 :             let guard = self.inner.read().await;
     117            0 :             if guard.value.is_some() {
     118            0 :                 return Ok(GuardRef(guard));
     119            0 :             }
     120            0 :             guard.init_semaphore.clone()
     121              :         };
     122              : 
     123            0 :         let permit = {
     124              :             // increment the count for the duration of queued
     125            0 :             let _guard = CountWaitingInitializers::start(self);
     126            0 :             sem.acquire_owned().await
     127              :         };
     128              : 
     129            0 :         match permit {
     130            0 :             Ok(permit) => {
     131            0 :                 let permit = InitPermit(permit);
     132            0 :                 let (value, _permit) = factory(permit).await?;
     133              : 
     134            0 :                 let guard = self.inner.write().await;
     135              : 
     136            0 :                 Ok(Self::set0(value, guard).downgrade())
     137              :             }
     138            0 :             Err(_closed) => {
     139            0 :                 let guard = self.inner.read().await;
     140              :                 assert!(
     141            0 :                     guard.value.is_some(),
     142            0 :                     "semaphore got closed, must be initialized"
     143              :                 );
     144            0 :                 return Ok(GuardRef(guard));
     145              :             }
     146              :         }
     147            0 :     }
     148              : 
     149              :     /// Assuming a permit is held after previous call to [`GuardMut::take_and_deinit`], it can be used
     150              :     /// to complete initializing the inner value.
     151              :     ///
     152              :     /// # Panics
     153              :     ///
     154              :     /// If the inner has already been initialized.
     155            4 :     pub async fn set(&self, value: T, _permit: InitPermit) -> GuardMut<'_, T> {
     156            4 :         let guard = self.inner.write().await;
     157              : 
     158              :         // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
     159              :         // give more permits right now.
     160            4 :         if guard.init_semaphore.try_acquire().is_ok() {
     161            0 :             drop(guard);
     162            0 :             panic!("permit is of wrong origin");
     163            4 :         }
     164            4 : 
     165            4 :         Self::set0(value, guard)
     166            4 :     }
     167              : 
     168         9780 :     fn set0(value: T, mut guard: tokio::sync::RwLockWriteGuard<'_, Inner<T>>) -> GuardMut<'_, T> {
     169         9780 :         if guard.value.is_some() {
     170            0 :             drop(guard);
     171            0 :             unreachable!("we won permit, must not be initialized");
     172         9780 :         }
     173         9780 :         guard.value = Some(value);
     174         9780 :         guard.init_semaphore.close();
     175         9780 :         GuardMut(guard)
     176         9780 :     }
     177              : 
     178              :     /// Returns a guard to an existing initialized value, if any.
     179         8152 :     pub async fn get_mut(&self) -> Option<GuardMut<'_, T>> {
     180         8152 :         let guard = self.inner.write().await;
     181         8152 :         if guard.value.is_some() {
     182         7398 :             Some(GuardMut(guard))
     183              :         } else {
     184          754 :             None
     185              :         }
     186         8152 :     }
     187              : 
     188              :     /// Returns a guard to an existing initialized value, if any.
     189            0 :     pub async fn get(&self) -> Option<GuardRef<'_, T>> {
     190            0 :         let guard = self.inner.read().await;
     191            0 :         if guard.value.is_some() {
     192            0 :             Some(GuardRef(guard))
     193              :         } else {
     194            0 :             None
     195              :         }
     196            0 :     }
     197              : 
     198              :     /// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
     199         9768 :     pub fn initializer_count(&self) -> usize {
     200         9768 :         self.initializers.load(Ordering::Relaxed)
     201         9768 :     }
     202              : }
     203              : 
     204              : /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
     205              : /// initializing task for example at the end of initialization.
     206              : struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
     207              : 
     208              : impl<'a, T> CountWaitingInitializers<'a, T> {
     209        10926 :     fn start(target: &'a OnceCell<T>) -> Self {
     210        10926 :         target.initializers.fetch_add(1, Ordering::Relaxed);
     211        10926 :         CountWaitingInitializers(target)
     212        10926 :     }
     213              : }
     214              : 
     215              : impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
     216        10926 :     fn drop(&mut self) {
     217        10926 :         self.0.initializers.fetch_sub(1, Ordering::Relaxed);
     218        10926 :     }
     219              : }
     220              : 
     221              : /// Uninteresting guard object to allow short-lived access to inspect or clone the held,
     222              : /// initialized value.
     223            0 : #[derive(Debug)]
     224              : pub struct GuardMut<'a, T>(tokio::sync::RwLockWriteGuard<'a, Inner<T>>);
     225              : 
     226              : impl<T> std::ops::Deref for GuardMut<'_, T> {
     227              :     type Target = T;
     228              : 
     229         2771 :     fn deref(&self) -> &Self::Target {
     230         2771 :         self.0
     231         2771 :             .value
     232         2771 :             .as_ref()
     233         2771 :             .expect("guard is not created unless value has been initialized")
     234         2771 :     }
     235              : }
     236              : 
     237              : impl<T> std::ops::DerefMut for GuardMut<'_, T> {
     238     23945530 :     fn deref_mut(&mut self) -> &mut Self::Target {
     239     23945530 :         self.0
     240     23945530 :             .value
     241     23945530 :             .as_mut()
     242     23945530 :             .expect("guard is not created unless value has been initialized")
     243     23945530 :     }
     244              : }
     245              : 
     246              : impl<'a, T> GuardMut<'a, T> {
     247              :     /// Take the current value, and a new permit for it's deinitialization.
     248              :     ///
     249              :     /// The permit will be on a semaphore part of the new internal value, and any following
     250              :     /// [`OnceCell::get_or_init`] will wait on it to complete.
     251         2567 :     pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
     252         2567 :         let mut swapped = Inner::default();
     253         2567 :         let permit = swapped
     254         2567 :             .init_semaphore
     255         2567 :             .clone()
     256         2567 :             .try_acquire_owned()
     257         2567 :             .expect("we just created this");
     258         2567 :         std::mem::swap(&mut *self.0, &mut swapped);
     259         2567 :         swapped
     260         2567 :             .value
     261         2567 :             .map(|v| (v, InitPermit(permit)))
     262         2567 :             .expect("guard is not created unless value has been initialized")
     263         2567 :     }
     264              : 
     265            0 :     pub fn downgrade(self) -> GuardRef<'a, T> {
     266            0 :         GuardRef(self.0.downgrade())
     267            0 :     }
     268              : }
     269              : 
     270            0 : #[derive(Debug)]
     271              : pub struct GuardRef<'a, T>(tokio::sync::RwLockReadGuard<'a, Inner<T>>);
     272              : 
     273              : impl<T> std::ops::Deref for GuardRef<'_, T> {
     274              :     type Target = T;
     275              : 
     276            0 :     fn deref(&self) -> &Self::Target {
     277            0 :         self.0
     278            0 :             .value
     279            0 :             .as_ref()
     280            0 :             .expect("guard is not created unless value has been initialized")
     281            0 :     }
     282              : }
     283              : 
     284              : /// Type held by OnceCell (de)initializing task.
     285              : pub struct InitPermit(tokio::sync::OwnedSemaphorePermit);
     286              : 
     287              : #[cfg(test)]
     288              : mod tests {
     289              :     use super::*;
     290              :     use std::{
     291              :         convert::Infallible,
     292              :         sync::atomic::{AtomicUsize, Ordering},
     293              :         time::Duration,
     294              :     };
     295              : 
     296            2 :     #[tokio::test]
     297            2 :     async fn many_initializers() {
     298            2 :         #[derive(Default, Debug)]
     299            2 :         struct Counters {
     300            2 :             factory_got_to_run: AtomicUsize,
     301            2 :             future_polled: AtomicUsize,
     302            2 :             winners: AtomicUsize,
     303            2 :         }
     304            2 : 
     305            2 :         let initializers = 100;
     306            2 : 
     307            2 :         let cell = Arc::new(OnceCell::default());
     308            2 :         let counters = Arc::new(Counters::default());
     309            2 :         let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
     310            2 : 
     311            2 :         let mut js = tokio::task::JoinSet::new();
     312          200 :         for i in 0..initializers {
     313          200 :             js.spawn({
     314          200 :                 let cell = cell.clone();
     315          200 :                 let counters = counters.clone();
     316          200 :                 let barrier = barrier.clone();
     317          200 : 
     318          200 :                 async move {
     319          200 :                     barrier.wait().await;
     320          200 :                     let won = {
     321          200 :                         let g = cell
     322          200 :                             .get_mut_or_init(|permit| {
     323            2 :                                 counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
     324            2 :                                 async {
     325            2 :                                     counters.future_polled.fetch_add(1, Ordering::Relaxed);
     326            2 :                                     Ok::<_, Infallible>((i, permit))
     327            2 :                                 }
     328          200 :                             })
     329            0 :                             .await
     330          200 :                             .unwrap();
     331          200 : 
     332          200 :                         *g == i
     333          200 :                     };
     334          200 : 
     335          200 :                     if won {
     336            2 :                         counters.winners.fetch_add(1, Ordering::Relaxed);
     337          198 :                     }
     338          200 :                 }
     339          200 :             });
     340          200 :         }
     341              : 
     342            2 :         barrier.wait().await;
     343              : 
     344          202 :         while let Some(next) = js.join_next().await {
     345          200 :             next.expect("no panics expected");
     346          200 :         }
     347              : 
     348            2 :         let mut counters = Arc::try_unwrap(counters).unwrap();
     349            2 : 
     350            2 :         assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
     351            2 :         assert_eq!(*counters.future_polled.get_mut(), 1);
     352            2 :         assert_eq!(*counters.winners.get_mut(), 1);
     353              :     }
     354              : 
     355            2 :     #[tokio::test(start_paused = true)]
     356            2 :     async fn reinit_waits_for_deinit() {
     357            2 :         // with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
     358            2 :         let sleep_for = Duration::from_secs(1);
     359            2 :         let initial = 42;
     360            2 :         let reinit = 1;
     361            2 :         let cell = Arc::new(OnceCell::new(initial));
     362            2 : 
     363            2 :         let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
     364            2 : 
     365            2 :         let jh = tokio::spawn({
     366            2 :             let cell = cell.clone();
     367            2 :             let deinitialization_started = deinitialization_started.clone();
     368            2 :             async move {
     369            2 :                 let (answer, _permit) = cell
     370            2 :                     .get_mut()
     371            0 :                     .await
     372            2 :                     .expect("initialized to value")
     373            2 :                     .take_and_deinit();
     374            2 :                 assert_eq!(answer, initial);
     375              : 
     376            2 :                 deinitialization_started.wait().await;
     377            2 :                 tokio::time::sleep(sleep_for).await;
     378            2 :             }
     379            2 :         });
     380            2 : 
     381            2 :         deinitialization_started.wait().await;
     382              : 
     383            2 :         let started_at = tokio::time::Instant::now();
     384            2 :         cell.get_mut_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
     385            2 :             .await
     386            2 :             .unwrap();
     387            2 : 
     388            2 :         let elapsed = started_at.elapsed();
     389            2 :         assert!(
     390            2 :             elapsed >= sleep_for,
     391            0 :             "initialization should had taken at least the time time slept with permit"
     392              :         );
     393              : 
     394            2 :         jh.await.unwrap();
     395              : 
     396            2 :         assert_eq!(*cell.get_mut().await.unwrap(), reinit);
     397              :     }
     398              : 
     399            2 :     #[tokio::test]
     400            2 :     async fn reinit_with_deinit_permit() {
     401            2 :         let cell = Arc::new(OnceCell::new(42));
     402              : 
     403            2 :         let (mol, permit) = cell.get_mut().await.unwrap().take_and_deinit();
     404            2 :         cell.set(5, permit).await;
     405            2 :         assert_eq!(*cell.get_mut().await.unwrap(), 5);
     406              : 
     407            2 :         let (five, permit) = cell.get_mut().await.unwrap().take_and_deinit();
     408            2 :         assert_eq!(5, five);
     409            2 :         cell.set(mol, permit).await;
     410            2 :         assert_eq!(*cell.get_mut().await.unwrap(), 42);
     411              :     }
     412              : 
     413            2 :     #[tokio::test]
     414            2 :     async fn initialization_attemptable_until_ok() {
     415            2 :         let cell = OnceCell::default();
     416              : 
     417           22 :         for _ in 0..10 {
     418           20 :             cell.get_mut_or_init(|_permit| async { Err("whatever error") })
     419            0 :                 .await
     420           20 :                 .unwrap_err();
     421              :         }
     422              : 
     423            2 :         let g = cell
     424            2 :             .get_mut_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
     425            0 :             .await
     426            2 :             .unwrap();
     427            2 :         assert_eq!(*g, "finally success");
     428              :     }
     429              : 
     430            2 :     #[tokio::test]
     431            2 :     async fn initialization_is_cancellation_safe() {
     432            2 :         let cell = OnceCell::default();
     433            2 : 
     434            2 :         let barrier = tokio::sync::Barrier::new(2);
     435            2 : 
     436            2 :         let initializer = cell.get_mut_or_init(|permit| async {
     437            2 :             barrier.wait().await;
     438            0 :             futures::future::pending::<()>().await;
     439              : 
     440            0 :             Ok::<_, Infallible>(("never reached", permit))
     441            2 :         });
     442            2 : 
     443            2 :         tokio::select! {
     444            2 :             _ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
     445            2 :             _ = barrier.wait() => {}
     446            2 :         };
     447              : 
     448              :         // now initializer is dropped
     449              : 
     450            2 :         assert!(cell.get_mut().await.is_none());
     451              : 
     452            2 :         let g = cell
     453            2 :             .get_mut_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
     454            0 :             .await
     455            2 :             .unwrap();
     456            2 :         assert_eq!(*g, "now initialized");
     457              :     }
     458              : }
        

Generated by: LCOV version 2.1-beta