LCOV - differential code coverage report
Current view: top level - libs/utils/src - seqwait.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 95.6 % 160 153 7 153
Current Date: 2024-01-09 02:06:09 Functions: 65.2 % 46 30 16 30
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : #![warn(missing_docs)]
       2                 : 
       3                 : use std::cmp::{Eq, Ordering, PartialOrd};
       4                 : use std::collections::BinaryHeap;
       5                 : use std::fmt::Debug;
       6                 : use std::mem;
       7                 : use std::sync::Mutex;
       8                 : use std::time::Duration;
       9                 : use tokio::sync::watch::{channel, Receiver, Sender};
      10                 : use tokio::time::timeout;
      11                 : 
      12                 : /// An error happened while waiting for a number
      13 CBC           1 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
      14                 : pub enum SeqWaitError {
      15                 :     /// The wait timeout was reached
      16                 :     #[error("seqwait timeout was reached")]
      17                 :     Timeout,
      18                 : 
      19                 :     /// [`SeqWait::shutdown`] was called
      20                 :     #[error("SeqWait::shutdown was called")]
      21                 :     Shutdown,
      22                 : }
      23                 : 
      24                 : /// Monotonically increasing value
      25                 : ///
      26                 : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
      27                 : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
      28                 : /// any type that can expose counter. `V` is the type of exposed counter.
      29                 : pub trait MonotonicCounter<V> {
      30                 :     /// Bump counter value and check that it goes forward
      31                 :     /// N.B.: new_val is an actual new value, not a difference.
      32                 :     fn cnt_advance(&mut self, new_val: V);
      33                 : 
      34                 :     /// Get counter value
      35                 :     fn cnt_value(&self) -> V;
      36                 : }
      37                 : 
      38                 : /// Internal components of a `SeqWait`
      39                 : struct SeqWaitInt<S, V>
      40                 : where
      41                 :     S: MonotonicCounter<V>,
      42                 :     V: Ord,
      43                 : {
      44                 :     waiters: BinaryHeap<Waiter<V>>,
      45                 :     current: S,
      46                 :     shutdown: bool,
      47                 : }
      48                 : 
      49                 : struct Waiter<T>
      50                 : where
      51                 :     T: Ord,
      52                 : {
      53                 :     wake_num: T,              // wake me when this number arrives ...
      54                 :     wake_channel: Sender<()>, // ... by sending a message to this channel
      55                 : }
      56                 : 
      57                 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
      58                 : // to get that.
      59                 : impl<T: Ord> PartialOrd for Waiter<T> {
      60          103818 :     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
      61          103818 :         Some(self.cmp(other))
      62          103818 :     }
      63                 : }
      64                 : 
      65                 : impl<T: Ord> Ord for Waiter<T> {
      66          103818 :     fn cmp(&self, other: &Self) -> Ordering {
      67          103818 :         other.wake_num.cmp(&self.wake_num)
      68          103818 :     }
      69                 : }
      70                 : 
      71                 : impl<T: Ord> PartialEq for Waiter<T> {
      72 UBC           0 :     fn eq(&self, other: &Self) -> bool {
      73               0 :         other.wake_num == self.wake_num
      74               0 :     }
      75                 : }
      76                 : 
      77                 : impl<T: Ord> Eq for Waiter<T> {}
      78                 : 
      79                 : /// A tool for waiting on a sequence number
      80                 : ///
      81                 : /// This provides a way to wait the arrival of a number.
      82                 : /// As soon as the number arrives by another caller calling
      83                 : /// [`advance`], then the waiter will be woken up.
      84                 : ///
      85                 : /// This implementation takes a blocking Mutex on both [`wait_for`]
      86                 : /// and [`advance`], meaning there may be unexpected executor blocking
      87                 : /// due to thread scheduling unfairness. There are probably better
      88                 : /// implementations, but we can probably live with this for now.
      89                 : ///
      90                 : /// [`wait_for`]: SeqWait::wait_for
      91                 : /// [`advance`]: SeqWait::advance
      92                 : ///
      93                 : /// `S` means Storage, `V` is type of counter that this storage exposes.
      94                 : ///
      95                 : pub struct SeqWait<S, V>
      96                 : where
      97                 :     S: MonotonicCounter<V>,
      98                 :     V: Ord,
      99                 : {
     100                 :     internal: Mutex<SeqWaitInt<S, V>>,
     101                 : }
     102                 : 
     103                 : impl<S, V> SeqWait<S, V>
     104                 : where
     105                 :     S: MonotonicCounter<V> + Copy,
     106                 :     V: Ord + Copy,
     107                 : {
     108                 :     /// Create a new `SeqWait`, initialized to a particular number
     109 CBC        1292 :     pub fn new(starting_num: S) -> Self {
     110            1292 :         let internal = SeqWaitInt {
     111            1292 :             waiters: BinaryHeap::new(),
     112            1292 :             current: starting_num,
     113            1292 :             shutdown: false,
     114            1292 :         };
     115            1292 :         SeqWait {
     116            1292 :             internal: Mutex::new(internal),
     117            1292 :         }
     118            1292 :     }
     119                 : 
     120                 :     /// Shut down a `SeqWait`, causing all waiters (present and
     121                 :     /// future) to return an error.
     122             745 :     pub fn shutdown(&self) {
     123             745 :         let waiters = {
     124             745 :             // Prevent new waiters; wake all those that exist.
     125             745 :             // Wake everyone with an error.
     126             745 :             let mut internal = self.internal.lock().unwrap();
     127             745 : 
     128             745 :             // Block any future waiters from starting
     129             745 :             internal.shutdown = true;
     130             745 : 
     131             745 :             // This will steal the entire waiters map.
     132             745 :             // When we drop it all waiters will be woken.
     133             745 :             mem::take(&mut internal.waiters)
     134             745 : 
     135             745 :             // Drop the lock as we exit this scope.
     136             745 :         };
     137             745 : 
     138             745 :         // When we drop the waiters list, each Receiver will
     139             745 :         // be woken with an error.
     140             745 :         // This drop doesn't need to be explicit; it's done
     141             745 :         // here to make it easier to read the code and understand
     142             745 :         // the order of events.
     143             745 :         drop(waiters);
     144             745 :     }
     145                 : 
     146                 :     /// Wait for a number to arrive
     147                 :     ///
     148                 :     /// This call won't complete until someone has called `advance`
     149                 :     /// with a number greater than or equal to the one we're waiting for.
     150                 :     ///
     151                 :     /// This function is async cancellation-safe.
     152               4 :     pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
     153               4 :         match self.queue_for_wait(num) {
     154               1 :             Ok(None) => Ok(()),
     155               3 :             Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
     156 UBC           0 :             Err(e) => Err(e),
     157                 :         }
     158 CBC           4 :     }
     159                 : 
     160                 :     /// Wait for a number to arrive
     161                 :     ///
     162                 :     /// This call won't complete until someone has called `advance`
     163                 :     /// with a number greater than or equal to the one we're waiting for.
     164                 :     ///
     165                 :     /// If that hasn't happened after the specified timeout duration,
     166                 :     /// [`SeqWaitError::Timeout`] will be returned.
     167                 :     ///
     168                 :     /// This function is async cancellation-safe.
     169         1419430 :     pub async fn wait_for_timeout(
     170         1419430 :         &self,
     171         1419430 :         num: V,
     172         1419430 :         timeout_duration: Duration,
     173         1419430 :     ) -> Result<(), SeqWaitError> {
     174         1419430 :         match self.queue_for_wait(num) {
     175         1366322 :             Ok(None) => Ok(()),
     176          103453 :             Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
     177           53102 :                 Ok(Ok(())) => Ok(()),
     178 UBC           0 :                 Ok(Err(_)) => Err(SeqWaitError::Shutdown),
     179 CBC           6 :                 Err(_) => Err(SeqWaitError::Timeout),
     180                 :             },
     181 UBC           0 :             Err(e) => Err(e),
     182                 :         }
     183 CBC     1419430 :     }
     184                 : 
     185                 :     /// Register and return a channel that will be notified when a number arrives,
     186                 :     /// or None, if it has already arrived.
     187         1419434 :     fn queue_for_wait(&self, num: V) -> Result<Option<Receiver<()>>, SeqWaitError> {
     188         1419434 :         let mut internal = self.internal.lock().unwrap();
     189         1419434 :         if internal.current.cnt_value() >= num {
     190         1366323 :             return Ok(None);
     191           53111 :         }
     192           53111 :         if internal.shutdown {
     193 UBC           0 :             return Err(SeqWaitError::Shutdown);
     194 CBC       53111 :         }
     195           53111 : 
     196           53111 :         // Create a new channel.
     197           53111 :         let (tx, rx) = channel(());
     198           53111 :         internal.waiters.push(Waiter {
     199           53111 :             wake_num: num,
     200           53111 :             wake_channel: tx,
     201           53111 :         });
     202           53111 :         // Drop the lock as we exit this scope.
     203           53111 :         Ok(Some(rx))
     204         1419434 :     }
     205                 : 
     206                 :     /// Announce a new number has arrived
     207                 :     ///
     208                 :     /// All waiters at this value or below will be woken.
     209                 :     ///
     210                 :     /// Returns the old number.
     211        49362595 :     pub fn advance(&self, num: V) -> V {
     212                 :         let old_value;
     213        48222287 :         let wake_these = {
     214        49362595 :             let mut internal = self.internal.lock().unwrap();
     215        49362595 : 
     216        49362595 :             old_value = internal.current.cnt_value();
     217        49362595 :             if old_value >= num {
     218         1140308 :                 return old_value;
     219        48222287 :             }
     220        48222287 :             internal.current.cnt_advance(num);
     221        48222287 : 
     222        48222287 :             // Pop all waiters <= num from the heap. Collect them in a vector, and
     223        48222287 :             // wake them up after releasing the lock.
     224        48222287 :             let mut wake_these = Vec::new();
     225        48275393 :             while let Some(n) = internal.waiters.peek() {
     226         9546338 :                 if n.wake_num > num {
     227         9493232 :                     break;
     228           53106 :                 }
     229           53106 :                 wake_these.push(internal.waiters.pop().unwrap().wake_channel);
     230                 :             }
     231        48222287 :             wake_these
     232                 :         };
     233                 : 
     234        48275393 :         for tx in wake_these {
     235           53106 :             // This can fail if there are no receivers.
     236           53106 :             // We don't care; discard the error.
     237           53106 :             let _ = tx.send(());
     238           53106 :         }
     239        48222287 :         old_value
     240        49362595 :     }
     241                 : 
     242                 :     /// Read the current value, without waiting.
     243         6350045 :     pub fn load(&self) -> S {
     244         6350045 :         self.internal.lock().unwrap().current
     245         6350045 :     }
     246                 : }
     247                 : 
     248                 : #[cfg(test)]
     249                 : mod tests {
     250                 :     use super::*;
     251                 :     use std::sync::Arc;
     252                 :     use std::time::Duration;
     253                 : 
     254                 :     impl MonotonicCounter<i32> for i32 {
     255               3 :         fn cnt_advance(&mut self, val: i32) {
     256               3 :             assert!(*self <= val);
     257               3 :             *self = val;
     258               3 :         }
     259              10 :         fn cnt_value(&self) -> i32 {
     260              10 :             *self
     261              10 :         }
     262                 :     }
     263                 : 
     264               1 :     #[tokio::test]
     265               1 :     async fn seqwait() {
     266               1 :         let seq = Arc::new(SeqWait::new(0));
     267               1 :         let seq2 = Arc::clone(&seq);
     268               1 :         let seq3 = Arc::clone(&seq);
     269               1 :         let jh1 = tokio::task::spawn(async move {
     270               1 :             seq2.wait_for(42).await.expect("wait_for 42");
     271               1 :             let old = seq2.advance(100);
     272               1 :             assert_eq!(old, 99);
     273               1 :             seq2.wait_for_timeout(999, Duration::from_millis(100))
     274               1 :                 .await
     275               1 :                 .expect_err("no 999");
     276               1 :         });
     277               1 :         let jh2 = tokio::task::spawn(async move {
     278               1 :             seq3.wait_for(42).await.expect("wait_for 42");
     279               1 :             seq3.wait_for(0).await.expect("wait_for 0");
     280               1 :         });
     281               1 :         tokio::time::sleep(Duration::from_millis(200)).await;
     282               1 :         let old = seq.advance(99);
     283               1 :         assert_eq!(old, 0);
     284               1 :         seq.wait_for(100).await.expect("wait_for 100");
     285               1 : 
     286               1 :         // Calling advance with a smaller value is a no-op
     287               1 :         assert_eq!(seq.advance(98), 100);
     288               1 :         assert_eq!(seq.load(), 100);
     289                 : 
     290               1 :         jh1.await.unwrap();
     291               1 :         jh2.await.unwrap();
     292               1 : 
     293               1 :         seq.shutdown();
     294                 :     }
     295                 : 
     296               1 :     #[tokio::test]
     297               1 :     async fn seqwait_timeout() {
     298               1 :         let seq = Arc::new(SeqWait::new(0));
     299               1 :         let seq2 = Arc::clone(&seq);
     300               1 :         let jh = tokio::task::spawn(async move {
     301               1 :             let timeout = Duration::from_millis(1);
     302               1 :             let res = seq2.wait_for_timeout(42, timeout).await;
     303               1 :             assert_eq!(res, Err(SeqWaitError::Timeout));
     304               1 :         });
     305               1 :         tokio::time::sleep(Duration::from_millis(200)).await;
     306                 :         // This will attempt to wake, but nothing will happen
     307                 :         // because the waiter already dropped its Receiver.
     308               1 :         let old = seq.advance(99);
     309               1 :         assert_eq!(old, 0);
     310               1 :         jh.await.unwrap();
     311               1 : 
     312               1 :         seq.shutdown();
     313                 :     }
     314                 : }
        

Generated by: LCOV version 2.1-beta