LCOV - code coverage report
Current view: top level - libs/remote_storage/src - support.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 90.9 % 110 100
Test Date: 2025-07-16 12:29:03 Functions: 84.6 % 26 22

            Line data    Source code
       1              : use std::future::Future;
       2              : use std::pin::Pin;
       3              : use std::task::{Context, Poll};
       4              : use std::time::Duration;
       5              : 
       6              : use bytes::Bytes;
       7              : use futures_util::Stream;
       8              : use tokio_util::sync::CancellationToken;
       9              : 
      10              : use crate::TimeoutOrCancel;
      11              : 
      12              : pin_project_lite::pin_project! {
      13              :     /// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
      14              :     pub(crate) struct PermitCarrying<S> {
      15              :         permit: tokio::sync::OwnedSemaphorePermit,
      16              :         #[pin]
      17              :         inner: S,
      18              :     }
      19              : }
      20              : 
      21              : impl<S> PermitCarrying<S> {
      22           28 :     pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
      23           28 :         Self { permit, inner }
      24           28 :     }
      25              : }
      26              : 
      27              : impl<S: Stream> Stream for PermitCarrying<S> {
      28              :     type Item = <S as Stream>::Item;
      29              : 
      30           42 :     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      31           42 :         self.project().inner.poll_next(cx)
      32           42 :     }
      33              : 
      34            0 :     fn size_hint(&self) -> (usize, Option<usize>) {
      35            0 :         self.inner.size_hint()
      36            0 :     }
      37              : }
      38              : 
      39              : pin_project_lite::pin_project! {
      40              :     pub(crate) struct DownloadStream<F, S> {
      41              :         hit: bool,
      42              :         #[pin]
      43              :         cancellation: F,
      44              :         #[pin]
      45              :         inner: S,
      46              :     }
      47              : }
      48              : 
      49              : impl<F, S> DownloadStream<F, S> {
      50          141 :     pub(crate) fn new(cancellation: F, inner: S) -> Self {
      51          141 :         Self {
      52          141 :             cancellation,
      53          141 :             hit: false,
      54          141 :             inner,
      55          141 :         }
      56          141 :     }
      57              : }
      58              : 
      59              : /// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used.
      60              : impl<E, F, S> Stream for DownloadStream<F, S>
      61              : where
      62              :     std::io::Error: From<E>,
      63              :     F: Future<Output = E>,
      64              :     S: Stream<Item = std::io::Result<Bytes>>,
      65              : {
      66              :     type Item = <S as Stream>::Item;
      67              : 
      68         1117 :     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      69         1117 :         let this = self.project();
      70              : 
      71         1117 :         if !*this.hit {
      72         1103 :             if let Poll::Ready(e) = this.cancellation.poll(cx) {
      73           13 :                 *this.hit = true;
      74              : 
      75              :                 // most likely this will be a std::io::Error wrapping a DownloadError
      76           13 :                 let e = Err(std::io::Error::from(e));
      77           13 :                 return Poll::Ready(Some(e));
      78         1090 :             }
      79              :         } else {
      80              :             // this would be perfectly valid behaviour for doing a graceful completion on the
      81              :             // download for example, but not one we expect to do right now.
      82           14 :             tracing::warn!("continuing polling after having cancelled or timeouted");
      83              :         }
      84              : 
      85         1104 :         this.inner.poll_next(cx)
      86         1117 :     }
      87              : 
      88            0 :     fn size_hint(&self) -> (usize, Option<usize>) {
      89            0 :         self.inner.size_hint()
      90            0 :     }
      91              : }
      92              : 
      93              : /// Fires only on the first cancel or timeout, not on both.
      94          154 : pub(crate) fn cancel_or_timeout(
      95          154 :     timeout: Duration,
      96          154 :     cancel: CancellationToken,
      97          154 : ) -> impl std::future::Future<Output = TimeoutOrCancel> + 'static {
      98              :     // futures are lazy, they don't do anything before being polled.
      99              :     //
     100              :     // "precalculate" the wanted deadline before returning the future, so that we can use pause
     101              :     // failpoint to trigger a timeout in test.
     102          154 :     let deadline = tokio::time::Instant::now() + timeout;
     103          139 :     async move {
     104          139 :         tokio::select! {
     105          139 :             _ = tokio::time::sleep_until(deadline) => TimeoutOrCancel::Timeout,
     106          139 :             _ = cancel.cancelled() => {
     107            8 :                 TimeoutOrCancel::Cancel
     108              :             },
     109              :         }
     110           13 :     }
     111          154 : }
     112              : 
     113              : #[cfg(test)]
     114              : mod tests {
     115              :     use futures::stream::StreamExt;
     116              : 
     117              :     use super::*;
     118              :     use crate::DownloadError;
     119              : 
     120              :     #[tokio::test(start_paused = true)]
     121            3 :     async fn cancelled_download_stream() {
     122            3 :         let inner = futures::stream::pending();
     123            3 :         let timeout = Duration::from_secs(120);
     124            3 :         let cancel = CancellationToken::new();
     125              : 
     126            3 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     127            3 :         let mut stream = std::pin::pin!(stream);
     128              : 
     129            3 :         let mut first = stream.next();
     130              : 
     131            3 :         tokio::select! {
     132            3 :             _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"),
     133            3 :             _ = tokio::time::sleep(Duration::from_secs(1)) => {},
     134              :         }
     135              : 
     136            3 :         cancel.cancel();
     137              : 
     138            3 :         let e = first.await.expect("there must be some").unwrap_err();
     139            3 :         assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
     140            3 :         let inner = e.get_ref().expect("inner should be set");
     141            3 :         assert!(
     142            3 :             inner
     143            3 :                 .downcast_ref::<DownloadError>()
     144            3 :                 .is_some_and(|e| matches!(e, DownloadError::Cancelled)),
     145            0 :             "{inner:?}"
     146              :         );
     147            3 :         let e = DownloadError::from(e);
     148            3 :         assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
     149              : 
     150            3 :         tokio::select! {
     151            3 :             _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"),
     152            3 :             _ = tokio::time::sleep(Duration::from_secs(121)) => {},
     153            3 :         }
     154            3 :     }
     155              : 
     156              :     #[tokio::test(start_paused = true)]
     157            3 :     async fn timeouted_download_stream() {
     158            3 :         let inner = futures::stream::pending();
     159            3 :         let timeout = Duration::from_secs(120);
     160            3 :         let cancel = CancellationToken::new();
     161              : 
     162            3 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     163            3 :         let mut stream = std::pin::pin!(stream);
     164              : 
     165              :         // because the stream uses 120s timeout and we are paused, we advance to 120s right away.
     166            3 :         let first = stream.next();
     167              : 
     168            3 :         let e = first.await.expect("there must be some").unwrap_err();
     169            3 :         assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
     170            3 :         let inner = e.get_ref().expect("inner should be set");
     171            3 :         assert!(
     172            3 :             inner
     173            3 :                 .downcast_ref::<DownloadError>()
     174            3 :                 .is_some_and(|e| matches!(e, DownloadError::Timeout)),
     175            0 :             "{inner:?}"
     176              :         );
     177            3 :         let e = DownloadError::from(e);
     178            3 :         assert!(matches!(e, DownloadError::Timeout), "{e:?}");
     179              : 
     180            3 :         cancel.cancel();
     181              : 
     182            3 :         tokio::select! {
     183            3 :             _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"),
     184            3 :             _ = tokio::time::sleep(Duration::from_secs(121)) => {},
     185            3 :         }
     186            3 :     }
     187              : 
     188              :     #[tokio::test]
     189            3 :     async fn notified_but_pollable_after() {
     190            3 :         let inner = futures::stream::once(futures::future::ready(Ok(bytes::Bytes::from_static(
     191            3 :             b"hello world",
     192            3 :         ))));
     193            3 :         let timeout = Duration::from_secs(120);
     194            3 :         let cancel = CancellationToken::new();
     195              : 
     196            3 :         cancel.cancel();
     197            3 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     198            3 :         let mut stream = std::pin::pin!(stream);
     199              : 
     200            3 :         let next = stream.next().await;
     201            3 :         let ioe = next.unwrap().unwrap_err();
     202            3 :         assert!(
     203            0 :             matches!(
     204            3 :                 ioe.get_ref().unwrap().downcast_ref::<DownloadError>(),
     205              :                 Some(&DownloadError::Cancelled)
     206              :             ),
     207            0 :             "{ioe:?}"
     208              :         );
     209              : 
     210            3 :         let next = stream.next().await;
     211            3 :         let bytes = next.unwrap().unwrap();
     212            3 :         assert_eq!(&b"hello world"[..], bytes);
     213            3 :     }
     214              : }
        

Generated by: LCOV version 2.1-beta