LCOV - code coverage report
Current view: top level - libs/utils/src/sync - spsc_fold.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 91.3 % 333 304
Test Date: 2025-02-20 13:11:02 Functions: 73.4 % 124 91

            Line data    Source code
       1              : use core::{future::poll_fn, task::Poll};
       2              : use std::sync::{Arc, Mutex};
       3              : 
       4              : use diatomic_waker::DiatomicWaker;
       5              : 
       6              : pub struct Sender<T> {
       7              :     state: Arc<Inner<T>>,
       8              : }
       9              : 
      10              : pub struct Receiver<T> {
      11              :     state: Arc<Inner<T>>,
      12              : }
      13              : 
      14              : struct Inner<T> {
      15              :     wake_receiver: DiatomicWaker,
      16              :     wake_sender: DiatomicWaker,
      17              :     value: Mutex<State<T>>,
      18              : }
      19              : 
      20              : enum State<T> {
      21              :     NoData,
      22              :     HasData(T),
      23              :     TryFoldFailed, // transient state
      24              :     SenderWaitsForReceiverToConsume(T),
      25              :     SenderGone(Option<T>),
      26              :     ReceiverGone,
      27              :     AllGone,
      28              :     SenderDropping,   // transient state
      29              :     ReceiverDropping, // transient state
      30              : }
      31              : 
      32           12 : pub fn channel<T: Send>() -> (Sender<T>, Receiver<T>) {
      33           12 :     let inner = Inner {
      34           12 :         wake_receiver: DiatomicWaker::new(),
      35           12 :         wake_sender: DiatomicWaker::new(),
      36           12 :         value: Mutex::new(State::NoData),
      37           12 :     };
      38           12 : 
      39           12 :     let state = Arc::new(inner);
      40           12 :     (
      41           12 :         Sender {
      42           12 :             state: state.clone(),
      43           12 :         },
      44           12 :         Receiver { state },
      45           12 :     )
      46           12 : }
      47              : 
      48              : #[derive(Debug, thiserror::Error)]
      49              : pub enum SendError {
      50              :     #[error("receiver is gone")]
      51              :     ReceiverGone,
      52              : }
      53              : 
      54              : impl<T: Send> Sender<T> {
      55              :     /// # Panics
      56              :     ///
      57              :     /// If `try_fold` panics,  any subsequent call to `send` panic.
      58           14 :     pub async fn send<F>(&mut self, value: T, try_fold: F) -> Result<(), SendError>
      59           14 :     where
      60           14 :         F: Fn(&mut T, T) -> Result<(), T>,
      61           14 :     {
      62           14 :         let mut value = Some(value);
      63           18 :         poll_fn(|cx| {
      64           18 :             let mut guard = self.state.value.lock().unwrap();
      65           18 :             match &mut *guard {
      66              :                 State::NoData => {
      67            9 :                     *guard = State::HasData(value.take().unwrap());
      68            9 :                     self.state.wake_receiver.notify();
      69            9 :                     Poll::Ready(Ok(()))
      70              :                 }
      71              :                 State::HasData(_) => {
      72            4 :                     let State::HasData(acc_mut) = &mut *guard else {
      73            0 :                         unreachable!("this match arm guarantees that the guard is HasData");
      74              :                     };
      75            4 :                     match try_fold(acc_mut, value.take().unwrap()) {
      76              :                         Ok(()) => {
      77              :                             // no need to wake receiver, if it was waiting it already
      78              :                             // got a wake-up when we transitioned from NoData to HasData
      79            1 :                             Poll::Ready(Ok(()))
      80              :                         }
      81            3 :                         Err(unfoldable_value) => {
      82            3 :                             value = Some(unfoldable_value);
      83            3 :                             let State::HasData(acc) =
      84            3 :                                 std::mem::replace(&mut *guard, State::TryFoldFailed)
      85              :                             else {
      86            0 :                                 unreachable!("this match arm guarantees that the guard is HasData");
      87              :                             };
      88            3 :                             *guard = State::SenderWaitsForReceiverToConsume(acc);
      89            3 :                             // SAFETY: send is single threaded due to `&mut self` requirement,
      90            3 :                             // therefore register is not concurrent.
      91            3 :                             unsafe {
      92            3 :                                 self.state.wake_sender.register(cx.waker());
      93            3 :                             }
      94            3 :                             Poll::Pending
      95              :                         }
      96              :                     }
      97              :                 }
      98            1 :                 State::SenderWaitsForReceiverToConsume(_data) => {
      99            1 :                     // SAFETY: send is single threaded due to `&mut self` requirement,
     100            1 :                     // therefore register is not concurrent.
     101            1 :                     unsafe {
     102            1 :                         self.state.wake_sender.register(cx.waker());
     103            1 :                     }
     104            1 :                     Poll::Pending
     105              :                 }
     106            4 :                 State::ReceiverGone => Poll::Ready(Err(SendError::ReceiverGone)),
     107              :                 State::SenderGone(_)
     108              :                 | State::AllGone
     109              :                 | State::SenderDropping
     110              :                 | State::ReceiverDropping
     111              :                 | State::TryFoldFailed => {
     112            0 :                     unreachable!();
     113              :                 }
     114              :             }
     115           18 :         })
     116           14 :         .await
     117           14 :     }
     118              : }
     119              : 
     120              : impl<T> Drop for Sender<T> {
     121           12 :     fn drop(&mut self) {
     122           12 :         scopeguard::defer! {
     123           12 :             self.state.wake_receiver.notify()
     124           12 :         };
     125           12 :         let Ok(mut guard) = self.state.value.lock() else {
     126            0 :             return;
     127              :         };
     128           12 :         *guard = match std::mem::replace(&mut *guard, State::SenderDropping) {
     129            3 :             State::NoData => State::SenderGone(None),
     130            1 :             State::HasData(data) | State::SenderWaitsForReceiverToConsume(data) => {
     131            1 :                 State::SenderGone(Some(data))
     132              :             }
     133            8 :             State::ReceiverGone => State::AllGone,
     134              :             State::TryFoldFailed
     135              :             | State::SenderGone(_)
     136              :             | State::AllGone
     137              :             | State::SenderDropping
     138              :             | State::ReceiverDropping => {
     139            0 :                 unreachable!("unreachable state {:?}", guard.discriminant_str())
     140              :             }
     141              :         }
     142           12 :     }
     143              : }
     144              : 
     145              : #[derive(Debug, thiserror::Error)]
     146              : pub enum RecvError {
     147              :     #[error("sender is gone")]
     148              :     SenderGone,
     149              : }
     150              : 
     151              : impl<T: Send> Receiver<T> {
     152           10 :     pub async fn recv(&mut self) -> Result<T, RecvError> {
     153           12 :         poll_fn(|cx| {
     154           12 :             let mut guard = self.state.value.lock().unwrap();
     155           12 :             match &mut *guard {
     156              :                 State::NoData => {
     157              :                     // SAFETY: recv is single threaded due to `&mut self` requirement,
     158              :                     // therefore register is not concurrent.
     159            2 :                     unsafe {
     160            2 :                         self.state.wake_receiver.register(cx.waker());
     161            2 :                     }
     162            2 :                     Poll::Pending
     163              :                 }
     164            4 :                 guard @ State::HasData(_)
     165            1 :                 | guard @ State::SenderWaitsForReceiverToConsume(_)
     166            1 :                 | guard @ State::SenderGone(Some(_)) => {
     167            6 :                     let data = guard
     168            6 :                         .take_data()
     169            6 :                         .expect("in these states, data is guaranteed to be present");
     170            6 :                     self.state.wake_sender.notify();
     171            6 :                     Poll::Ready(Ok(data))
     172              :                 }
     173            4 :                 State::SenderGone(None) => Poll::Ready(Err(RecvError::SenderGone)),
     174              :                 State::ReceiverGone
     175              :                 | State::AllGone
     176              :                 | State::SenderDropping
     177              :                 | State::ReceiverDropping
     178              :                 | State::TryFoldFailed => {
     179            0 :                     unreachable!("unreachable state {:?}", guard.discriminant_str());
     180              :                 }
     181              :             }
     182           12 :         })
     183           10 :         .await
     184           10 :     }
     185              : }
     186              : 
     187              : impl<T> Drop for Receiver<T> {
     188           12 :     fn drop(&mut self) {
     189           12 :         scopeguard::defer! {
     190           12 :             self.state.wake_sender.notify()
     191           12 :         };
     192           12 :         let Ok(mut guard) = self.state.value.lock() else {
     193            0 :             return;
     194              :         };
     195           12 :         *guard = match std::mem::replace(&mut *guard, State::ReceiverDropping) {
     196            5 :             State::NoData => State::ReceiverGone,
     197            3 :             State::HasData(_) | State::SenderWaitsForReceiverToConsume(_) => State::ReceiverGone,
     198            4 :             State::SenderGone(_) => State::AllGone,
     199              :             State::TryFoldFailed
     200              :             | State::ReceiverGone
     201              :             | State::AllGone
     202              :             | State::SenderDropping
     203              :             | State::ReceiverDropping => {
     204            0 :                 unreachable!("unreachable state {:?}", guard.discriminant_str())
     205              :             }
     206              :         }
     207           12 :     }
     208              : }
     209              : 
     210              : impl<T> State<T> {
     211            6 :     fn take_data(&mut self) -> Option<T> {
     212            6 :         match self {
     213              :             State::HasData(_) => {
     214            4 :                 let State::HasData(data) = std::mem::replace(self, State::NoData) else {
     215            0 :                     unreachable!("this match arm guarantees that the state is HasData");
     216              :                 };
     217            4 :                 Some(data)
     218              :             }
     219              :             State::SenderWaitsForReceiverToConsume(_) => {
     220            1 :                 let State::SenderWaitsForReceiverToConsume(data) =
     221            1 :                     std::mem::replace(self, State::NoData)
     222              :                 else {
     223            0 :                     unreachable!(
     224            0 :                         "this match arm guarantees that the state is SenderWaitsForReceiverToConsume"
     225            0 :                     );
     226              :                 };
     227            1 :                 Some(data)
     228              :             }
     229            1 :             State::SenderGone(data) => Some(data.take().unwrap()),
     230              :             State::NoData
     231              :             | State::TryFoldFailed
     232              :             | State::ReceiverGone
     233              :             | State::AllGone
     234              :             | State::SenderDropping
     235            0 :             | State::ReceiverDropping => None,
     236              :         }
     237            6 :     }
     238            0 :     fn discriminant_str(&self) -> &'static str {
     239            0 :         match self {
     240            0 :             State::NoData => "NoData",
     241            0 :             State::HasData(_) => "HasData",
     242            0 :             State::TryFoldFailed => "TryFoldFailed",
     243            0 :             State::SenderWaitsForReceiverToConsume(_) => "SenderWaitsForReceiverToConsume",
     244            0 :             State::SenderGone(_) => "SenderGone",
     245            0 :             State::ReceiverGone => "ReceiverGone",
     246            0 :             State::AllGone => "AllGone",
     247            0 :             State::SenderDropping => "SenderDropping",
     248            0 :             State::ReceiverDropping => "ReceiverDropping",
     249              :         }
     250            0 :     }
     251              : }
     252              : 
     253              : #[cfg(test)]
     254              : mod tests {
     255              : 
     256              :     use super::*;
     257              : 
     258              :     const FOREVER: std::time::Duration = std::time::Duration::from_secs(u64::MAX);
     259              : 
     260              :     #[tokio::test]
     261            1 :     async fn test_send_recv() {
     262            1 :         let (mut sender, mut receiver) = channel();
     263            1 : 
     264            1 :         sender
     265            1 :             .send(42, |acc, val| {
     266            0 :                 *acc += val;
     267            0 :                 Ok(())
     268            1 :             })
     269            1 :             .await
     270            1 :             .unwrap();
     271            1 : 
     272            1 :         let received = receiver.recv().await.unwrap();
     273            1 :         assert_eq!(received, 42);
     274            1 :     }
     275              : 
     276              :     #[tokio::test]
     277            1 :     async fn test_send_recv_with_fold() {
     278            1 :         let (mut sender, mut receiver) = channel();
     279            1 : 
     280            1 :         sender
     281            1 :             .send(1, |acc, val| {
     282            0 :                 *acc += val;
     283            0 :                 Ok(())
     284            1 :             })
     285            1 :             .await
     286            1 :             .unwrap();
     287            1 :         sender
     288            1 :             .send(2, |acc, val| {
     289            1 :                 *acc += val;
     290            1 :                 Ok(())
     291            1 :             })
     292            1 :             .await
     293            1 :             .unwrap();
     294            1 : 
     295            1 :         let received = receiver.recv().await.unwrap();
     296            1 :         assert_eq!(received, 3);
     297            1 :     }
     298              : 
     299              :     #[tokio::test(start_paused = true)]
     300            1 :     async fn test_sender_waits_for_receiver_if_try_fold_fails() {
     301            1 :         let (mut sender, mut receiver) = channel();
     302            1 : 
     303            1 :         sender.send(23, |_, _| panic!("first send")).await.unwrap();
     304            1 : 
     305            1 :         let send_fut = sender.send(42, |_, val| Err(val));
     306            1 :         let mut send_fut = std::pin::pin!(send_fut);
     307            1 : 
     308            1 :         tokio::select! {
     309            1 :             _ = tokio::time::sleep(FOREVER) => {},
     310            1 :             _ = &mut send_fut => {
     311            1 :                 panic!("send should not complete");
     312            1 :             },
     313            1 :         }
     314            1 : 
     315            1 :         let val = receiver.recv().await.unwrap();
     316            1 :         assert_eq!(val, 23);
     317            1 : 
     318            1 :         tokio::select! {
     319            1 :             _ = tokio::time::sleep(FOREVER) => {
     320            1 :                 panic!("receiver should have consumed the value");
     321            1 :             },
     322            1 :             _ = &mut send_fut => { },
     323            1 :         }
     324            1 : 
     325            1 :         let val = receiver.recv().await.unwrap();
     326            1 :         assert_eq!(val, 42);
     327            1 :     }
     328              : 
     329              :     #[tokio::test(start_paused = true)]
     330            1 :     async fn test_sender_errors_if_waits_for_receiver_and_receiver_drops() {
     331            1 :         let (mut sender, receiver) = channel();
     332            1 : 
     333            1 :         sender.send(23, |_, _| unreachable!()).await.unwrap();
     334            1 : 
     335            1 :         let send_fut = sender.send(42, |_, val| Err(val));
     336            1 :         let send_fut = std::pin::pin!(send_fut);
     337            1 : 
     338            1 :         drop(receiver);
     339            1 : 
     340            1 :         let result = send_fut.await;
     341            1 :         assert!(matches!(result, Err(SendError::ReceiverGone)));
     342            1 :     }
     343              : 
     344              :     #[tokio::test(start_paused = true)]
     345            1 :     async fn test_receiver_errors_if_waits_for_sender_and_sender_drops() {
     346            1 :         let (sender, mut receiver) = channel::<()>();
     347            1 : 
     348            1 :         let recv_fut = receiver.recv();
     349            1 :         let recv_fut = std::pin::pin!(recv_fut);
     350            1 : 
     351            1 :         drop(sender);
     352            1 : 
     353            1 :         let result = recv_fut.await;
     354            1 :         assert!(matches!(result, Err(RecvError::SenderGone)));
     355            1 :     }
     356              : 
     357              :     #[tokio::test(start_paused = true)]
     358            1 :     async fn test_receiver_errors_if_waits_for_sender_and_sender_drops_with_data() {
     359            1 :         let (mut sender, mut receiver) = channel();
     360            1 : 
     361            1 :         sender.send(42, |_, _| unreachable!()).await.unwrap();
     362            1 : 
     363            1 :         {
     364            1 :             let recv_fut = receiver.recv();
     365            1 :             let recv_fut = std::pin::pin!(recv_fut);
     366            1 : 
     367            1 :             drop(sender);
     368            1 : 
     369            1 :             let val = recv_fut.await.unwrap();
     370            1 :             assert_eq!(val, 42);
     371            1 :         }
     372            1 : 
     373            1 :         let result = receiver.recv().await;
     374            1 :         assert!(matches!(result, Err(RecvError::SenderGone)));
     375            1 :     }
     376              : 
     377              :     #[tokio::test(start_paused = true)]
     378            1 :     async fn test_receiver_waits_for_sender_if_no_data() {
     379            1 :         let (mut sender, mut receiver) = channel();
     380            1 : 
     381            1 :         let recv_fut = receiver.recv();
     382            1 :         let mut recv_fut = std::pin::pin!(recv_fut);
     383            1 : 
     384            1 :         tokio::select! {
     385            1 :             _ = tokio::time::sleep(FOREVER) => {},
     386            1 :             _ = &mut recv_fut => {
     387            1 :                 panic!("recv should not complete");
     388            1 :             },
     389            1 :         }
     390            1 : 
     391            1 :         sender.send(42, |_, _| Ok(())).await.unwrap();
     392            1 : 
     393            1 :         let val = recv_fut.await.unwrap();
     394            1 :         assert_eq!(val, 42);
     395            1 :     }
     396              : 
     397              :     #[tokio::test]
     398            1 :     async fn test_receiver_gone_while_nodata() {
     399            1 :         let (mut sender, receiver) = channel();
     400            1 :         drop(receiver);
     401            1 : 
     402            1 :         let result = sender.send(42, |_, _| Ok(())).await;
     403            1 :         assert!(matches!(result, Err(SendError::ReceiverGone)));
     404            1 :     }
     405              : 
     406              :     #[tokio::test]
     407            1 :     async fn test_sender_gone_while_nodata() {
     408            1 :         let (sender, mut receiver) = super::channel::<usize>();
     409            1 :         drop(sender);
     410            1 : 
     411            1 :         let result = receiver.recv().await;
     412            1 :         assert!(matches!(result, Err(RecvError::SenderGone)));
     413            1 :     }
     414              : 
     415              :     #[tokio::test(start_paused = true)]
     416            1 :     async fn test_receiver_drops_after_sender_went_to_sleep() {
     417            1 :         let (mut sender, receiver) = channel();
     418            1 :         let state = receiver.state.clone();
     419            1 : 
     420            1 :         sender.send(23, |_, _| unreachable!()).await.unwrap();
     421            1 : 
     422            1 :         let send_task = tokio::spawn(async move { sender.send(42, |_, v| Err(v)).await });
     423            1 : 
     424            1 :         tokio::time::sleep(FOREVER).await;
     425            1 : 
     426            1 :         assert!(matches!(
     427            1 :             &*state.value.lock().unwrap(),
     428            1 :             &State::SenderWaitsForReceiverToConsume(_)
     429            1 :         ));
     430            1 : 
     431            1 :         drop(receiver);
     432            1 : 
     433            1 :         let err = send_task
     434            1 :             .await
     435            1 :             .unwrap()
     436            1 :             .expect_err("should unblock immediately");
     437            1 :         assert!(matches!(err, SendError::ReceiverGone));
     438            1 :     }
     439              : 
     440              :     #[tokio::test(start_paused = true)]
     441            1 :     async fn test_sender_drops_after_receiver_went_to_sleep() {
     442            1 :         let (sender, mut receiver) = channel::<usize>();
     443            1 :         let state = sender.state.clone();
     444            1 : 
     445            1 :         let recv_task = tokio::spawn(async move { receiver.recv().await });
     446            1 : 
     447            1 :         tokio::time::sleep(FOREVER).await;
     448            1 : 
     449            1 :         assert!(matches!(&*state.value.lock().unwrap(), &State::NoData));
     450            1 : 
     451            1 :         drop(sender);
     452            1 : 
     453            1 :         let err = recv_task.await.unwrap().expect_err("should error");
     454            1 :         assert!(matches!(err, RecvError::SenderGone));
     455            1 :     }
     456              : 
     457              :     #[tokio::test(start_paused = true)]
     458            1 :     async fn test_receiver_drop_while_waiting_for_receiver_to_consume_unblocks_sender() {
     459            1 :         let (mut sender, receiver) = channel();
     460            1 : 
     461            1 :         let state = receiver.state.clone();
     462            1 : 
     463            1 :         sender.send((), |_, _| unreachable!()).await.unwrap();
     464            1 : 
     465            1 :         assert!(matches!(&*state.value.lock().unwrap(), &State::HasData(_)));
     466            1 : 
     467            1 :         let unmergeable = sender.send((), |_, _| Err(()));
     468            1 :         let mut unmergeable = std::pin::pin!(unmergeable);
     469            1 :         tokio::select! {
     470            1 :             _ = tokio::time::sleep(FOREVER) => {},
     471            1 :             _ = &mut unmergeable => {
     472            1 :                 panic!("unmergeable should not complete");
     473            1 :             },
     474            1 :         }
     475            1 : 
     476            1 :         assert!(matches!(
     477            1 :             &*state.value.lock().unwrap(),
     478            1 :             &State::SenderWaitsForReceiverToConsume(_)
     479            1 :         ));
     480            1 : 
     481            1 :         drop(receiver);
     482            1 : 
     483            1 :         assert!(matches!(
     484            1 :             &*state.value.lock().unwrap(),
     485            1 :             &State::ReceiverGone
     486            1 :         ));
     487            1 : 
     488            1 :         unmergeable.await.unwrap_err();
     489            1 :     }
     490              : }
        

Generated by: LCOV version 2.1-beta