|             Line data    Source code 
       1              : //!
       2              : //! RCU stands for Read-Copy-Update. It's a synchronization mechanism somewhat
       3              : //! similar to a lock, but it allows readers to "hold on" to an old value of RCU
       4              : //! without blocking writers, and allows writing a new value without blocking
       5              : //! readers. When you update the value, the new value is immediately visible
       6              : //! to new readers, but the update waits until all existing readers have
       7              : //! finished, so that on return, no one sees the old value anymore.
       8              : //!
       9              : //! This implementation isn't wait-free; it uses an RwLock that is held for a
      10              : //! short duration when the value is read or updated.
      11              : //!
      12              : //! # Examples
      13              : //!
      14              : //! Read a value and do things with it while holding the guard:
      15              : //!
      16              : //! ```
      17              : //! # let rcu = utils::simple_rcu::Rcu::new(1);
      18              : //! {
      19              : //!     let read = rcu.read();
      20              : //!     println!("the current value is {}", *read);
      21              : //!     // exiting the scope drops the read-guard, and allows concurrent writers
      22              : //!     // to finish.
      23              : //! }
      24              : //! ```
      25              : //!
      26              : //! Increment the value by one, and wait for old readers to finish:
      27              : //!
      28              : //! ```
      29              : //! # async fn dox() {
      30              : //! # let rcu = utils::simple_rcu::Rcu::new(1);
      31              : //! let write_guard = rcu.lock_for_write();
      32              : //!
      33              : //! // NB: holding `write_guard` blocks new readers and writers. Keep this section short!
      34              : //! let new_value = *write_guard + 1;
      35              : //!
      36              : //! let waitlist = write_guard.store_and_unlock(new_value); // consumes `write_guard`
      37              : //!
      38              : //! // Concurrent reads and writes are now possible again. Wait for all the readers
      39              : //! // that still observe the old value to finish.
      40              : //! waitlist.wait().await;
      41              : //! # }
      42              : //! ```
      43              : //!
      44              : #![warn(missing_docs)]
      45              : 
      46              : use std::ops::Deref;
      47              : use std::sync::{Arc, RwLock, RwLockWriteGuard, Weak};
      48              : 
      49              : use tokio::sync::watch;
      50              : 
      51              : /// Rcu allows multiple readers to read and hold onto a value without blocking
      52              : /// (for very long).
      53              : ///
      54              : /// Storing to the Rcu updates the value, making new readers immediately see
      55              : /// the new value, but it also waits for all current readers to finish.
      56              : pub struct Rcu<V> {
      57              :     inner: RwLock<RcuInner<V>>,
      58              : }
      59              : 
      60              : struct RcuInner<V> {
      61              :     current_cell: Arc<RcuCell<V>>,
      62              :     old_cells: Vec<Weak<RcuCell<V>>>,
      63              : }
      64              : 
      65              : ///
      66              : /// RcuCell holds one value. It can be the latest one, or an old one.
      67              : ///
      68              : struct RcuCell<V> {
      69              :     value: V,
      70              : 
      71              :     /// A dummy channel. We never send anything to this channel. The point is
      72              :     /// that when the RcuCell is dropped, any subscribed Receivers will be notified
      73              :     /// that the channel is closed. Updaters can use this to wait out until the
      74              :     /// RcuCell has been dropped, i.e. until the old value is no longer in use.
      75              :     ///
      76              :     /// We never send anything to this, we just need to hold onto it so that the
      77              :     /// Receivers will be notified when it's dropped.
      78              :     watch: watch::Sender<()>,
      79              : }
      80              : 
      81              : impl<V> RcuCell<V> {
      82          971 :     fn new(value: V) -> Self {
      83          971 :         let (watch_sender, _) = watch::channel(());
      84          971 :         RcuCell {
      85          971 :             value,
      86          971 :             watch: watch_sender,
      87          971 :         }
      88          971 :     }
      89              : }
      90              : 
      91              : impl<V> Rcu<V> {
      92              :     /// Create a new `Rcu`, initialized to `starting_val`
      93          905 :     pub fn new(starting_val: V) -> Self {
      94          905 :         let inner = RcuInner {
      95          905 :             current_cell: Arc::new(RcuCell::new(starting_val)),
      96          905 :             old_cells: Vec::new(),
      97          905 :         };
      98          905 :         Self {
      99          905 :             inner: RwLock::new(inner),
     100          905 :         }
     101          905 :     }
     102              : 
     103              :     ///
     104              :     /// Read current value. Any store() calls will block until the returned
     105              :     /// guard object is dropped.
     106              :     ///
     107      1708194 :     pub fn read(&self) -> RcuReadGuard<V> {
     108      1708194 :         let current_cell = Arc::clone(&self.inner.read().unwrap().current_cell);
     109      1708194 :         RcuReadGuard { cell: current_cell }
     110      1708194 :     }
     111              : 
     112              :     ///
     113              :     /// Lock the current value for updating. Returns a guard object that can be
     114              :     /// used to read the current value, and to store a new value.
     115              :     ///
     116              :     /// Note: holding the write-guard blocks concurrent readers, so you should
     117              :     /// finish the update and drop the guard quickly! Multiple writers can be
     118              :     /// waiting on the RcuWriteGuard::store step at the same time, however.
     119              :     ///
     120           66 :     pub fn lock_for_write(&self) -> RcuWriteGuard<'_, V> {
     121           66 :         let inner = self.inner.write().unwrap();
     122           66 :         RcuWriteGuard { inner }
     123           66 :     }
     124              : }
     125              : 
     126              : ///
     127              : /// Read guard returned by `read`
     128              : ///
     129              : pub struct RcuReadGuard<V> {
     130              :     cell: Arc<RcuCell<V>>,
     131              : }
     132              : 
     133              : impl<V> Deref for RcuReadGuard<V> {
     134              :     type Target = V;
     135              : 
     136         3997 :     fn deref(&self) -> &V {
     137         3997 :         &self.cell.value
     138         3997 :     }
     139              : }
     140              : 
     141              : ///
     142              : /// Write guard returned by `write`
     143              : ///
     144              : /// NB: Holding this guard blocks all concurrent `read` and `write` calls, so it should only be
     145              : /// held for a short duration!
     146              : ///
     147              : /// Calling [`Self::store_and_unlock`] consumes the guard, making new reads and new writes possible
     148              : /// again.
     149              : ///
     150              : pub struct RcuWriteGuard<'a, V> {
     151              :     inner: RwLockWriteGuard<'a, RcuInner<V>>,
     152              : }
     153              : 
     154              : impl<V> Deref for RcuWriteGuard<'_, V> {
     155              :     type Target = V;
     156              : 
     157           10 :     fn deref(&self) -> &V {
     158           10 :         &self.inner.current_cell.value
     159           10 :     }
     160              : }
     161              : 
     162              : impl<V> RcuWriteGuard<'_, V> {
     163              :     ///
     164              :     /// Store a new value. The new value will be written to the Rcu immediately,
     165              :     /// and will be immediately seen by any `read` calls that start afterwards.
     166              :     ///
     167              :     /// Returns a list of readers that can see old values. You can call `wait()`
     168              :     /// on it to wait for them to finish.
     169              :     ///
     170           66 :     pub fn store_and_unlock(mut self, new_val: V) -> RcuWaitList {
     171           66 :         let new_cell = Arc::new(RcuCell::new(new_val));
     172           66 : 
     173           66 :         let mut watches = Vec::new();
     174           66 :         {
     175           66 :             let old = std::mem::replace(&mut self.inner.current_cell, new_cell);
     176           66 :             self.inner.old_cells.push(Arc::downgrade(&old));
     177           66 : 
     178           66 :             // cleanup old cells that no longer have any readers, and collect
     179           66 :             // the watches for any that do.
     180           66 :             self.inner.old_cells.retain(|weak| {
     181           83 :                 if let Some(cell) = weak.upgrade() {
     182           67 :                     watches.push(cell.watch.subscribe());
     183           67 :                     true
     184              :                 } else {
     185           16 :                     false
     186              :                 }
     187           66 :             });
     188           66 :         }
     189           66 :         RcuWaitList(watches)
     190           66 :     }
     191              : }
     192              : 
     193              : ///
     194              : /// List of readers who can still see old values.
     195              : ///
     196              : pub struct RcuWaitList(Vec<watch::Receiver<()>>);
     197              : 
     198              : impl RcuWaitList {
     199              :     ///
     200              :     /// Wait for old readers to finish.
     201              :     ///
     202           66 :     pub async fn wait(mut self) {
     203              :         // after all the old_cells are no longer in use, we're done
     204           67 :         for w in self.0.iter_mut() {
     205              :             // This will block until the Receiver is closed. That happens when
     206              :             // the RcuCell is dropped.
     207              :             #[allow(clippy::single_match)]
     208           67 :             match w.changed().await {
     209            0 :                 Ok(_) => panic!("changed() unexpectedly succeeded on dummy channel"),
     210           67 :                 Err(_) => {
     211           67 :                     // closed, which means that the cell has been dropped, and
     212           67 :                     // its value is no longer in use
     213           67 :                 }
     214              :             }
     215              :         }
     216            2 :     }
     217              : }
     218              : 
     219              : #[cfg(test)]
     220              : mod tests {
     221              :     use std::sync::Mutex;
     222              :     use std::time::Duration;
     223              : 
     224              :     use super::*;
     225              : 
     226              :     #[tokio::test]
     227            1 :     async fn two_writers() {
     228            1 :         let rcu = Rcu::new(1);
     229            1 : 
     230            1 :         let read1 = rcu.read();
     231            1 :         assert_eq!(*read1, 1);
     232            1 : 
     233            1 :         let write2 = rcu.lock_for_write();
     234            1 :         assert_eq!(*write2, 1);
     235            1 :         let wait2 = write2.store_and_unlock(2);
     236            1 : 
     237            1 :         let read2 = rcu.read();
     238            1 :         assert_eq!(*read2, 2);
     239            1 : 
     240            1 :         let write3 = rcu.lock_for_write();
     241            1 :         assert_eq!(*write3, 2);
     242            1 :         let wait3 = write3.store_and_unlock(3);
     243            1 : 
     244            1 :         // new reader can see the new value, and old readers continue to see the old values.
     245            1 :         let read3 = rcu.read();
     246            1 :         assert_eq!(*read3, 3);
     247            1 :         assert_eq!(*read2, 2);
     248            1 :         assert_eq!(*read1, 1);
     249            1 : 
     250            1 :         let log = Arc::new(Mutex::new(Vec::new()));
     251            1 :         // Wait for the old readers to finish in separate tasks.
     252            1 :         let log_clone = Arc::clone(&log);
     253            1 :         let task2 = tokio::spawn(async move {
     254            1 :             wait2.wait().await;
     255            1 :             log_clone.lock().unwrap().push("wait2 done");
     256            1 :         });
     257            1 :         let log_clone = Arc::clone(&log);
     258            1 :         let task3 = tokio::spawn(async move {
     259            1 :             wait3.wait().await;
     260            1 :             log_clone.lock().unwrap().push("wait3 done");
     261            1 :         });
     262            1 : 
     263            1 :         // without this sleep the test can pass on accident if the writer is slow
     264            1 :         tokio::time::sleep(Duration::from_millis(100)).await;
     265            1 : 
     266            1 :         // Release first reader. This allows first write to finish, but calling
     267            1 :         // wait() on the 'task3' would still block.
     268            1 :         log.lock().unwrap().push("dropping read1");
     269            1 :         drop(read1);
     270            1 :         task2.await.unwrap();
     271            1 : 
     272            1 :         assert!(!task3.is_finished());
     273            1 : 
     274            1 :         tokio::time::sleep(Duration::from_millis(100)).await;
     275            1 : 
     276            1 :         // Release second reader, and finish second writer.
     277            1 :         log.lock().unwrap().push("dropping read2");
     278            1 :         drop(read2);
     279            1 :         task3.await.unwrap();
     280            1 : 
     281            1 :         assert_eq!(
     282            1 :             log.lock().unwrap().as_slice(),
     283            1 :             &[
     284            1 :                 "dropping read1",
     285            1 :                 "wait2 done",
     286            1 :                 "dropping read2",
     287            1 :                 "wait3 done"
     288            1 :             ]
     289            1 :         );
     290            1 :     }
     291              : }
         |