Line data Source code
1 : use std::sync::atomic::{AtomicUsize, Ordering};
2 : use std::sync::{Arc, Mutex, MutexGuard};
3 :
4 : use tokio::sync::Semaphore;
5 :
6 : /// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
7 : /// `SemaphorePermit`.
8 : ///
9 : /// Allows use of `take` which does not require holding an outer mutex guard
10 : /// for the duration of initialization.
11 : ///
12 : /// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
13 : ///
14 : /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
15 : pub struct OnceCell<T> {
16 : inner: Mutex<Inner<T>>,
17 : initializers: AtomicUsize,
18 : }
19 :
20 : impl<T> Default for OnceCell<T> {
21 : /// Create new uninitialized [`OnceCell`].
22 22 : fn default() -> Self {
23 22 : Self {
24 22 : inner: Default::default(),
25 22 : initializers: AtomicUsize::new(0),
26 22 : }
27 22 : }
28 : }
29 :
30 : /// Semaphore is the current state:
31 : /// - open semaphore means the value is `None`, not yet initialized
32 : /// - closed semaphore means the value has been initialized
33 : #[derive(Debug)]
34 : struct Inner<T> {
35 : init_semaphore: Arc<Semaphore>,
36 : value: Option<T>,
37 : }
38 :
39 : impl<T> Default for Inner<T> {
40 1529 : fn default() -> Self {
41 1529 : Self {
42 1529 : init_semaphore: Arc::new(Semaphore::new(1)),
43 1529 : value: None,
44 1529 : }
45 1529 : }
46 : }
47 :
48 : impl<T> OnceCell<T> {
49 : /// Creates an already initialized `OnceCell` with the given value.
50 3878 : pub fn new(value: T) -> Self {
51 3878 : let sem = Semaphore::new(1);
52 3878 : sem.close();
53 3878 : Self {
54 3878 : inner: Mutex::new(Inner {
55 3878 : init_semaphore: Arc::new(sem),
56 3878 : value: Some(value),
57 3878 : }),
58 3878 : initializers: AtomicUsize::new(0),
59 3878 : }
60 3878 : }
61 :
62 : /// Returns a guard to an existing initialized value, or uniquely initializes the value before
63 : /// returning the guard.
64 : ///
65 : /// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
66 : ///
67 : /// Initialization is panic-safe and cancellation-safe.
68 119 : pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
69 119 : where
70 119 : F: FnOnce(InitPermit) -> Fut,
71 119 : Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
72 119 : {
73 : loop {
74 20 : let sem = {
75 120 : let guard = self.inner.lock().unwrap();
76 120 : if guard.value.is_some() {
77 100 : return Ok(Guard(guard));
78 20 : }
79 20 : guard.init_semaphore.clone()
80 : };
81 :
82 : {
83 19 : let permit = {
84 : // increment the count for the duration of queued
85 20 : let _guard = CountWaitingInitializers::start(self);
86 20 : sem.acquire().await
87 : };
88 :
89 19 : let Ok(permit) = permit else {
90 1 : let guard = self.inner.lock().unwrap();
91 1 : if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
92 : // there was a take_and_deinit in between
93 1 : continue;
94 0 : }
95 0 : assert!(
96 0 : guard.value.is_some(),
97 0 : "semaphore got closed, must be initialized"
98 : );
99 0 : return Ok(Guard(guard));
100 : };
101 :
102 18 : permit.forget();
103 18 : }
104 18 :
105 18 : let permit = InitPermit(sem);
106 18 : let (value, _permit) = factory(permit).await?;
107 :
108 7 : let guard = self.inner.lock().unwrap();
109 7 :
110 7 : return Ok(Self::set0(value, guard));
111 : }
112 117 : }
113 :
114 : /// Like [`Self::get_or_init_detached_measured`], but without out parameter for time spent waiting.
115 101 : pub async fn get_or_init_detached(&self) -> Result<Guard<'_, T>, InitPermit> {
116 101 : self.get_or_init_detached_measured(None).await
117 1 : }
118 :
119 : /// Returns a guard to an existing initialized value, or returns an unique initialization
120 : /// permit which can be used to initialize this `OnceCell` using `OnceCell::set`.
121 481537 : pub async fn get_or_init_detached_measured(
122 481537 : &self,
123 481537 : mut wait_time: Option<&mut crate::elapsed_accum::ElapsedAccum>,
124 481537 : ) -> Result<Guard<'_, T>, InitPermit> {
125 : // It looks like OnceCell::get_or_init could be implemented using this method instead of
126 : // duplication. However, that makes the future be !Send due to possibly holding on to the
127 : // MutexGuard over an await point.
128 : loop {
129 61 : let sem = {
130 481537 : let guard = self.inner.lock().unwrap();
131 481537 : if guard.value.is_some() {
132 481476 : return Ok(Guard(guard));
133 61 : }
134 61 : guard.init_semaphore.clone()
135 : };
136 : {
137 61 : let permit = {
138 : // increment the count for the duration of queued
139 61 : let _guard = CountWaitingInitializers::start(self);
140 61 : let fut = sem.acquire();
141 61 : if let Some(wait_time) = wait_time.as_mut() {
142 32 : wait_time.measure(fut).await
143 : } else {
144 29 : fut.await
145 : }
146 : };
147 :
148 61 : let Ok(permit) = permit else {
149 0 : let guard = self.inner.lock().unwrap();
150 0 : if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
151 : // there was a take_and_deinit in between
152 0 : continue;
153 0 : }
154 0 : assert!(
155 0 : guard.value.is_some(),
156 0 : "semaphore got closed, must be initialized"
157 : );
158 0 : return Ok(Guard(guard));
159 : };
160 :
161 61 : permit.forget();
162 61 : }
163 61 :
164 61 : let permit = InitPermit(sem);
165 61 : return Err(permit);
166 : }
167 1 : }
168 :
169 : /// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
170 : /// to complete initializing the inner value.
171 : ///
172 : /// # Panics
173 : ///
174 : /// If the inner has already been initialized.
175 68 : pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
176 68 : let guard = self.inner.lock().unwrap();
177 68 :
178 68 : // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
179 68 : // give more permits right now.
180 68 : if guard.init_semaphore.try_acquire().is_ok() {
181 0 : drop(guard);
182 0 : panic!("permit is of wrong origin");
183 68 : }
184 68 :
185 68 : Self::set0(value, guard)
186 68 : }
187 :
188 75 : fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
189 75 : if guard.value.is_some() {
190 0 : drop(guard);
191 0 : unreachable!("we won permit, must not be initialized");
192 75 : }
193 75 : guard.value = Some(value);
194 75 : guard.init_semaphore.close();
195 75 : Guard(guard)
196 75 : }
197 :
198 : /// Returns a guard to an existing initialized value, if any.
199 512 : pub fn get(&self) -> Option<Guard<'_, T>> {
200 512 : let guard = self.inner.lock().unwrap();
201 512 : if guard.value.is_some() {
202 362 : Some(Guard(guard))
203 : } else {
204 150 : None
205 : }
206 11 : }
207 :
208 : /// Like [`Guard::take_and_deinit`], but will return `None` if this OnceCell was never
209 : /// initialized.
210 1416 : pub fn take_and_deinit(&mut self) -> Option<(T, InitPermit)> {
211 1416 : let inner = self.inner.get_mut().unwrap();
212 1416 :
213 1416 : inner.take_and_deinit()
214 1416 : }
215 :
216 : /// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
217 108 : pub fn initializer_count(&self) -> usize {
218 108 : self.initializers.load(Ordering::Relaxed)
219 108 : }
220 : }
221 :
222 : /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
223 : /// initializing task for example at the end of initialization.
224 : struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
225 :
226 : impl<'a, T> CountWaitingInitializers<'a, T> {
227 81 : fn start(target: &'a OnceCell<T>) -> Self {
228 81 : target.initializers.fetch_add(1, Ordering::Relaxed);
229 81 : CountWaitingInitializers(target)
230 81 : }
231 : }
232 :
233 : impl<T> Drop for CountWaitingInitializers<'_, T> {
234 81 : fn drop(&mut self) {
235 81 : self.0.initializers.fetch_sub(1, Ordering::Relaxed);
236 81 : }
237 : }
238 :
239 : /// Uninteresting guard object to allow short-lived access to inspect or clone the held,
240 : /// initialized value.
241 : #[derive(Debug)]
242 : pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
243 :
244 : impl<T> std::ops::Deref for Guard<'_, T> {
245 : type Target = T;
246 :
247 441 : fn deref(&self) -> &Self::Target {
248 441 : self.0
249 441 : .value
250 441 : .as_ref()
251 441 : .expect("guard is not created unless value has been initialized")
252 441 : }
253 : }
254 :
255 : impl<T> std::ops::DerefMut for Guard<'_, T> {
256 481496 : fn deref_mut(&mut self) -> &mut Self::Target {
257 481496 : self.0
258 481496 : .value
259 481496 : .as_mut()
260 481496 : .expect("guard is not created unless value has been initialized")
261 481496 : }
262 : }
263 :
264 : impl<T> Guard<'_, T> {
265 : /// Take the current value, and a new permit for it's deinitialization.
266 : ///
267 : /// The permit will be on a semaphore part of the new internal value, and any following
268 : /// [`OnceCell::get_or_init`] will wait on it to complete.
269 97 : pub fn take_and_deinit(mut self) -> (T, InitPermit) {
270 97 : self.0
271 97 : .take_and_deinit()
272 97 : .expect("guard is not created unless value has been initialized")
273 97 : }
274 : }
275 :
276 : impl<T> Inner<T> {
277 1513 : pub fn take_and_deinit(&mut self) -> Option<(T, InitPermit)> {
278 1513 : let value = self.value.take()?;
279 :
280 1507 : let mut swapped = Inner::default();
281 1507 : let sem = swapped.init_semaphore.clone();
282 1507 : // acquire and forget right away, moving the control over to InitPermit
283 1507 : sem.try_acquire().expect("we just created this").forget();
284 1507 : let permit = InitPermit(sem);
285 1507 : std::mem::swap(self, &mut swapped);
286 1507 : Some((value, permit))
287 8 : }
288 : }
289 :
290 : /// Type held by OnceCell (de)initializing task.
291 : ///
292 : /// On drop, this type will return the permit.
293 : pub struct InitPermit(Arc<tokio::sync::Semaphore>);
294 :
295 : impl std::fmt::Debug for InitPermit {
296 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 0 : let ptr = Arc::as_ptr(&self.0) as *const ();
298 0 : f.debug_tuple("InitPermit").field(&ptr).finish()
299 0 : }
300 : }
301 :
302 : impl Drop for InitPermit {
303 1586 : fn drop(&mut self) {
304 1586 : assert_eq!(
305 1586 : self.0.available_permits(),
306 : 0,
307 0 : "InitPermit should only exist as the unique permit"
308 : );
309 1586 : self.0.add_permits(1);
310 1586 : }
311 : }
312 :
313 : #[cfg(test)]
314 : mod tests {
315 : use std::convert::Infallible;
316 : use std::pin::{Pin, pin};
317 : use std::time::Duration;
318 :
319 : use futures::Future;
320 :
321 : use super::*;
322 :
323 : #[tokio::test]
324 1 : async fn many_initializers() {
325 1 : #[derive(Default, Debug)]
326 1 : struct Counters {
327 1 : factory_got_to_run: AtomicUsize,
328 1 : future_polled: AtomicUsize,
329 1 : winners: AtomicUsize,
330 1 : }
331 1 :
332 1 : let initializers = 100;
333 1 :
334 1 : let cell = Arc::new(OnceCell::default());
335 1 : let counters = Arc::new(Counters::default());
336 1 : let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
337 1 :
338 1 : let mut js = tokio::task::JoinSet::new();
339 100 : for i in 0..initializers {
340 100 : js.spawn({
341 100 : let cell = cell.clone();
342 100 : let counters = counters.clone();
343 100 : let barrier = barrier.clone();
344 100 :
345 100 : async move {
346 100 : barrier.wait().await;
347 100 : let won = {
348 100 : let g = cell
349 100 : .get_or_init(|permit| {
350 1 : counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
351 1 : async {
352 1 : counters.future_polled.fetch_add(1, Ordering::Relaxed);
353 1 : Ok::<_, Infallible>((i, permit))
354 1 : }
355 100 : })
356 100 : .await
357 100 : .unwrap();
358 100 :
359 100 : *g == i
360 100 : };
361 100 :
362 100 : if won {
363 1 : counters.winners.fetch_add(1, Ordering::Relaxed);
364 99 : }
365 100 : }
366 100 : });
367 100 : }
368 1 :
369 1 : barrier.wait().await;
370 1 :
371 101 : while let Some(next) = js.join_next().await {
372 100 : next.expect("no panics expected");
373 100 : }
374 1 :
375 1 : let mut counters = Arc::try_unwrap(counters).unwrap();
376 1 :
377 1 : assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
378 1 : assert_eq!(*counters.future_polled.get_mut(), 1);
379 1 : assert_eq!(*counters.winners.get_mut(), 1);
380 1 : }
381 :
382 : #[tokio::test(start_paused = true)]
383 1 : async fn reinit_waits_for_deinit() {
384 1 : // with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
385 1 : let sleep_for = Duration::from_secs(1);
386 1 : let initial = 42;
387 1 : let reinit = 1;
388 1 : let cell = Arc::new(OnceCell::new(initial));
389 1 :
390 1 : let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
391 1 :
392 1 : let jh = tokio::spawn({
393 1 : let cell = cell.clone();
394 1 : let deinitialization_started = deinitialization_started.clone();
395 1 : async move {
396 1 : let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
397 1 : assert_eq!(answer, initial);
398 1 :
399 1 : deinitialization_started.wait().await;
400 1 : tokio::time::sleep(sleep_for).await;
401 1 : }
402 1 : });
403 1 :
404 1 : deinitialization_started.wait().await;
405 1 :
406 1 : let started_at = tokio::time::Instant::now();
407 1 : cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
408 1 : .await
409 1 : .unwrap();
410 1 :
411 1 : let elapsed = started_at.elapsed();
412 1 : assert!(
413 1 : elapsed >= sleep_for,
414 1 : "initialization should had taken at least the time time slept with permit"
415 1 : );
416 1 :
417 1 : jh.await.unwrap();
418 1 :
419 1 : assert_eq!(*cell.get().unwrap(), reinit);
420 1 : }
421 :
422 : #[test]
423 1 : fn reinit_with_deinit_permit() {
424 1 : let cell = Arc::new(OnceCell::new(42));
425 1 :
426 1 : let (mol, permit) = cell.get().unwrap().take_and_deinit();
427 1 : cell.set(5, permit);
428 1 : assert_eq!(*cell.get().unwrap(), 5);
429 :
430 1 : let (five, permit) = cell.get().unwrap().take_and_deinit();
431 1 : assert_eq!(5, five);
432 1 : cell.set(mol, permit);
433 1 : assert_eq!(*cell.get().unwrap(), 42);
434 1 : }
435 :
436 : #[tokio::test]
437 1 : async fn initialization_attemptable_until_ok() {
438 1 : let cell = OnceCell::default();
439 1 :
440 11 : for _ in 0..10 {
441 10 : cell.get_or_init(|_permit| async { Err("whatever error") })
442 10 : .await
443 10 : .unwrap_err();
444 1 : }
445 1 :
446 1 : let g = cell
447 1 : .get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
448 1 : .await
449 1 : .unwrap();
450 1 : assert_eq!(*g, "finally success");
451 1 : }
452 :
453 : #[tokio::test]
454 1 : async fn initialization_is_cancellation_safe() {
455 1 : let cell = OnceCell::default();
456 1 :
457 1 : let barrier = tokio::sync::Barrier::new(2);
458 1 :
459 1 : let initializer = cell.get_or_init(|permit| async {
460 1 : barrier.wait().await;
461 1 : futures::future::pending::<()>().await;
462 1 :
463 1 : Ok::<_, Infallible>(("never reached", permit))
464 1 : });
465 1 :
466 1 : tokio::select! {
467 1 : _ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
468 1 : _ = barrier.wait() => {}
469 1 : };
470 1 :
471 1 : // now initializer is dropped
472 1 :
473 1 : assert!(cell.get().is_none());
474 1 :
475 1 : let g = cell
476 1 : .get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
477 1 : .await
478 1 : .unwrap();
479 1 : assert_eq!(*g, "now initialized");
480 1 : }
481 :
482 : #[tokio::test(start_paused = true)]
483 1 : async fn reproduce_init_take_deinit_race() {
484 2 : init_take_deinit_scenario(|cell, factory| {
485 2 : Box::pin(async {
486 2 : cell.get_or_init(factory).await.unwrap();
487 2 : })
488 2 : })
489 1 : .await;
490 1 : }
491 :
492 : type BoxedInitFuture<T, E> = Pin<Box<dyn Future<Output = Result<(T, InitPermit), E>>>>;
493 : type BoxedInitFunction<T, E> = Box<dyn Fn(InitPermit) -> BoxedInitFuture<T, E>>;
494 :
495 : /// Reproduce an assertion failure.
496 : ///
497 : /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`.
498 : /// We currently only have one, but the structure is kept.
499 1 : async fn init_take_deinit_scenario<F>(init_way: F)
500 1 : where
501 1 : F: for<'a> Fn(
502 1 : &'a OnceCell<&'static str>,
503 1 : BoxedInitFunction<&'static str, Infallible>,
504 1 : ) -> Pin<Box<dyn Future<Output = ()> + 'a>>,
505 1 : {
506 1 : let cell = OnceCell::default();
507 1 :
508 1 : // acquire the init_semaphore only permit to drive initializing tasks in order to waiting
509 1 : // on the same semaphore.
510 1 : let permit = cell
511 1 : .inner
512 1 : .lock()
513 1 : .unwrap()
514 1 : .init_semaphore
515 1 : .clone()
516 1 : .try_acquire_owned()
517 1 : .unwrap();
518 1 :
519 1 : let mut t1 = pin!(init_way(
520 1 : &cell,
521 1 : Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })),
522 1 : ));
523 1 :
524 1 : let mut t2 = pin!(init_way(
525 1 : &cell,
526 1 : Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })),
527 1 : ));
528 1 :
529 1 : // drive t2 first to the init_semaphore -- the timeout will be hit once t2 future can
530 1 : // no longer make progress
531 1 : tokio::select! {
532 1 : _ = &mut t2 => unreachable!("it cannot get permit"),
533 1 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
534 1 : }
535 1 :
536 1 : // followed by t1 in the init_semaphore
537 1 : tokio::select! {
538 1 : _ = &mut t1 => unreachable!("it cannot get permit"),
539 1 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
540 1 : }
541 1 :
542 1 : // now let t2 proceed and initialize
543 1 : drop(permit);
544 1 : t2.await;
545 :
546 1 : let (s, permit) = { cell.get().unwrap().take_and_deinit() };
547 1 : assert_eq!("t2", s);
548 :
549 : // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from
550 : // the new one.
551 1 : tokio::select! {
552 1 : _ = &mut t1 => unreachable!("it cannot get permit"),
553 1 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
554 1 : }
555 1 :
556 1 : // only now we get to initialize it
557 1 : drop(permit);
558 1 : t1.await;
559 :
560 1 : assert_eq!("t1", *cell.get().unwrap());
561 1 : }
562 :
563 : #[tokio::test(start_paused = true)]
564 1 : async fn detached_init_smoke() {
565 1 : let target = OnceCell::default();
566 1 :
567 1 : let Err(permit) = target.get_or_init_detached().await else {
568 1 : unreachable!("it is not initialized")
569 1 : };
570 1 :
571 1 : tokio::time::timeout(
572 1 : std::time::Duration::from_secs(3600 * 24 * 7 * 365),
573 1 : target.get_or_init(|permit2| async { Ok::<_, Infallible>((11, permit2)) }),
574 1 : )
575 1 : .await
576 1 : .expect_err("should timeout since we are already holding the permit");
577 1 :
578 1 : target.set(42, permit);
579 1 :
580 1 : let (_answer, permit) = {
581 1 : let guard = target
582 1 : .get_or_init(|permit| async { Ok::<_, Infallible>((11, permit)) })
583 1 : .await
584 1 : .unwrap();
585 1 :
586 1 : assert_eq!(*guard, 42);
587 1 :
588 1 : guard.take_and_deinit()
589 1 : };
590 1 :
591 1 : assert!(target.get().is_none());
592 1 :
593 1 : target.set(11, permit);
594 1 :
595 1 : assert_eq!(*target.get().unwrap(), 11);
596 1 : }
597 :
598 : #[tokio::test]
599 1 : async fn take_and_deinit_on_mut() {
600 1 : use std::convert::Infallible;
601 1 :
602 1 : let mut target = OnceCell::<u32>::default();
603 1 : assert!(target.take_and_deinit().is_none());
604 1 :
605 1 : target
606 1 : .get_or_init(|permit| async move { Ok::<_, Infallible>((42, permit)) })
607 1 : .await
608 1 : .unwrap();
609 1 :
610 1 : let again = target.take_and_deinit();
611 1 : assert!(matches!(again, Some((42, _))), "{again:?}");
612 1 :
613 1 : assert!(target.take_and_deinit().is_none());
614 1 : }
615 : }
|