LCOV - code coverage report
Current view: top level - libs/utils/src - seqwait.rs (source / functions) Coverage Total Hit
Test: 322b88762cba8ea666f63cda880cccab6936bf37.info Lines: 94.5 % 165 156
Test Date: 2024-02-29 11:57:12 Functions: 51.8 % 56 29

            Line data    Source code
       1              : #![warn(missing_docs)]
       2              : 
       3              : use std::cmp::{Eq, Ordering, PartialOrd};
       4              : use std::collections::BinaryHeap;
       5              : use std::fmt::Debug;
       6              : use std::mem;
       7              : use std::sync::Mutex;
       8              : use std::time::Duration;
       9              : use tokio::sync::watch::{channel, Receiver, Sender};
      10              : use tokio::time::timeout;
      11              : 
      12              : /// An error happened while waiting for a number
      13            2 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
      14              : pub enum SeqWaitError {
      15              :     /// The wait timeout was reached
      16              :     #[error("seqwait timeout was reached")]
      17              :     Timeout,
      18              : 
      19              :     /// [`SeqWait::shutdown`] was called
      20              :     #[error("SeqWait::shutdown was called")]
      21              :     Shutdown,
      22              : }
      23              : 
      24              : /// Monotonically increasing value
      25              : ///
      26              : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
      27              : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
      28              : /// any type that can expose counter. `V` is the type of exposed counter.
      29              : pub trait MonotonicCounter<V> {
      30              :     /// Bump counter value and check that it goes forward
      31              :     /// N.B.: new_val is an actual new value, not a difference.
      32              :     fn cnt_advance(&mut self, new_val: V);
      33              : 
      34              :     /// Get counter value
      35              :     fn cnt_value(&self) -> V;
      36              : }
      37              : 
      38              : /// Internal components of a `SeqWait`
      39              : struct SeqWaitInt<S, V>
      40              : where
      41              :     S: MonotonicCounter<V>,
      42              :     V: Ord,
      43              : {
      44              :     waiters: BinaryHeap<Waiter<V>>,
      45              :     current: S,
      46              :     shutdown: bool,
      47              : }
      48              : 
      49              : struct Waiter<T>
      50              : where
      51              :     T: Ord,
      52              : {
      53              :     wake_num: T,              // wake me when this number arrives ...
      54              :     wake_channel: Sender<()>, // ... by sending a message to this channel
      55              : }
      56              : 
      57              : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
      58              : // to get that.
      59              : impl<T: Ord> PartialOrd for Waiter<T> {
      60            2 :     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
      61            2 :         Some(self.cmp(other))
      62            2 :     }
      63              : }
      64              : 
      65              : impl<T: Ord> Ord for Waiter<T> {
      66            2 :     fn cmp(&self, other: &Self) -> Ordering {
      67            2 :         other.wake_num.cmp(&self.wake_num)
      68            2 :     }
      69              : }
      70              : 
      71              : impl<T: Ord> PartialEq for Waiter<T> {
      72            0 :     fn eq(&self, other: &Self) -> bool {
      73            0 :         other.wake_num == self.wake_num
      74            0 :     }
      75              : }
      76              : 
      77              : impl<T: Ord> Eq for Waiter<T> {}
      78              : 
      79              : /// A tool for waiting on a sequence number
      80              : ///
      81              : /// This provides a way to wait the arrival of a number.
      82              : /// As soon as the number arrives by another caller calling
      83              : /// [`advance`], then the waiter will be woken up.
      84              : ///
      85              : /// This implementation takes a blocking Mutex on both [`wait_for`]
      86              : /// and [`advance`], meaning there may be unexpected executor blocking
      87              : /// due to thread scheduling unfairness. There are probably better
      88              : /// implementations, but we can probably live with this for now.
      89              : ///
      90              : /// [`wait_for`]: SeqWait::wait_for
      91              : /// [`advance`]: SeqWait::advance
      92              : ///
      93              : /// `S` means Storage, `V` is type of counter that this storage exposes.
      94              : ///
      95              : pub struct SeqWait<S, V>
      96              : where
      97              :     S: MonotonicCounter<V>,
      98              :     V: Ord,
      99              : {
     100              :     internal: Mutex<SeqWaitInt<S, V>>,
     101              : }
     102              : 
     103              : impl<S, V> SeqWait<S, V>
     104              : where
     105              :     S: MonotonicCounter<V> + Copy,
     106              :     V: Ord + Copy,
     107              : {
     108              :     /// Create a new `SeqWait`, initialized to a particular number
     109          304 :     pub fn new(starting_num: S) -> Self {
     110          304 :         let internal = SeqWaitInt {
     111          304 :             waiters: BinaryHeap::new(),
     112          304 :             current: starting_num,
     113          304 :             shutdown: false,
     114          304 :         };
     115          304 :         SeqWait {
     116          304 :             internal: Mutex::new(internal),
     117          304 :         }
     118          304 :     }
     119              : 
     120              :     /// Shut down a `SeqWait`, causing all waiters (present and
     121              :     /// future) to return an error.
     122           18 :     pub fn shutdown(&self) {
     123           18 :         let waiters = {
     124           18 :             // Prevent new waiters; wake all those that exist.
     125           18 :             // Wake everyone with an error.
     126           18 :             let mut internal = self.internal.lock().unwrap();
     127           18 : 
     128           18 :             // Block any future waiters from starting
     129           18 :             internal.shutdown = true;
     130           18 : 
     131           18 :             // This will steal the entire waiters map.
     132           18 :             // When we drop it all waiters will be woken.
     133           18 :             mem::take(&mut internal.waiters)
     134           18 : 
     135           18 :             // Drop the lock as we exit this scope.
     136           18 :         };
     137           18 : 
     138           18 :         // When we drop the waiters list, each Receiver will
     139           18 :         // be woken with an error.
     140           18 :         // This drop doesn't need to be explicit; it's done
     141           18 :         // here to make it easier to read the code and understand
     142           18 :         // the order of events.
     143           18 :         drop(waiters);
     144           18 :     }
     145              : 
     146              :     /// Wait for a number to arrive
     147              :     ///
     148              :     /// This call won't complete until someone has called `advance`
     149              :     /// with a number greater than or equal to the one we're waiting for.
     150              :     ///
     151              :     /// This function is async cancellation-safe.
     152            8 :     pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
     153            8 :         match self.queue_for_wait(num) {
     154            2 :             Ok(None) => Ok(()),
     155            6 :             Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
     156            0 :             Err(e) => Err(e),
     157              :         }
     158            8 :     }
     159              : 
     160              :     /// Wait for a number to arrive
     161              :     ///
     162              :     /// This call won't complete until someone has called `advance`
     163              :     /// with a number greater than or equal to the one we're waiting for.
     164              :     ///
     165              :     /// If that hasn't happened after the specified timeout duration,
     166              :     /// [`SeqWaitError::Timeout`] will be returned.
     167              :     ///
     168              :     /// This function is async cancellation-safe.
     169       226771 :     pub async fn wait_for_timeout(
     170       226771 :         &self,
     171       226771 :         num: V,
     172       226771 :         timeout_duration: Duration,
     173       226771 :     ) -> Result<(), SeqWaitError> {
     174       226771 :         match self.queue_for_wait(num) {
     175       226767 :             Ok(None) => Ok(()),
     176            4 :             Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
     177            0 :                 Ok(Ok(())) => Ok(()),
     178            0 :                 Ok(Err(_)) => Err(SeqWaitError::Shutdown),
     179            4 :                 Err(_) => Err(SeqWaitError::Timeout),
     180              :             },
     181            0 :             Err(e) => Err(e),
     182              :         }
     183       226771 :     }
     184              : 
     185              :     /// Register and return a channel that will be notified when a number arrives,
     186              :     /// or None, if it has already arrived.
     187       226779 :     fn queue_for_wait(&self, num: V) -> Result<Option<Receiver<()>>, SeqWaitError> {
     188       226779 :         let mut internal = self.internal.lock().unwrap();
     189       226779 :         if internal.current.cnt_value() >= num {
     190       226769 :             return Ok(None);
     191           10 :         }
     192           10 :         if internal.shutdown {
     193            0 :             return Err(SeqWaitError::Shutdown);
     194           10 :         }
     195           10 : 
     196           10 :         // Create a new channel.
     197           10 :         let (tx, rx) = channel(());
     198           10 :         internal.waiters.push(Waiter {
     199           10 :             wake_num: num,
     200           10 :             wake_channel: tx,
     201           10 :         });
     202           10 :         // Drop the lock as we exit this scope.
     203           10 :         Ok(Some(rx))
     204       226779 :     }
     205              : 
     206              :     /// Announce a new number has arrived
     207              :     ///
     208              :     /// All waiters at this value or below will be woken.
     209              :     ///
     210              :     /// Returns the old number.
     211      3102920 :     pub fn advance(&self, num: V) -> V {
     212              :         let old_value;
     213      2628230 :         let wake_these = {
     214      3102920 :             let mut internal = self.internal.lock().unwrap();
     215      3102920 : 
     216      3102920 :             old_value = internal.current.cnt_value();
     217      3102920 :             if old_value >= num {
     218       474690 :                 return old_value;
     219      2628230 :             }
     220      2628230 :             internal.current.cnt_advance(num);
     221      2628230 : 
     222      2628230 :             // Pop all waiters <= num from the heap. Collect them in a vector, and
     223      2628230 :             // wake them up after releasing the lock.
     224      2628230 :             let mut wake_these = Vec::new();
     225      2628238 :             while let Some(n) = internal.waiters.peek() {
     226            8 :                 if n.wake_num > num {
     227            0 :                     break;
     228            8 :                 }
     229            8 :                 wake_these.push(internal.waiters.pop().unwrap().wake_channel);
     230              :             }
     231      2628230 :             wake_these
     232              :         };
     233              : 
     234      2628238 :         for tx in wake_these {
     235            8 :             // This can fail if there are no receivers.
     236            8 :             // We don't care; discard the error.
     237            8 :             let _ = tx.send(());
     238            8 :         }
     239      2628230 :         old_value
     240      3102920 :     }
     241              : 
     242              :     /// Read the current value, without waiting.
     243      2901270 :     pub fn load(&self) -> S {
     244      2901270 :         self.internal.lock().unwrap().current
     245      2901270 :     }
     246              : }
     247              : 
     248              : #[cfg(test)]
     249              : mod tests {
     250              :     use super::*;
     251              :     use std::sync::Arc;
     252              :     use std::time::Duration;
     253              : 
     254              :     impl MonotonicCounter<i32> for i32 {
     255            6 :         fn cnt_advance(&mut self, val: i32) {
     256            6 :             assert!(*self <= val);
     257            6 :             *self = val;
     258            6 :         }
     259           20 :         fn cnt_value(&self) -> i32 {
     260           20 :             *self
     261           20 :         }
     262              :     }
     263              : 
     264            2 :     #[tokio::test]
     265            2 :     async fn seqwait() {
     266            2 :         let seq = Arc::new(SeqWait::new(0));
     267            2 :         let seq2 = Arc::clone(&seq);
     268            2 :         let seq3 = Arc::clone(&seq);
     269            2 :         let jh1 = tokio::task::spawn(async move {
     270            2 :             seq2.wait_for(42).await.expect("wait_for 42");
     271            2 :             let old = seq2.advance(100);
     272            2 :             assert_eq!(old, 99);
     273            2 :             seq2.wait_for_timeout(999, Duration::from_millis(100))
     274            2 :                 .await
     275            2 :                 .expect_err("no 999");
     276            2 :         });
     277            2 :         let jh2 = tokio::task::spawn(async move {
     278            2 :             seq3.wait_for(42).await.expect("wait_for 42");
     279            2 :             seq3.wait_for(0).await.expect("wait_for 0");
     280            2 :         });
     281            2 :         tokio::time::sleep(Duration::from_millis(200)).await;
     282            2 :         let old = seq.advance(99);
     283            2 :         assert_eq!(old, 0);
     284            2 :         seq.wait_for(100).await.expect("wait_for 100");
     285            2 : 
     286            2 :         // Calling advance with a smaller value is a no-op
     287            2 :         assert_eq!(seq.advance(98), 100);
     288            2 :         assert_eq!(seq.load(), 100);
     289            2 : 
     290            2 :         jh1.await.unwrap();
     291            2 :         jh2.await.unwrap();
     292            2 : 
     293            2 :         seq.shutdown();
     294            2 :     }
     295              : 
     296            2 :     #[tokio::test]
     297            2 :     async fn seqwait_timeout() {
     298            2 :         let seq = Arc::new(SeqWait::new(0));
     299            2 :         let seq2 = Arc::clone(&seq);
     300            2 :         let jh = tokio::task::spawn(async move {
     301            2 :             let timeout = Duration::from_millis(1);
     302            2 :             let res = seq2.wait_for_timeout(42, timeout).await;
     303            2 :             assert_eq!(res, Err(SeqWaitError::Timeout));
     304            2 :         });
     305            2 :         tokio::time::sleep(Duration::from_millis(200)).await;
     306            2 :         // This will attempt to wake, but nothing will happen
     307            2 :         // because the waiter already dropped its Receiver.
     308            2 :         let old = seq.advance(99);
     309            2 :         assert_eq!(old, 0);
     310            2 :         jh.await.unwrap();
     311            2 : 
     312            2 :         seq.shutdown();
     313            2 :     }
     314              : }
        

Generated by: LCOV version 2.1-beta