Line data Source code
1 : //! Azure Blob Storage wrapper
2 :
3 : use std::borrow::Cow;
4 : use std::collections::HashMap;
5 : use std::fmt::Display;
6 : use std::num::NonZeroU32;
7 : use std::pin::Pin;
8 : use std::str::FromStr;
9 : use std::sync::Arc;
10 : use std::time::{Duration, SystemTime};
11 : use std::{env, io};
12 :
13 : use anyhow::{Context, Result, anyhow};
14 : use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
15 : use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions};
16 : use azure_storage::StorageCredentials;
17 : use azure_storage_blobs::blob::BlobBlockType;
18 : use azure_storage_blobs::blob::BlockList;
19 : use azure_storage_blobs::blob::{Blob, CopyStatus};
20 : use azure_storage_blobs::container::operations::ListBlobsBuilder;
21 : use azure_storage_blobs::prelude::ClientBuilder;
22 : use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
23 : use base64::{Engine as _, engine::general_purpose::URL_SAFE};
24 : use byteorder::{BigEndian, ByteOrder};
25 : use bytes::Bytes;
26 : use camino::Utf8Path;
27 : use futures::FutureExt;
28 : use futures::future::Either;
29 : use futures::stream::Stream;
30 : use futures_util::{StreamExt, TryStreamExt};
31 : use http_types::{StatusCode, Url};
32 : use scopeguard::ScopeGuard;
33 : use tokio::fs::File;
34 : use tokio::io::AsyncReadExt;
35 : use tokio::io::AsyncSeekExt;
36 : use tokio_util::sync::CancellationToken;
37 : use tracing::debug;
38 : use utils::backoff;
39 : use utils::backoff::exponential_backoff_duration_seconds;
40 :
41 : use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
42 : use crate::config::AzureConfig;
43 : use crate::error::Cancelled;
44 : use crate::metrics::{AttemptOutcome, RequestKind, start_measuring_requests};
45 : use crate::{
46 : ConcurrencyLimiter, Download, DownloadError, DownloadKind, DownloadOpts, Listing, ListingMode,
47 : ListingObject, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
48 : Version, VersionKind,
49 : };
50 :
51 : pub struct AzureBlobStorage {
52 : client: ContainerClient,
53 : container_name: String,
54 : prefix_in_container: Option<String>,
55 : max_keys_per_list_response: Option<NonZeroU32>,
56 : concurrency_limiter: ConcurrencyLimiter,
57 : // Per-request timeout. Accessible for tests.
58 : pub timeout: Duration,
59 :
60 : // Alternative timeout used for metadata objects which are expected to be small
61 : pub small_timeout: Duration,
62 : /* BEGIN_HADRON */
63 : pub put_block_size_mb: Option<usize>,
64 : /* END_HADRON */
65 : }
66 :
67 : impl AzureBlobStorage {
68 10 : pub fn new(
69 10 : azure_config: &AzureConfig,
70 10 : timeout: Duration,
71 10 : small_timeout: Duration,
72 10 : ) -> Result<Self> {
73 10 : debug!(
74 0 : "Creating azure remote storage for azure container {}",
75 : azure_config.container_name
76 : );
77 :
78 : // Use the storage account from the config by default, fall back to env var if not present.
79 10 : let account = azure_config.storage_account.clone().unwrap_or_else(|| {
80 10 : env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT")
81 10 : });
82 :
83 : // If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
84 : // otherwise try the token based credentials.
85 10 : let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
86 10 : StorageCredentials::access_key(account.clone(), access_key)
87 : } else {
88 0 : let token_credential = azure_identity::create_default_credential()
89 0 : .context("trying to obtain Azure default credentials")?;
90 0 : StorageCredentials::token_credential(token_credential)
91 : };
92 :
93 10 : let builder = ClientBuilder::new(account, credentials)
94 : // we have an outer retry
95 10 : .retry(RetryOptions::none())
96 : // Customize transport to configure conneciton pooling
97 10 : .transport(TransportOptions::new(Self::reqwest_client(
98 10 : azure_config.conn_pool_size,
99 10 : )));
100 :
101 10 : let client = builder.container_client(azure_config.container_name.to_owned());
102 :
103 10 : let max_keys_per_list_response =
104 10 : if let Some(limit) = azure_config.max_keys_per_list_response {
105 : Some(
106 4 : NonZeroU32::new(limit as u32)
107 4 : .ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
108 : )
109 : } else {
110 6 : None
111 : };
112 :
113 10 : Ok(AzureBlobStorage {
114 10 : client,
115 10 : container_name: azure_config.container_name.to_owned(),
116 10 : prefix_in_container: azure_config.prefix_in_container.to_owned(),
117 10 : max_keys_per_list_response,
118 10 : concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
119 10 : timeout,
120 10 : small_timeout,
121 10 : /* BEGIN_HADRON */
122 10 : put_block_size_mb: azure_config.put_block_size_mb,
123 10 : /* END_HADRON */
124 10 : })
125 10 : }
126 :
127 10 : fn reqwest_client(conn_pool_size: usize) -> Arc<dyn HttpClient> {
128 10 : let client = reqwest::ClientBuilder::new()
129 10 : .pool_max_idle_per_host(conn_pool_size)
130 10 : .build()
131 10 : .expect("failed to build `reqwest` client");
132 10 : Arc::new(client)
133 10 : }
134 :
135 237 : pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
136 237 : assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
137 237 : let path_string = path.get_path().as_str();
138 237 : match &self.prefix_in_container {
139 237 : Some(prefix) => {
140 237 : if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
141 237 : prefix.clone() + path_string
142 : } else {
143 0 : format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
144 : }
145 : }
146 0 : None => path_string.to_string(),
147 : }
148 237 : }
149 :
150 249 : fn name_to_relative_path(&self, key: &str) -> RemotePath {
151 249 : let relative_path =
152 249 : match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
153 249 : Some(stripped) => stripped,
154 : // we rely on Azure to return properly prefixed paths
155 : // for requests with a certain prefix
156 0 : None => panic!(
157 0 : "Key {key} does not start with container prefix {:?}",
158 : self.prefix_in_container
159 : ),
160 : };
161 249 : RemotePath(
162 249 : relative_path
163 249 : .split(REMOTE_STORAGE_PREFIX_SEPARATOR)
164 249 : .collect(),
165 249 : )
166 249 : }
167 :
168 11 : async fn download_for_builder(
169 11 : &self,
170 11 : builder: GetBlobBuilder,
171 11 : timeout: Duration,
172 11 : cancel: &CancellationToken,
173 11 : ) -> Result<Download, DownloadError> {
174 11 : let kind = RequestKind::Get;
175 :
176 11 : let _permit = self.permit(kind, cancel).await?;
177 11 : let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
178 11 : let cancel_or_timeout_ = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
179 :
180 11 : let mut etag = None;
181 11 : let mut last_modified = None;
182 11 : let mut metadata = HashMap::new();
183 :
184 11 : let started_at = start_measuring_requests(kind);
185 :
186 11 : let download = async {
187 11 : let response = builder
188 : // convert to concrete Pageable
189 11 : .into_stream()
190 : // convert to TryStream
191 11 : .into_stream()
192 11 : .map_err(to_download_error);
193 :
194 : // apply per request timeout
195 11 : let response = tokio_stream::StreamExt::timeout(response, timeout);
196 :
197 : // flatten
198 11 : let response = response.map(|res| match res {
199 11 : Ok(res) => res,
200 0 : Err(_elapsed) => Err(DownloadError::Timeout),
201 11 : });
202 :
203 11 : let mut response = Box::pin(response);
204 :
205 11 : let Some(part) = response.next().await else {
206 0 : return Err(DownloadError::Other(anyhow::anyhow!(
207 0 : "Azure GET response contained no response body"
208 0 : )));
209 : };
210 11 : let part = part?;
211 9 : if etag.is_none() {
212 9 : etag = Some(part.blob.properties.etag);
213 9 : }
214 9 : if last_modified.is_none() {
215 9 : last_modified = Some(part.blob.properties.last_modified.into());
216 9 : }
217 9 : if let Some(blob_meta) = part.blob.metadata {
218 0 : metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
219 9 : }
220 :
221 : // unwrap safety: if these were None, bufs would be empty and we would have returned an error already
222 9 : let etag = etag.unwrap();
223 9 : let last_modified = last_modified.unwrap();
224 :
225 9 : let tail_stream = response
226 9 : .map(|part| match part {
227 0 : Ok(part) => Either::Left(part.data.map(|r| r.map_err(io::Error::other))),
228 0 : Err(e) => {
229 0 : Either::Right(futures::stream::once(async { Err(io::Error::other(e)) }))
230 : }
231 0 : })
232 9 : .flatten();
233 9 : let stream = part
234 9 : .data
235 9 : .map(|r| r.map_err(io::Error::other))
236 9 : .chain(sync_wrapper::SyncStream::new(tail_stream));
237 : //.chain(SyncStream::from_pin(Box::pin(tail_stream)));
238 :
239 9 : let download_stream = crate::support::DownloadStream::new(cancel_or_timeout_, stream);
240 :
241 9 : Ok(Download {
242 9 : download_stream: Box::pin(download_stream),
243 9 : etag,
244 9 : last_modified,
245 9 : metadata: Some(StorageMetadata(metadata)),
246 9 : })
247 11 : };
248 :
249 11 : let download = tokio::select! {
250 11 : bufs = download => bufs,
251 11 : cancel_or_timeout = cancel_or_timeout => match cancel_or_timeout {
252 0 : TimeoutOrCancel::Timeout => return Err(DownloadError::Timeout),
253 0 : TimeoutOrCancel::Cancel => return Err(DownloadError::Cancelled),
254 : },
255 : };
256 11 : let started_at = ScopeGuard::into_inner(started_at);
257 11 : let outcome = match &download {
258 9 : Ok(_) => AttemptOutcome::Ok,
259 : // At this level in the stack 404 and 304 responses do not indicate an error.
260 : // There's expected cases when a blob may not exist or hasn't been modified since
261 : // the last get (e.g. probing for timeline indices and heatmap downloads).
262 : // Callers should handle errors if they are unexpected.
263 2 : Err(DownloadError::NotFound | DownloadError::Unmodified) => AttemptOutcome::Ok,
264 0 : Err(_) => AttemptOutcome::Err,
265 : };
266 11 : crate::metrics::BUCKET_METRICS
267 11 : .req_seconds
268 11 : .observe_elapsed(kind, outcome, started_at);
269 11 : download
270 11 : }
271 :
272 26 : fn list_streaming_for_fn<T: Default + ListingCollector>(
273 26 : &self,
274 26 : prefix: Option<&RemotePath>,
275 26 : mode: ListingMode,
276 26 : max_keys: Option<NonZeroU32>,
277 26 : cancel: &CancellationToken,
278 26 : request_kind: RequestKind,
279 26 : customize_builder: impl Fn(ListBlobsBuilder) -> ListBlobsBuilder,
280 26 : ) -> impl Stream<Item = Result<T, DownloadError>> {
281 : // get the passed prefix or if it is not set use prefix_in_bucket value
282 26 : let list_prefix = prefix.map(|p| self.relative_path_to_name(p)).or_else(|| {
283 10 : self.prefix_in_container.clone().map(|mut s| {
284 10 : if !s.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
285 0 : s.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
286 10 : }
287 10 : s
288 10 : })
289 10 : });
290 :
291 26 : async_stream::stream! {
292 : let _permit = self.permit(request_kind, cancel).await?;
293 :
294 : let mut builder = self.client.list_blobs();
295 :
296 26 : if let ListingMode::WithDelimiter = mode {
297 : builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
298 : }
299 :
300 26 : if let Some(prefix) = list_prefix {
301 : builder = builder.prefix(Cow::from(prefix.to_owned()));
302 : }
303 :
304 26 : if let Some(limit) = self.max_keys_per_list_response {
305 : builder = builder.max_results(MaxResults::new(limit));
306 : }
307 :
308 : builder = customize_builder(builder);
309 :
310 : let mut next_marker = None;
311 :
312 : let mut timeout_try_cnt = 1;
313 :
314 : 'outer: loop {
315 : let mut builder = builder.clone();
316 : if let Some(marker) = next_marker.clone() {
317 : builder = builder.marker(marker);
318 : }
319 : // Azure Blob Rust SDK does not expose the list blob API directly. Users have to use
320 : // their pageable iterator wrapper that returns all keys as a stream. We want to have
321 : // full control of paging, and therefore we only take the first item from the stream.
322 : let mut response_stream = builder.into_stream();
323 : let response = response_stream.next();
324 : // Timeout mechanism: Azure client will sometimes stuck on a request, but retrying that request
325 : // would immediately succeed. Therefore, we use exponential backoff timeout to retry the request.
326 : // (Usually, exponential backoff is used to determine the sleep time between two retries.) We
327 : // start with 10.0 second timeout, and double the timeout for each failure, up to 5 failures.
328 : // timeout = min(5 * (1.0+1.0)^n, self.timeout).
329 : let this_timeout = (5.0 * exponential_backoff_duration_seconds(timeout_try_cnt, 1.0, self.timeout.as_secs_f64())).min(self.timeout.as_secs_f64());
330 : let response = tokio::time::timeout(Duration::from_secs_f64(this_timeout), response);
331 45 : let response = response.map(|res| {
332 45 : match res {
333 45 : Ok(Some(Ok(res))) => Ok(Some(res)),
334 0 : Ok(Some(Err(e))) => Err(to_download_error(e)),
335 0 : Ok(None) => Ok(None),
336 0 : Err(_elasped) => Err(DownloadError::Timeout),
337 : }
338 45 : });
339 1 : let mut max_keys = max_keys.map(|mk| mk.get());
340 : let next_item = tokio::select! {
341 : op = response => op,
342 : _ = cancel.cancelled() => Err(DownloadError::Cancelled),
343 : };
344 :
345 : if let Err(DownloadError::Timeout) = &next_item {
346 : timeout_try_cnt += 1;
347 : if timeout_try_cnt <= 5 {
348 : continue 'outer;
349 : }
350 : }
351 :
352 : let next_item = match next_item {
353 : Ok(next_item) => next_item,
354 : Err(e) => {
355 : // The error is potentially retryable, so we must rewind the loop after yielding.
356 : yield Err(e);
357 : continue 'outer;
358 : },
359 : };
360 :
361 : // Log a warning if we saw two timeouts in a row before a successful request
362 : if timeout_try_cnt > 2 {
363 : tracing::warn!("Azure Blob Storage list timed out and succeeded after {} tries", timeout_try_cnt);
364 : }
365 : timeout_try_cnt = 1;
366 :
367 : let Some(entry) = next_item else {
368 : // The list is complete, so yield it.
369 : break;
370 : };
371 :
372 : let mut res = T::default();
373 : next_marker = entry.continuation();
374 : let prefix_iter = entry
375 : .blobs
376 : .prefixes()
377 52 : .map(|prefix| self.name_to_relative_path(&prefix.name));
378 : res.add_prefixes(self, prefix_iter);
379 :
380 : let blob_iter = entry
381 : .blobs
382 : .blobs();
383 :
384 : for key in blob_iter {
385 : res.add_blob(self, key);
386 :
387 : if let Some(mut mk) = max_keys {
388 : assert!(mk > 0);
389 : mk -= 1;
390 : if mk == 0 {
391 : yield Ok(res); // limit reached
392 : break 'outer;
393 : }
394 : max_keys = Some(mk);
395 : }
396 : }
397 : yield Ok(res);
398 :
399 : // We are done here
400 : if next_marker.is_none() {
401 : break;
402 : }
403 : }
404 : }
405 26 : }
406 :
407 227 : async fn permit(
408 227 : &self,
409 227 : kind: RequestKind,
410 227 : cancel: &CancellationToken,
411 227 : ) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
412 227 : let acquire = self.concurrency_limiter.acquire(kind);
413 :
414 227 : tokio::select! {
415 227 : permit = acquire => Ok(permit.expect("never closed")),
416 227 : _ = cancel.cancelled() => Err(Cancelled),
417 : }
418 227 : }
419 :
420 0 : pub fn container_name(&self) -> &str {
421 0 : &self.container_name
422 0 : }
423 :
424 0 : async fn list_versions_with_permit(
425 0 : &self,
426 0 : _permit: &tokio::sync::SemaphorePermit<'_>,
427 0 : prefix: Option<&RemotePath>,
428 0 : mode: ListingMode,
429 0 : max_keys: Option<NonZeroU32>,
430 0 : cancel: &CancellationToken,
431 0 : ) -> Result<crate::VersionListing, DownloadError> {
432 0 : let customize_builder = |mut builder: ListBlobsBuilder| {
433 0 : builder = builder.include_versions(true);
434 : // We do not return this info back to `VersionListing` yet.
435 0 : builder = builder.include_deleted(true);
436 0 : builder
437 0 : };
438 0 : let kind = RequestKind::ListVersions;
439 :
440 0 : let mut stream = std::pin::pin!(self.list_streaming_for_fn(
441 0 : prefix,
442 0 : mode,
443 0 : max_keys,
444 0 : cancel,
445 0 : kind,
446 0 : customize_builder
447 : ));
448 0 : let mut combined: crate::VersionListing =
449 0 : stream.next().await.expect("At least one item required")?;
450 0 : while let Some(list) = stream.next().await {
451 0 : let list = list?;
452 0 : combined.versions.extend(list.versions.into_iter());
453 : }
454 0 : Ok(combined)
455 0 : }
456 : }
457 :
458 : trait ListingCollector {
459 : fn add_prefixes(&mut self, abs: &AzureBlobStorage, prefix_it: impl Iterator<Item = RemotePath>);
460 : fn add_blob(&mut self, abs: &AzureBlobStorage, blob: &Blob);
461 : }
462 :
463 : impl ListingCollector for Listing {
464 45 : fn add_prefixes(
465 45 : &mut self,
466 45 : _abs: &AzureBlobStorage,
467 45 : prefix_it: impl Iterator<Item = RemotePath>,
468 45 : ) {
469 45 : self.prefixes.extend(prefix_it);
470 45 : }
471 197 : fn add_blob(&mut self, abs: &AzureBlobStorage, blob: &Blob) {
472 197 : self.keys.push(ListingObject {
473 197 : key: abs.name_to_relative_path(&blob.name),
474 197 : last_modified: blob.properties.last_modified.into(),
475 197 : size: blob.properties.content_length,
476 197 : });
477 197 : }
478 : }
479 :
480 : impl ListingCollector for crate::VersionListing {
481 0 : fn add_prefixes(
482 0 : &mut self,
483 0 : _abs: &AzureBlobStorage,
484 0 : _prefix_it: impl Iterator<Item = RemotePath>,
485 0 : ) {
486 : // nothing
487 0 : }
488 0 : fn add_blob(&mut self, abs: &AzureBlobStorage, blob: &Blob) {
489 0 : let id = crate::VersionId(blob.version_id.clone().expect("didn't find version ID"));
490 0 : self.versions.push(crate::Version {
491 0 : key: abs.name_to_relative_path(&blob.name),
492 0 : last_modified: blob.properties.last_modified.into(),
493 0 : kind: crate::VersionKind::Version(id),
494 0 : });
495 0 : }
496 : }
497 :
498 0 : fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
499 0 : let mut res = Metadata::new();
500 0 : for (k, v) in metadata.0.into_iter() {
501 0 : res.insert(k, v);
502 0 : }
503 0 : res
504 0 : }
505 :
506 3 : fn to_download_error(error: azure_core::Error) -> DownloadError {
507 3 : if let Some(http_err) = error.as_http_error() {
508 3 : match http_err.status() {
509 1 : StatusCode::NotFound => DownloadError::NotFound,
510 2 : StatusCode::NotModified => DownloadError::Unmodified,
511 0 : StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
512 0 : _ => DownloadError::Other(anyhow::Error::new(error)),
513 : }
514 : } else {
515 0 : DownloadError::Other(error.into())
516 : }
517 3 : }
518 :
519 : impl RemoteStorage for AzureBlobStorage {
520 26 : fn list_streaming(
521 26 : &self,
522 26 : prefix: Option<&RemotePath>,
523 26 : mode: ListingMode,
524 26 : max_keys: Option<NonZeroU32>,
525 26 : cancel: &CancellationToken,
526 26 : ) -> impl Stream<Item = Result<Listing, DownloadError>> {
527 26 : let customize_builder = |builder| builder;
528 26 : let kind = RequestKind::ListVersions;
529 26 : self.list_streaming_for_fn(prefix, mode, max_keys, cancel, kind, customize_builder)
530 26 : }
531 :
532 0 : async fn list_versions(
533 0 : &self,
534 0 : prefix: Option<&RemotePath>,
535 0 : mode: ListingMode,
536 0 : max_keys: Option<NonZeroU32>,
537 0 : cancel: &CancellationToken,
538 0 : ) -> std::result::Result<crate::VersionListing, DownloadError> {
539 0 : let kind = RequestKind::ListVersions;
540 0 : let permit = self.permit(kind, cancel).await?;
541 0 : self.list_versions_with_permit(&permit, prefix, mode, max_keys, cancel)
542 0 : .await
543 0 : }
544 :
545 3 : async fn head_object(
546 3 : &self,
547 3 : key: &RemotePath,
548 3 : cancel: &CancellationToken,
549 3 : ) -> Result<ListingObject, DownloadError> {
550 3 : let kind = RequestKind::Head;
551 3 : let _permit = self.permit(kind, cancel).await?;
552 :
553 3 : let started_at = start_measuring_requests(kind);
554 :
555 3 : let blob_client = self.client.blob_client(self.relative_path_to_name(key));
556 3 : let properties_future = blob_client.get_properties().into_future();
557 :
558 3 : let properties_future = tokio::time::timeout(self.small_timeout, properties_future);
559 :
560 3 : let res = tokio::select! {
561 3 : res = properties_future => res,
562 3 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
563 : };
564 :
565 3 : if let Ok(inner) = &res {
566 3 : // do not incl. timeouts as errors in metrics but cancellations
567 3 : let started_at = ScopeGuard::into_inner(started_at);
568 3 : crate::metrics::BUCKET_METRICS
569 3 : .req_seconds
570 3 : .observe_elapsed(kind, inner, started_at);
571 3 : }
572 :
573 3 : let data = match res {
574 2 : Ok(Ok(data)) => Ok(data),
575 1 : Ok(Err(sdk)) => Err(to_download_error(sdk)),
576 0 : Err(_timeout) => Err(DownloadError::Timeout),
577 1 : }?;
578 :
579 2 : let properties = data.blob.properties;
580 2 : Ok(ListingObject {
581 2 : key: key.to_owned(),
582 2 : last_modified: SystemTime::from(properties.last_modified),
583 2 : size: properties.content_length,
584 2 : })
585 3 : }
586 :
587 93 : async fn upload(
588 93 : &self,
589 93 : from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
590 93 : data_size_bytes: usize,
591 93 : to: &RemotePath,
592 93 : metadata: Option<StorageMetadata>,
593 93 : cancel: &CancellationToken,
594 93 : ) -> anyhow::Result<()> {
595 93 : let kind = RequestKind::Put;
596 93 : let _permit = self.permit(kind, cancel).await?;
597 :
598 93 : let started_at = start_measuring_requests(kind);
599 :
600 93 : let mut metadata_map = metadata.unwrap_or([].into());
601 93 : let timeline_file_path = metadata_map.0.remove("databricks_azure_put_block");
602 :
603 : /* BEGIN_HADRON */
604 93 : let op = async move {
605 93 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
606 93 : let put_block_size = self.put_block_size_mb.unwrap_or(0) * 1024 * 1024;
607 93 : if timeline_file_path.is_none() || put_block_size == 0 {
608 : // Use put_block_blob directly.
609 72 : let from: Pin<
610 72 : Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>,
611 72 : > = Box::pin(from);
612 72 : let from = NonSeekableStream::new(from, data_size_bytes);
613 72 : let body = azure_core::Body::SeekableStream(Box::new(from));
614 :
615 72 : let mut builder = blob_client.put_block_blob(body);
616 72 : if !metadata_map.0.is_empty() {
617 0 : builder = builder.metadata(to_azure_metadata(metadata_map));
618 0 : }
619 72 : let fut = builder.into_future();
620 72 : let fut = tokio::time::timeout(self.timeout, fut);
621 72 : let result = fut.await;
622 72 : match result {
623 72 : Ok(Ok(_response)) => return Ok(()),
624 0 : Ok(Err(azure)) => return Err(azure.into()),
625 0 : Err(_timeout) => return Err(TimeoutOrCancel::Timeout.into()),
626 : };
627 0 : }
628 : // Upload chunks concurrently using Put Block.
629 : // Each PutBlock uploads put_block_size bytes of the file.
630 21 : let mut upload_futures: Vec<tokio::task::JoinHandle<Result<(), azure_core::Error>>> =
631 21 : vec![];
632 21 : let mut block_list = BlockList::default();
633 21 : let mut start_bytes = 0u64;
634 21 : let mut remaining_bytes = data_size_bytes;
635 21 : let mut block_list_count = 0;
636 :
637 42 : while remaining_bytes > 0 {
638 21 : let block_size = std::cmp::min(remaining_bytes, put_block_size);
639 21 : let end_bytes = start_bytes + block_size as u64;
640 21 : let block_id = block_list_count;
641 21 : let timeout = self.timeout;
642 21 : let blob_client = blob_client.clone();
643 21 : let timeline_file = timeline_file_path.clone().unwrap().clone();
644 :
645 21 : let mut encoded_block_id = [0u8; 8];
646 21 : BigEndian::write_u64(&mut encoded_block_id, block_id);
647 21 : URL_SAFE.encode(encoded_block_id);
648 :
649 : // Put one block.
650 21 : let part_fut = async move {
651 21 : let mut file = File::open(Utf8Path::new(&timeline_file.clone())).await?;
652 21 : file.seek(io::SeekFrom::Start(start_bytes)).await?;
653 21 : let limited_reader = file.take(block_size as u64);
654 21 : let file_chunk_stream =
655 21 : tokio_util::io::ReaderStream::with_capacity(limited_reader, 1024 * 1024);
656 21 : let file_chunk_stream_pin: Pin<
657 21 : Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>,
658 21 : > = Box::pin(file_chunk_stream);
659 21 : let stream_wrapper = NonSeekableStream::new(file_chunk_stream_pin, block_size);
660 21 : let body = azure_core::Body::SeekableStream(Box::new(stream_wrapper));
661 : // Azure put block takes URL-encoded block ids and all blocks must have the same byte length.
662 : // https://learn.microsoft.com/en-us/rest/api/storageservices/put-block?tabs=microsoft-entra-id#uri-parameters
663 21 : let builder = blob_client.put_block(encoded_block_id.to_vec(), body);
664 21 : let fut = builder.into_future();
665 21 : let fut = tokio::time::timeout(timeout, fut);
666 21 : let result = fut.await;
667 21 : tracing::debug!(
668 0 : "azure put block id-{} size {} start {} end {} file {} response {:#?}",
669 : block_id,
670 : block_size,
671 : start_bytes,
672 : end_bytes,
673 : timeline_file,
674 : result
675 : );
676 21 : match result {
677 21 : Ok(Ok(_response)) => Ok(()),
678 0 : Ok(Err(azure)) => Err(azure),
679 0 : Err(_timeout) => Err(azure_core::Error::new(
680 0 : azure_core::error::ErrorKind::Io,
681 0 : std::io::Error::new(
682 0 : std::io::ErrorKind::TimedOut,
683 0 : "Operation timed out",
684 0 : ),
685 0 : )),
686 : }
687 0 : };
688 21 : upload_futures.push(tokio::spawn(part_fut));
689 :
690 21 : block_list_count += 1;
691 21 : remaining_bytes -= block_size;
692 21 : start_bytes += block_size as u64;
693 :
694 21 : block_list
695 21 : .blocks
696 21 : .push(BlobBlockType::Uncommitted(encoded_block_id.to_vec().into()));
697 : }
698 :
699 21 : tracing::debug!(
700 0 : "azure put blocks {} total MB: {} chunk size MB: {}",
701 : block_list_count,
702 0 : data_size_bytes / 1024 / 1024,
703 0 : put_block_size / 1024 / 1024
704 : );
705 : // Wait for all blocks to be uploaded.
706 21 : let upload_results = futures::future::try_join_all(upload_futures).await;
707 21 : if upload_results.is_err() {
708 0 : return Err(anyhow::anyhow!(format!(
709 0 : "Failed to upload all blocks {:#?}",
710 0 : upload_results.unwrap_err()
711 0 : )));
712 0 : }
713 :
714 : // Commit the blocks.
715 21 : let mut builder = blob_client.put_block_list(block_list);
716 21 : if !metadata_map.0.is_empty() {
717 0 : builder = builder.metadata(to_azure_metadata(metadata_map));
718 0 : }
719 21 : let fut = builder.into_future();
720 21 : let fut = tokio::time::timeout(self.timeout, fut);
721 21 : let result = fut.await;
722 21 : tracing::debug!("azure put block list response {:#?}", result);
723 :
724 21 : match result {
725 21 : Ok(Ok(_response)) => Ok(()),
726 0 : Ok(Err(azure)) => Err(azure.into()),
727 0 : Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()),
728 : }
729 0 : };
730 : /* END_HADRON */
731 :
732 93 : let res = tokio::select! {
733 93 : res = op => res,
734 93 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
735 : };
736 :
737 93 : let outcome = match res {
738 93 : Ok(_) => AttemptOutcome::Ok,
739 0 : Err(_) => AttemptOutcome::Err,
740 : };
741 93 : let started_at = ScopeGuard::into_inner(started_at);
742 93 : crate::metrics::BUCKET_METRICS
743 93 : .req_seconds
744 93 : .observe_elapsed(kind, outcome, started_at);
745 93 : res
746 0 : }
747 :
748 11 : async fn download(
749 11 : &self,
750 11 : from: &RemotePath,
751 11 : opts: &DownloadOpts,
752 11 : cancel: &CancellationToken,
753 11 : ) -> Result<Download, DownloadError> {
754 11 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
755 :
756 11 : let mut builder = blob_client.get();
757 :
758 11 : if let Some(ref etag) = opts.etag {
759 3 : builder = builder.if_match(IfMatchCondition::NotMatch(etag.to_string()));
760 8 : }
761 :
762 11 : if let Some(ref version_id) = opts.version_id {
763 0 : let version_id = azure_storage_blobs::prelude::VersionId::new(version_id.0.clone());
764 0 : builder = builder.blob_versioning(version_id);
765 11 : }
766 :
767 11 : if let Some((start, end)) = opts.byte_range() {
768 5 : builder = builder.range(match end {
769 3 : Some(end) => Range::Range(start..end),
770 2 : None => Range::RangeFrom(start..),
771 : });
772 6 : }
773 :
774 11 : let timeout = match opts.kind {
775 0 : DownloadKind::Small => self.small_timeout,
776 11 : DownloadKind::Large => self.timeout,
777 : };
778 :
779 11 : self.download_for_builder(builder, timeout, cancel).await
780 11 : }
781 :
782 86 : async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
783 86 : self.delete_objects(std::array::from_ref(path), cancel)
784 86 : .await
785 86 : }
786 :
787 93 : async fn delete_objects(
788 93 : &self,
789 93 : paths: &[RemotePath],
790 93 : cancel: &CancellationToken,
791 93 : ) -> anyhow::Result<()> {
792 93 : let kind = RequestKind::Delete;
793 93 : let _permit = self.permit(kind, cancel).await?;
794 93 : let started_at = start_measuring_requests(kind);
795 :
796 93 : let op = async {
797 : // TODO batch requests are not supported by the SDK
798 : // https://github.com/Azure/azure-sdk-for-rust/issues/1068
799 205 : for path in paths {
800 : #[derive(Debug)]
801 : enum AzureOrTimeout {
802 : AzureError(azure_core::Error),
803 : Timeout,
804 : Cancel,
805 : }
806 : impl Display for AzureOrTimeout {
807 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
808 0 : write!(f, "{self:?}")
809 0 : }
810 : }
811 112 : let warn_threshold = 3;
812 112 : let max_retries = 5;
813 112 : backoff::retry(
814 112 : || async {
815 112 : let blob_client = self.client.blob_client(self.relative_path_to_name(path));
816 :
817 112 : let request = blob_client.delete().into_future();
818 :
819 112 : let res = tokio::time::timeout(self.timeout, request).await;
820 :
821 112 : match res {
822 90 : Ok(Ok(_v)) => Ok(()),
823 22 : Ok(Err(azure_err)) => {
824 22 : if let Some(http_err) = azure_err.as_http_error() {
825 22 : if http_err.status() == StatusCode::NotFound {
826 22 : return Ok(());
827 0 : }
828 0 : }
829 0 : Err(AzureOrTimeout::AzureError(azure_err))
830 : }
831 0 : Err(_elapsed) => Err(AzureOrTimeout::Timeout),
832 : }
833 224 : },
834 0 : |err| match err {
835 0 : AzureOrTimeout::AzureError(_) | AzureOrTimeout::Timeout => false,
836 0 : AzureOrTimeout::Cancel => true,
837 0 : },
838 112 : warn_threshold,
839 112 : max_retries,
840 112 : "deleting remote object",
841 112 : cancel,
842 : )
843 112 : .await
844 112 : .ok_or_else(|| AzureOrTimeout::Cancel)
845 112 : .and_then(|x| x)
846 112 : .map_err(|e| match e {
847 0 : AzureOrTimeout::AzureError(err) => anyhow::Error::from(err),
848 0 : AzureOrTimeout::Timeout => TimeoutOrCancel::Timeout.into(),
849 0 : AzureOrTimeout::Cancel => TimeoutOrCancel::Cancel.into(),
850 0 : })?;
851 : }
852 93 : Ok(())
853 93 : };
854 :
855 93 : let res = tokio::select! {
856 93 : res = op => res,
857 93 : _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
858 : };
859 :
860 93 : let started_at = ScopeGuard::into_inner(started_at);
861 93 : crate::metrics::BUCKET_METRICS
862 93 : .req_seconds
863 93 : .observe_elapsed(kind, &res, started_at);
864 93 : res
865 93 : }
866 :
867 0 : fn max_keys_per_delete(&self) -> usize {
868 0 : super::MAX_KEYS_PER_DELETE_AZURE
869 0 : }
870 :
871 1 : async fn copy(
872 1 : &self,
873 1 : from: &RemotePath,
874 1 : to: &RemotePath,
875 1 : cancel: &CancellationToken,
876 1 : ) -> anyhow::Result<()> {
877 1 : let kind = RequestKind::Copy;
878 1 : let _permit = self.permit(kind, cancel).await?;
879 1 : let started_at = start_measuring_requests(kind);
880 :
881 1 : let timeout = tokio::time::sleep(self.timeout);
882 :
883 1 : let mut copy_status = None;
884 :
885 1 : let op = async {
886 1 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
887 :
888 1 : let source_url = format!(
889 1 : "{}/{}",
890 1 : self.client.url()?,
891 1 : self.relative_path_to_name(from)
892 : );
893 :
894 1 : let builder = blob_client.copy(Url::from_str(&source_url)?);
895 1 : let copy = builder.into_future();
896 :
897 1 : let result = copy.await?;
898 :
899 1 : copy_status = Some(result.copy_status);
900 : loop {
901 1 : match copy_status.as_ref().expect("we always set it to Some") {
902 : CopyStatus::Aborted => {
903 0 : anyhow::bail!("Received abort for copy from {from} to {to}.");
904 : }
905 : CopyStatus::Failed => {
906 0 : anyhow::bail!("Received failure response for copy from {from} to {to}.");
907 : }
908 1 : CopyStatus::Success => return Ok(()),
909 0 : CopyStatus::Pending => (),
910 : }
911 : // The copy is taking longer. Waiting a second and then re-trying.
912 : // TODO estimate time based on copy_progress and adjust time based on that
913 0 : tokio::time::sleep(Duration::from_millis(1000)).await;
914 0 : let properties = blob_client.get_properties().into_future().await?;
915 0 : let Some(status) = properties.blob.properties.copy_status else {
916 0 : tracing::warn!("copy_status for copy is None!, from={from}, to={to}");
917 0 : return Ok(());
918 : };
919 0 : copy_status = Some(status);
920 : }
921 1 : };
922 :
923 1 : let res = tokio::select! {
924 1 : res = op => res,
925 1 : _ = cancel.cancelled() => return Err(anyhow::Error::new(TimeoutOrCancel::Cancel)),
926 1 : _ = timeout => {
927 0 : let e = anyhow::Error::new(TimeoutOrCancel::Timeout);
928 0 : let e = e.context(format!("Timeout, last status: {copy_status:?}"));
929 0 : Err(e)
930 : },
931 : };
932 :
933 1 : let started_at = ScopeGuard::into_inner(started_at);
934 1 : crate::metrics::BUCKET_METRICS
935 1 : .req_seconds
936 1 : .observe_elapsed(kind, &res, started_at);
937 1 : res
938 1 : }
939 :
940 0 : async fn time_travel_recover(
941 0 : &self,
942 0 : prefix: Option<&RemotePath>,
943 0 : timestamp: SystemTime,
944 0 : done_if_after: SystemTime,
945 0 : cancel: &CancellationToken,
946 0 : _complexity_limit: Option<NonZeroU32>,
947 0 : ) -> Result<(), TimeTravelError> {
948 0 : let msg = "PLEASE NOTE: Azure Blob storage time-travel recovery may not work as expected "
949 0 : .to_string()
950 0 : + "for some specific files. If a file gets deleted but then overwritten and we want to recover "
951 0 : + "to the time during the file was not present, this functionality will recover the file. Only "
952 0 : + "use the functionality for services that can tolerate this. For example, recovering a state of the "
953 0 : + "pageserver tenants.";
954 0 : tracing::error!("{}", msg);
955 :
956 0 : let kind = RequestKind::TimeTravel;
957 0 : let permit = self.permit(kind, cancel).await?;
958 :
959 0 : let mode = ListingMode::NoDelimiter;
960 0 : let version_listing = self
961 0 : .list_versions_with_permit(&permit, prefix, mode, None, cancel)
962 0 : .await
963 0 : .map_err(|err| match err {
964 0 : DownloadError::Other(e) => TimeTravelError::Other(e),
965 0 : DownloadError::Cancelled => TimeTravelError::Cancelled,
966 0 : other => TimeTravelError::Other(other.into()),
967 0 : })?;
968 0 : let versions_and_deletes = version_listing.versions;
969 :
970 0 : tracing::info!(
971 0 : "Built list for time travel with {} versions and deletions",
972 0 : versions_and_deletes.len()
973 : );
974 :
975 : // Work on the list of references instead of the objects directly,
976 : // otherwise we get lifetime errors in the sort_by_key call below.
977 0 : let mut versions_and_deletes = versions_and_deletes.iter().collect::<Vec<_>>();
978 :
979 0 : versions_and_deletes.sort_by_key(|vd| (&vd.key, &vd.last_modified));
980 :
981 0 : let mut vds_for_key = HashMap::<_, Vec<_>>::new();
982 :
983 0 : for vd in &versions_and_deletes {
984 0 : let Version { key, .. } = &vd;
985 0 : let version_id = vd.version_id().map(|v| v.0.as_str());
986 0 : if version_id == Some("null") {
987 0 : return Err(TimeTravelError::Other(anyhow!(
988 0 : "Received ListVersions response for key={key} with version_id='null', \
989 0 : indicating either disabled versioning, or legacy objects with null version id values"
990 0 : )));
991 0 : }
992 0 : tracing::trace!("Parsing version key={key} kind={:?}", vd.kind);
993 :
994 0 : vds_for_key.entry(key).or_default().push(vd);
995 : }
996 :
997 0 : let warn_threshold = 3;
998 0 : let max_retries = 10;
999 0 : let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled);
1000 :
1001 0 : for (key, versions) in vds_for_key {
1002 0 : let last_vd = versions.last().unwrap();
1003 0 : let key = self.relative_path_to_name(key);
1004 0 : if last_vd.last_modified > done_if_after {
1005 0 : tracing::debug!("Key {key} has version later than done_if_after, skipping");
1006 0 : continue;
1007 0 : }
1008 : // the version we want to restore to.
1009 0 : let version_to_restore_to =
1010 0 : match versions.binary_search_by_key(×tamp, |tpl| tpl.last_modified) {
1011 0 : Ok(v) => v,
1012 0 : Err(e) => e,
1013 : };
1014 0 : if version_to_restore_to == versions.len() {
1015 0 : tracing::debug!("Key {key} has no changes since timestamp, skipping");
1016 0 : continue;
1017 0 : }
1018 0 : let mut do_delete = false;
1019 0 : if version_to_restore_to == 0 {
1020 : // All versions more recent, so the key didn't exist at the specified time point.
1021 0 : tracing::debug!(
1022 0 : "All {} versions more recent for {key}, deleting",
1023 0 : versions.len()
1024 : );
1025 0 : do_delete = true;
1026 : } else {
1027 0 : match &versions[version_to_restore_to - 1] {
1028 : Version {
1029 0 : kind: VersionKind::Version(version_id),
1030 : ..
1031 : } => {
1032 0 : let source_url = format!(
1033 0 : "{}/{}?versionid={}",
1034 0 : self.client
1035 0 : .url()
1036 0 : .map_err(|e| TimeTravelError::Other(anyhow!("{e}")))?,
1037 : key,
1038 : version_id.0
1039 : );
1040 0 : tracing::debug!(
1041 0 : "Promoting old version {} for {key} at {}...",
1042 : version_id.0,
1043 : source_url
1044 : );
1045 0 : backoff::retry(
1046 0 : || async {
1047 0 : let blob_client = self.client.blob_client(key.clone());
1048 0 : let op = blob_client.copy(Url::from_str(&source_url).unwrap());
1049 0 : tokio::select! {
1050 0 : res = op => res.map_err(|e| TimeTravelError::Other(e.into())),
1051 0 : _ = cancel.cancelled() => Err(TimeTravelError::Cancelled),
1052 : }
1053 0 : },
1054 0 : is_permanent,
1055 0 : warn_threshold,
1056 0 : max_retries,
1057 0 : "copying object version for time_travel_recover",
1058 0 : cancel,
1059 : )
1060 0 : .await
1061 0 : .ok_or_else(|| TimeTravelError::Cancelled)
1062 0 : .and_then(|x| x)?;
1063 0 : tracing::info!(?version_id, %key, "Copied old version in Azure blob storage");
1064 : }
1065 : Version {
1066 : kind: VersionKind::DeletionMarker,
1067 : ..
1068 0 : } => {
1069 0 : do_delete = true;
1070 0 : }
1071 : }
1072 : };
1073 0 : if do_delete {
1074 0 : if matches!(last_vd.kind, VersionKind::DeletionMarker) {
1075 : // Key has since been deleted (but there was some history), no need to do anything
1076 0 : tracing::debug!("Key {key} already deleted, skipping.");
1077 : } else {
1078 0 : tracing::debug!("Deleting {key}...");
1079 :
1080 0 : self.delete(&RemotePath::from_string(&key).unwrap(), cancel)
1081 0 : .await
1082 0 : .map_err(|e| {
1083 : // delete_oid0 will use TimeoutOrCancel
1084 0 : if TimeoutOrCancel::caused_by_cancel(&e) {
1085 0 : TimeTravelError::Cancelled
1086 : } else {
1087 0 : TimeTravelError::Other(e)
1088 : }
1089 0 : })?;
1090 : }
1091 0 : }
1092 : }
1093 :
1094 0 : Ok(())
1095 0 : }
1096 : }
1097 :
1098 : pin_project_lite::pin_project! {
1099 : /// Hack to work around not being able to stream once with azure sdk.
1100 : ///
1101 : /// Azure sdk clones streams around with the assumption that they are like
1102 : /// `Arc<tokio::fs::File>` (except not supporting tokio), however our streams are not like
1103 : /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`]
1104 : /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially"
1105 : /// seekable, but we can also just re-try the request easier.
1106 : #[project = NonSeekableStreamProj]
1107 : enum NonSeekableStream<S> {
1108 : /// A stream wrappers initial form.
1109 : ///
1110 : /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1
1111 : /// clone before first request, then this must be changed.
1112 : Initial {
1113 : inner: std::sync::Mutex<Option<tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>>>,
1114 : len: usize,
1115 : },
1116 : /// The actually readable variant, produced by cloning the Initial variant.
1117 : ///
1118 : /// The sdk currently always clones once, even without retry policy.
1119 : Actual {
1120 : #[pin]
1121 : inner: tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>,
1122 : len: usize,
1123 : read_any: bool,
1124 : },
1125 : /// Most likely unneeded, but left to make life easier, in case more clones are added.
1126 : Cloned {
1127 : len_was: usize,
1128 : }
1129 : }
1130 : }
1131 :
1132 : impl<S> NonSeekableStream<S>
1133 : where
1134 : S: Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
1135 : {
1136 93 : fn new(inner: S, len: usize) -> NonSeekableStream<S> {
1137 : use tokio_util::compat::TokioAsyncReadCompatExt;
1138 :
1139 93 : let inner = tokio_util::io::StreamReader::new(inner).compat();
1140 93 : let inner = Some(inner);
1141 93 : let inner = std::sync::Mutex::new(inner);
1142 93 : NonSeekableStream::Initial { inner, len }
1143 0 : }
1144 : }
1145 :
1146 : impl<S> std::fmt::Debug for NonSeekableStream<S> {
1147 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1148 0 : match self {
1149 0 : Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(),
1150 0 : Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(),
1151 0 : Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(),
1152 : }
1153 0 : }
1154 : }
1155 :
1156 : impl<S> futures::io::AsyncRead for NonSeekableStream<S>
1157 : where
1158 : S: Stream<Item = std::io::Result<Bytes>>,
1159 : {
1160 113 : fn poll_read(
1161 113 : self: std::pin::Pin<&mut Self>,
1162 113 : cx: &mut std::task::Context<'_>,
1163 113 : buf: &mut [u8],
1164 113 : ) -> std::task::Poll<std::io::Result<usize>> {
1165 113 : match self.project() {
1166 : NonSeekableStreamProj::Actual {
1167 113 : inner, read_any, ..
1168 : } => {
1169 113 : *read_any = true;
1170 113 : inner.poll_read(cx, buf)
1171 : }
1172 : // NonSeekableStream::Initial does not support reading because it is just much easier
1173 : // to have the mutex in place where one does not poll the contents, or that's how it
1174 : // seemed originally. If there is a version upgrade which changes the cloning, then
1175 : // that support needs to be hacked in.
1176 : //
1177 : // including {self:?} into the message would be useful, but unsure how to unproject.
1178 0 : _ => std::task::Poll::Ready(Err(std::io::Error::other(
1179 0 : "cloned or initial values cannot be read",
1180 0 : ))),
1181 : }
1182 0 : }
1183 : }
1184 :
1185 : impl<S> Clone for NonSeekableStream<S> {
1186 : /// Weird clone implementation exists to support the sdk doing cloning before issuing the first
1187 : /// request, see type documentation.
1188 93 : fn clone(&self) -> Self {
1189 : use NonSeekableStream::*;
1190 :
1191 93 : match self {
1192 93 : Initial { inner, len } => {
1193 93 : if let Some(inner) = inner.lock().unwrap().take() {
1194 93 : Actual {
1195 93 : inner,
1196 93 : len: *len,
1197 93 : read_any: false,
1198 93 : }
1199 : } else {
1200 0 : Self::Cloned { len_was: *len }
1201 : }
1202 : }
1203 0 : Actual { len, .. } => Cloned { len_was: *len },
1204 0 : Cloned { len_was } => Cloned { len_was: *len_was },
1205 : }
1206 0 : }
1207 : }
1208 :
1209 : #[async_trait::async_trait]
1210 : impl<S> azure_core::SeekableStream for NonSeekableStream<S>
1211 : where
1212 : S: Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync + 'static,
1213 : {
1214 0 : async fn reset(&mut self) -> azure_core::error::Result<()> {
1215 : use NonSeekableStream::*;
1216 :
1217 0 : let msg = match self {
1218 0 : Initial { inner, .. } => {
1219 0 : if inner.get_mut().unwrap().is_some() {
1220 0 : return Ok(());
1221 : } else {
1222 0 : "reset after first clone is not supported"
1223 : }
1224 : }
1225 0 : Actual { read_any, .. } if !*read_any => return Ok(()),
1226 0 : Actual { .. } => "reset after reading is not supported",
1227 0 : Cloned { .. } => "reset after second clone is not supported",
1228 : };
1229 0 : Err(azure_core::error::Error::new(
1230 0 : azure_core::error::ErrorKind::Io,
1231 0 : std::io::Error::other(msg),
1232 0 : ))
1233 0 : }
1234 :
1235 : // Note: it is not documented if this should be the total or remaining length, total passes the
1236 : // tests.
1237 93 : fn len(&self) -> usize {
1238 : use NonSeekableStream::*;
1239 93 : match self {
1240 93 : Initial { len, .. } => *len,
1241 0 : Actual { len, .. } => *len,
1242 0 : Cloned { len_was, .. } => *len_was,
1243 : }
1244 0 : }
1245 : }
|