LCOV - code coverage report
Current view: top level - libs/remote_storage/src - support.rs (source / functions) Coverage Total Hit
Test: fabb29a6339542ee130cd1d32b534fafdc0be240.info Lines: 95.5 % 134 128
Test Date: 2024-06-25 13:20:00 Functions: 68.6 % 35 24

            Line data    Source code
       1              : use std::{
       2              :     future::Future,
       3              :     pin::Pin,
       4              :     task::{Context, Poll},
       5              :     time::Duration,
       6              : };
       7              : 
       8              : use bytes::Bytes;
       9              : use futures_util::Stream;
      10              : use tokio_util::sync::CancellationToken;
      11              : 
      12              : use crate::TimeoutOrCancel;
      13              : 
      14              : pin_project_lite::pin_project! {
      15              :     /// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
      16              :     pub(crate) struct PermitCarrying<S> {
      17              :         permit: tokio::sync::OwnedSemaphorePermit,
      18              :         #[pin]
      19              :         inner: S,
      20              :     }
      21              : }
      22              : 
      23              : impl<S> PermitCarrying<S> {
      24           24 :     pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
      25           24 :         Self { permit, inner }
      26           24 :     }
      27              : }
      28              : 
      29              : impl<S: Stream> Stream for PermitCarrying<S> {
      30              :     type Item = <S as Stream>::Item;
      31              : 
      32           48 :     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      33           48 :         self.project().inner.poll_next(cx)
      34           48 :     }
      35              : 
      36            0 :     fn size_hint(&self) -> (usize, Option<usize>) {
      37            0 :         self.inner.size_hint()
      38            0 :     }
      39              : }
      40              : 
      41              : pin_project_lite::pin_project! {
      42              :     pub(crate) struct DownloadStream<F, S> {
      43              :         hit: bool,
      44              :         #[pin]
      45              :         cancellation: F,
      46              :         #[pin]
      47              :         inner: S,
      48              :     }
      49              : }
      50              : 
      51              : impl<F, S> DownloadStream<F, S> {
      52          115 :     pub(crate) fn new(cancellation: F, inner: S) -> Self {
      53          115 :         Self {
      54          115 :             cancellation,
      55          115 :             hit: false,
      56          115 :             inner,
      57          115 :         }
      58          115 :     }
      59              : }
      60              : 
      61              : /// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used.
      62              : impl<E, F, S> Stream for DownloadStream<F, S>
      63              : where
      64              :     std::io::Error: From<E>,
      65              :     F: Future<Output = E>,
      66              :     S: Stream<Item = std::io::Result<Bytes>>,
      67              : {
      68              :     type Item = <S as Stream>::Item;
      69              : 
      70         1630 :     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      71         1630 :         let this = self.project();
      72         1630 : 
      73         1630 :         if !*this.hit {
      74         1610 :             if let Poll::Ready(e) = this.cancellation.poll(cx) {
      75           16 :                 *this.hit = true;
      76           16 : 
      77           16 :                 // most likely this will be a std::io::Error wrapping a DownloadError
      78           16 :                 let e = Err(std::io::Error::from(e));
      79           16 :                 return Poll::Ready(Some(e));
      80         1594 :             }
      81              :         } else {
      82              :             // this would be perfectly valid behaviour for doing a graceful completion on the
      83              :             // download for example, but not one we expect to do right now.
      84           20 :             tracing::warn!("continuing polling after having cancelled or timeouted");
      85              :         }
      86              : 
      87         1614 :         this.inner.poll_next(cx)
      88         1630 :     }
      89              : 
      90            0 :     fn size_hint(&self) -> (usize, Option<usize>) {
      91            0 :         self.inner.size_hint()
      92            0 :     }
      93              : }
      94              : 
      95              : /// Fires only on the first cancel or timeout, not on both.
      96          122 : pub(crate) fn cancel_or_timeout(
      97          122 :     timeout: Duration,
      98          122 :     cancel: CancellationToken,
      99          122 : ) -> impl std::future::Future<Output = TimeoutOrCancel> + 'static {
     100          122 :     // futures are lazy, they don't do anything before being polled.
     101          122 :     //
     102          122 :     // "precalculate" the wanted deadline before returning the future, so that we can use pause
     103          122 :     // failpoint to trigger a timeout in test.
     104          122 :     let deadline = tokio::time::Instant::now() + timeout;
     105          120 :     async move {
     106              :         tokio::select! {
     107              :             _ = tokio::time::sleep_until(deadline) => TimeoutOrCancel::Timeout,
     108              :             _ = cancel.cancelled() => {
     109              :                 TimeoutOrCancel::Cancel
     110              :             },
     111              :         }
     112           16 :     }
     113          122 : }
     114              : 
     115              : #[cfg(test)]
     116              : mod tests {
     117              :     use super::*;
     118              :     use crate::DownloadError;
     119              :     use futures::stream::StreamExt;
     120              : 
     121              :     #[tokio::test(start_paused = true)]
     122            4 :     async fn cancelled_download_stream() {
     123            4 :         let inner = futures::stream::pending();
     124            4 :         let timeout = Duration::from_secs(120);
     125            4 :         let cancel = CancellationToken::new();
     126            4 : 
     127            4 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     128            4 :         let mut stream = std::pin::pin!(stream);
     129            4 : 
     130            4 :         let mut first = stream.next();
     131            4 : 
     132            4 :         tokio::select! {
     133            4 :             _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"),
     134            4 :             _ = tokio::time::sleep(Duration::from_secs(1)) => {},
     135            4 :         }
     136            4 : 
     137            4 :         cancel.cancel();
     138            4 : 
     139            4 :         let e = first.await.expect("there must be some").unwrap_err();
     140            4 :         assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
     141            4 :         let inner = e.get_ref().expect("inner should be set");
     142            4 :         assert!(
     143            4 :             inner
     144            4 :                 .downcast_ref::<DownloadError>()
     145            4 :                 .is_some_and(|e| matches!(e, DownloadError::Cancelled)),
     146            4 :             "{inner:?}"
     147            4 :         );
     148            4 :         let e = DownloadError::from(e);
     149            4 :         assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
     150            4 : 
     151            4 :         tokio::select! {
     152            4 :             _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"),
     153            4 :             _ = tokio::time::sleep(Duration::from_secs(121)) => {},
     154            4 :         }
     155            4 :     }
     156              : 
     157              :     #[tokio::test(start_paused = true)]
     158            4 :     async fn timeouted_download_stream() {
     159            4 :         let inner = futures::stream::pending();
     160            4 :         let timeout = Duration::from_secs(120);
     161            4 :         let cancel = CancellationToken::new();
     162            4 : 
     163            4 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     164            4 :         let mut stream = std::pin::pin!(stream);
     165            4 : 
     166            4 :         // because the stream uses 120s timeout and we are paused, we advance to 120s right away.
     167            4 :         let first = stream.next();
     168            4 : 
     169            4 :         let e = first.await.expect("there must be some").unwrap_err();
     170            4 :         assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
     171            4 :         let inner = e.get_ref().expect("inner should be set");
     172            4 :         assert!(
     173            4 :             inner
     174            4 :                 .downcast_ref::<DownloadError>()
     175            4 :                 .is_some_and(|e| matches!(e, DownloadError::Timeout)),
     176            4 :             "{inner:?}"
     177            4 :         );
     178            4 :         let e = DownloadError::from(e);
     179            4 :         assert!(matches!(e, DownloadError::Timeout), "{e:?}");
     180            4 : 
     181            4 :         cancel.cancel();
     182            4 : 
     183            4 :         tokio::select! {
     184            4 :             _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"),
     185            4 :             _ = tokio::time::sleep(Duration::from_secs(121)) => {},
     186            4 :         }
     187            4 :     }
     188              : 
     189              :     #[tokio::test]
     190            4 :     async fn notified_but_pollable_after() {
     191            4 :         let inner = futures::stream::once(futures::future::ready(Ok(bytes::Bytes::from_static(
     192            4 :             b"hello world",
     193            4 :         ))));
     194            4 :         let timeout = Duration::from_secs(120);
     195            4 :         let cancel = CancellationToken::new();
     196            4 : 
     197            4 :         cancel.cancel();
     198            4 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     199            4 :         let mut stream = std::pin::pin!(stream);
     200            4 : 
     201            4 :         let next = stream.next().await;
     202            4 :         let ioe = next.unwrap().unwrap_err();
     203            4 :         assert!(
     204            4 :             matches!(
     205            4 :                 ioe.get_ref().unwrap().downcast_ref::<DownloadError>(),
     206            4 :                 Some(&DownloadError::Cancelled)
     207            4 :             ),
     208            4 :             "{ioe:?}"
     209            4 :         );
     210            4 : 
     211            4 :         let next = stream.next().await;
     212            4 :         let bytes = next.unwrap().unwrap();
     213            4 :         assert_eq!(&b"hello world"[..], bytes);
     214            4 :     }
     215              : }
        

Generated by: LCOV version 2.1-beta