Line data Source code
1 : use core::{future::poll_fn, task::Poll};
2 : use std::sync::{Arc, Mutex};
3 :
4 : use diatomic_waker::DiatomicWaker;
5 :
6 : pub struct Sender<T> {
7 : state: Arc<Inner<T>>,
8 : }
9 :
10 : pub struct Receiver<T> {
11 : state: Arc<Inner<T>>,
12 : }
13 :
14 : struct Inner<T> {
15 : wake_receiver: DiatomicWaker,
16 : wake_sender: DiatomicWaker,
17 : value: Mutex<State<T>>,
18 : }
19 :
20 : enum State<T> {
21 : NoData,
22 : HasData(T),
23 : TryFoldFailed, // transient state
24 : SenderWaitsForReceiverToConsume(T),
25 : SenderGone(Option<T>),
26 : ReceiverGone,
27 : AllGone,
28 : SenderDropping, // transient state
29 : ReceiverDropping, // transient state
30 : }
31 :
32 11 : pub fn channel<T: Send>() -> (Sender<T>, Receiver<T>) {
33 11 : let inner = Inner {
34 11 : wake_receiver: DiatomicWaker::new(),
35 11 : wake_sender: DiatomicWaker::new(),
36 11 : value: Mutex::new(State::NoData),
37 11 : };
38 11 :
39 11 : let state = Arc::new(inner);
40 11 : (
41 11 : Sender {
42 11 : state: state.clone(),
43 11 : },
44 11 : Receiver { state },
45 11 : )
46 11 : }
47 :
48 : #[derive(Debug, thiserror::Error)]
49 : pub enum SendError {
50 : #[error("receiver is gone")]
51 : ReceiverGone,
52 : }
53 :
54 : impl<T: Send> Sender<T> {
55 : /// # Panics
56 : ///
57 : /// If `try_fold` panics, any subsequent call to `send` panic.
58 12 : pub async fn send<F>(&mut self, value: T, try_fold: F) -> Result<(), SendError>
59 12 : where
60 12 : F: Fn(&mut T, T) -> Result<(), T>,
61 12 : {
62 12 : let mut value = Some(value);
63 14 : poll_fn(|cx| {
64 14 : let mut guard = self.state.value.lock().unwrap();
65 14 : match &mut *guard {
66 : State::NoData => {
67 8 : *guard = State::HasData(value.take().unwrap());
68 8 : self.state.wake_receiver.notify();
69 8 : Poll::Ready(Ok(()))
70 : }
71 : State::HasData(_) => {
72 3 : let State::HasData(acc_mut) = &mut *guard else {
73 0 : unreachable!("this match arm guarantees that the guard is HasData");
74 : };
75 3 : match try_fold(acc_mut, value.take().unwrap()) {
76 : Ok(()) => {
77 : // no need to wake receiver, if it was waiting it already
78 : // got a wake-up when we transitioned from NoData to HasData
79 1 : Poll::Ready(Ok(()))
80 : }
81 2 : Err(unfoldable_value) => {
82 2 : value = Some(unfoldable_value);
83 2 : let State::HasData(acc) =
84 2 : std::mem::replace(&mut *guard, State::TryFoldFailed)
85 : else {
86 0 : unreachable!("this match arm guarantees that the guard is HasData");
87 : };
88 2 : *guard = State::SenderWaitsForReceiverToConsume(acc);
89 2 : // SAFETY: send is single threaded due to `&mut self` requirement,
90 2 : // therefore register is not concurrent.
91 2 : unsafe {
92 2 : self.state.wake_sender.register(cx.waker());
93 2 : }
94 2 : Poll::Pending
95 : }
96 : }
97 : }
98 0 : State::SenderWaitsForReceiverToConsume(_data) => {
99 0 : // Really, we shouldn't be polled until receiver has consumed and wakes us.
100 0 : Poll::Pending
101 : }
102 3 : State::ReceiverGone => Poll::Ready(Err(SendError::ReceiverGone)),
103 : State::SenderGone(_)
104 : | State::AllGone
105 : | State::SenderDropping
106 : | State::ReceiverDropping
107 : | State::TryFoldFailed => {
108 0 : unreachable!();
109 : }
110 : }
111 14 : })
112 12 : .await
113 12 : }
114 : }
115 :
116 : impl<T> Drop for Sender<T> {
117 11 : fn drop(&mut self) {
118 11 : scopeguard::defer! {
119 11 : self.state.wake_receiver.notify()
120 11 : };
121 11 : let Ok(mut guard) = self.state.value.lock() else {
122 0 : return;
123 : };
124 11 : *guard = match std::mem::replace(&mut *guard, State::SenderDropping) {
125 3 : State::NoData => State::SenderGone(None),
126 1 : State::HasData(data) | State::SenderWaitsForReceiverToConsume(data) => {
127 1 : State::SenderGone(Some(data))
128 : }
129 7 : State::ReceiverGone => State::AllGone,
130 : State::TryFoldFailed
131 : | State::SenderGone(_)
132 : | State::AllGone
133 : | State::SenderDropping
134 : | State::ReceiverDropping => {
135 0 : unreachable!("unreachable state {:?}", guard.discriminant_str())
136 : }
137 : }
138 11 : }
139 : }
140 :
141 : #[derive(Debug, thiserror::Error)]
142 : pub enum RecvError {
143 : #[error("sender is gone")]
144 : SenderGone,
145 : }
146 :
147 : impl<T: Send> Receiver<T> {
148 10 : pub async fn recv(&mut self) -> Result<T, RecvError> {
149 13 : poll_fn(|cx| {
150 13 : let mut guard = self.state.value.lock().unwrap();
151 13 : match &mut *guard {
152 : State::NoData => {
153 : // SAFETY: recv is single threaded due to `&mut self` requirement,
154 : // therefore register is not concurrent.
155 3 : unsafe {
156 3 : self.state.wake_receiver.register(cx.waker());
157 3 : }
158 3 : Poll::Pending
159 : }
160 4 : guard @ State::HasData(_)
161 1 : | guard @ State::SenderWaitsForReceiverToConsume(_)
162 1 : | guard @ State::SenderGone(Some(_)) => {
163 6 : let data = guard
164 6 : .take_data()
165 6 : .expect("in these states, data is guaranteed to be present");
166 6 : self.state.wake_sender.notify();
167 6 : Poll::Ready(Ok(data))
168 : }
169 4 : State::SenderGone(None) => Poll::Ready(Err(RecvError::SenderGone)),
170 : State::ReceiverGone
171 : | State::AllGone
172 : | State::SenderDropping
173 : | State::ReceiverDropping
174 : | State::TryFoldFailed => {
175 0 : unreachable!("unreachable state {:?}", guard.discriminant_str());
176 : }
177 : }
178 13 : })
179 10 : .await
180 10 : }
181 : }
182 :
183 : impl<T> Drop for Receiver<T> {
184 11 : fn drop(&mut self) {
185 11 : scopeguard::defer! {
186 11 : self.state.wake_sender.notify()
187 11 : };
188 11 : let Ok(mut guard) = self.state.value.lock() else {
189 0 : return;
190 : };
191 11 : *guard = match std::mem::replace(&mut *guard, State::ReceiverDropping) {
192 5 : State::NoData => State::ReceiverGone,
193 2 : State::HasData(_) | State::SenderWaitsForReceiverToConsume(_) => State::ReceiverGone,
194 4 : State::SenderGone(_) => State::AllGone,
195 : State::TryFoldFailed
196 : | State::ReceiverGone
197 : | State::AllGone
198 : | State::SenderDropping
199 : | State::ReceiverDropping => {
200 0 : unreachable!("unreachable state {:?}", guard.discriminant_str())
201 : }
202 : }
203 11 : }
204 : }
205 :
206 : impl<T> State<T> {
207 6 : fn take_data(&mut self) -> Option<T> {
208 6 : match self {
209 : State::HasData(_) => {
210 4 : let State::HasData(data) = std::mem::replace(self, State::NoData) else {
211 0 : unreachable!("this match arm guarantees that the state is HasData");
212 : };
213 4 : Some(data)
214 : }
215 : State::SenderWaitsForReceiverToConsume(_) => {
216 1 : let State::SenderWaitsForReceiverToConsume(data) =
217 1 : std::mem::replace(self, State::NoData)
218 : else {
219 0 : unreachable!(
220 0 : "this match arm guarantees that the state is SenderWaitsForReceiverToConsume"
221 0 : );
222 : };
223 1 : Some(data)
224 : }
225 1 : State::SenderGone(data) => Some(data.take().unwrap()),
226 : State::NoData
227 : | State::TryFoldFailed
228 : | State::ReceiverGone
229 : | State::AllGone
230 : | State::SenderDropping
231 0 : | State::ReceiverDropping => None,
232 : }
233 6 : }
234 0 : fn discriminant_str(&self) -> &'static str {
235 0 : match self {
236 0 : State::NoData => "NoData",
237 0 : State::HasData(_) => "HasData",
238 0 : State::TryFoldFailed => "TryFoldFailed",
239 0 : State::SenderWaitsForReceiverToConsume(_) => "SenderWaitsForReceiverToConsume",
240 0 : State::SenderGone(_) => "SenderGone",
241 0 : State::ReceiverGone => "ReceiverGone",
242 0 : State::AllGone => "AllGone",
243 0 : State::SenderDropping => "SenderDropping",
244 0 : State::ReceiverDropping => "ReceiverDropping",
245 : }
246 0 : }
247 : }
248 :
249 : #[cfg(test)]
250 : mod tests {
251 :
252 : use super::*;
253 :
254 : const FOREVER: std::time::Duration = std::time::Duration::from_secs(u64::MAX);
255 :
256 : #[tokio::test]
257 1 : async fn test_send_recv() {
258 1 : let (mut sender, mut receiver) = channel();
259 1 :
260 1 : sender
261 1 : .send(42, |acc, val| {
262 0 : *acc += val;
263 0 : Ok(())
264 1 : })
265 1 : .await
266 1 : .unwrap();
267 1 :
268 1 : let received = receiver.recv().await.unwrap();
269 1 : assert_eq!(received, 42);
270 1 : }
271 :
272 : #[tokio::test]
273 1 : async fn test_send_recv_with_fold() {
274 1 : let (mut sender, mut receiver) = channel();
275 1 :
276 1 : sender
277 1 : .send(1, |acc, val| {
278 0 : *acc += val;
279 0 : Ok(())
280 1 : })
281 1 : .await
282 1 : .unwrap();
283 1 : sender
284 1 : .send(2, |acc, val| {
285 1 : *acc += val;
286 1 : Ok(())
287 1 : })
288 1 : .await
289 1 : .unwrap();
290 1 :
291 1 : let received = receiver.recv().await.unwrap();
292 1 : assert_eq!(received, 3);
293 1 : }
294 :
295 : #[tokio::test(start_paused = true)]
296 1 : async fn test_sender_waits_for_receiver_if_try_fold_fails() {
297 1 : let (mut sender, mut receiver) = channel();
298 1 :
299 1 : sender.send(23, |_, _| panic!("first send")).await.unwrap();
300 1 :
301 1 : let send_fut = sender.send(42, |_, val| Err(val));
302 1 : let mut send_fut = std::pin::pin!(send_fut);
303 1 :
304 1 : tokio::select! {
305 1 : _ = tokio::time::sleep(FOREVER) => {},
306 1 : _ = &mut send_fut => {
307 1 : panic!("send should not complete");
308 1 : },
309 1 : }
310 1 :
311 1 : let val = receiver.recv().await.unwrap();
312 1 : assert_eq!(val, 23);
313 1 :
314 1 : tokio::select! {
315 1 : _ = tokio::time::sleep(FOREVER) => {
316 1 : panic!("receiver should have consumed the value");
317 1 : },
318 1 : _ = &mut send_fut => { },
319 1 : }
320 1 :
321 1 : let val = receiver.recv().await.unwrap();
322 1 : assert_eq!(val, 42);
323 1 : }
324 :
325 : #[tokio::test(start_paused = true)]
326 1 : async fn test_sender_errors_if_waits_for_receiver_and_receiver_drops() {
327 1 : let (mut sender, receiver) = channel();
328 1 :
329 1 : sender.send(23, |_, _| unreachable!()).await.unwrap();
330 1 :
331 1 : let send_fut = sender.send(42, |_, val| Err(val));
332 1 : let send_fut = std::pin::pin!(send_fut);
333 1 :
334 1 : drop(receiver);
335 1 :
336 1 : let result = send_fut.await;
337 1 : assert!(matches!(result, Err(SendError::ReceiverGone)));
338 1 : }
339 :
340 : #[tokio::test(start_paused = true)]
341 1 : async fn test_receiver_errors_if_waits_for_sender_and_sender_drops() {
342 1 : let (sender, mut receiver) = channel::<()>();
343 1 :
344 1 : let recv_fut = receiver.recv();
345 1 : let recv_fut = std::pin::pin!(recv_fut);
346 1 :
347 1 : drop(sender);
348 1 :
349 1 : let result = recv_fut.await;
350 1 : assert!(matches!(result, Err(RecvError::SenderGone)));
351 1 : }
352 :
353 : #[tokio::test(start_paused = true)]
354 1 : async fn test_receiver_errors_if_waits_for_sender_and_sender_drops_with_data() {
355 1 : let (mut sender, mut receiver) = channel();
356 1 :
357 1 : sender.send(42, |_, _| unreachable!()).await.unwrap();
358 1 :
359 1 : {
360 1 : let recv_fut = receiver.recv();
361 1 : let recv_fut = std::pin::pin!(recv_fut);
362 1 :
363 1 : drop(sender);
364 1 :
365 1 : let val = recv_fut.await.unwrap();
366 1 : assert_eq!(val, 42);
367 1 : }
368 1 :
369 1 : let result = receiver.recv().await;
370 1 : assert!(matches!(result, Err(RecvError::SenderGone)));
371 1 : }
372 :
373 : #[tokio::test(start_paused = true)]
374 1 : async fn test_receiver_waits_for_sender_if_no_data() {
375 1 : let (mut sender, mut receiver) = channel();
376 1 :
377 1 : let recv_fut = receiver.recv();
378 1 : let mut recv_fut = std::pin::pin!(recv_fut);
379 1 :
380 1 : tokio::select! {
381 1 : _ = tokio::time::sleep(FOREVER) => {},
382 1 : _ = &mut recv_fut => {
383 1 : panic!("recv should not complete");
384 1 : },
385 1 : }
386 1 :
387 1 : sender.send(42, |_, _| Ok(())).await.unwrap();
388 1 :
389 1 : let val = recv_fut.await.unwrap();
390 1 : assert_eq!(val, 42);
391 1 : }
392 :
393 : #[tokio::test]
394 1 : async fn test_receiver_gone_while_nodata() {
395 1 : let (mut sender, receiver) = channel();
396 1 : drop(receiver);
397 1 :
398 1 : let result = sender.send(42, |_, _| Ok(())).await;
399 1 : assert!(matches!(result, Err(SendError::ReceiverGone)));
400 1 : }
401 :
402 : #[tokio::test]
403 1 : async fn test_sender_gone_while_nodata() {
404 1 : let (sender, mut receiver) = super::channel::<usize>();
405 1 : drop(sender);
406 1 :
407 1 : let result = receiver.recv().await;
408 1 : assert!(matches!(result, Err(RecvError::SenderGone)));
409 1 : }
410 :
411 : #[tokio::test(start_paused = true)]
412 1 : async fn test_receiver_drops_after_sender_went_to_sleep() {
413 1 : let (mut sender, receiver) = channel();
414 1 : let state = receiver.state.clone();
415 1 :
416 1 : sender.send(23, |_, _| unreachable!()).await.unwrap();
417 1 :
418 1 : let send_task = tokio::spawn(async move { sender.send(42, |_, v| Err(v)).await });
419 1 :
420 1 : tokio::time::sleep(FOREVER).await;
421 1 :
422 1 : assert!(matches!(
423 1 : &*state.value.lock().unwrap(),
424 1 : &State::SenderWaitsForReceiverToConsume(_)
425 1 : ));
426 1 :
427 1 : drop(receiver);
428 1 :
429 1 : let err = send_task
430 1 : .await
431 1 : .unwrap()
432 1 : .expect_err("should unblock immediately");
433 1 : assert!(matches!(err, SendError::ReceiverGone));
434 1 : }
435 :
436 : #[tokio::test(start_paused = true)]
437 1 : async fn test_sender_drops_after_receiver_went_to_sleep() {
438 1 : let (sender, mut receiver) = channel::<usize>();
439 1 : let state = sender.state.clone();
440 1 :
441 1 : let recv_task = tokio::spawn(async move { receiver.recv().await });
442 1 :
443 1 : tokio::time::sleep(FOREVER).await;
444 1 :
445 1 : assert!(matches!(&*state.value.lock().unwrap(), &State::NoData));
446 1 :
447 1 : drop(sender);
448 1 :
449 1 : let err = recv_task.await.unwrap().expect_err("should error");
450 1 : assert!(matches!(err, RecvError::SenderGone));
451 1 : }
452 : }
|