Line data Source code
1 : use std::sync::{
2 : atomic::{AtomicUsize, Ordering},
3 : Arc, Mutex, MutexGuard,
4 : };
5 : use tokio::sync::Semaphore;
6 :
7 : /// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
8 : /// `SemaphorePermit`, allowing use of `take` which does not require holding an outer mutex guard
9 : /// for the duration of initialization.
10 : ///
11 : /// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
12 : ///
13 : /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
14 : pub struct OnceCell<T> {
15 : inner: Mutex<Inner<T>>,
16 : initializers: AtomicUsize,
17 : }
18 :
19 : impl<T> Default for OnceCell<T> {
20 : /// Create new uninitialized [`OnceCell`].
21 40907 : fn default() -> Self {
22 40907 : Self {
23 40907 : inner: Default::default(),
24 40907 : initializers: AtomicUsize::new(0),
25 40907 : }
26 40907 : }
27 : }
28 :
29 : /// Semaphore is the current state:
30 : /// - open semaphore means the value is `None`, not yet initialized
31 : /// - closed semaphore means the value has been initialized
32 0 : #[derive(Debug)]
33 : struct Inner<T> {
34 : init_semaphore: Arc<Semaphore>,
35 : value: Option<T>,
36 : }
37 :
38 : impl<T> Default for Inner<T> {
39 43430 : fn default() -> Self {
40 43430 : Self {
41 43430 : init_semaphore: Arc::new(Semaphore::new(1)),
42 43430 : value: None,
43 43430 : }
44 43430 : }
45 : }
46 :
47 : impl<T> OnceCell<T> {
48 : /// Creates an already initialized `OnceCell` with the given value.
49 34461 : pub fn new(value: T) -> Self {
50 34461 : let sem = Semaphore::new(1);
51 34461 : sem.close();
52 34461 : Self {
53 34461 : inner: Mutex::new(Inner {
54 34461 : init_semaphore: Arc::new(sem),
55 34461 : value: Some(value),
56 34461 : }),
57 34461 : initializers: AtomicUsize::new(0),
58 34461 : }
59 34461 : }
60 :
61 : /// Returns a guard to an existing initialized value, or uniquely initializes the value before
62 : /// returning the guard.
63 : ///
64 : /// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
65 : ///
66 : /// Initialization is panic-safe and cancellation-safe.
67 16853002 : pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
68 16853002 : where
69 16853002 : F: FnOnce(InitPermit) -> Fut,
70 16853002 : Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
71 16853002 : {
72 : loop {
73 10619 : let sem = {
74 16853004 : let guard = self.inner.lock().unwrap();
75 16853004 : if guard.value.is_some() {
76 16842385 : return Ok(Guard(guard));
77 10619 : }
78 10619 : guard.init_semaphore.clone()
79 : };
80 :
81 : {
82 10619 : let permit = {
83 : // increment the count for the duration of queued
84 10619 : let _guard = CountWaitingInitializers::start(self);
85 10619 : sem.acquire().await
86 : };
87 :
88 10619 : let Ok(permit) = permit else {
89 499 : let guard = self.inner.lock().unwrap();
90 499 : if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
91 : // there was a take_and_deinit in between
92 2 : continue;
93 497 : }
94 497 : assert!(
95 497 : guard.value.is_some(),
96 0 : "semaphore got closed, must be initialized"
97 : );
98 497 : return Ok(Guard(guard));
99 : };
100 :
101 10120 : permit.forget();
102 10120 : }
103 10120 :
104 10120 : let permit = InitPermit(sem);
105 29430 : let (value, _permit) = factory(permit).await?;
106 :
107 9468 : let guard = self.inner.lock().unwrap();
108 9468 :
109 9468 : return Ok(Self::set0(value, guard));
110 : }
111 16852992 : }
112 :
113 : /// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
114 : /// to complete initializing the inner value.
115 : ///
116 : /// # Panics
117 : ///
118 : /// If the inner has already been initialized.
119 4 : pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
120 4 : let guard = self.inner.lock().unwrap();
121 4 :
122 4 : // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
123 4 : // give more permits right now.
124 4 : if guard.init_semaphore.try_acquire().is_ok() {
125 0 : drop(guard);
126 0 : panic!("permit is of wrong origin");
127 4 : }
128 4 :
129 4 : Self::set0(value, guard)
130 4 : }
131 :
132 9472 : fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
133 9472 : if guard.value.is_some() {
134 0 : drop(guard);
135 0 : unreachable!("we won permit, must not be initialized");
136 9472 : }
137 9472 : guard.value = Some(value);
138 9472 : guard.init_semaphore.close();
139 9472 : Guard(guard)
140 9472 : }
141 :
142 : /// Returns a guard to an existing initialized value, if any.
143 8100 : pub fn get(&self) -> Option<Guard<'_, T>> {
144 8100 : let guard = self.inner.lock().unwrap();
145 8100 : if guard.value.is_some() {
146 7311 : Some(Guard(guard))
147 : } else {
148 789 : None
149 : }
150 8100 : }
151 :
152 : /// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
153 9456 : pub fn initializer_count(&self) -> usize {
154 9456 : self.initializers.load(Ordering::Relaxed)
155 9456 : }
156 : }
157 :
158 : /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
159 : /// initializing task for example at the end of initialization.
160 : struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
161 :
162 : impl<'a, T> CountWaitingInitializers<'a, T> {
163 10619 : fn start(target: &'a OnceCell<T>) -> Self {
164 10619 : target.initializers.fetch_add(1, Ordering::Relaxed);
165 10619 : CountWaitingInitializers(target)
166 10619 : }
167 : }
168 :
169 : impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
170 10619 : fn drop(&mut self) {
171 10619 : self.0.initializers.fetch_sub(1, Ordering::Relaxed);
172 10619 : }
173 : }
174 :
175 : /// Uninteresting guard object to allow short-lived access to inspect or clone the held,
176 : /// initialized value.
177 0 : #[derive(Debug)]
178 : pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
179 :
180 : impl<T> std::ops::Deref for Guard<'_, T> {
181 : type Target = T;
182 :
183 2727 : fn deref(&self) -> &Self::Target {
184 2727 : self.0
185 2727 : .value
186 2727 : .as_ref()
187 2727 : .expect("guard is not created unless value has been initialized")
188 2727 : }
189 : }
190 :
191 : impl<T> std::ops::DerefMut for Guard<'_, T> {
192 16854655 : fn deref_mut(&mut self) -> &mut Self::Target {
193 16854655 : self.0
194 16854655 : .value
195 16854655 : .as_mut()
196 16854655 : .expect("guard is not created unless value has been initialized")
197 16854655 : }
198 : }
199 :
200 : impl<'a, T> Guard<'a, T> {
201 : /// Take the current value, and a new permit for it's deinitialization.
202 : ///
203 : /// The permit will be on a semaphore part of the new internal value, and any following
204 : /// [`OnceCell::get_or_init`] will wait on it to complete.
205 2523 : pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
206 2523 : let mut swapped = Inner::default();
207 2523 : let sem = swapped.init_semaphore.clone();
208 2523 : // acquire and forget right away, moving the control over to InitPermit
209 2523 : sem.try_acquire().expect("we just created this").forget();
210 2523 : std::mem::swap(&mut *self.0, &mut swapped);
211 2523 : swapped
212 2523 : .value
213 2523 : .map(|v| (v, InitPermit(sem)))
214 2523 : .expect("guard is not created unless value has been initialized")
215 2523 : }
216 : }
217 :
218 : /// Type held by OnceCell (de)initializing task.
219 : ///
220 : /// On drop, this type will return the permit.
221 : pub struct InitPermit(Arc<tokio::sync::Semaphore>);
222 :
223 : impl Drop for InitPermit {
224 12640 : fn drop(&mut self) {
225 12640 : assert_eq!(
226 12640 : self.0.available_permits(),
227 : 0,
228 0 : "InitPermit should only exist as the unique permit"
229 : );
230 12640 : self.0.add_permits(1);
231 12640 : }
232 : }
233 :
234 : #[cfg(test)]
235 : mod tests {
236 : use futures::Future;
237 :
238 : use super::*;
239 : use std::{
240 : convert::Infallible,
241 : pin::{pin, Pin},
242 : sync::atomic::{AtomicUsize, Ordering},
243 : time::Duration,
244 : };
245 :
246 2 : #[tokio::test]
247 2 : async fn many_initializers() {
248 2 : #[derive(Default, Debug)]
249 2 : struct Counters {
250 2 : factory_got_to_run: AtomicUsize,
251 2 : future_polled: AtomicUsize,
252 2 : winners: AtomicUsize,
253 2 : }
254 2 :
255 2 : let initializers = 100;
256 2 :
257 2 : let cell = Arc::new(OnceCell::default());
258 2 : let counters = Arc::new(Counters::default());
259 2 : let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
260 2 :
261 2 : let mut js = tokio::task::JoinSet::new();
262 200 : for i in 0..initializers {
263 200 : js.spawn({
264 200 : let cell = cell.clone();
265 200 : let counters = counters.clone();
266 200 : let barrier = barrier.clone();
267 200 :
268 200 : async move {
269 200 : barrier.wait().await;
270 200 : let won = {
271 200 : let g = cell
272 200 : .get_or_init(|permit| {
273 2 : counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
274 2 : async {
275 2 : counters.future_polled.fetch_add(1, Ordering::Relaxed);
276 2 : Ok::<_, Infallible>((i, permit))
277 2 : }
278 200 : })
279 2 : .await
280 200 : .unwrap();
281 200 :
282 200 : *g == i
283 200 : };
284 200 :
285 200 : if won {
286 2 : counters.winners.fetch_add(1, Ordering::Relaxed);
287 198 : }
288 200 : }
289 200 : });
290 200 : }
291 2 :
292 2 : barrier.wait().await;
293 2 :
294 202 : while let Some(next) = js.join_next().await {
295 200 : next.expect("no panics expected");
296 200 : }
297 2 :
298 2 : let mut counters = Arc::try_unwrap(counters).unwrap();
299 2 :
300 2 : assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
301 2 : assert_eq!(*counters.future_polled.get_mut(), 1);
302 2 : assert_eq!(*counters.winners.get_mut(), 1);
303 2 : }
304 :
305 2 : #[tokio::test(start_paused = true)]
306 2 : async fn reinit_waits_for_deinit() {
307 2 : // with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
308 2 : let sleep_for = Duration::from_secs(1);
309 2 : let initial = 42;
310 2 : let reinit = 1;
311 2 : let cell = Arc::new(OnceCell::new(initial));
312 2 :
313 2 : let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
314 2 :
315 2 : let jh = tokio::spawn({
316 2 : let cell = cell.clone();
317 2 : let deinitialization_started = deinitialization_started.clone();
318 2 : async move {
319 2 : let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
320 2 : assert_eq!(answer, initial);
321 2 :
322 2 : deinitialization_started.wait().await;
323 2 : tokio::time::sleep(sleep_for).await;
324 2 : }
325 2 : });
326 2 :
327 2 : deinitialization_started.wait().await;
328 2 :
329 2 : let started_at = tokio::time::Instant::now();
330 2 : cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
331 2 : .await
332 2 : .unwrap();
333 2 :
334 2 : let elapsed = started_at.elapsed();
335 2 : assert!(
336 2 : elapsed >= sleep_for,
337 2 : "initialization should had taken at least the time time slept with permit"
338 2 : );
339 2 :
340 2 : jh.await.unwrap();
341 2 :
342 2 : assert_eq!(*cell.get().unwrap(), reinit);
343 2 : }
344 :
345 2 : #[test]
346 2 : fn reinit_with_deinit_permit() {
347 2 : let cell = Arc::new(OnceCell::new(42));
348 2 :
349 2 : let (mol, permit) = cell.get().unwrap().take_and_deinit();
350 2 : cell.set(5, permit);
351 2 : assert_eq!(*cell.get().unwrap(), 5);
352 :
353 2 : let (five, permit) = cell.get().unwrap().take_and_deinit();
354 2 : assert_eq!(5, five);
355 2 : cell.set(mol, permit);
356 2 : assert_eq!(*cell.get().unwrap(), 42);
357 2 : }
358 :
359 2 : #[tokio::test]
360 2 : async fn initialization_attemptable_until_ok() {
361 2 : let cell = OnceCell::default();
362 2 :
363 22 : for _ in 0..10 {
364 20 : cell.get_or_init(|_permit| async { Err("whatever error") })
365 2 : .await
366 20 : .unwrap_err();
367 2 : }
368 2 :
369 2 : let g = cell
370 2 : .get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
371 2 : .await
372 2 : .unwrap();
373 2 : assert_eq!(*g, "finally success");
374 2 : }
375 :
376 2 : #[tokio::test]
377 2 : async fn initialization_is_cancellation_safe() {
378 2 : let cell = OnceCell::default();
379 2 :
380 2 : let barrier = tokio::sync::Barrier::new(2);
381 2 :
382 2 : let initializer = cell.get_or_init(|permit| async {
383 2 : barrier.wait().await;
384 2 : futures::future::pending::<()>().await;
385 2 :
386 2 : Ok::<_, Infallible>(("never reached", permit))
387 2 : });
388 2 :
389 3 : tokio::select! {
390 3 : _ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
391 3 : _ = barrier.wait() => {}
392 3 : };
393 2 :
394 2 : // now initializer is dropped
395 2 :
396 2 : assert!(cell.get().is_none());
397 2 :
398 2 : let g = cell
399 2 : .get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
400 2 : .await
401 2 : .unwrap();
402 2 : assert_eq!(*g, "now initialized");
403 2 : }
404 :
405 2 : #[tokio::test(start_paused = true)]
406 2 : async fn reproduce_init_take_deinit_race() {
407 4 : init_take_deinit_scenario(|cell, factory| {
408 4 : Box::pin(async {
409 10 : cell.get_or_init(factory).await.unwrap();
410 4 : })
411 4 : })
412 8 : .await;
413 2 : }
414 :
415 : type BoxedInitFuture<T, E> = Pin<Box<dyn Future<Output = Result<(T, InitPermit), E>>>>;
416 : type BoxedInitFunction<T, E> = Box<dyn Fn(InitPermit) -> BoxedInitFuture<T, E>>;
417 :
418 : /// Reproduce an assertion failure.
419 : ///
420 : /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`.
421 : /// We currently only have one, but the structure is kept.
422 2 : async fn init_take_deinit_scenario<F>(init_way: F)
423 2 : where
424 2 : F: for<'a> Fn(
425 2 : &'a OnceCell<&'static str>,
426 2 : BoxedInitFunction<&'static str, Infallible>,
427 2 : ) -> Pin<Box<dyn Future<Output = ()> + 'a>>,
428 2 : {
429 2 : let cell = OnceCell::default();
430 2 :
431 2 : // acquire the init_semaphore only permit to drive initializing tasks in order to waiting
432 2 : // on the same semaphore.
433 2 : let permit = cell
434 2 : .inner
435 2 : .lock()
436 2 : .unwrap()
437 2 : .init_semaphore
438 2 : .clone()
439 2 : .try_acquire_owned()
440 2 : .unwrap();
441 2 :
442 2 : let mut t1 = pin!(init_way(
443 2 : &cell,
444 2 : Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })),
445 2 : ));
446 2 :
447 2 : let mut t2 = pin!(init_way(
448 2 : &cell,
449 2 : Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })),
450 2 : ));
451 2 :
452 2 : // drive t2 first to the init_semaphore -- the timeout will be hit once t2 future can
453 2 : // no longer make progress
454 4 : tokio::select! {
455 4 : _ = &mut t2 => unreachable!("it cannot get permit"),
456 4 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
457 4 : }
458 :
459 : // followed by t1 in the init_semaphore
460 2 : tokio::select! {
461 4 : _ = &mut t1 => unreachable!("it cannot get permit"),
462 4 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
463 4 : }
464 :
465 : // now let t2 proceed and initialize
466 2 : drop(permit);
467 2 : t2.await;
468 :
469 2 : let (s, permit) = { cell.get().unwrap().take_and_deinit() };
470 2 : assert_eq!("t2", s);
471 :
472 : // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from
473 : // the new one.
474 2 : tokio::select! {
475 6 : _ = &mut t1 => unreachable!("it cannot get permit"),
476 6 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
477 6 : }
478 :
479 : // only now we get to initialize it
480 2 : drop(permit);
481 2 : t1.await;
482 :
483 2 : assert_eq!("t1", *cell.get().unwrap());
484 2 : }
485 : }
|