Line data Source code
1 : //! Azure Blob Storage wrapper
2 :
3 : use std::borrow::Cow;
4 : use std::collections::HashMap;
5 : use std::fmt::Display;
6 : use std::num::NonZeroU32;
7 : use std::pin::Pin;
8 : use std::str::FromStr;
9 : use std::sync::Arc;
10 : use std::time::{Duration, SystemTime};
11 : use std::{env, io};
12 :
13 : use anyhow::{Context, Result};
14 : use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
15 : use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions};
16 : use azure_storage::StorageCredentials;
17 : use azure_storage_blobs::blob::operations::GetBlobBuilder;
18 : use azure_storage_blobs::blob::{Blob, CopyStatus};
19 : use azure_storage_blobs::container::operations::ListBlobsBuilder;
20 : use azure_storage_blobs::prelude::{ClientBuilder, ContainerClient};
21 : use bytes::Bytes;
22 : use futures::FutureExt;
23 : use futures::future::Either;
24 : use futures::stream::Stream;
25 : use futures_util::{StreamExt, TryStreamExt};
26 : use http_types::{StatusCode, Url};
27 : use scopeguard::ScopeGuard;
28 : use tokio_util::sync::CancellationToken;
29 : use tracing::debug;
30 : use utils::backoff;
31 : use utils::backoff::exponential_backoff_duration_seconds;
32 :
33 : use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
34 : use crate::config::AzureConfig;
35 : use crate::error::Cancelled;
36 : use crate::metrics::{AttemptOutcome, RequestKind, start_measuring_requests};
37 : use crate::{
38 : ConcurrencyLimiter, Download, DownloadError, DownloadKind, DownloadOpts, Listing, ListingMode,
39 : ListingObject, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
40 : };
41 :
42 : pub struct AzureBlobStorage {
43 : client: ContainerClient,
44 : container_name: String,
45 : prefix_in_container: Option<String>,
46 : max_keys_per_list_response: Option<NonZeroU32>,
47 : concurrency_limiter: ConcurrencyLimiter,
48 : // Per-request timeout. Accessible for tests.
49 : pub timeout: Duration,
50 :
51 : // Alternative timeout used for metadata objects which are expected to be small
52 : pub small_timeout: Duration,
53 : }
54 :
55 : impl AzureBlobStorage {
56 10 : pub fn new(
57 10 : azure_config: &AzureConfig,
58 10 : timeout: Duration,
59 10 : small_timeout: Duration,
60 10 : ) -> Result<Self> {
61 10 : debug!(
62 0 : "Creating azure remote storage for azure container {}",
63 : azure_config.container_name
64 : );
65 :
66 : // Use the storage account from the config by default, fall back to env var if not present.
67 10 : let account = azure_config.storage_account.clone().unwrap_or_else(|| {
68 10 : env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT")
69 10 : });
70 :
71 : // If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
72 : // otherwise try the token based credentials.
73 10 : let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
74 10 : StorageCredentials::access_key(account.clone(), access_key)
75 : } else {
76 0 : let token_credential = azure_identity::create_default_credential()
77 0 : .context("trying to obtain Azure default credentials")?;
78 0 : StorageCredentials::token_credential(token_credential)
79 : };
80 :
81 10 : let builder = ClientBuilder::new(account, credentials)
82 10 : // we have an outer retry
83 10 : .retry(RetryOptions::none())
84 10 : // Customize transport to configure conneciton pooling
85 10 : .transport(TransportOptions::new(Self::reqwest_client(
86 10 : azure_config.conn_pool_size,
87 10 : )));
88 10 :
89 10 : let client = builder.container_client(azure_config.container_name.to_owned());
90 :
91 10 : let max_keys_per_list_response =
92 10 : if let Some(limit) = azure_config.max_keys_per_list_response {
93 : Some(
94 4 : NonZeroU32::new(limit as u32)
95 4 : .ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
96 : )
97 : } else {
98 6 : None
99 : };
100 :
101 10 : Ok(AzureBlobStorage {
102 10 : client,
103 10 : container_name: azure_config.container_name.to_owned(),
104 10 : prefix_in_container: azure_config.prefix_in_container.to_owned(),
105 10 : max_keys_per_list_response,
106 10 : concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
107 10 : timeout,
108 10 : small_timeout,
109 10 : })
110 10 : }
111 :
112 10 : fn reqwest_client(conn_pool_size: usize) -> Arc<dyn HttpClient> {
113 10 : let client = reqwest::ClientBuilder::new()
114 10 : .pool_max_idle_per_host(conn_pool_size)
115 10 : .build()
116 10 : .expect("failed to build `reqwest` client");
117 10 : Arc::new(client)
118 10 : }
119 :
120 237 : pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
121 237 : assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
122 237 : let path_string = path.get_path().as_str();
123 237 : match &self.prefix_in_container {
124 237 : Some(prefix) => {
125 237 : if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
126 237 : prefix.clone() + path_string
127 : } else {
128 0 : format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
129 : }
130 : }
131 0 : None => path_string.to_string(),
132 : }
133 237 : }
134 :
135 249 : fn name_to_relative_path(&self, key: &str) -> RemotePath {
136 249 : let relative_path =
137 249 : match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
138 249 : Some(stripped) => stripped,
139 : // we rely on Azure to return properly prefixed paths
140 : // for requests with a certain prefix
141 0 : None => panic!(
142 0 : "Key {key} does not start with container prefix {:?}",
143 0 : self.prefix_in_container
144 0 : ),
145 : };
146 249 : RemotePath(
147 249 : relative_path
148 249 : .split(REMOTE_STORAGE_PREFIX_SEPARATOR)
149 249 : .collect(),
150 249 : )
151 249 : }
152 :
153 11 : async fn download_for_builder(
154 11 : &self,
155 11 : builder: GetBlobBuilder,
156 11 : timeout: Duration,
157 11 : cancel: &CancellationToken,
158 11 : ) -> Result<Download, DownloadError> {
159 11 : let kind = RequestKind::Get;
160 :
161 11 : let _permit = self.permit(kind, cancel).await?;
162 11 : let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
163 11 : let cancel_or_timeout_ = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
164 11 :
165 11 : let mut etag = None;
166 11 : let mut last_modified = None;
167 11 : let mut metadata = HashMap::new();
168 11 :
169 11 : let started_at = start_measuring_requests(kind);
170 11 :
171 11 : let download = async {
172 11 : let response = builder
173 11 : // convert to concrete Pageable
174 11 : .into_stream()
175 11 : // convert to TryStream
176 11 : .into_stream()
177 11 : .map_err(to_download_error);
178 11 :
179 11 : // apply per request timeout
180 11 : let response = tokio_stream::StreamExt::timeout(response, timeout);
181 11 :
182 11 : // flatten
183 11 : let response = response.map(|res| match res {
184 11 : Ok(res) => res,
185 0 : Err(_elapsed) => Err(DownloadError::Timeout),
186 11 : });
187 11 :
188 11 : let mut response = Box::pin(response);
189 :
190 11 : let Some(part) = response.next().await else {
191 0 : return Err(DownloadError::Other(anyhow::anyhow!(
192 0 : "Azure GET response contained no response body"
193 0 : )));
194 : };
195 11 : let part = part?;
196 9 : if etag.is_none() {
197 9 : etag = Some(part.blob.properties.etag);
198 9 : }
199 9 : if last_modified.is_none() {
200 9 : last_modified = Some(part.blob.properties.last_modified.into());
201 9 : }
202 9 : if let Some(blob_meta) = part.blob.metadata {
203 0 : metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
204 9 : }
205 :
206 : // unwrap safety: if these were None, bufs would be empty and we would have returned an error already
207 9 : let etag = etag.unwrap();
208 9 : let last_modified = last_modified.unwrap();
209 9 :
210 9 : let tail_stream = response
211 9 : .map(|part| match part {
212 0 : Ok(part) => Either::Left(part.data.map(|r| r.map_err(io::Error::other))),
213 0 : Err(e) => {
214 0 : Either::Right(futures::stream::once(async { Err(io::Error::other(e)) }))
215 : }
216 9 : })
217 9 : .flatten();
218 9 : let stream = part
219 9 : .data
220 9 : .map(|r| r.map_err(io::Error::other))
221 9 : .chain(sync_wrapper::SyncStream::new(tail_stream));
222 9 : //.chain(SyncStream::from_pin(Box::pin(tail_stream)));
223 9 :
224 9 : let download_stream = crate::support::DownloadStream::new(cancel_or_timeout_, stream);
225 9 :
226 9 : Ok(Download {
227 9 : download_stream: Box::pin(download_stream),
228 9 : etag,
229 9 : last_modified,
230 9 : metadata: Some(StorageMetadata(metadata)),
231 9 : })
232 11 : };
233 :
234 11 : let download = tokio::select! {
235 11 : bufs = download => bufs,
236 11 : cancel_or_timeout = cancel_or_timeout => match cancel_or_timeout {
237 0 : TimeoutOrCancel::Timeout => return Err(DownloadError::Timeout),
238 0 : TimeoutOrCancel::Cancel => return Err(DownloadError::Cancelled),
239 : },
240 : };
241 11 : let started_at = ScopeGuard::into_inner(started_at);
242 11 : let outcome = match &download {
243 9 : Ok(_) => AttemptOutcome::Ok,
244 : // At this level in the stack 404 and 304 responses do not indicate an error.
245 : // There's expected cases when a blob may not exist or hasn't been modified since
246 : // the last get (e.g. probing for timeline indices and heatmap downloads).
247 : // Callers should handle errors if they are unexpected.
248 2 : Err(DownloadError::NotFound | DownloadError::Unmodified) => AttemptOutcome::Ok,
249 0 : Err(_) => AttemptOutcome::Err,
250 : };
251 11 : crate::metrics::BUCKET_METRICS
252 11 : .req_seconds
253 11 : .observe_elapsed(kind, outcome, started_at);
254 11 : download
255 11 : }
256 :
257 26 : fn list_streaming_for_fn<T: Default + ListingCollector>(
258 26 : &self,
259 26 : prefix: Option<&RemotePath>,
260 26 : mode: ListingMode,
261 26 : max_keys: Option<NonZeroU32>,
262 26 : cancel: &CancellationToken,
263 26 : request_kind: RequestKind,
264 26 : customize_builder: impl Fn(ListBlobsBuilder) -> ListBlobsBuilder,
265 26 : ) -> impl Stream<Item = Result<T, DownloadError>> {
266 26 : // get the passed prefix or if it is not set use prefix_in_bucket value
267 26 : let list_prefix = prefix.map(|p| self.relative_path_to_name(p)).or_else(|| {
268 10 : self.prefix_in_container.clone().map(|mut s| {
269 10 : if !s.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
270 0 : s.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
271 10 : }
272 10 : s
273 10 : })
274 26 : });
275 26 :
276 26 : async_stream::stream! {
277 26 : let _permit = self.permit(request_kind, cancel).await?;
278 26 :
279 26 : let mut builder = self.client.list_blobs();
280 26 :
281 26 : if let ListingMode::WithDelimiter = mode {
282 26 : builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
283 26 : }
284 26 :
285 26 : if let Some(prefix) = list_prefix {
286 26 : builder = builder.prefix(Cow::from(prefix.to_owned()));
287 26 : }
288 26 :
289 26 : if let Some(limit) = self.max_keys_per_list_response {
290 26 : builder = builder.max_results(MaxResults::new(limit));
291 26 : }
292 26 :
293 26 : builder = customize_builder(builder);
294 26 :
295 26 : let mut next_marker = None;
296 26 :
297 26 : let mut timeout_try_cnt = 1;
298 26 :
299 26 : 'outer: loop {
300 26 : let mut builder = builder.clone();
301 26 : if let Some(marker) = next_marker.clone() {
302 26 : builder = builder.marker(marker);
303 26 : }
304 26 : // Azure Blob Rust SDK does not expose the list blob API directly. Users have to use
305 26 : // their pageable iterator wrapper that returns all keys as a stream. We want to have
306 26 : // full control of paging, and therefore we only take the first item from the stream.
307 26 : let mut response_stream = builder.into_stream();
308 26 : let response = response_stream.next();
309 26 : // Timeout mechanism: Azure client will sometimes stuck on a request, but retrying that request
310 26 : // would immediately succeed. Therefore, we use exponential backoff timeout to retry the request.
311 26 : // (Usually, exponential backoff is used to determine the sleep time between two retries.) We
312 26 : // start with 10.0 second timeout, and double the timeout for each failure, up to 5 failures.
313 26 : // timeout = min(5 * (1.0+1.0)^n, self.timeout).
314 26 : let this_timeout = (5.0 * exponential_backoff_duration_seconds(timeout_try_cnt, 1.0, self.timeout.as_secs_f64())).min(self.timeout.as_secs_f64());
315 26 : let response = tokio::time::timeout(Duration::from_secs_f64(this_timeout), response);
316 46 : let response = response.map(|res| {
317 46 : match res {
318 46 : Ok(Some(Ok(res))) => Ok(Some(res)),
319 26 : Ok(Some(Err(e))) => Err(to_download_error(e)),
320 26 : Ok(None) => Ok(None),
321 26 : Err(_elasped) => Err(DownloadError::Timeout),
322 26 : }
323 46 : });
324 26 : let mut max_keys = max_keys.map(|mk| mk.get());
325 26 : let next_item = tokio::select! {
326 26 : op = response => op,
327 26 : _ = cancel.cancelled() => Err(DownloadError::Cancelled),
328 26 : };
329 26 :
330 26 : if let Err(DownloadError::Timeout) = &next_item {
331 26 : timeout_try_cnt += 1;
332 26 : if timeout_try_cnt <= 5 {
333 26 : continue;
334 26 : }
335 26 : }
336 26 :
337 26 : let next_item = next_item?;
338 26 :
339 26 : // Log a warning if we saw two timeouts in a row before a successful request
340 26 : if timeout_try_cnt > 2 {
341 26 : tracing::warn!("Azure Blob Storage list timed out and succeeded after {} tries", timeout_try_cnt);
342 26 : }
343 26 : timeout_try_cnt = 1;
344 26 :
345 26 : let Some(entry) = next_item else {
346 26 : // The list is complete, so yield it.
347 26 : break;
348 26 : };
349 26 :
350 26 : let mut res = T::default();
351 26 : next_marker = entry.continuation();
352 26 : let prefix_iter = entry
353 26 : .blobs
354 26 : .prefixes()
355 52 : .map(|prefix| self.name_to_relative_path(&prefix.name));
356 26 : res.add_prefixes(self, prefix_iter);
357 26 :
358 26 : let blob_iter = entry
359 26 : .blobs
360 26 : .blobs();
361 26 :
362 26 : for key in blob_iter {
363 26 : res.add_blob(self, key);
364 26 :
365 26 : if let Some(mut mk) = max_keys {
366 26 : assert!(mk > 0);
367 26 : mk -= 1;
368 26 : if mk == 0 {
369 26 : yield Ok(res); // limit reached
370 26 : break 'outer;
371 26 : }
372 26 : max_keys = Some(mk);
373 26 : }
374 26 : }
375 26 : yield Ok(res);
376 26 :
377 26 : // We are done here
378 26 : if next_marker.is_none() {
379 26 : break;
380 26 : }
381 26 : }
382 26 : }
383 26 : }
384 :
385 227 : async fn permit(
386 227 : &self,
387 227 : kind: RequestKind,
388 227 : cancel: &CancellationToken,
389 227 : ) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
390 227 : let acquire = self.concurrency_limiter.acquire(kind);
391 227 :
392 227 : tokio::select! {
393 227 : permit = acquire => Ok(permit.expect("never closed")),
394 227 : _ = cancel.cancelled() => Err(Cancelled),
395 : }
396 227 : }
397 :
398 0 : pub fn container_name(&self) -> &str {
399 0 : &self.container_name
400 0 : }
401 : }
402 :
403 : trait ListingCollector {
404 : fn add_prefixes(&mut self, abs: &AzureBlobStorage, prefix_it: impl Iterator<Item = RemotePath>);
405 : fn add_blob(&mut self, abs: &AzureBlobStorage, blob: &Blob);
406 : }
407 :
408 : impl ListingCollector for Listing {
409 46 : fn add_prefixes(
410 46 : &mut self,
411 46 : _abs: &AzureBlobStorage,
412 46 : prefix_it: impl Iterator<Item = RemotePath>,
413 46 : ) {
414 46 : self.prefixes.extend(prefix_it);
415 46 : }
416 197 : fn add_blob(&mut self, abs: &AzureBlobStorage, blob: &Blob) {
417 197 : self.keys.push(ListingObject {
418 197 : key: abs.name_to_relative_path(&blob.name),
419 197 : last_modified: blob.properties.last_modified.into(),
420 197 : size: blob.properties.content_length,
421 197 : });
422 197 : }
423 : }
424 :
425 : impl ListingCollector for crate::VersionListing {
426 0 : fn add_prefixes(
427 0 : &mut self,
428 0 : _abs: &AzureBlobStorage,
429 0 : _prefix_it: impl Iterator<Item = RemotePath>,
430 0 : ) {
431 0 : // nothing
432 0 : }
433 0 : fn add_blob(&mut self, abs: &AzureBlobStorage, blob: &Blob) {
434 0 : let id = crate::VersionId(blob.version_id.clone().expect("didn't find version ID"));
435 0 : self.versions.push(crate::Version {
436 0 : key: abs.name_to_relative_path(&blob.name),
437 0 : last_modified: blob.properties.last_modified.into(),
438 0 : kind: crate::VersionKind::Version(id),
439 0 : });
440 0 : }
441 : }
442 :
443 0 : fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
444 0 : let mut res = Metadata::new();
445 0 : for (k, v) in metadata.0.into_iter() {
446 0 : res.insert(k, v);
447 0 : }
448 0 : res
449 0 : }
450 :
451 3 : fn to_download_error(error: azure_core::Error) -> DownloadError {
452 3 : if let Some(http_err) = error.as_http_error() {
453 3 : match http_err.status() {
454 1 : StatusCode::NotFound => DownloadError::NotFound,
455 2 : StatusCode::NotModified => DownloadError::Unmodified,
456 0 : StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
457 0 : _ => DownloadError::Other(anyhow::Error::new(error)),
458 : }
459 : } else {
460 0 : DownloadError::Other(error.into())
461 : }
462 3 : }
463 :
464 : impl RemoteStorage for AzureBlobStorage {
465 26 : fn list_streaming(
466 26 : &self,
467 26 : prefix: Option<&RemotePath>,
468 26 : mode: ListingMode,
469 26 : max_keys: Option<NonZeroU32>,
470 26 : cancel: &CancellationToken,
471 26 : ) -> impl Stream<Item = Result<Listing, DownloadError>> {
472 26 : let customize_builder = |builder| builder;
473 26 : let kind = RequestKind::ListVersions;
474 26 : self.list_streaming_for_fn(prefix, mode, max_keys, cancel, kind, customize_builder)
475 26 : }
476 :
477 0 : async fn list_versions(
478 0 : &self,
479 0 : prefix: Option<&RemotePath>,
480 0 : mode: ListingMode,
481 0 : max_keys: Option<NonZeroU32>,
482 0 : cancel: &CancellationToken,
483 0 : ) -> std::result::Result<crate::VersionListing, DownloadError> {
484 0 : let customize_builder = |mut builder: ListBlobsBuilder| {
485 0 : builder = builder.include_versions(true);
486 0 : builder
487 0 : };
488 0 : let kind = RequestKind::ListVersions;
489 0 :
490 0 : let mut stream = std::pin::pin!(self.list_streaming_for_fn(
491 0 : prefix,
492 0 : mode,
493 0 : max_keys,
494 0 : cancel,
495 0 : kind,
496 0 : customize_builder
497 0 : ));
498 0 : let mut combined: crate::VersionListing =
499 0 : stream.next().await.expect("At least one item required")?;
500 0 : while let Some(list) = stream.next().await {
501 0 : let list = list?;
502 0 : combined.versions.extend(list.versions.into_iter());
503 : }
504 0 : Ok(combined)
505 0 : }
506 :
507 3 : async fn head_object(
508 3 : &self,
509 3 : key: &RemotePath,
510 3 : cancel: &CancellationToken,
511 3 : ) -> Result<ListingObject, DownloadError> {
512 3 : let kind = RequestKind::Head;
513 3 : let _permit = self.permit(kind, cancel).await?;
514 :
515 3 : let started_at = start_measuring_requests(kind);
516 3 :
517 3 : let blob_client = self.client.blob_client(self.relative_path_to_name(key));
518 3 : let properties_future = blob_client.get_properties().into_future();
519 3 :
520 3 : let properties_future = tokio::time::timeout(self.small_timeout, properties_future);
521 :
522 3 : let res = tokio::select! {
523 3 : res = properties_future => res,
524 3 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
525 : };
526 :
527 3 : if let Ok(inner) = &res {
528 3 : // do not incl. timeouts as errors in metrics but cancellations
529 3 : let started_at = ScopeGuard::into_inner(started_at);
530 3 : crate::metrics::BUCKET_METRICS
531 3 : .req_seconds
532 3 : .observe_elapsed(kind, inner, started_at);
533 3 : }
534 :
535 3 : let data = match res {
536 2 : Ok(Ok(data)) => Ok(data),
537 1 : Ok(Err(sdk)) => Err(to_download_error(sdk)),
538 0 : Err(_timeout) => Err(DownloadError::Timeout),
539 1 : }?;
540 :
541 2 : let properties = data.blob.properties;
542 2 : Ok(ListingObject {
543 2 : key: key.to_owned(),
544 2 : last_modified: SystemTime::from(properties.last_modified),
545 2 : size: properties.content_length,
546 2 : })
547 3 : }
548 :
549 93 : async fn upload(
550 93 : &self,
551 93 : from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
552 93 : data_size_bytes: usize,
553 93 : to: &RemotePath,
554 93 : metadata: Option<StorageMetadata>,
555 93 : cancel: &CancellationToken,
556 93 : ) -> anyhow::Result<()> {
557 93 : let kind = RequestKind::Put;
558 93 : let _permit = self.permit(kind, cancel).await?;
559 :
560 93 : let started_at = start_measuring_requests(kind);
561 93 :
562 93 : let op = async {
563 93 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
564 93 :
565 93 : let from: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>> =
566 93 : Box::pin(from);
567 93 :
568 93 : let from = NonSeekableStream::new(from, data_size_bytes);
569 93 :
570 93 : let body = azure_core::Body::SeekableStream(Box::new(from));
571 93 :
572 93 : let mut builder = blob_client.put_block_blob(body);
573 :
574 93 : if let Some(metadata) = metadata {
575 0 : builder = builder.metadata(to_azure_metadata(metadata));
576 0 : }
577 :
578 93 : let fut = builder.into_future();
579 93 : let fut = tokio::time::timeout(self.timeout, fut);
580 93 :
581 93 : match fut.await {
582 93 : Ok(Ok(_response)) => Ok(()),
583 0 : Ok(Err(azure)) => Err(azure.into()),
584 0 : Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()),
585 : }
586 0 : };
587 :
588 93 : let res = tokio::select! {
589 93 : res = op => res,
590 93 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
591 : };
592 :
593 93 : let outcome = match res {
594 93 : Ok(_) => AttemptOutcome::Ok,
595 0 : Err(_) => AttemptOutcome::Err,
596 : };
597 93 : let started_at = ScopeGuard::into_inner(started_at);
598 93 : crate::metrics::BUCKET_METRICS
599 93 : .req_seconds
600 93 : .observe_elapsed(kind, outcome, started_at);
601 93 :
602 93 : res
603 0 : }
604 :
605 11 : async fn download(
606 11 : &self,
607 11 : from: &RemotePath,
608 11 : opts: &DownloadOpts,
609 11 : cancel: &CancellationToken,
610 11 : ) -> Result<Download, DownloadError> {
611 11 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
612 11 :
613 11 : let mut builder = blob_client.get();
614 :
615 11 : if let Some(ref etag) = opts.etag {
616 3 : builder = builder.if_match(IfMatchCondition::NotMatch(etag.to_string()));
617 8 : }
618 :
619 11 : if let Some(ref version_id) = opts.version_id {
620 0 : let version_id = azure_storage_blobs::prelude::VersionId::new(version_id.0.clone());
621 0 : builder = builder.blob_versioning(version_id);
622 11 : }
623 :
624 11 : if let Some((start, end)) = opts.byte_range() {
625 5 : builder = builder.range(match end {
626 3 : Some(end) => Range::Range(start..end),
627 2 : None => Range::RangeFrom(start..),
628 : });
629 6 : }
630 :
631 11 : let timeout = match opts.kind {
632 0 : DownloadKind::Small => self.small_timeout,
633 11 : DownloadKind::Large => self.timeout,
634 : };
635 :
636 11 : self.download_for_builder(builder, timeout, cancel).await
637 11 : }
638 :
639 86 : async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
640 86 : self.delete_objects(std::array::from_ref(path), cancel)
641 86 : .await
642 86 : }
643 :
644 93 : async fn delete_objects(
645 93 : &self,
646 93 : paths: &[RemotePath],
647 93 : cancel: &CancellationToken,
648 93 : ) -> anyhow::Result<()> {
649 93 : let kind = RequestKind::Delete;
650 93 : let _permit = self.permit(kind, cancel).await?;
651 93 : let started_at = start_measuring_requests(kind);
652 93 :
653 93 : let op = async {
654 : // TODO batch requests are not supported by the SDK
655 : // https://github.com/Azure/azure-sdk-for-rust/issues/1068
656 205 : for path in paths {
657 112 : #[derive(Debug)]
658 112 : enum AzureOrTimeout {
659 : AzureError(azure_core::Error),
660 : Timeout,
661 : Cancel,
662 112 : }
663 112 : impl Display for AzureOrTimeout {
664 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
665 0 : write!(f, "{self:?}")
666 0 : }
667 : }
668 112 : let warn_threshold = 3;
669 112 : let max_retries = 5;
670 112 : backoff::retry(
671 112 : || async {
672 112 : let blob_client = self.client.blob_client(self.relative_path_to_name(path));
673 112 :
674 112 : let request = blob_client.delete().into_future();
675 :
676 112 : let res = tokio::time::timeout(self.timeout, request).await;
677 :
678 112 : match res {
679 90 : Ok(Ok(_v)) => Ok(()),
680 22 : Ok(Err(azure_err)) => {
681 22 : if let Some(http_err) = azure_err.as_http_error() {
682 22 : if http_err.status() == StatusCode::NotFound {
683 22 : return Ok(());
684 0 : }
685 0 : }
686 0 : Err(AzureOrTimeout::AzureError(azure_err))
687 : }
688 0 : Err(_elapsed) => Err(AzureOrTimeout::Timeout),
689 : }
690 224 : },
691 112 : |err| match err {
692 0 : AzureOrTimeout::AzureError(_) | AzureOrTimeout::Timeout => false,
693 0 : AzureOrTimeout::Cancel => true,
694 112 : },
695 112 : warn_threshold,
696 112 : max_retries,
697 112 : "deleting remote object",
698 112 : cancel,
699 112 : )
700 112 : .await
701 112 : .ok_or_else(|| AzureOrTimeout::Cancel)
702 112 : .and_then(|x| x)
703 112 : .map_err(|e| match e {
704 0 : AzureOrTimeout::AzureError(err) => anyhow::Error::from(err),
705 0 : AzureOrTimeout::Timeout => TimeoutOrCancel::Timeout.into(),
706 0 : AzureOrTimeout::Cancel => TimeoutOrCancel::Cancel.into(),
707 112 : })?;
708 : }
709 93 : Ok(())
710 93 : };
711 :
712 93 : let res = tokio::select! {
713 93 : res = op => res,
714 93 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
715 : };
716 :
717 93 : let started_at = ScopeGuard::into_inner(started_at);
718 93 : crate::metrics::BUCKET_METRICS
719 93 : .req_seconds
720 93 : .observe_elapsed(kind, &res, started_at);
721 93 : res
722 93 : }
723 :
724 0 : fn max_keys_per_delete(&self) -> usize {
725 0 : super::MAX_KEYS_PER_DELETE_AZURE
726 0 : }
727 :
728 1 : async fn copy(
729 1 : &self,
730 1 : from: &RemotePath,
731 1 : to: &RemotePath,
732 1 : cancel: &CancellationToken,
733 1 : ) -> anyhow::Result<()> {
734 1 : let kind = RequestKind::Copy;
735 1 : let _permit = self.permit(kind, cancel).await?;
736 1 : let started_at = start_measuring_requests(kind);
737 1 :
738 1 : let timeout = tokio::time::sleep(self.timeout);
739 1 :
740 1 : let mut copy_status = None;
741 1 :
742 1 : let op = async {
743 1 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
744 :
745 1 : let source_url = format!(
746 1 : "{}/{}",
747 1 : self.client.url()?,
748 1 : self.relative_path_to_name(from)
749 : );
750 :
751 1 : let builder = blob_client.copy(Url::from_str(&source_url)?);
752 1 : let copy = builder.into_future();
753 :
754 1 : let result = copy.await?;
755 :
756 1 : copy_status = Some(result.copy_status);
757 : loop {
758 1 : match copy_status.as_ref().expect("we always set it to Some") {
759 : CopyStatus::Aborted => {
760 0 : anyhow::bail!("Received abort for copy from {from} to {to}.");
761 : }
762 : CopyStatus::Failed => {
763 0 : anyhow::bail!("Received failure response for copy from {from} to {to}.");
764 : }
765 1 : CopyStatus::Success => return Ok(()),
766 0 : CopyStatus::Pending => (),
767 0 : }
768 0 : // The copy is taking longer. Waiting a second and then re-trying.
769 0 : // TODO estimate time based on copy_progress and adjust time based on that
770 0 : tokio::time::sleep(Duration::from_millis(1000)).await;
771 0 : let properties = blob_client.get_properties().into_future().await?;
772 0 : let Some(status) = properties.blob.properties.copy_status else {
773 0 : tracing::warn!("copy_status for copy is None!, from={from}, to={to}");
774 0 : return Ok(());
775 : };
776 0 : copy_status = Some(status);
777 : }
778 1 : };
779 :
780 1 : let res = tokio::select! {
781 1 : res = op => res,
782 1 : _ = cancel.cancelled() => return Err(anyhow::Error::new(TimeoutOrCancel::Cancel)),
783 1 : _ = timeout => {
784 0 : let e = anyhow::Error::new(TimeoutOrCancel::Timeout);
785 0 : let e = e.context(format!("Timeout, last status: {copy_status:?}"));
786 0 : Err(e)
787 : },
788 : };
789 :
790 1 : let started_at = ScopeGuard::into_inner(started_at);
791 1 : crate::metrics::BUCKET_METRICS
792 1 : .req_seconds
793 1 : .observe_elapsed(kind, &res, started_at);
794 1 : res
795 1 : }
796 :
797 0 : async fn time_travel_recover(
798 0 : &self,
799 0 : _prefix: Option<&RemotePath>,
800 0 : _timestamp: SystemTime,
801 0 : _done_if_after: SystemTime,
802 0 : _cancel: &CancellationToken,
803 0 : ) -> Result<(), TimeTravelError> {
804 0 : // TODO use Azure point in time recovery feature for this
805 0 : // https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
806 0 : Err(TimeTravelError::Unimplemented)
807 0 : }
808 : }
809 :
810 : pin_project_lite::pin_project! {
811 : /// Hack to work around not being able to stream once with azure sdk.
812 : ///
813 : /// Azure sdk clones streams around with the assumption that they are like
814 : /// `Arc<tokio::fs::File>` (except not supporting tokio), however our streams are not like
815 : /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`]
816 : /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially"
817 : /// seekable, but we can also just re-try the request easier.
818 : #[project = NonSeekableStreamProj]
819 : enum NonSeekableStream<S> {
820 : /// A stream wrappers initial form.
821 : ///
822 : /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1
823 : /// clone before first request, then this must be changed.
824 : Initial {
825 : inner: std::sync::Mutex<Option<tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>>>,
826 : len: usize,
827 : },
828 : /// The actually readable variant, produced by cloning the Initial variant.
829 : ///
830 : /// The sdk currently always clones once, even without retry policy.
831 : Actual {
832 : #[pin]
833 : inner: tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>,
834 : len: usize,
835 : read_any: bool,
836 : },
837 : /// Most likely unneeded, but left to make life easier, in case more clones are added.
838 : Cloned {
839 : len_was: usize,
840 : }
841 : }
842 : }
843 :
844 : impl<S> NonSeekableStream<S>
845 : where
846 : S: Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
847 : {
848 93 : fn new(inner: S, len: usize) -> NonSeekableStream<S> {
849 : use tokio_util::compat::TokioAsyncReadCompatExt;
850 :
851 93 : let inner = tokio_util::io::StreamReader::new(inner).compat();
852 93 : let inner = Some(inner);
853 93 : let inner = std::sync::Mutex::new(inner);
854 93 : NonSeekableStream::Initial { inner, len }
855 93 : }
856 : }
857 :
858 : impl<S> std::fmt::Debug for NonSeekableStream<S> {
859 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
860 0 : match self {
861 0 : Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(),
862 0 : Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(),
863 0 : Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(),
864 : }
865 0 : }
866 : }
867 :
868 : impl<S> futures::io::AsyncRead for NonSeekableStream<S>
869 : where
870 : S: Stream<Item = std::io::Result<Bytes>>,
871 : {
872 93 : fn poll_read(
873 93 : self: std::pin::Pin<&mut Self>,
874 93 : cx: &mut std::task::Context<'_>,
875 93 : buf: &mut [u8],
876 93 : ) -> std::task::Poll<std::io::Result<usize>> {
877 93 : match self.project() {
878 : NonSeekableStreamProj::Actual {
879 93 : inner, read_any, ..
880 93 : } => {
881 93 : *read_any = true;
882 93 : inner.poll_read(cx, buf)
883 : }
884 : // NonSeekableStream::Initial does not support reading because it is just much easier
885 : // to have the mutex in place where one does not poll the contents, or that's how it
886 : // seemed originally. If there is a version upgrade which changes the cloning, then
887 : // that support needs to be hacked in.
888 : //
889 : // including {self:?} into the message would be useful, but unsure how to unproject.
890 0 : _ => std::task::Poll::Ready(Err(std::io::Error::other(
891 0 : "cloned or initial values cannot be read",
892 0 : ))),
893 : }
894 0 : }
895 : }
896 :
897 : impl<S> Clone for NonSeekableStream<S> {
898 : /// Weird clone implementation exists to support the sdk doing cloning before issuing the first
899 : /// request, see type documentation.
900 93 : fn clone(&self) -> Self {
901 : use NonSeekableStream::*;
902 :
903 93 : match self {
904 93 : Initial { inner, len } => {
905 93 : if let Some(inner) = inner.lock().unwrap().take() {
906 93 : Actual {
907 93 : inner,
908 93 : len: *len,
909 93 : read_any: false,
910 93 : }
911 : } else {
912 0 : Self::Cloned { len_was: *len }
913 : }
914 : }
915 0 : Actual { len, .. } => Cloned { len_was: *len },
916 0 : Cloned { len_was } => Cloned { len_was: *len_was },
917 : }
918 0 : }
919 : }
920 :
921 : #[async_trait::async_trait]
922 : impl<S> azure_core::SeekableStream for NonSeekableStream<S>
923 : where
924 : S: Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync + 'static,
925 : {
926 0 : async fn reset(&mut self) -> azure_core::error::Result<()> {
927 : use NonSeekableStream::*;
928 :
929 0 : let msg = match self {
930 0 : Initial { inner, .. } => {
931 0 : if inner.get_mut().unwrap().is_some() {
932 0 : return Ok(());
933 : } else {
934 0 : "reset after first clone is not supported"
935 : }
936 : }
937 0 : Actual { read_any, .. } if !*read_any => return Ok(()),
938 0 : Actual { .. } => "reset after reading is not supported",
939 0 : Cloned { .. } => "reset after second clone is not supported",
940 : };
941 0 : Err(azure_core::error::Error::new(
942 0 : azure_core::error::ErrorKind::Io,
943 0 : std::io::Error::other(msg),
944 0 : ))
945 0 : }
946 :
947 : // Note: it is not documented if this should be the total or remaining length, total passes the
948 : // tests.
949 93 : fn len(&self) -> usize {
950 : use NonSeekableStream::*;
951 93 : match self {
952 93 : Initial { len, .. } => *len,
953 0 : Actual { len, .. } => *len,
954 0 : Cloned { len_was, .. } => *len_was,
955 : }
956 0 : }
957 : }
|