Line data Source code
1 : //! Azure Blob Storage wrapper
2 :
3 : use std::borrow::Cow;
4 : use std::collections::HashMap;
5 : use std::env;
6 : use std::fmt::Display;
7 : use std::io;
8 : use std::num::NonZeroU32;
9 : use std::pin::Pin;
10 : use std::str::FromStr;
11 : use std::sync::Arc;
12 : use std::time::Duration;
13 : use std::time::SystemTime;
14 :
15 : use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
16 : use anyhow::Result;
17 : use azure_core::request_options::{MaxResults, Metadata, Range};
18 : use azure_core::{Continuable, RetryOptions};
19 : use azure_identity::DefaultAzureCredential;
20 : use azure_storage::StorageCredentials;
21 : use azure_storage_blobs::blob::CopyStatus;
22 : use azure_storage_blobs::prelude::ClientBuilder;
23 : use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
24 : use bytes::Bytes;
25 : use futures::future::Either;
26 : use futures::stream::Stream;
27 : use futures_util::StreamExt;
28 : use futures_util::TryStreamExt;
29 : use http_types::{StatusCode, Url};
30 : use scopeguard::ScopeGuard;
31 : use tokio_util::sync::CancellationToken;
32 : use tracing::debug;
33 : use utils::backoff;
34 :
35 : use crate::metrics::{start_measuring_requests, AttemptOutcome, RequestKind};
36 : use crate::ListingObject;
37 : use crate::{
38 : config::AzureConfig, error::Cancelled, ConcurrencyLimiter, Download, DownloadError, Listing,
39 : ListingMode, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
40 : };
41 :
42 : pub struct AzureBlobStorage {
43 : client: ContainerClient,
44 : container_name: String,
45 : prefix_in_container: Option<String>,
46 : max_keys_per_list_response: Option<NonZeroU32>,
47 : concurrency_limiter: ConcurrencyLimiter,
48 : // Per-request timeout. Accessible for tests.
49 : pub timeout: Duration,
50 : }
51 :
52 : impl AzureBlobStorage {
53 6 : pub fn new(azure_config: &AzureConfig, timeout: Duration) -> Result<Self> {
54 6 : debug!(
55 0 : "Creating azure remote storage for azure container {}",
56 : azure_config.container_name
57 : );
58 :
59 : // Use the storage account from the config by default, fall back to env var if not present.
60 6 : let account = azure_config.storage_account.clone().unwrap_or_else(|| {
61 6 : env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT")
62 6 : });
63 :
64 : // If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
65 : // otherwise try the token based credentials.
66 6 : let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
67 6 : StorageCredentials::access_key(account.clone(), access_key)
68 : } else {
69 0 : let token_credential = DefaultAzureCredential::default();
70 0 : StorageCredentials::token_credential(Arc::new(token_credential))
71 : };
72 :
73 : // we have an outer retry
74 6 : let builder = ClientBuilder::new(account, credentials).retry(RetryOptions::none());
75 6 :
76 6 : let client = builder.container_client(azure_config.container_name.to_owned());
77 :
78 6 : let max_keys_per_list_response =
79 6 : if let Some(limit) = azure_config.max_keys_per_list_response {
80 : Some(
81 2 : NonZeroU32::new(limit as u32)
82 2 : .ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
83 : )
84 : } else {
85 4 : None
86 : };
87 :
88 6 : Ok(AzureBlobStorage {
89 6 : client,
90 6 : container_name: azure_config.container_name.to_owned(),
91 6 : prefix_in_container: azure_config.prefix_in_container.to_owned(),
92 6 : max_keys_per_list_response,
93 6 : concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
94 6 : timeout,
95 6 : })
96 6 : }
97 :
98 108 : pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
99 108 : assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
100 108 : let path_string = path
101 108 : .get_path()
102 108 : .as_str()
103 108 : .trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
104 108 : match &self.prefix_in_container {
105 108 : Some(prefix) => {
106 108 : if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
107 108 : prefix.clone() + path_string
108 : } else {
109 0 : format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
110 : }
111 : }
112 0 : None => path_string.to_string(),
113 : }
114 108 : }
115 :
116 74 : fn name_to_relative_path(&self, key: &str) -> RemotePath {
117 74 : let relative_path =
118 74 : match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
119 74 : Some(stripped) => stripped,
120 : // we rely on Azure to return properly prefixed paths
121 : // for requests with a certain prefix
122 0 : None => panic!(
123 0 : "Key {key} does not start with container prefix {:?}",
124 0 : self.prefix_in_container
125 0 : ),
126 : };
127 74 : RemotePath(
128 74 : relative_path
129 74 : .split(REMOTE_STORAGE_PREFIX_SEPARATOR)
130 74 : .collect(),
131 74 : )
132 74 : }
133 :
134 7 : async fn download_for_builder(
135 7 : &self,
136 7 : builder: GetBlobBuilder,
137 7 : cancel: &CancellationToken,
138 7 : ) -> Result<Download, DownloadError> {
139 7 : let kind = RequestKind::Get;
140 :
141 7 : let _permit = self.permit(kind, cancel).await?;
142 7 : let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
143 7 : let cancel_or_timeout_ = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
144 7 :
145 7 : let mut etag = None;
146 7 : let mut last_modified = None;
147 7 : let mut metadata = HashMap::new();
148 7 :
149 7 : let started_at = start_measuring_requests(kind);
150 7 :
151 7 : let download = async {
152 7 : let response = builder
153 7 : // convert to concrete Pageable
154 7 : .into_stream()
155 7 : // convert to TryStream
156 7 : .into_stream()
157 7 : .map_err(to_download_error);
158 7 :
159 7 : // apply per request timeout
160 7 : let response = tokio_stream::StreamExt::timeout(response, self.timeout);
161 7 :
162 7 : // flatten
163 7 : let response = response.map(|res| match res {
164 7 : Ok(res) => res,
165 0 : Err(_elapsed) => Err(DownloadError::Timeout),
166 7 : });
167 7 :
168 7 : let mut response = Box::pin(response);
169 :
170 35 : let Some(part) = response.next().await else {
171 0 : return Err(DownloadError::Other(anyhow::anyhow!(
172 0 : "Azure GET response contained no response body"
173 0 : )));
174 : };
175 7 : let part = part?;
176 7 : if etag.is_none() {
177 7 : etag = Some(part.blob.properties.etag);
178 7 : }
179 7 : if last_modified.is_none() {
180 7 : last_modified = Some(part.blob.properties.last_modified.into());
181 7 : }
182 7 : if let Some(blob_meta) = part.blob.metadata {
183 0 : metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
184 7 : }
185 :
186 : // unwrap safety: if these were None, bufs would be empty and we would have returned an error already
187 7 : let etag = etag.unwrap();
188 7 : let last_modified = last_modified.unwrap();
189 7 :
190 7 : let tail_stream = response
191 7 : .map(|part| match part {
192 0 : Ok(part) => Either::Left(part.data.map(|r| r.map_err(io::Error::other))),
193 0 : Err(e) => {
194 0 : Either::Right(futures::stream::once(async { Err(io::Error::other(e)) }))
195 : }
196 7 : })
197 7 : .flatten();
198 7 : let stream = part
199 7 : .data
200 7 : .map(|r| r.map_err(io::Error::other))
201 7 : .chain(sync_wrapper::SyncStream::new(tail_stream));
202 7 : //.chain(SyncStream::from_pin(Box::pin(tail_stream)));
203 7 :
204 7 : let download_stream = crate::support::DownloadStream::new(cancel_or_timeout_, stream);
205 7 :
206 7 : Ok(Download {
207 7 : download_stream: Box::pin(download_stream),
208 7 : etag,
209 7 : last_modified,
210 7 : metadata: Some(StorageMetadata(metadata)),
211 7 : })
212 7 : };
213 :
214 7 : let download = tokio::select! {
215 7 : bufs = download => bufs,
216 7 : cancel_or_timeout = cancel_or_timeout => match cancel_or_timeout {
217 0 : TimeoutOrCancel::Timeout => return Err(DownloadError::Timeout),
218 0 : TimeoutOrCancel::Cancel => return Err(DownloadError::Cancelled),
219 : },
220 : };
221 7 : let started_at = ScopeGuard::into_inner(started_at);
222 7 : let outcome = match &download {
223 7 : Ok(_) => AttemptOutcome::Ok,
224 0 : Err(_) => AttemptOutcome::Err,
225 : };
226 7 : crate::metrics::BUCKET_METRICS
227 7 : .req_seconds
228 7 : .observe_elapsed(kind, outcome, started_at);
229 7 : download
230 7 : }
231 :
232 109 : async fn permit(
233 109 : &self,
234 109 : kind: RequestKind,
235 109 : cancel: &CancellationToken,
236 109 : ) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
237 109 : let acquire = self.concurrency_limiter.acquire(kind);
238 109 :
239 109 : tokio::select! {
240 109 : permit = acquire => Ok(permit.expect("never closed")),
241 109 : _ = cancel.cancelled() => Err(Cancelled),
242 : }
243 109 : }
244 :
245 0 : pub fn container_name(&self) -> &str {
246 0 : &self.container_name
247 0 : }
248 : }
249 :
250 0 : fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
251 0 : let mut res = Metadata::new();
252 0 : for (k, v) in metadata.0.into_iter() {
253 0 : res.insert(k, v);
254 0 : }
255 0 : res
256 0 : }
257 :
258 0 : fn to_download_error(error: azure_core::Error) -> DownloadError {
259 0 : if let Some(http_err) = error.as_http_error() {
260 0 : match http_err.status() {
261 0 : StatusCode::NotFound => DownloadError::NotFound,
262 0 : StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
263 0 : _ => DownloadError::Other(anyhow::Error::new(error)),
264 : }
265 : } else {
266 0 : DownloadError::Other(error.into())
267 : }
268 0 : }
269 :
270 : impl RemoteStorage for AzureBlobStorage {
271 7 : fn list_streaming(
272 7 : &self,
273 7 : prefix: Option<&RemotePath>,
274 7 : mode: ListingMode,
275 7 : max_keys: Option<NonZeroU32>,
276 7 : cancel: &CancellationToken,
277 7 : ) -> impl Stream<Item = Result<Listing, DownloadError>> {
278 7 : // get the passed prefix or if it is not set use prefix_in_bucket value
279 7 : let list_prefix = prefix
280 7 : .map(|p| self.relative_path_to_name(p))
281 7 : .or_else(|| self.prefix_in_container.clone())
282 7 : .map(|mut p| {
283 : // required to end with a separator
284 : // otherwise request will return only the entry of a prefix
285 7 : if matches!(mode, ListingMode::WithDelimiter)
286 4 : && !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
287 2 : {
288 2 : p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
289 5 : }
290 7 : p
291 7 : });
292 7 :
293 7 : async_stream::stream! {
294 7 : let _permit = self.permit(RequestKind::List, cancel).await?;
295 7 :
296 7 : let mut builder = self.client.list_blobs();
297 7 :
298 7 : if let ListingMode::WithDelimiter = mode {
299 7 : builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
300 7 : }
301 7 :
302 7 : if let Some(prefix) = list_prefix {
303 7 : builder = builder.prefix(Cow::from(prefix.to_owned()));
304 7 : }
305 7 :
306 7 : if let Some(limit) = self.max_keys_per_list_response {
307 7 : builder = builder.max_results(MaxResults::new(limit));
308 7 : }
309 7 :
310 7 : let mut next_marker = None;
311 7 :
312 7 : 'outer: loop {
313 7 : let mut builder = builder.clone();
314 7 : if let Some(marker) = next_marker.clone() {
315 7 : builder = builder.marker(marker);
316 7 : }
317 7 : let response = builder.into_stream();
318 7 : let response = response.into_stream().map_err(to_download_error);
319 7 : let response = tokio_stream::StreamExt::timeout(response, self.timeout);
320 13 : let response = response.map(|res| match res {
321 13 : Ok(res) => res,
322 7 : Err(_elapsed) => Err(DownloadError::Timeout),
323 13 : });
324 7 :
325 7 : let mut response = std::pin::pin!(response);
326 7 :
327 7 : let mut max_keys = max_keys.map(|mk| mk.get());
328 7 : let next_item = tokio::select! {
329 7 : op = response.next() => Ok(op),
330 7 : _ = cancel.cancelled() => Err(DownloadError::Cancelled),
331 7 : }?;
332 7 : let Some(entry) = next_item else {
333 7 : // The list is complete, so yield it.
334 7 : break;
335 7 : };
336 7 :
337 7 : let mut res = Listing::default();
338 7 : let entry = match entry {
339 7 : Ok(entry) => entry,
340 7 : Err(e) => {
341 7 : // The error is potentially retryable, so we must rewind the loop after yielding.
342 7 : yield Err(e);
343 7 : continue;
344 7 : }
345 7 : };
346 7 : next_marker = entry.continuation();
347 7 : let prefix_iter = entry
348 7 : .blobs
349 7 : .prefixes()
350 44 : .map(|prefix| self.name_to_relative_path(&prefix.name));
351 7 : res.prefixes.extend(prefix_iter);
352 7 :
353 7 : let blob_iter = entry
354 7 : .blobs
355 7 : .blobs()
356 30 : .map(|k| ListingObject{
357 30 : key: self.name_to_relative_path(&k.name),
358 30 : last_modified: k.properties.last_modified.into(),
359 30 : size: k.properties.content_length,
360 30 : }
361 7 : );
362 7 :
363 7 : for key in blob_iter {
364 7 : res.keys.push(key);
365 7 :
366 7 : if let Some(mut mk) = max_keys {
367 7 : assert!(mk > 0);
368 7 : mk -= 1;
369 7 : if mk == 0 {
370 7 : yield Ok(res); // limit reached
371 7 : break 'outer;
372 7 : }
373 7 : max_keys = Some(mk);
374 7 : }
375 7 : }
376 7 : yield Ok(res);
377 7 :
378 7 : // We are done here
379 7 : if next_marker.is_none() {
380 7 : break;
381 7 : }
382 7 : }
383 7 : }
384 7 : }
385 :
386 0 : async fn head_object(
387 0 : &self,
388 0 : key: &RemotePath,
389 0 : cancel: &CancellationToken,
390 0 : ) -> Result<ListingObject, DownloadError> {
391 0 : let kind = RequestKind::Head;
392 0 : let _permit = self.permit(kind, cancel).await?;
393 :
394 0 : let started_at = start_measuring_requests(kind);
395 0 :
396 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(key));
397 0 : let properties_future = blob_client.get_properties().into_future();
398 0 :
399 0 : let properties_future = tokio::time::timeout(self.timeout, properties_future);
400 :
401 0 : let res = tokio::select! {
402 0 : res = properties_future => res,
403 0 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
404 : };
405 :
406 0 : if let Ok(inner) = &res {
407 0 : // do not incl. timeouts as errors in metrics but cancellations
408 0 : let started_at = ScopeGuard::into_inner(started_at);
409 0 : crate::metrics::BUCKET_METRICS
410 0 : .req_seconds
411 0 : .observe_elapsed(kind, inner, started_at);
412 0 : }
413 :
414 0 : let data = match res {
415 0 : Ok(Ok(data)) => Ok(data),
416 0 : Ok(Err(sdk)) => Err(to_download_error(sdk)),
417 0 : Err(_timeout) => Err(DownloadError::Timeout),
418 0 : }?;
419 :
420 0 : let properties = data.blob.properties;
421 0 : Ok(ListingObject {
422 0 : key: key.to_owned(),
423 0 : last_modified: SystemTime::from(properties.last_modified),
424 0 : size: properties.content_length,
425 0 : })
426 0 : }
427 :
428 47 : async fn upload(
429 47 : &self,
430 47 : from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
431 47 : data_size_bytes: usize,
432 47 : to: &RemotePath,
433 47 : metadata: Option<StorageMetadata>,
434 47 : cancel: &CancellationToken,
435 47 : ) -> anyhow::Result<()> {
436 47 : let kind = RequestKind::Put;
437 47 : let _permit = self.permit(kind, cancel).await?;
438 :
439 47 : let started_at = start_measuring_requests(kind);
440 47 :
441 47 : let op = async {
442 47 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
443 47 :
444 47 : let from: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>> =
445 47 : Box::pin(from);
446 47 :
447 47 : let from = NonSeekableStream::new(from, data_size_bytes);
448 47 :
449 47 : let body = azure_core::Body::SeekableStream(Box::new(from));
450 47 :
451 47 : let mut builder = blob_client.put_block_blob(body);
452 :
453 47 : if let Some(metadata) = metadata {
454 0 : builder = builder.metadata(to_azure_metadata(metadata));
455 47 : }
456 :
457 47 : let fut = builder.into_future();
458 47 : let fut = tokio::time::timeout(self.timeout, fut);
459 47 :
460 270 : match fut.await {
461 47 : Ok(Ok(_response)) => Ok(()),
462 0 : Ok(Err(azure)) => Err(azure.into()),
463 0 : Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()),
464 : }
465 47 : };
466 :
467 47 : let res = tokio::select! {
468 47 : res = op => res,
469 47 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
470 : };
471 :
472 47 : let outcome = match res {
473 47 : Ok(_) => AttemptOutcome::Ok,
474 0 : Err(_) => AttemptOutcome::Err,
475 : };
476 47 : let started_at = ScopeGuard::into_inner(started_at);
477 47 : crate::metrics::BUCKET_METRICS
478 47 : .req_seconds
479 47 : .observe_elapsed(kind, outcome, started_at);
480 47 :
481 47 : res
482 47 : }
483 :
484 2 : async fn download(
485 2 : &self,
486 2 : from: &RemotePath,
487 2 : cancel: &CancellationToken,
488 2 : ) -> Result<Download, DownloadError> {
489 2 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
490 2 :
491 2 : let builder = blob_client.get();
492 2 :
493 10 : self.download_for_builder(builder, cancel).await
494 2 : }
495 :
496 5 : async fn download_byte_range(
497 5 : &self,
498 5 : from: &RemotePath,
499 5 : start_inclusive: u64,
500 5 : end_exclusive: Option<u64>,
501 5 : cancel: &CancellationToken,
502 5 : ) -> Result<Download, DownloadError> {
503 5 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
504 5 :
505 5 : let mut builder = blob_client.get();
506 :
507 5 : let range: Range = if let Some(end_exclusive) = end_exclusive {
508 3 : (start_inclusive..end_exclusive).into()
509 : } else {
510 2 : (start_inclusive..).into()
511 : };
512 5 : builder = builder.range(range);
513 5 :
514 25 : self.download_for_builder(builder, cancel).await
515 5 : }
516 :
517 44 : async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
518 44 : self.delete_objects(std::array::from_ref(path), cancel)
519 222 : .await
520 44 : }
521 :
522 47 : async fn delete_objects<'a>(
523 47 : &self,
524 47 : paths: &'a [RemotePath],
525 47 : cancel: &CancellationToken,
526 47 : ) -> anyhow::Result<()> {
527 47 : let kind = RequestKind::Delete;
528 47 : let _permit = self.permit(kind, cancel).await?;
529 47 : let started_at = start_measuring_requests(kind);
530 47 :
531 47 : let op = async {
532 : // TODO batch requests are not supported by the SDK
533 : // https://github.com/Azure/azure-sdk-for-rust/issues/1068
534 96 : for path in paths {
535 49 : #[derive(Debug)]
536 49 : enum AzureOrTimeout {
537 : AzureError(azure_core::Error),
538 : Timeout,
539 : Cancel,
540 49 : }
541 49 : impl Display for AzureOrTimeout {
542 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543 0 : write!(f, "{self:?}")
544 0 : }
545 : }
546 49 : let warn_threshold = 3;
547 49 : let max_retries = 5;
548 49 : backoff::retry(
549 49 : || async {
550 49 : let blob_client = self.client.blob_client(self.relative_path_to_name(path));
551 49 :
552 49 : let request = blob_client.delete().into_future();
553 :
554 247 : let res = tokio::time::timeout(self.timeout, request).await;
555 :
556 49 : match res {
557 48 : Ok(Ok(_v)) => Ok(()),
558 1 : Ok(Err(azure_err)) => {
559 1 : if let Some(http_err) = azure_err.as_http_error() {
560 1 : if http_err.status() == StatusCode::NotFound {
561 1 : return Ok(());
562 0 : }
563 0 : }
564 0 : Err(AzureOrTimeout::AzureError(azure_err))
565 : }
566 0 : Err(_elapsed) => Err(AzureOrTimeout::Timeout),
567 : }
568 98 : },
569 49 : |err| match err {
570 0 : AzureOrTimeout::AzureError(_) | AzureOrTimeout::Timeout => false,
571 0 : AzureOrTimeout::Cancel => true,
572 49 : },
573 49 : warn_threshold,
574 49 : max_retries,
575 49 : "deleting remote object",
576 49 : cancel,
577 49 : )
578 247 : .await
579 49 : .ok_or_else(|| AzureOrTimeout::Cancel)
580 49 : .and_then(|x| x)
581 49 : .map_err(|e| match e {
582 0 : AzureOrTimeout::AzureError(err) => anyhow::Error::from(err),
583 0 : AzureOrTimeout::Timeout => TimeoutOrCancel::Timeout.into(),
584 0 : AzureOrTimeout::Cancel => TimeoutOrCancel::Cancel.into(),
585 49 : })?;
586 : }
587 47 : Ok(())
588 47 : };
589 :
590 47 : let res = tokio::select! {
591 47 : res = op => res,
592 47 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
593 : };
594 :
595 47 : let started_at = ScopeGuard::into_inner(started_at);
596 47 : crate::metrics::BUCKET_METRICS
597 47 : .req_seconds
598 47 : .observe_elapsed(kind, &res, started_at);
599 47 : res
600 47 : }
601 :
602 1 : async fn copy(
603 1 : &self,
604 1 : from: &RemotePath,
605 1 : to: &RemotePath,
606 1 : cancel: &CancellationToken,
607 1 : ) -> anyhow::Result<()> {
608 1 : let kind = RequestKind::Copy;
609 1 : let _permit = self.permit(kind, cancel).await?;
610 1 : let started_at = start_measuring_requests(kind);
611 1 :
612 1 : let timeout = tokio::time::sleep(self.timeout);
613 1 :
614 1 : let mut copy_status = None;
615 1 :
616 1 : let op = async {
617 1 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
618 :
619 1 : let source_url = format!(
620 1 : "{}/{}",
621 1 : self.client.url()?,
622 1 : self.relative_path_to_name(from)
623 : );
624 :
625 1 : let builder = blob_client.copy(Url::from_str(&source_url)?);
626 1 : let copy = builder.into_future();
627 :
628 5 : let result = copy.await?;
629 :
630 1 : copy_status = Some(result.copy_status);
631 : loop {
632 1 : match copy_status.as_ref().expect("we always set it to Some") {
633 : CopyStatus::Aborted => {
634 0 : anyhow::bail!("Received abort for copy from {from} to {to}.");
635 : }
636 : CopyStatus::Failed => {
637 0 : anyhow::bail!("Received failure response for copy from {from} to {to}.");
638 : }
639 1 : CopyStatus::Success => return Ok(()),
640 0 : CopyStatus::Pending => (),
641 0 : }
642 0 : // The copy is taking longer. Waiting a second and then re-trying.
643 0 : // TODO estimate time based on copy_progress and adjust time based on that
644 0 : tokio::time::sleep(Duration::from_millis(1000)).await;
645 0 : let properties = blob_client.get_properties().into_future().await?;
646 0 : let Some(status) = properties.blob.properties.copy_status else {
647 0 : tracing::warn!("copy_status for copy is None!, from={from}, to={to}");
648 0 : return Ok(());
649 : };
650 0 : copy_status = Some(status);
651 : }
652 1 : };
653 :
654 1 : let res = tokio::select! {
655 1 : res = op => res,
656 1 : _ = cancel.cancelled() => return Err(anyhow::Error::new(TimeoutOrCancel::Cancel)),
657 1 : _ = timeout => {
658 0 : let e = anyhow::Error::new(TimeoutOrCancel::Timeout);
659 0 : let e = e.context(format!("Timeout, last status: {copy_status:?}"));
660 0 : Err(e)
661 : },
662 : };
663 :
664 1 : let started_at = ScopeGuard::into_inner(started_at);
665 1 : crate::metrics::BUCKET_METRICS
666 1 : .req_seconds
667 1 : .observe_elapsed(kind, &res, started_at);
668 1 : res
669 1 : }
670 :
671 0 : async fn time_travel_recover(
672 0 : &self,
673 0 : _prefix: Option<&RemotePath>,
674 0 : _timestamp: SystemTime,
675 0 : _done_if_after: SystemTime,
676 0 : _cancel: &CancellationToken,
677 0 : ) -> Result<(), TimeTravelError> {
678 0 : // TODO use Azure point in time recovery feature for this
679 0 : // https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
680 0 : Err(TimeTravelError::Unimplemented)
681 0 : }
682 : }
683 :
684 : pin_project_lite::pin_project! {
685 : /// Hack to work around not being able to stream once with azure sdk.
686 : ///
687 : /// Azure sdk clones streams around with the assumption that they are like
688 : /// `Arc<tokio::fs::File>` (except not supporting tokio), however our streams are not like
689 : /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`]
690 : /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially"
691 : /// seekable, but we can also just re-try the request easier.
692 : #[project = NonSeekableStreamProj]
693 : enum NonSeekableStream<S> {
694 : /// A stream wrappers initial form.
695 : ///
696 : /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1
697 : /// clone before first request, then this must be changed.
698 : Initial {
699 : inner: std::sync::Mutex<Option<tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>>>,
700 : len: usize,
701 : },
702 : /// The actually readable variant, produced by cloning the Initial variant.
703 : ///
704 : /// The sdk currently always clones once, even without retry policy.
705 : Actual {
706 : #[pin]
707 : inner: tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>,
708 : len: usize,
709 : read_any: bool,
710 : },
711 : /// Most likely unneeded, but left to make life easier, in case more clones are added.
712 : Cloned {
713 : len_was: usize,
714 : }
715 : }
716 : }
717 :
718 : impl<S> NonSeekableStream<S>
719 : where
720 : S: Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
721 : {
722 47 : fn new(inner: S, len: usize) -> NonSeekableStream<S> {
723 : use tokio_util::compat::TokioAsyncReadCompatExt;
724 :
725 47 : let inner = tokio_util::io::StreamReader::new(inner).compat();
726 47 : let inner = Some(inner);
727 47 : let inner = std::sync::Mutex::new(inner);
728 47 : NonSeekableStream::Initial { inner, len }
729 47 : }
730 : }
731 :
732 : impl<S> std::fmt::Debug for NonSeekableStream<S> {
733 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
734 0 : match self {
735 0 : Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(),
736 0 : Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(),
737 0 : Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(),
738 : }
739 0 : }
740 : }
741 :
742 : impl<S> futures::io::AsyncRead for NonSeekableStream<S>
743 : where
744 : S: Stream<Item = std::io::Result<Bytes>>,
745 : {
746 47 : fn poll_read(
747 47 : self: std::pin::Pin<&mut Self>,
748 47 : cx: &mut std::task::Context<'_>,
749 47 : buf: &mut [u8],
750 47 : ) -> std::task::Poll<std::io::Result<usize>> {
751 47 : match self.project() {
752 : NonSeekableStreamProj::Actual {
753 47 : inner, read_any, ..
754 47 : } => {
755 47 : *read_any = true;
756 47 : inner.poll_read(cx, buf)
757 : }
758 : // NonSeekableStream::Initial does not support reading because it is just much easier
759 : // to have the mutex in place where one does not poll the contents, or that's how it
760 : // seemed originally. If there is a version upgrade which changes the cloning, then
761 : // that support needs to be hacked in.
762 : //
763 : // including {self:?} into the message would be useful, but unsure how to unproject.
764 0 : _ => std::task::Poll::Ready(Err(std::io::Error::new(
765 0 : std::io::ErrorKind::Other,
766 0 : "cloned or initial values cannot be read",
767 0 : ))),
768 : }
769 47 : }
770 : }
771 :
772 : impl<S> Clone for NonSeekableStream<S> {
773 : /// Weird clone implementation exists to support the sdk doing cloning before issuing the first
774 : /// request, see type documentation.
775 47 : fn clone(&self) -> Self {
776 : use NonSeekableStream::*;
777 :
778 47 : match self {
779 47 : Initial { inner, len } => {
780 47 : if let Some(inner) = inner.lock().unwrap().take() {
781 47 : Actual {
782 47 : inner,
783 47 : len: *len,
784 47 : read_any: false,
785 47 : }
786 : } else {
787 0 : Self::Cloned { len_was: *len }
788 : }
789 : }
790 0 : Actual { len, .. } => Cloned { len_was: *len },
791 0 : Cloned { len_was } => Cloned { len_was: *len_was },
792 : }
793 47 : }
794 : }
795 :
796 : #[async_trait::async_trait]
797 : impl<S> azure_core::SeekableStream for NonSeekableStream<S>
798 : where
799 : S: Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync + 'static,
800 : {
801 0 : async fn reset(&mut self) -> azure_core::error::Result<()> {
802 : use NonSeekableStream::*;
803 :
804 0 : let msg = match self {
805 0 : Initial { inner, .. } => {
806 0 : if inner.get_mut().unwrap().is_some() {
807 0 : return Ok(());
808 : } else {
809 0 : "reset after first clone is not supported"
810 : }
811 : }
812 0 : Actual { read_any, .. } if !*read_any => return Ok(()),
813 0 : Actual { .. } => "reset after reading is not supported",
814 0 : Cloned { .. } => "reset after second clone is not supported",
815 : };
816 0 : Err(azure_core::error::Error::new(
817 0 : azure_core::error::ErrorKind::Io,
818 0 : std::io::Error::new(std::io::ErrorKind::Other, msg),
819 0 : ))
820 0 : }
821 :
822 : // Note: it is not documented if this should be the total or remaining length, total passes the
823 : // tests.
824 47 : fn len(&self) -> usize {
825 : use NonSeekableStream::*;
826 47 : match self {
827 47 : Initial { len, .. } => *len,
828 0 : Actual { len, .. } => *len,
829 0 : Cloned { len_was, .. } => *len_was,
830 : }
831 47 : }
832 : }
|