LCOV - code coverage report
Current view: top level - libs/remote_storage/src - support.rs (source / functions) Coverage Total Hit
Test: a43a77853355b937a79c57b07a8f05607cf29e6c.info Lines: 95.7 % 138 132
Test Date: 2024-09-19 12:04:32 Functions: 68.6 % 35 24

            Line data    Source code
       1              : use std::{
       2              :     future::Future,
       3              :     pin::Pin,
       4              :     task::{Context, Poll},
       5              :     time::Duration,
       6              : };
       7              : 
       8              : use bytes::Bytes;
       9              : use futures_util::Stream;
      10              : use tokio_util::sync::CancellationToken;
      11              : 
      12              : use crate::TimeoutOrCancel;
      13              : 
      14              : pin_project_lite::pin_project! {
      15              :     /// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
      16              :     pub(crate) struct PermitCarrying<S> {
      17              :         permit: tokio::sync::OwnedSemaphorePermit,
      18              :         #[pin]
      19              :         inner: S,
      20              :     }
      21              : }
      22              : 
      23              : impl<S> PermitCarrying<S> {
      24           24 :     pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
      25           24 :         Self { permit, inner }
      26           24 :     }
      27              : }
      28              : 
      29              : impl<S: Stream> Stream for PermitCarrying<S> {
      30              :     type Item = <S as Stream>::Item;
      31              : 
      32           41 :     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      33           41 :         self.project().inner.poll_next(cx)
      34           41 :     }
      35              : 
      36            0 :     fn size_hint(&self) -> (usize, Option<usize>) {
      37            0 :         self.inner.size_hint()
      38            0 :     }
      39              : }
      40              : 
      41              : pin_project_lite::pin_project! {
      42              :     pub(crate) struct DownloadStream<F, S> {
      43              :         hit: bool,
      44              :         #[pin]
      45              :         cancellation: F,
      46              :         #[pin]
      47              :         inner: S,
      48              :     }
      49              : }
      50              : 
      51              : impl<F, S> DownloadStream<F, S> {
      52          157 :     pub(crate) fn new(cancellation: F, inner: S) -> Self {
      53          157 :         Self {
      54          157 :             cancellation,
      55          157 :             hit: false,
      56          157 :             inner,
      57          157 :         }
      58          157 :     }
      59              : }
      60              : 
      61              : /// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used.
      62              : impl<E, F, S> Stream for DownloadStream<F, S>
      63              : where
      64              :     std::io::Error: From<E>,
      65              :     F: Future<Output = E>,
      66              :     S: Stream<Item = std::io::Result<Bytes>>,
      67              : {
      68              :     type Item = <S as Stream>::Item;
      69              : 
      70         4883 :     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      71         4883 :         let this = self.project();
      72         4883 : 
      73         4883 :         if !*this.hit {
      74         4866 :             if let Poll::Ready(e) = this.cancellation.poll(cx) {
      75           13 :                 *this.hit = true;
      76           13 : 
      77           13 :                 // most likely this will be a std::io::Error wrapping a DownloadError
      78           13 :                 let e = Err(std::io::Error::from(e));
      79           13 :                 return Poll::Ready(Some(e));
      80         4853 :             }
      81              :         } else {
      82              :             // this would be perfectly valid behaviour for doing a graceful completion on the
      83              :             // download for example, but not one we expect to do right now.
      84           17 :             tracing::warn!("continuing polling after having cancelled or timeouted");
      85              :         }
      86              : 
      87         4870 :         this.inner.poll_next(cx)
      88         4883 :     }
      89              : 
      90            0 :     fn size_hint(&self) -> (usize, Option<usize>) {
      91            0 :         self.inner.size_hint()
      92            0 :     }
      93              : }
      94              : 
      95              : /// Fires only on the first cancel or timeout, not on both.
      96          164 : pub(crate) fn cancel_or_timeout(
      97          164 :     timeout: Duration,
      98          164 :     cancel: CancellationToken,
      99          164 : ) -> impl std::future::Future<Output = TimeoutOrCancel> + 'static {
     100          164 :     // futures are lazy, they don't do anything before being polled.
     101          164 :     //
     102          164 :     // "precalculate" the wanted deadline before returning the future, so that we can use pause
     103          164 :     // failpoint to trigger a timeout in test.
     104          164 :     let deadline = tokio::time::Instant::now() + timeout;
     105          162 :     async move {
     106          162 :         tokio::select! {
     107          162 :             _ = tokio::time::sleep_until(deadline) => TimeoutOrCancel::Timeout,
     108          162 :             _ = cancel.cancelled() => {
     109            8 :                 TimeoutOrCancel::Cancel
     110              :             },
     111              :         }
     112           13 :     }
     113          164 : }
     114              : 
     115              : #[cfg(test)]
     116              : mod tests {
     117              :     use super::*;
     118              :     use crate::DownloadError;
     119              :     use futures::stream::StreamExt;
     120              : 
     121              :     #[tokio::test(start_paused = true)]
     122            3 :     async fn cancelled_download_stream() {
     123            3 :         let inner = futures::stream::pending();
     124            3 :         let timeout = Duration::from_secs(120);
     125            3 :         let cancel = CancellationToken::new();
     126            3 : 
     127            3 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     128            3 :         let mut stream = std::pin::pin!(stream);
     129            3 : 
     130            3 :         let mut first = stream.next();
     131            3 : 
     132            3 :         tokio::select! {
     133            3 :             _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"),
     134            3 :             _ = tokio::time::sleep(Duration::from_secs(1)) => {},
     135            3 :         }
     136            3 : 
     137            3 :         cancel.cancel();
     138            3 : 
     139            3 :         let e = first.await.expect("there must be some").unwrap_err();
     140            3 :         assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
     141            3 :         let inner = e.get_ref().expect("inner should be set");
     142            3 :         assert!(
     143            3 :             inner
     144            3 :                 .downcast_ref::<DownloadError>()
     145            3 :                 .is_some_and(|e| matches!(e, DownloadError::Cancelled)),
     146            3 :             "{inner:?}"
     147            3 :         );
     148            3 :         let e = DownloadError::from(e);
     149            3 :         assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
     150            3 : 
     151            3 :         tokio::select! {
     152            3 :             _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"),
     153            3 :             _ = tokio::time::sleep(Duration::from_secs(121)) => {},
     154            3 :         }
     155            3 :     }
     156              : 
     157              :     #[tokio::test(start_paused = true)]
     158            3 :     async fn timeouted_download_stream() {
     159            3 :         let inner = futures::stream::pending();
     160            3 :         let timeout = Duration::from_secs(120);
     161            3 :         let cancel = CancellationToken::new();
     162            3 : 
     163            3 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     164            3 :         let mut stream = std::pin::pin!(stream);
     165            3 : 
     166            3 :         // because the stream uses 120s timeout and we are paused, we advance to 120s right away.
     167            3 :         let first = stream.next();
     168            3 : 
     169            3 :         let e = first.await.expect("there must be some").unwrap_err();
     170            3 :         assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
     171            3 :         let inner = e.get_ref().expect("inner should be set");
     172            3 :         assert!(
     173            3 :             inner
     174            3 :                 .downcast_ref::<DownloadError>()
     175            3 :                 .is_some_and(|e| matches!(e, DownloadError::Timeout)),
     176            3 :             "{inner:?}"
     177            3 :         );
     178            3 :         let e = DownloadError::from(e);
     179            3 :         assert!(matches!(e, DownloadError::Timeout), "{e:?}");
     180            3 : 
     181            3 :         cancel.cancel();
     182            3 : 
     183            3 :         tokio::select! {
     184            3 :             _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"),
     185            3 :             _ = tokio::time::sleep(Duration::from_secs(121)) => {},
     186            3 :         }
     187            3 :     }
     188              : 
     189              :     #[tokio::test]
     190            3 :     async fn notified_but_pollable_after() {
     191            3 :         let inner = futures::stream::once(futures::future::ready(Ok(bytes::Bytes::from_static(
     192            3 :             b"hello world",
     193            3 :         ))));
     194            3 :         let timeout = Duration::from_secs(120);
     195            3 :         let cancel = CancellationToken::new();
     196            3 : 
     197            3 :         cancel.cancel();
     198            3 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     199            3 :         let mut stream = std::pin::pin!(stream);
     200            3 : 
     201            3 :         let next = stream.next().await;
     202            3 :         let ioe = next.unwrap().unwrap_err();
     203            3 :         assert!(
     204            3 :             matches!(
     205            3 :                 ioe.get_ref().unwrap().downcast_ref::<DownloadError>(),
     206            3 :                 Some(&DownloadError::Cancelled)
     207            3 :             ),
     208            3 :             "{ioe:?}"
     209            3 :         );
     210            3 : 
     211            3 :         let next = stream.next().await;
     212            3 :         let bytes = next.unwrap().unwrap();
     213            3 :         assert_eq!(&b"hello world"[..], bytes);
     214            3 :     }
     215              : }
        

Generated by: LCOV version 2.1-beta