LCOV - code coverage report
Current view: top level - libs/utils/src/sync - spsc_fold.rs (source / functions) Coverage Total Hit
Test: 4f58e98c51285c7fa348e0b410c88a10caf68ad2.info Lines: 89.2 % 297 265
Test Date: 2025-01-07 20:58:07 Functions: 71.3 % 115 82

            Line data    Source code
       1              : use core::{future::poll_fn, task::Poll};
       2              : use std::sync::{Arc, Mutex};
       3              : 
       4              : use diatomic_waker::DiatomicWaker;
       5              : 
       6              : pub struct Sender<T> {
       7              :     state: Arc<Inner<T>>,
       8              : }
       9              : 
      10              : pub struct Receiver<T> {
      11              :     state: Arc<Inner<T>>,
      12              : }
      13              : 
      14              : struct Inner<T> {
      15              :     wake_receiver: DiatomicWaker,
      16              :     wake_sender: DiatomicWaker,
      17              :     value: Mutex<State<T>>,
      18              : }
      19              : 
      20              : enum State<T> {
      21              :     NoData,
      22              :     HasData(T),
      23              :     TryFoldFailed, // transient state
      24              :     SenderWaitsForReceiverToConsume(T),
      25              :     SenderGone(Option<T>),
      26              :     ReceiverGone,
      27              :     AllGone,
      28              :     SenderDropping,   // transient state
      29              :     ReceiverDropping, // transient state
      30              : }
      31              : 
      32           11 : pub fn channel<T: Send>() -> (Sender<T>, Receiver<T>) {
      33           11 :     let inner = Inner {
      34           11 :         wake_receiver: DiatomicWaker::new(),
      35           11 :         wake_sender: DiatomicWaker::new(),
      36           11 :         value: Mutex::new(State::NoData),
      37           11 :     };
      38           11 : 
      39           11 :     let state = Arc::new(inner);
      40           11 :     (
      41           11 :         Sender {
      42           11 :             state: state.clone(),
      43           11 :         },
      44           11 :         Receiver { state },
      45           11 :     )
      46           11 : }
      47              : 
      48              : #[derive(Debug, thiserror::Error)]
      49              : pub enum SendError {
      50              :     #[error("receiver is gone")]
      51              :     ReceiverGone,
      52              : }
      53              : 
      54              : impl<T: Send> Sender<T> {
      55              :     /// # Panics
      56              :     ///
      57              :     /// If `try_fold` panics,  any subsequent call to `send` panic.
      58           12 :     pub async fn send<F>(&mut self, value: T, try_fold: F) -> Result<(), SendError>
      59           12 :     where
      60           12 :         F: Fn(&mut T, T) -> Result<(), T>,
      61           12 :     {
      62           12 :         let mut value = Some(value);
      63           14 :         poll_fn(|cx| {
      64           14 :             let mut guard = self.state.value.lock().unwrap();
      65           14 :             match &mut *guard {
      66              :                 State::NoData => {
      67            8 :                     *guard = State::HasData(value.take().unwrap());
      68            8 :                     self.state.wake_receiver.notify();
      69            8 :                     Poll::Ready(Ok(()))
      70              :                 }
      71              :                 State::HasData(_) => {
      72            3 :                     let State::HasData(acc_mut) = &mut *guard else {
      73            0 :                         unreachable!("this match arm guarantees that the guard is HasData");
      74              :                     };
      75            3 :                     match try_fold(acc_mut, value.take().unwrap()) {
      76              :                         Ok(()) => {
      77              :                             // no need to wake receiver, if it was waiting it already
      78              :                             // got a wake-up when we transitioned from NoData to HasData
      79            1 :                             Poll::Ready(Ok(()))
      80              :                         }
      81            2 :                         Err(unfoldable_value) => {
      82            2 :                             value = Some(unfoldable_value);
      83            2 :                             let State::HasData(acc) =
      84            2 :                                 std::mem::replace(&mut *guard, State::TryFoldFailed)
      85              :                             else {
      86            0 :                                 unreachable!("this match arm guarantees that the guard is HasData");
      87              :                             };
      88            2 :                             *guard = State::SenderWaitsForReceiverToConsume(acc);
      89            2 :                             // SAFETY: send is single threaded due to `&mut self` requirement,
      90            2 :                             // therefore register is not concurrent.
      91            2 :                             unsafe {
      92            2 :                                 self.state.wake_sender.register(cx.waker());
      93            2 :                             }
      94            2 :                             Poll::Pending
      95              :                         }
      96              :                     }
      97              :                 }
      98            0 :                 State::SenderWaitsForReceiverToConsume(_data) => {
      99            0 :                     // Really, we shouldn't be polled until receiver has consumed and wakes us.
     100            0 :                     Poll::Pending
     101              :                 }
     102            3 :                 State::ReceiverGone => Poll::Ready(Err(SendError::ReceiverGone)),
     103              :                 State::SenderGone(_)
     104              :                 | State::AllGone
     105              :                 | State::SenderDropping
     106              :                 | State::ReceiverDropping
     107              :                 | State::TryFoldFailed => {
     108            0 :                     unreachable!();
     109              :                 }
     110              :             }
     111           14 :         })
     112           12 :         .await
     113           12 :     }
     114              : }
     115              : 
     116              : impl<T> Drop for Sender<T> {
     117           11 :     fn drop(&mut self) {
     118           11 :         scopeguard::defer! {
     119           11 :             self.state.wake_receiver.notify()
     120           11 :         };
     121           11 :         let Ok(mut guard) = self.state.value.lock() else {
     122            0 :             return;
     123              :         };
     124           11 :         *guard = match std::mem::replace(&mut *guard, State::SenderDropping) {
     125            3 :             State::NoData => State::SenderGone(None),
     126            1 :             State::HasData(data) | State::SenderWaitsForReceiverToConsume(data) => {
     127            1 :                 State::SenderGone(Some(data))
     128              :             }
     129            7 :             State::ReceiverGone => State::AllGone,
     130              :             State::TryFoldFailed
     131              :             | State::SenderGone(_)
     132              :             | State::AllGone
     133              :             | State::SenderDropping
     134              :             | State::ReceiverDropping => {
     135            0 :                 unreachable!("unreachable state {:?}", guard.discriminant_str())
     136              :             }
     137              :         }
     138           11 :     }
     139              : }
     140              : 
     141              : #[derive(Debug, thiserror::Error)]
     142              : pub enum RecvError {
     143              :     #[error("sender is gone")]
     144              :     SenderGone,
     145              : }
     146              : 
     147              : impl<T: Send> Receiver<T> {
     148           10 :     pub async fn recv(&mut self) -> Result<T, RecvError> {
     149           13 :         poll_fn(|cx| {
     150           13 :             let mut guard = self.state.value.lock().unwrap();
     151           13 :             match &mut *guard {
     152              :                 State::NoData => {
     153              :                     // SAFETY: recv is single threaded due to `&mut self` requirement,
     154              :                     // therefore register is not concurrent.
     155            3 :                     unsafe {
     156            3 :                         self.state.wake_receiver.register(cx.waker());
     157            3 :                     }
     158            3 :                     Poll::Pending
     159              :                 }
     160            4 :                 guard @ State::HasData(_)
     161            1 :                 | guard @ State::SenderWaitsForReceiverToConsume(_)
     162            1 :                 | guard @ State::SenderGone(Some(_)) => {
     163            6 :                     let data = guard
     164            6 :                         .take_data()
     165            6 :                         .expect("in these states, data is guaranteed to be present");
     166            6 :                     self.state.wake_sender.notify();
     167            6 :                     Poll::Ready(Ok(data))
     168              :                 }
     169            4 :                 State::SenderGone(None) => Poll::Ready(Err(RecvError::SenderGone)),
     170              :                 State::ReceiverGone
     171              :                 | State::AllGone
     172              :                 | State::SenderDropping
     173              :                 | State::ReceiverDropping
     174              :                 | State::TryFoldFailed => {
     175            0 :                     unreachable!("unreachable state {:?}", guard.discriminant_str());
     176              :                 }
     177              :             }
     178           13 :         })
     179           10 :         .await
     180           10 :     }
     181              : }
     182              : 
     183              : impl<T> Drop for Receiver<T> {
     184           11 :     fn drop(&mut self) {
     185           11 :         scopeguard::defer! {
     186           11 :             self.state.wake_sender.notify()
     187           11 :         };
     188           11 :         let Ok(mut guard) = self.state.value.lock() else {
     189            0 :             return;
     190              :         };
     191           11 :         *guard = match std::mem::replace(&mut *guard, State::ReceiverDropping) {
     192            5 :             State::NoData => State::ReceiverGone,
     193            2 :             State::HasData(_) | State::SenderWaitsForReceiverToConsume(_) => State::ReceiverGone,
     194            4 :             State::SenderGone(_) => State::AllGone,
     195              :             State::TryFoldFailed
     196              :             | State::ReceiverGone
     197              :             | State::AllGone
     198              :             | State::SenderDropping
     199              :             | State::ReceiverDropping => {
     200            0 :                 unreachable!("unreachable state {:?}", guard.discriminant_str())
     201              :             }
     202              :         }
     203           11 :     }
     204              : }
     205              : 
     206              : impl<T> State<T> {
     207            6 :     fn take_data(&mut self) -> Option<T> {
     208            6 :         match self {
     209              :             State::HasData(_) => {
     210            4 :                 let State::HasData(data) = std::mem::replace(self, State::NoData) else {
     211            0 :                     unreachable!("this match arm guarantees that the state is HasData");
     212              :                 };
     213            4 :                 Some(data)
     214              :             }
     215              :             State::SenderWaitsForReceiverToConsume(_) => {
     216            1 :                 let State::SenderWaitsForReceiverToConsume(data) =
     217            1 :                     std::mem::replace(self, State::NoData)
     218              :                 else {
     219            0 :                     unreachable!(
     220            0 :                         "this match arm guarantees that the state is SenderWaitsForReceiverToConsume"
     221            0 :                     );
     222              :                 };
     223            1 :                 Some(data)
     224              :             }
     225            1 :             State::SenderGone(data) => Some(data.take().unwrap()),
     226              :             State::NoData
     227              :             | State::TryFoldFailed
     228              :             | State::ReceiverGone
     229              :             | State::AllGone
     230              :             | State::SenderDropping
     231            0 :             | State::ReceiverDropping => None,
     232              :         }
     233            6 :     }
     234            0 :     fn discriminant_str(&self) -> &'static str {
     235            0 :         match self {
     236            0 :             State::NoData => "NoData",
     237            0 :             State::HasData(_) => "HasData",
     238            0 :             State::TryFoldFailed => "TryFoldFailed",
     239            0 :             State::SenderWaitsForReceiverToConsume(_) => "SenderWaitsForReceiverToConsume",
     240            0 :             State::SenderGone(_) => "SenderGone",
     241            0 :             State::ReceiverGone => "ReceiverGone",
     242            0 :             State::AllGone => "AllGone",
     243            0 :             State::SenderDropping => "SenderDropping",
     244            0 :             State::ReceiverDropping => "ReceiverDropping",
     245              :         }
     246            0 :     }
     247              : }
     248              : 
     249              : #[cfg(test)]
     250              : mod tests {
     251              : 
     252              :     use super::*;
     253              : 
     254              :     const FOREVER: std::time::Duration = std::time::Duration::from_secs(u64::MAX);
     255              : 
     256              :     #[tokio::test]
     257            1 :     async fn test_send_recv() {
     258            1 :         let (mut sender, mut receiver) = channel();
     259            1 : 
     260            1 :         sender
     261            1 :             .send(42, |acc, val| {
     262            0 :                 *acc += val;
     263            0 :                 Ok(())
     264            1 :             })
     265            1 :             .await
     266            1 :             .unwrap();
     267            1 : 
     268            1 :         let received = receiver.recv().await.unwrap();
     269            1 :         assert_eq!(received, 42);
     270            1 :     }
     271              : 
     272              :     #[tokio::test]
     273            1 :     async fn test_send_recv_with_fold() {
     274            1 :         let (mut sender, mut receiver) = channel();
     275            1 : 
     276            1 :         sender
     277            1 :             .send(1, |acc, val| {
     278            0 :                 *acc += val;
     279            0 :                 Ok(())
     280            1 :             })
     281            1 :             .await
     282            1 :             .unwrap();
     283            1 :         sender
     284            1 :             .send(2, |acc, val| {
     285            1 :                 *acc += val;
     286            1 :                 Ok(())
     287            1 :             })
     288            1 :             .await
     289            1 :             .unwrap();
     290            1 : 
     291            1 :         let received = receiver.recv().await.unwrap();
     292            1 :         assert_eq!(received, 3);
     293            1 :     }
     294              : 
     295              :     #[tokio::test(start_paused = true)]
     296            1 :     async fn test_sender_waits_for_receiver_if_try_fold_fails() {
     297            1 :         let (mut sender, mut receiver) = channel();
     298            1 : 
     299            1 :         sender.send(23, |_, _| panic!("first send")).await.unwrap();
     300            1 : 
     301            1 :         let send_fut = sender.send(42, |_, val| Err(val));
     302            1 :         let mut send_fut = std::pin::pin!(send_fut);
     303            1 : 
     304            1 :         tokio::select! {
     305            1 :             _ = tokio::time::sleep(FOREVER) => {},
     306            1 :             _ = &mut send_fut => {
     307            1 :                 panic!("send should not complete");
     308            1 :             },
     309            1 :         }
     310            1 : 
     311            1 :         let val = receiver.recv().await.unwrap();
     312            1 :         assert_eq!(val, 23);
     313            1 : 
     314            1 :         tokio::select! {
     315            1 :             _ = tokio::time::sleep(FOREVER) => {
     316            1 :                 panic!("receiver should have consumed the value");
     317            1 :             },
     318            1 :             _ = &mut send_fut => { },
     319            1 :         }
     320            1 : 
     321            1 :         let val = receiver.recv().await.unwrap();
     322            1 :         assert_eq!(val, 42);
     323            1 :     }
     324              : 
     325              :     #[tokio::test(start_paused = true)]
     326            1 :     async fn test_sender_errors_if_waits_for_receiver_and_receiver_drops() {
     327            1 :         let (mut sender, receiver) = channel();
     328            1 : 
     329            1 :         sender.send(23, |_, _| unreachable!()).await.unwrap();
     330            1 : 
     331            1 :         let send_fut = sender.send(42, |_, val| Err(val));
     332            1 :         let send_fut = std::pin::pin!(send_fut);
     333            1 : 
     334            1 :         drop(receiver);
     335            1 : 
     336            1 :         let result = send_fut.await;
     337            1 :         assert!(matches!(result, Err(SendError::ReceiverGone)));
     338            1 :     }
     339              : 
     340              :     #[tokio::test(start_paused = true)]
     341            1 :     async fn test_receiver_errors_if_waits_for_sender_and_sender_drops() {
     342            1 :         let (sender, mut receiver) = channel::<()>();
     343            1 : 
     344            1 :         let recv_fut = receiver.recv();
     345            1 :         let recv_fut = std::pin::pin!(recv_fut);
     346            1 : 
     347            1 :         drop(sender);
     348            1 : 
     349            1 :         let result = recv_fut.await;
     350            1 :         assert!(matches!(result, Err(RecvError::SenderGone)));
     351            1 :     }
     352              : 
     353              :     #[tokio::test(start_paused = true)]
     354            1 :     async fn test_receiver_errors_if_waits_for_sender_and_sender_drops_with_data() {
     355            1 :         let (mut sender, mut receiver) = channel();
     356            1 : 
     357            1 :         sender.send(42, |_, _| unreachable!()).await.unwrap();
     358            1 : 
     359            1 :         {
     360            1 :             let recv_fut = receiver.recv();
     361            1 :             let recv_fut = std::pin::pin!(recv_fut);
     362            1 : 
     363            1 :             drop(sender);
     364            1 : 
     365            1 :             let val = recv_fut.await.unwrap();
     366            1 :             assert_eq!(val, 42);
     367            1 :         }
     368            1 : 
     369            1 :         let result = receiver.recv().await;
     370            1 :         assert!(matches!(result, Err(RecvError::SenderGone)));
     371            1 :     }
     372              : 
     373              :     #[tokio::test(start_paused = true)]
     374            1 :     async fn test_receiver_waits_for_sender_if_no_data() {
     375            1 :         let (mut sender, mut receiver) = channel();
     376            1 : 
     377            1 :         let recv_fut = receiver.recv();
     378            1 :         let mut recv_fut = std::pin::pin!(recv_fut);
     379            1 : 
     380            1 :         tokio::select! {
     381            1 :             _ = tokio::time::sleep(FOREVER) => {},
     382            1 :             _ = &mut recv_fut => {
     383            1 :                 panic!("recv should not complete");
     384            1 :             },
     385            1 :         }
     386            1 : 
     387            1 :         sender.send(42, |_, _| Ok(())).await.unwrap();
     388            1 : 
     389            1 :         let val = recv_fut.await.unwrap();
     390            1 :         assert_eq!(val, 42);
     391            1 :     }
     392              : 
     393              :     #[tokio::test]
     394            1 :     async fn test_receiver_gone_while_nodata() {
     395            1 :         let (mut sender, receiver) = channel();
     396            1 :         drop(receiver);
     397            1 : 
     398            1 :         let result = sender.send(42, |_, _| Ok(())).await;
     399            1 :         assert!(matches!(result, Err(SendError::ReceiverGone)));
     400            1 :     }
     401              : 
     402              :     #[tokio::test]
     403            1 :     async fn test_sender_gone_while_nodata() {
     404            1 :         let (sender, mut receiver) = super::channel::<usize>();
     405            1 :         drop(sender);
     406            1 : 
     407            1 :         let result = receiver.recv().await;
     408            1 :         assert!(matches!(result, Err(RecvError::SenderGone)));
     409            1 :     }
     410              : 
     411              :     #[tokio::test(start_paused = true)]
     412            1 :     async fn test_receiver_drops_after_sender_went_to_sleep() {
     413            1 :         let (mut sender, receiver) = channel();
     414            1 :         let state = receiver.state.clone();
     415            1 : 
     416            1 :         sender.send(23, |_, _| unreachable!()).await.unwrap();
     417            1 : 
     418            1 :         let send_task = tokio::spawn(async move { sender.send(42, |_, v| Err(v)).await });
     419            1 : 
     420            1 :         tokio::time::sleep(FOREVER).await;
     421            1 : 
     422            1 :         assert!(matches!(
     423            1 :             &*state.value.lock().unwrap(),
     424            1 :             &State::SenderWaitsForReceiverToConsume(_)
     425            1 :         ));
     426            1 : 
     427            1 :         drop(receiver);
     428            1 : 
     429            1 :         let err = send_task
     430            1 :             .await
     431            1 :             .unwrap()
     432            1 :             .expect_err("should unblock immediately");
     433            1 :         assert!(matches!(err, SendError::ReceiverGone));
     434            1 :     }
     435              : 
     436              :     #[tokio::test(start_paused = true)]
     437            1 :     async fn test_sender_drops_after_receiver_went_to_sleep() {
     438            1 :         let (sender, mut receiver) = channel::<usize>();
     439            1 :         let state = sender.state.clone();
     440            1 : 
     441            1 :         let recv_task = tokio::spawn(async move { receiver.recv().await });
     442            1 : 
     443            1 :         tokio::time::sleep(FOREVER).await;
     444            1 : 
     445            1 :         assert!(matches!(&*state.value.lock().unwrap(), &State::NoData));
     446            1 : 
     447            1 :         drop(sender);
     448            1 : 
     449            1 :         let err = recv_task.await.unwrap().expect_err("should error");
     450            1 :         assert!(matches!(err, RecvError::SenderGone));
     451            1 :     }
     452              : }
        

Generated by: LCOV version 2.1-beta