Line data Source code
1 : use std::{
2 : future::Future,
3 : pin::Pin,
4 : task::{Context, Poll},
5 : time::Duration,
6 : };
7 :
8 : use bytes::Bytes;
9 : use futures_util::Stream;
10 : use tokio_util::sync::CancellationToken;
11 :
12 : use crate::TimeoutOrCancel;
13 :
14 : pin_project_lite::pin_project! {
15 : /// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
16 : pub(crate) struct PermitCarrying<S> {
17 : permit: tokio::sync::OwnedSemaphorePermit,
18 : #[pin]
19 : inner: S,
20 : }
21 : }
22 :
23 : impl<S> PermitCarrying<S> {
24 24 : pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
25 24 : Self { permit, inner }
26 24 : }
27 : }
28 :
29 : impl<S: Stream> Stream for PermitCarrying<S> {
30 : type Item = <S as Stream>::Item;
31 :
32 47 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
33 47 : self.project().inner.poll_next(cx)
34 47 : }
35 :
36 0 : fn size_hint(&self) -> (usize, Option<usize>) {
37 0 : self.inner.size_hint()
38 0 : }
39 : }
40 :
41 : pin_project_lite::pin_project! {
42 : pub(crate) struct DownloadStream<F, S> {
43 : hit: bool,
44 : #[pin]
45 : cancellation: F,
46 : #[pin]
47 : inner: S,
48 : }
49 : }
50 :
51 : impl<F, S> DownloadStream<F, S> {
52 115 : pub(crate) fn new(cancellation: F, inner: S) -> Self {
53 115 : Self {
54 115 : cancellation,
55 115 : hit: false,
56 115 : inner,
57 115 : }
58 115 : }
59 : }
60 :
61 : /// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used.
62 : impl<E, F, S> Stream for DownloadStream<F, S>
63 : where
64 : std::io::Error: From<E>,
65 : F: Future<Output = E>,
66 : S: Stream<Item = std::io::Result<Bytes>>,
67 : {
68 : type Item = <S as Stream>::Item;
69 :
70 1884 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 1884 : let this = self.project();
72 1884 :
73 1884 : if !*this.hit {
74 1865 : if let Poll::Ready(e) = this.cancellation.poll(cx) {
75 16 : *this.hit = true;
76 16 :
77 16 : // most likely this will be a std::io::Error wrapping a DownloadError
78 16 : let e = Err(std::io::Error::from(e));
79 16 : return Poll::Ready(Some(e));
80 1849 : }
81 : } else {
82 : // this would be perfectly valid behaviour for doing a graceful completion on the
83 : // download for example, but not one we expect to do right now.
84 19 : tracing::warn!("continuing polling after having cancelled or timeouted");
85 : }
86 :
87 1868 : this.inner.poll_next(cx)
88 1884 : }
89 :
90 0 : fn size_hint(&self) -> (usize, Option<usize>) {
91 0 : self.inner.size_hint()
92 0 : }
93 : }
94 :
95 : /// Fires only on the first cancel or timeout, not on both.
96 122 : pub(crate) fn cancel_or_timeout(
97 122 : timeout: Duration,
98 122 : cancel: CancellationToken,
99 122 : ) -> impl std::future::Future<Output = TimeoutOrCancel> + 'static {
100 122 : // futures are lazy, they don't do anything before being polled.
101 122 : //
102 122 : // "precalculate" the wanted deadline before returning the future, so that we can use pause
103 122 : // failpoint to trigger a timeout in test.
104 122 : let deadline = tokio::time::Instant::now() + timeout;
105 120 : async move {
106 : tokio::select! {
107 : _ = tokio::time::sleep_until(deadline) => TimeoutOrCancel::Timeout,
108 : _ = cancel.cancelled() => {
109 : TimeoutOrCancel::Cancel
110 : },
111 : }
112 16 : }
113 122 : }
114 :
115 : #[cfg(test)]
116 : mod tests {
117 : use super::*;
118 : use crate::DownloadError;
119 : use futures::stream::StreamExt;
120 :
121 : #[tokio::test(start_paused = true)]
122 4 : async fn cancelled_download_stream() {
123 4 : let inner = futures::stream::pending();
124 4 : let timeout = Duration::from_secs(120);
125 4 : let cancel = CancellationToken::new();
126 4 :
127 4 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
128 4 : let mut stream = std::pin::pin!(stream);
129 4 :
130 4 : let mut first = stream.next();
131 4 :
132 4 : tokio::select! {
133 4 : _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"),
134 4 : _ = tokio::time::sleep(Duration::from_secs(1)) => {},
135 4 : }
136 4 :
137 4 : cancel.cancel();
138 4 :
139 4 : let e = first.await.expect("there must be some").unwrap_err();
140 4 : assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
141 4 : let inner = e.get_ref().expect("inner should be set");
142 4 : assert!(
143 4 : inner
144 4 : .downcast_ref::<DownloadError>()
145 4 : .is_some_and(|e| matches!(e, DownloadError::Cancelled)),
146 4 : "{inner:?}"
147 4 : );
148 4 : let e = DownloadError::from(e);
149 4 : assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
150 4 :
151 4 : tokio::select! {
152 4 : _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"),
153 4 : _ = tokio::time::sleep(Duration::from_secs(121)) => {},
154 4 : }
155 4 : }
156 :
157 : #[tokio::test(start_paused = true)]
158 4 : async fn timeouted_download_stream() {
159 4 : let inner = futures::stream::pending();
160 4 : let timeout = Duration::from_secs(120);
161 4 : let cancel = CancellationToken::new();
162 4 :
163 4 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
164 4 : let mut stream = std::pin::pin!(stream);
165 4 :
166 4 : // because the stream uses 120s timeout and we are paused, we advance to 120s right away.
167 4 : let first = stream.next();
168 4 :
169 4 : let e = first.await.expect("there must be some").unwrap_err();
170 4 : assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
171 4 : let inner = e.get_ref().expect("inner should be set");
172 4 : assert!(
173 4 : inner
174 4 : .downcast_ref::<DownloadError>()
175 4 : .is_some_and(|e| matches!(e, DownloadError::Timeout)),
176 4 : "{inner:?}"
177 4 : );
178 4 : let e = DownloadError::from(e);
179 4 : assert!(matches!(e, DownloadError::Timeout), "{e:?}");
180 4 :
181 4 : cancel.cancel();
182 4 :
183 4 : tokio::select! {
184 4 : _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"),
185 4 : _ = tokio::time::sleep(Duration::from_secs(121)) => {},
186 4 : }
187 4 : }
188 :
189 : #[tokio::test]
190 4 : async fn notified_but_pollable_after() {
191 4 : let inner = futures::stream::once(futures::future::ready(Ok(bytes::Bytes::from_static(
192 4 : b"hello world",
193 4 : ))));
194 4 : let timeout = Duration::from_secs(120);
195 4 : let cancel = CancellationToken::new();
196 4 :
197 4 : cancel.cancel();
198 4 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
199 4 : let mut stream = std::pin::pin!(stream);
200 4 :
201 4 : let next = stream.next().await;
202 4 : let ioe = next.unwrap().unwrap_err();
203 4 : assert!(
204 4 : matches!(
205 4 : ioe.get_ref().unwrap().downcast_ref::<DownloadError>(),
206 4 : Some(&DownloadError::Cancelled)
207 4 : ),
208 4 : "{ioe:?}"
209 4 : );
210 4 :
211 4 : let next = stream.next().await;
212 4 : let bytes = next.unwrap().unwrap();
213 4 : assert_eq!(&b"hello world"[..], bytes);
214 4 : }
215 : }
|