LCOV - code coverage report
Current view: top level - libs/utils/src - seqwait.rs (source / functions) Coverage Total Hit
Test: 20b6afc7b7f34578dcaab2b3acdaecfe91cd8bf1.info Lines: 87.2 % 203 177
Test Date: 2024-11-25 17:48:16 Functions: 47.0 % 83 39

            Line data    Source code
       1              : #![warn(missing_docs)]
       2              : 
       3              : use std::cmp::{Eq, Ordering};
       4              : use std::collections::BinaryHeap;
       5              : use std::mem;
       6              : use std::sync::Mutex;
       7              : use std::time::Duration;
       8              : use tokio::sync::watch::{self, channel};
       9              : use tokio::time::timeout;
      10              : 
      11              : /// An error happened while waiting for a number
      12            0 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
      13              : pub enum SeqWaitError {
      14              :     /// The wait timeout was reached
      15              :     #[error("seqwait timeout was reached")]
      16              :     Timeout,
      17              : 
      18              :     /// [`SeqWait::shutdown`] was called
      19              :     #[error("SeqWait::shutdown was called")]
      20              :     Shutdown,
      21              : }
      22              : 
      23              : /// Monotonically increasing value
      24              : ///
      25              : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
      26              : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
      27              : /// any type that can expose counter. `V` is the type of exposed counter.
      28              : pub trait MonotonicCounter<V> {
      29              :     /// Bump counter value and check that it goes forward
      30              :     /// N.B.: new_val is an actual new value, not a difference.
      31              :     fn cnt_advance(&mut self, new_val: V);
      32              : 
      33              :     /// Get counter value
      34              :     fn cnt_value(&self) -> V;
      35              : }
      36              : 
      37              : /// Heap of waiters, lowest numbers pop first.
      38              : struct Waiters<V>
      39              : where
      40              :     V: Ord,
      41              : {
      42              :     heap: BinaryHeap<Waiter<V>>,
      43              :     /// Number of the first waiter in the heap, or None if there are no waiters.
      44              :     status_channel: watch::Sender<Option<V>>,
      45              : }
      46              : 
      47              : impl<V> Waiters<V>
      48              : where
      49              :     V: Ord + Copy,
      50              : {
      51        25458 :     fn new() -> Self {
      52        25458 :         Waiters {
      53        25458 :             heap: BinaryHeap::new(),
      54        25458 :             status_channel: channel(None).0,
      55        25458 :         }
      56        25458 :     }
      57              : 
      58              :     /// `status_channel` contains the number of the first waiter in the heap.
      59              :     /// This function should be called whenever waiters heap changes.
      60           20 :     fn update_status(&self) {
      61           20 :         let first_waiter = self.heap.peek().map(|w| w.wake_num);
      62           20 :         let _ = self.status_channel.send_replace(first_waiter);
      63           20 :     }
      64              : 
      65              :     /// Add new waiter to the heap, return a channel that will be notified when the number arrives.
      66            5 :     fn add(&mut self, num: V) -> watch::Receiver<()> {
      67            5 :         let (tx, rx) = channel(());
      68            5 :         self.heap.push(Waiter {
      69            5 :             wake_num: num,
      70            5 :             wake_channel: tx,
      71            5 :         });
      72            5 :         self.update_status();
      73            5 :         rx
      74            5 :     }
      75              : 
      76              :     /// Pop all waiters <= num from the heap. Collect channels in a vector,
      77              :     /// so that caller can wake them up.
      78      4804419 :     fn pop_leq(&mut self, num: V) -> Vec<watch::Sender<()>> {
      79      4804419 :         let mut wake_these = Vec::new();
      80      4804423 :         while let Some(n) = self.heap.peek() {
      81            4 :             if n.wake_num > num {
      82            0 :                 break;
      83            4 :             }
      84            4 :             wake_these.push(self.heap.pop().unwrap().wake_channel);
      85              :         }
      86      4804419 :         if !wake_these.is_empty() {
      87            3 :             self.update_status();
      88      4804416 :         }
      89      4804419 :         wake_these
      90      4804419 :     }
      91              : 
      92              :     /// Used on shutdown to efficiently drop all waiters.
      93           12 :     fn take_all(&mut self) -> BinaryHeap<Waiter<V>> {
      94           12 :         let heap = mem::take(&mut self.heap);
      95           12 :         self.update_status();
      96           12 :         heap
      97           12 :     }
      98              : }
      99              : 
     100              : struct Waiter<T>
     101              : where
     102              :     T: Ord,
     103              : {
     104              :     wake_num: T,                     // wake me when this number arrives ...
     105              :     wake_channel: watch::Sender<()>, // ... by sending a message to this channel
     106              : }
     107              : 
     108              : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
     109              : // to get that.
     110              : impl<T: Ord> PartialOrd for Waiter<T> {
     111            1 :     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
     112            1 :         Some(self.cmp(other))
     113            1 :     }
     114              : }
     115              : 
     116              : impl<T: Ord> Ord for Waiter<T> {
     117            1 :     fn cmp(&self, other: &Self) -> Ordering {
     118            1 :         other.wake_num.cmp(&self.wake_num)
     119            1 :     }
     120              : }
     121              : 
     122              : impl<T: Ord> PartialEq for Waiter<T> {
     123            0 :     fn eq(&self, other: &Self) -> bool {
     124            0 :         other.wake_num == self.wake_num
     125            0 :     }
     126              : }
     127              : 
     128              : impl<T: Ord> Eq for Waiter<T> {}
     129              : 
     130              : /// Internal components of a `SeqWait`
     131              : struct SeqWaitInt<S, V>
     132              : where
     133              :     S: MonotonicCounter<V>,
     134              :     V: Ord,
     135              : {
     136              :     waiters: Waiters<V>,
     137              :     current: S,
     138              :     shutdown: bool,
     139              : }
     140              : 
     141              : /// A tool for waiting on a sequence number
     142              : ///
     143              : /// This provides a way to wait the arrival of a number.
     144              : /// As soon as the number arrives by another caller calling
     145              : /// [`advance`], then the waiter will be woken up.
     146              : ///
     147              : /// This implementation takes a blocking Mutex on both [`wait_for`]
     148              : /// and [`advance`], meaning there may be unexpected executor blocking
     149              : /// due to thread scheduling unfairness. There are probably better
     150              : /// implementations, but we can probably live with this for now.
     151              : ///
     152              : /// [`wait_for`]: SeqWait::wait_for
     153              : /// [`advance`]: SeqWait::advance
     154              : ///
     155              : /// `S` means Storage, `V` is type of counter that this storage exposes.
     156              : ///
     157              : pub struct SeqWait<S, V>
     158              : where
     159              :     S: MonotonicCounter<V>,
     160              :     V: Ord,
     161              : {
     162              :     internal: Mutex<SeqWaitInt<S, V>>,
     163              : }
     164              : 
     165              : impl<S, V> SeqWait<S, V>
     166              : where
     167              :     S: MonotonicCounter<V> + Copy,
     168              :     V: Ord + Copy,
     169              : {
     170              :     /// Create a new `SeqWait`, initialized to a particular number
     171        25458 :     pub fn new(starting_num: S) -> Self {
     172        25458 :         let internal = SeqWaitInt {
     173        25458 :             waiters: Waiters::new(),
     174        25458 :             current: starting_num,
     175        25458 :             shutdown: false,
     176        25458 :         };
     177        25458 :         SeqWait {
     178        25458 :             internal: Mutex::new(internal),
     179        25458 :         }
     180        25458 :     }
     181              : 
     182              :     /// Shut down a `SeqWait`, causing all waiters (present and
     183              :     /// future) to return an error.
     184           12 :     pub fn shutdown(&self) {
     185           12 :         let waiters = {
     186           12 :             // Prevent new waiters; wake all those that exist.
     187           12 :             // Wake everyone with an error.
     188           12 :             let mut internal = self.internal.lock().unwrap();
     189           12 : 
     190           12 :             // Block any future waiters from starting
     191           12 :             internal.shutdown = true;
     192           12 : 
     193           12 :             // Take all waiters to drop them later.
     194           12 :             internal.waiters.take_all()
     195           12 : 
     196           12 :             // Drop the lock as we exit this scope.
     197           12 :         };
     198           12 : 
     199           12 :         // When we drop the waiters list, each Receiver will
     200           12 :         // be woken with an error.
     201           12 :         // This drop doesn't need to be explicit; it's done
     202           12 :         // here to make it easier to read the code and understand
     203           12 :         // the order of events.
     204           12 :         drop(waiters);
     205           12 :     }
     206              : 
     207              :     /// Wait for a number to arrive
     208              :     ///
     209              :     /// This call won't complete until someone has called `advance`
     210              :     /// with a number greater than or equal to the one we're waiting for.
     211              :     ///
     212              :     /// This function is async cancellation-safe.
     213            4 :     pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
     214            4 :         match self.queue_for_wait(num) {
     215            1 :             Ok(None) => Ok(()),
     216            3 :             Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
     217            0 :             Err(e) => Err(e),
     218              :         }
     219            4 :     }
     220              : 
     221              :     /// Wait for a number to arrive
     222              :     ///
     223              :     /// This call won't complete until someone has called `advance`
     224              :     /// with a number greater than or equal to the one we're waiting for.
     225              :     ///
     226              :     /// If that hasn't happened after the specified timeout duration,
     227              :     /// [`SeqWaitError::Timeout`] will be returned.
     228              :     ///
     229              :     /// This function is async cancellation-safe.
     230       226558 :     pub async fn wait_for_timeout(
     231       226558 :         &self,
     232       226558 :         num: V,
     233       226558 :         timeout_duration: Duration,
     234       226558 :     ) -> Result<(), SeqWaitError> {
     235       226558 :         match self.queue_for_wait(num) {
     236       226556 :             Ok(None) => Ok(()),
     237            2 :             Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
     238            0 :                 Ok(Ok(())) => Ok(()),
     239            0 :                 Ok(Err(_)) => Err(SeqWaitError::Shutdown),
     240            2 :                 Err(_) => Err(SeqWaitError::Timeout),
     241              :             },
     242            0 :             Err(e) => Err(e),
     243              :         }
     244       226558 :     }
     245              : 
     246              :     /// Check if [`Self::wait_for`] or [`Self::wait_for_timeout`] would wait if called with `num`.
     247            0 :     pub fn would_wait_for(&self, num: V) -> Result<(), V> {
     248            0 :         let internal = self.internal.lock().unwrap();
     249            0 :         let cnt = internal.current.cnt_value();
     250            0 :         drop(internal);
     251            0 :         if cnt >= num {
     252            0 :             Ok(())
     253              :         } else {
     254            0 :             Err(cnt)
     255              :         }
     256            0 :     }
     257              : 
     258              :     /// Register and return a channel that will be notified when a number arrives,
     259              :     /// or None, if it has already arrived.
     260       226562 :     fn queue_for_wait(&self, num: V) -> Result<Option<watch::Receiver<()>>, SeqWaitError> {
     261       226562 :         let mut internal = self.internal.lock().unwrap();
     262       226562 :         if internal.current.cnt_value() >= num {
     263       226557 :             return Ok(None);
     264            5 :         }
     265            5 :         if internal.shutdown {
     266            0 :             return Err(SeqWaitError::Shutdown);
     267            5 :         }
     268            5 : 
     269            5 :         // Add waiter channel to the queue.
     270            5 :         let rx = internal.waiters.add(num);
     271            5 :         // Drop the lock as we exit this scope.
     272            5 :         Ok(Some(rx))
     273       226562 :     }
     274              : 
     275              :     /// Announce a new number has arrived
     276              :     ///
     277              :     /// All waiters at this value or below will be woken.
     278              :     ///
     279              :     /// Returns the old number.
     280      5279108 :     pub fn advance(&self, num: V) -> V {
     281              :         let old_value;
     282      4804419 :         let wake_these = {
     283      5279108 :             let mut internal = self.internal.lock().unwrap();
     284      5279108 : 
     285      5279108 :             old_value = internal.current.cnt_value();
     286      5279108 :             if old_value >= num {
     287       474689 :                 return old_value;
     288      4804419 :             }
     289      4804419 :             internal.current.cnt_advance(num);
     290      4804419 : 
     291      4804419 :             // Pop all waiters <= num from the heap.
     292      4804419 :             internal.waiters.pop_leq(num)
     293              :         };
     294              : 
     295      4804423 :         for tx in wake_these {
     296            4 :             // This can fail if there are no receivers.
     297            4 :             // We don't care; discard the error.
     298            4 :             let _ = tx.send(());
     299            4 :         }
     300      4804419 :         old_value
     301      5279108 :     }
     302              : 
     303              :     /// Read the current value, without waiting.
     304       275665 :     pub fn load(&self) -> S {
     305       275665 :         self.internal.lock().unwrap().current
     306       275665 :     }
     307              : 
     308              :     /// Get a Receiver for the current status.
     309              :     ///
     310              :     /// The current status is the number of the first waiter in the queue,
     311              :     /// or None if there are no waiters.
     312              :     ///
     313              :     /// This receiver will be notified whenever the status changes.
     314              :     /// It is useful for receiving notifications when the first waiter
     315              :     /// starts waiting for a number, or when there are no more waiters left.
     316            0 :     pub fn status_receiver(&self) -> watch::Receiver<Option<V>> {
     317            0 :         self.internal
     318            0 :             .lock()
     319            0 :             .unwrap()
     320            0 :             .waiters
     321            0 :             .status_channel
     322            0 :             .subscribe()
     323            0 :     }
     324              : }
     325              : 
     326              : #[cfg(test)]
     327              : mod tests {
     328              :     use super::*;
     329              :     use std::sync::Arc;
     330              : 
     331              :     impl MonotonicCounter<i32> for i32 {
     332            3 :         fn cnt_advance(&mut self, val: i32) {
     333            3 :             assert!(*self <= val);
     334            3 :             *self = val;
     335            3 :         }
     336           10 :         fn cnt_value(&self) -> i32 {
     337           10 :             *self
     338           10 :         }
     339              :     }
     340              : 
     341              :     #[tokio::test]
     342            1 :     async fn seqwait() {
     343            1 :         let seq = Arc::new(SeqWait::new(0));
     344            1 :         let seq2 = Arc::clone(&seq);
     345            1 :         let seq3 = Arc::clone(&seq);
     346            1 :         let jh1 = tokio::task::spawn(async move {
     347            1 :             seq2.wait_for(42).await.expect("wait_for 42");
     348            1 :             let old = seq2.advance(100);
     349            1 :             assert_eq!(old, 99);
     350            1 :             seq2.wait_for_timeout(999, Duration::from_millis(100))
     351            1 :                 .await
     352            1 :                 .expect_err("no 999");
     353            1 :         });
     354            1 :         let jh2 = tokio::task::spawn(async move {
     355            1 :             seq3.wait_for(42).await.expect("wait_for 42");
     356            1 :             seq3.wait_for(0).await.expect("wait_for 0");
     357            1 :         });
     358            1 :         tokio::time::sleep(Duration::from_millis(200)).await;
     359            1 :         let old = seq.advance(99);
     360            1 :         assert_eq!(old, 0);
     361            1 :         seq.wait_for(100).await.expect("wait_for 100");
     362            1 : 
     363            1 :         // Calling advance with a smaller value is a no-op
     364            1 :         assert_eq!(seq.advance(98), 100);
     365            1 :         assert_eq!(seq.load(), 100);
     366            1 : 
     367            1 :         jh1.await.unwrap();
     368            1 :         jh2.await.unwrap();
     369            1 : 
     370            1 :         seq.shutdown();
     371            1 :     }
     372              : 
     373              :     #[tokio::test]
     374            1 :     async fn seqwait_timeout() {
     375            1 :         let seq = Arc::new(SeqWait::new(0));
     376            1 :         let seq2 = Arc::clone(&seq);
     377            1 :         let jh = tokio::task::spawn(async move {
     378            1 :             let timeout = Duration::from_millis(1);
     379            1 :             let res = seq2.wait_for_timeout(42, timeout).await;
     380            1 :             assert_eq!(res, Err(SeqWaitError::Timeout));
     381            1 :         });
     382            1 :         tokio::time::sleep(Duration::from_millis(200)).await;
     383            1 :         // This will attempt to wake, but nothing will happen
     384            1 :         // because the waiter already dropped its Receiver.
     385            1 :         let old = seq.advance(99);
     386            1 :         assert_eq!(old, 0);
     387            1 :         jh.await.unwrap();
     388            1 : 
     389            1 :         seq.shutdown();
     390            1 :     }
     391              : }
        

Generated by: LCOV version 2.1-beta