Line data Source code
1 : use std::{
2 : future::Future,
3 : pin::Pin,
4 : task::{Context, Poll},
5 : time::Duration,
6 : };
7 :
8 : use bytes::Bytes;
9 : use futures_util::Stream;
10 : use tokio_util::sync::CancellationToken;
11 :
12 : use crate::TimeoutOrCancel;
13 :
14 : pin_project_lite::pin_project! {
15 : /// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
16 : pub(crate) struct PermitCarrying<S> {
17 : permit: tokio::sync::OwnedSemaphorePermit,
18 : #[pin]
19 : inner: S,
20 : }
21 : }
22 :
23 : impl<S> PermitCarrying<S> {
24 24 : pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
25 24 : Self { permit, inner }
26 24 : }
27 : }
28 :
29 : impl<S: Stream> Stream for PermitCarrying<S> {
30 : type Item = <S as Stream>::Item;
31 :
32 49 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
33 49 : self.project().inner.poll_next(cx)
34 49 : }
35 :
36 0 : fn size_hint(&self) -> (usize, Option<usize>) {
37 0 : self.inner.size_hint()
38 0 : }
39 : }
40 :
41 : pin_project_lite::pin_project! {
42 : pub(crate) struct DownloadStream<F, S> {
43 : hit: bool,
44 : #[pin]
45 : cancellation: F,
46 : #[pin]
47 : inner: S,
48 : }
49 : }
50 :
51 : impl<F, S> DownloadStream<F, S> {
52 227 : pub(crate) fn new(cancellation: F, inner: S) -> Self {
53 227 : Self {
54 227 : cancellation,
55 227 : hit: false,
56 227 : inner,
57 227 : }
58 227 : }
59 : }
60 :
61 : /// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used.
62 : impl<E, F, S> Stream for DownloadStream<F, S>
63 : where
64 : std::io::Error: From<E>,
65 : F: Future<Output = E>,
66 : S: Stream<Item = std::io::Result<Bytes>>,
67 : {
68 : type Item = <S as Stream>::Item;
69 :
70 4826 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 4826 : let this = self.project();
72 4826 :
73 4826 : if !*this.hit {
74 4789 : if let Poll::Ready(e) = this.cancellation.poll(cx) {
75 28 : *this.hit = true;
76 28 :
77 28 : // most likely this will be a std::io::Error wrapping a DownloadError
78 28 : let e = Err(std::io::Error::from(e));
79 28 : return Poll::Ready(Some(e));
80 4761 : }
81 : } else {
82 : // this would be perfectly valid behaviour for doing a graceful completion on the
83 : // download for example, but not one we expect to do right now.
84 37 : tracing::warn!("continuing polling after having cancelled or timeouted");
85 : }
86 :
87 4798 : this.inner.poll_next(cx)
88 4826 : }
89 :
90 0 : fn size_hint(&self) -> (usize, Option<usize>) {
91 0 : self.inner.size_hint()
92 0 : }
93 : }
94 :
95 : /// Fires only on the first cancel or timeout, not on both.
96 234 : pub(crate) fn cancel_or_timeout(
97 234 : timeout: Duration,
98 234 : cancel: CancellationToken,
99 234 : ) -> impl std::future::Future<Output = TimeoutOrCancel> + 'static {
100 234 : // futures are lazy, they don't do anything before being polled.
101 234 : //
102 234 : // "precalculate" the wanted deadline before returning the future, so that we can use pause
103 234 : // failpoint to trigger a timeout in test.
104 234 : let deadline = tokio::time::Instant::now() + timeout;
105 232 : async move {
106 : tokio::select! {
107 : _ = tokio::time::sleep_until(deadline) => TimeoutOrCancel::Timeout,
108 : _ = cancel.cancelled() => {
109 : TimeoutOrCancel::Cancel
110 : },
111 : }
112 28 : }
113 234 : }
114 :
115 : #[cfg(test)]
116 : mod tests {
117 : use super::*;
118 : use crate::DownloadError;
119 : use futures::stream::StreamExt;
120 :
121 : #[tokio::test(start_paused = true)]
122 8 : async fn cancelled_download_stream() {
123 8 : let inner = futures::stream::pending();
124 8 : let timeout = Duration::from_secs(120);
125 8 : let cancel = CancellationToken::new();
126 8 :
127 8 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
128 8 : let mut stream = std::pin::pin!(stream);
129 8 :
130 8 : let mut first = stream.next();
131 8 :
132 8 : tokio::select! {
133 8 : _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"),
134 8 : _ = tokio::time::sleep(Duration::from_secs(1)) => {},
135 8 : }
136 8 :
137 8 : cancel.cancel();
138 8 :
139 8 : let e = first.await.expect("there must be some").unwrap_err();
140 8 : assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
141 8 : let inner = e.get_ref().expect("inner should be set");
142 8 : assert!(
143 8 : inner
144 8 : .downcast_ref::<DownloadError>()
145 8 : .is_some_and(|e| matches!(e, DownloadError::Cancelled)),
146 8 : "{inner:?}"
147 8 : );
148 8 : let e = DownloadError::from(e);
149 8 : assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
150 8 :
151 8 : tokio::select! {
152 8 : _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"),
153 8 : _ = tokio::time::sleep(Duration::from_secs(121)) => {},
154 8 : }
155 8 : }
156 :
157 : #[tokio::test(start_paused = true)]
158 8 : async fn timeouted_download_stream() {
159 8 : let inner = futures::stream::pending();
160 8 : let timeout = Duration::from_secs(120);
161 8 : let cancel = CancellationToken::new();
162 8 :
163 8 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
164 8 : let mut stream = std::pin::pin!(stream);
165 8 :
166 8 : // because the stream uses 120s timeout and we are paused, we advance to 120s right away.
167 8 : let first = stream.next();
168 8 :
169 8 : let e = first.await.expect("there must be some").unwrap_err();
170 8 : assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
171 8 : let inner = e.get_ref().expect("inner should be set");
172 8 : assert!(
173 8 : inner
174 8 : .downcast_ref::<DownloadError>()
175 8 : .is_some_and(|e| matches!(e, DownloadError::Timeout)),
176 8 : "{inner:?}"
177 8 : );
178 8 : let e = DownloadError::from(e);
179 8 : assert!(matches!(e, DownloadError::Timeout), "{e:?}");
180 8 :
181 8 : cancel.cancel();
182 8 :
183 8 : tokio::select! {
184 8 : _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"),
185 8 : _ = tokio::time::sleep(Duration::from_secs(121)) => {},
186 8 : }
187 8 : }
188 :
189 : #[tokio::test]
190 8 : async fn notified_but_pollable_after() {
191 8 : let inner = futures::stream::once(futures::future::ready(Ok(bytes::Bytes::from_static(
192 8 : b"hello world",
193 8 : ))));
194 8 : let timeout = Duration::from_secs(120);
195 8 : let cancel = CancellationToken::new();
196 8 :
197 8 : cancel.cancel();
198 8 : let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
199 8 : let mut stream = std::pin::pin!(stream);
200 8 :
201 8 : let next = stream.next().await;
202 8 : let ioe = next.unwrap().unwrap_err();
203 8 : assert!(
204 8 : matches!(
205 8 : ioe.get_ref().unwrap().downcast_ref::<DownloadError>(),
206 8 : Some(&DownloadError::Cancelled)
207 8 : ),
208 8 : "{ioe:?}"
209 8 : );
210 8 :
211 8 : let next = stream.next().await;
212 8 : let bytes = next.unwrap().unwrap();
213 8 : assert_eq!(&b"hello world"[..], bytes);
214 8 : }
215 : }
|