Line data Source code
1 : use std::{
2 : future::Future,
3 : pin::Pin,
4 : task::{Context, Poll},
5 : time::Duration,
6 : };
7 :
8 : use bytes::Bytes;
9 : use futures_util::Stream;
10 : use tokio_util::sync::CancellationToken;
11 :
12 : use crate::TimeoutOrCancel;
13 :
14 : pin_project_lite::pin_project! {
15 : /// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
16 : pub(crate) struct PermitCarrying<S> {
17 : permit: tokio::sync::OwnedSemaphorePermit,
18 : #[pin]
19 : inner: S,
20 : }
21 : }
22 :
23 : impl<S> PermitCarrying<S> {
24 0 : pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
25 0 : Self { permit, inner }
26 0 : }
27 : }
28 :
29 : impl<S: Stream> Stream for PermitCarrying<S> {
30 : type Item = <S as Stream>::Item;
31 :
32 0 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
33 0 : self.project().inner.poll_next(cx)
34 0 : }
35 :
36 0 : fn size_hint(&self) -> (usize, Option<usize>) {
37 0 : self.inner.size_hint()
38 0 : }
39 : }
40 :
41 : pin_project_lite::pin_project! {
42 : pub(crate) struct DownloadStream<F, S> {
43 : hit: bool,
44 : #[pin]
45 : cancellation: F,
46 : #[pin]
47 : inner: S,
48 : }
49 : }
50 :
51 : impl<F, S> DownloadStream<F, S> {
52 48 : pub(crate) fn new(cancellation: F, inner: S) -> Self {
53 48 : Self {
54 48 : cancellation,
55 48 : hit: false,
56 48 : inner,
57 48 : }
58 48 : }
59 : }
60 :
61 : /// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used.
62 : impl<E, F, S> Stream for DownloadStream<F, S>
63 : where
64 : std::io::Error: From<E>,
65 : F: Future<Output = E>,
66 : S: Stream<Item = std::io::Result<Bytes>>,
67 : {
68 : type Item = <S as Stream>::Item;
69 :
70 1609 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 1609 : let this = self.project();
72 1609 :
73 1609 : if !*this.hit {
74 1601 : if let Poll::Ready(e) = this.cancellation.poll(cx) {
75 4 : *this.hit = true;
76 4 :
77 4 : // most likely this will be a std::io::Error wrapping a DownloadError
78 4 : let e = Err(std::io::Error::from(e));
79 4 : return Poll::Ready(Some(e));
80 1597 : }
81 8 : }
82 :
83 1605 : this.inner.poll_next(cx)
84 1609 : }
85 :
86 0 : fn size_hint(&self) -> (usize, Option<usize>) {
87 0 : self.inner.size_hint()
88 0 : }
89 : }
90 :
91 : /// Fires only on the first cancel or timeout, not on both.
92 72 : pub(crate) async fn cancel_or_timeout(
93 72 : timeout: Duration,
94 72 : cancel: CancellationToken,
95 72 : ) -> TimeoutOrCancel {
96 1601 : tokio::select! {
97 1601 : _ = tokio::time::sleep(timeout) => TimeoutOrCancel::Timeout,
98 1601 : _ = cancel.cancelled() => TimeoutOrCancel::Cancel,
99 1601 : }
100 4 : }
101 :
102 : #[cfg(test)]
103 : mod tests {
104 : use super::*;
105 : use crate::DownloadError;
106 : use futures::stream::StreamExt;
107 :
108 2 : #[tokio::test(start_paused = true)]
109 2 : async fn cancelled_download_stream() {
110 2 : let inner = futures::stream::pending();
111 2 : let timeout = Duration::from_secs(120);
112 2 : let cancel = CancellationToken::new();
113 2 :
114 2 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
115 2 : let mut stream = std::pin::pin!(stream);
116 2 :
117 2 : let mut first = stream.next();
118 2 :
119 4 : tokio::select! {
120 4 : _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"),
121 4 : _ = tokio::time::sleep(Duration::from_secs(1)) => {},
122 4 : }
123 2 :
124 2 : cancel.cancel();
125 2 :
126 2 : let e = first.await.expect("there must be some").unwrap_err();
127 2 : assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
128 2 : let inner = e.get_ref().expect("inner should be set");
129 2 : assert!(
130 2 : inner
131 2 : .downcast_ref::<DownloadError>()
132 2 : .is_some_and(|e| matches!(e, DownloadError::Cancelled)),
133 2 : "{inner:?}"
134 2 : );
135 2 : let e = DownloadError::from(e);
136 2 : assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
137 2 :
138 2 : tokio::select! {
139 6 : _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"),
140 6 : _ = tokio::time::sleep(Duration::from_secs(121)) => {},
141 6 : }
142 2 : }
143 :
144 2 : #[tokio::test(start_paused = true)]
145 2 : async fn timeouted_download_stream() {
146 2 : let inner = futures::stream::pending();
147 2 : let timeout = Duration::from_secs(120);
148 2 : let cancel = CancellationToken::new();
149 2 :
150 2 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
151 2 : let mut stream = std::pin::pin!(stream);
152 2 :
153 2 : // because the stream uses 120s timeout and we are paused, we advance to 120s right away.
154 2 : let first = stream.next();
155 2 :
156 2 : let e = first.await.expect("there must be some").unwrap_err();
157 2 : assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
158 2 : let inner = e.get_ref().expect("inner should be set");
159 2 : assert!(
160 2 : inner
161 2 : .downcast_ref::<DownloadError>()
162 2 : .is_some_and(|e| matches!(e, DownloadError::Timeout)),
163 2 : "{inner:?}"
164 2 : );
165 2 : let e = DownloadError::from(e);
166 2 : assert!(matches!(e, DownloadError::Timeout), "{e:?}");
167 2 :
168 2 : cancel.cancel();
169 2 :
170 4 : tokio::select! {
171 4 : _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"),
172 4 : _ = tokio::time::sleep(Duration::from_secs(121)) => {},
173 4 : }
174 2 : }
175 : }
|