Line data Source code
1 : use core::{future::poll_fn, task::Poll};
2 : use std::sync::{Arc, Mutex};
3 :
4 : use diatomic_waker::DiatomicWaker;
5 :
6 : pub struct Sender<T> {
7 : state: Arc<Inner<T>>,
8 : }
9 :
10 : pub struct Receiver<T> {
11 : state: Arc<Inner<T>>,
12 : }
13 :
14 : struct Inner<T> {
15 : wake_receiver: DiatomicWaker,
16 : wake_sender: DiatomicWaker,
17 : value: Mutex<State<T>>,
18 : }
19 :
20 : enum State<T> {
21 : NoData,
22 : HasData(T),
23 : TryFoldFailed, // transient state
24 : SenderWaitsForReceiverToConsume(T),
25 : SenderGone(Option<T>),
26 : ReceiverGone,
27 : AllGone,
28 : SenderDropping, // transient state
29 : ReceiverDropping, // transient state
30 : }
31 :
32 12 : pub fn channel<T: Send>() -> (Sender<T>, Receiver<T>) {
33 12 : let inner = Inner {
34 12 : wake_receiver: DiatomicWaker::new(),
35 12 : wake_sender: DiatomicWaker::new(),
36 12 : value: Mutex::new(State::NoData),
37 12 : };
38 12 :
39 12 : let state = Arc::new(inner);
40 12 : (
41 12 : Sender {
42 12 : state: state.clone(),
43 12 : },
44 12 : Receiver { state },
45 12 : )
46 12 : }
47 :
48 : #[derive(Debug, thiserror::Error)]
49 : pub enum SendError {
50 : #[error("receiver is gone")]
51 : ReceiverGone,
52 : }
53 :
54 : impl<T: Send> Sender<T> {
55 : /// # Panics
56 : ///
57 : /// If `try_fold` panics, any subsequent call to `send` panic.
58 14 : pub async fn send<F>(&mut self, value: T, try_fold: F) -> Result<(), SendError>
59 14 : where
60 14 : F: Fn(&mut T, T) -> Result<(), T>,
61 14 : {
62 14 : let mut value = Some(value);
63 18 : poll_fn(|cx| {
64 18 : let mut guard = self.state.value.lock().unwrap();
65 18 : match &mut *guard {
66 : State::NoData => {
67 9 : *guard = State::HasData(value.take().unwrap());
68 9 : self.state.wake_receiver.notify();
69 9 : Poll::Ready(Ok(()))
70 : }
71 : State::HasData(_) => {
72 4 : let State::HasData(acc_mut) = &mut *guard else {
73 0 : unreachable!("this match arm guarantees that the guard is HasData");
74 : };
75 4 : match try_fold(acc_mut, value.take().unwrap()) {
76 : Ok(()) => {
77 : // no need to wake receiver, if it was waiting it already
78 : // got a wake-up when we transitioned from NoData to HasData
79 1 : Poll::Ready(Ok(()))
80 : }
81 3 : Err(unfoldable_value) => {
82 3 : value = Some(unfoldable_value);
83 3 : let State::HasData(acc) =
84 3 : std::mem::replace(&mut *guard, State::TryFoldFailed)
85 : else {
86 0 : unreachable!("this match arm guarantees that the guard is HasData");
87 : };
88 3 : *guard = State::SenderWaitsForReceiverToConsume(acc);
89 3 : // SAFETY: send is single threaded due to `&mut self` requirement,
90 3 : // therefore register is not concurrent.
91 3 : unsafe {
92 3 : self.state.wake_sender.register(cx.waker());
93 3 : }
94 3 : Poll::Pending
95 : }
96 : }
97 : }
98 1 : State::SenderWaitsForReceiverToConsume(_data) => {
99 1 : // SAFETY: send is single threaded due to `&mut self` requirement,
100 1 : // therefore register is not concurrent.
101 1 : unsafe {
102 1 : self.state.wake_sender.register(cx.waker());
103 1 : }
104 1 : Poll::Pending
105 : }
106 4 : State::ReceiverGone => Poll::Ready(Err(SendError::ReceiverGone)),
107 : State::SenderGone(_)
108 : | State::AllGone
109 : | State::SenderDropping
110 : | State::ReceiverDropping
111 : | State::TryFoldFailed => {
112 0 : unreachable!();
113 : }
114 : }
115 18 : })
116 14 : .await
117 14 : }
118 : }
119 :
120 : impl<T> Drop for Sender<T> {
121 12 : fn drop(&mut self) {
122 12 : scopeguard::defer! {
123 12 : self.state.wake_receiver.notify()
124 12 : };
125 12 : let Ok(mut guard) = self.state.value.lock() else {
126 0 : return;
127 : };
128 12 : *guard = match std::mem::replace(&mut *guard, State::SenderDropping) {
129 3 : State::NoData => State::SenderGone(None),
130 1 : State::HasData(data) | State::SenderWaitsForReceiverToConsume(data) => {
131 1 : State::SenderGone(Some(data))
132 : }
133 8 : State::ReceiverGone => State::AllGone,
134 : State::TryFoldFailed
135 : | State::SenderGone(_)
136 : | State::AllGone
137 : | State::SenderDropping
138 : | State::ReceiverDropping => {
139 0 : unreachable!("unreachable state {:?}", guard.discriminant_str())
140 : }
141 : }
142 12 : }
143 : }
144 :
145 : #[derive(Debug, thiserror::Error)]
146 : pub enum RecvError {
147 : #[error("sender is gone")]
148 : SenderGone,
149 : }
150 :
151 : impl<T: Send> Receiver<T> {
152 10 : pub async fn recv(&mut self) -> Result<T, RecvError> {
153 12 : poll_fn(|cx| {
154 12 : let mut guard = self.state.value.lock().unwrap();
155 12 : match &mut *guard {
156 : State::NoData => {
157 : // SAFETY: recv is single threaded due to `&mut self` requirement,
158 : // therefore register is not concurrent.
159 2 : unsafe {
160 2 : self.state.wake_receiver.register(cx.waker());
161 2 : }
162 2 : Poll::Pending
163 : }
164 4 : guard @ State::HasData(_)
165 1 : | guard @ State::SenderWaitsForReceiverToConsume(_)
166 1 : | guard @ State::SenderGone(Some(_)) => {
167 6 : let data = guard
168 6 : .take_data()
169 6 : .expect("in these states, data is guaranteed to be present");
170 6 : self.state.wake_sender.notify();
171 6 : Poll::Ready(Ok(data))
172 : }
173 4 : State::SenderGone(None) => Poll::Ready(Err(RecvError::SenderGone)),
174 : State::ReceiverGone
175 : | State::AllGone
176 : | State::SenderDropping
177 : | State::ReceiverDropping
178 : | State::TryFoldFailed => {
179 0 : unreachable!("unreachable state {:?}", guard.discriminant_str());
180 : }
181 : }
182 12 : })
183 10 : .await
184 10 : }
185 : }
186 :
187 : impl<T> Drop for Receiver<T> {
188 12 : fn drop(&mut self) {
189 12 : scopeguard::defer! {
190 12 : self.state.wake_sender.notify()
191 12 : };
192 12 : let Ok(mut guard) = self.state.value.lock() else {
193 0 : return;
194 : };
195 12 : *guard = match std::mem::replace(&mut *guard, State::ReceiverDropping) {
196 5 : State::NoData => State::ReceiverGone,
197 3 : State::HasData(_) | State::SenderWaitsForReceiverToConsume(_) => State::ReceiverGone,
198 4 : State::SenderGone(_) => State::AllGone,
199 : State::TryFoldFailed
200 : | State::ReceiverGone
201 : | State::AllGone
202 : | State::SenderDropping
203 : | State::ReceiverDropping => {
204 0 : unreachable!("unreachable state {:?}", guard.discriminant_str())
205 : }
206 : }
207 12 : }
208 : }
209 :
210 : impl<T> State<T> {
211 6 : fn take_data(&mut self) -> Option<T> {
212 6 : match self {
213 : State::HasData(_) => {
214 4 : let State::HasData(data) = std::mem::replace(self, State::NoData) else {
215 0 : unreachable!("this match arm guarantees that the state is HasData");
216 : };
217 4 : Some(data)
218 : }
219 : State::SenderWaitsForReceiverToConsume(_) => {
220 1 : let State::SenderWaitsForReceiverToConsume(data) =
221 1 : std::mem::replace(self, State::NoData)
222 : else {
223 0 : unreachable!(
224 0 : "this match arm guarantees that the state is SenderWaitsForReceiverToConsume"
225 0 : );
226 : };
227 1 : Some(data)
228 : }
229 1 : State::SenderGone(data) => Some(data.take().unwrap()),
230 : State::NoData
231 : | State::TryFoldFailed
232 : | State::ReceiverGone
233 : | State::AllGone
234 : | State::SenderDropping
235 0 : | State::ReceiverDropping => None,
236 : }
237 6 : }
238 0 : fn discriminant_str(&self) -> &'static str {
239 0 : match self {
240 0 : State::NoData => "NoData",
241 0 : State::HasData(_) => "HasData",
242 0 : State::TryFoldFailed => "TryFoldFailed",
243 0 : State::SenderWaitsForReceiverToConsume(_) => "SenderWaitsForReceiverToConsume",
244 0 : State::SenderGone(_) => "SenderGone",
245 0 : State::ReceiverGone => "ReceiverGone",
246 0 : State::AllGone => "AllGone",
247 0 : State::SenderDropping => "SenderDropping",
248 0 : State::ReceiverDropping => "ReceiverDropping",
249 : }
250 0 : }
251 : }
252 :
253 : #[cfg(test)]
254 : mod tests {
255 :
256 : use super::*;
257 :
258 : const FOREVER: std::time::Duration = std::time::Duration::from_secs(u64::MAX);
259 :
260 : #[tokio::test]
261 1 : async fn test_send_recv() {
262 1 : let (mut sender, mut receiver) = channel();
263 1 :
264 1 : sender
265 1 : .send(42, |acc, val| {
266 0 : *acc += val;
267 0 : Ok(())
268 1 : })
269 1 : .await
270 1 : .unwrap();
271 1 :
272 1 : let received = receiver.recv().await.unwrap();
273 1 : assert_eq!(received, 42);
274 1 : }
275 :
276 : #[tokio::test]
277 1 : async fn test_send_recv_with_fold() {
278 1 : let (mut sender, mut receiver) = channel();
279 1 :
280 1 : sender
281 1 : .send(1, |acc, val| {
282 0 : *acc += val;
283 0 : Ok(())
284 1 : })
285 1 : .await
286 1 : .unwrap();
287 1 : sender
288 1 : .send(2, |acc, val| {
289 1 : *acc += val;
290 1 : Ok(())
291 1 : })
292 1 : .await
293 1 : .unwrap();
294 1 :
295 1 : let received = receiver.recv().await.unwrap();
296 1 : assert_eq!(received, 3);
297 1 : }
298 :
299 : #[tokio::test(start_paused = true)]
300 1 : async fn test_sender_waits_for_receiver_if_try_fold_fails() {
301 1 : let (mut sender, mut receiver) = channel();
302 1 :
303 1 : sender.send(23, |_, _| panic!("first send")).await.unwrap();
304 1 :
305 1 : let send_fut = sender.send(42, |_, val| Err(val));
306 1 : let mut send_fut = std::pin::pin!(send_fut);
307 1 :
308 1 : tokio::select! {
309 1 : _ = tokio::time::sleep(FOREVER) => {},
310 1 : _ = &mut send_fut => {
311 1 : panic!("send should not complete");
312 1 : },
313 1 : }
314 1 :
315 1 : let val = receiver.recv().await.unwrap();
316 1 : assert_eq!(val, 23);
317 1 :
318 1 : tokio::select! {
319 1 : _ = tokio::time::sleep(FOREVER) => {
320 1 : panic!("receiver should have consumed the value");
321 1 : },
322 1 : _ = &mut send_fut => { },
323 1 : }
324 1 :
325 1 : let val = receiver.recv().await.unwrap();
326 1 : assert_eq!(val, 42);
327 1 : }
328 :
329 : #[tokio::test(start_paused = true)]
330 1 : async fn test_sender_errors_if_waits_for_receiver_and_receiver_drops() {
331 1 : let (mut sender, receiver) = channel();
332 1 :
333 1 : sender.send(23, |_, _| unreachable!()).await.unwrap();
334 1 :
335 1 : let send_fut = sender.send(42, |_, val| Err(val));
336 1 : let send_fut = std::pin::pin!(send_fut);
337 1 :
338 1 : drop(receiver);
339 1 :
340 1 : let result = send_fut.await;
341 1 : assert!(matches!(result, Err(SendError::ReceiverGone)));
342 1 : }
343 :
344 : #[tokio::test(start_paused = true)]
345 1 : async fn test_receiver_errors_if_waits_for_sender_and_sender_drops() {
346 1 : let (sender, mut receiver) = channel::<()>();
347 1 :
348 1 : let recv_fut = receiver.recv();
349 1 : let recv_fut = std::pin::pin!(recv_fut);
350 1 :
351 1 : drop(sender);
352 1 :
353 1 : let result = recv_fut.await;
354 1 : assert!(matches!(result, Err(RecvError::SenderGone)));
355 1 : }
356 :
357 : #[tokio::test(start_paused = true)]
358 1 : async fn test_receiver_errors_if_waits_for_sender_and_sender_drops_with_data() {
359 1 : let (mut sender, mut receiver) = channel();
360 1 :
361 1 : sender.send(42, |_, _| unreachable!()).await.unwrap();
362 1 :
363 1 : {
364 1 : let recv_fut = receiver.recv();
365 1 : let recv_fut = std::pin::pin!(recv_fut);
366 1 :
367 1 : drop(sender);
368 1 :
369 1 : let val = recv_fut.await.unwrap();
370 1 : assert_eq!(val, 42);
371 1 : }
372 1 :
373 1 : let result = receiver.recv().await;
374 1 : assert!(matches!(result, Err(RecvError::SenderGone)));
375 1 : }
376 :
377 : #[tokio::test(start_paused = true)]
378 1 : async fn test_receiver_waits_for_sender_if_no_data() {
379 1 : let (mut sender, mut receiver) = channel();
380 1 :
381 1 : let recv_fut = receiver.recv();
382 1 : let mut recv_fut = std::pin::pin!(recv_fut);
383 1 :
384 1 : tokio::select! {
385 1 : _ = tokio::time::sleep(FOREVER) => {},
386 1 : _ = &mut recv_fut => {
387 1 : panic!("recv should not complete");
388 1 : },
389 1 : }
390 1 :
391 1 : sender.send(42, |_, _| Ok(())).await.unwrap();
392 1 :
393 1 : let val = recv_fut.await.unwrap();
394 1 : assert_eq!(val, 42);
395 1 : }
396 :
397 : #[tokio::test]
398 1 : async fn test_receiver_gone_while_nodata() {
399 1 : let (mut sender, receiver) = channel();
400 1 : drop(receiver);
401 1 :
402 1 : let result = sender.send(42, |_, _| Ok(())).await;
403 1 : assert!(matches!(result, Err(SendError::ReceiverGone)));
404 1 : }
405 :
406 : #[tokio::test]
407 1 : async fn test_sender_gone_while_nodata() {
408 1 : let (sender, mut receiver) = super::channel::<usize>();
409 1 : drop(sender);
410 1 :
411 1 : let result = receiver.recv().await;
412 1 : assert!(matches!(result, Err(RecvError::SenderGone)));
413 1 : }
414 :
415 : #[tokio::test(start_paused = true)]
416 1 : async fn test_receiver_drops_after_sender_went_to_sleep() {
417 1 : let (mut sender, receiver) = channel();
418 1 : let state = receiver.state.clone();
419 1 :
420 1 : sender.send(23, |_, _| unreachable!()).await.unwrap();
421 1 :
422 1 : let send_task = tokio::spawn(async move { sender.send(42, |_, v| Err(v)).await });
423 1 :
424 1 : tokio::time::sleep(FOREVER).await;
425 1 :
426 1 : assert!(matches!(
427 1 : &*state.value.lock().unwrap(),
428 1 : &State::SenderWaitsForReceiverToConsume(_)
429 1 : ));
430 1 :
431 1 : drop(receiver);
432 1 :
433 1 : let err = send_task
434 1 : .await
435 1 : .unwrap()
436 1 : .expect_err("should unblock immediately");
437 1 : assert!(matches!(err, SendError::ReceiverGone));
438 1 : }
439 :
440 : #[tokio::test(start_paused = true)]
441 1 : async fn test_sender_drops_after_receiver_went_to_sleep() {
442 1 : let (sender, mut receiver) = channel::<usize>();
443 1 : let state = sender.state.clone();
444 1 :
445 1 : let recv_task = tokio::spawn(async move { receiver.recv().await });
446 1 :
447 1 : tokio::time::sleep(FOREVER).await;
448 1 :
449 1 : assert!(matches!(&*state.value.lock().unwrap(), &State::NoData));
450 1 :
451 1 : drop(sender);
452 1 :
453 1 : let err = recv_task.await.unwrap().expect_err("should error");
454 1 : assert!(matches!(err, RecvError::SenderGone));
455 1 : }
456 :
457 : #[tokio::test(start_paused = true)]
458 1 : async fn test_receiver_drop_while_waiting_for_receiver_to_consume_unblocks_sender() {
459 1 : let (mut sender, receiver) = channel();
460 1 :
461 1 : let state = receiver.state.clone();
462 1 :
463 1 : sender.send((), |_, _| unreachable!()).await.unwrap();
464 1 :
465 1 : assert!(matches!(&*state.value.lock().unwrap(), &State::HasData(_)));
466 1 :
467 1 : let unmergeable = sender.send((), |_, _| Err(()));
468 1 : let mut unmergeable = std::pin::pin!(unmergeable);
469 1 : tokio::select! {
470 1 : _ = tokio::time::sleep(FOREVER) => {},
471 1 : _ = &mut unmergeable => {
472 1 : panic!("unmergeable should not complete");
473 1 : },
474 1 : }
475 1 :
476 1 : assert!(matches!(
477 1 : &*state.value.lock().unwrap(),
478 1 : &State::SenderWaitsForReceiverToConsume(_)
479 1 : ));
480 1 :
481 1 : drop(receiver);
482 1 :
483 1 : assert!(matches!(
484 1 : &*state.value.lock().unwrap(),
485 1 : &State::ReceiverGone
486 1 : ));
487 1 :
488 1 : unmergeable.await.unwrap_err();
489 1 : }
490 : }
|