LCOV - code coverage report
Current view: top level - libs/remote_storage/src - azure_blob.rs (source / functions) Coverage Total Hit
Test: fabb29a6339542ee130cd1d32b534fafdc0be240.info Lines: 79.1 % 497 393
Test Date: 2024-06-25 13:20:00 Functions: 55.1 % 78 43

            Line data    Source code
       1              : //! Azure Blob Storage wrapper
       2              : 
       3              : use std::borrow::Cow;
       4              : use std::collections::HashMap;
       5              : use std::env;
       6              : use std::fmt::Display;
       7              : use std::io;
       8              : use std::num::NonZeroU32;
       9              : use std::pin::Pin;
      10              : use std::str::FromStr;
      11              : use std::sync::Arc;
      12              : use std::time::Duration;
      13              : use std::time::SystemTime;
      14              : 
      15              : use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
      16              : use anyhow::Result;
      17              : use azure_core::request_options::{MaxResults, Metadata, Range};
      18              : use azure_core::RetryOptions;
      19              : use azure_identity::DefaultAzureCredential;
      20              : use azure_storage::StorageCredentials;
      21              : use azure_storage_blobs::blob::CopyStatus;
      22              : use azure_storage_blobs::prelude::ClientBuilder;
      23              : use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
      24              : use bytes::Bytes;
      25              : use futures::future::Either;
      26              : use futures::stream::Stream;
      27              : use futures_util::StreamExt;
      28              : use futures_util::TryStreamExt;
      29              : use http_types::{StatusCode, Url};
      30              : use scopeguard::ScopeGuard;
      31              : use tokio_util::sync::CancellationToken;
      32              : use tracing::debug;
      33              : use utils::backoff;
      34              : 
      35              : use crate::metrics::{start_measuring_requests, AttemptOutcome, RequestKind};
      36              : use crate::{
      37              :     config::AzureConfig, error::Cancelled, ConcurrencyLimiter, Download, DownloadError, Listing,
      38              :     ListingMode, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
      39              : };
      40              : 
      41              : pub struct AzureBlobStorage {
      42              :     client: ContainerClient,
      43              :     prefix_in_container: Option<String>,
      44              :     max_keys_per_list_response: Option<NonZeroU32>,
      45              :     concurrency_limiter: ConcurrencyLimiter,
      46              :     // Per-request timeout. Accessible for tests.
      47              :     pub timeout: Duration,
      48              : }
      49              : 
      50              : impl AzureBlobStorage {
      51            6 :     pub fn new(azure_config: &AzureConfig, timeout: Duration) -> Result<Self> {
      52            6 :         debug!(
      53            0 :             "Creating azure remote storage for azure container {}",
      54              :             azure_config.container_name
      55              :         );
      56              : 
      57              :         // Use the storage account from the config by default, fall back to env var if not present.
      58            6 :         let account = azure_config.storage_account.clone().unwrap_or_else(|| {
      59            6 :             env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT")
      60            6 :         });
      61              : 
      62              :         // If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
      63              :         // otherwise try the token based credentials.
      64            6 :         let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
      65            6 :             StorageCredentials::access_key(account.clone(), access_key)
      66              :         } else {
      67            0 :             let token_credential = DefaultAzureCredential::default();
      68            0 :             StorageCredentials::token_credential(Arc::new(token_credential))
      69              :         };
      70              : 
      71              :         // we have an outer retry
      72            6 :         let builder = ClientBuilder::new(account, credentials).retry(RetryOptions::none());
      73            6 : 
      74            6 :         let client = builder.container_client(azure_config.container_name.to_owned());
      75              : 
      76            6 :         let max_keys_per_list_response =
      77            6 :             if let Some(limit) = azure_config.max_keys_per_list_response {
      78              :                 Some(
      79            2 :                     NonZeroU32::new(limit as u32)
      80            2 :                         .ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
      81              :                 )
      82              :             } else {
      83            4 :                 None
      84              :             };
      85              : 
      86            6 :         Ok(AzureBlobStorage {
      87            6 :             client,
      88            6 :             prefix_in_container: azure_config.prefix_in_container.to_owned(),
      89            6 :             max_keys_per_list_response,
      90            6 :             concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
      91            6 :             timeout,
      92            6 :         })
      93            6 :     }
      94              : 
      95          107 :     pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
      96          107 :         assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
      97          107 :         let path_string = path
      98          107 :             .get_path()
      99          107 :             .as_str()
     100          107 :             .trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
     101          107 :         match &self.prefix_in_container {
     102          107 :             Some(prefix) => {
     103          107 :                 if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
     104          107 :                     prefix.clone() + path_string
     105              :                 } else {
     106            0 :                     format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
     107              :                 }
     108              :             }
     109            0 :             None => path_string.to_string(),
     110              :         }
     111          107 :     }
     112              : 
     113           53 :     fn name_to_relative_path(&self, key: &str) -> RemotePath {
     114           53 :         let relative_path =
     115           53 :             match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
     116           53 :                 Some(stripped) => stripped,
     117              :                 // we rely on Azure to return properly prefixed paths
     118              :                 // for requests with a certain prefix
     119            0 :                 None => panic!(
     120            0 :                     "Key {key} does not start with container prefix {:?}",
     121            0 :                     self.prefix_in_container
     122            0 :                 ),
     123              :             };
     124           53 :         RemotePath(
     125           53 :             relative_path
     126           53 :                 .split(REMOTE_STORAGE_PREFIX_SEPARATOR)
     127           53 :                 .collect(),
     128           53 :         )
     129           53 :     }
     130              : 
     131            7 :     async fn download_for_builder(
     132            7 :         &self,
     133            7 :         builder: GetBlobBuilder,
     134            7 :         cancel: &CancellationToken,
     135            7 :     ) -> Result<Download, DownloadError> {
     136            7 :         let kind = RequestKind::Get;
     137              : 
     138            7 :         let _permit = self.permit(kind, cancel).await?;
     139            7 :         let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
     140            7 :         let cancel_or_timeout_ = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
     141            7 : 
     142            7 :         let mut etag = None;
     143            7 :         let mut last_modified = None;
     144            7 :         let mut metadata = HashMap::new();
     145            7 : 
     146            7 :         let started_at = start_measuring_requests(kind);
     147            7 : 
     148            7 :         let download = async {
     149            7 :             let response = builder
     150            7 :                 // convert to concrete Pageable
     151            7 :                 .into_stream()
     152            7 :                 // convert to TryStream
     153            7 :                 .into_stream()
     154            7 :                 .map_err(to_download_error);
     155            7 : 
     156            7 :             // apply per request timeout
     157            7 :             let response = tokio_stream::StreamExt::timeout(response, self.timeout);
     158            7 : 
     159            7 :             // flatten
     160            7 :             let response = response.map(|res| match res {
     161            7 :                 Ok(res) => res,
     162            0 :                 Err(_elapsed) => Err(DownloadError::Timeout),
     163            7 :             });
     164            7 : 
     165            7 :             let mut response = Box::pin(response);
     166              : 
     167           35 :             let Some(part) = response.next().await else {
     168            0 :                 return Err(DownloadError::Other(anyhow::anyhow!(
     169            0 :                     "Azure GET response contained no response body"
     170            0 :                 )));
     171              :             };
     172            7 :             let part = part?;
     173            7 :             if etag.is_none() {
     174            7 :                 etag = Some(part.blob.properties.etag);
     175            7 :             }
     176            7 :             if last_modified.is_none() {
     177            7 :                 last_modified = Some(part.blob.properties.last_modified.into());
     178            7 :             }
     179            7 :             if let Some(blob_meta) = part.blob.metadata {
     180            0 :                 metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
     181            7 :             }
     182              : 
     183              :             // unwrap safety: if these were None, bufs would be empty and we would have returned an error already
     184            7 :             let etag = etag.unwrap();
     185            7 :             let last_modified = last_modified.unwrap();
     186            7 : 
     187            7 :             let tail_stream = response
     188            7 :                 .map(|part| match part {
     189            0 :                     Ok(part) => Either::Left(part.data.map(|r| r.map_err(io::Error::other))),
     190            0 :                     Err(e) => {
     191            0 :                         Either::Right(futures::stream::once(async { Err(io::Error::other(e)) }))
     192              :                     }
     193            7 :                 })
     194            7 :                 .flatten();
     195            7 :             let stream = part
     196            7 :                 .data
     197            7 :                 .map(|r| r.map_err(io::Error::other))
     198            7 :                 .chain(sync_wrapper::SyncStream::new(tail_stream));
     199            7 :             //.chain(SyncStream::from_pin(Box::pin(tail_stream)));
     200            7 : 
     201            7 :             let download_stream = crate::support::DownloadStream::new(cancel_or_timeout_, stream);
     202            7 : 
     203            7 :             Ok(Download {
     204            7 :                 download_stream: Box::pin(download_stream),
     205            7 :                 etag,
     206            7 :                 last_modified,
     207            7 :                 metadata: Some(StorageMetadata(metadata)),
     208            7 :             })
     209            7 :         };
     210              : 
     211            7 :         let download = tokio::select! {
     212              :             bufs = download => bufs,
     213              :             cancel_or_timeout = cancel_or_timeout => match cancel_or_timeout {
     214              :                 TimeoutOrCancel::Timeout => return Err(DownloadError::Timeout),
     215              :                 TimeoutOrCancel::Cancel => return Err(DownloadError::Cancelled),
     216              :             },
     217              :         };
     218            7 :         let started_at = ScopeGuard::into_inner(started_at);
     219            7 :         let outcome = match &download {
     220            7 :             Ok(_) => AttemptOutcome::Ok,
     221            0 :             Err(_) => AttemptOutcome::Err,
     222              :         };
     223            7 :         crate::metrics::BUCKET_METRICS
     224            7 :             .req_seconds
     225            7 :             .observe_elapsed(kind, outcome, started_at);
     226            7 :         download
     227            7 :     }
     228              : 
     229          108 :     async fn permit(
     230          108 :         &self,
     231          108 :         kind: RequestKind,
     232          108 :         cancel: &CancellationToken,
     233          108 :     ) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
     234          108 :         let acquire = self.concurrency_limiter.acquire(kind);
     235              : 
     236              :         tokio::select! {
     237              :             permit = acquire => Ok(permit.expect("never closed")),
     238              :             _ = cancel.cancelled() => Err(Cancelled),
     239              :         }
     240          108 :     }
     241              : }
     242              : 
     243            0 : fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
     244            0 :     let mut res = Metadata::new();
     245            0 :     for (k, v) in metadata.0.into_iter() {
     246            0 :         res.insert(k, v);
     247            0 :     }
     248            0 :     res
     249            0 : }
     250              : 
     251            0 : fn to_download_error(error: azure_core::Error) -> DownloadError {
     252            0 :     if let Some(http_err) = error.as_http_error() {
     253            0 :         match http_err.status() {
     254            0 :             StatusCode::NotFound => DownloadError::NotFound,
     255            0 :             StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
     256            0 :             _ => DownloadError::Other(anyhow::Error::new(error)),
     257              :         }
     258              :     } else {
     259            0 :         DownloadError::Other(error.into())
     260              :     }
     261            0 : }
     262              : 
     263              : impl RemoteStorage for AzureBlobStorage {
     264            6 :     async fn list(
     265            6 :         &self,
     266            6 :         prefix: Option<&RemotePath>,
     267            6 :         mode: ListingMode,
     268            6 :         max_keys: Option<NonZeroU32>,
     269            6 :         cancel: &CancellationToken,
     270            6 :     ) -> anyhow::Result<Listing, DownloadError> {
     271            6 :         let _permit = self.permit(RequestKind::List, cancel).await?;
     272              : 
     273            6 :         let op = async {
     274            6 :             // get the passed prefix or if it is not set use prefix_in_bucket value
     275            6 :             let list_prefix = prefix
     276            6 :                 .map(|p| self.relative_path_to_name(p))
     277            6 :                 .or_else(|| self.prefix_in_container.clone())
     278            6 :                 .map(|mut p| {
     279              :                     // required to end with a separator
     280              :                     // otherwise request will return only the entry of a prefix
     281            6 :                     if matches!(mode, ListingMode::WithDelimiter)
     282            3 :                         && !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
     283            1 :                     {
     284            1 :                         p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
     285            5 :                     }
     286            6 :                     p
     287            6 :                 });
     288            6 : 
     289            6 :             let mut builder = self.client.list_blobs();
     290            6 : 
     291            6 :             if let ListingMode::WithDelimiter = mode {
     292            3 :                 builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
     293            3 :             }
     294              : 
     295            6 :             if let Some(prefix) = list_prefix {
     296            6 :                 builder = builder.prefix(Cow::from(prefix.to_owned()));
     297            6 :             }
     298              : 
     299            6 :             if let Some(limit) = self.max_keys_per_list_response {
     300            5 :                 builder = builder.max_results(MaxResults::new(limit));
     301            5 :             }
     302              : 
     303            6 :             let response = builder.into_stream();
     304            6 :             let response = response.into_stream().map_err(to_download_error);
     305            6 :             let response = tokio_stream::StreamExt::timeout(response, self.timeout);
     306           10 :             let response = response.map(|res| match res {
     307           10 :                 Ok(res) => res,
     308            0 :                 Err(_elapsed) => Err(DownloadError::Timeout),
     309           10 :             });
     310            6 : 
     311            6 :             let mut response = std::pin::pin!(response);
     312            6 : 
     313            6 :             let mut res = Listing::default();
     314            6 : 
     315            6 :             let mut max_keys = max_keys.map(|mk| mk.get());
     316           60 :             while let Some(entry) = response.next().await {
     317           10 :                 let entry = entry?;
     318           10 :                 let prefix_iter = entry
     319           10 :                     .blobs
     320           10 :                     .prefixes()
     321           23 :                     .map(|prefix| self.name_to_relative_path(&prefix.name));
     322           10 :                 res.prefixes.extend(prefix_iter);
     323           10 : 
     324           10 :                 let blob_iter = entry
     325           10 :                     .blobs
     326           10 :                     .blobs()
     327           30 :                     .map(|k| self.name_to_relative_path(&k.name));
     328              : 
     329           39 :                 for key in blob_iter {
     330           30 :                     res.keys.push(key);
     331              : 
     332           30 :                     if let Some(mut mk) = max_keys {
     333            2 :                         assert!(mk > 0);
     334            2 :                         mk -= 1;
     335            2 :                         if mk == 0 {
     336            1 :                             return Ok(res); // limit reached
     337            1 :                         }
     338            1 :                         max_keys = Some(mk);
     339           28 :                     }
     340              :                 }
     341              :             }
     342              : 
     343            5 :             Ok(res)
     344            6 :         };
     345              : 
     346              :         tokio::select! {
     347              :             res = op => res,
     348              :             _ = cancel.cancelled() => Err(DownloadError::Cancelled),
     349              :         }
     350            6 :     }
     351              : 
     352           47 :     async fn upload(
     353           47 :         &self,
     354           47 :         from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
     355           47 :         data_size_bytes: usize,
     356           47 :         to: &RemotePath,
     357           47 :         metadata: Option<StorageMetadata>,
     358           47 :         cancel: &CancellationToken,
     359           47 :     ) -> anyhow::Result<()> {
     360           47 :         let kind = RequestKind::Put;
     361           47 :         let _permit = self.permit(kind, cancel).await?;
     362              : 
     363           47 :         let started_at = start_measuring_requests(kind);
     364           47 : 
     365           47 :         let op = async {
     366           47 :             let blob_client = self.client.blob_client(self.relative_path_to_name(to));
     367           47 : 
     368           47 :             let from: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>> =
     369           47 :                 Box::pin(from);
     370           47 : 
     371           47 :             let from = NonSeekableStream::new(from, data_size_bytes);
     372           47 : 
     373           47 :             let body = azure_core::Body::SeekableStream(Box::new(from));
     374           47 : 
     375           47 :             let mut builder = blob_client.put_block_blob(body);
     376              : 
     377           47 :             if let Some(metadata) = metadata {
     378            0 :                 builder = builder.metadata(to_azure_metadata(metadata));
     379           47 :             }
     380              : 
     381           47 :             let fut = builder.into_future();
     382           47 :             let fut = tokio::time::timeout(self.timeout, fut);
     383           47 : 
     384          272 :             match fut.await {
     385           47 :                 Ok(Ok(_response)) => Ok(()),
     386            0 :                 Ok(Err(azure)) => Err(azure.into()),
     387            0 :                 Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()),
     388              :             }
     389           47 :         };
     390              : 
     391           47 :         let res = tokio::select! {
     392              :             res = op => res,
     393              :             _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
     394              :         };
     395              : 
     396           47 :         let outcome = match res {
     397           47 :             Ok(_) => AttemptOutcome::Ok,
     398            0 :             Err(_) => AttemptOutcome::Err,
     399              :         };
     400           47 :         let started_at = ScopeGuard::into_inner(started_at);
     401           47 :         crate::metrics::BUCKET_METRICS
     402           47 :             .req_seconds
     403           47 :             .observe_elapsed(kind, outcome, started_at);
     404           47 : 
     405           47 :         res
     406           47 :     }
     407              : 
     408            2 :     async fn download(
     409            2 :         &self,
     410            2 :         from: &RemotePath,
     411            2 :         cancel: &CancellationToken,
     412            2 :     ) -> Result<Download, DownloadError> {
     413            2 :         let blob_client = self.client.blob_client(self.relative_path_to_name(from));
     414            2 : 
     415            2 :         let builder = blob_client.get();
     416            2 : 
     417           10 :         self.download_for_builder(builder, cancel).await
     418            2 :     }
     419              : 
     420            5 :     async fn download_byte_range(
     421            5 :         &self,
     422            5 :         from: &RemotePath,
     423            5 :         start_inclusive: u64,
     424            5 :         end_exclusive: Option<u64>,
     425            5 :         cancel: &CancellationToken,
     426            5 :     ) -> Result<Download, DownloadError> {
     427            5 :         let blob_client = self.client.blob_client(self.relative_path_to_name(from));
     428            5 : 
     429            5 :         let mut builder = blob_client.get();
     430              : 
     431            5 :         let range: Range = if let Some(end_exclusive) = end_exclusive {
     432            3 :             (start_inclusive..end_exclusive).into()
     433              :         } else {
     434            2 :             (start_inclusive..).into()
     435              :         };
     436            5 :         builder = builder.range(range);
     437            5 : 
     438           25 :         self.download_for_builder(builder, cancel).await
     439            5 :     }
     440              : 
     441           44 :     async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
     442           44 :         self.delete_objects(std::array::from_ref(path), cancel)
     443          221 :             .await
     444           44 :     }
     445              : 
     446           47 :     async fn delete_objects<'a>(
     447           47 :         &self,
     448           47 :         paths: &'a [RemotePath],
     449           47 :         cancel: &CancellationToken,
     450           47 :     ) -> anyhow::Result<()> {
     451           47 :         let kind = RequestKind::Delete;
     452           47 :         let _permit = self.permit(kind, cancel).await?;
     453           47 :         let started_at = start_measuring_requests(kind);
     454           47 : 
     455           47 :         let op = async {
     456              :             // TODO batch requests are not supported by the SDK
     457              :             // https://github.com/Azure/azure-sdk-for-rust/issues/1068
     458           96 :             for path in paths {
     459              :                 #[derive(Debug)]
     460              :                 enum AzureOrTimeout {
     461              :                     AzureError(azure_core::Error),
     462              :                     Timeout,
     463              :                     Cancel,
     464              :                 }
     465              :                 impl Display for AzureOrTimeout {
     466            0 :                     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     467            0 :                         write!(f, "{self:?}")
     468            0 :                     }
     469              :                 }
     470           49 :                 let warn_threshold = 3;
     471           49 :                 let max_retries = 5;
     472           49 :                 backoff::retry(
     473           49 :                     || async {
     474           49 :                         let blob_client = self.client.blob_client(self.relative_path_to_name(path));
     475           49 : 
     476           49 :                         let request = blob_client.delete().into_future();
     477           49 : 
     478          247 :                         let res = tokio::time::timeout(self.timeout, request).await;
     479           49 : 
     480           49 :                         match res {
     481           49 :                             Ok(Ok(_v)) => Ok(()),
     482           49 :                             Ok(Err(azure_err)) => {
     483           49 :                                 if let Some(http_err) = azure_err.as_http_error() {
     484           49 :                                     if http_err.status() == StatusCode::NotFound {
     485           49 :                                         return Ok(());
     486           49 :                                     }
     487           49 :                                 }
     488           49 :                                 Err(AzureOrTimeout::AzureError(azure_err))
     489           49 :                             }
     490           49 :                             Err(_elapsed) => Err(AzureOrTimeout::Timeout),
     491           49 :                         }
     492           49 :                     },
     493           49 :                     |err| match err {
     494            0 :                         AzureOrTimeout::AzureError(_) | AzureOrTimeout::Timeout => false,
     495            0 :                         AzureOrTimeout::Cancel => true,
     496           49 :                     },
     497           49 :                     warn_threshold,
     498           49 :                     max_retries,
     499           49 :                     "deleting remote object",
     500           49 :                     cancel,
     501           49 :                 )
     502          247 :                 .await
     503           49 :                 .ok_or_else(|| AzureOrTimeout::Cancel)
     504           49 :                 .and_then(|x| x)
     505           49 :                 .map_err(|e| match e {
     506            0 :                     AzureOrTimeout::AzureError(err) => anyhow::Error::from(err),
     507            0 :                     AzureOrTimeout::Timeout => TimeoutOrCancel::Timeout.into(),
     508            0 :                     AzureOrTimeout::Cancel => TimeoutOrCancel::Cancel.into(),
     509           49 :                 })?;
     510              :             }
     511           47 :             Ok(())
     512           47 :         };
     513              : 
     514           47 :         let res = tokio::select! {
     515              :             res = op => res,
     516              :             _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
     517              :         };
     518              : 
     519           47 :         let started_at = ScopeGuard::into_inner(started_at);
     520           47 :         crate::metrics::BUCKET_METRICS
     521           47 :             .req_seconds
     522           47 :             .observe_elapsed(kind, &res, started_at);
     523           47 :         res
     524           47 :     }
     525              : 
     526            1 :     async fn copy(
     527            1 :         &self,
     528            1 :         from: &RemotePath,
     529            1 :         to: &RemotePath,
     530            1 :         cancel: &CancellationToken,
     531            1 :     ) -> anyhow::Result<()> {
     532            1 :         let kind = RequestKind::Copy;
     533            1 :         let _permit = self.permit(kind, cancel).await?;
     534            1 :         let started_at = start_measuring_requests(kind);
     535            1 : 
     536            1 :         let timeout = tokio::time::sleep(self.timeout);
     537            1 : 
     538            1 :         let mut copy_status = None;
     539            1 : 
     540            1 :         let op = async {
     541            1 :             let blob_client = self.client.blob_client(self.relative_path_to_name(to));
     542              : 
     543            1 :             let source_url = format!(
     544            1 :                 "{}/{}",
     545            1 :                 self.client.url()?,
     546            1 :                 self.relative_path_to_name(from)
     547              :             );
     548              : 
     549            1 :             let builder = blob_client.copy(Url::from_str(&source_url)?);
     550            1 :             let copy = builder.into_future();
     551              : 
     552            5 :             let result = copy.await?;
     553              : 
     554            1 :             copy_status = Some(result.copy_status);
     555              :             loop {
     556            1 :                 match copy_status.as_ref().expect("we always set it to Some") {
     557              :                     CopyStatus::Aborted => {
     558            0 :                         anyhow::bail!("Received abort for copy from {from} to {to}.");
     559              :                     }
     560              :                     CopyStatus::Failed => {
     561            0 :                         anyhow::bail!("Received failure response for copy from {from} to {to}.");
     562              :                     }
     563            1 :                     CopyStatus::Success => return Ok(()),
     564            0 :                     CopyStatus::Pending => (),
     565            0 :                 }
     566            0 :                 // The copy is taking longer. Waiting a second and then re-trying.
     567            0 :                 // TODO estimate time based on copy_progress and adjust time based on that
     568            0 :                 tokio::time::sleep(Duration::from_millis(1000)).await;
     569            0 :                 let properties = blob_client.get_properties().into_future().await?;
     570            0 :                 let Some(status) = properties.blob.properties.copy_status else {
     571            0 :                     tracing::warn!("copy_status for copy is None!, from={from}, to={to}");
     572            0 :                     return Ok(());
     573              :                 };
     574            0 :                 copy_status = Some(status);
     575              :             }
     576            1 :         };
     577              : 
     578            1 :         let res = tokio::select! {
     579              :             res = op => res,
     580              :             _ = cancel.cancelled() => return Err(anyhow::Error::new(TimeoutOrCancel::Cancel)),
     581              :             _ = timeout => {
     582              :                 let e = anyhow::Error::new(TimeoutOrCancel::Timeout);
     583              :                 let e = e.context(format!("Timeout, last status: {copy_status:?}"));
     584              :                 Err(e)
     585              :             },
     586              :         };
     587              : 
     588            1 :         let started_at = ScopeGuard::into_inner(started_at);
     589            1 :         crate::metrics::BUCKET_METRICS
     590            1 :             .req_seconds
     591            1 :             .observe_elapsed(kind, &res, started_at);
     592            1 :         res
     593            1 :     }
     594              : 
     595            0 :     async fn time_travel_recover(
     596            0 :         &self,
     597            0 :         _prefix: Option<&RemotePath>,
     598            0 :         _timestamp: SystemTime,
     599            0 :         _done_if_after: SystemTime,
     600            0 :         _cancel: &CancellationToken,
     601            0 :     ) -> Result<(), TimeTravelError> {
     602            0 :         // TODO use Azure point in time recovery feature for this
     603            0 :         // https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
     604            0 :         Err(TimeTravelError::Unimplemented)
     605            0 :     }
     606              : }
     607              : 
     608              : pin_project_lite::pin_project! {
     609              :     /// Hack to work around not being able to stream once with azure sdk.
     610              :     ///
     611              :     /// Azure sdk clones streams around with the assumption that they are like
     612              :     /// `Arc<tokio::fs::File>` (except not supporting tokio), however our streams are not like
     613              :     /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`]
     614              :     /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially"
     615              :     /// seekable, but we can also just re-try the request easier.
     616              :     #[project = NonSeekableStreamProj]
     617              :     enum NonSeekableStream<S> {
     618              :         /// A stream wrappers initial form.
     619              :         ///
     620              :         /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1
     621              :         /// clone before first request, then this must be changed.
     622              :         Initial {
     623              :             inner: std::sync::Mutex<Option<tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>>>,
     624              :             len: usize,
     625              :         },
     626              :         /// The actually readable variant, produced by cloning the Initial variant.
     627              :         ///
     628              :         /// The sdk currently always clones once, even without retry policy.
     629              :         Actual {
     630              :             #[pin]
     631              :             inner: tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>,
     632              :             len: usize,
     633              :             read_any: bool,
     634              :         },
     635              :         /// Most likely unneeded, but left to make life easier, in case more clones are added.
     636              :         Cloned {
     637              :             len_was: usize,
     638              :         }
     639              :     }
     640              : }
     641              : 
     642              : impl<S> NonSeekableStream<S>
     643              : where
     644              :     S: Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
     645              : {
     646           47 :     fn new(inner: S, len: usize) -> NonSeekableStream<S> {
     647           47 :         use tokio_util::compat::TokioAsyncReadCompatExt;
     648           47 : 
     649           47 :         let inner = tokio_util::io::StreamReader::new(inner).compat();
     650           47 :         let inner = Some(inner);
     651           47 :         let inner = std::sync::Mutex::new(inner);
     652           47 :         NonSeekableStream::Initial { inner, len }
     653           47 :     }
     654              : }
     655              : 
     656              : impl<S> std::fmt::Debug for NonSeekableStream<S> {
     657            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     658            0 :         match self {
     659            0 :             Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(),
     660            0 :             Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(),
     661            0 :             Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(),
     662              :         }
     663            0 :     }
     664              : }
     665              : 
     666              : impl<S> futures::io::AsyncRead for NonSeekableStream<S>
     667              : where
     668              :     S: Stream<Item = std::io::Result<Bytes>>,
     669              : {
     670           47 :     fn poll_read(
     671           47 :         self: std::pin::Pin<&mut Self>,
     672           47 :         cx: &mut std::task::Context<'_>,
     673           47 :         buf: &mut [u8],
     674           47 :     ) -> std::task::Poll<std::io::Result<usize>> {
     675           47 :         match self.project() {
     676              :             NonSeekableStreamProj::Actual {
     677           47 :                 inner, read_any, ..
     678           47 :             } => {
     679           47 :                 *read_any = true;
     680           47 :                 inner.poll_read(cx, buf)
     681              :             }
     682              :             // NonSeekableStream::Initial does not support reading because it is just much easier
     683              :             // to have the mutex in place where one does not poll the contents, or that's how it
     684              :             // seemed originally. If there is a version upgrade which changes the cloning, then
     685              :             // that support needs to be hacked in.
     686              :             //
     687              :             // including {self:?} into the message would be useful, but unsure how to unproject.
     688            0 :             _ => std::task::Poll::Ready(Err(std::io::Error::new(
     689            0 :                 std::io::ErrorKind::Other,
     690            0 :                 "cloned or initial values cannot be read",
     691            0 :             ))),
     692              :         }
     693           47 :     }
     694              : }
     695              : 
     696              : impl<S> Clone for NonSeekableStream<S> {
     697              :     /// Weird clone implementation exists to support the sdk doing cloning before issuing the first
     698              :     /// request, see type documentation.
     699           47 :     fn clone(&self) -> Self {
     700           47 :         use NonSeekableStream::*;
     701           47 : 
     702           47 :         match self {
     703           47 :             Initial { inner, len } => {
     704           47 :                 if let Some(inner) = inner.lock().unwrap().take() {
     705           47 :                     Actual {
     706           47 :                         inner,
     707           47 :                         len: *len,
     708           47 :                         read_any: false,
     709           47 :                     }
     710              :                 } else {
     711            0 :                     Self::Cloned { len_was: *len }
     712              :                 }
     713              :             }
     714            0 :             Actual { len, .. } => Cloned { len_was: *len },
     715            0 :             Cloned { len_was } => Cloned { len_was: *len_was },
     716              :         }
     717           47 :     }
     718              : }
     719              : 
     720              : #[async_trait::async_trait]
     721              : impl<S> azure_core::SeekableStream for NonSeekableStream<S>
     722              : where
     723              :     S: Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync + 'static,
     724              : {
     725            0 :     async fn reset(&mut self) -> azure_core::error::Result<()> {
     726            0 :         use NonSeekableStream::*;
     727            0 : 
     728            0 :         let msg = match self {
     729            0 :             Initial { inner, .. } => {
     730            0 :                 if inner.get_mut().unwrap().is_some() {
     731            0 :                     return Ok(());
     732            0 :                 } else {
     733            0 :                     "reset after first clone is not supported"
     734            0 :                 }
     735            0 :             }
     736            0 :             Actual { read_any, .. } if !*read_any => return Ok(()),
     737            0 :             Actual { .. } => "reset after reading is not supported",
     738            0 :             Cloned { .. } => "reset after second clone is not supported",
     739            0 :         };
     740            0 :         Err(azure_core::error::Error::new(
     741            0 :             azure_core::error::ErrorKind::Io,
     742            0 :             std::io::Error::new(std::io::ErrorKind::Other, msg),
     743            0 :         ))
     744            0 :     }
     745              : 
     746              :     // Note: it is not documented if this should be the total or remaining length, total passes the
     747              :     // tests.
     748           47 :     fn len(&self) -> usize {
     749           47 :         use NonSeekableStream::*;
     750           47 :         match self {
     751           47 :             Initial { len, .. } => *len,
     752            0 :             Actual { len, .. } => *len,
     753            0 :             Cloned { len_was, .. } => *len_was,
     754              :         }
     755           47 :     }
     756              : }
        

Generated by: LCOV version 2.1-beta