Line data Source code
1 : use std::sync::{
2 : atomic::{AtomicUsize, Ordering},
3 : Arc, Mutex, MutexGuard,
4 : };
5 : use tokio::sync::Semaphore;
6 :
7 : /// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
8 : /// `SemaphorePermit`, allowing use of `take` which does not require holding an outer mutex guard
9 : /// for the duration of initialization.
10 : ///
11 : /// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
12 : ///
13 : /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
14 : pub struct OnceCell<T> {
15 : inner: Mutex<Inner<T>>,
16 : initializers: AtomicUsize,
17 : }
18 :
19 : impl<T> Default for OnceCell<T> {
20 : /// Create new uninitialized [`OnceCell`].
21 16 : fn default() -> Self {
22 16 : Self {
23 16 : inner: Default::default(),
24 16 : initializers: AtomicUsize::new(0),
25 16 : }
26 16 : }
27 : }
28 :
29 : /// Semaphore is the current state:
30 : /// - open semaphore means the value is `None`, not yet initialized
31 : /// - closed semaphore means the value has been initialized
32 : #[derive(Debug)]
33 : struct Inner<T> {
34 : init_semaphore: Arc<Semaphore>,
35 : value: Option<T>,
36 : }
37 :
38 : impl<T> Default for Inner<T> {
39 54 : fn default() -> Self {
40 54 : Self {
41 54 : init_semaphore: Arc::new(Semaphore::new(1)),
42 54 : value: None,
43 54 : }
44 54 : }
45 : }
46 :
47 : impl<T> OnceCell<T> {
48 : /// Creates an already initialized `OnceCell` with the given value.
49 970 : pub fn new(value: T) -> Self {
50 970 : let sem = Semaphore::new(1);
51 970 : sem.close();
52 970 : Self {
53 970 : inner: Mutex::new(Inner {
54 970 : init_semaphore: Arc::new(sem),
55 970 : value: Some(value),
56 970 : }),
57 970 : initializers: AtomicUsize::new(0),
58 970 : }
59 970 : }
60 :
61 : /// Returns a guard to an existing initialized value, or uniquely initializes the value before
62 : /// returning the guard.
63 : ///
64 : /// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
65 : ///
66 : /// Initialization is panic-safe and cancellation-safe.
67 236 : pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
68 236 : where
69 236 : F: FnOnce(InitPermit) -> Fut,
70 236 : Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
71 236 : {
72 : loop {
73 38 : let sem = {
74 238 : let guard = self.inner.lock().unwrap();
75 238 : if guard.value.is_some() {
76 200 : return Ok(Guard(guard));
77 38 : }
78 38 : guard.init_semaphore.clone()
79 : };
80 :
81 : {
82 36 : let permit = {
83 : // increment the count for the duration of queued
84 38 : let _guard = CountWaitingInitializers::start(self);
85 38 : sem.acquire().await
86 : };
87 :
88 36 : let Ok(permit) = permit else {
89 2 : let guard = self.inner.lock().unwrap();
90 2 : if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
91 : // there was a take_and_deinit in between
92 2 : continue;
93 0 : }
94 0 : assert!(
95 0 : guard.value.is_some(),
96 0 : "semaphore got closed, must be initialized"
97 : );
98 0 : return Ok(Guard(guard));
99 : };
100 :
101 34 : permit.forget();
102 34 : }
103 34 :
104 34 : let permit = InitPermit(sem);
105 34 : let (value, _permit) = factory(permit).await?;
106 :
107 12 : let guard = self.inner.lock().unwrap();
108 12 :
109 12 : return Ok(Self::set0(value, guard));
110 : }
111 232 : }
112 :
113 : /// Returns a guard to an existing initialized value, or returns an unique initialization
114 : /// permit which can be used to initialize this `OnceCell` using `OnceCell::set`.
115 125052 : pub async fn get_or_init_detached(&self) -> Result<Guard<'_, T>, InitPermit> {
116 : // It looks like OnceCell::get_or_init could be implemented using this method instead of
117 : // duplication. However, that makes the future be !Send due to possibly holding on to the
118 : // MutexGuard over an await point.
119 : loop {
120 20 : let sem = {
121 125052 : let guard = self.inner.lock().unwrap();
122 125052 : if guard.value.is_some() {
123 125032 : return Ok(Guard(guard));
124 20 : }
125 20 : guard.init_semaphore.clone()
126 : };
127 :
128 : {
129 20 : let permit = {
130 : // increment the count for the duration of queued
131 20 : let _guard = CountWaitingInitializers::start(self);
132 20 : sem.acquire().await
133 : };
134 :
135 20 : let Ok(permit) = permit else {
136 0 : let guard = self.inner.lock().unwrap();
137 0 : if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
138 : // there was a take_and_deinit in between
139 0 : continue;
140 0 : }
141 0 : assert!(
142 0 : guard.value.is_some(),
143 0 : "semaphore got closed, must be initialized"
144 : );
145 0 : return Ok(Guard(guard));
146 : };
147 :
148 20 : permit.forget();
149 20 : }
150 20 :
151 20 : let permit = InitPermit(sem);
152 20 : return Err(permit);
153 : }
154 125052 : }
155 :
156 : /// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
157 : /// to complete initializing the inner value.
158 : ///
159 : /// # Panics
160 : ///
161 : /// If the inner has already been initialized.
162 28 : pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
163 28 : let guard = self.inner.lock().unwrap();
164 28 :
165 28 : // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
166 28 : // give more permits right now.
167 28 : if guard.init_semaphore.try_acquire().is_ok() {
168 0 : drop(guard);
169 0 : panic!("permit is of wrong origin");
170 28 : }
171 28 :
172 28 : Self::set0(value, guard)
173 28 : }
174 :
175 40 : fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
176 40 : if guard.value.is_some() {
177 0 : drop(guard);
178 0 : unreachable!("we won permit, must not be initialized");
179 40 : }
180 40 : guard.value = Some(value);
181 40 : guard.init_semaphore.close();
182 40 : Guard(guard)
183 40 : }
184 :
185 : /// Returns a guard to an existing initialized value, if any.
186 104 : pub fn get(&self) -> Option<Guard<'_, T>> {
187 104 : let guard = self.inner.lock().unwrap();
188 104 : if guard.value.is_some() {
189 88 : Some(Guard(guard))
190 : } else {
191 16 : None
192 : }
193 104 : }
194 :
195 : /// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
196 30 : pub fn initializer_count(&self) -> usize {
197 30 : self.initializers.load(Ordering::Relaxed)
198 30 : }
199 : }
200 :
201 : /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
202 : /// initializing task for example at the end of initialization.
203 : struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
204 :
205 : impl<'a, T> CountWaitingInitializers<'a, T> {
206 58 : fn start(target: &'a OnceCell<T>) -> Self {
207 58 : target.initializers.fetch_add(1, Ordering::Relaxed);
208 58 : CountWaitingInitializers(target)
209 58 : }
210 : }
211 :
212 : impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
213 58 : fn drop(&mut self) {
214 58 : self.0.initializers.fetch_sub(1, Ordering::Relaxed);
215 58 : }
216 : }
217 :
218 : /// Uninteresting guard object to allow short-lived access to inspect or clone the held,
219 : /// initialized value.
220 : #[derive(Debug)]
221 : pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
222 :
223 : impl<T> std::ops::Deref for Guard<'_, T> {
224 : type Target = T;
225 :
226 278 : fn deref(&self) -> &Self::Target {
227 278 : self.0
228 278 : .value
229 278 : .as_ref()
230 278 : .expect("guard is not created unless value has been initialized")
231 278 : }
232 : }
233 :
234 : impl<T> std::ops::DerefMut for Guard<'_, T> {
235 125040 : fn deref_mut(&mut self) -> &mut Self::Target {
236 125040 : self.0
237 125040 : .value
238 125040 : .as_mut()
239 125040 : .expect("guard is not created unless value has been initialized")
240 125040 : }
241 : }
242 :
243 : impl<'a, T> Guard<'a, T> {
244 : /// Take the current value, and a new permit for it's deinitialization.
245 : ///
246 : /// The permit will be on a semaphore part of the new internal value, and any following
247 : /// [`OnceCell::get_or_init`] will wait on it to complete.
248 38 : pub fn take_and_deinit(mut self) -> (T, InitPermit) {
249 38 : let mut swapped = Inner::default();
250 38 : let sem = swapped.init_semaphore.clone();
251 38 : // acquire and forget right away, moving the control over to InitPermit
252 38 : sem.try_acquire().expect("we just created this").forget();
253 38 : std::mem::swap(&mut *self.0, &mut swapped);
254 38 : swapped
255 38 : .value
256 38 : .map(|v| (v, InitPermit(sem)))
257 38 : .expect("guard is not created unless value has been initialized")
258 38 : }
259 : }
260 :
261 : /// Type held by OnceCell (de)initializing task.
262 : ///
263 : /// On drop, this type will return the permit.
264 : pub struct InitPermit(Arc<tokio::sync::Semaphore>);
265 :
266 : impl Drop for InitPermit {
267 92 : fn drop(&mut self) {
268 92 : assert_eq!(
269 92 : self.0.available_permits(),
270 : 0,
271 0 : "InitPermit should only exist as the unique permit"
272 : );
273 92 : self.0.add_permits(1);
274 92 : }
275 : }
276 :
277 : #[cfg(test)]
278 : mod tests {
279 : use futures::Future;
280 :
281 : use super::*;
282 : use std::{
283 : convert::Infallible,
284 : pin::{pin, Pin},
285 : time::Duration,
286 : };
287 :
288 : #[tokio::test]
289 2 : async fn many_initializers() {
290 2 : #[derive(Default, Debug)]
291 2 : struct Counters {
292 2 : factory_got_to_run: AtomicUsize,
293 2 : future_polled: AtomicUsize,
294 2 : winners: AtomicUsize,
295 2 : }
296 2 :
297 2 : let initializers = 100;
298 2 :
299 2 : let cell = Arc::new(OnceCell::default());
300 2 : let counters = Arc::new(Counters::default());
301 2 : let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
302 2 :
303 2 : let mut js = tokio::task::JoinSet::new();
304 200 : for i in 0..initializers {
305 200 : js.spawn({
306 200 : let cell = cell.clone();
307 200 : let counters = counters.clone();
308 200 : let barrier = barrier.clone();
309 200 :
310 200 : async move {
311 200 : barrier.wait().await;
312 200 : let won = {
313 200 : let g = cell
314 200 : .get_or_init(|permit| {
315 2 : counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
316 2 : async {
317 2 : counters.future_polled.fetch_add(1, Ordering::Relaxed);
318 2 : Ok::<_, Infallible>((i, permit))
319 2 : }
320 200 : })
321 2 : .await
322 200 : .unwrap();
323 200 :
324 200 : *g == i
325 200 : };
326 200 :
327 200 : if won {
328 2 : counters.winners.fetch_add(1, Ordering::Relaxed);
329 198 : }
330 200 : }
331 200 : });
332 200 : }
333 2 :
334 2 : barrier.wait().await;
335 2 :
336 202 : while let Some(next) = js.join_next().await {
337 200 : next.expect("no panics expected");
338 200 : }
339 2 :
340 2 : let mut counters = Arc::try_unwrap(counters).unwrap();
341 2 :
342 2 : assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
343 2 : assert_eq!(*counters.future_polled.get_mut(), 1);
344 2 : assert_eq!(*counters.winners.get_mut(), 1);
345 2 : }
346 :
347 : #[tokio::test(start_paused = true)]
348 2 : async fn reinit_waits_for_deinit() {
349 2 : // with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
350 2 : let sleep_for = Duration::from_secs(1);
351 2 : let initial = 42;
352 2 : let reinit = 1;
353 2 : let cell = Arc::new(OnceCell::new(initial));
354 2 :
355 2 : let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
356 2 :
357 2 : let jh = tokio::spawn({
358 2 : let cell = cell.clone();
359 2 : let deinitialization_started = deinitialization_started.clone();
360 2 : async move {
361 2 : let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
362 2 : assert_eq!(answer, initial);
363 2 :
364 2 : deinitialization_started.wait().await;
365 2 : tokio::time::sleep(sleep_for).await;
366 2 : }
367 2 : });
368 2 :
369 2 : deinitialization_started.wait().await;
370 2 :
371 2 : let started_at = tokio::time::Instant::now();
372 2 : cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
373 2 : .await
374 2 : .unwrap();
375 2 :
376 2 : let elapsed = started_at.elapsed();
377 2 : assert!(
378 2 : elapsed >= sleep_for,
379 2 : "initialization should had taken at least the time time slept with permit"
380 2 : );
381 2 :
382 2 : jh.await.unwrap();
383 2 :
384 2 : assert_eq!(*cell.get().unwrap(), reinit);
385 2 : }
386 :
387 : #[test]
388 2 : fn reinit_with_deinit_permit() {
389 2 : let cell = Arc::new(OnceCell::new(42));
390 2 :
391 2 : let (mol, permit) = cell.get().unwrap().take_and_deinit();
392 2 : cell.set(5, permit);
393 2 : assert_eq!(*cell.get().unwrap(), 5);
394 :
395 2 : let (five, permit) = cell.get().unwrap().take_and_deinit();
396 2 : assert_eq!(5, five);
397 2 : cell.set(mol, permit);
398 2 : assert_eq!(*cell.get().unwrap(), 42);
399 2 : }
400 :
401 : #[tokio::test]
402 2 : async fn initialization_attemptable_until_ok() {
403 2 : let cell = OnceCell::default();
404 2 :
405 22 : for _ in 0..10 {
406 20 : cell.get_or_init(|_permit| async { Err("whatever error") })
407 2 : .await
408 20 : .unwrap_err();
409 2 : }
410 2 :
411 2 : let g = cell
412 2 : .get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
413 2 : .await
414 2 : .unwrap();
415 2 : assert_eq!(*g, "finally success");
416 2 : }
417 :
418 : #[tokio::test]
419 2 : async fn initialization_is_cancellation_safe() {
420 2 : let cell = OnceCell::default();
421 2 :
422 2 : let barrier = tokio::sync::Barrier::new(2);
423 2 :
424 2 : let initializer = cell.get_or_init(|permit| async {
425 2 : barrier.wait().await;
426 2 : futures::future::pending::<()>().await;
427 2 :
428 2 : Ok::<_, Infallible>(("never reached", permit))
429 2 : });
430 2 :
431 3 : tokio::select! {
432 3 : _ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
433 3 : _ = barrier.wait() => {}
434 3 : };
435 2 :
436 2 : // now initializer is dropped
437 2 :
438 2 : assert!(cell.get().is_none());
439 2 :
440 2 : let g = cell
441 2 : .get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
442 2 : .await
443 2 : .unwrap();
444 2 : assert_eq!(*g, "now initialized");
445 2 : }
446 :
447 : #[tokio::test(start_paused = true)]
448 2 : async fn reproduce_init_take_deinit_race() {
449 4 : init_take_deinit_scenario(|cell, factory| {
450 4 : Box::pin(async {
451 9 : cell.get_or_init(factory).await.unwrap();
452 4 : })
453 4 : })
454 8 : .await;
455 2 : }
456 :
457 : type BoxedInitFuture<T, E> = Pin<Box<dyn Future<Output = Result<(T, InitPermit), E>>>>;
458 : type BoxedInitFunction<T, E> = Box<dyn Fn(InitPermit) -> BoxedInitFuture<T, E>>;
459 :
460 : /// Reproduce an assertion failure.
461 : ///
462 : /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`.
463 : /// We currently only have one, but the structure is kept.
464 2 : async fn init_take_deinit_scenario<F>(init_way: F)
465 2 : where
466 2 : F: for<'a> Fn(
467 2 : &'a OnceCell<&'static str>,
468 2 : BoxedInitFunction<&'static str, Infallible>,
469 2 : ) -> Pin<Box<dyn Future<Output = ()> + 'a>>,
470 2 : {
471 2 : let cell = OnceCell::default();
472 2 :
473 2 : // acquire the init_semaphore only permit to drive initializing tasks in order to waiting
474 2 : // on the same semaphore.
475 2 : let permit = cell
476 2 : .inner
477 2 : .lock()
478 2 : .unwrap()
479 2 : .init_semaphore
480 2 : .clone()
481 2 : .try_acquire_owned()
482 2 : .unwrap();
483 2 :
484 2 : let mut t1 = pin!(init_way(
485 2 : &cell,
486 2 : Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })),
487 2 : ));
488 2 :
489 2 : let mut t2 = pin!(init_way(
490 2 : &cell,
491 2 : Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })),
492 2 : ));
493 :
494 : // drive t2 first to the init_semaphore -- the timeout will be hit once t2 future can
495 : // no longer make progress
496 4 : tokio::select! {
497 4 : _ = &mut t2 => unreachable!("it cannot get permit"),
498 4 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
499 4 : }
500 :
501 : // followed by t1 in the init_semaphore
502 4 : tokio::select! {
503 4 : _ = &mut t1 => unreachable!("it cannot get permit"),
504 4 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
505 4 : }
506 :
507 : // now let t2 proceed and initialize
508 2 : drop(permit);
509 2 : t2.await;
510 :
511 2 : let (s, permit) = { cell.get().unwrap().take_and_deinit() };
512 2 : assert_eq!("t2", s);
513 :
514 : // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from
515 : // the new one.
516 6 : tokio::select! {
517 6 : _ = &mut t1 => unreachable!("it cannot get permit"),
518 6 : _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
519 6 : }
520 :
521 : // only now we get to initialize it
522 2 : drop(permit);
523 2 : t1.await;
524 :
525 2 : assert_eq!("t1", *cell.get().unwrap());
526 2 : }
527 :
528 : #[tokio::test(start_paused = true)]
529 2 : async fn detached_init_smoke() {
530 2 : let target = OnceCell::default();
531 2 :
532 2 : let Err(permit) = target.get_or_init_detached().await else {
533 2 : unreachable!("it is not initialized")
534 2 : };
535 2 :
536 2 : tokio::time::timeout(
537 2 : std::time::Duration::from_secs(3600 * 24 * 7 * 365),
538 2 : target.get_or_init(|permit2| async { Ok::<_, Infallible>((11, permit2)) }),
539 2 : )
540 2 : .await
541 2 : .expect_err("should timeout since we are already holding the permit");
542 2 :
543 2 : target.set(42, permit);
544 2 :
545 2 : let (_answer, permit) = {
546 2 : let guard = target
547 2 : .get_or_init(|permit| async { Ok::<_, Infallible>((11, permit)) })
548 2 : .await
549 2 : .unwrap();
550 2 :
551 2 : assert_eq!(*guard, 42);
552 2 :
553 2 : guard.take_and_deinit()
554 2 : };
555 2 :
556 2 : assert!(target.get().is_none());
557 2 :
558 2 : target.set(11, permit);
559 2 :
560 2 : assert_eq!(*target.get().unwrap(), 11);
561 2 : }
562 : }
|