LCOV - code coverage report
Current view: top level - libs/utils/src - seqwait.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 95.6 % 160 153
Test Date: 2024-02-07 07:37:29 Functions: 66.1 % 56 37

            Line data    Source code
       1              : #![warn(missing_docs)]
       2              : 
       3              : use std::cmp::{Eq, Ordering, PartialOrd};
       4              : use std::collections::BinaryHeap;
       5              : use std::fmt::Debug;
       6              : use std::mem;
       7              : use std::sync::Mutex;
       8              : use std::time::Duration;
       9              : use tokio::sync::watch::{channel, Receiver, Sender};
      10              : use tokio::time::timeout;
      11              : 
      12              : /// An error happened while waiting for a number
      13            2 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
      14              : pub enum SeqWaitError {
      15              :     /// The wait timeout was reached
      16              :     #[error("seqwait timeout was reached")]
      17              :     Timeout,
      18              : 
      19              :     /// [`SeqWait::shutdown`] was called
      20              :     #[error("SeqWait::shutdown was called")]
      21              :     Shutdown,
      22              : }
      23              : 
      24              : /// Monotonically increasing value
      25              : ///
      26              : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
      27              : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
      28              : /// any type that can expose counter. `V` is the type of exposed counter.
      29              : pub trait MonotonicCounter<V> {
      30              :     /// Bump counter value and check that it goes forward
      31              :     /// N.B.: new_val is an actual new value, not a difference.
      32              :     fn cnt_advance(&mut self, new_val: V);
      33              : 
      34              :     /// Get counter value
      35              :     fn cnt_value(&self) -> V;
      36              : }
      37              : 
      38              : /// Internal components of a `SeqWait`
      39              : struct SeqWaitInt<S, V>
      40              : where
      41              :     S: MonotonicCounter<V>,
      42              :     V: Ord,
      43              : {
      44              :     waiters: BinaryHeap<Waiter<V>>,
      45              :     current: S,
      46              :     shutdown: bool,
      47              : }
      48              : 
      49              : struct Waiter<T>
      50              : where
      51              :     T: Ord,
      52              : {
      53              :     wake_num: T,              // wake me when this number arrives ...
      54              :     wake_channel: Sender<()>, // ... by sending a message to this channel
      55              : }
      56              : 
      57              : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
      58              : // to get that.
      59              : impl<T: Ord> PartialOrd for Waiter<T> {
      60        96551 :     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
      61        96551 :         Some(self.cmp(other))
      62        96551 :     }
      63              : }
      64              : 
      65              : impl<T: Ord> Ord for Waiter<T> {
      66        96551 :     fn cmp(&self, other: &Self) -> Ordering {
      67        96551 :         other.wake_num.cmp(&self.wake_num)
      68        96551 :     }
      69              : }
      70              : 
      71              : impl<T: Ord> PartialEq for Waiter<T> {
      72            0 :     fn eq(&self, other: &Self) -> bool {
      73            0 :         other.wake_num == self.wake_num
      74            0 :     }
      75              : }
      76              : 
      77              : impl<T: Ord> Eq for Waiter<T> {}
      78              : 
      79              : /// A tool for waiting on a sequence number
      80              : ///
      81              : /// This provides a way to wait the arrival of a number.
      82              : /// As soon as the number arrives by another caller calling
      83              : /// [`advance`], then the waiter will be woken up.
      84              : ///
      85              : /// This implementation takes a blocking Mutex on both [`wait_for`]
      86              : /// and [`advance`], meaning there may be unexpected executor blocking
      87              : /// due to thread scheduling unfairness. There are probably better
      88              : /// implementations, but we can probably live with this for now.
      89              : ///
      90              : /// [`wait_for`]: SeqWait::wait_for
      91              : /// [`advance`]: SeqWait::advance
      92              : ///
      93              : /// `S` means Storage, `V` is type of counter that this storage exposes.
      94              : ///
      95              : pub struct SeqWait<S, V>
      96              : where
      97              :     S: MonotonicCounter<V>,
      98              :     V: Ord,
      99              : {
     100              :     internal: Mutex<SeqWaitInt<S, V>>,
     101              : }
     102              : 
     103              : impl<S, V> SeqWait<S, V>
     104              : where
     105              :     S: MonotonicCounter<V> + Copy,
     106              :     V: Ord + Copy,
     107              : {
     108              :     /// Create a new `SeqWait`, initialized to a particular number
     109         2584 :     pub fn new(starting_num: S) -> Self {
     110         2584 :         let internal = SeqWaitInt {
     111         2584 :             waiters: BinaryHeap::new(),
     112         2584 :             current: starting_num,
     113         2584 :             shutdown: false,
     114         2584 :         };
     115         2584 :         SeqWait {
     116         2584 :             internal: Mutex::new(internal),
     117         2584 :         }
     118         2584 :     }
     119              : 
     120              :     /// Shut down a `SeqWait`, causing all waiters (present and
     121              :     /// future) to return an error.
     122          830 :     pub fn shutdown(&self) {
     123          830 :         let waiters = {
     124          830 :             // Prevent new waiters; wake all those that exist.
     125          830 :             // Wake everyone with an error.
     126          830 :             let mut internal = self.internal.lock().unwrap();
     127          830 : 
     128          830 :             // Block any future waiters from starting
     129          830 :             internal.shutdown = true;
     130          830 : 
     131          830 :             // This will steal the entire waiters map.
     132          830 :             // When we drop it all waiters will be woken.
     133          830 :             mem::take(&mut internal.waiters)
     134          830 : 
     135          830 :             // Drop the lock as we exit this scope.
     136          830 :         };
     137          830 : 
     138          830 :         // When we drop the waiters list, each Receiver will
     139          830 :         // be woken with an error.
     140          830 :         // This drop doesn't need to be explicit; it's done
     141          830 :         // here to make it easier to read the code and understand
     142          830 :         // the order of events.
     143          830 :         drop(waiters);
     144          830 :     }
     145              : 
     146              :     /// Wait for a number to arrive
     147              :     ///
     148              :     /// This call won't complete until someone has called `advance`
     149              :     /// with a number greater than or equal to the one we're waiting for.
     150              :     ///
     151              :     /// This function is async cancellation-safe.
     152          492 :     pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
     153          485 :         match self.queue_for_wait(num) {
     154            2 :             Ok(None) => Ok(()),
     155          483 :             Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
     156            0 :             Err(e) => Err(e),
     157              :         }
     158            9 :     }
     159              : 
     160              :     /// Wait for a number to arrive
     161              :     ///
     162              :     /// This call won't complete until someone has called `advance`
     163              :     /// with a number greater than or equal to the one we're waiting for.
     164              :     ///
     165              :     /// If that hasn't happened after the specified timeout duration,
     166              :     /// [`SeqWaitError::Timeout`] will be returned.
     167              :     ///
     168              :     /// This function is async cancellation-safe.
     169      1358731 :     pub async fn wait_for_timeout(
     170      1358731 :         &self,
     171      1358731 :         num: V,
     172      1358731 :         timeout_duration: Duration,
     173      1358731 :     ) -> Result<(), SeqWaitError> {
     174      1358731 :         match self.queue_for_wait(num) {
     175      1298941 :             Ok(None) => Ok(()),
     176       113027 :             Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
     177        59760 :                 Ok(Ok(())) => Ok(()),
     178            0 :                 Ok(Err(_)) => Err(SeqWaitError::Shutdown),
     179           28 :                 Err(_) => Err(SeqWaitError::Timeout),
     180              :             },
     181            0 :             Err(e) => Err(e),
     182              :         }
     183      1358729 :     }
     184              : 
     185              :     /// Register and return a channel that will be notified when a number arrives,
     186              :     /// or None, if it has already arrived.
     187      1359216 :     fn queue_for_wait(&self, num: V) -> Result<Option<Receiver<()>>, SeqWaitError> {
     188      1359216 :         let mut internal = self.internal.lock().unwrap();
     189      1359216 :         if internal.current.cnt_value() >= num {
     190      1298943 :             return Ok(None);
     191        60273 :         }
     192        60273 :         if internal.shutdown {
     193            0 :             return Err(SeqWaitError::Shutdown);
     194        60273 :         }
     195        60273 : 
     196        60273 :         // Create a new channel.
     197        60273 :         let (tx, rx) = channel(());
     198        60273 :         internal.waiters.push(Waiter {
     199        60273 :             wake_num: num,
     200        60273 :             wake_channel: tx,
     201        60273 :         });
     202        60273 :         // Drop the lock as we exit this scope.
     203        60273 :         Ok(Some(rx))
     204      1359216 :     }
     205              : 
     206              :     /// Announce a new number has arrived
     207              :     ///
     208              :     /// All waiters at this value or below will be woken.
     209              :     ///
     210              :     /// Returns the old number.
     211     76570822 :     pub fn advance(&self, num: V) -> V {
     212              :         let old_value;
     213     74771397 :         let wake_these = {
     214     76570822 :             let mut internal = self.internal.lock().unwrap();
     215     76570822 : 
     216     76570822 :             old_value = internal.current.cnt_value();
     217     76570822 :             if old_value >= num {
     218      1799425 :                 return old_value;
     219     74771397 :             }
     220     74771397 :             internal.current.cnt_advance(num);
     221     74771397 : 
     222     74771397 :             // Pop all waiters <= num from the heap. Collect them in a vector, and
     223     74771397 :             // wake them up after releasing the lock.
     224     74771397 :             let mut wake_these = Vec::new();
     225     74831166 :             while let Some(n) = internal.waiters.peek() {
     226     12310981 :                 if n.wake_num > num {
     227     12251212 :                     break;
     228        59769 :                 }
     229        59769 :                 wake_these.push(internal.waiters.pop().unwrap().wake_channel);
     230              :             }
     231     74771397 :             wake_these
     232              :         };
     233              : 
     234     74831166 :         for tx in wake_these {
     235        59769 :             // This can fail if there are no receivers.
     236        59769 :             // We don't care; discard the error.
     237        59769 :             let _ = tx.send(());
     238        59769 :         }
     239     74771397 :         old_value
     240     76570822 :     }
     241              : 
     242              :     /// Read the current value, without waiting.
     243     10871532 :     pub fn load(&self) -> S {
     244     10871532 :         self.internal.lock().unwrap().current
     245     10871532 :     }
     246              : }
     247              : 
     248              : #[cfg(test)]
     249              : mod tests {
     250              :     use super::*;
     251              :     use std::sync::Arc;
     252              :     use std::time::Duration;
     253              : 
     254              :     impl MonotonicCounter<i32> for i32 {
     255            6 :         fn cnt_advance(&mut self, val: i32) {
     256            6 :             assert!(*self <= val);
     257            6 :             *self = val;
     258            6 :         }
     259           20 :         fn cnt_value(&self) -> i32 {
     260           20 :             *self
     261           20 :         }
     262              :     }
     263              : 
     264            2 :     #[tokio::test]
     265            2 :     async fn seqwait() {
     266            2 :         let seq = Arc::new(SeqWait::new(0));
     267            2 :         let seq2 = Arc::clone(&seq);
     268            2 :         let seq3 = Arc::clone(&seq);
     269            2 :         let jh1 = tokio::task::spawn(async move {
     270            2 :             seq2.wait_for(42).await.expect("wait_for 42");
     271            2 :             let old = seq2.advance(100);
     272            2 :             assert_eq!(old, 99);
     273            2 :             seq2.wait_for_timeout(999, Duration::from_millis(100))
     274            2 :                 .await
     275            2 :                 .expect_err("no 999");
     276            2 :         });
     277            2 :         let jh2 = tokio::task::spawn(async move {
     278            2 :             seq3.wait_for(42).await.expect("wait_for 42");
     279            2 :             seq3.wait_for(0).await.expect("wait_for 0");
     280            2 :         });
     281            2 :         tokio::time::sleep(Duration::from_millis(200)).await;
     282            2 :         let old = seq.advance(99);
     283            2 :         assert_eq!(old, 0);
     284            2 :         seq.wait_for(100).await.expect("wait_for 100");
     285            2 : 
     286            2 :         // Calling advance with a smaller value is a no-op
     287            2 :         assert_eq!(seq.advance(98), 100);
     288            2 :         assert_eq!(seq.load(), 100);
     289              : 
     290            2 :         jh1.await.unwrap();
     291            2 :         jh2.await.unwrap();
     292            2 : 
     293            2 :         seq.shutdown();
     294              :     }
     295              : 
     296            2 :     #[tokio::test]
     297            2 :     async fn seqwait_timeout() {
     298            2 :         let seq = Arc::new(SeqWait::new(0));
     299            2 :         let seq2 = Arc::clone(&seq);
     300            2 :         let jh = tokio::task::spawn(async move {
     301            2 :             let timeout = Duration::from_millis(1);
     302            2 :             let res = seq2.wait_for_timeout(42, timeout).await;
     303            2 :             assert_eq!(res, Err(SeqWaitError::Timeout));
     304            2 :         });
     305            2 :         tokio::time::sleep(Duration::from_millis(200)).await;
     306              :         // This will attempt to wake, but nothing will happen
     307              :         // because the waiter already dropped its Receiver.
     308            2 :         let old = seq.advance(99);
     309            2 :         assert_eq!(old, 0);
     310            2 :         jh.await.unwrap();
     311            2 : 
     312            2 :         seq.shutdown();
     313              :     }
     314              : }
        

Generated by: LCOV version 2.1-beta