TLA Line data Source code
1 : use std::sync::{
2 : atomic::{AtomicUsize, Ordering},
3 : Arc, Mutex, MutexGuard,
4 : };
5 : use tokio::sync::Semaphore;
6 :
7 : /// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
8 : /// `SemaphorePermit`, allowing use of `take` which does not require holding an outer mutex guard
9 : /// for the duration of initialization.
10 : ///
11 : /// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
12 : ///
13 : /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
14 : pub struct OnceCell<T> {
15 : inner: Mutex<Inner<T>>,
16 : initializers: AtomicUsize,
17 : }
18 :
19 : impl<T> Default for OnceCell<T> {
20 : /// Create new uninitialized [`OnceCell`].
21 CBC 43723 : fn default() -> Self {
22 43723 : Self {
23 43723 : inner: Default::default(),
24 43723 : initializers: AtomicUsize::new(0),
25 43723 : }
26 43723 : }
27 : }
28 :
29 : /// Semaphore is the current state:
30 : /// - open semaphore means the value is `None`, not yet initialized
31 : /// - closed semaphore means the value has been initialized
32 UBC 0 : #[derive(Debug)]
33 : struct Inner<T> {
34 : init_semaphore: Arc<Semaphore>,
35 : value: Option<T>,
36 : }
37 :
38 : impl<T> Default for Inner<T> {
39 CBC 46264 : fn default() -> Self {
40 46264 : Self {
41 46264 : init_semaphore: Arc::new(Semaphore::new(1)),
42 46264 : value: None,
43 46264 : }
44 46264 : }
45 : }
46 :
47 : impl<T> OnceCell<T> {
48 : /// Creates an already initialized `OnceCell` with the given value.
49 34535 : pub fn new(value: T) -> Self {
50 34535 : let sem = Semaphore::new(1);
51 34535 : sem.close();
52 34535 : Self {
53 34535 : inner: Mutex::new(Inner {
54 34535 : init_semaphore: Arc::new(sem),
55 34535 : value: Some(value),
56 34535 : }),
57 34535 : initializers: AtomicUsize::new(0),
58 34535 : }
59 34535 : }
60 :
61 : /// Returns a guard to an existing initialized value, or uniquely initializes the value before
62 : /// returning the guard.
63 : ///
64 : /// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
65 : ///
66 : /// Initialization is panic-safe and cancellation-safe.
67 15463516 : pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
68 15463516 : where
69 15463516 : F: FnOnce(InitPermit) -> Fut,
70 15463516 : Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
71 15463516 : {
72 10552 : let sem = {
73 15463516 : let guard = self.inner.lock().unwrap();
74 15463516 : if guard.value.is_some() {
75 15452964 : return Ok(Guard(guard));
76 10552 : }
77 10552 : guard.init_semaphore.clone()
78 : };
79 :
80 10552 : let permit = {
81 : // increment the count for the duration of queued
82 10552 : let _guard = CountWaitingInitializers::start(self);
83 10552 : sem.acquire_owned().await
84 : };
85 :
86 10552 : match permit {
87 10050 : Ok(permit) => {
88 10050 : let permit = InitPermit(permit);
89 29253 : let (value, _permit) = factory(permit).await?;
90 :
91 9400 : let guard = self.inner.lock().unwrap();
92 9400 :
93 9400 : Ok(Self::set0(value, guard))
94 : }
95 502 : Err(_closed) => {
96 502 : let guard = self.inner.lock().unwrap();
97 502 : assert!(
98 502 : guard.value.is_some(),
99 UBC 0 : "semaphore got closed, must be initialized"
100 : );
101 CBC 502 : return Ok(Guard(guard));
102 : }
103 : }
104 15463506 : }
105 :
106 : /// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
107 : /// to complete initializing the inner value.
108 : ///
109 : /// # Panics
110 : ///
111 : /// If the inner has already been initialized.
112 2 : pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
113 2 : let guard = self.inner.lock().unwrap();
114 2 :
115 2 : // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
116 2 : // give more permits right now.
117 2 : if guard.init_semaphore.try_acquire().is_ok() {
118 UBC 0 : drop(guard);
119 0 : panic!("permit is of wrong origin");
120 CBC 2 : }
121 2 :
122 2 : Self::set0(value, guard)
123 2 : }
124 :
125 9402 : fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
126 9402 : if guard.value.is_some() {
127 UBC 0 : drop(guard);
128 0 : unreachable!("we won permit, must not be initialized");
129 CBC 9402 : }
130 9402 : guard.value = Some(value);
131 9402 : guard.init_semaphore.close();
132 9402 : Guard(guard)
133 9402 : }
134 :
135 : /// Returns a guard to an existing initialized value, if any.
136 8058 : pub fn get(&self) -> Option<Guard<'_, T>> {
137 8058 : let guard = self.inner.lock().unwrap();
138 8058 : if guard.value.is_some() {
139 7250 : Some(Guard(guard))
140 : } else {
141 808 : None
142 : }
143 8058 : }
144 :
145 : /// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
146 9396 : pub fn initializer_count(&self) -> usize {
147 9396 : self.initializers.load(Ordering::Relaxed)
148 9396 : }
149 : }
150 :
151 : /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
152 : /// initializing task for example at the end of initialization.
153 : struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
154 :
155 : impl<'a, T> CountWaitingInitializers<'a, T> {
156 10552 : fn start(target: &'a OnceCell<T>) -> Self {
157 10552 : target.initializers.fetch_add(1, Ordering::Relaxed);
158 10552 : CountWaitingInitializers(target)
159 10552 : }
160 : }
161 :
162 : impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
163 10552 : fn drop(&mut self) {
164 10552 : self.0.initializers.fetch_sub(1, Ordering::Relaxed);
165 10552 : }
166 : }
167 :
168 : /// Uninteresting guard object to allow short-lived access to inspect or clone the held,
169 : /// initialized value.
170 UBC 0 : #[derive(Debug)]
171 : pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
172 :
173 : impl<T> std::ops::Deref for Guard<'_, T> {
174 : type Target = T;
175 :
176 CBC 2643 : fn deref(&self) -> &Self::Target {
177 2643 : self.0
178 2643 : .value
179 2643 : .as_ref()
180 2643 : .expect("guard is not created unless value has been initialized")
181 2643 : }
182 : }
183 :
184 : impl<T> std::ops::DerefMut for Guard<'_, T> {
185 15465301 : fn deref_mut(&mut self) -> &mut Self::Target {
186 15465301 : self.0
187 15465301 : .value
188 15465301 : .as_mut()
189 15465301 : .expect("guard is not created unless value has been initialized")
190 15465301 : }
191 : }
192 :
193 : impl<'a, T> Guard<'a, T> {
194 : /// Take the current value, and a new permit for it's deinitialization.
195 : ///
196 : /// The permit will be on a semaphore part of the new internal value, and any following
197 : /// [`OnceCell::get_or_init`] will wait on it to complete.
198 2541 : pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
199 2541 : let mut swapped = Inner::default();
200 2541 : let permit = swapped
201 2541 : .init_semaphore
202 2541 : .clone()
203 2541 : .try_acquire_owned()
204 2541 : .expect("we just created this");
205 2541 : std::mem::swap(&mut *self.0, &mut swapped);
206 2541 : swapped
207 2541 : .value
208 2541 : .map(|v| (v, InitPermit(permit)))
209 2541 : .expect("guard is not created unless value has been initialized")
210 2541 : }
211 : }
212 :
213 : /// Type held by OnceCell (de)initializing task.
214 : pub struct InitPermit(tokio::sync::OwnedSemaphorePermit);
215 :
216 : #[cfg(test)]
217 : mod tests {
218 : use super::*;
219 : use std::{
220 : convert::Infallible,
221 : sync::atomic::{AtomicUsize, Ordering},
222 : time::Duration,
223 : };
224 :
225 1 : #[tokio::test]
226 1 : async fn many_initializers() {
227 1 : #[derive(Default, Debug)]
228 1 : struct Counters {
229 1 : factory_got_to_run: AtomicUsize,
230 1 : future_polled: AtomicUsize,
231 1 : winners: AtomicUsize,
232 1 : }
233 1 :
234 1 : let initializers = 100;
235 1 :
236 1 : let cell = Arc::new(OnceCell::default());
237 1 : let counters = Arc::new(Counters::default());
238 1 : let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
239 1 :
240 1 : let mut js = tokio::task::JoinSet::new();
241 100 : for i in 0..initializers {
242 100 : js.spawn({
243 100 : let cell = cell.clone();
244 100 : let counters = counters.clone();
245 100 : let barrier = barrier.clone();
246 100 :
247 100 : async move {
248 100 : barrier.wait().await;
249 100 : let won = {
250 100 : let g = cell
251 100 : .get_or_init(|permit| {
252 1 : counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
253 1 : async {
254 1 : counters.future_polled.fetch_add(1, Ordering::Relaxed);
255 1 : Ok::<_, Infallible>((i, permit))
256 1 : }
257 100 : })
258 UBC 0 : .await
259 CBC 100 : .unwrap();
260 100 :
261 100 : *g == i
262 100 : };
263 100 :
264 100 : if won {
265 1 : counters.winners.fetch_add(1, Ordering::Relaxed);
266 99 : }
267 100 : }
268 100 : });
269 100 : }
270 :
271 1 : barrier.wait().await;
272 :
273 101 : while let Some(next) = js.join_next().await {
274 100 : next.expect("no panics expected");
275 100 : }
276 :
277 1 : let mut counters = Arc::try_unwrap(counters).unwrap();
278 1 :
279 1 : assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
280 1 : assert_eq!(*counters.future_polled.get_mut(), 1);
281 1 : assert_eq!(*counters.winners.get_mut(), 1);
282 : }
283 :
284 1 : #[tokio::test(start_paused = true)]
285 1 : async fn reinit_waits_for_deinit() {
286 1 : // with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
287 1 : let sleep_for = Duration::from_secs(1);
288 1 : let initial = 42;
289 1 : let reinit = 1;
290 1 : let cell = Arc::new(OnceCell::new(initial));
291 1 :
292 1 : let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
293 1 :
294 1 : let jh = tokio::spawn({
295 1 : let cell = cell.clone();
296 1 : let deinitialization_started = deinitialization_started.clone();
297 1 : async move {
298 1 : let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
299 1 : assert_eq!(answer, initial);
300 :
301 1 : deinitialization_started.wait().await;
302 1 : tokio::time::sleep(sleep_for).await;
303 1 : }
304 1 : });
305 1 :
306 1 : deinitialization_started.wait().await;
307 :
308 1 : let started_at = tokio::time::Instant::now();
309 1 : cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
310 1 : .await
311 1 : .unwrap();
312 1 :
313 1 : let elapsed = started_at.elapsed();
314 1 : assert!(
315 1 : elapsed >= sleep_for,
316 UBC 0 : "initialization should had taken at least the time time slept with permit"
317 : );
318 :
319 CBC 1 : jh.await.unwrap();
320 1 :
321 1 : assert_eq!(*cell.get().unwrap(), reinit);
322 : }
323 :
324 1 : #[test]
325 1 : fn reinit_with_deinit_permit() {
326 1 : let cell = Arc::new(OnceCell::new(42));
327 1 :
328 1 : let (mol, permit) = cell.get().unwrap().take_and_deinit();
329 1 : cell.set(5, permit);
330 1 : assert_eq!(*cell.get().unwrap(), 5);
331 :
332 1 : let (five, permit) = cell.get().unwrap().take_and_deinit();
333 1 : assert_eq!(5, five);
334 1 : cell.set(mol, permit);
335 1 : assert_eq!(*cell.get().unwrap(), 42);
336 1 : }
337 :
338 1 : #[tokio::test]
339 1 : async fn initialization_attemptable_until_ok() {
340 1 : let cell = OnceCell::default();
341 :
342 11 : for _ in 0..10 {
343 10 : cell.get_or_init(|_permit| async { Err("whatever error") })
344 UBC 0 : .await
345 CBC 10 : .unwrap_err();
346 : }
347 :
348 1 : let g = cell
349 1 : .get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
350 UBC 0 : .await
351 CBC 1 : .unwrap();
352 1 : assert_eq!(*g, "finally success");
353 : }
354 :
355 1 : #[tokio::test]
356 1 : async fn initialization_is_cancellation_safe() {
357 1 : let cell = OnceCell::default();
358 1 :
359 1 : let barrier = tokio::sync::Barrier::new(2);
360 1 :
361 1 : let initializer = cell.get_or_init(|permit| async {
362 1 : barrier.wait().await;
363 GBC 1 : futures::future::pending::<()>().await;
364 :
365 UBC 0 : Ok::<_, Infallible>(("never reached", permit))
366 CBC 1 : });
367 1 :
368 2 : tokio::select! {
369 2 : _ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
370 2 : _ = barrier.wait() => {}
371 2 : };
372 :
373 : // now initializer is dropped
374 :
375 1 : assert!(cell.get().is_none());
376 :
377 1 : let g = cell
378 1 : .get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
379 UBC 0 : .await
380 CBC 1 : .unwrap();
381 1 : assert_eq!(*g, "now initialized");
382 : }
383 : }
|