Line data Source code
1 : use core::future::poll_fn;
2 : use core::task::Poll;
3 : use std::sync::{Arc, Mutex};
4 :
5 : use diatomic_waker::DiatomicWaker;
6 :
7 : pub struct Sender<T> {
8 : state: Arc<Inner<T>>,
9 : }
10 :
11 : pub struct Receiver<T> {
12 : state: Arc<Inner<T>>,
13 : }
14 :
15 : struct Inner<T> {
16 : wake_receiver: DiatomicWaker,
17 : wake_sender: DiatomicWaker,
18 : value: Mutex<State<T>>,
19 : }
20 :
21 : enum State<T> {
22 : NoData,
23 : HasData(T),
24 : TryFoldFailed, // transient state
25 : SenderWaitsForReceiverToConsume(T),
26 : SenderGone(Option<T>),
27 : ReceiverGone,
28 : AllGone,
29 : SenderDropping, // transient state
30 : ReceiverDropping, // transient state
31 : }
32 :
33 12 : pub fn channel<T: Send>() -> (Sender<T>, Receiver<T>) {
34 12 : let inner = Inner {
35 12 : wake_receiver: DiatomicWaker::new(),
36 12 : wake_sender: DiatomicWaker::new(),
37 12 : value: Mutex::new(State::NoData),
38 12 : };
39 12 :
40 12 : let state = Arc::new(inner);
41 12 : (
42 12 : Sender {
43 12 : state: state.clone(),
44 12 : },
45 12 : Receiver { state },
46 12 : )
47 12 : }
48 :
49 : #[derive(Debug, thiserror::Error)]
50 : pub enum SendError {
51 : #[error("receiver is gone")]
52 : ReceiverGone,
53 : }
54 :
55 : impl<T: Send> Sender<T> {
56 : /// # Panics
57 : ///
58 : /// If `try_fold` panics, any subsequent call to `send` panic.
59 14 : pub async fn send<F>(&mut self, value: T, try_fold: F) -> Result<(), SendError>
60 14 : where
61 14 : F: Fn(&mut T, T) -> Result<(), T>,
62 14 : {
63 14 : let mut value = Some(value);
64 17 : poll_fn(|cx| {
65 17 : let mut guard = self.state.value.lock().unwrap();
66 17 : match &mut *guard {
67 : State::NoData => {
68 9 : *guard = State::HasData(value.take().unwrap());
69 9 : self.state.wake_receiver.notify();
70 9 : Poll::Ready(Ok(()))
71 : }
72 : State::HasData(_) => {
73 4 : let State::HasData(acc_mut) = &mut *guard else {
74 0 : unreachable!("this match arm guarantees that the guard is HasData");
75 : };
76 4 : match try_fold(acc_mut, value.take().unwrap()) {
77 : Ok(()) => {
78 : // no need to wake receiver, if it was waiting it already
79 : // got a wake-up when we transitioned from NoData to HasData
80 1 : Poll::Ready(Ok(()))
81 : }
82 3 : Err(unfoldable_value) => {
83 3 : value = Some(unfoldable_value);
84 3 : let State::HasData(acc) =
85 3 : std::mem::replace(&mut *guard, State::TryFoldFailed)
86 : else {
87 0 : unreachable!("this match arm guarantees that the guard is HasData");
88 : };
89 3 : *guard = State::SenderWaitsForReceiverToConsume(acc);
90 3 : // SAFETY: send is single threaded due to `&mut self` requirement,
91 3 : // therefore register is not concurrent.
92 3 : unsafe {
93 3 : self.state.wake_sender.register(cx.waker());
94 3 : }
95 3 : Poll::Pending
96 : }
97 : }
98 : }
99 0 : State::SenderWaitsForReceiverToConsume(_data) => {
100 0 : // SAFETY: send is single threaded due to `&mut self` requirement,
101 0 : // therefore register is not concurrent.
102 0 : unsafe {
103 0 : self.state.wake_sender.register(cx.waker());
104 0 : }
105 0 : Poll::Pending
106 : }
107 4 : State::ReceiverGone => Poll::Ready(Err(SendError::ReceiverGone)),
108 : State::SenderGone(_)
109 : | State::AllGone
110 : | State::SenderDropping
111 : | State::ReceiverDropping
112 : | State::TryFoldFailed => {
113 0 : unreachable!();
114 : }
115 : }
116 17 : })
117 14 : .await
118 14 : }
119 : }
120 :
121 : impl<T> Drop for Sender<T> {
122 12 : fn drop(&mut self) {
123 12 : scopeguard::defer! {
124 12 : self.state.wake_receiver.notify()
125 12 : };
126 12 : let Ok(mut guard) = self.state.value.lock() else {
127 0 : return;
128 : };
129 12 : *guard = match std::mem::replace(&mut *guard, State::SenderDropping) {
130 3 : State::NoData => State::SenderGone(None),
131 1 : State::HasData(data) | State::SenderWaitsForReceiverToConsume(data) => {
132 1 : State::SenderGone(Some(data))
133 : }
134 8 : State::ReceiverGone => State::AllGone,
135 : State::TryFoldFailed
136 : | State::SenderGone(_)
137 : | State::AllGone
138 : | State::SenderDropping
139 : | State::ReceiverDropping => {
140 0 : unreachable!("unreachable state {:?}", guard.discriminant_str())
141 : }
142 : }
143 12 : }
144 : }
145 :
146 : #[derive(Debug, thiserror::Error)]
147 : pub enum RecvError {
148 : #[error("sender is gone")]
149 : SenderGone,
150 : }
151 :
152 : impl<T: Send> Receiver<T> {
153 10 : pub async fn recv(&mut self) -> Result<T, RecvError> {
154 12 : poll_fn(|cx| {
155 12 : let mut guard = self.state.value.lock().unwrap();
156 12 : match &mut *guard {
157 : State::NoData => {
158 : // SAFETY: recv is single threaded due to `&mut self` requirement,
159 : // therefore register is not concurrent.
160 2 : unsafe {
161 2 : self.state.wake_receiver.register(cx.waker());
162 2 : }
163 2 : Poll::Pending
164 : }
165 4 : guard @ State::HasData(_)
166 1 : | guard @ State::SenderWaitsForReceiverToConsume(_)
167 1 : | guard @ State::SenderGone(Some(_)) => {
168 6 : let data = guard
169 6 : .take_data()
170 6 : .expect("in these states, data is guaranteed to be present");
171 6 : self.state.wake_sender.notify();
172 6 : Poll::Ready(Ok(data))
173 : }
174 4 : State::SenderGone(None) => Poll::Ready(Err(RecvError::SenderGone)),
175 : State::ReceiverGone
176 : | State::AllGone
177 : | State::SenderDropping
178 : | State::ReceiverDropping
179 : | State::TryFoldFailed => {
180 0 : unreachable!("unreachable state {:?}", guard.discriminant_str());
181 : }
182 : }
183 12 : })
184 10 : .await
185 10 : }
186 : }
187 :
188 : impl<T> Drop for Receiver<T> {
189 12 : fn drop(&mut self) {
190 12 : scopeguard::defer! {
191 12 : self.state.wake_sender.notify()
192 12 : };
193 12 : let Ok(mut guard) = self.state.value.lock() else {
194 0 : return;
195 : };
196 12 : *guard = match std::mem::replace(&mut *guard, State::ReceiverDropping) {
197 5 : State::NoData => State::ReceiverGone,
198 3 : State::HasData(_) | State::SenderWaitsForReceiverToConsume(_) => State::ReceiverGone,
199 4 : State::SenderGone(_) => State::AllGone,
200 : State::TryFoldFailed
201 : | State::ReceiverGone
202 : | State::AllGone
203 : | State::SenderDropping
204 : | State::ReceiverDropping => {
205 0 : unreachable!("unreachable state {:?}", guard.discriminant_str())
206 : }
207 : }
208 12 : }
209 : }
210 :
211 : impl<T> State<T> {
212 6 : fn take_data(&mut self) -> Option<T> {
213 6 : match self {
214 : State::HasData(_) => {
215 4 : let State::HasData(data) = std::mem::replace(self, State::NoData) else {
216 0 : unreachable!("this match arm guarantees that the state is HasData");
217 : };
218 4 : Some(data)
219 : }
220 : State::SenderWaitsForReceiverToConsume(_) => {
221 1 : let State::SenderWaitsForReceiverToConsume(data) =
222 1 : std::mem::replace(self, State::NoData)
223 : else {
224 0 : unreachable!(
225 0 : "this match arm guarantees that the state is SenderWaitsForReceiverToConsume"
226 0 : );
227 : };
228 1 : Some(data)
229 : }
230 1 : State::SenderGone(data) => Some(data.take().unwrap()),
231 : State::NoData
232 : | State::TryFoldFailed
233 : | State::ReceiverGone
234 : | State::AllGone
235 : | State::SenderDropping
236 0 : | State::ReceiverDropping => None,
237 : }
238 6 : }
239 0 : fn discriminant_str(&self) -> &'static str {
240 0 : match self {
241 0 : State::NoData => "NoData",
242 0 : State::HasData(_) => "HasData",
243 0 : State::TryFoldFailed => "TryFoldFailed",
244 0 : State::SenderWaitsForReceiverToConsume(_) => "SenderWaitsForReceiverToConsume",
245 0 : State::SenderGone(_) => "SenderGone",
246 0 : State::ReceiverGone => "ReceiverGone",
247 0 : State::AllGone => "AllGone",
248 0 : State::SenderDropping => "SenderDropping",
249 0 : State::ReceiverDropping => "ReceiverDropping",
250 : }
251 0 : }
252 : }
253 :
254 : #[cfg(test)]
255 : mod tests {
256 :
257 : use super::*;
258 :
259 : const FOREVER: std::time::Duration = std::time::Duration::from_secs(u64::MAX);
260 :
261 : #[tokio::test]
262 1 : async fn test_send_recv() {
263 1 : let (mut sender, mut receiver) = channel();
264 1 :
265 1 : sender
266 1 : .send(42, |acc, val| {
267 0 : *acc += val;
268 0 : Ok(())
269 1 : })
270 1 : .await
271 1 : .unwrap();
272 1 :
273 1 : let received = receiver.recv().await.unwrap();
274 1 : assert_eq!(received, 42);
275 1 : }
276 :
277 : #[tokio::test]
278 1 : async fn test_send_recv_with_fold() {
279 1 : let (mut sender, mut receiver) = channel();
280 1 :
281 1 : sender
282 1 : .send(1, |acc, val| {
283 0 : *acc += val;
284 0 : Ok(())
285 1 : })
286 1 : .await
287 1 : .unwrap();
288 1 : sender
289 1 : .send(2, |acc, val| {
290 1 : *acc += val;
291 1 : Ok(())
292 1 : })
293 1 : .await
294 1 : .unwrap();
295 1 :
296 1 : let received = receiver.recv().await.unwrap();
297 1 : assert_eq!(received, 3);
298 1 : }
299 :
300 : #[tokio::test(start_paused = true)]
301 1 : async fn test_sender_waits_for_receiver_if_try_fold_fails() {
302 1 : let (mut sender, mut receiver) = channel();
303 1 :
304 1 : sender.send(23, |_, _| panic!("first send")).await.unwrap();
305 1 :
306 1 : let send_fut = sender.send(42, |_, val| Err(val));
307 1 : let mut send_fut = std::pin::pin!(send_fut);
308 1 :
309 1 : tokio::select! {
310 1 : _ = tokio::time::sleep(FOREVER) => {},
311 1 : _ = &mut send_fut => {
312 1 : panic!("send should not complete");
313 1 : },
314 1 : }
315 1 :
316 1 : let val = receiver.recv().await.unwrap();
317 1 : assert_eq!(val, 23);
318 1 :
319 1 : tokio::select! {
320 1 : _ = tokio::time::sleep(FOREVER) => {
321 1 : panic!("receiver should have consumed the value");
322 1 : },
323 1 : _ = &mut send_fut => { },
324 1 : }
325 1 :
326 1 : let val = receiver.recv().await.unwrap();
327 1 : assert_eq!(val, 42);
328 1 : }
329 :
330 : #[tokio::test(start_paused = true)]
331 1 : async fn test_sender_errors_if_waits_for_receiver_and_receiver_drops() {
332 1 : let (mut sender, receiver) = channel();
333 1 :
334 1 : sender.send(23, |_, _| unreachable!()).await.unwrap();
335 1 :
336 1 : let send_fut = sender.send(42, |_, val| Err(val));
337 1 : let send_fut = std::pin::pin!(send_fut);
338 1 :
339 1 : drop(receiver);
340 1 :
341 1 : let result = send_fut.await;
342 1 : assert!(matches!(result, Err(SendError::ReceiverGone)));
343 1 : }
344 :
345 : #[tokio::test(start_paused = true)]
346 1 : async fn test_receiver_errors_if_waits_for_sender_and_sender_drops() {
347 1 : let (sender, mut receiver) = channel::<()>();
348 1 :
349 1 : let recv_fut = receiver.recv();
350 1 : let recv_fut = std::pin::pin!(recv_fut);
351 1 :
352 1 : drop(sender);
353 1 :
354 1 : let result = recv_fut.await;
355 1 : assert!(matches!(result, Err(RecvError::SenderGone)));
356 1 : }
357 :
358 : #[tokio::test(start_paused = true)]
359 1 : async fn test_receiver_errors_if_waits_for_sender_and_sender_drops_with_data() {
360 1 : let (mut sender, mut receiver) = channel();
361 1 :
362 1 : sender.send(42, |_, _| unreachable!()).await.unwrap();
363 1 :
364 1 : {
365 1 : let recv_fut = receiver.recv();
366 1 : let recv_fut = std::pin::pin!(recv_fut);
367 1 :
368 1 : drop(sender);
369 1 :
370 1 : let val = recv_fut.await.unwrap();
371 1 : assert_eq!(val, 42);
372 1 : }
373 1 :
374 1 : let result = receiver.recv().await;
375 1 : assert!(matches!(result, Err(RecvError::SenderGone)));
376 1 : }
377 :
378 : #[tokio::test(start_paused = true)]
379 1 : async fn test_receiver_waits_for_sender_if_no_data() {
380 1 : let (mut sender, mut receiver) = channel();
381 1 :
382 1 : let recv_fut = receiver.recv();
383 1 : let mut recv_fut = std::pin::pin!(recv_fut);
384 1 :
385 1 : tokio::select! {
386 1 : _ = tokio::time::sleep(FOREVER) => {},
387 1 : _ = &mut recv_fut => {
388 1 : panic!("recv should not complete");
389 1 : },
390 1 : }
391 1 :
392 1 : sender.send(42, |_, _| Ok(())).await.unwrap();
393 1 :
394 1 : let val = recv_fut.await.unwrap();
395 1 : assert_eq!(val, 42);
396 1 : }
397 :
398 : #[tokio::test]
399 1 : async fn test_receiver_gone_while_nodata() {
400 1 : let (mut sender, receiver) = channel();
401 1 : drop(receiver);
402 1 :
403 1 : let result = sender.send(42, |_, _| Ok(())).await;
404 1 : assert!(matches!(result, Err(SendError::ReceiverGone)));
405 1 : }
406 :
407 : #[tokio::test]
408 1 : async fn test_sender_gone_while_nodata() {
409 1 : let (sender, mut receiver) = super::channel::<usize>();
410 1 : drop(sender);
411 1 :
412 1 : let result = receiver.recv().await;
413 1 : assert!(matches!(result, Err(RecvError::SenderGone)));
414 1 : }
415 :
416 : #[tokio::test(start_paused = true)]
417 1 : async fn test_receiver_drops_after_sender_went_to_sleep() {
418 1 : let (mut sender, receiver) = channel();
419 1 : let state = receiver.state.clone();
420 1 :
421 1 : sender.send(23, |_, _| unreachable!()).await.unwrap();
422 1 :
423 1 : let send_task = tokio::spawn(async move { sender.send(42, |_, v| Err(v)).await });
424 1 :
425 1 : tokio::time::sleep(FOREVER).await;
426 1 :
427 1 : assert!(matches!(
428 1 : &*state.value.lock().unwrap(),
429 1 : &State::SenderWaitsForReceiverToConsume(_)
430 1 : ));
431 1 :
432 1 : drop(receiver);
433 1 :
434 1 : let err = send_task
435 1 : .await
436 1 : .unwrap()
437 1 : .expect_err("should unblock immediately");
438 1 : assert!(matches!(err, SendError::ReceiverGone));
439 1 : }
440 :
441 : #[tokio::test(start_paused = true)]
442 1 : async fn test_sender_drops_after_receiver_went_to_sleep() {
443 1 : let (sender, mut receiver) = channel::<usize>();
444 1 : let state = sender.state.clone();
445 1 :
446 1 : let recv_task = tokio::spawn(async move { receiver.recv().await });
447 1 :
448 1 : tokio::time::sleep(FOREVER).await;
449 1 :
450 1 : assert!(matches!(&*state.value.lock().unwrap(), &State::NoData));
451 1 :
452 1 : drop(sender);
453 1 :
454 1 : let err = recv_task.await.unwrap().expect_err("should error");
455 1 : assert!(matches!(err, RecvError::SenderGone));
456 1 : }
457 :
458 : #[tokio::test(start_paused = true)]
459 1 : async fn test_receiver_drop_while_waiting_for_receiver_to_consume_unblocks_sender() {
460 1 : let (mut sender, receiver) = channel();
461 1 :
462 1 : let state = receiver.state.clone();
463 1 :
464 1 : sender.send((), |_, _| unreachable!()).await.unwrap();
465 1 :
466 1 : assert!(matches!(&*state.value.lock().unwrap(), &State::HasData(_)));
467 1 :
468 1 : let unmergeable = sender.send((), |_, _| Err(()));
469 1 : let mut unmergeable = std::pin::pin!(unmergeable);
470 1 : tokio::select! {
471 1 : _ = tokio::time::sleep(FOREVER) => {},
472 1 : _ = &mut unmergeable => {
473 1 : panic!("unmergeable should not complete");
474 1 : },
475 1 : }
476 1 :
477 1 : assert!(matches!(
478 1 : &*state.value.lock().unwrap(),
479 1 : &State::SenderWaitsForReceiverToConsume(_)
480 1 : ));
481 1 :
482 1 : drop(receiver);
483 1 :
484 1 : assert!(matches!(
485 1 : &*state.value.lock().unwrap(),
486 1 : &State::ReceiverGone
487 1 : ));
488 1 :
489 1 : unmergeable.await.unwrap_err();
490 1 : }
491 : }
|