LCOV - code coverage report
Current view: top level - libs/remote_storage/src - azure_blob.rs (source / functions) Coverage Total Hit
Test: 496e96cdfff2df79370229591d6427cda12fde29.info Lines: 19.4 % 427 83
Test Date: 2024-05-21 18:28:29 Functions: 15.5 % 71 11

            Line data    Source code
       1              : //! Azure Blob Storage wrapper
       2              : 
       3              : use std::borrow::Cow;
       4              : use std::collections::HashMap;
       5              : use std::env;
       6              : use std::io;
       7              : use std::num::NonZeroU32;
       8              : use std::pin::Pin;
       9              : use std::str::FromStr;
      10              : use std::sync::Arc;
      11              : use std::time::Duration;
      12              : use std::time::SystemTime;
      13              : 
      14              : use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
      15              : use anyhow::Result;
      16              : use azure_core::request_options::{MaxResults, Metadata, Range};
      17              : use azure_core::RetryOptions;
      18              : use azure_identity::DefaultAzureCredential;
      19              : use azure_storage::StorageCredentials;
      20              : use azure_storage_blobs::blob::CopyStatus;
      21              : use azure_storage_blobs::prelude::ClientBuilder;
      22              : use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
      23              : use bytes::Bytes;
      24              : use futures::future::Either;
      25              : use futures::stream::Stream;
      26              : use futures_util::StreamExt;
      27              : use futures_util::TryStreamExt;
      28              : use http_types::{StatusCode, Url};
      29              : use tokio_util::sync::CancellationToken;
      30              : use tracing::debug;
      31              : 
      32              : use crate::RemoteStorageActivity;
      33              : use crate::{
      34              :     error::Cancelled, s3_bucket::RequestKind, AzureConfig, ConcurrencyLimiter, Download,
      35              :     DownloadError, Listing, ListingMode, RemotePath, RemoteStorage, StorageMetadata,
      36              :     TimeTravelError, TimeoutOrCancel,
      37              : };
      38              : 
      39              : pub struct AzureBlobStorage {
      40              :     client: ContainerClient,
      41              :     prefix_in_container: Option<String>,
      42              :     max_keys_per_list_response: Option<NonZeroU32>,
      43              :     concurrency_limiter: ConcurrencyLimiter,
      44              :     // Per-request timeout. Accessible for tests.
      45              :     pub timeout: Duration,
      46              : }
      47              : 
      48              : impl AzureBlobStorage {
      49            6 :     pub fn new(azure_config: &AzureConfig, timeout: Duration) -> Result<Self> {
      50            6 :         debug!(
      51            0 :             "Creating azure remote storage for azure container {}",
      52              :             azure_config.container_name
      53              :         );
      54              : 
      55            6 :         let account = env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT");
      56              : 
      57              :         // If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
      58              :         // otherwise try the token based credentials.
      59            6 :         let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
      60            6 :             StorageCredentials::access_key(account.clone(), access_key)
      61              :         } else {
      62            0 :             let token_credential = DefaultAzureCredential::default();
      63            0 :             StorageCredentials::token_credential(Arc::new(token_credential))
      64              :         };
      65              : 
      66              :         // we have an outer retry
      67            6 :         let builder = ClientBuilder::new(account, credentials).retry(RetryOptions::none());
      68            6 : 
      69            6 :         let client = builder.container_client(azure_config.container_name.to_owned());
      70              : 
      71            6 :         let max_keys_per_list_response =
      72            6 :             if let Some(limit) = azure_config.max_keys_per_list_response {
      73              :                 Some(
      74            2 :                     NonZeroU32::new(limit as u32)
      75            2 :                         .ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
      76              :                 )
      77              :             } else {
      78            4 :                 None
      79              :             };
      80              : 
      81            6 :         Ok(AzureBlobStorage {
      82            6 :             client,
      83            6 :             prefix_in_container: azure_config.prefix_in_container.to_owned(),
      84            6 :             max_keys_per_list_response,
      85            6 :             concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
      86            6 :             timeout,
      87            6 :         })
      88            6 :     }
      89              : 
      90          107 :     pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
      91          107 :         assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
      92          107 :         let path_string = path
      93          107 :             .get_path()
      94          107 :             .as_str()
      95          107 :             .trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
      96          107 :         match &self.prefix_in_container {
      97          107 :             Some(prefix) => {
      98          107 :                 if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
      99          107 :                     prefix.clone() + path_string
     100              :                 } else {
     101            0 :                     format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
     102              :                 }
     103              :             }
     104            0 :             None => path_string.to_string(),
     105              :         }
     106          107 :     }
     107              : 
     108           53 :     fn name_to_relative_path(&self, key: &str) -> RemotePath {
     109           53 :         let relative_path =
     110           53 :             match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
     111           53 :                 Some(stripped) => stripped,
     112              :                 // we rely on Azure to return properly prefixed paths
     113              :                 // for requests with a certain prefix
     114            0 :                 None => panic!(
     115            0 :                     "Key {key} does not start with container prefix {:?}",
     116            0 :                     self.prefix_in_container
     117            0 :                 ),
     118              :             };
     119           53 :         RemotePath(
     120           53 :             relative_path
     121           53 :                 .split(REMOTE_STORAGE_PREFIX_SEPARATOR)
     122           53 :                 .collect(),
     123           53 :         )
     124           53 :     }
     125              : 
     126            7 :     async fn download_for_builder(
     127            7 :         &self,
     128            7 :         builder: GetBlobBuilder,
     129            7 :         cancel: &CancellationToken,
     130            7 :     ) -> Result<Download, DownloadError> {
     131            0 :         let kind = RequestKind::Get;
     132              : 
     133            0 :         let _permit = self.permit(kind, cancel).await?;
     134            0 :         let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
     135            0 :         let cancel_or_timeout_ = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
     136            0 : 
     137            0 :         let mut etag = None;
     138            0 :         let mut last_modified = None;
     139            0 :         let mut metadata = HashMap::new();
     140            0 : 
     141            0 :         let download = async {
     142            0 :             let response = builder
     143            0 :                 // convert to concrete Pageable
     144            0 :                 .into_stream()
     145            0 :                 // convert to TryStream
     146            0 :                 .into_stream()
     147            0 :                 .map_err(to_download_error);
     148            0 : 
     149            0 :             // apply per request timeout
     150            0 :             let response = tokio_stream::StreamExt::timeout(response, self.timeout);
     151            0 : 
     152            0 :             // flatten
     153            0 :             let response = response.map(|res| match res {
     154            0 :                 Ok(res) => res,
     155            0 :                 Err(_elapsed) => Err(DownloadError::Timeout),
     156            0 :             });
     157            0 : 
     158            0 :             let mut response = Box::pin(response);
     159              : 
     160            0 :             let Some(part) = response.next().await else {
     161            0 :                 return Err(DownloadError::Other(anyhow::anyhow!(
     162            0 :                     "Azure GET response contained no response body"
     163            0 :                 )));
     164              :             };
     165            0 :             let part = part?;
     166            0 :             if etag.is_none() {
     167            0 :                 etag = Some(part.blob.properties.etag);
     168            0 :             }
     169            0 :             if last_modified.is_none() {
     170            0 :                 last_modified = Some(part.blob.properties.last_modified.into());
     171            0 :             }
     172            0 :             if let Some(blob_meta) = part.blob.metadata {
     173            0 :                 metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
     174            0 :             }
     175              : 
     176              :             // unwrap safety: if these were None, bufs would be empty and we would have returned an error already
     177            0 :             let etag = etag.unwrap();
     178            0 :             let last_modified = last_modified.unwrap();
     179            0 : 
     180            0 :             let tail_stream = response
     181            0 :                 .map(|part| match part {
     182            0 :                     Ok(part) => Either::Left(part.data.map(|r| r.map_err(io::Error::other))),
     183            0 :                     Err(e) => {
     184            0 :                         Either::Right(futures::stream::once(async { Err(io::Error::other(e)) }))
     185              :                     }
     186            0 :                 })
     187            0 :                 .flatten();
     188            0 :             let stream = part
     189            0 :                 .data
     190            0 :                 .map(|r| r.map_err(io::Error::other))
     191            0 :                 .chain(sync_wrapper::SyncStream::new(tail_stream));
     192            0 :             //.chain(SyncStream::from_pin(Box::pin(tail_stream)));
     193            0 : 
     194            0 :             let download_stream = crate::support::DownloadStream::new(cancel_or_timeout_, stream);
     195            0 : 
     196            0 :             Ok(Download {
     197            0 :                 download_stream: Box::pin(download_stream),
     198            0 :                 etag,
     199            0 :                 last_modified,
     200            0 :                 metadata: Some(StorageMetadata(metadata)),
     201            0 :             })
     202            0 :         };
     203              : 
     204              :         tokio::select! {
     205              :             bufs = download => bufs,
     206              :             cancel_or_timeout = cancel_or_timeout => match cancel_or_timeout {
     207              :                 TimeoutOrCancel::Timeout => Err(DownloadError::Timeout),
     208              :                 TimeoutOrCancel::Cancel => Err(DownloadError::Cancelled),
     209              :             },
     210              :         }
     211            0 :     }
     212              : 
     213          108 :     async fn permit(
     214          108 :         &self,
     215          108 :         kind: RequestKind,
     216          108 :         cancel: &CancellationToken,
     217          108 :     ) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
     218            0 :         let acquire = self.concurrency_limiter.acquire(kind);
     219              : 
     220              :         tokio::select! {
     221              :             permit = acquire => Ok(permit.expect("never closed")),
     222              :             _ = cancel.cancelled() => Err(Cancelled),
     223              :         }
     224            0 :     }
     225              : }
     226              : 
     227            0 : fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
     228            0 :     let mut res = Metadata::new();
     229            0 :     for (k, v) in metadata.0.into_iter() {
     230            0 :         res.insert(k, v);
     231            0 :     }
     232            0 :     res
     233            0 : }
     234              : 
     235            0 : fn to_download_error(error: azure_core::Error) -> DownloadError {
     236            0 :     if let Some(http_err) = error.as_http_error() {
     237            0 :         match http_err.status() {
     238            0 :             StatusCode::NotFound => DownloadError::NotFound,
     239            0 :             StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
     240            0 :             _ => DownloadError::Other(anyhow::Error::new(error)),
     241              :         }
     242              :     } else {
     243            0 :         DownloadError::Other(error.into())
     244              :     }
     245            0 : }
     246              : 
     247              : impl RemoteStorage for AzureBlobStorage {
     248            6 :     async fn list(
     249            6 :         &self,
     250            6 :         prefix: Option<&RemotePath>,
     251            6 :         mode: ListingMode,
     252            6 :         max_keys: Option<NonZeroU32>,
     253            6 :         cancel: &CancellationToken,
     254            6 :     ) -> anyhow::Result<Listing, DownloadError> {
     255            0 :         let _permit = self.permit(RequestKind::List, cancel).await?;
     256              : 
     257            0 :         let op = async {
     258            0 :             // get the passed prefix or if it is not set use prefix_in_bucket value
     259            0 :             let list_prefix = prefix
     260            0 :                 .map(|p| self.relative_path_to_name(p))
     261            0 :                 .or_else(|| self.prefix_in_container.clone())
     262            0 :                 .map(|mut p| {
     263              :                     // required to end with a separator
     264              :                     // otherwise request will return only the entry of a prefix
     265            0 :                     if matches!(mode, ListingMode::WithDelimiter)
     266            0 :                         && !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
     267            0 :                     {
     268            0 :                         p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
     269            0 :                     }
     270            0 :                     p
     271            0 :                 });
     272            0 : 
     273            0 :             let mut builder = self.client.list_blobs();
     274            0 : 
     275            0 :             if let ListingMode::WithDelimiter = mode {
     276            0 :                 builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
     277            0 :             }
     278              : 
     279            0 :             if let Some(prefix) = list_prefix {
     280            0 :                 builder = builder.prefix(Cow::from(prefix.to_owned()));
     281            0 :             }
     282              : 
     283            0 :             if let Some(limit) = self.max_keys_per_list_response {
     284            0 :                 builder = builder.max_results(MaxResults::new(limit));
     285            0 :             }
     286              : 
     287            0 :             let response = builder.into_stream();
     288            0 :             let response = response.into_stream().map_err(to_download_error);
     289            0 :             let response = tokio_stream::StreamExt::timeout(response, self.timeout);
     290            0 :             let response = response.map(|res| match res {
     291            0 :                 Ok(res) => res,
     292            0 :                 Err(_elapsed) => Err(DownloadError::Timeout),
     293            0 :             });
     294            0 : 
     295            0 :             let mut response = std::pin::pin!(response);
     296            0 : 
     297            0 :             let mut res = Listing::default();
     298            0 : 
     299            0 :             let mut max_keys = max_keys.map(|mk| mk.get());
     300            0 :             while let Some(entry) = response.next().await {
     301            0 :                 let entry = entry?;
     302            0 :                 let prefix_iter = entry
     303            0 :                     .blobs
     304            0 :                     .prefixes()
     305            0 :                     .map(|prefix| self.name_to_relative_path(&prefix.name));
     306            0 :                 res.prefixes.extend(prefix_iter);
     307            0 : 
     308            0 :                 let blob_iter = entry
     309            0 :                     .blobs
     310            0 :                     .blobs()
     311            0 :                     .map(|k| self.name_to_relative_path(&k.name));
     312              : 
     313            0 :                 for key in blob_iter {
     314            0 :                     res.keys.push(key);
     315              : 
     316            0 :                     if let Some(mut mk) = max_keys {
     317            0 :                         assert!(mk > 0);
     318            0 :                         mk -= 1;
     319            0 :                         if mk == 0 {
     320            0 :                             return Ok(res); // limit reached
     321            0 :                         }
     322            0 :                         max_keys = Some(mk);
     323            0 :                     }
     324              :                 }
     325              :             }
     326              : 
     327            0 :             Ok(res)
     328            0 :         };
     329              : 
     330              :         tokio::select! {
     331              :             res = op => res,
     332              :             _ = cancel.cancelled() => Err(DownloadError::Cancelled),
     333              :         }
     334            0 :     }
     335              : 
     336            0 :     async fn upload(
     337            0 :         &self,
     338            0 :         from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
     339            0 :         data_size_bytes: usize,
     340            0 :         to: &RemotePath,
     341            0 :         metadata: Option<StorageMetadata>,
     342            0 :         cancel: &CancellationToken,
     343            0 :     ) -> anyhow::Result<()> {
     344            0 :         let _permit = self.permit(RequestKind::Put, cancel).await?;
     345              : 
     346            0 :         let op = async {
     347            0 :             let blob_client = self.client.blob_client(self.relative_path_to_name(to));
     348            0 : 
     349            0 :             let from: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>> =
     350            0 :                 Box::pin(from);
     351            0 : 
     352            0 :             let from = NonSeekableStream::new(from, data_size_bytes);
     353            0 : 
     354            0 :             let body = azure_core::Body::SeekableStream(Box::new(from));
     355            0 : 
     356            0 :             let mut builder = blob_client.put_block_blob(body);
     357              : 
     358            0 :             if let Some(metadata) = metadata {
     359            0 :                 builder = builder.metadata(to_azure_metadata(metadata));
     360            0 :             }
     361              : 
     362            0 :             let fut = builder.into_future();
     363            0 :             let fut = tokio::time::timeout(self.timeout, fut);
     364            0 : 
     365            0 :             match fut.await {
     366            0 :                 Ok(Ok(_response)) => Ok(()),
     367            0 :                 Ok(Err(azure)) => Err(azure.into()),
     368            0 :                 Err(_timeout) => Err(TimeoutOrCancel::Cancel.into()),
     369              :             }
     370            0 :         };
     371              : 
     372              :         tokio::select! {
     373              :             res = op => res,
     374              :             _ = cancel.cancelled() => Err(TimeoutOrCancel::Cancel.into()),
     375              :         }
     376            0 :     }
     377              : 
     378            2 :     async fn download(
     379            2 :         &self,
     380            2 :         from: &RemotePath,
     381            2 :         cancel: &CancellationToken,
     382            2 :     ) -> Result<Download, DownloadError> {
     383            0 :         let blob_client = self.client.blob_client(self.relative_path_to_name(from));
     384            0 : 
     385            0 :         let builder = blob_client.get();
     386            0 : 
     387            0 :         self.download_for_builder(builder, cancel).await
     388            0 :     }
     389              : 
     390            5 :     async fn download_byte_range(
     391            5 :         &self,
     392            5 :         from: &RemotePath,
     393            5 :         start_inclusive: u64,
     394            5 :         end_exclusive: Option<u64>,
     395            5 :         cancel: &CancellationToken,
     396            5 :     ) -> Result<Download, DownloadError> {
     397            0 :         let blob_client = self.client.blob_client(self.relative_path_to_name(from));
     398            0 : 
     399            0 :         let mut builder = blob_client.get();
     400              : 
     401            0 :         let range: Range = if let Some(end_exclusive) = end_exclusive {
     402            0 :             (start_inclusive..end_exclusive).into()
     403              :         } else {
     404            0 :             (start_inclusive..).into()
     405              :         };
     406            0 :         builder = builder.range(range);
     407            0 : 
     408            0 :         self.download_for_builder(builder, cancel).await
     409            0 :     }
     410              : 
     411           44 :     async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
     412            0 :         self.delete_objects(std::array::from_ref(path), cancel)
     413            0 :             .await
     414            0 :     }
     415              : 
     416           47 :     async fn delete_objects<'a>(
     417           47 :         &self,
     418           47 :         paths: &'a [RemotePath],
     419           47 :         cancel: &CancellationToken,
     420           47 :     ) -> anyhow::Result<()> {
     421            0 :         let _permit = self.permit(RequestKind::Delete, cancel).await?;
     422              : 
     423            0 :         let op = async {
     424              :             // TODO batch requests are also not supported by the SDK
     425              :             // https://github.com/Azure/azure-sdk-for-rust/issues/1068
     426              :             // https://github.com/Azure/azure-sdk-for-rust/issues/1249
     427            0 :             for path in paths {
     428            0 :                 let blob_client = self.client.blob_client(self.relative_path_to_name(path));
     429            0 : 
     430            0 :                 let request = blob_client.delete().into_future();
     431              : 
     432            0 :                 let res = tokio::time::timeout(self.timeout, request).await;
     433              : 
     434            0 :                 match res {
     435            0 :                     Ok(Ok(_response)) => continue,
     436            0 :                     Ok(Err(e)) => {
     437            0 :                         if let Some(http_err) = e.as_http_error() {
     438            0 :                             if http_err.status() == StatusCode::NotFound {
     439            0 :                                 continue;
     440            0 :                             }
     441            0 :                         }
     442            0 :                         return Err(e.into());
     443              :                     }
     444            0 :                     Err(_elapsed) => return Err(TimeoutOrCancel::Timeout.into()),
     445              :                 }
     446              :             }
     447              : 
     448            0 :             Ok(())
     449            0 :         };
     450              : 
     451              :         tokio::select! {
     452              :             res = op => res,
     453              :             _ = cancel.cancelled() => Err(TimeoutOrCancel::Cancel.into()),
     454              :         }
     455            0 :     }
     456              : 
     457            1 :     async fn copy(
     458            1 :         &self,
     459            1 :         from: &RemotePath,
     460            1 :         to: &RemotePath,
     461            1 :         cancel: &CancellationToken,
     462            1 :     ) -> anyhow::Result<()> {
     463            0 :         let _permit = self.permit(RequestKind::Copy, cancel).await?;
     464              : 
     465            0 :         let timeout = tokio::time::sleep(self.timeout);
     466            0 : 
     467            0 :         let mut copy_status = None;
     468            0 : 
     469            0 :         let op = async {
     470            0 :             let blob_client = self.client.blob_client(self.relative_path_to_name(to));
     471              : 
     472            0 :             let source_url = format!(
     473            0 :                 "{}/{}",
     474            0 :                 self.client.url()?,
     475            0 :                 self.relative_path_to_name(from)
     476              :             );
     477              : 
     478            0 :             let builder = blob_client.copy(Url::from_str(&source_url)?);
     479            0 :             let copy = builder.into_future();
     480              : 
     481            0 :             let result = copy.await?;
     482              : 
     483            0 :             copy_status = Some(result.copy_status);
     484              :             loop {
     485            0 :                 match copy_status.as_ref().expect("we always set it to Some") {
     486              :                     CopyStatus::Aborted => {
     487            0 :                         anyhow::bail!("Received abort for copy from {from} to {to}.");
     488              :                     }
     489              :                     CopyStatus::Failed => {
     490            0 :                         anyhow::bail!("Received failure response for copy from {from} to {to}.");
     491              :                     }
     492            0 :                     CopyStatus::Success => return Ok(()),
     493            0 :                     CopyStatus::Pending => (),
     494            0 :                 }
     495            0 :                 // The copy is taking longer. Waiting a second and then re-trying.
     496            0 :                 // TODO estimate time based on copy_progress and adjust time based on that
     497            0 :                 tokio::time::sleep(Duration::from_millis(1000)).await;
     498            0 :                 let properties = blob_client.get_properties().into_future().await?;
     499            0 :                 let Some(status) = properties.blob.properties.copy_status else {
     500            0 :                     tracing::warn!("copy_status for copy is None!, from={from}, to={to}");
     501            0 :                     return Ok(());
     502              :                 };
     503            0 :                 copy_status = Some(status);
     504              :             }
     505            0 :         };
     506              : 
     507              :         tokio::select! {
     508              :             res = op => res,
     509              :             _ = cancel.cancelled() => Err(anyhow::Error::new(TimeoutOrCancel::Cancel)),
     510              :             _ = timeout => {
     511              :                 let e = anyhow::Error::new(TimeoutOrCancel::Timeout);
     512              :                 let e = e.context(format!("Timeout, last status: {copy_status:?}"));
     513              :                 Err(e)
     514              :             },
     515              :         }
     516            0 :     }
     517              : 
     518            0 :     async fn time_travel_recover(
     519            0 :         &self,
     520            0 :         _prefix: Option<&RemotePath>,
     521            0 :         _timestamp: SystemTime,
     522            0 :         _done_if_after: SystemTime,
     523            0 :         _cancel: &CancellationToken,
     524            0 :     ) -> Result<(), TimeTravelError> {
     525            0 :         // TODO use Azure point in time recovery feature for this
     526            0 :         // https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
     527            0 :         Err(TimeTravelError::Unimplemented)
     528            0 :     }
     529              : 
     530            0 :     fn activity(&self) -> RemoteStorageActivity {
     531            0 :         self.concurrency_limiter.activity()
     532            0 :     }
     533              : }
     534              : 
     535              : pin_project_lite::pin_project! {
     536              :     /// Hack to work around not being able to stream once with azure sdk.
     537              :     ///
     538              :     /// Azure sdk clones streams around with the assumption that they are like
     539              :     /// `Arc<tokio::fs::File>` (except not supporting tokio), however our streams are not like
     540              :     /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`]
     541              :     /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially"
     542              :     /// seekable, but we can also just re-try the request easier.
     543              :     #[project = NonSeekableStreamProj]
     544              :     enum NonSeekableStream<S> {
     545              :         /// A stream wrappers initial form.
     546              :         ///
     547              :         /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1
     548              :         /// clone before first request, then this must be changed.
     549              :         Initial {
     550              :             inner: std::sync::Mutex<Option<tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>>>,
     551              :             len: usize,
     552              :         },
     553              :         /// The actually readable variant, produced by cloning the Initial variant.
     554              :         ///
     555              :         /// The sdk currently always clones once, even without retry policy.
     556              :         Actual {
     557              :             #[pin]
     558              :             inner: tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>,
     559              :             len: usize,
     560              :             read_any: bool,
     561              :         },
     562              :         /// Most likely unneeded, but left to make life easier, in case more clones are added.
     563              :         Cloned {
     564              :             len_was: usize,
     565              :         }
     566              :     }
     567              : }
     568              : 
     569              : impl<S> NonSeekableStream<S>
     570              : where
     571              :     S: Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
     572              : {
     573            0 :     fn new(inner: S, len: usize) -> NonSeekableStream<S> {
     574            0 :         use tokio_util::compat::TokioAsyncReadCompatExt;
     575            0 : 
     576            0 :         let inner = tokio_util::io::StreamReader::new(inner).compat();
     577            0 :         let inner = Some(inner);
     578            0 :         let inner = std::sync::Mutex::new(inner);
     579            0 :         NonSeekableStream::Initial { inner, len }
     580            0 :     }
     581              : }
     582              : 
     583              : impl<S> std::fmt::Debug for NonSeekableStream<S> {
     584            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     585            0 :         match self {
     586            0 :             Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(),
     587            0 :             Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(),
     588            0 :             Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(),
     589              :         }
     590            0 :     }
     591              : }
     592              : 
     593              : impl<S> futures::io::AsyncRead for NonSeekableStream<S>
     594              : where
     595              :     S: Stream<Item = std::io::Result<Bytes>>,
     596              : {
     597            0 :     fn poll_read(
     598            0 :         self: std::pin::Pin<&mut Self>,
     599            0 :         cx: &mut std::task::Context<'_>,
     600            0 :         buf: &mut [u8],
     601            0 :     ) -> std::task::Poll<std::io::Result<usize>> {
     602            0 :         match self.project() {
     603              :             NonSeekableStreamProj::Actual {
     604            0 :                 inner, read_any, ..
     605            0 :             } => {
     606            0 :                 *read_any = true;
     607            0 :                 inner.poll_read(cx, buf)
     608              :             }
     609              :             // NonSeekableStream::Initial does not support reading because it is just much easier
     610              :             // to have the mutex in place where one does not poll the contents, or that's how it
     611              :             // seemed originally. If there is a version upgrade which changes the cloning, then
     612              :             // that support needs to be hacked in.
     613              :             //
     614              :             // including {self:?} into the message would be useful, but unsure how to unproject.
     615            0 :             _ => std::task::Poll::Ready(Err(std::io::Error::new(
     616            0 :                 std::io::ErrorKind::Other,
     617            0 :                 "cloned or initial values cannot be read",
     618            0 :             ))),
     619              :         }
     620            0 :     }
     621              : }
     622              : 
     623              : impl<S> Clone for NonSeekableStream<S> {
     624              :     /// Weird clone implementation exists to support the sdk doing cloning before issuing the first
     625              :     /// request, see type documentation.
     626            0 :     fn clone(&self) -> Self {
     627            0 :         use NonSeekableStream::*;
     628            0 : 
     629            0 :         match self {
     630            0 :             Initial { inner, len } => {
     631            0 :                 if let Some(inner) = inner.lock().unwrap().take() {
     632            0 :                     Actual {
     633            0 :                         inner,
     634            0 :                         len: *len,
     635            0 :                         read_any: false,
     636            0 :                     }
     637              :                 } else {
     638            0 :                     Self::Cloned { len_was: *len }
     639              :                 }
     640              :             }
     641            0 :             Actual { len, .. } => Cloned { len_was: *len },
     642            0 :             Cloned { len_was } => Cloned { len_was: *len_was },
     643              :         }
     644            0 :     }
     645              : }
     646              : 
     647              : #[async_trait::async_trait]
     648              : impl<S> azure_core::SeekableStream for NonSeekableStream<S>
     649              : where
     650              :     S: Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync + 'static,
     651              : {
     652            0 :     async fn reset(&mut self) -> azure_core::error::Result<()> {
     653            0 :         use NonSeekableStream::*;
     654            0 : 
     655            0 :         let msg = match self {
     656            0 :             Initial { inner, .. } => {
     657            0 :                 if inner.get_mut().unwrap().is_some() {
     658            0 :                     return Ok(());
     659            0 :                 } else {
     660            0 :                     "reset after first clone is not supported"
     661            0 :                 }
     662            0 :             }
     663            0 :             Actual { read_any, .. } if !*read_any => return Ok(()),
     664            0 :             Actual { .. } => "reset after reading is not supported",
     665            0 :             Cloned { .. } => "reset after second clone is not supported",
     666            0 :         };
     667            0 :         Err(azure_core::error::Error::new(
     668            0 :             azure_core::error::ErrorKind::Io,
     669            0 :             std::io::Error::new(std::io::ErrorKind::Other, msg),
     670            0 :         ))
     671            0 :     }
     672              : 
     673              :     // Note: it is not documented if this should be the total or remaining length, total passes the
     674              :     // tests.
     675            0 :     fn len(&self) -> usize {
     676            0 :         use NonSeekableStream::*;
     677            0 :         match self {
     678            0 :             Initial { len, .. } => *len,
     679            0 :             Actual { len, .. } => *len,
     680            0 :             Cloned { len_was, .. } => *len_was,
     681              :         }
     682            0 :     }
     683              : }
        

Generated by: LCOV version 2.1-beta