LCOV - code coverage report
Current view: top level - libs/utils/src/sync - spsc_fold.rs (source / functions) Coverage Total Hit
Test: 6df3fc19ec669bcfbbf9aba41d1338898d24eaa0.info Lines: 89.2 % 333 297
Test Date: 2025-03-12 18:28:53 Functions: 73.4 % 124 91

            Line data    Source code
       1              : use core::future::poll_fn;
       2              : use core::task::Poll;
       3              : use std::sync::{Arc, Mutex};
       4              : 
       5              : use diatomic_waker::DiatomicWaker;
       6              : 
       7              : pub struct Sender<T> {
       8              :     state: Arc<Inner<T>>,
       9              : }
      10              : 
      11              : pub struct Receiver<T> {
      12              :     state: Arc<Inner<T>>,
      13              : }
      14              : 
      15              : struct Inner<T> {
      16              :     wake_receiver: DiatomicWaker,
      17              :     wake_sender: DiatomicWaker,
      18              :     value: Mutex<State<T>>,
      19              : }
      20              : 
      21              : enum State<T> {
      22              :     NoData,
      23              :     HasData(T),
      24              :     TryFoldFailed, // transient state
      25              :     SenderWaitsForReceiverToConsume(T),
      26              :     SenderGone(Option<T>),
      27              :     ReceiverGone,
      28              :     AllGone,
      29              :     SenderDropping,   // transient state
      30              :     ReceiverDropping, // transient state
      31              : }
      32              : 
      33           12 : pub fn channel<T: Send>() -> (Sender<T>, Receiver<T>) {
      34           12 :     let inner = Inner {
      35           12 :         wake_receiver: DiatomicWaker::new(),
      36           12 :         wake_sender: DiatomicWaker::new(),
      37           12 :         value: Mutex::new(State::NoData),
      38           12 :     };
      39           12 : 
      40           12 :     let state = Arc::new(inner);
      41           12 :     (
      42           12 :         Sender {
      43           12 :             state: state.clone(),
      44           12 :         },
      45           12 :         Receiver { state },
      46           12 :     )
      47           12 : }
      48              : 
      49              : #[derive(Debug, thiserror::Error)]
      50              : pub enum SendError {
      51              :     #[error("receiver is gone")]
      52              :     ReceiverGone,
      53              : }
      54              : 
      55              : impl<T: Send> Sender<T> {
      56              :     /// # Panics
      57              :     ///
      58              :     /// If `try_fold` panics,  any subsequent call to `send` panic.
      59           14 :     pub async fn send<F>(&mut self, value: T, try_fold: F) -> Result<(), SendError>
      60           14 :     where
      61           14 :         F: Fn(&mut T, T) -> Result<(), T>,
      62           14 :     {
      63           14 :         let mut value = Some(value);
      64           17 :         poll_fn(|cx| {
      65           17 :             let mut guard = self.state.value.lock().unwrap();
      66           17 :             match &mut *guard {
      67              :                 State::NoData => {
      68            9 :                     *guard = State::HasData(value.take().unwrap());
      69            9 :                     self.state.wake_receiver.notify();
      70            9 :                     Poll::Ready(Ok(()))
      71              :                 }
      72              :                 State::HasData(_) => {
      73            4 :                     let State::HasData(acc_mut) = &mut *guard else {
      74            0 :                         unreachable!("this match arm guarantees that the guard is HasData");
      75              :                     };
      76            4 :                     match try_fold(acc_mut, value.take().unwrap()) {
      77              :                         Ok(()) => {
      78              :                             // no need to wake receiver, if it was waiting it already
      79              :                             // got a wake-up when we transitioned from NoData to HasData
      80            1 :                             Poll::Ready(Ok(()))
      81              :                         }
      82            3 :                         Err(unfoldable_value) => {
      83            3 :                             value = Some(unfoldable_value);
      84            3 :                             let State::HasData(acc) =
      85            3 :                                 std::mem::replace(&mut *guard, State::TryFoldFailed)
      86              :                             else {
      87            0 :                                 unreachable!("this match arm guarantees that the guard is HasData");
      88              :                             };
      89            3 :                             *guard = State::SenderWaitsForReceiverToConsume(acc);
      90            3 :                             // SAFETY: send is single threaded due to `&mut self` requirement,
      91            3 :                             // therefore register is not concurrent.
      92            3 :                             unsafe {
      93            3 :                                 self.state.wake_sender.register(cx.waker());
      94            3 :                             }
      95            3 :                             Poll::Pending
      96              :                         }
      97              :                     }
      98              :                 }
      99            0 :                 State::SenderWaitsForReceiverToConsume(_data) => {
     100            0 :                     // SAFETY: send is single threaded due to `&mut self` requirement,
     101            0 :                     // therefore register is not concurrent.
     102            0 :                     unsafe {
     103            0 :                         self.state.wake_sender.register(cx.waker());
     104            0 :                     }
     105            0 :                     Poll::Pending
     106              :                 }
     107            4 :                 State::ReceiverGone => Poll::Ready(Err(SendError::ReceiverGone)),
     108              :                 State::SenderGone(_)
     109              :                 | State::AllGone
     110              :                 | State::SenderDropping
     111              :                 | State::ReceiverDropping
     112              :                 | State::TryFoldFailed => {
     113            0 :                     unreachable!();
     114              :                 }
     115              :             }
     116           17 :         })
     117           14 :         .await
     118           14 :     }
     119              : }
     120              : 
     121              : impl<T> Drop for Sender<T> {
     122           12 :     fn drop(&mut self) {
     123           12 :         scopeguard::defer! {
     124           12 :             self.state.wake_receiver.notify()
     125           12 :         };
     126           12 :         let Ok(mut guard) = self.state.value.lock() else {
     127            0 :             return;
     128              :         };
     129           12 :         *guard = match std::mem::replace(&mut *guard, State::SenderDropping) {
     130            3 :             State::NoData => State::SenderGone(None),
     131            1 :             State::HasData(data) | State::SenderWaitsForReceiverToConsume(data) => {
     132            1 :                 State::SenderGone(Some(data))
     133              :             }
     134            8 :             State::ReceiverGone => State::AllGone,
     135              :             State::TryFoldFailed
     136              :             | State::SenderGone(_)
     137              :             | State::AllGone
     138              :             | State::SenderDropping
     139              :             | State::ReceiverDropping => {
     140            0 :                 unreachable!("unreachable state {:?}", guard.discriminant_str())
     141              :             }
     142              :         }
     143           12 :     }
     144              : }
     145              : 
     146              : #[derive(Debug, thiserror::Error)]
     147              : pub enum RecvError {
     148              :     #[error("sender is gone")]
     149              :     SenderGone,
     150              : }
     151              : 
     152              : impl<T: Send> Receiver<T> {
     153           10 :     pub async fn recv(&mut self) -> Result<T, RecvError> {
     154           12 :         poll_fn(|cx| {
     155           12 :             let mut guard = self.state.value.lock().unwrap();
     156           12 :             match &mut *guard {
     157              :                 State::NoData => {
     158              :                     // SAFETY: recv is single threaded due to `&mut self` requirement,
     159              :                     // therefore register is not concurrent.
     160            2 :                     unsafe {
     161            2 :                         self.state.wake_receiver.register(cx.waker());
     162            2 :                     }
     163            2 :                     Poll::Pending
     164              :                 }
     165            4 :                 guard @ State::HasData(_)
     166            1 :                 | guard @ State::SenderWaitsForReceiverToConsume(_)
     167            1 :                 | guard @ State::SenderGone(Some(_)) => {
     168            6 :                     let data = guard
     169            6 :                         .take_data()
     170            6 :                         .expect("in these states, data is guaranteed to be present");
     171            6 :                     self.state.wake_sender.notify();
     172            6 :                     Poll::Ready(Ok(data))
     173              :                 }
     174            4 :                 State::SenderGone(None) => Poll::Ready(Err(RecvError::SenderGone)),
     175              :                 State::ReceiverGone
     176              :                 | State::AllGone
     177              :                 | State::SenderDropping
     178              :                 | State::ReceiverDropping
     179              :                 | State::TryFoldFailed => {
     180            0 :                     unreachable!("unreachable state {:?}", guard.discriminant_str());
     181              :                 }
     182              :             }
     183           12 :         })
     184           10 :         .await
     185           10 :     }
     186              : }
     187              : 
     188              : impl<T> Drop for Receiver<T> {
     189           12 :     fn drop(&mut self) {
     190           12 :         scopeguard::defer! {
     191           12 :             self.state.wake_sender.notify()
     192           12 :         };
     193           12 :         let Ok(mut guard) = self.state.value.lock() else {
     194            0 :             return;
     195              :         };
     196           12 :         *guard = match std::mem::replace(&mut *guard, State::ReceiverDropping) {
     197            5 :             State::NoData => State::ReceiverGone,
     198            3 :             State::HasData(_) | State::SenderWaitsForReceiverToConsume(_) => State::ReceiverGone,
     199            4 :             State::SenderGone(_) => State::AllGone,
     200              :             State::TryFoldFailed
     201              :             | State::ReceiverGone
     202              :             | State::AllGone
     203              :             | State::SenderDropping
     204              :             | State::ReceiverDropping => {
     205            0 :                 unreachable!("unreachable state {:?}", guard.discriminant_str())
     206              :             }
     207              :         }
     208           12 :     }
     209              : }
     210              : 
     211              : impl<T> State<T> {
     212            6 :     fn take_data(&mut self) -> Option<T> {
     213            6 :         match self {
     214              :             State::HasData(_) => {
     215            4 :                 let State::HasData(data) = std::mem::replace(self, State::NoData) else {
     216            0 :                     unreachable!("this match arm guarantees that the state is HasData");
     217              :                 };
     218            4 :                 Some(data)
     219              :             }
     220              :             State::SenderWaitsForReceiverToConsume(_) => {
     221            1 :                 let State::SenderWaitsForReceiverToConsume(data) =
     222            1 :                     std::mem::replace(self, State::NoData)
     223              :                 else {
     224            0 :                     unreachable!(
     225            0 :                         "this match arm guarantees that the state is SenderWaitsForReceiverToConsume"
     226            0 :                     );
     227              :                 };
     228            1 :                 Some(data)
     229              :             }
     230            1 :             State::SenderGone(data) => Some(data.take().unwrap()),
     231              :             State::NoData
     232              :             | State::TryFoldFailed
     233              :             | State::ReceiverGone
     234              :             | State::AllGone
     235              :             | State::SenderDropping
     236            0 :             | State::ReceiverDropping => None,
     237              :         }
     238            6 :     }
     239            0 :     fn discriminant_str(&self) -> &'static str {
     240            0 :         match self {
     241            0 :             State::NoData => "NoData",
     242            0 :             State::HasData(_) => "HasData",
     243            0 :             State::TryFoldFailed => "TryFoldFailed",
     244            0 :             State::SenderWaitsForReceiverToConsume(_) => "SenderWaitsForReceiverToConsume",
     245            0 :             State::SenderGone(_) => "SenderGone",
     246            0 :             State::ReceiverGone => "ReceiverGone",
     247            0 :             State::AllGone => "AllGone",
     248            0 :             State::SenderDropping => "SenderDropping",
     249            0 :             State::ReceiverDropping => "ReceiverDropping",
     250              :         }
     251            0 :     }
     252              : }
     253              : 
     254              : #[cfg(test)]
     255              : mod tests {
     256              : 
     257              :     use super::*;
     258              : 
     259              :     const FOREVER: std::time::Duration = std::time::Duration::from_secs(u64::MAX);
     260              : 
     261              :     #[tokio::test]
     262            1 :     async fn test_send_recv() {
     263            1 :         let (mut sender, mut receiver) = channel();
     264            1 : 
     265            1 :         sender
     266            1 :             .send(42, |acc, val| {
     267            0 :                 *acc += val;
     268            0 :                 Ok(())
     269            1 :             })
     270            1 :             .await
     271            1 :             .unwrap();
     272            1 : 
     273            1 :         let received = receiver.recv().await.unwrap();
     274            1 :         assert_eq!(received, 42);
     275            1 :     }
     276              : 
     277              :     #[tokio::test]
     278            1 :     async fn test_send_recv_with_fold() {
     279            1 :         let (mut sender, mut receiver) = channel();
     280            1 : 
     281            1 :         sender
     282            1 :             .send(1, |acc, val| {
     283            0 :                 *acc += val;
     284            0 :                 Ok(())
     285            1 :             })
     286            1 :             .await
     287            1 :             .unwrap();
     288            1 :         sender
     289            1 :             .send(2, |acc, val| {
     290            1 :                 *acc += val;
     291            1 :                 Ok(())
     292            1 :             })
     293            1 :             .await
     294            1 :             .unwrap();
     295            1 : 
     296            1 :         let received = receiver.recv().await.unwrap();
     297            1 :         assert_eq!(received, 3);
     298            1 :     }
     299              : 
     300              :     #[tokio::test(start_paused = true)]
     301            1 :     async fn test_sender_waits_for_receiver_if_try_fold_fails() {
     302            1 :         let (mut sender, mut receiver) = channel();
     303            1 : 
     304            1 :         sender.send(23, |_, _| panic!("first send")).await.unwrap();
     305            1 : 
     306            1 :         let send_fut = sender.send(42, |_, val| Err(val));
     307            1 :         let mut send_fut = std::pin::pin!(send_fut);
     308            1 : 
     309            1 :         tokio::select! {
     310            1 :             _ = tokio::time::sleep(FOREVER) => {},
     311            1 :             _ = &mut send_fut => {
     312            1 :                 panic!("send should not complete");
     313            1 :             },
     314            1 :         }
     315            1 : 
     316            1 :         let val = receiver.recv().await.unwrap();
     317            1 :         assert_eq!(val, 23);
     318            1 : 
     319            1 :         tokio::select! {
     320            1 :             _ = tokio::time::sleep(FOREVER) => {
     321            1 :                 panic!("receiver should have consumed the value");
     322            1 :             },
     323            1 :             _ = &mut send_fut => { },
     324            1 :         }
     325            1 : 
     326            1 :         let val = receiver.recv().await.unwrap();
     327            1 :         assert_eq!(val, 42);
     328            1 :     }
     329              : 
     330              :     #[tokio::test(start_paused = true)]
     331            1 :     async fn test_sender_errors_if_waits_for_receiver_and_receiver_drops() {
     332            1 :         let (mut sender, receiver) = channel();
     333            1 : 
     334            1 :         sender.send(23, |_, _| unreachable!()).await.unwrap();
     335            1 : 
     336            1 :         let send_fut = sender.send(42, |_, val| Err(val));
     337            1 :         let send_fut = std::pin::pin!(send_fut);
     338            1 : 
     339            1 :         drop(receiver);
     340            1 : 
     341            1 :         let result = send_fut.await;
     342            1 :         assert!(matches!(result, Err(SendError::ReceiverGone)));
     343            1 :     }
     344              : 
     345              :     #[tokio::test(start_paused = true)]
     346            1 :     async fn test_receiver_errors_if_waits_for_sender_and_sender_drops() {
     347            1 :         let (sender, mut receiver) = channel::<()>();
     348            1 : 
     349            1 :         let recv_fut = receiver.recv();
     350            1 :         let recv_fut = std::pin::pin!(recv_fut);
     351            1 : 
     352            1 :         drop(sender);
     353            1 : 
     354            1 :         let result = recv_fut.await;
     355            1 :         assert!(matches!(result, Err(RecvError::SenderGone)));
     356            1 :     }
     357              : 
     358              :     #[tokio::test(start_paused = true)]
     359            1 :     async fn test_receiver_errors_if_waits_for_sender_and_sender_drops_with_data() {
     360            1 :         let (mut sender, mut receiver) = channel();
     361            1 : 
     362            1 :         sender.send(42, |_, _| unreachable!()).await.unwrap();
     363            1 : 
     364            1 :         {
     365            1 :             let recv_fut = receiver.recv();
     366            1 :             let recv_fut = std::pin::pin!(recv_fut);
     367            1 : 
     368            1 :             drop(sender);
     369            1 : 
     370            1 :             let val = recv_fut.await.unwrap();
     371            1 :             assert_eq!(val, 42);
     372            1 :         }
     373            1 : 
     374            1 :         let result = receiver.recv().await;
     375            1 :         assert!(matches!(result, Err(RecvError::SenderGone)));
     376            1 :     }
     377              : 
     378              :     #[tokio::test(start_paused = true)]
     379            1 :     async fn test_receiver_waits_for_sender_if_no_data() {
     380            1 :         let (mut sender, mut receiver) = channel();
     381            1 : 
     382            1 :         let recv_fut = receiver.recv();
     383            1 :         let mut recv_fut = std::pin::pin!(recv_fut);
     384            1 : 
     385            1 :         tokio::select! {
     386            1 :             _ = tokio::time::sleep(FOREVER) => {},
     387            1 :             _ = &mut recv_fut => {
     388            1 :                 panic!("recv should not complete");
     389            1 :             },
     390            1 :         }
     391            1 : 
     392            1 :         sender.send(42, |_, _| Ok(())).await.unwrap();
     393            1 : 
     394            1 :         let val = recv_fut.await.unwrap();
     395            1 :         assert_eq!(val, 42);
     396            1 :     }
     397              : 
     398              :     #[tokio::test]
     399            1 :     async fn test_receiver_gone_while_nodata() {
     400            1 :         let (mut sender, receiver) = channel();
     401            1 :         drop(receiver);
     402            1 : 
     403            1 :         let result = sender.send(42, |_, _| Ok(())).await;
     404            1 :         assert!(matches!(result, Err(SendError::ReceiverGone)));
     405            1 :     }
     406              : 
     407              :     #[tokio::test]
     408            1 :     async fn test_sender_gone_while_nodata() {
     409            1 :         let (sender, mut receiver) = super::channel::<usize>();
     410            1 :         drop(sender);
     411            1 : 
     412            1 :         let result = receiver.recv().await;
     413            1 :         assert!(matches!(result, Err(RecvError::SenderGone)));
     414            1 :     }
     415              : 
     416              :     #[tokio::test(start_paused = true)]
     417            1 :     async fn test_receiver_drops_after_sender_went_to_sleep() {
     418            1 :         let (mut sender, receiver) = channel();
     419            1 :         let state = receiver.state.clone();
     420            1 : 
     421            1 :         sender.send(23, |_, _| unreachable!()).await.unwrap();
     422            1 : 
     423            1 :         let send_task = tokio::spawn(async move { sender.send(42, |_, v| Err(v)).await });
     424            1 : 
     425            1 :         tokio::time::sleep(FOREVER).await;
     426            1 : 
     427            1 :         assert!(matches!(
     428            1 :             &*state.value.lock().unwrap(),
     429            1 :             &State::SenderWaitsForReceiverToConsume(_)
     430            1 :         ));
     431            1 : 
     432            1 :         drop(receiver);
     433            1 : 
     434            1 :         let err = send_task
     435            1 :             .await
     436            1 :             .unwrap()
     437            1 :             .expect_err("should unblock immediately");
     438            1 :         assert!(matches!(err, SendError::ReceiverGone));
     439            1 :     }
     440              : 
     441              :     #[tokio::test(start_paused = true)]
     442            1 :     async fn test_sender_drops_after_receiver_went_to_sleep() {
     443            1 :         let (sender, mut receiver) = channel::<usize>();
     444            1 :         let state = sender.state.clone();
     445            1 : 
     446            1 :         let recv_task = tokio::spawn(async move { receiver.recv().await });
     447            1 : 
     448            1 :         tokio::time::sleep(FOREVER).await;
     449            1 : 
     450            1 :         assert!(matches!(&*state.value.lock().unwrap(), &State::NoData));
     451            1 : 
     452            1 :         drop(sender);
     453            1 : 
     454            1 :         let err = recv_task.await.unwrap().expect_err("should error");
     455            1 :         assert!(matches!(err, RecvError::SenderGone));
     456            1 :     }
     457              : 
     458              :     #[tokio::test(start_paused = true)]
     459            1 :     async fn test_receiver_drop_while_waiting_for_receiver_to_consume_unblocks_sender() {
     460            1 :         let (mut sender, receiver) = channel();
     461            1 : 
     462            1 :         let state = receiver.state.clone();
     463            1 : 
     464            1 :         sender.send((), |_, _| unreachable!()).await.unwrap();
     465            1 : 
     466            1 :         assert!(matches!(&*state.value.lock().unwrap(), &State::HasData(_)));
     467            1 : 
     468            1 :         let unmergeable = sender.send((), |_, _| Err(()));
     469            1 :         let mut unmergeable = std::pin::pin!(unmergeable);
     470            1 :         tokio::select! {
     471            1 :             _ = tokio::time::sleep(FOREVER) => {},
     472            1 :             _ = &mut unmergeable => {
     473            1 :                 panic!("unmergeable should not complete");
     474            1 :             },
     475            1 :         }
     476            1 : 
     477            1 :         assert!(matches!(
     478            1 :             &*state.value.lock().unwrap(),
     479            1 :             &State::SenderWaitsForReceiverToConsume(_)
     480            1 :         ));
     481            1 : 
     482            1 :         drop(receiver);
     483            1 : 
     484            1 :         assert!(matches!(
     485            1 :             &*state.value.lock().unwrap(),
     486            1 :             &State::ReceiverGone
     487            1 :         ));
     488            1 : 
     489            1 :         unmergeable.await.unwrap_err();
     490            1 :     }
     491              : }
        

Generated by: LCOV version 2.1-beta