LCOV - differential code coverage report
Current view: top level - libs/utils/src - seqwait.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 95.5 % 157 150 7 150
Current Date: 2023-10-19 02:04:12 Functions: 64.4 % 45 29 16 29
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : #![warn(missing_docs)]
       2                 : 
       3                 : use std::cmp::{Eq, Ordering, PartialOrd};
       4                 : use std::collections::BinaryHeap;
       5                 : use std::fmt::Debug;
       6                 : use std::mem;
       7                 : use std::sync::Mutex;
       8                 : use std::time::Duration;
       9                 : use tokio::sync::watch::{channel, Receiver, Sender};
      10                 : use tokio::time::timeout;
      11                 : 
      12                 : /// An error happened while waiting for a number
      13 CBC           4 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
      14                 : pub enum SeqWaitError {
      15                 :     /// The wait timeout was reached
      16                 :     #[error("seqwait timeout was reached")]
      17                 :     Timeout,
      18                 : 
      19                 :     /// [`SeqWait::shutdown`] was called
      20                 :     #[error("SeqWait::shutdown was called")]
      21                 :     Shutdown,
      22                 : }
      23                 : 
      24                 : /// Monotonically increasing value
      25                 : ///
      26                 : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
      27                 : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
      28                 : /// any type that can expose counter. `V` is the type of exposed counter.
      29                 : pub trait MonotonicCounter<V> {
      30                 :     /// Bump counter value and check that it goes forward
      31                 :     /// N.B.: new_val is an actual new value, not a difference.
      32                 :     fn cnt_advance(&mut self, new_val: V);
      33                 : 
      34                 :     /// Get counter value
      35                 :     fn cnt_value(&self) -> V;
      36                 : }
      37                 : 
      38                 : /// Internal components of a `SeqWait`
      39                 : struct SeqWaitInt<S, V>
      40                 : where
      41                 :     S: MonotonicCounter<V>,
      42                 :     V: Ord,
      43                 : {
      44                 :     waiters: BinaryHeap<Waiter<V>>,
      45                 :     current: S,
      46                 :     shutdown: bool,
      47                 : }
      48                 : 
      49                 : struct Waiter<T>
      50                 : where
      51                 :     T: Ord,
      52                 : {
      53                 :     wake_num: T,              // wake me when this number arrives ...
      54                 :     wake_channel: Sender<()>, // ... by sending a message to this channel
      55                 : }
      56                 : 
      57                 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
      58                 : // to get that.
      59                 : impl<T: Ord> PartialOrd for Waiter<T> {
      60           90426 :     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
      61           90426 :         Some(self.cmp(other))
      62           90426 :     }
      63                 : }
      64                 : 
      65                 : impl<T: Ord> Ord for Waiter<T> {
      66           90426 :     fn cmp(&self, other: &Self) -> Ordering {
      67           90426 :         other.wake_num.cmp(&self.wake_num)
      68           90426 :     }
      69                 : }
      70                 : 
      71                 : impl<T: Ord> PartialEq for Waiter<T> {
      72 UBC           0 :     fn eq(&self, other: &Self) -> bool {
      73               0 :         other.wake_num == self.wake_num
      74               0 :     }
      75                 : }
      76                 : 
      77                 : impl<T: Ord> Eq for Waiter<T> {}
      78                 : 
      79                 : /// A tool for waiting on a sequence number
      80                 : ///
      81                 : /// This provides a way to wait the arrival of a number.
      82                 : /// As soon as the number arrives by another caller calling
      83                 : /// [`advance`], then the waiter will be woken up.
      84                 : ///
      85                 : /// This implementation takes a blocking Mutex on both [`wait_for`]
      86                 : /// and [`advance`], meaning there may be unexpected executor blocking
      87                 : /// due to thread scheduling unfairness. There are probably better
      88                 : /// implementations, but we can probably live with this for now.
      89                 : ///
      90                 : /// [`wait_for`]: SeqWait::wait_for
      91                 : /// [`advance`]: SeqWait::advance
      92                 : ///
      93                 : /// `S` means Storage, `V` is type of counter that this storage exposes.
      94                 : ///
      95                 : pub struct SeqWait<S, V>
      96                 : where
      97                 :     S: MonotonicCounter<V>,
      98                 :     V: Ord,
      99                 : {
     100                 :     internal: Mutex<SeqWaitInt<S, V>>,
     101                 : }
     102                 : 
     103                 : impl<S, V> SeqWait<S, V>
     104                 : where
     105                 :     S: MonotonicCounter<V> + Copy,
     106                 :     V: Ord + Copy,
     107                 : {
     108                 :     /// Create a new `SeqWait`, initialized to a particular number
     109 CBC        1304 :     pub fn new(starting_num: S) -> Self {
     110            1304 :         let internal = SeqWaitInt {
     111            1304 :             waiters: BinaryHeap::new(),
     112            1304 :             current: starting_num,
     113            1304 :             shutdown: false,
     114            1304 :         };
     115            1304 :         SeqWait {
     116            1304 :             internal: Mutex::new(internal),
     117            1304 :         }
     118            1304 :     }
     119                 : 
     120                 :     /// Shut down a `SeqWait`, causing all waiters (present and
     121                 :     /// future) to return an error.
     122               2 :     pub fn shutdown(&self) {
     123               2 :         let waiters = {
     124               2 :             // Prevent new waiters; wake all those that exist.
     125               2 :             // Wake everyone with an error.
     126               2 :             let mut internal = self.internal.lock().unwrap();
     127               2 : 
     128               2 :             // This will steal the entire waiters map.
     129               2 :             // When we drop it all waiters will be woken.
     130               2 :             mem::take(&mut internal.waiters)
     131               2 : 
     132               2 :             // Drop the lock as we exit this scope.
     133               2 :         };
     134               2 : 
     135               2 :         // When we drop the waiters list, each Receiver will
     136               2 :         // be woken with an error.
     137               2 :         // This drop doesn't need to be explicit; it's done
     138               2 :         // here to make it easier to read the code and understand
     139               2 :         // the order of events.
     140               2 :         drop(waiters);
     141               2 :     }
     142                 : 
     143                 :     /// Wait for a number to arrive
     144                 :     ///
     145                 :     /// This call won't complete until someone has called `advance`
     146                 :     /// with a number greater than or equal to the one we're waiting for.
     147                 :     ///
     148                 :     /// This function is async cancellation-safe.
     149               4 :     pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
     150               4 :         match self.queue_for_wait(num) {
     151               1 :             Ok(None) => Ok(()),
     152               3 :             Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
     153 UBC           0 :             Err(e) => Err(e),
     154                 :         }
     155 CBC           4 :     }
     156                 : 
     157                 :     /// Wait for a number to arrive
     158                 :     ///
     159                 :     /// This call won't complete until someone has called `advance`
     160                 :     /// with a number greater than or equal to the one we're waiting for.
     161                 :     ///
     162                 :     /// If that hasn't happened after the specified timeout duration,
     163                 :     /// [`SeqWaitError::Timeout`] will be returned.
     164                 :     ///
     165                 :     /// This function is async cancellation-safe.
     166         1278622 :     pub async fn wait_for_timeout(
     167         1278622 :         &self,
     168         1278622 :         num: V,
     169         1278622 :         timeout_duration: Duration,
     170         1278622 :     ) -> Result<(), SeqWaitError> {
     171         1278622 :         match self.queue_for_wait(num) {
     172         1176301 :             Ok(None) => Ok(()),
     173          145153 :             Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
     174          102314 :                 Ok(Ok(())) => Ok(()),
     175 UBC           0 :                 Ok(Err(_)) => Err(SeqWaitError::Shutdown),
     176 CBC           6 :                 Err(_) => Err(SeqWaitError::Timeout),
     177                 :             },
     178 UBC           0 :             Err(e) => Err(e),
     179                 :         }
     180 CBC     1278621 :     }
     181                 : 
     182                 :     /// Register and return a channel that will be notified when a number arrives,
     183                 :     /// or None, if it has already arrived.
     184         1278626 :     fn queue_for_wait(&self, num: V) -> Result<Option<Receiver<()>>, SeqWaitError> {
     185         1278626 :         let mut internal = self.internal.lock().unwrap();
     186         1278626 :         if internal.current.cnt_value() >= num {
     187         1176302 :             return Ok(None);
     188          102324 :         }
     189          102324 :         if internal.shutdown {
     190 UBC           0 :             return Err(SeqWaitError::Shutdown);
     191 CBC      102324 :         }
     192          102324 : 
     193          102324 :         // Create a new channel.
     194          102324 :         let (tx, rx) = channel(());
     195          102324 :         internal.waiters.push(Waiter {
     196          102324 :             wake_num: num,
     197          102324 :             wake_channel: tx,
     198          102324 :         });
     199          102324 :         // Drop the lock as we exit this scope.
     200          102324 :         Ok(Some(rx))
     201         1278626 :     }
     202                 : 
     203                 :     /// Announce a new number has arrived
     204                 :     ///
     205                 :     /// All waiters at this value or below will be woken.
     206                 :     ///
     207                 :     /// Returns the old number.
     208        69330607 :     pub fn advance(&self, num: V) -> V {
     209                 :         let old_value;
     210        69330606 :         let wake_these = {
     211        69330607 :             let mut internal = self.internal.lock().unwrap();
     212        69330607 : 
     213        69330607 :             old_value = internal.current.cnt_value();
     214        69330607 :             if old_value >= num {
     215               1 :                 return old_value;
     216        69330606 :             }
     217        69330606 :             internal.current.cnt_advance(num);
     218        69330606 : 
     219        69330606 :             // Pop all waiters <= num from the heap. Collect them in a vector, and
     220        69330606 :             // wake them up after releasing the lock.
     221        69330606 :             let mut wake_these = Vec::new();
     222        69432924 :             while let Some(n) = internal.waiters.peek() {
     223        12212997 :                 if n.wake_num > num {
     224        12110679 :                     break;
     225          102318 :                 }
     226          102318 :                 wake_these.push(internal.waiters.pop().unwrap().wake_channel);
     227                 :             }
     228        69330606 :             wake_these
     229                 :         };
     230                 : 
     231        69432924 :         for tx in wake_these {
     232          102318 :             // This can fail if there are no receivers.
     233          102318 :             // We don't care; discard the error.
     234          102318 :             let _ = tx.send(());
     235          102318 :         }
     236        69330606 :         old_value
     237        69330607 :     }
     238                 : 
     239                 :     /// Read the current value, without waiting.
     240        82268228 :     pub fn load(&self) -> S {
     241        82268228 :         self.internal.lock().unwrap().current
     242        82268228 :     }
     243                 : }
     244                 : 
     245                 : #[cfg(test)]
     246                 : mod tests {
     247                 :     use super::*;
     248                 :     use std::sync::Arc;
     249                 :     use std::time::Duration;
     250                 : 
     251                 :     impl MonotonicCounter<i32> for i32 {
     252               3 :         fn cnt_advance(&mut self, val: i32) {
     253               3 :             assert!(*self <= val);
     254               3 :             *self = val;
     255               3 :         }
     256              10 :         fn cnt_value(&self) -> i32 {
     257              10 :             *self
     258              10 :         }
     259                 :     }
     260                 : 
     261               1 :     #[tokio::test]
     262               1 :     async fn seqwait() {
     263               1 :         let seq = Arc::new(SeqWait::new(0));
     264               1 :         let seq2 = Arc::clone(&seq);
     265               1 :         let seq3 = Arc::clone(&seq);
     266               1 :         let jh1 = tokio::task::spawn(async move {
     267               1 :             seq2.wait_for(42).await.expect("wait_for 42");
     268               1 :             let old = seq2.advance(100);
     269               1 :             assert_eq!(old, 99);
     270               1 :             seq2.wait_for_timeout(999, Duration::from_millis(100))
     271               1 :                 .await
     272               1 :                 .expect_err("no 999");
     273               1 :         });
     274               1 :         let jh2 = tokio::task::spawn(async move {
     275               1 :             seq3.wait_for(42).await.expect("wait_for 42");
     276               1 :             seq3.wait_for(0).await.expect("wait_for 0");
     277               1 :         });
     278               1 :         tokio::time::sleep(Duration::from_millis(200)).await;
     279               1 :         let old = seq.advance(99);
     280               1 :         assert_eq!(old, 0);
     281               1 :         seq.wait_for(100).await.expect("wait_for 100");
     282               1 : 
     283               1 :         // Calling advance with a smaller value is a no-op
     284               1 :         assert_eq!(seq.advance(98), 100);
     285               1 :         assert_eq!(seq.load(), 100);
     286                 : 
     287               1 :         jh1.await.unwrap();
     288               1 :         jh2.await.unwrap();
     289               1 : 
     290               1 :         seq.shutdown();
     291                 :     }
     292                 : 
     293               1 :     #[tokio::test]
     294               1 :     async fn seqwait_timeout() {
     295               1 :         let seq = Arc::new(SeqWait::new(0));
     296               1 :         let seq2 = Arc::clone(&seq);
     297               1 :         let jh = tokio::task::spawn(async move {
     298               1 :             let timeout = Duration::from_millis(1);
     299               1 :             let res = seq2.wait_for_timeout(42, timeout).await;
     300               1 :             assert_eq!(res, Err(SeqWaitError::Timeout));
     301               1 :         });
     302               1 :         tokio::time::sleep(Duration::from_millis(200)).await;
     303                 :         // This will attempt to wake, but nothing will happen
     304                 :         // because the waiter already dropped its Receiver.
     305               1 :         let old = seq.advance(99);
     306               1 :         assert_eq!(old, 0);
     307               1 :         jh.await.unwrap();
     308               1 : 
     309               1 :         seq.shutdown();
     310                 :     }
     311                 : }
        

Generated by: LCOV version 2.1-beta