Line data Source code
1 : //! Azure Blob Storage wrapper
2 :
3 : use std::borrow::Cow;
4 : use std::collections::HashMap;
5 : use std::env;
6 : use std::fmt::Display;
7 : use std::io;
8 : use std::num::NonZeroU32;
9 : use std::pin::Pin;
10 : use std::str::FromStr;
11 : use std::sync::Arc;
12 : use std::time::Duration;
13 : use std::time::SystemTime;
14 :
15 : use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
16 : use anyhow::Result;
17 : use azure_core::request_options::{MaxResults, Metadata, Range};
18 : use azure_core::RetryOptions;
19 : use azure_identity::DefaultAzureCredential;
20 : use azure_storage::StorageCredentials;
21 : use azure_storage_blobs::blob::CopyStatus;
22 : use azure_storage_blobs::prelude::ClientBuilder;
23 : use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
24 : use bytes::Bytes;
25 : use futures::future::Either;
26 : use futures::stream::Stream;
27 : use futures_util::StreamExt;
28 : use futures_util::TryStreamExt;
29 : use http_types::{StatusCode, Url};
30 : use scopeguard::ScopeGuard;
31 : use tokio_util::sync::CancellationToken;
32 : use tracing::debug;
33 : use utils::backoff;
34 :
35 : use crate::metrics::{start_measuring_requests, AttemptOutcome, RequestKind};
36 : use crate::{
37 : error::Cancelled, AzureConfig, ConcurrencyLimiter, Download, DownloadError, Listing,
38 : ListingMode, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
39 : };
40 :
41 : pub struct AzureBlobStorage {
42 : client: ContainerClient,
43 : prefix_in_container: Option<String>,
44 : max_keys_per_list_response: Option<NonZeroU32>,
45 : concurrency_limiter: ConcurrencyLimiter,
46 : // Per-request timeout. Accessible for tests.
47 : pub timeout: Duration,
48 : }
49 :
50 : impl AzureBlobStorage {
51 6 : pub fn new(azure_config: &AzureConfig, timeout: Duration) -> Result<Self> {
52 6 : debug!(
53 0 : "Creating azure remote storage for azure container {}",
54 : azure_config.container_name
55 : );
56 :
57 : // Use the storage account from the config by default, fall back to env var if not present.
58 6 : let account = azure_config.storage_account.clone().unwrap_or_else(|| {
59 6 : env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT")
60 6 : });
61 :
62 : // If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
63 : // otherwise try the token based credentials.
64 6 : let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
65 6 : StorageCredentials::access_key(account.clone(), access_key)
66 : } else {
67 0 : let token_credential = DefaultAzureCredential::default();
68 0 : StorageCredentials::token_credential(Arc::new(token_credential))
69 : };
70 :
71 : // we have an outer retry
72 6 : let builder = ClientBuilder::new(account, credentials).retry(RetryOptions::none());
73 6 :
74 6 : let client = builder.container_client(azure_config.container_name.to_owned());
75 :
76 6 : let max_keys_per_list_response =
77 6 : if let Some(limit) = azure_config.max_keys_per_list_response {
78 : Some(
79 2 : NonZeroU32::new(limit as u32)
80 2 : .ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
81 : )
82 : } else {
83 4 : None
84 : };
85 :
86 6 : Ok(AzureBlobStorage {
87 6 : client,
88 6 : prefix_in_container: azure_config.prefix_in_container.to_owned(),
89 6 : max_keys_per_list_response,
90 6 : concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
91 6 : timeout,
92 6 : })
93 6 : }
94 :
95 107 : pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
96 107 : assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
97 107 : let path_string = path
98 107 : .get_path()
99 107 : .as_str()
100 107 : .trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
101 107 : match &self.prefix_in_container {
102 107 : Some(prefix) => {
103 107 : if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
104 107 : prefix.clone() + path_string
105 : } else {
106 0 : format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
107 : }
108 : }
109 0 : None => path_string.to_string(),
110 : }
111 107 : }
112 :
113 53 : fn name_to_relative_path(&self, key: &str) -> RemotePath {
114 53 : let relative_path =
115 53 : match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
116 53 : Some(stripped) => stripped,
117 : // we rely on Azure to return properly prefixed paths
118 : // for requests with a certain prefix
119 0 : None => panic!(
120 0 : "Key {key} does not start with container prefix {:?}",
121 0 : self.prefix_in_container
122 0 : ),
123 : };
124 53 : RemotePath(
125 53 : relative_path
126 53 : .split(REMOTE_STORAGE_PREFIX_SEPARATOR)
127 53 : .collect(),
128 53 : )
129 53 : }
130 :
131 7 : async fn download_for_builder(
132 7 : &self,
133 7 : builder: GetBlobBuilder,
134 7 : cancel: &CancellationToken,
135 7 : ) -> Result<Download, DownloadError> {
136 7 : let kind = RequestKind::Get;
137 :
138 7 : let _permit = self.permit(kind, cancel).await?;
139 7 : let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
140 7 : let cancel_or_timeout_ = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
141 7 :
142 7 : let mut etag = None;
143 7 : let mut last_modified = None;
144 7 : let mut metadata = HashMap::new();
145 7 :
146 7 : let started_at = start_measuring_requests(kind);
147 7 :
148 7 : let download = async {
149 7 : let response = builder
150 7 : // convert to concrete Pageable
151 7 : .into_stream()
152 7 : // convert to TryStream
153 7 : .into_stream()
154 7 : .map_err(to_download_error);
155 7 :
156 7 : // apply per request timeout
157 7 : let response = tokio_stream::StreamExt::timeout(response, self.timeout);
158 7 :
159 7 : // flatten
160 7 : let response = response.map(|res| match res {
161 7 : Ok(res) => res,
162 0 : Err(_elapsed) => Err(DownloadError::Timeout),
163 7 : });
164 7 :
165 7 : let mut response = Box::pin(response);
166 :
167 36 : let Some(part) = response.next().await else {
168 0 : return Err(DownloadError::Other(anyhow::anyhow!(
169 0 : "Azure GET response contained no response body"
170 0 : )));
171 : };
172 7 : let part = part?;
173 7 : if etag.is_none() {
174 7 : etag = Some(part.blob.properties.etag);
175 7 : }
176 7 : if last_modified.is_none() {
177 7 : last_modified = Some(part.blob.properties.last_modified.into());
178 7 : }
179 7 : if let Some(blob_meta) = part.blob.metadata {
180 0 : metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
181 7 : }
182 :
183 : // unwrap safety: if these were None, bufs would be empty and we would have returned an error already
184 7 : let etag = etag.unwrap();
185 7 : let last_modified = last_modified.unwrap();
186 7 :
187 7 : let tail_stream = response
188 7 : .map(|part| match part {
189 0 : Ok(part) => Either::Left(part.data.map(|r| r.map_err(io::Error::other))),
190 0 : Err(e) => {
191 0 : Either::Right(futures::stream::once(async { Err(io::Error::other(e)) }))
192 : }
193 7 : })
194 7 : .flatten();
195 7 : let stream = part
196 7 : .data
197 7 : .map(|r| r.map_err(io::Error::other))
198 7 : .chain(sync_wrapper::SyncStream::new(tail_stream));
199 7 : //.chain(SyncStream::from_pin(Box::pin(tail_stream)));
200 7 :
201 7 : let download_stream = crate::support::DownloadStream::new(cancel_or_timeout_, stream);
202 7 :
203 7 : Ok(Download {
204 7 : download_stream: Box::pin(download_stream),
205 7 : etag,
206 7 : last_modified,
207 7 : metadata: Some(StorageMetadata(metadata)),
208 7 : })
209 7 : };
210 :
211 7 : let download = tokio::select! {
212 : bufs = download => bufs,
213 : cancel_or_timeout = cancel_or_timeout => match cancel_or_timeout {
214 : TimeoutOrCancel::Timeout => return Err(DownloadError::Timeout),
215 : TimeoutOrCancel::Cancel => return Err(DownloadError::Cancelled),
216 : },
217 : };
218 7 : let started_at = ScopeGuard::into_inner(started_at);
219 7 : let outcome = match &download {
220 7 : Ok(_) => AttemptOutcome::Ok,
221 0 : Err(_) => AttemptOutcome::Err,
222 : };
223 7 : crate::metrics::BUCKET_METRICS
224 7 : .req_seconds
225 7 : .observe_elapsed(kind, outcome, started_at);
226 7 : download
227 7 : }
228 :
229 108 : async fn permit(
230 108 : &self,
231 108 : kind: RequestKind,
232 108 : cancel: &CancellationToken,
233 108 : ) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
234 108 : let acquire = self.concurrency_limiter.acquire(kind);
235 :
236 : tokio::select! {
237 : permit = acquire => Ok(permit.expect("never closed")),
238 : _ = cancel.cancelled() => Err(Cancelled),
239 : }
240 108 : }
241 : }
242 :
243 0 : fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
244 0 : let mut res = Metadata::new();
245 0 : for (k, v) in metadata.0.into_iter() {
246 0 : res.insert(k, v);
247 0 : }
248 0 : res
249 0 : }
250 :
251 0 : fn to_download_error(error: azure_core::Error) -> DownloadError {
252 0 : if let Some(http_err) = error.as_http_error() {
253 0 : match http_err.status() {
254 0 : StatusCode::NotFound => DownloadError::NotFound,
255 0 : StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
256 0 : _ => DownloadError::Other(anyhow::Error::new(error)),
257 : }
258 : } else {
259 0 : DownloadError::Other(error.into())
260 : }
261 0 : }
262 :
263 : impl RemoteStorage for AzureBlobStorage {
264 6 : async fn list(
265 6 : &self,
266 6 : prefix: Option<&RemotePath>,
267 6 : mode: ListingMode,
268 6 : max_keys: Option<NonZeroU32>,
269 6 : cancel: &CancellationToken,
270 6 : ) -> anyhow::Result<Listing, DownloadError> {
271 6 : let _permit = self.permit(RequestKind::List, cancel).await?;
272 :
273 6 : let op = async {
274 6 : // get the passed prefix or if it is not set use prefix_in_bucket value
275 6 : let list_prefix = prefix
276 6 : .map(|p| self.relative_path_to_name(p))
277 6 : .or_else(|| self.prefix_in_container.clone())
278 6 : .map(|mut p| {
279 : // required to end with a separator
280 : // otherwise request will return only the entry of a prefix
281 6 : if matches!(mode, ListingMode::WithDelimiter)
282 3 : && !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
283 1 : {
284 1 : p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
285 5 : }
286 6 : p
287 6 : });
288 6 :
289 6 : let mut builder = self.client.list_blobs();
290 6 :
291 6 : if let ListingMode::WithDelimiter = mode {
292 3 : builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
293 3 : }
294 :
295 6 : if let Some(prefix) = list_prefix {
296 6 : builder = builder.prefix(Cow::from(prefix.to_owned()));
297 6 : }
298 :
299 6 : if let Some(limit) = self.max_keys_per_list_response {
300 5 : builder = builder.max_results(MaxResults::new(limit));
301 5 : }
302 :
303 6 : let response = builder.into_stream();
304 6 : let response = response.into_stream().map_err(to_download_error);
305 6 : let response = tokio_stream::StreamExt::timeout(response, self.timeout);
306 10 : let response = response.map(|res| match res {
307 10 : Ok(res) => res,
308 0 : Err(_elapsed) => Err(DownloadError::Timeout),
309 10 : });
310 6 :
311 6 : let mut response = std::pin::pin!(response);
312 6 :
313 6 : let mut res = Listing::default();
314 6 :
315 6 : let mut max_keys = max_keys.map(|mk| mk.get());
316 60 : while let Some(entry) = response.next().await {
317 10 : let entry = entry?;
318 10 : let prefix_iter = entry
319 10 : .blobs
320 10 : .prefixes()
321 23 : .map(|prefix| self.name_to_relative_path(&prefix.name));
322 10 : res.prefixes.extend(prefix_iter);
323 10 :
324 10 : let blob_iter = entry
325 10 : .blobs
326 10 : .blobs()
327 30 : .map(|k| self.name_to_relative_path(&k.name));
328 :
329 39 : for key in blob_iter {
330 30 : res.keys.push(key);
331 :
332 30 : if let Some(mut mk) = max_keys {
333 2 : assert!(mk > 0);
334 2 : mk -= 1;
335 2 : if mk == 0 {
336 1 : return Ok(res); // limit reached
337 1 : }
338 1 : max_keys = Some(mk);
339 28 : }
340 : }
341 : }
342 :
343 5 : Ok(res)
344 6 : };
345 :
346 : tokio::select! {
347 : res = op => res,
348 : _ = cancel.cancelled() => Err(DownloadError::Cancelled),
349 : }
350 6 : }
351 :
352 47 : async fn upload(
353 47 : &self,
354 47 : from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
355 47 : data_size_bytes: usize,
356 47 : to: &RemotePath,
357 47 : metadata: Option<StorageMetadata>,
358 47 : cancel: &CancellationToken,
359 47 : ) -> anyhow::Result<()> {
360 47 : let kind = RequestKind::Put;
361 47 : let _permit = self.permit(kind, cancel).await?;
362 :
363 47 : let started_at = start_measuring_requests(kind);
364 47 :
365 47 : let op = async {
366 47 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
367 47 :
368 47 : let from: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>> =
369 47 : Box::pin(from);
370 47 :
371 47 : let from = NonSeekableStream::new(from, data_size_bytes);
372 47 :
373 47 : let body = azure_core::Body::SeekableStream(Box::new(from));
374 47 :
375 47 : let mut builder = blob_client.put_block_blob(body);
376 :
377 47 : if let Some(metadata) = metadata {
378 0 : builder = builder.metadata(to_azure_metadata(metadata));
379 47 : }
380 :
381 47 : let fut = builder.into_future();
382 47 : let fut = tokio::time::timeout(self.timeout, fut);
383 47 :
384 268 : match fut.await {
385 47 : Ok(Ok(_response)) => Ok(()),
386 0 : Ok(Err(azure)) => Err(azure.into()),
387 0 : Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()),
388 : }
389 47 : };
390 :
391 47 : let res = tokio::select! {
392 : res = op => res,
393 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
394 : };
395 :
396 47 : let outcome = match res {
397 47 : Ok(_) => AttemptOutcome::Ok,
398 0 : Err(_) => AttemptOutcome::Err,
399 : };
400 47 : let started_at = ScopeGuard::into_inner(started_at);
401 47 : crate::metrics::BUCKET_METRICS
402 47 : .req_seconds
403 47 : .observe_elapsed(kind, outcome, started_at);
404 47 :
405 47 : res
406 47 : }
407 :
408 2 : async fn download(
409 2 : &self,
410 2 : from: &RemotePath,
411 2 : cancel: &CancellationToken,
412 2 : ) -> Result<Download, DownloadError> {
413 2 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
414 2 :
415 2 : let builder = blob_client.get();
416 2 :
417 11 : self.download_for_builder(builder, cancel).await
418 2 : }
419 :
420 5 : async fn download_byte_range(
421 5 : &self,
422 5 : from: &RemotePath,
423 5 : start_inclusive: u64,
424 5 : end_exclusive: Option<u64>,
425 5 : cancel: &CancellationToken,
426 5 : ) -> Result<Download, DownloadError> {
427 5 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
428 5 :
429 5 : let mut builder = blob_client.get();
430 :
431 5 : let range: Range = if let Some(end_exclusive) = end_exclusive {
432 3 : (start_inclusive..end_exclusive).into()
433 : } else {
434 2 : (start_inclusive..).into()
435 : };
436 5 : builder = builder.range(range);
437 5 :
438 25 : self.download_for_builder(builder, cancel).await
439 5 : }
440 :
441 44 : async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
442 44 : self.delete_objects(std::array::from_ref(path), cancel)
443 221 : .await
444 44 : }
445 :
446 47 : async fn delete_objects<'a>(
447 47 : &self,
448 47 : paths: &'a [RemotePath],
449 47 : cancel: &CancellationToken,
450 47 : ) -> anyhow::Result<()> {
451 47 : let kind = RequestKind::Delete;
452 47 : let _permit = self.permit(kind, cancel).await?;
453 47 : let started_at = start_measuring_requests(kind);
454 47 :
455 47 : let op = async {
456 : // TODO batch requests are not supported by the SDK
457 : // https://github.com/Azure/azure-sdk-for-rust/issues/1068
458 96 : for path in paths {
459 : #[derive(Debug)]
460 : enum AzureOrTimeout {
461 : AzureError(azure_core::Error),
462 : Timeout,
463 : Cancel,
464 : }
465 : impl Display for AzureOrTimeout {
466 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467 0 : write!(f, "{self:?}")
468 0 : }
469 : }
470 49 : let warn_threshold = 3;
471 49 : let max_retries = 5;
472 49 : backoff::retry(
473 49 : || async {
474 49 : let blob_client = self.client.blob_client(self.relative_path_to_name(path));
475 49 :
476 49 : let request = blob_client.delete().into_future();
477 49 :
478 246 : let res = tokio::time::timeout(self.timeout, request).await;
479 49 :
480 49 : match res {
481 49 : Ok(Ok(_v)) => Ok(()),
482 49 : Ok(Err(azure_err)) => {
483 49 : if let Some(http_err) = azure_err.as_http_error() {
484 49 : if http_err.status() == StatusCode::NotFound {
485 49 : return Ok(());
486 49 : }
487 49 : }
488 49 : Err(AzureOrTimeout::AzureError(azure_err))
489 49 : }
490 49 : Err(_elapsed) => Err(AzureOrTimeout::Timeout),
491 49 : }
492 49 : },
493 49 : |err| match err {
494 0 : AzureOrTimeout::AzureError(_) | AzureOrTimeout::Timeout => false,
495 0 : AzureOrTimeout::Cancel => true,
496 49 : },
497 49 : warn_threshold,
498 49 : max_retries,
499 49 : "deleting remote object",
500 49 : cancel,
501 49 : )
502 246 : .await
503 49 : .ok_or_else(|| AzureOrTimeout::Cancel)
504 49 : .and_then(|x| x)
505 49 : .map_err(|e| match e {
506 0 : AzureOrTimeout::AzureError(err) => anyhow::Error::from(err),
507 0 : AzureOrTimeout::Timeout => TimeoutOrCancel::Timeout.into(),
508 0 : AzureOrTimeout::Cancel => TimeoutOrCancel::Cancel.into(),
509 49 : })?;
510 : }
511 47 : Ok(())
512 47 : };
513 :
514 47 : let res = tokio::select! {
515 : res = op => res,
516 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
517 : };
518 :
519 47 : let started_at = ScopeGuard::into_inner(started_at);
520 47 : crate::metrics::BUCKET_METRICS
521 47 : .req_seconds
522 47 : .observe_elapsed(kind, &res, started_at);
523 47 : res
524 47 : }
525 :
526 1 : async fn copy(
527 1 : &self,
528 1 : from: &RemotePath,
529 1 : to: &RemotePath,
530 1 : cancel: &CancellationToken,
531 1 : ) -> anyhow::Result<()> {
532 1 : let kind = RequestKind::Copy;
533 1 : let _permit = self.permit(kind, cancel).await?;
534 1 : let started_at = start_measuring_requests(kind);
535 1 :
536 1 : let timeout = tokio::time::sleep(self.timeout);
537 1 :
538 1 : let mut copy_status = None;
539 1 :
540 1 : let op = async {
541 1 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
542 :
543 1 : let source_url = format!(
544 1 : "{}/{}",
545 1 : self.client.url()?,
546 1 : self.relative_path_to_name(from)
547 : );
548 :
549 1 : let builder = blob_client.copy(Url::from_str(&source_url)?);
550 1 : let copy = builder.into_future();
551 :
552 5 : let result = copy.await?;
553 :
554 1 : copy_status = Some(result.copy_status);
555 : loop {
556 1 : match copy_status.as_ref().expect("we always set it to Some") {
557 : CopyStatus::Aborted => {
558 0 : anyhow::bail!("Received abort for copy from {from} to {to}.");
559 : }
560 : CopyStatus::Failed => {
561 0 : anyhow::bail!("Received failure response for copy from {from} to {to}.");
562 : }
563 1 : CopyStatus::Success => return Ok(()),
564 0 : CopyStatus::Pending => (),
565 0 : }
566 0 : // The copy is taking longer. Waiting a second and then re-trying.
567 0 : // TODO estimate time based on copy_progress and adjust time based on that
568 0 : tokio::time::sleep(Duration::from_millis(1000)).await;
569 0 : let properties = blob_client.get_properties().into_future().await?;
570 0 : let Some(status) = properties.blob.properties.copy_status else {
571 0 : tracing::warn!("copy_status for copy is None!, from={from}, to={to}");
572 0 : return Ok(());
573 : };
574 0 : copy_status = Some(status);
575 : }
576 1 : };
577 :
578 1 : let res = tokio::select! {
579 : res = op => res,
580 : _ = cancel.cancelled() => return Err(anyhow::Error::new(TimeoutOrCancel::Cancel)),
581 : _ = timeout => {
582 : let e = anyhow::Error::new(TimeoutOrCancel::Timeout);
583 : let e = e.context(format!("Timeout, last status: {copy_status:?}"));
584 : Err(e)
585 : },
586 : };
587 :
588 1 : let started_at = ScopeGuard::into_inner(started_at);
589 1 : crate::metrics::BUCKET_METRICS
590 1 : .req_seconds
591 1 : .observe_elapsed(kind, &res, started_at);
592 1 : res
593 1 : }
594 :
595 0 : async fn time_travel_recover(
596 0 : &self,
597 0 : _prefix: Option<&RemotePath>,
598 0 : _timestamp: SystemTime,
599 0 : _done_if_after: SystemTime,
600 0 : _cancel: &CancellationToken,
601 0 : ) -> Result<(), TimeTravelError> {
602 0 : // TODO use Azure point in time recovery feature for this
603 0 : // https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
604 0 : Err(TimeTravelError::Unimplemented)
605 0 : }
606 : }
607 :
608 : pin_project_lite::pin_project! {
609 : /// Hack to work around not being able to stream once with azure sdk.
610 : ///
611 : /// Azure sdk clones streams around with the assumption that they are like
612 : /// `Arc<tokio::fs::File>` (except not supporting tokio), however our streams are not like
613 : /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`]
614 : /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially"
615 : /// seekable, but we can also just re-try the request easier.
616 : #[project = NonSeekableStreamProj]
617 : enum NonSeekableStream<S> {
618 : /// A stream wrappers initial form.
619 : ///
620 : /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1
621 : /// clone before first request, then this must be changed.
622 : Initial {
623 : inner: std::sync::Mutex<Option<tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>>>,
624 : len: usize,
625 : },
626 : /// The actually readable variant, produced by cloning the Initial variant.
627 : ///
628 : /// The sdk currently always clones once, even without retry policy.
629 : Actual {
630 : #[pin]
631 : inner: tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>,
632 : len: usize,
633 : read_any: bool,
634 : },
635 : /// Most likely unneeded, but left to make life easier, in case more clones are added.
636 : Cloned {
637 : len_was: usize,
638 : }
639 : }
640 : }
641 :
642 : impl<S> NonSeekableStream<S>
643 : where
644 : S: Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
645 : {
646 47 : fn new(inner: S, len: usize) -> NonSeekableStream<S> {
647 47 : use tokio_util::compat::TokioAsyncReadCompatExt;
648 47 :
649 47 : let inner = tokio_util::io::StreamReader::new(inner).compat();
650 47 : let inner = Some(inner);
651 47 : let inner = std::sync::Mutex::new(inner);
652 47 : NonSeekableStream::Initial { inner, len }
653 47 : }
654 : }
655 :
656 : impl<S> std::fmt::Debug for NonSeekableStream<S> {
657 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
658 0 : match self {
659 0 : Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(),
660 0 : Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(),
661 0 : Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(),
662 : }
663 0 : }
664 : }
665 :
666 : impl<S> futures::io::AsyncRead for NonSeekableStream<S>
667 : where
668 : S: Stream<Item = std::io::Result<Bytes>>,
669 : {
670 47 : fn poll_read(
671 47 : self: std::pin::Pin<&mut Self>,
672 47 : cx: &mut std::task::Context<'_>,
673 47 : buf: &mut [u8],
674 47 : ) -> std::task::Poll<std::io::Result<usize>> {
675 47 : match self.project() {
676 : NonSeekableStreamProj::Actual {
677 47 : inner, read_any, ..
678 47 : } => {
679 47 : *read_any = true;
680 47 : inner.poll_read(cx, buf)
681 : }
682 : // NonSeekableStream::Initial does not support reading because it is just much easier
683 : // to have the mutex in place where one does not poll the contents, or that's how it
684 : // seemed originally. If there is a version upgrade which changes the cloning, then
685 : // that support needs to be hacked in.
686 : //
687 : // including {self:?} into the message would be useful, but unsure how to unproject.
688 0 : _ => std::task::Poll::Ready(Err(std::io::Error::new(
689 0 : std::io::ErrorKind::Other,
690 0 : "cloned or initial values cannot be read",
691 0 : ))),
692 : }
693 47 : }
694 : }
695 :
696 : impl<S> Clone for NonSeekableStream<S> {
697 : /// Weird clone implementation exists to support the sdk doing cloning before issuing the first
698 : /// request, see type documentation.
699 47 : fn clone(&self) -> Self {
700 47 : use NonSeekableStream::*;
701 47 :
702 47 : match self {
703 47 : Initial { inner, len } => {
704 47 : if let Some(inner) = inner.lock().unwrap().take() {
705 47 : Actual {
706 47 : inner,
707 47 : len: *len,
708 47 : read_any: false,
709 47 : }
710 : } else {
711 0 : Self::Cloned { len_was: *len }
712 : }
713 : }
714 0 : Actual { len, .. } => Cloned { len_was: *len },
715 0 : Cloned { len_was } => Cloned { len_was: *len_was },
716 : }
717 47 : }
718 : }
719 :
720 : #[async_trait::async_trait]
721 : impl<S> azure_core::SeekableStream for NonSeekableStream<S>
722 : where
723 : S: Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync + 'static,
724 : {
725 0 : async fn reset(&mut self) -> azure_core::error::Result<()> {
726 0 : use NonSeekableStream::*;
727 0 :
728 0 : let msg = match self {
729 0 : Initial { inner, .. } => {
730 0 : if inner.get_mut().unwrap().is_some() {
731 0 : return Ok(());
732 0 : } else {
733 0 : "reset after first clone is not supported"
734 0 : }
735 0 : }
736 0 : Actual { read_any, .. } if !*read_any => return Ok(()),
737 0 : Actual { .. } => "reset after reading is not supported",
738 0 : Cloned { .. } => "reset after second clone is not supported",
739 0 : };
740 0 : Err(azure_core::error::Error::new(
741 0 : azure_core::error::ErrorKind::Io,
742 0 : std::io::Error::new(std::io::ErrorKind::Other, msg),
743 0 : ))
744 0 : }
745 :
746 : // Note: it is not documented if this should be the total or remaining length, total passes the
747 : // tests.
748 47 : fn len(&self) -> usize {
749 47 : use NonSeekableStream::*;
750 47 : match self {
751 47 : Initial { len, .. } => *len,
752 0 : Actual { len, .. } => *len,
753 0 : Cloned { len_was, .. } => *len_was,
754 : }
755 47 : }
756 : }
|