Line data Source code
1 : use std::future::Future;
2 : use std::pin::Pin;
3 : use std::task::{Context, Poll};
4 : use std::time::Duration;
5 :
6 : use bytes::Bytes;
7 : use futures_util::Stream;
8 : use tokio_util::sync::CancellationToken;
9 :
10 : use crate::TimeoutOrCancel;
11 :
12 : pin_project_lite::pin_project! {
13 : /// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
14 : pub(crate) struct PermitCarrying<S> {
15 : permit: tokio::sync::OwnedSemaphorePermit,
16 : #[pin]
17 : inner: S,
18 : }
19 : }
20 :
21 : impl<S> PermitCarrying<S> {
22 28 : pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
23 28 : Self { permit, inner }
24 28 : }
25 : }
26 :
27 : impl<S: Stream> Stream for PermitCarrying<S> {
28 : type Item = <S as Stream>::Item;
29 :
30 46 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
31 46 : self.project().inner.poll_next(cx)
32 46 : }
33 :
34 0 : fn size_hint(&self) -> (usize, Option<usize>) {
35 0 : self.inner.size_hint()
36 0 : }
37 : }
38 :
39 : pin_project_lite::pin_project! {
40 : pub(crate) struct DownloadStream<F, S> {
41 : hit: bool,
42 : #[pin]
43 : cancellation: F,
44 : #[pin]
45 : inner: S,
46 : }
47 : }
48 :
49 : impl<F, S> DownloadStream<F, S> {
50 151 : pub(crate) fn new(cancellation: F, inner: S) -> Self {
51 151 : Self {
52 151 : cancellation,
53 151 : hit: false,
54 151 : inner,
55 151 : }
56 151 : }
57 : }
58 :
59 : /// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used.
60 : impl<E, F, S> Stream for DownloadStream<F, S>
61 : where
62 : std::io::Error: From<E>,
63 : F: Future<Output = E>,
64 : S: Stream<Item = std::io::Result<Bytes>>,
65 : {
66 : type Item = <S as Stream>::Item;
67 :
68 2731 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
69 2731 : let this = self.project();
70 2731 :
71 2731 : if !*this.hit {
72 2717 : if let Poll::Ready(e) = this.cancellation.poll(cx) {
73 13 : *this.hit = true;
74 13 :
75 13 : // most likely this will be a std::io::Error wrapping a DownloadError
76 13 : let e = Err(std::io::Error::from(e));
77 13 : return Poll::Ready(Some(e));
78 109 : }
79 : } else {
80 : // this would be perfectly valid behaviour for doing a graceful completion on the
81 : // download for example, but not one we expect to do right now.
82 14 : tracing::warn!("continuing polling after having cancelled or timeouted");
83 : }
84 :
85 2718 : this.inner.poll_next(cx)
86 132 : }
87 :
88 0 : fn size_hint(&self) -> (usize, Option<usize>) {
89 0 : self.inner.size_hint()
90 0 : }
91 : }
92 :
93 : /// Fires only on the first cancel or timeout, not on both.
94 164 : pub(crate) fn cancel_or_timeout(
95 164 : timeout: Duration,
96 164 : cancel: CancellationToken,
97 164 : ) -> impl std::future::Future<Output = TimeoutOrCancel> + 'static {
98 164 : // futures are lazy, they don't do anything before being polled.
99 164 : //
100 164 : // "precalculate" the wanted deadline before returning the future, so that we can use pause
101 164 : // failpoint to trigger a timeout in test.
102 164 : let deadline = tokio::time::Instant::now() + timeout;
103 154 : async move {
104 154 : tokio::select! {
105 154 : _ = tokio::time::sleep_until(deadline) => TimeoutOrCancel::Timeout,
106 154 : _ = cancel.cancelled() => {
107 8 : TimeoutOrCancel::Cancel
108 : },
109 : }
110 9 : }
111 164 : }
112 :
113 : #[cfg(test)]
114 : mod tests {
115 : use futures::stream::StreamExt;
116 :
117 : use super::*;
118 : use crate::DownloadError;
119 :
120 : #[tokio::test(start_paused = true)]
121 3 : async fn cancelled_download_stream() {
122 3 : let inner = futures::stream::pending();
123 3 : let timeout = Duration::from_secs(120);
124 3 : let cancel = CancellationToken::new();
125 3 :
126 3 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
127 3 : let mut stream = std::pin::pin!(stream);
128 3 :
129 3 : let mut first = stream.next();
130 3 :
131 3 : tokio::select! {
132 3 : _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"),
133 3 : _ = tokio::time::sleep(Duration::from_secs(1)) => {},
134 3 : }
135 3 :
136 3 : cancel.cancel();
137 3 :
138 3 : let e = first.await.expect("there must be some").unwrap_err();
139 3 : assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
140 3 : let inner = e.get_ref().expect("inner should be set");
141 3 : assert!(
142 3 : inner
143 3 : .downcast_ref::<DownloadError>()
144 3 : .is_some_and(|e| matches!(e, DownloadError::Cancelled)),
145 3 : "{inner:?}"
146 3 : );
147 3 : let e = DownloadError::from(e);
148 3 : assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
149 3 :
150 3 : tokio::select! {
151 3 : _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"),
152 3 : _ = tokio::time::sleep(Duration::from_secs(121)) => {},
153 3 : }
154 3 : }
155 :
156 : #[tokio::test(start_paused = true)]
157 3 : async fn timeouted_download_stream() {
158 3 : let inner = futures::stream::pending();
159 3 : let timeout = Duration::from_secs(120);
160 3 : let cancel = CancellationToken::new();
161 3 :
162 3 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
163 3 : let mut stream = std::pin::pin!(stream);
164 3 :
165 3 : // because the stream uses 120s timeout and we are paused, we advance to 120s right away.
166 3 : let first = stream.next();
167 3 :
168 3 : let e = first.await.expect("there must be some").unwrap_err();
169 3 : assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
170 3 : let inner = e.get_ref().expect("inner should be set");
171 3 : assert!(
172 3 : inner
173 3 : .downcast_ref::<DownloadError>()
174 3 : .is_some_and(|e| matches!(e, DownloadError::Timeout)),
175 3 : "{inner:?}"
176 3 : );
177 3 : let e = DownloadError::from(e);
178 3 : assert!(matches!(e, DownloadError::Timeout), "{e:?}");
179 3 :
180 3 : cancel.cancel();
181 3 :
182 3 : tokio::select! {
183 3 : _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"),
184 3 : _ = tokio::time::sleep(Duration::from_secs(121)) => {},
185 3 : }
186 3 : }
187 :
188 : #[tokio::test]
189 3 : async fn notified_but_pollable_after() {
190 3 : let inner = futures::stream::once(futures::future::ready(Ok(bytes::Bytes::from_static(
191 3 : b"hello world",
192 3 : ))));
193 3 : let timeout = Duration::from_secs(120);
194 3 : let cancel = CancellationToken::new();
195 3 :
196 3 : cancel.cancel();
197 3 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
198 3 : let mut stream = std::pin::pin!(stream);
199 3 :
200 3 : let next = stream.next().await;
201 3 : let ioe = next.unwrap().unwrap_err();
202 3 : assert!(
203 3 : matches!(
204 3 : ioe.get_ref().unwrap().downcast_ref::<DownloadError>(),
205 3 : Some(&DownloadError::Cancelled)
206 3 : ),
207 3 : "{ioe:?}"
208 3 : );
209 3 :
210 3 : let next = stream.next().await;
211 3 : let bytes = next.unwrap().unwrap();
212 3 : assert_eq!(&b"hello world"[..], bytes);
213 3 : }
214 : }
|