LCOV - code coverage report
Current view: top level - libs/remote_storage/src - support.rs (source / functions) Coverage Total Hit
Test: 6df3fc19ec669bcfbbf9aba41d1338898d24eaa0.info Lines: 95.7 % 138 132
Test Date: 2025-03-12 18:28:53 Functions: 68.8 % 32 22

            Line data    Source code
       1              : use std::future::Future;
       2              : use std::pin::Pin;
       3              : use std::task::{Context, Poll};
       4              : use std::time::Duration;
       5              : 
       6              : use bytes::Bytes;
       7              : use futures_util::Stream;
       8              : use tokio_util::sync::CancellationToken;
       9              : 
      10              : use crate::TimeoutOrCancel;
      11              : 
      12              : pin_project_lite::pin_project! {
      13              :     /// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
      14              :     pub(crate) struct PermitCarrying<S> {
      15              :         permit: tokio::sync::OwnedSemaphorePermit,
      16              :         #[pin]
      17              :         inner: S,
      18              :     }
      19              : }
      20              : 
      21              : impl<S> PermitCarrying<S> {
      22           28 :     pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
      23           28 :         Self { permit, inner }
      24           28 :     }
      25              : }
      26              : 
      27              : impl<S: Stream> Stream for PermitCarrying<S> {
      28              :     type Item = <S as Stream>::Item;
      29              : 
      30           46 :     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      31           46 :         self.project().inner.poll_next(cx)
      32           46 :     }
      33              : 
      34            0 :     fn size_hint(&self) -> (usize, Option<usize>) {
      35            0 :         self.inner.size_hint()
      36            0 :     }
      37              : }
      38              : 
      39              : pin_project_lite::pin_project! {
      40              :     pub(crate) struct DownloadStream<F, S> {
      41              :         hit: bool,
      42              :         #[pin]
      43              :         cancellation: F,
      44              :         #[pin]
      45              :         inner: S,
      46              :     }
      47              : }
      48              : 
      49              : impl<F, S> DownloadStream<F, S> {
      50          151 :     pub(crate) fn new(cancellation: F, inner: S) -> Self {
      51          151 :         Self {
      52          151 :             cancellation,
      53          151 :             hit: false,
      54          151 :             inner,
      55          151 :         }
      56          151 :     }
      57              : }
      58              : 
      59              : /// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used.
      60              : impl<E, F, S> Stream for DownloadStream<F, S>
      61              : where
      62              :     std::io::Error: From<E>,
      63              :     F: Future<Output = E>,
      64              :     S: Stream<Item = std::io::Result<Bytes>>,
      65              : {
      66              :     type Item = <S as Stream>::Item;
      67              : 
      68         2731 :     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      69         2731 :         let this = self.project();
      70         2731 : 
      71         2731 :         if !*this.hit {
      72         2717 :             if let Poll::Ready(e) = this.cancellation.poll(cx) {
      73           13 :                 *this.hit = true;
      74           13 : 
      75           13 :                 // most likely this will be a std::io::Error wrapping a DownloadError
      76           13 :                 let e = Err(std::io::Error::from(e));
      77           13 :                 return Poll::Ready(Some(e));
      78          109 :             }
      79              :         } else {
      80              :             // this would be perfectly valid behaviour for doing a graceful completion on the
      81              :             // download for example, but not one we expect to do right now.
      82           14 :             tracing::warn!("continuing polling after having cancelled or timeouted");
      83              :         }
      84              : 
      85         2718 :         this.inner.poll_next(cx)
      86          132 :     }
      87              : 
      88            0 :     fn size_hint(&self) -> (usize, Option<usize>) {
      89            0 :         self.inner.size_hint()
      90            0 :     }
      91              : }
      92              : 
      93              : /// Fires only on the first cancel or timeout, not on both.
      94          164 : pub(crate) fn cancel_or_timeout(
      95          164 :     timeout: Duration,
      96          164 :     cancel: CancellationToken,
      97          164 : ) -> impl std::future::Future<Output = TimeoutOrCancel> + 'static {
      98          164 :     // futures are lazy, they don't do anything before being polled.
      99          164 :     //
     100          164 :     // "precalculate" the wanted deadline before returning the future, so that we can use pause
     101          164 :     // failpoint to trigger a timeout in test.
     102          164 :     let deadline = tokio::time::Instant::now() + timeout;
     103          154 :     async move {
     104          154 :         tokio::select! {
     105          154 :             _ = tokio::time::sleep_until(deadline) => TimeoutOrCancel::Timeout,
     106          154 :             _ = cancel.cancelled() => {
     107            8 :                 TimeoutOrCancel::Cancel
     108              :             },
     109              :         }
     110            9 :     }
     111          164 : }
     112              : 
     113              : #[cfg(test)]
     114              : mod tests {
     115              :     use futures::stream::StreamExt;
     116              : 
     117              :     use super::*;
     118              :     use crate::DownloadError;
     119              : 
     120              :     #[tokio::test(start_paused = true)]
     121            3 :     async fn cancelled_download_stream() {
     122            3 :         let inner = futures::stream::pending();
     123            3 :         let timeout = Duration::from_secs(120);
     124            3 :         let cancel = CancellationToken::new();
     125            3 : 
     126            3 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     127            3 :         let mut stream = std::pin::pin!(stream);
     128            3 : 
     129            3 :         let mut first = stream.next();
     130            3 : 
     131            3 :         tokio::select! {
     132            3 :             _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"),
     133            3 :             _ = tokio::time::sleep(Duration::from_secs(1)) => {},
     134            3 :         }
     135            3 : 
     136            3 :         cancel.cancel();
     137            3 : 
     138            3 :         let e = first.await.expect("there must be some").unwrap_err();
     139            3 :         assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
     140            3 :         let inner = e.get_ref().expect("inner should be set");
     141            3 :         assert!(
     142            3 :             inner
     143            3 :                 .downcast_ref::<DownloadError>()
     144            3 :                 .is_some_and(|e| matches!(e, DownloadError::Cancelled)),
     145            3 :             "{inner:?}"
     146            3 :         );
     147            3 :         let e = DownloadError::from(e);
     148            3 :         assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
     149            3 : 
     150            3 :         tokio::select! {
     151            3 :             _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"),
     152            3 :             _ = tokio::time::sleep(Duration::from_secs(121)) => {},
     153            3 :         }
     154            3 :     }
     155              : 
     156              :     #[tokio::test(start_paused = true)]
     157            3 :     async fn timeouted_download_stream() {
     158            3 :         let inner = futures::stream::pending();
     159            3 :         let timeout = Duration::from_secs(120);
     160            3 :         let cancel = CancellationToken::new();
     161            3 : 
     162            3 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     163            3 :         let mut stream = std::pin::pin!(stream);
     164            3 : 
     165            3 :         // because the stream uses 120s timeout and we are paused, we advance to 120s right away.
     166            3 :         let first = stream.next();
     167            3 : 
     168            3 :         let e = first.await.expect("there must be some").unwrap_err();
     169            3 :         assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
     170            3 :         let inner = e.get_ref().expect("inner should be set");
     171            3 :         assert!(
     172            3 :             inner
     173            3 :                 .downcast_ref::<DownloadError>()
     174            3 :                 .is_some_and(|e| matches!(e, DownloadError::Timeout)),
     175            3 :             "{inner:?}"
     176            3 :         );
     177            3 :         let e = DownloadError::from(e);
     178            3 :         assert!(matches!(e, DownloadError::Timeout), "{e:?}");
     179            3 : 
     180            3 :         cancel.cancel();
     181            3 : 
     182            3 :         tokio::select! {
     183            3 :             _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"),
     184            3 :             _ = tokio::time::sleep(Duration::from_secs(121)) => {},
     185            3 :         }
     186            3 :     }
     187              : 
     188              :     #[tokio::test]
     189            3 :     async fn notified_but_pollable_after() {
     190            3 :         let inner = futures::stream::once(futures::future::ready(Ok(bytes::Bytes::from_static(
     191            3 :             b"hello world",
     192            3 :         ))));
     193            3 :         let timeout = Duration::from_secs(120);
     194            3 :         let cancel = CancellationToken::new();
     195            3 : 
     196            3 :         cancel.cancel();
     197            3 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     198            3 :         let mut stream = std::pin::pin!(stream);
     199            3 : 
     200            3 :         let next = stream.next().await;
     201            3 :         let ioe = next.unwrap().unwrap_err();
     202            3 :         assert!(
     203            3 :             matches!(
     204            3 :                 ioe.get_ref().unwrap().downcast_ref::<DownloadError>(),
     205            3 :                 Some(&DownloadError::Cancelled)
     206            3 :             ),
     207            3 :             "{ioe:?}"
     208            3 :         );
     209            3 : 
     210            3 :         let next = stream.next().await;
     211            3 :         let bytes = next.unwrap().unwrap();
     212            3 :         assert_eq!(&b"hello world"[..], bytes);
     213            3 :     }
     214              : }
        

Generated by: LCOV version 2.1-beta