|             Line data    Source code 
       1              : use tokio_util::task::TaskTracker;
       2              : use tokio_util::task::task_tracker::TaskTrackerToken;
       3              : 
       4              : /// While a reference is kept around, the associated [`Barrier::wait`] will wait.
       5              : ///
       6              : /// Can be cloned, moved and kept around in futures as "guard objects".
       7              : #[derive(Clone)]
       8              : pub struct Completion {
       9              :     token: TaskTrackerToken,
      10              : }
      11              : 
      12              : impl std::fmt::Debug for Completion {
      13            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      14            0 :         f.debug_struct("Completion")
      15            0 :             .field("siblings", &self.token.task_tracker().len())
      16            0 :             .finish()
      17            0 :     }
      18              : }
      19              : 
      20              : impl Completion {
      21              :     /// Returns true if this completion is associated with the given barrier.
      22            0 :     pub fn blocks(&self, barrier: &Barrier) -> bool {
      23            0 :         TaskTracker::ptr_eq(self.token.task_tracker(), &barrier.0)
      24            0 :     }
      25              : 
      26            0 :     pub fn barrier(&self) -> Barrier {
      27            0 :         Barrier(self.token.task_tracker().clone())
      28            0 :     }
      29              : }
      30              : 
      31              : /// Barrier will wait until all clones of [`Completion`] have been dropped.
      32              : #[derive(Clone)]
      33              : pub struct Barrier(TaskTracker);
      34              : 
      35              : impl std::fmt::Debug for Barrier {
      36            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      37            0 :         f.debug_struct("Barrier")
      38            0 :             .field("remaining", &self.0.len())
      39            0 :             .finish()
      40            0 :     }
      41              : }
      42              : 
      43              : impl Default for Barrier {
      44            3 :     fn default() -> Self {
      45            3 :         let (_, rx) = channel();
      46            3 :         rx
      47            3 :     }
      48              : }
      49              : 
      50              : impl Barrier {
      51         3623 :     pub async fn wait(self) {
      52         3623 :         self.0.wait().await;
      53         3623 :     }
      54              : 
      55            0 :     pub async fn maybe_wait(barrier: Option<Barrier>) {
      56            0 :         if let Some(b) = barrier {
      57            0 :             b.wait().await
      58            0 :         }
      59            0 :     }
      60              : 
      61              :     /// Return true if a call to wait() would complete immediately
      62            0 :     pub fn is_ready(&self) -> bool {
      63            0 :         futures::future::FutureExt::now_or_never(self.0.wait()).is_some()
      64            0 :     }
      65              : }
      66              : 
      67              : impl PartialEq for Barrier {
      68            0 :     fn eq(&self, other: &Self) -> bool {
      69            0 :         TaskTracker::ptr_eq(&self.0, &other.0)
      70            0 :     }
      71              : }
      72              : 
      73              : impl Eq for Barrier {}
      74              : 
      75              : /// Create new Guard and Barrier pair.
      76           50 : pub fn channel() -> (Completion, Barrier) {
      77           50 :     let tracker = TaskTracker::new();
      78              :     // otherwise wait never exits
      79           50 :     tracker.close();
      80              : 
      81           50 :     let token = tracker.token();
      82           50 :     (Completion { token }, Barrier(tracker))
      83           50 : }
         |