Line data Source code
1 : use std::{
2 : future::Future,
3 : pin::Pin,
4 : task::{Context, Poll},
5 : time::Duration,
6 : };
7 :
8 : use bytes::Bytes;
9 : use futures_util::Stream;
10 : use tokio_util::sync::CancellationToken;
11 :
12 : use crate::TimeoutOrCancel;
13 :
14 : pin_project_lite::pin_project! {
15 : /// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
16 : pub(crate) struct PermitCarrying<S> {
17 : permit: tokio::sync::OwnedSemaphorePermit,
18 : #[pin]
19 : inner: S,
20 : }
21 : }
22 :
23 : impl<S> PermitCarrying<S> {
24 24 : pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
25 24 : Self { permit, inner }
26 24 : }
27 : }
28 :
29 : impl<S: Stream> Stream for PermitCarrying<S> {
30 : type Item = <S as Stream>::Item;
31 :
32 49 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
33 49 : self.project().inner.poll_next(cx)
34 49 : }
35 :
36 0 : fn size_hint(&self) -> (usize, Option<usize>) {
37 0 : self.inner.size_hint()
38 0 : }
39 : }
40 :
41 : pin_project_lite::pin_project! {
42 : pub(crate) struct DownloadStream<F, S> {
43 : hit: bool,
44 : #[pin]
45 : cancellation: F,
46 : #[pin]
47 : inner: S,
48 : }
49 : }
50 :
51 : impl<F, S> DownloadStream<F, S> {
52 157 : pub(crate) fn new(cancellation: F, inner: S) -> Self {
53 157 : Self {
54 157 : cancellation,
55 157 : hit: false,
56 157 : inner,
57 157 : }
58 157 : }
59 : }
60 :
61 : /// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used.
62 : impl<E, F, S> Stream for DownloadStream<F, S>
63 : where
64 : std::io::Error: From<E>,
65 : F: Future<Output = E>,
66 : S: Stream<Item = std::io::Result<Bytes>>,
67 : {
68 : type Item = <S as Stream>::Item;
69 :
70 5020 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 5020 : let this = self.project();
72 5020 :
73 5020 : if !*this.hit {
74 5004 : if let Poll::Ready(e) = this.cancellation.poll(cx) {
75 13 : *this.hit = true;
76 13 :
77 13 : // most likely this will be a std::io::Error wrapping a DownloadError
78 13 : let e = Err(std::io::Error::from(e));
79 13 : return Poll::Ready(Some(e));
80 4991 : }
81 : } else {
82 : // this would be perfectly valid behaviour for doing a graceful completion on the
83 : // download for example, but not one we expect to do right now.
84 16 : tracing::warn!("continuing polling after having cancelled or timeouted");
85 : }
86 :
87 5007 : this.inner.poll_next(cx)
88 5020 : }
89 :
90 0 : fn size_hint(&self) -> (usize, Option<usize>) {
91 0 : self.inner.size_hint()
92 0 : }
93 : }
94 :
95 : /// Fires only on the first cancel or timeout, not on both.
96 164 : pub(crate) fn cancel_or_timeout(
97 164 : timeout: Duration,
98 164 : cancel: CancellationToken,
99 164 : ) -> impl std::future::Future<Output = TimeoutOrCancel> + 'static {
100 164 : // futures are lazy, they don't do anything before being polled.
101 164 : //
102 164 : // "precalculate" the wanted deadline before returning the future, so that we can use pause
103 164 : // failpoint to trigger a timeout in test.
104 164 : let deadline = tokio::time::Instant::now() + timeout;
105 162 : async move {
106 162 : tokio::select! {
107 162 : _ = tokio::time::sleep_until(deadline) => TimeoutOrCancel::Timeout,
108 162 : _ = cancel.cancelled() => {
109 8 : TimeoutOrCancel::Cancel
110 : },
111 : }
112 13 : }
113 164 : }
114 :
115 : #[cfg(test)]
116 : mod tests {
117 : use super::*;
118 : use crate::DownloadError;
119 : use futures::stream::StreamExt;
120 :
121 : #[tokio::test(start_paused = true)]
122 3 : async fn cancelled_download_stream() {
123 3 : let inner = futures::stream::pending();
124 3 : let timeout = Duration::from_secs(120);
125 3 : let cancel = CancellationToken::new();
126 3 :
127 3 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
128 3 : let mut stream = std::pin::pin!(stream);
129 3 :
130 3 : let mut first = stream.next();
131 3 :
132 3 : tokio::select! {
133 3 : _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"),
134 3 : _ = tokio::time::sleep(Duration::from_secs(1)) => {},
135 3 : }
136 3 :
137 3 : cancel.cancel();
138 3 :
139 3 : let e = first.await.expect("there must be some").unwrap_err();
140 3 : assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
141 3 : let inner = e.get_ref().expect("inner should be set");
142 3 : assert!(
143 3 : inner
144 3 : .downcast_ref::<DownloadError>()
145 3 : .is_some_and(|e| matches!(e, DownloadError::Cancelled)),
146 3 : "{inner:?}"
147 3 : );
148 3 : let e = DownloadError::from(e);
149 3 : assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
150 3 :
151 3 : tokio::select! {
152 3 : _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"),
153 3 : _ = tokio::time::sleep(Duration::from_secs(121)) => {},
154 3 : }
155 3 : }
156 :
157 : #[tokio::test(start_paused = true)]
158 3 : async fn timeouted_download_stream() {
159 3 : let inner = futures::stream::pending();
160 3 : let timeout = Duration::from_secs(120);
161 3 : let cancel = CancellationToken::new();
162 3 :
163 3 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
164 3 : let mut stream = std::pin::pin!(stream);
165 3 :
166 3 : // because the stream uses 120s timeout and we are paused, we advance to 120s right away.
167 3 : let first = stream.next();
168 3 :
169 3 : let e = first.await.expect("there must be some").unwrap_err();
170 3 : assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
171 3 : let inner = e.get_ref().expect("inner should be set");
172 3 : assert!(
173 3 : inner
174 3 : .downcast_ref::<DownloadError>()
175 3 : .is_some_and(|e| matches!(e, DownloadError::Timeout)),
176 3 : "{inner:?}"
177 3 : );
178 3 : let e = DownloadError::from(e);
179 3 : assert!(matches!(e, DownloadError::Timeout), "{e:?}");
180 3 :
181 3 : cancel.cancel();
182 3 :
183 3 : tokio::select! {
184 3 : _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"),
185 3 : _ = tokio::time::sleep(Duration::from_secs(121)) => {},
186 3 : }
187 3 : }
188 :
189 : #[tokio::test]
190 3 : async fn notified_but_pollable_after() {
191 3 : let inner = futures::stream::once(futures::future::ready(Ok(bytes::Bytes::from_static(
192 3 : b"hello world",
193 3 : ))));
194 3 : let timeout = Duration::from_secs(120);
195 3 : let cancel = CancellationToken::new();
196 3 :
197 3 : cancel.cancel();
198 3 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
199 3 : let mut stream = std::pin::pin!(stream);
200 3 :
201 3 : let next = stream.next().await;
202 3 : let ioe = next.unwrap().unwrap_err();
203 3 : assert!(
204 3 : matches!(
205 3 : ioe.get_ref().unwrap().downcast_ref::<DownloadError>(),
206 3 : Some(&DownloadError::Cancelled)
207 3 : ),
208 3 : "{ioe:?}"
209 3 : );
210 3 :
211 3 : let next = stream.next().await;
212 3 : let bytes = next.unwrap().unwrap();
213 3 : assert_eq!(&b"hello world"[..], bytes);
214 3 : }
215 : }
|