Line data Source code
1 : //! Azure Blob Storage wrapper
2 :
3 : use std::borrow::Cow;
4 : use std::collections::HashMap;
5 : use std::env;
6 : use std::fmt::Display;
7 : use std::io;
8 : use std::num::NonZeroU32;
9 : use std::pin::Pin;
10 : use std::str::FromStr;
11 : use std::sync::Arc;
12 : use std::time::Duration;
13 : use std::time::SystemTime;
14 :
15 : use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
16 : use anyhow::Result;
17 : use azure_core::request_options::{MaxResults, Metadata, Range};
18 : use azure_core::{Continuable, RetryOptions};
19 : use azure_identity::DefaultAzureCredential;
20 : use azure_storage::StorageCredentials;
21 : use azure_storage_blobs::blob::CopyStatus;
22 : use azure_storage_blobs::prelude::ClientBuilder;
23 : use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
24 : use bytes::Bytes;
25 : use futures::future::Either;
26 : use futures::stream::Stream;
27 : use futures_util::StreamExt;
28 : use futures_util::TryStreamExt;
29 : use http_types::{StatusCode, Url};
30 : use scopeguard::ScopeGuard;
31 : use tokio_util::sync::CancellationToken;
32 : use tracing::debug;
33 : use utils::backoff;
34 :
35 : use crate::metrics::{start_measuring_requests, AttemptOutcome, RequestKind};
36 : use crate::ListingObject;
37 : use crate::{
38 : config::AzureConfig, error::Cancelled, ConcurrencyLimiter, Download, DownloadError, Listing,
39 : ListingMode, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
40 : };
41 :
42 : pub struct AzureBlobStorage {
43 : client: ContainerClient,
44 : container_name: String,
45 : prefix_in_container: Option<String>,
46 : max_keys_per_list_response: Option<NonZeroU32>,
47 : concurrency_limiter: ConcurrencyLimiter,
48 : // Per-request timeout. Accessible for tests.
49 : pub timeout: Duration,
50 : }
51 :
52 : impl AzureBlobStorage {
53 6 : pub fn new(azure_config: &AzureConfig, timeout: Duration) -> Result<Self> {
54 6 : debug!(
55 0 : "Creating azure remote storage for azure container {}",
56 : azure_config.container_name
57 : );
58 :
59 : // Use the storage account from the config by default, fall back to env var if not present.
60 6 : let account = azure_config.storage_account.clone().unwrap_or_else(|| {
61 6 : env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT")
62 6 : });
63 :
64 : // If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
65 : // otherwise try the token based credentials.
66 6 : let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
67 6 : StorageCredentials::access_key(account.clone(), access_key)
68 : } else {
69 0 : let token_credential = DefaultAzureCredential::default();
70 0 : StorageCredentials::token_credential(Arc::new(token_credential))
71 : };
72 :
73 : // we have an outer retry
74 6 : let builder = ClientBuilder::new(account, credentials).retry(RetryOptions::none());
75 6 :
76 6 : let client = builder.container_client(azure_config.container_name.to_owned());
77 :
78 6 : let max_keys_per_list_response =
79 6 : if let Some(limit) = azure_config.max_keys_per_list_response {
80 : Some(
81 2 : NonZeroU32::new(limit as u32)
82 2 : .ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
83 : )
84 : } else {
85 4 : None
86 : };
87 :
88 6 : Ok(AzureBlobStorage {
89 6 : client,
90 6 : container_name: azure_config.container_name.to_owned(),
91 6 : prefix_in_container: azure_config.prefix_in_container.to_owned(),
92 6 : max_keys_per_list_response,
93 6 : concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
94 6 : timeout,
95 6 : })
96 6 : }
97 :
98 108 : pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
99 108 : assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
100 108 : let path_string = path
101 108 : .get_path()
102 108 : .as_str()
103 108 : .trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
104 108 : match &self.prefix_in_container {
105 108 : Some(prefix) => {
106 108 : if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
107 108 : prefix.clone() + path_string
108 : } else {
109 0 : format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
110 : }
111 : }
112 0 : None => path_string.to_string(),
113 : }
114 108 : }
115 :
116 74 : fn name_to_relative_path(&self, key: &str) -> RemotePath {
117 74 : let relative_path =
118 74 : match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
119 74 : Some(stripped) => stripped,
120 : // we rely on Azure to return properly prefixed paths
121 : // for requests with a certain prefix
122 0 : None => panic!(
123 0 : "Key {key} does not start with container prefix {:?}",
124 0 : self.prefix_in_container
125 0 : ),
126 : };
127 74 : RemotePath(
128 74 : relative_path
129 74 : .split(REMOTE_STORAGE_PREFIX_SEPARATOR)
130 74 : .collect(),
131 74 : )
132 74 : }
133 :
134 7 : async fn download_for_builder(
135 7 : &self,
136 7 : builder: GetBlobBuilder,
137 7 : cancel: &CancellationToken,
138 7 : ) -> Result<Download, DownloadError> {
139 7 : let kind = RequestKind::Get;
140 :
141 7 : let _permit = self.permit(kind, cancel).await?;
142 7 : let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
143 7 : let cancel_or_timeout_ = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
144 7 :
145 7 : let mut etag = None;
146 7 : let mut last_modified = None;
147 7 : let mut metadata = HashMap::new();
148 7 :
149 7 : let started_at = start_measuring_requests(kind);
150 7 :
151 7 : let download = async {
152 7 : let response = builder
153 7 : // convert to concrete Pageable
154 7 : .into_stream()
155 7 : // convert to TryStream
156 7 : .into_stream()
157 7 : .map_err(to_download_error);
158 7 :
159 7 : // apply per request timeout
160 7 : let response = tokio_stream::StreamExt::timeout(response, self.timeout);
161 7 :
162 7 : // flatten
163 7 : let response = response.map(|res| match res {
164 7 : Ok(res) => res,
165 0 : Err(_elapsed) => Err(DownloadError::Timeout),
166 7 : });
167 7 :
168 7 : let mut response = Box::pin(response);
169 :
170 35 : let Some(part) = response.next().await else {
171 0 : return Err(DownloadError::Other(anyhow::anyhow!(
172 0 : "Azure GET response contained no response body"
173 0 : )));
174 : };
175 7 : let part = part?;
176 7 : if etag.is_none() {
177 7 : etag = Some(part.blob.properties.etag);
178 7 : }
179 7 : if last_modified.is_none() {
180 7 : last_modified = Some(part.blob.properties.last_modified.into());
181 7 : }
182 7 : if let Some(blob_meta) = part.blob.metadata {
183 0 : metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
184 7 : }
185 :
186 : // unwrap safety: if these were None, bufs would be empty and we would have returned an error already
187 7 : let etag = etag.unwrap();
188 7 : let last_modified = last_modified.unwrap();
189 7 :
190 7 : let tail_stream = response
191 7 : .map(|part| match part {
192 0 : Ok(part) => Either::Left(part.data.map(|r| r.map_err(io::Error::other))),
193 0 : Err(e) => {
194 0 : Either::Right(futures::stream::once(async { Err(io::Error::other(e)) }))
195 : }
196 7 : })
197 7 : .flatten();
198 7 : let stream = part
199 7 : .data
200 7 : .map(|r| r.map_err(io::Error::other))
201 7 : .chain(sync_wrapper::SyncStream::new(tail_stream));
202 7 : //.chain(SyncStream::from_pin(Box::pin(tail_stream)));
203 7 :
204 7 : let download_stream = crate::support::DownloadStream::new(cancel_or_timeout_, stream);
205 7 :
206 7 : Ok(Download {
207 7 : download_stream: Box::pin(download_stream),
208 7 : etag,
209 7 : last_modified,
210 7 : metadata: Some(StorageMetadata(metadata)),
211 7 : })
212 7 : };
213 :
214 7 : let download = tokio::select! {
215 : bufs = download => bufs,
216 : cancel_or_timeout = cancel_or_timeout => match cancel_or_timeout {
217 : TimeoutOrCancel::Timeout => return Err(DownloadError::Timeout),
218 : TimeoutOrCancel::Cancel => return Err(DownloadError::Cancelled),
219 : },
220 : };
221 7 : let started_at = ScopeGuard::into_inner(started_at);
222 7 : let outcome = match &download {
223 7 : Ok(_) => AttemptOutcome::Ok,
224 0 : Err(_) => AttemptOutcome::Err,
225 : };
226 7 : crate::metrics::BUCKET_METRICS
227 7 : .req_seconds
228 7 : .observe_elapsed(kind, outcome, started_at);
229 7 : download
230 7 : }
231 :
232 109 : async fn permit(
233 109 : &self,
234 109 : kind: RequestKind,
235 109 : cancel: &CancellationToken,
236 109 : ) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
237 109 : let acquire = self.concurrency_limiter.acquire(kind);
238 :
239 : tokio::select! {
240 : permit = acquire => Ok(permit.expect("never closed")),
241 : _ = cancel.cancelled() => Err(Cancelled),
242 : }
243 109 : }
244 :
245 0 : pub fn container_name(&self) -> &str {
246 0 : &self.container_name
247 0 : }
248 : }
249 :
250 0 : fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
251 0 : let mut res = Metadata::new();
252 0 : for (k, v) in metadata.0.into_iter() {
253 0 : res.insert(k, v);
254 0 : }
255 0 : res
256 0 : }
257 :
258 0 : fn to_download_error(error: azure_core::Error) -> DownloadError {
259 0 : if let Some(http_err) = error.as_http_error() {
260 0 : match http_err.status() {
261 0 : StatusCode::NotFound => DownloadError::NotFound,
262 0 : StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
263 0 : _ => DownloadError::Other(anyhow::Error::new(error)),
264 : }
265 : } else {
266 0 : DownloadError::Other(error.into())
267 : }
268 0 : }
269 :
270 : impl RemoteStorage for AzureBlobStorage {
271 7 : fn list_streaming(
272 7 : &self,
273 7 : prefix: Option<&RemotePath>,
274 7 : mode: ListingMode,
275 7 : max_keys: Option<NonZeroU32>,
276 7 : cancel: &CancellationToken,
277 7 : ) -> impl Stream<Item = Result<Listing, DownloadError>> {
278 7 : // get the passed prefix or if it is not set use prefix_in_bucket value
279 7 : let list_prefix = prefix
280 7 : .map(|p| self.relative_path_to_name(p))
281 7 : .or_else(|| self.prefix_in_container.clone())
282 7 : .map(|mut p| {
283 : // required to end with a separator
284 : // otherwise request will return only the entry of a prefix
285 7 : if matches!(mode, ListingMode::WithDelimiter)
286 4 : && !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
287 2 : {
288 2 : p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
289 5 : }
290 7 : p
291 7 : });
292 :
293 : async_stream::stream! {
294 : let _permit = self.permit(RequestKind::List, cancel).await?;
295 :
296 : let mut builder = self.client.list_blobs();
297 :
298 : if let ListingMode::WithDelimiter = mode {
299 : builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
300 : }
301 :
302 : if let Some(prefix) = list_prefix {
303 : builder = builder.prefix(Cow::from(prefix.to_owned()));
304 : }
305 :
306 : if let Some(limit) = self.max_keys_per_list_response {
307 : builder = builder.max_results(MaxResults::new(limit));
308 : }
309 :
310 : let mut next_marker = None;
311 :
312 : 'outer: loop {
313 : let mut builder = builder.clone();
314 : if let Some(marker) = next_marker.clone() {
315 : builder = builder.marker(marker);
316 : }
317 : let response = builder.into_stream();
318 : let response = response.into_stream().map_err(to_download_error);
319 : let response = tokio_stream::StreamExt::timeout(response, self.timeout);
320 13 : let response = response.map(|res| match res {
321 13 : Ok(res) => res,
322 0 : Err(_elapsed) => Err(DownloadError::Timeout),
323 13 : });
324 :
325 : let mut response = std::pin::pin!(response);
326 :
327 1 : let mut max_keys = max_keys.map(|mk| mk.get());
328 : let next_item = tokio::select! {
329 : op = response.next() => Ok(op),
330 : _ = cancel.cancelled() => Err(DownloadError::Cancelled),
331 : }?;
332 : let Some(entry) = next_item else {
333 : // The list is complete, so yield it.
334 : break;
335 : };
336 :
337 : let mut res = Listing::default();
338 : let entry = match entry {
339 : Ok(entry) => entry,
340 : Err(e) => {
341 : // The error is potentially retryable, so we must rewind the loop after yielding.
342 : yield Err(e);
343 : continue;
344 : }
345 : };
346 : next_marker = entry.continuation();
347 : let prefix_iter = entry
348 : .blobs
349 : .prefixes()
350 44 : .map(|prefix| self.name_to_relative_path(&prefix.name));
351 : res.prefixes.extend(prefix_iter);
352 :
353 : let blob_iter = entry
354 : .blobs
355 : .blobs()
356 30 : .map(|k| ListingObject{
357 30 : key: self.name_to_relative_path(&k.name),
358 30 : last_modified: k.properties.last_modified.into(),
359 30 : size: k.properties.content_length,
360 30 : }
361 : );
362 :
363 : for key in blob_iter {
364 : res.keys.push(key);
365 :
366 : if let Some(mut mk) = max_keys {
367 : assert!(mk > 0);
368 : mk -= 1;
369 : if mk == 0 {
370 : yield Ok(res); // limit reached
371 : break 'outer;
372 : }
373 : max_keys = Some(mk);
374 : }
375 : }
376 : yield Ok(res);
377 :
378 : // We are done here
379 : if next_marker.is_none() {
380 : break;
381 : }
382 : }
383 : }
384 7 : }
385 :
386 47 : async fn upload(
387 47 : &self,
388 47 : from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
389 47 : data_size_bytes: usize,
390 47 : to: &RemotePath,
391 47 : metadata: Option<StorageMetadata>,
392 47 : cancel: &CancellationToken,
393 47 : ) -> anyhow::Result<()> {
394 47 : let kind = RequestKind::Put;
395 47 : let _permit = self.permit(kind, cancel).await?;
396 :
397 47 : let started_at = start_measuring_requests(kind);
398 47 :
399 47 : let op = async {
400 47 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
401 47 :
402 47 : let from: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>> =
403 47 : Box::pin(from);
404 47 :
405 47 : let from = NonSeekableStream::new(from, data_size_bytes);
406 47 :
407 47 : let body = azure_core::Body::SeekableStream(Box::new(from));
408 47 :
409 47 : let mut builder = blob_client.put_block_blob(body);
410 :
411 47 : if let Some(metadata) = metadata {
412 0 : builder = builder.metadata(to_azure_metadata(metadata));
413 47 : }
414 :
415 47 : let fut = builder.into_future();
416 47 : let fut = tokio::time::timeout(self.timeout, fut);
417 47 :
418 268 : match fut.await {
419 47 : Ok(Ok(_response)) => Ok(()),
420 0 : Ok(Err(azure)) => Err(azure.into()),
421 0 : Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()),
422 : }
423 47 : };
424 :
425 47 : let res = tokio::select! {
426 : res = op => res,
427 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
428 : };
429 :
430 47 : let outcome = match res {
431 47 : Ok(_) => AttemptOutcome::Ok,
432 0 : Err(_) => AttemptOutcome::Err,
433 : };
434 47 : let started_at = ScopeGuard::into_inner(started_at);
435 47 : crate::metrics::BUCKET_METRICS
436 47 : .req_seconds
437 47 : .observe_elapsed(kind, outcome, started_at);
438 47 :
439 47 : res
440 47 : }
441 :
442 2 : async fn download(
443 2 : &self,
444 2 : from: &RemotePath,
445 2 : cancel: &CancellationToken,
446 2 : ) -> Result<Download, DownloadError> {
447 2 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
448 2 :
449 2 : let builder = blob_client.get();
450 2 :
451 10 : self.download_for_builder(builder, cancel).await
452 2 : }
453 :
454 5 : async fn download_byte_range(
455 5 : &self,
456 5 : from: &RemotePath,
457 5 : start_inclusive: u64,
458 5 : end_exclusive: Option<u64>,
459 5 : cancel: &CancellationToken,
460 5 : ) -> Result<Download, DownloadError> {
461 5 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
462 5 :
463 5 : let mut builder = blob_client.get();
464 :
465 5 : let range: Range = if let Some(end_exclusive) = end_exclusive {
466 3 : (start_inclusive..end_exclusive).into()
467 : } else {
468 2 : (start_inclusive..).into()
469 : };
470 5 : builder = builder.range(range);
471 5 :
472 25 : self.download_for_builder(builder, cancel).await
473 5 : }
474 :
475 44 : async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
476 44 : self.delete_objects(std::array::from_ref(path), cancel)
477 221 : .await
478 44 : }
479 :
480 47 : async fn delete_objects<'a>(
481 47 : &self,
482 47 : paths: &'a [RemotePath],
483 47 : cancel: &CancellationToken,
484 47 : ) -> anyhow::Result<()> {
485 47 : let kind = RequestKind::Delete;
486 47 : let _permit = self.permit(kind, cancel).await?;
487 47 : let started_at = start_measuring_requests(kind);
488 47 :
489 47 : let op = async {
490 : // TODO batch requests are not supported by the SDK
491 : // https://github.com/Azure/azure-sdk-for-rust/issues/1068
492 96 : for path in paths {
493 : #[derive(Debug)]
494 : enum AzureOrTimeout {
495 : AzureError(azure_core::Error),
496 : Timeout,
497 : Cancel,
498 : }
499 : impl Display for AzureOrTimeout {
500 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501 0 : write!(f, "{self:?}")
502 0 : }
503 : }
504 49 : let warn_threshold = 3;
505 49 : let max_retries = 5;
506 49 : backoff::retry(
507 49 : || async {
508 49 : let blob_client = self.client.blob_client(self.relative_path_to_name(path));
509 49 :
510 49 : let request = blob_client.delete().into_future();
511 49 :
512 246 : let res = tokio::time::timeout(self.timeout, request).await;
513 49 :
514 49 : match res {
515 49 : Ok(Ok(_v)) => Ok(()),
516 49 : Ok(Err(azure_err)) => {
517 49 : if let Some(http_err) = azure_err.as_http_error() {
518 49 : if http_err.status() == StatusCode::NotFound {
519 49 : return Ok(());
520 49 : }
521 49 : }
522 49 : Err(AzureOrTimeout::AzureError(azure_err))
523 49 : }
524 49 : Err(_elapsed) => Err(AzureOrTimeout::Timeout),
525 49 : }
526 49 : },
527 49 : |err| match err {
528 0 : AzureOrTimeout::AzureError(_) | AzureOrTimeout::Timeout => false,
529 0 : AzureOrTimeout::Cancel => true,
530 49 : },
531 49 : warn_threshold,
532 49 : max_retries,
533 49 : "deleting remote object",
534 49 : cancel,
535 49 : )
536 246 : .await
537 49 : .ok_or_else(|| AzureOrTimeout::Cancel)
538 49 : .and_then(|x| x)
539 49 : .map_err(|e| match e {
540 0 : AzureOrTimeout::AzureError(err) => anyhow::Error::from(err),
541 0 : AzureOrTimeout::Timeout => TimeoutOrCancel::Timeout.into(),
542 0 : AzureOrTimeout::Cancel => TimeoutOrCancel::Cancel.into(),
543 49 : })?;
544 : }
545 47 : Ok(())
546 47 : };
547 :
548 47 : let res = tokio::select! {
549 : res = op => res,
550 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
551 : };
552 :
553 47 : let started_at = ScopeGuard::into_inner(started_at);
554 47 : crate::metrics::BUCKET_METRICS
555 47 : .req_seconds
556 47 : .observe_elapsed(kind, &res, started_at);
557 47 : res
558 47 : }
559 :
560 1 : async fn copy(
561 1 : &self,
562 1 : from: &RemotePath,
563 1 : to: &RemotePath,
564 1 : cancel: &CancellationToken,
565 1 : ) -> anyhow::Result<()> {
566 1 : let kind = RequestKind::Copy;
567 1 : let _permit = self.permit(kind, cancel).await?;
568 1 : let started_at = start_measuring_requests(kind);
569 1 :
570 1 : let timeout = tokio::time::sleep(self.timeout);
571 1 :
572 1 : let mut copy_status = None;
573 1 :
574 1 : let op = async {
575 1 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
576 :
577 1 : let source_url = format!(
578 1 : "{}/{}",
579 1 : self.client.url()?,
580 1 : self.relative_path_to_name(from)
581 : );
582 :
583 1 : let builder = blob_client.copy(Url::from_str(&source_url)?);
584 1 : let copy = builder.into_future();
585 :
586 5 : let result = copy.await?;
587 :
588 1 : copy_status = Some(result.copy_status);
589 : loop {
590 1 : match copy_status.as_ref().expect("we always set it to Some") {
591 : CopyStatus::Aborted => {
592 0 : anyhow::bail!("Received abort for copy from {from} to {to}.");
593 : }
594 : CopyStatus::Failed => {
595 0 : anyhow::bail!("Received failure response for copy from {from} to {to}.");
596 : }
597 1 : CopyStatus::Success => return Ok(()),
598 0 : CopyStatus::Pending => (),
599 0 : }
600 0 : // The copy is taking longer. Waiting a second and then re-trying.
601 0 : // TODO estimate time based on copy_progress and adjust time based on that
602 0 : tokio::time::sleep(Duration::from_millis(1000)).await;
603 0 : let properties = blob_client.get_properties().into_future().await?;
604 0 : let Some(status) = properties.blob.properties.copy_status else {
605 0 : tracing::warn!("copy_status for copy is None!, from={from}, to={to}");
606 0 : return Ok(());
607 : };
608 0 : copy_status = Some(status);
609 : }
610 1 : };
611 :
612 1 : let res = tokio::select! {
613 : res = op => res,
614 : _ = cancel.cancelled() => return Err(anyhow::Error::new(TimeoutOrCancel::Cancel)),
615 : _ = timeout => {
616 : let e = anyhow::Error::new(TimeoutOrCancel::Timeout);
617 : let e = e.context(format!("Timeout, last status: {copy_status:?}"));
618 : Err(e)
619 : },
620 : };
621 :
622 1 : let started_at = ScopeGuard::into_inner(started_at);
623 1 : crate::metrics::BUCKET_METRICS
624 1 : .req_seconds
625 1 : .observe_elapsed(kind, &res, started_at);
626 1 : res
627 1 : }
628 :
629 0 : async fn time_travel_recover(
630 0 : &self,
631 0 : _prefix: Option<&RemotePath>,
632 0 : _timestamp: SystemTime,
633 0 : _done_if_after: SystemTime,
634 0 : _cancel: &CancellationToken,
635 0 : ) -> Result<(), TimeTravelError> {
636 0 : // TODO use Azure point in time recovery feature for this
637 0 : // https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
638 0 : Err(TimeTravelError::Unimplemented)
639 0 : }
640 : }
641 :
642 : pin_project_lite::pin_project! {
643 : /// Hack to work around not being able to stream once with azure sdk.
644 : ///
645 : /// Azure sdk clones streams around with the assumption that they are like
646 : /// `Arc<tokio::fs::File>` (except not supporting tokio), however our streams are not like
647 : /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`]
648 : /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially"
649 : /// seekable, but we can also just re-try the request easier.
650 : #[project = NonSeekableStreamProj]
651 : enum NonSeekableStream<S> {
652 : /// A stream wrappers initial form.
653 : ///
654 : /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1
655 : /// clone before first request, then this must be changed.
656 : Initial {
657 : inner: std::sync::Mutex<Option<tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>>>,
658 : len: usize,
659 : },
660 : /// The actually readable variant, produced by cloning the Initial variant.
661 : ///
662 : /// The sdk currently always clones once, even without retry policy.
663 : Actual {
664 : #[pin]
665 : inner: tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>,
666 : len: usize,
667 : read_any: bool,
668 : },
669 : /// Most likely unneeded, but left to make life easier, in case more clones are added.
670 : Cloned {
671 : len_was: usize,
672 : }
673 : }
674 : }
675 :
676 : impl<S> NonSeekableStream<S>
677 : where
678 : S: Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
679 : {
680 47 : fn new(inner: S, len: usize) -> NonSeekableStream<S> {
681 47 : use tokio_util::compat::TokioAsyncReadCompatExt;
682 47 :
683 47 : let inner = tokio_util::io::StreamReader::new(inner).compat();
684 47 : let inner = Some(inner);
685 47 : let inner = std::sync::Mutex::new(inner);
686 47 : NonSeekableStream::Initial { inner, len }
687 47 : }
688 : }
689 :
690 : impl<S> std::fmt::Debug for NonSeekableStream<S> {
691 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
692 0 : match self {
693 0 : Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(),
694 0 : Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(),
695 0 : Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(),
696 : }
697 0 : }
698 : }
699 :
700 : impl<S> futures::io::AsyncRead for NonSeekableStream<S>
701 : where
702 : S: Stream<Item = std::io::Result<Bytes>>,
703 : {
704 47 : fn poll_read(
705 47 : self: std::pin::Pin<&mut Self>,
706 47 : cx: &mut std::task::Context<'_>,
707 47 : buf: &mut [u8],
708 47 : ) -> std::task::Poll<std::io::Result<usize>> {
709 47 : match self.project() {
710 : NonSeekableStreamProj::Actual {
711 47 : inner, read_any, ..
712 47 : } => {
713 47 : *read_any = true;
714 47 : inner.poll_read(cx, buf)
715 : }
716 : // NonSeekableStream::Initial does not support reading because it is just much easier
717 : // to have the mutex in place where one does not poll the contents, or that's how it
718 : // seemed originally. If there is a version upgrade which changes the cloning, then
719 : // that support needs to be hacked in.
720 : //
721 : // including {self:?} into the message would be useful, but unsure how to unproject.
722 0 : _ => std::task::Poll::Ready(Err(std::io::Error::new(
723 0 : std::io::ErrorKind::Other,
724 0 : "cloned or initial values cannot be read",
725 0 : ))),
726 : }
727 47 : }
728 : }
729 :
730 : impl<S> Clone for NonSeekableStream<S> {
731 : /// Weird clone implementation exists to support the sdk doing cloning before issuing the first
732 : /// request, see type documentation.
733 47 : fn clone(&self) -> Self {
734 47 : use NonSeekableStream::*;
735 47 :
736 47 : match self {
737 47 : Initial { inner, len } => {
738 47 : if let Some(inner) = inner.lock().unwrap().take() {
739 47 : Actual {
740 47 : inner,
741 47 : len: *len,
742 47 : read_any: false,
743 47 : }
744 : } else {
745 0 : Self::Cloned { len_was: *len }
746 : }
747 : }
748 0 : Actual { len, .. } => Cloned { len_was: *len },
749 0 : Cloned { len_was } => Cloned { len_was: *len_was },
750 : }
751 47 : }
752 : }
753 :
754 : #[async_trait::async_trait]
755 : impl<S> azure_core::SeekableStream for NonSeekableStream<S>
756 : where
757 : S: Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync + 'static,
758 : {
759 0 : async fn reset(&mut self) -> azure_core::error::Result<()> {
760 0 : use NonSeekableStream::*;
761 0 :
762 0 : let msg = match self {
763 0 : Initial { inner, .. } => {
764 0 : if inner.get_mut().unwrap().is_some() {
765 0 : return Ok(());
766 0 : } else {
767 0 : "reset after first clone is not supported"
768 0 : }
769 0 : }
770 0 : Actual { read_any, .. } if !*read_any => return Ok(()),
771 0 : Actual { .. } => "reset after reading is not supported",
772 0 : Cloned { .. } => "reset after second clone is not supported",
773 0 : };
774 0 : Err(azure_core::error::Error::new(
775 0 : azure_core::error::ErrorKind::Io,
776 0 : std::io::Error::new(std::io::ErrorKind::Other, msg),
777 0 : ))
778 0 : }
779 :
780 : // Note: it is not documented if this should be the total or remaining length, total passes the
781 : // tests.
782 47 : fn len(&self) -> usize {
783 47 : use NonSeekableStream::*;
784 47 : match self {
785 47 : Initial { len, .. } => *len,
786 0 : Actual { len, .. } => *len,
787 0 : Cloned { len_was, .. } => *len_was,
788 : }
789 47 : }
790 : }
|