Line data Source code
1 : //! Azure Blob Storage wrapper
2 :
3 : use std::borrow::Cow;
4 : use std::collections::HashMap;
5 : use std::env;
6 : use std::io;
7 : use std::num::NonZeroU32;
8 : use std::pin::Pin;
9 : use std::str::FromStr;
10 : use std::sync::Arc;
11 : use std::time::Duration;
12 : use std::time::SystemTime;
13 :
14 : use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
15 : use anyhow::Result;
16 : use azure_core::request_options::{MaxResults, Metadata, Range};
17 : use azure_core::RetryOptions;
18 : use azure_identity::DefaultAzureCredential;
19 : use azure_storage::StorageCredentials;
20 : use azure_storage_blobs::blob::CopyStatus;
21 : use azure_storage_blobs::prelude::ClientBuilder;
22 : use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
23 : use bytes::Bytes;
24 : use futures::future::Either;
25 : use futures::stream::Stream;
26 : use futures_util::StreamExt;
27 : use futures_util::TryStreamExt;
28 : use http_types::{StatusCode, Url};
29 : use tokio_util::sync::CancellationToken;
30 : use tracing::debug;
31 :
32 : use crate::RemoteStorageActivity;
33 : use crate::{
34 : error::Cancelled, s3_bucket::RequestKind, AzureConfig, ConcurrencyLimiter, Download,
35 : DownloadError, Listing, ListingMode, RemotePath, RemoteStorage, StorageMetadata,
36 : TimeTravelError, TimeoutOrCancel,
37 : };
38 :
39 : pub struct AzureBlobStorage {
40 : client: ContainerClient,
41 : prefix_in_container: Option<String>,
42 : max_keys_per_list_response: Option<NonZeroU32>,
43 : concurrency_limiter: ConcurrencyLimiter,
44 : // Per-request timeout. Accessible for tests.
45 : pub timeout: Duration,
46 : }
47 :
48 : impl AzureBlobStorage {
49 6 : pub fn new(azure_config: &AzureConfig, timeout: Duration) -> Result<Self> {
50 6 : debug!(
51 0 : "Creating azure remote storage for azure container {}",
52 : azure_config.container_name
53 : );
54 :
55 6 : let account = env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT");
56 :
57 : // If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
58 : // otherwise try the token based credentials.
59 6 : let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
60 6 : StorageCredentials::access_key(account.clone(), access_key)
61 : } else {
62 0 : let token_credential = DefaultAzureCredential::default();
63 0 : StorageCredentials::token_credential(Arc::new(token_credential))
64 : };
65 :
66 : // we have an outer retry
67 6 : let builder = ClientBuilder::new(account, credentials).retry(RetryOptions::none());
68 6 :
69 6 : let client = builder.container_client(azure_config.container_name.to_owned());
70 :
71 6 : let max_keys_per_list_response =
72 6 : if let Some(limit) = azure_config.max_keys_per_list_response {
73 : Some(
74 2 : NonZeroU32::new(limit as u32)
75 2 : .ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
76 : )
77 : } else {
78 4 : None
79 : };
80 :
81 6 : Ok(AzureBlobStorage {
82 6 : client,
83 6 : prefix_in_container: azure_config.prefix_in_container.to_owned(),
84 6 : max_keys_per_list_response,
85 6 : concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
86 6 : timeout,
87 6 : })
88 6 : }
89 :
90 107 : pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
91 107 : assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
92 107 : let path_string = path
93 107 : .get_path()
94 107 : .as_str()
95 107 : .trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
96 107 : match &self.prefix_in_container {
97 107 : Some(prefix) => {
98 107 : if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
99 107 : prefix.clone() + path_string
100 : } else {
101 0 : format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
102 : }
103 : }
104 0 : None => path_string.to_string(),
105 : }
106 107 : }
107 :
108 53 : fn name_to_relative_path(&self, key: &str) -> RemotePath {
109 53 : let relative_path =
110 53 : match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
111 53 : Some(stripped) => stripped,
112 : // we rely on Azure to return properly prefixed paths
113 : // for requests with a certain prefix
114 0 : None => panic!(
115 0 : "Key {key} does not start with container prefix {:?}",
116 0 : self.prefix_in_container
117 0 : ),
118 : };
119 53 : RemotePath(
120 53 : relative_path
121 53 : .split(REMOTE_STORAGE_PREFIX_SEPARATOR)
122 53 : .collect(),
123 53 : )
124 53 : }
125 :
126 7 : async fn download_for_builder(
127 7 : &self,
128 7 : builder: GetBlobBuilder,
129 7 : cancel: &CancellationToken,
130 7 : ) -> Result<Download, DownloadError> {
131 0 : let kind = RequestKind::Get;
132 :
133 0 : let _permit = self.permit(kind, cancel).await?;
134 0 : let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
135 0 : let cancel_or_timeout_ = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
136 0 :
137 0 : let mut etag = None;
138 0 : let mut last_modified = None;
139 0 : let mut metadata = HashMap::new();
140 0 :
141 0 : let download = async {
142 0 : let response = builder
143 0 : // convert to concrete Pageable
144 0 : .into_stream()
145 0 : // convert to TryStream
146 0 : .into_stream()
147 0 : .map_err(to_download_error);
148 0 :
149 0 : // apply per request timeout
150 0 : let response = tokio_stream::StreamExt::timeout(response, self.timeout);
151 0 :
152 0 : // flatten
153 0 : let response = response.map(|res| match res {
154 0 : Ok(res) => res,
155 0 : Err(_elapsed) => Err(DownloadError::Timeout),
156 0 : });
157 0 :
158 0 : let mut response = Box::pin(response);
159 :
160 0 : let Some(part) = response.next().await else {
161 0 : return Err(DownloadError::Other(anyhow::anyhow!(
162 0 : "Azure GET response contained no response body"
163 0 : )));
164 : };
165 0 : let part = part?;
166 0 : if etag.is_none() {
167 0 : etag = Some(part.blob.properties.etag);
168 0 : }
169 0 : if last_modified.is_none() {
170 0 : last_modified = Some(part.blob.properties.last_modified.into());
171 0 : }
172 0 : if let Some(blob_meta) = part.blob.metadata {
173 0 : metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
174 0 : }
175 :
176 : // unwrap safety: if these were None, bufs would be empty and we would have returned an error already
177 0 : let etag = etag.unwrap();
178 0 : let last_modified = last_modified.unwrap();
179 0 :
180 0 : let tail_stream = response
181 0 : .map(|part| match part {
182 0 : Ok(part) => Either::Left(part.data.map(|r| r.map_err(io::Error::other))),
183 0 : Err(e) => {
184 0 : Either::Right(futures::stream::once(async { Err(io::Error::other(e)) }))
185 : }
186 0 : })
187 0 : .flatten();
188 0 : let stream = part
189 0 : .data
190 0 : .map(|r| r.map_err(io::Error::other))
191 0 : .chain(sync_wrapper::SyncStream::new(tail_stream));
192 0 : //.chain(SyncStream::from_pin(Box::pin(tail_stream)));
193 0 :
194 0 : let download_stream = crate::support::DownloadStream::new(cancel_or_timeout_, stream);
195 0 :
196 0 : Ok(Download {
197 0 : download_stream: Box::pin(download_stream),
198 0 : etag,
199 0 : last_modified,
200 0 : metadata: Some(StorageMetadata(metadata)),
201 0 : })
202 0 : };
203 :
204 : tokio::select! {
205 : bufs = download => bufs,
206 : cancel_or_timeout = cancel_or_timeout => match cancel_or_timeout {
207 : TimeoutOrCancel::Timeout => Err(DownloadError::Timeout),
208 : TimeoutOrCancel::Cancel => Err(DownloadError::Cancelled),
209 : },
210 : }
211 0 : }
212 :
213 108 : async fn permit(
214 108 : &self,
215 108 : kind: RequestKind,
216 108 : cancel: &CancellationToken,
217 108 : ) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
218 0 : let acquire = self.concurrency_limiter.acquire(kind);
219 :
220 : tokio::select! {
221 : permit = acquire => Ok(permit.expect("never closed")),
222 : _ = cancel.cancelled() => Err(Cancelled),
223 : }
224 0 : }
225 : }
226 :
227 0 : fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
228 0 : let mut res = Metadata::new();
229 0 : for (k, v) in metadata.0.into_iter() {
230 0 : res.insert(k, v);
231 0 : }
232 0 : res
233 0 : }
234 :
235 0 : fn to_download_error(error: azure_core::Error) -> DownloadError {
236 0 : if let Some(http_err) = error.as_http_error() {
237 0 : match http_err.status() {
238 0 : StatusCode::NotFound => DownloadError::NotFound,
239 0 : StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
240 0 : _ => DownloadError::Other(anyhow::Error::new(error)),
241 : }
242 : } else {
243 0 : DownloadError::Other(error.into())
244 : }
245 0 : }
246 :
247 : impl RemoteStorage for AzureBlobStorage {
248 6 : async fn list(
249 6 : &self,
250 6 : prefix: Option<&RemotePath>,
251 6 : mode: ListingMode,
252 6 : max_keys: Option<NonZeroU32>,
253 6 : cancel: &CancellationToken,
254 6 : ) -> anyhow::Result<Listing, DownloadError> {
255 0 : let _permit = self.permit(RequestKind::List, cancel).await?;
256 :
257 0 : let op = async {
258 0 : // get the passed prefix or if it is not set use prefix_in_bucket value
259 0 : let list_prefix = prefix
260 0 : .map(|p| self.relative_path_to_name(p))
261 0 : .or_else(|| self.prefix_in_container.clone())
262 0 : .map(|mut p| {
263 : // required to end with a separator
264 : // otherwise request will return only the entry of a prefix
265 0 : if matches!(mode, ListingMode::WithDelimiter)
266 0 : && !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
267 0 : {
268 0 : p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
269 0 : }
270 0 : p
271 0 : });
272 0 :
273 0 : let mut builder = self.client.list_blobs();
274 0 :
275 0 : if let ListingMode::WithDelimiter = mode {
276 0 : builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
277 0 : }
278 :
279 0 : if let Some(prefix) = list_prefix {
280 0 : builder = builder.prefix(Cow::from(prefix.to_owned()));
281 0 : }
282 :
283 0 : if let Some(limit) = self.max_keys_per_list_response {
284 0 : builder = builder.max_results(MaxResults::new(limit));
285 0 : }
286 :
287 0 : let response = builder.into_stream();
288 0 : let response = response.into_stream().map_err(to_download_error);
289 0 : let response = tokio_stream::StreamExt::timeout(response, self.timeout);
290 0 : let response = response.map(|res| match res {
291 0 : Ok(res) => res,
292 0 : Err(_elapsed) => Err(DownloadError::Timeout),
293 0 : });
294 0 :
295 0 : let mut response = std::pin::pin!(response);
296 0 :
297 0 : let mut res = Listing::default();
298 0 :
299 0 : let mut max_keys = max_keys.map(|mk| mk.get());
300 0 : while let Some(entry) = response.next().await {
301 0 : let entry = entry?;
302 0 : let prefix_iter = entry
303 0 : .blobs
304 0 : .prefixes()
305 0 : .map(|prefix| self.name_to_relative_path(&prefix.name));
306 0 : res.prefixes.extend(prefix_iter);
307 0 :
308 0 : let blob_iter = entry
309 0 : .blobs
310 0 : .blobs()
311 0 : .map(|k| self.name_to_relative_path(&k.name));
312 :
313 0 : for key in blob_iter {
314 0 : res.keys.push(key);
315 :
316 0 : if let Some(mut mk) = max_keys {
317 0 : assert!(mk > 0);
318 0 : mk -= 1;
319 0 : if mk == 0 {
320 0 : return Ok(res); // limit reached
321 0 : }
322 0 : max_keys = Some(mk);
323 0 : }
324 : }
325 : }
326 :
327 0 : Ok(res)
328 0 : };
329 :
330 : tokio::select! {
331 : res = op => res,
332 : _ = cancel.cancelled() => Err(DownloadError::Cancelled),
333 : }
334 0 : }
335 :
336 0 : async fn upload(
337 0 : &self,
338 0 : from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
339 0 : data_size_bytes: usize,
340 0 : to: &RemotePath,
341 0 : metadata: Option<StorageMetadata>,
342 0 : cancel: &CancellationToken,
343 0 : ) -> anyhow::Result<()> {
344 0 : let _permit = self.permit(RequestKind::Put, cancel).await?;
345 :
346 0 : let op = async {
347 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
348 0 :
349 0 : let from: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>> =
350 0 : Box::pin(from);
351 0 :
352 0 : let from = NonSeekableStream::new(from, data_size_bytes);
353 0 :
354 0 : let body = azure_core::Body::SeekableStream(Box::new(from));
355 0 :
356 0 : let mut builder = blob_client.put_block_blob(body);
357 :
358 0 : if let Some(metadata) = metadata {
359 0 : builder = builder.metadata(to_azure_metadata(metadata));
360 0 : }
361 :
362 0 : let fut = builder.into_future();
363 0 : let fut = tokio::time::timeout(self.timeout, fut);
364 0 :
365 0 : match fut.await {
366 0 : Ok(Ok(_response)) => Ok(()),
367 0 : Ok(Err(azure)) => Err(azure.into()),
368 0 : Err(_timeout) => Err(TimeoutOrCancel::Cancel.into()),
369 : }
370 0 : };
371 :
372 : tokio::select! {
373 : res = op => res,
374 : _ = cancel.cancelled() => Err(TimeoutOrCancel::Cancel.into()),
375 : }
376 0 : }
377 :
378 2 : async fn download(
379 2 : &self,
380 2 : from: &RemotePath,
381 2 : cancel: &CancellationToken,
382 2 : ) -> Result<Download, DownloadError> {
383 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
384 0 :
385 0 : let builder = blob_client.get();
386 0 :
387 0 : self.download_for_builder(builder, cancel).await
388 0 : }
389 :
390 5 : async fn download_byte_range(
391 5 : &self,
392 5 : from: &RemotePath,
393 5 : start_inclusive: u64,
394 5 : end_exclusive: Option<u64>,
395 5 : cancel: &CancellationToken,
396 5 : ) -> Result<Download, DownloadError> {
397 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
398 0 :
399 0 : let mut builder = blob_client.get();
400 :
401 0 : let range: Range = if let Some(end_exclusive) = end_exclusive {
402 0 : (start_inclusive..end_exclusive).into()
403 : } else {
404 0 : (start_inclusive..).into()
405 : };
406 0 : builder = builder.range(range);
407 0 :
408 0 : self.download_for_builder(builder, cancel).await
409 0 : }
410 :
411 44 : async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
412 0 : self.delete_objects(std::array::from_ref(path), cancel)
413 0 : .await
414 0 : }
415 :
416 47 : async fn delete_objects<'a>(
417 47 : &self,
418 47 : paths: &'a [RemotePath],
419 47 : cancel: &CancellationToken,
420 47 : ) -> anyhow::Result<()> {
421 0 : let _permit = self.permit(RequestKind::Delete, cancel).await?;
422 :
423 0 : let op = async {
424 : // TODO batch requests are also not supported by the SDK
425 : // https://github.com/Azure/azure-sdk-for-rust/issues/1068
426 : // https://github.com/Azure/azure-sdk-for-rust/issues/1249
427 0 : for path in paths {
428 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(path));
429 0 :
430 0 : let request = blob_client.delete().into_future();
431 :
432 0 : let res = tokio::time::timeout(self.timeout, request).await;
433 :
434 0 : match res {
435 0 : Ok(Ok(_response)) => continue,
436 0 : Ok(Err(e)) => {
437 0 : if let Some(http_err) = e.as_http_error() {
438 0 : if http_err.status() == StatusCode::NotFound {
439 0 : continue;
440 0 : }
441 0 : }
442 0 : return Err(e.into());
443 : }
444 0 : Err(_elapsed) => return Err(TimeoutOrCancel::Timeout.into()),
445 : }
446 : }
447 :
448 0 : Ok(())
449 0 : };
450 :
451 : tokio::select! {
452 : res = op => res,
453 : _ = cancel.cancelled() => Err(TimeoutOrCancel::Cancel.into()),
454 : }
455 0 : }
456 :
457 1 : async fn copy(
458 1 : &self,
459 1 : from: &RemotePath,
460 1 : to: &RemotePath,
461 1 : cancel: &CancellationToken,
462 1 : ) -> anyhow::Result<()> {
463 0 : let _permit = self.permit(RequestKind::Copy, cancel).await?;
464 :
465 0 : let timeout = tokio::time::sleep(self.timeout);
466 0 :
467 0 : let mut copy_status = None;
468 0 :
469 0 : let op = async {
470 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
471 :
472 0 : let source_url = format!(
473 0 : "{}/{}",
474 0 : self.client.url()?,
475 0 : self.relative_path_to_name(from)
476 : );
477 :
478 0 : let builder = blob_client.copy(Url::from_str(&source_url)?);
479 0 : let copy = builder.into_future();
480 :
481 0 : let result = copy.await?;
482 :
483 0 : copy_status = Some(result.copy_status);
484 : loop {
485 0 : match copy_status.as_ref().expect("we always set it to Some") {
486 : CopyStatus::Aborted => {
487 0 : anyhow::bail!("Received abort for copy from {from} to {to}.");
488 : }
489 : CopyStatus::Failed => {
490 0 : anyhow::bail!("Received failure response for copy from {from} to {to}.");
491 : }
492 0 : CopyStatus::Success => return Ok(()),
493 0 : CopyStatus::Pending => (),
494 0 : }
495 0 : // The copy is taking longer. Waiting a second and then re-trying.
496 0 : // TODO estimate time based on copy_progress and adjust time based on that
497 0 : tokio::time::sleep(Duration::from_millis(1000)).await;
498 0 : let properties = blob_client.get_properties().into_future().await?;
499 0 : let Some(status) = properties.blob.properties.copy_status else {
500 0 : tracing::warn!("copy_status for copy is None!, from={from}, to={to}");
501 0 : return Ok(());
502 : };
503 0 : copy_status = Some(status);
504 : }
505 0 : };
506 :
507 : tokio::select! {
508 : res = op => res,
509 : _ = cancel.cancelled() => Err(anyhow::Error::new(TimeoutOrCancel::Cancel)),
510 : _ = timeout => {
511 : let e = anyhow::Error::new(TimeoutOrCancel::Timeout);
512 : let e = e.context(format!("Timeout, last status: {copy_status:?}"));
513 : Err(e)
514 : },
515 : }
516 0 : }
517 :
518 0 : async fn time_travel_recover(
519 0 : &self,
520 0 : _prefix: Option<&RemotePath>,
521 0 : _timestamp: SystemTime,
522 0 : _done_if_after: SystemTime,
523 0 : _cancel: &CancellationToken,
524 0 : ) -> Result<(), TimeTravelError> {
525 0 : // TODO use Azure point in time recovery feature for this
526 0 : // https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
527 0 : Err(TimeTravelError::Unimplemented)
528 0 : }
529 :
530 0 : fn activity(&self) -> RemoteStorageActivity {
531 0 : self.concurrency_limiter.activity()
532 0 : }
533 : }
534 :
535 : pin_project_lite::pin_project! {
536 : /// Hack to work around not being able to stream once with azure sdk.
537 : ///
538 : /// Azure sdk clones streams around with the assumption that they are like
539 : /// `Arc<tokio::fs::File>` (except not supporting tokio), however our streams are not like
540 : /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`]
541 : /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially"
542 : /// seekable, but we can also just re-try the request easier.
543 : #[project = NonSeekableStreamProj]
544 : enum NonSeekableStream<S> {
545 : /// A stream wrappers initial form.
546 : ///
547 : /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1
548 : /// clone before first request, then this must be changed.
549 : Initial {
550 : inner: std::sync::Mutex<Option<tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>>>,
551 : len: usize,
552 : },
553 : /// The actually readable variant, produced by cloning the Initial variant.
554 : ///
555 : /// The sdk currently always clones once, even without retry policy.
556 : Actual {
557 : #[pin]
558 : inner: tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>,
559 : len: usize,
560 : read_any: bool,
561 : },
562 : /// Most likely unneeded, but left to make life easier, in case more clones are added.
563 : Cloned {
564 : len_was: usize,
565 : }
566 : }
567 : }
568 :
569 : impl<S> NonSeekableStream<S>
570 : where
571 : S: Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
572 : {
573 0 : fn new(inner: S, len: usize) -> NonSeekableStream<S> {
574 0 : use tokio_util::compat::TokioAsyncReadCompatExt;
575 0 :
576 0 : let inner = tokio_util::io::StreamReader::new(inner).compat();
577 0 : let inner = Some(inner);
578 0 : let inner = std::sync::Mutex::new(inner);
579 0 : NonSeekableStream::Initial { inner, len }
580 0 : }
581 : }
582 :
583 : impl<S> std::fmt::Debug for NonSeekableStream<S> {
584 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
585 0 : match self {
586 0 : Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(),
587 0 : Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(),
588 0 : Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(),
589 : }
590 0 : }
591 : }
592 :
593 : impl<S> futures::io::AsyncRead for NonSeekableStream<S>
594 : where
595 : S: Stream<Item = std::io::Result<Bytes>>,
596 : {
597 0 : fn poll_read(
598 0 : self: std::pin::Pin<&mut Self>,
599 0 : cx: &mut std::task::Context<'_>,
600 0 : buf: &mut [u8],
601 0 : ) -> std::task::Poll<std::io::Result<usize>> {
602 0 : match self.project() {
603 : NonSeekableStreamProj::Actual {
604 0 : inner, read_any, ..
605 0 : } => {
606 0 : *read_any = true;
607 0 : inner.poll_read(cx, buf)
608 : }
609 : // NonSeekableStream::Initial does not support reading because it is just much easier
610 : // to have the mutex in place where one does not poll the contents, or that's how it
611 : // seemed originally. If there is a version upgrade which changes the cloning, then
612 : // that support needs to be hacked in.
613 : //
614 : // including {self:?} into the message would be useful, but unsure how to unproject.
615 0 : _ => std::task::Poll::Ready(Err(std::io::Error::new(
616 0 : std::io::ErrorKind::Other,
617 0 : "cloned or initial values cannot be read",
618 0 : ))),
619 : }
620 0 : }
621 : }
622 :
623 : impl<S> Clone for NonSeekableStream<S> {
624 : /// Weird clone implementation exists to support the sdk doing cloning before issuing the first
625 : /// request, see type documentation.
626 0 : fn clone(&self) -> Self {
627 0 : use NonSeekableStream::*;
628 0 :
629 0 : match self {
630 0 : Initial { inner, len } => {
631 0 : if let Some(inner) = inner.lock().unwrap().take() {
632 0 : Actual {
633 0 : inner,
634 0 : len: *len,
635 0 : read_any: false,
636 0 : }
637 : } else {
638 0 : Self::Cloned { len_was: *len }
639 : }
640 : }
641 0 : Actual { len, .. } => Cloned { len_was: *len },
642 0 : Cloned { len_was } => Cloned { len_was: *len_was },
643 : }
644 0 : }
645 : }
646 :
647 : #[async_trait::async_trait]
648 : impl<S> azure_core::SeekableStream for NonSeekableStream<S>
649 : where
650 : S: Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync + 'static,
651 : {
652 0 : async fn reset(&mut self) -> azure_core::error::Result<()> {
653 0 : use NonSeekableStream::*;
654 0 :
655 0 : let msg = match self {
656 0 : Initial { inner, .. } => {
657 0 : if inner.get_mut().unwrap().is_some() {
658 0 : return Ok(());
659 0 : } else {
660 0 : "reset after first clone is not supported"
661 0 : }
662 0 : }
663 0 : Actual { read_any, .. } if !*read_any => return Ok(()),
664 0 : Actual { .. } => "reset after reading is not supported",
665 0 : Cloned { .. } => "reset after second clone is not supported",
666 0 : };
667 0 : Err(azure_core::error::Error::new(
668 0 : azure_core::error::ErrorKind::Io,
669 0 : std::io::Error::new(std::io::ErrorKind::Other, msg),
670 0 : ))
671 0 : }
672 :
673 : // Note: it is not documented if this should be the total or remaining length, total passes the
674 : // tests.
675 0 : fn len(&self) -> usize {
676 0 : use NonSeekableStream::*;
677 0 : match self {
678 0 : Initial { len, .. } => *len,
679 0 : Actual { len, .. } => *len,
680 0 : Cloned { len_was, .. } => *len_was,
681 : }
682 0 : }
683 : }
|