LCOV - code coverage report
Current view: top level - libs/remote_storage/src - support.rs (source / functions) Coverage Total Hit
Test: 691a4c28fe7169edd60b367c52d448a0a6605f1f.info Lines: 88.2 % 102 90
Test Date: 2024-05-10 13:18:37 Functions: 45.2 % 31 14

            Line data    Source code
       1              : use std::{
       2              :     future::Future,
       3              :     pin::Pin,
       4              :     task::{Context, Poll},
       5              :     time::Duration,
       6              : };
       7              : 
       8              : use bytes::Bytes;
       9              : use futures_util::Stream;
      10              : use tokio_util::sync::CancellationToken;
      11              : 
      12              : use crate::TimeoutOrCancel;
      13              : 
      14              : pin_project_lite::pin_project! {
      15              :     /// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
      16              :     pub(crate) struct PermitCarrying<S> {
      17              :         permit: tokio::sync::OwnedSemaphorePermit,
      18              :         #[pin]
      19              :         inner: S,
      20              :     }
      21              : }
      22              : 
      23              : impl<S> PermitCarrying<S> {
      24            0 :     pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
      25            0 :         Self { permit, inner }
      26            0 :     }
      27              : }
      28              : 
      29              : impl<S: Stream> Stream for PermitCarrying<S> {
      30              :     type Item = <S as Stream>::Item;
      31              : 
      32            0 :     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      33            0 :         self.project().inner.poll_next(cx)
      34            0 :     }
      35              : 
      36            0 :     fn size_hint(&self) -> (usize, Option<usize>) {
      37            0 :         self.inner.size_hint()
      38            0 :     }
      39              : }
      40              : 
      41              : pin_project_lite::pin_project! {
      42              :     pub(crate) struct DownloadStream<F, S> {
      43              :         hit: bool,
      44              :         #[pin]
      45              :         cancellation: F,
      46              :         #[pin]
      47              :         inner: S,
      48              :     }
      49              : }
      50              : 
      51              : impl<F, S> DownloadStream<F, S> {
      52           54 :     pub(crate) fn new(cancellation: F, inner: S) -> Self {
      53           54 :         Self {
      54           54 :             cancellation,
      55           54 :             hit: false,
      56           54 :             inner,
      57           54 :         }
      58           54 :     }
      59              : }
      60              : 
      61              : /// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used.
      62              : impl<E, F, S> Stream for DownloadStream<F, S>
      63              : where
      64              :     std::io::Error: From<E>,
      65              :     F: Future<Output = E>,
      66              :     S: Stream<Item = std::io::Result<Bytes>>,
      67              : {
      68              :     type Item = <S as Stream>::Item;
      69              : 
      70         1683 :     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      71         1683 :         let this = self.project();
      72         1683 : 
      73         1683 :         if !*this.hit {
      74         1675 :             if let Poll::Ready(e) = this.cancellation.poll(cx) {
      75            4 :                 *this.hit = true;
      76            4 : 
      77            4 :                 // most likely this will be a std::io::Error wrapping a DownloadError
      78            4 :                 let e = Err(std::io::Error::from(e));
      79            4 :                 return Poll::Ready(Some(e));
      80         1671 :             }
      81            8 :         }
      82              : 
      83         1679 :         this.inner.poll_next(cx)
      84         1683 :     }
      85              : 
      86            0 :     fn size_hint(&self) -> (usize, Option<usize>) {
      87            0 :         self.inner.size_hint()
      88            0 :     }
      89              : }
      90              : 
      91              : /// Fires only on the first cancel or timeout, not on both.
      92           92 : pub(crate) async fn cancel_or_timeout(
      93           92 :     timeout: Duration,
      94           92 :     cancel: CancellationToken,
      95           92 : ) -> TimeoutOrCancel {
      96              :     tokio::select! {
      97              :         _ = tokio::time::sleep(timeout) => TimeoutOrCancel::Timeout,
      98              :         _ = cancel.cancelled() => TimeoutOrCancel::Cancel,
      99              :     }
     100            4 : }
     101              : 
     102              : #[cfg(test)]
     103              : mod tests {
     104              :     use super::*;
     105              :     use crate::DownloadError;
     106              :     use futures::stream::StreamExt;
     107              : 
     108              :     #[tokio::test(start_paused = true)]
     109            2 :     async fn cancelled_download_stream() {
     110            2 :         let inner = futures::stream::pending();
     111            2 :         let timeout = Duration::from_secs(120);
     112            2 :         let cancel = CancellationToken::new();
     113            2 : 
     114            2 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     115            2 :         let mut stream = std::pin::pin!(stream);
     116            2 : 
     117            2 :         let mut first = stream.next();
     118            2 : 
     119            2 :         tokio::select! {
     120            2 :             _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"),
     121            2 :             _ = tokio::time::sleep(Duration::from_secs(1)) => {},
     122            2 :         }
     123            2 : 
     124            2 :         cancel.cancel();
     125            2 : 
     126            2 :         let e = first.await.expect("there must be some").unwrap_err();
     127            2 :         assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
     128            2 :         let inner = e.get_ref().expect("inner should be set");
     129            2 :         assert!(
     130            2 :             inner
     131            2 :                 .downcast_ref::<DownloadError>()
     132            2 :                 .is_some_and(|e| matches!(e, DownloadError::Cancelled)),
     133            2 :             "{inner:?}"
     134            2 :         );
     135            2 :         let e = DownloadError::from(e);
     136            2 :         assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
     137            2 : 
     138            2 :         tokio::select! {
     139            2 :             _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"),
     140            2 :             _ = tokio::time::sleep(Duration::from_secs(121)) => {},
     141            2 :         }
     142            2 :     }
     143              : 
     144              :     #[tokio::test(start_paused = true)]
     145            2 :     async fn timeouted_download_stream() {
     146            2 :         let inner = futures::stream::pending();
     147            2 :         let timeout = Duration::from_secs(120);
     148            2 :         let cancel = CancellationToken::new();
     149            2 : 
     150            2 :         let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner);
     151            2 :         let mut stream = std::pin::pin!(stream);
     152            2 : 
     153            2 :         // because the stream uses 120s timeout and we are paused, we advance to 120s right away.
     154            2 :         let first = stream.next();
     155            2 : 
     156            2 :         let e = first.await.expect("there must be some").unwrap_err();
     157            2 :         assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}");
     158            2 :         let inner = e.get_ref().expect("inner should be set");
     159            2 :         assert!(
     160            2 :             inner
     161            2 :                 .downcast_ref::<DownloadError>()
     162            2 :                 .is_some_and(|e| matches!(e, DownloadError::Timeout)),
     163            2 :             "{inner:?}"
     164            2 :         );
     165            2 :         let e = DownloadError::from(e);
     166            2 :         assert!(matches!(e, DownloadError::Timeout), "{e:?}");
     167            2 : 
     168            2 :         cancel.cancel();
     169            2 : 
     170            2 :         tokio::select! {
     171            2 :             _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"),
     172            2 :             _ = tokio::time::sleep(Duration::from_secs(121)) => {},
     173            2 :         }
     174            2 :     }
     175              : }
        

Generated by: LCOV version 2.1-beta