Line data Source code
1 : //! Try RCU extension lifted from <https://github.com/vorner/arc-swap/issues/94#issuecomment-1987154023>
2 :
3 : pub trait ArcSwapExt<T> {
4 : /// [`ArcSwap::rcu`](arc_swap::ArcSwap::rcu), but with Result that short-circuits on error.
5 : fn try_rcu<R, F, E>(&self, f: F) -> Result<T, E>
6 : where
7 : F: FnMut(&T) -> Result<R, E>,
8 : R: Into<T>;
9 : }
10 :
11 : impl<T, S> ArcSwapExt<T> for arc_swap::ArcSwapAny<T, S>
12 : where
13 : T: arc_swap::RefCnt,
14 : S: arc_swap::strategy::CaS<T>,
15 : {
16 2 : fn try_rcu<R, F, E>(&self, mut f: F) -> Result<T, E>
17 2 : where
18 2 : F: FnMut(&T) -> Result<R, E>,
19 2 : R: Into<T>,
20 2 : {
21 1 : fn ptr_eq<Base, A, B>(a: A, b: B) -> bool
22 1 : where
23 1 : A: arc_swap::AsRaw<Base>,
24 1 : B: arc_swap::AsRaw<Base>,
25 1 : {
26 1 : let a = a.as_raw();
27 1 : let b = b.as_raw();
28 1 : std::ptr::eq(a, b)
29 1 : }
30 :
31 2 : let mut cur = self.load();
32 : loop {
33 2 : let new = f(&cur)?.into();
34 1 : let prev = self.compare_and_swap(&*cur, new);
35 1 : let swapped = ptr_eq(&*cur, &*prev);
36 1 : if swapped {
37 1 : return Ok(arc_swap::Guard::into_inner(prev));
38 0 : } else {
39 0 : cur = prev;
40 0 : }
41 : }
42 2 : }
43 : }
44 :
45 : #[cfg(test)]
46 : mod tests {
47 : use super::*;
48 : use arc_swap::ArcSwap;
49 : use std::sync::Arc;
50 :
51 : #[test]
52 1 : fn test_try_rcu_success() {
53 1 : let swap = ArcSwap::from(Arc::new(42));
54 1 :
55 1 : let result = swap.try_rcu(|value| -> Result<_, String> { Ok(**value + 1) });
56 1 :
57 1 : assert!(result.is_ok());
58 1 : assert_eq!(**swap.load(), 43);
59 1 : }
60 :
61 : #[test]
62 1 : fn test_try_rcu_error() {
63 1 : let swap = ArcSwap::from(Arc::new(42));
64 1 :
65 1 : let result = swap.try_rcu(|value| -> Result<i32, _> {
66 1 : if **value == 42 {
67 1 : Err("err")
68 : } else {
69 0 : Ok(**value + 1)
70 : }
71 1 : });
72 1 :
73 1 : assert!(result.is_err());
74 1 : assert_eq!(result.unwrap_err(), "err");
75 1 : assert_eq!(**swap.load(), 42);
76 1 : }
77 : }
|