LCOV - differential code coverage report
Current view: top level - libs/utils/src/sync - heavier_once_cell.rs (source / functions) Coverage Total Hit UBC GBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 94.8 % 249 236 13 1 235
Current Date: 2024-01-09 02:06:09 Functions: 76.9 % 91 70 21 70
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use std::sync::{
       2                 :     atomic::{AtomicUsize, Ordering},
       3                 :     Arc, Mutex, MutexGuard,
       4                 : };
       5                 : use tokio::sync::Semaphore;
       6                 : 
       7                 : /// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
       8                 : /// `SemaphorePermit`, allowing use of `take` which does not require holding an outer mutex guard
       9                 : /// for the duration of initialization.
      10                 : ///
      11                 : /// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
      12                 : ///
      13                 : /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
      14                 : pub struct OnceCell<T> {
      15                 :     inner: Mutex<Inner<T>>,
      16                 :     initializers: AtomicUsize,
      17                 : }
      18                 : 
      19                 : impl<T> Default for OnceCell<T> {
      20                 :     /// Create new uninitialized [`OnceCell`].
      21 CBC       43723 :     fn default() -> Self {
      22           43723 :         Self {
      23           43723 :             inner: Default::default(),
      24           43723 :             initializers: AtomicUsize::new(0),
      25           43723 :         }
      26           43723 :     }
      27                 : }
      28                 : 
      29                 : /// Semaphore is the current state:
      30                 : /// - open semaphore means the value is `None`, not yet initialized
      31                 : /// - closed semaphore means the value has been initialized
      32 UBC           0 : #[derive(Debug)]
      33                 : struct Inner<T> {
      34                 :     init_semaphore: Arc<Semaphore>,
      35                 :     value: Option<T>,
      36                 : }
      37                 : 
      38                 : impl<T> Default for Inner<T> {
      39 CBC       46264 :     fn default() -> Self {
      40           46264 :         Self {
      41           46264 :             init_semaphore: Arc::new(Semaphore::new(1)),
      42           46264 :             value: None,
      43           46264 :         }
      44           46264 :     }
      45                 : }
      46                 : 
      47                 : impl<T> OnceCell<T> {
      48                 :     /// Creates an already initialized `OnceCell` with the given value.
      49           34535 :     pub fn new(value: T) -> Self {
      50           34535 :         let sem = Semaphore::new(1);
      51           34535 :         sem.close();
      52           34535 :         Self {
      53           34535 :             inner: Mutex::new(Inner {
      54           34535 :                 init_semaphore: Arc::new(sem),
      55           34535 :                 value: Some(value),
      56           34535 :             }),
      57           34535 :             initializers: AtomicUsize::new(0),
      58           34535 :         }
      59           34535 :     }
      60                 : 
      61                 :     /// Returns a guard to an existing initialized value, or uniquely initializes the value before
      62                 :     /// returning the guard.
      63                 :     ///
      64                 :     /// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
      65                 :     ///
      66                 :     /// Initialization is panic-safe and cancellation-safe.
      67        15463516 :     pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
      68        15463516 :     where
      69        15463516 :         F: FnOnce(InitPermit) -> Fut,
      70        15463516 :         Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
      71        15463516 :     {
      72           10552 :         let sem = {
      73        15463516 :             let guard = self.inner.lock().unwrap();
      74        15463516 :             if guard.value.is_some() {
      75        15452964 :                 return Ok(Guard(guard));
      76           10552 :             }
      77           10552 :             guard.init_semaphore.clone()
      78                 :         };
      79                 : 
      80           10552 :         let permit = {
      81                 :             // increment the count for the duration of queued
      82           10552 :             let _guard = CountWaitingInitializers::start(self);
      83           10552 :             sem.acquire_owned().await
      84                 :         };
      85                 : 
      86           10552 :         match permit {
      87           10050 :             Ok(permit) => {
      88           10050 :                 let permit = InitPermit(permit);
      89           29253 :                 let (value, _permit) = factory(permit).await?;
      90                 : 
      91            9400 :                 let guard = self.inner.lock().unwrap();
      92            9400 : 
      93            9400 :                 Ok(Self::set0(value, guard))
      94                 :             }
      95             502 :             Err(_closed) => {
      96             502 :                 let guard = self.inner.lock().unwrap();
      97             502 :                 assert!(
      98             502 :                     guard.value.is_some(),
      99 UBC           0 :                     "semaphore got closed, must be initialized"
     100                 :                 );
     101 CBC         502 :                 return Ok(Guard(guard));
     102                 :             }
     103                 :         }
     104        15463506 :     }
     105                 : 
     106                 :     /// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
     107                 :     /// to complete initializing the inner value.
     108                 :     ///
     109                 :     /// # Panics
     110                 :     ///
     111                 :     /// If the inner has already been initialized.
     112               2 :     pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
     113               2 :         let guard = self.inner.lock().unwrap();
     114               2 : 
     115               2 :         // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
     116               2 :         // give more permits right now.
     117               2 :         if guard.init_semaphore.try_acquire().is_ok() {
     118 UBC           0 :             drop(guard);
     119               0 :             panic!("permit is of wrong origin");
     120 CBC           2 :         }
     121               2 : 
     122               2 :         Self::set0(value, guard)
     123               2 :     }
     124                 : 
     125            9402 :     fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
     126            9402 :         if guard.value.is_some() {
     127 UBC           0 :             drop(guard);
     128               0 :             unreachable!("we won permit, must not be initialized");
     129 CBC        9402 :         }
     130            9402 :         guard.value = Some(value);
     131            9402 :         guard.init_semaphore.close();
     132            9402 :         Guard(guard)
     133            9402 :     }
     134                 : 
     135                 :     /// Returns a guard to an existing initialized value, if any.
     136            8058 :     pub fn get(&self) -> Option<Guard<'_, T>> {
     137            8058 :         let guard = self.inner.lock().unwrap();
     138            8058 :         if guard.value.is_some() {
     139            7250 :             Some(Guard(guard))
     140                 :         } else {
     141             808 :             None
     142                 :         }
     143            8058 :     }
     144                 : 
     145                 :     /// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
     146            9396 :     pub fn initializer_count(&self) -> usize {
     147            9396 :         self.initializers.load(Ordering::Relaxed)
     148            9396 :     }
     149                 : }
     150                 : 
     151                 : /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
     152                 : /// initializing task for example at the end of initialization.
     153                 : struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
     154                 : 
     155                 : impl<'a, T> CountWaitingInitializers<'a, T> {
     156           10552 :     fn start(target: &'a OnceCell<T>) -> Self {
     157           10552 :         target.initializers.fetch_add(1, Ordering::Relaxed);
     158           10552 :         CountWaitingInitializers(target)
     159           10552 :     }
     160                 : }
     161                 : 
     162                 : impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
     163           10552 :     fn drop(&mut self) {
     164           10552 :         self.0.initializers.fetch_sub(1, Ordering::Relaxed);
     165           10552 :     }
     166                 : }
     167                 : 
     168                 : /// Uninteresting guard object to allow short-lived access to inspect or clone the held,
     169                 : /// initialized value.
     170 UBC           0 : #[derive(Debug)]
     171                 : pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
     172                 : 
     173                 : impl<T> std::ops::Deref for Guard<'_, T> {
     174                 :     type Target = T;
     175                 : 
     176 CBC        2643 :     fn deref(&self) -> &Self::Target {
     177            2643 :         self.0
     178            2643 :             .value
     179            2643 :             .as_ref()
     180            2643 :             .expect("guard is not created unless value has been initialized")
     181            2643 :     }
     182                 : }
     183                 : 
     184                 : impl<T> std::ops::DerefMut for Guard<'_, T> {
     185        15465301 :     fn deref_mut(&mut self) -> &mut Self::Target {
     186        15465301 :         self.0
     187        15465301 :             .value
     188        15465301 :             .as_mut()
     189        15465301 :             .expect("guard is not created unless value has been initialized")
     190        15465301 :     }
     191                 : }
     192                 : 
     193                 : impl<'a, T> Guard<'a, T> {
     194                 :     /// Take the current value, and a new permit for it's deinitialization.
     195                 :     ///
     196                 :     /// The permit will be on a semaphore part of the new internal value, and any following
     197                 :     /// [`OnceCell::get_or_init`] will wait on it to complete.
     198            2541 :     pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
     199            2541 :         let mut swapped = Inner::default();
     200            2541 :         let permit = swapped
     201            2541 :             .init_semaphore
     202            2541 :             .clone()
     203            2541 :             .try_acquire_owned()
     204            2541 :             .expect("we just created this");
     205            2541 :         std::mem::swap(&mut *self.0, &mut swapped);
     206            2541 :         swapped
     207            2541 :             .value
     208            2541 :             .map(|v| (v, InitPermit(permit)))
     209            2541 :             .expect("guard is not created unless value has been initialized")
     210            2541 :     }
     211                 : }
     212                 : 
     213                 : /// Type held by OnceCell (de)initializing task.
     214                 : pub struct InitPermit(tokio::sync::OwnedSemaphorePermit);
     215                 : 
     216                 : #[cfg(test)]
     217                 : mod tests {
     218                 :     use super::*;
     219                 :     use std::{
     220                 :         convert::Infallible,
     221                 :         sync::atomic::{AtomicUsize, Ordering},
     222                 :         time::Duration,
     223                 :     };
     224                 : 
     225               1 :     #[tokio::test]
     226               1 :     async fn many_initializers() {
     227               1 :         #[derive(Default, Debug)]
     228               1 :         struct Counters {
     229               1 :             factory_got_to_run: AtomicUsize,
     230               1 :             future_polled: AtomicUsize,
     231               1 :             winners: AtomicUsize,
     232               1 :         }
     233               1 : 
     234               1 :         let initializers = 100;
     235               1 : 
     236               1 :         let cell = Arc::new(OnceCell::default());
     237               1 :         let counters = Arc::new(Counters::default());
     238               1 :         let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
     239               1 : 
     240               1 :         let mut js = tokio::task::JoinSet::new();
     241             100 :         for i in 0..initializers {
     242             100 :             js.spawn({
     243             100 :                 let cell = cell.clone();
     244             100 :                 let counters = counters.clone();
     245             100 :                 let barrier = barrier.clone();
     246             100 : 
     247             100 :                 async move {
     248             100 :                     barrier.wait().await;
     249             100 :                     let won = {
     250             100 :                         let g = cell
     251             100 :                             .get_or_init(|permit| {
     252               1 :                                 counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
     253               1 :                                 async {
     254               1 :                                     counters.future_polled.fetch_add(1, Ordering::Relaxed);
     255               1 :                                     Ok::<_, Infallible>((i, permit))
     256               1 :                                 }
     257             100 :                             })
     258 UBC           0 :                             .await
     259 CBC         100 :                             .unwrap();
     260             100 : 
     261             100 :                         *g == i
     262             100 :                     };
     263             100 : 
     264             100 :                     if won {
     265               1 :                         counters.winners.fetch_add(1, Ordering::Relaxed);
     266              99 :                     }
     267             100 :                 }
     268             100 :             });
     269             100 :         }
     270                 : 
     271               1 :         barrier.wait().await;
     272                 : 
     273             101 :         while let Some(next) = js.join_next().await {
     274             100 :             next.expect("no panics expected");
     275             100 :         }
     276                 : 
     277               1 :         let mut counters = Arc::try_unwrap(counters).unwrap();
     278               1 : 
     279               1 :         assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
     280               1 :         assert_eq!(*counters.future_polled.get_mut(), 1);
     281               1 :         assert_eq!(*counters.winners.get_mut(), 1);
     282                 :     }
     283                 : 
     284               1 :     #[tokio::test(start_paused = true)]
     285               1 :     async fn reinit_waits_for_deinit() {
     286               1 :         // with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
     287               1 :         let sleep_for = Duration::from_secs(1);
     288               1 :         let initial = 42;
     289               1 :         let reinit = 1;
     290               1 :         let cell = Arc::new(OnceCell::new(initial));
     291               1 : 
     292               1 :         let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
     293               1 : 
     294               1 :         let jh = tokio::spawn({
     295               1 :             let cell = cell.clone();
     296               1 :             let deinitialization_started = deinitialization_started.clone();
     297               1 :             async move {
     298               1 :                 let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
     299               1 :                 assert_eq!(answer, initial);
     300                 : 
     301               1 :                 deinitialization_started.wait().await;
     302               1 :                 tokio::time::sleep(sleep_for).await;
     303               1 :             }
     304               1 :         });
     305               1 : 
     306               1 :         deinitialization_started.wait().await;
     307                 : 
     308               1 :         let started_at = tokio::time::Instant::now();
     309               1 :         cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
     310               1 :             .await
     311               1 :             .unwrap();
     312               1 : 
     313               1 :         let elapsed = started_at.elapsed();
     314               1 :         assert!(
     315               1 :             elapsed >= sleep_for,
     316 UBC           0 :             "initialization should had taken at least the time time slept with permit"
     317                 :         );
     318                 : 
     319 CBC           1 :         jh.await.unwrap();
     320               1 : 
     321               1 :         assert_eq!(*cell.get().unwrap(), reinit);
     322                 :     }
     323                 : 
     324               1 :     #[test]
     325               1 :     fn reinit_with_deinit_permit() {
     326               1 :         let cell = Arc::new(OnceCell::new(42));
     327               1 : 
     328               1 :         let (mol, permit) = cell.get().unwrap().take_and_deinit();
     329               1 :         cell.set(5, permit);
     330               1 :         assert_eq!(*cell.get().unwrap(), 5);
     331                 : 
     332               1 :         let (five, permit) = cell.get().unwrap().take_and_deinit();
     333               1 :         assert_eq!(5, five);
     334               1 :         cell.set(mol, permit);
     335               1 :         assert_eq!(*cell.get().unwrap(), 42);
     336               1 :     }
     337                 : 
     338               1 :     #[tokio::test]
     339               1 :     async fn initialization_attemptable_until_ok() {
     340               1 :         let cell = OnceCell::default();
     341                 : 
     342              11 :         for _ in 0..10 {
     343              10 :             cell.get_or_init(|_permit| async { Err("whatever error") })
     344 UBC           0 :                 .await
     345 CBC          10 :                 .unwrap_err();
     346                 :         }
     347                 : 
     348               1 :         let g = cell
     349               1 :             .get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
     350 UBC           0 :             .await
     351 CBC           1 :             .unwrap();
     352               1 :         assert_eq!(*g, "finally success");
     353                 :     }
     354                 : 
     355               1 :     #[tokio::test]
     356               1 :     async fn initialization_is_cancellation_safe() {
     357               1 :         let cell = OnceCell::default();
     358               1 : 
     359               1 :         let barrier = tokio::sync::Barrier::new(2);
     360               1 : 
     361               1 :         let initializer = cell.get_or_init(|permit| async {
     362               1 :             barrier.wait().await;
     363 GBC           1 :             futures::future::pending::<()>().await;
     364                 : 
     365 UBC           0 :             Ok::<_, Infallible>(("never reached", permit))
     366 CBC           1 :         });
     367               1 : 
     368               2 :         tokio::select! {
     369               2 :             _ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
     370               2 :             _ = barrier.wait() => {}
     371               2 :         };
     372                 : 
     373                 :         // now initializer is dropped
     374                 : 
     375               1 :         assert!(cell.get().is_none());
     376                 : 
     377               1 :         let g = cell
     378               1 :             .get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
     379 UBC           0 :             .await
     380 CBC           1 :             .unwrap();
     381               1 :         assert_eq!(*g, "now initialized");
     382                 :     }
     383                 : }
        

Generated by: LCOV version 2.1-beta