|             Line data    Source code 
       1              : //! Try RCU extension lifted from <https://github.com/vorner/arc-swap/issues/94#issuecomment-1987154023>
       2              : 
       3              : pub trait ArcSwapExt<T> {
       4              :     /// [`ArcSwap::rcu`](arc_swap::ArcSwap::rcu), but with Result that short-circuits on error.
       5              :     fn try_rcu<R, F, E>(&self, f: F) -> Result<T, E>
       6              :     where
       7              :         F: FnMut(&T) -> Result<R, E>,
       8              :         R: Into<T>;
       9              : }
      10              : 
      11              : impl<T, S> ArcSwapExt<T> for arc_swap::ArcSwapAny<T, S>
      12              : where
      13              :     T: arc_swap::RefCnt,
      14              :     S: arc_swap::strategy::CaS<T>,
      15              : {
      16            3 :     fn try_rcu<R, F, E>(&self, mut f: F) -> Result<T, E>
      17            3 :     where
      18            3 :         F: FnMut(&T) -> Result<R, E>,
      19            3 :         R: Into<T>,
      20              :     {
      21            2 :         fn ptr_eq<Base, A, B>(a: A, b: B) -> bool
      22            2 :         where
      23            2 :             A: arc_swap::AsRaw<Base>,
      24            2 :             B: arc_swap::AsRaw<Base>,
      25              :         {
      26            2 :             let a = a.as_raw();
      27            2 :             let b = b.as_raw();
      28            2 :             std::ptr::eq(a, b)
      29            1 :         }
      30              : 
      31            3 :         let mut cur = self.load();
      32              :         loop {
      33            3 :             let new = f(&cur)?.into();
      34            2 :             let prev = self.compare_and_swap(&*cur, new);
      35            2 :             let swapped = ptr_eq(&*cur, &*prev);
      36            2 :             if swapped {
      37            2 :                 return Ok(arc_swap::Guard::into_inner(prev));
      38            0 :             } else {
      39            0 :                 cur = prev;
      40            0 :             }
      41              :         }
      42            2 :     }
      43              : }
      44              : 
      45              : #[cfg(test)]
      46              : mod tests {
      47              :     use std::sync::Arc;
      48              : 
      49              :     use arc_swap::ArcSwap;
      50              : 
      51              :     use super::*;
      52              : 
      53              :     #[test]
      54            1 :     fn test_try_rcu_success() {
      55            1 :         let swap = ArcSwap::from(Arc::new(42));
      56              : 
      57            1 :         let result = swap.try_rcu(|value| -> Result<_, String> { Ok(**value + 1) });
      58              : 
      59            1 :         assert!(result.is_ok());
      60            1 :         assert_eq!(**swap.load(), 43);
      61            1 :     }
      62              : 
      63              :     #[test]
      64            1 :     fn test_try_rcu_error() {
      65            1 :         let swap = ArcSwap::from(Arc::new(42));
      66              : 
      67            1 :         let result = swap.try_rcu(|value| -> Result<i32, _> {
      68            1 :             if **value == 42 {
      69            1 :                 Err("err")
      70              :             } else {
      71            0 :                 Ok(**value + 1)
      72              :             }
      73            1 :         });
      74              : 
      75            1 :         assert!(result.is_err());
      76            1 :         assert_eq!(result.unwrap_err(), "err");
      77            1 :         assert_eq!(**swap.load(), 42);
      78            1 :     }
      79              : }
         |