Line data Source code
1 : //! Azure Blob Storage wrapper
2 :
3 : use std::borrow::Cow;
4 : use std::collections::HashMap;
5 : use std::env;
6 : use std::num::NonZeroU32;
7 : use std::pin::Pin;
8 : use std::str::FromStr;
9 : use std::sync::Arc;
10 : use std::time::Duration;
11 : use std::time::SystemTime;
12 :
13 : use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
14 : use anyhow::Result;
15 : use azure_core::request_options::{MaxResults, Metadata, Range};
16 : use azure_core::RetryOptions;
17 : use azure_identity::DefaultAzureCredential;
18 : use azure_storage::StorageCredentials;
19 : use azure_storage_blobs::blob::CopyStatus;
20 : use azure_storage_blobs::prelude::ClientBuilder;
21 : use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
22 : use bytes::Bytes;
23 : use futures::stream::Stream;
24 : use futures_util::StreamExt;
25 : use http_types::{StatusCode, Url};
26 : use tokio::time::Instant;
27 : use tokio_util::sync::CancellationToken;
28 : use tracing::debug;
29 :
30 : use crate::s3_bucket::RequestKind;
31 : use crate::TimeTravelError;
32 : use crate::{
33 : AzureConfig, ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath,
34 : RemoteStorage, StorageMetadata,
35 : };
36 :
37 : pub struct AzureBlobStorage {
38 : client: ContainerClient,
39 : prefix_in_container: Option<String>,
40 : max_keys_per_list_response: Option<NonZeroU32>,
41 : concurrency_limiter: ConcurrencyLimiter,
42 : }
43 :
44 : impl AzureBlobStorage {
45 6 : pub fn new(azure_config: &AzureConfig) -> Result<Self> {
46 6 : debug!(
47 0 : "Creating azure remote storage for azure container {}",
48 0 : azure_config.container_name
49 0 : );
50 :
51 6 : let account = env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT");
52 :
53 : // If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
54 : // otherwise try the token based credentials.
55 6 : let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
56 6 : StorageCredentials::access_key(account.clone(), access_key)
57 : } else {
58 0 : let token_credential = DefaultAzureCredential::default();
59 0 : StorageCredentials::token_credential(Arc::new(token_credential))
60 : };
61 :
62 : // we have an outer retry
63 6 : let builder = ClientBuilder::new(account, credentials).retry(RetryOptions::none());
64 6 :
65 6 : let client = builder.container_client(azure_config.container_name.to_owned());
66 :
67 6 : let max_keys_per_list_response =
68 6 : if let Some(limit) = azure_config.max_keys_per_list_response {
69 : Some(
70 2 : NonZeroU32::new(limit as u32)
71 2 : .ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
72 : )
73 : } else {
74 4 : None
75 : };
76 :
77 6 : Ok(AzureBlobStorage {
78 6 : client,
79 6 : prefix_in_container: azure_config.prefix_in_container.to_owned(),
80 6 : max_keys_per_list_response,
81 6 : concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
82 6 : })
83 6 : }
84 :
85 107 : pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
86 107 : assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
87 107 : let path_string = path
88 107 : .get_path()
89 107 : .as_str()
90 107 : .trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
91 107 : match &self.prefix_in_container {
92 107 : Some(prefix) => {
93 107 : if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
94 107 : prefix.clone() + path_string
95 : } else {
96 0 : format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
97 : }
98 : }
99 0 : None => path_string.to_string(),
100 : }
101 107 : }
102 :
103 53 : fn name_to_relative_path(&self, key: &str) -> RemotePath {
104 53 : let relative_path =
105 53 : match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
106 53 : Some(stripped) => stripped,
107 : // we rely on Azure to return properly prefixed paths
108 : // for requests with a certain prefix
109 0 : None => panic!(
110 0 : "Key {key} does not start with container prefix {:?}",
111 0 : self.prefix_in_container
112 0 : ),
113 : };
114 53 : RemotePath(
115 53 : relative_path
116 53 : .split(REMOTE_STORAGE_PREFIX_SEPARATOR)
117 53 : .collect(),
118 53 : )
119 53 : }
120 :
121 7 : async fn download_for_builder(
122 7 : &self,
123 7 : builder: GetBlobBuilder,
124 7 : ) -> Result<Download, DownloadError> {
125 0 : let mut response = builder.into_stream();
126 0 :
127 0 : let mut etag = None;
128 0 : let mut last_modified = None;
129 0 : let mut metadata = HashMap::new();
130 0 : // TODO give proper streaming response instead of buffering into RAM
131 0 : // https://github.com/neondatabase/neon/issues/5563
132 0 :
133 0 : let mut bufs = Vec::new();
134 0 : while let Some(part) = response.next().await {
135 0 : let part = part.map_err(to_download_error)?;
136 0 : let etag_str: &str = part.blob.properties.etag.as_ref();
137 0 : if etag.is_none() {
138 0 : etag = Some(etag.unwrap_or_else(|| etag_str.to_owned()));
139 0 : }
140 0 : if last_modified.is_none() {
141 0 : last_modified = Some(part.blob.properties.last_modified.into());
142 0 : }
143 0 : if let Some(blob_meta) = part.blob.metadata {
144 0 : metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
145 0 : }
146 0 : let data = part
147 0 : .data
148 0 : .collect()
149 0 : .await
150 0 : .map_err(|e| DownloadError::Other(e.into()))?;
151 0 : bufs.push(data);
152 : }
153 0 : Ok(Download {
154 0 : download_stream: Box::pin(futures::stream::iter(bufs.into_iter().map(Ok))),
155 0 : etag,
156 0 : last_modified,
157 0 : metadata: Some(StorageMetadata(metadata)),
158 0 : })
159 0 : }
160 :
161 104 : async fn permit(&self, kind: RequestKind) -> tokio::sync::SemaphorePermit<'_> {
162 0 : self.concurrency_limiter
163 0 : .acquire(kind)
164 0 : .await
165 0 : .expect("semaphore is never closed")
166 0 : }
167 : }
168 :
169 0 : fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
170 0 : let mut res = Metadata::new();
171 0 : for (k, v) in metadata.0.into_iter() {
172 0 : res.insert(k, v);
173 0 : }
174 0 : res
175 0 : }
176 :
177 0 : fn to_download_error(error: azure_core::Error) -> DownloadError {
178 0 : if let Some(http_err) = error.as_http_error() {
179 0 : match http_err.status() {
180 0 : StatusCode::NotFound => DownloadError::NotFound,
181 0 : StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
182 0 : _ => DownloadError::Other(anyhow::Error::new(error)),
183 : }
184 : } else {
185 0 : DownloadError::Other(error.into())
186 : }
187 0 : }
188 :
189 : impl RemoteStorage for AzureBlobStorage {
190 6 : async fn list(
191 6 : &self,
192 6 : prefix: Option<&RemotePath>,
193 6 : mode: ListingMode,
194 6 : max_keys: Option<NonZeroU32>,
195 6 : ) -> anyhow::Result<Listing, DownloadError> {
196 0 : // get the passed prefix or if it is not set use prefix_in_bucket value
197 0 : let list_prefix = prefix
198 0 : .map(|p| self.relative_path_to_name(p))
199 0 : .or_else(|| self.prefix_in_container.clone())
200 0 : .map(|mut p| {
201 : // required to end with a separator
202 : // otherwise request will return only the entry of a prefix
203 0 : if matches!(mode, ListingMode::WithDelimiter)
204 0 : && !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
205 0 : {
206 0 : p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
207 0 : }
208 0 : p
209 0 : });
210 0 :
211 0 : let mut builder = self.client.list_blobs();
212 0 :
213 0 : if let ListingMode::WithDelimiter = mode {
214 0 : builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
215 0 : }
216 :
217 0 : if let Some(prefix) = list_prefix {
218 0 : builder = builder.prefix(Cow::from(prefix.to_owned()));
219 0 : }
220 :
221 0 : if let Some(limit) = self.max_keys_per_list_response {
222 0 : builder = builder.max_results(MaxResults::new(limit));
223 0 : }
224 :
225 0 : let mut response = builder.into_stream();
226 0 : let mut res = Listing::default();
227 0 : // NonZeroU32 doesn't support subtraction apparently
228 0 : let mut max_keys = max_keys.map(|mk| mk.get());
229 0 : while let Some(l) = response.next().await {
230 0 : let entry = l.map_err(to_download_error)?;
231 0 : let prefix_iter = entry
232 0 : .blobs
233 0 : .prefixes()
234 0 : .map(|prefix| self.name_to_relative_path(&prefix.name));
235 0 : res.prefixes.extend(prefix_iter);
236 0 :
237 0 : let blob_iter = entry
238 0 : .blobs
239 0 : .blobs()
240 0 : .map(|k| self.name_to_relative_path(&k.name));
241 :
242 0 : for key in blob_iter {
243 0 : res.keys.push(key);
244 0 : if let Some(mut mk) = max_keys {
245 0 : assert!(mk > 0);
246 0 : mk -= 1;
247 0 : if mk == 0 {
248 0 : return Ok(res); // limit reached
249 0 : }
250 0 : max_keys = Some(mk);
251 0 : }
252 : }
253 : }
254 0 : Ok(res)
255 0 : }
256 :
257 0 : async fn upload(
258 0 : &self,
259 0 : from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
260 0 : data_size_bytes: usize,
261 0 : to: &RemotePath,
262 0 : metadata: Option<StorageMetadata>,
263 0 : ) -> anyhow::Result<()> {
264 0 : let _permit = self.permit(RequestKind::Put).await;
265 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
266 0 :
267 0 : let from: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>> =
268 0 : Box::pin(from);
269 0 :
270 0 : let from = NonSeekableStream::new(from, data_size_bytes);
271 0 :
272 0 : let body = azure_core::Body::SeekableStream(Box::new(from));
273 0 :
274 0 : let mut builder = blob_client.put_block_blob(body);
275 :
276 0 : if let Some(metadata) = metadata {
277 0 : builder = builder.metadata(to_azure_metadata(metadata));
278 0 : }
279 :
280 0 : let _response = builder.into_future().await?;
281 :
282 0 : Ok(())
283 0 : }
284 :
285 2 : async fn download(&self, from: &RemotePath) -> Result<Download, DownloadError> {
286 0 : let _permit = self.permit(RequestKind::Get).await;
287 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
288 0 :
289 0 : let builder = blob_client.get();
290 0 :
291 0 : self.download_for_builder(builder).await
292 0 : }
293 :
294 5 : async fn download_byte_range(
295 5 : &self,
296 5 : from: &RemotePath,
297 5 : start_inclusive: u64,
298 5 : end_exclusive: Option<u64>,
299 5 : ) -> Result<Download, DownloadError> {
300 0 : let _permit = self.permit(RequestKind::Get).await;
301 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
302 0 :
303 0 : let mut builder = blob_client.get();
304 :
305 0 : let range: Range = if let Some(end_exclusive) = end_exclusive {
306 0 : (start_inclusive..end_exclusive).into()
307 : } else {
308 0 : (start_inclusive..).into()
309 : };
310 0 : builder = builder.range(range);
311 0 :
312 0 : self.download_for_builder(builder).await
313 0 : }
314 :
315 49 : async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> {
316 0 : let _permit = self.permit(RequestKind::Delete).await;
317 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(path));
318 0 :
319 0 : let builder = blob_client.delete();
320 0 :
321 0 : match builder.into_future().await {
322 0 : Ok(_response) => Ok(()),
323 0 : Err(e) => {
324 0 : if let Some(http_err) = e.as_http_error() {
325 0 : if http_err.status() == StatusCode::NotFound {
326 0 : return Ok(());
327 0 : }
328 0 : }
329 0 : Err(anyhow::Error::new(e))
330 : }
331 : }
332 0 : }
333 :
334 3 : async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> {
335 : // Permit is already obtained by inner delete function
336 :
337 : // TODO batch requests are also not supported by the SDK
338 : // https://github.com/Azure/azure-sdk-for-rust/issues/1068
339 : // https://github.com/Azure/azure-sdk-for-rust/issues/1249
340 0 : for path in paths {
341 0 : self.delete(path).await?;
342 : }
343 0 : Ok(())
344 0 : }
345 :
346 1 : async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> {
347 0 : let _permit = self.permit(RequestKind::Copy).await;
348 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
349 :
350 0 : let source_url = format!(
351 0 : "{}/{}",
352 0 : self.client.url()?,
353 0 : self.relative_path_to_name(from)
354 : );
355 0 : let builder = blob_client.copy(Url::from_str(&source_url)?);
356 :
357 0 : let result = builder.into_future().await?;
358 :
359 0 : let mut copy_status = result.copy_status;
360 0 : let start_time = Instant::now();
361 : const MAX_WAIT_TIME: Duration = Duration::from_secs(60);
362 : loop {
363 0 : match copy_status {
364 : CopyStatus::Aborted => {
365 0 : anyhow::bail!("Received abort for copy from {from} to {to}.");
366 : }
367 : CopyStatus::Failed => {
368 0 : anyhow::bail!("Received failure response for copy from {from} to {to}.");
369 : }
370 0 : CopyStatus::Success => return Ok(()),
371 0 : CopyStatus::Pending => (),
372 0 : }
373 0 : // The copy is taking longer. Waiting a second and then re-trying.
374 0 : // TODO estimate time based on copy_progress and adjust time based on that
375 0 : tokio::time::sleep(Duration::from_millis(1000)).await;
376 0 : let properties = blob_client.get_properties().into_future().await?;
377 0 : let Some(status) = properties.blob.properties.copy_status else {
378 0 : tracing::warn!("copy_status for copy is None!, from={from}, to={to}");
379 0 : return Ok(());
380 : };
381 0 : if start_time.elapsed() > MAX_WAIT_TIME {
382 0 : anyhow::bail!("Copy from from {from} to {to} took longer than limit MAX_WAIT_TIME={}s. copy_pogress={:?}.",
383 0 : MAX_WAIT_TIME.as_secs_f32(),
384 0 : properties.blob.properties.copy_progress,
385 0 : );
386 0 : }
387 0 : copy_status = status;
388 : }
389 0 : }
390 :
391 0 : async fn time_travel_recover(
392 0 : &self,
393 0 : _prefix: Option<&RemotePath>,
394 0 : _timestamp: SystemTime,
395 0 : _done_if_after: SystemTime,
396 0 : _cancel: &CancellationToken,
397 0 : ) -> Result<(), TimeTravelError> {
398 0 : // TODO use Azure point in time recovery feature for this
399 0 : // https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
400 0 : Err(TimeTravelError::Unimplemented)
401 0 : }
402 : }
403 :
404 : pin_project_lite::pin_project! {
405 : /// Hack to work around not being able to stream once with azure sdk.
406 : ///
407 : /// Azure sdk clones streams around with the assumption that they are like
408 : /// `Arc<tokio::fs::File>` (except not supporting tokio), however our streams are not like
409 : /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`]
410 : /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially"
411 : /// seekable, but we can also just re-try the request easier.
412 : #[project = NonSeekableStreamProj]
413 : enum NonSeekableStream<S> {
414 : /// A stream wrappers initial form.
415 : ///
416 : /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1
417 : /// clone before first request, then this must be changed.
418 : Initial {
419 : inner: std::sync::Mutex<Option<tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>>>,
420 : len: usize,
421 : },
422 : /// The actually readable variant, produced by cloning the Initial variant.
423 : ///
424 : /// The sdk currently always clones once, even without retry policy.
425 : Actual {
426 : #[pin]
427 : inner: tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>,
428 : len: usize,
429 : read_any: bool,
430 : },
431 : /// Most likely unneeded, but left to make life easier, in case more clones are added.
432 : Cloned {
433 : len_was: usize,
434 : }
435 : }
436 : }
437 :
438 : impl<S> NonSeekableStream<S>
439 : where
440 : S: Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
441 : {
442 0 : fn new(inner: S, len: usize) -> NonSeekableStream<S> {
443 0 : use tokio_util::compat::TokioAsyncReadCompatExt;
444 0 :
445 0 : let inner = tokio_util::io::StreamReader::new(inner).compat();
446 0 : let inner = Some(inner);
447 0 : let inner = std::sync::Mutex::new(inner);
448 0 : NonSeekableStream::Initial { inner, len }
449 0 : }
450 : }
451 :
452 : impl<S> std::fmt::Debug for NonSeekableStream<S> {
453 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454 0 : match self {
455 0 : Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(),
456 0 : Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(),
457 0 : Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(),
458 : }
459 0 : }
460 : }
461 :
462 : impl<S> futures::io::AsyncRead for NonSeekableStream<S>
463 : where
464 : S: Stream<Item = std::io::Result<Bytes>>,
465 : {
466 0 : fn poll_read(
467 0 : self: std::pin::Pin<&mut Self>,
468 0 : cx: &mut std::task::Context<'_>,
469 0 : buf: &mut [u8],
470 0 : ) -> std::task::Poll<std::io::Result<usize>> {
471 0 : match self.project() {
472 : NonSeekableStreamProj::Actual {
473 0 : inner, read_any, ..
474 0 : } => {
475 0 : *read_any = true;
476 0 : inner.poll_read(cx, buf)
477 : }
478 : // NonSeekableStream::Initial does not support reading because it is just much easier
479 : // to have the mutex in place where one does not poll the contents, or that's how it
480 : // seemed originally. If there is a version upgrade which changes the cloning, then
481 : // that support needs to be hacked in.
482 : //
483 : // including {self:?} into the message would be useful, but unsure how to unproject.
484 0 : _ => std::task::Poll::Ready(Err(std::io::Error::new(
485 0 : std::io::ErrorKind::Other,
486 0 : "cloned or initial values cannot be read",
487 0 : ))),
488 : }
489 0 : }
490 : }
491 :
492 : impl<S> Clone for NonSeekableStream<S> {
493 : /// Weird clone implementation exists to support the sdk doing cloning before issuing the first
494 : /// request, see type documentation.
495 0 : fn clone(&self) -> Self {
496 0 : use NonSeekableStream::*;
497 0 :
498 0 : match self {
499 0 : Initial { inner, len } => {
500 0 : if let Some(inner) = inner.lock().unwrap().take() {
501 0 : Actual {
502 0 : inner,
503 0 : len: *len,
504 0 : read_any: false,
505 0 : }
506 : } else {
507 0 : Self::Cloned { len_was: *len }
508 : }
509 : }
510 0 : Actual { len, .. } => Cloned { len_was: *len },
511 0 : Cloned { len_was } => Cloned { len_was: *len_was },
512 : }
513 0 : }
514 : }
515 :
516 : #[async_trait::async_trait]
517 : impl<S> azure_core::SeekableStream for NonSeekableStream<S>
518 : where
519 : S: Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync + 'static,
520 : {
521 0 : async fn reset(&mut self) -> azure_core::error::Result<()> {
522 : use NonSeekableStream::*;
523 :
524 0 : let msg = match self {
525 0 : Initial { inner, .. } => {
526 0 : if inner.get_mut().unwrap().is_some() {
527 0 : return Ok(());
528 : } else {
529 0 : "reset after first clone is not supported"
530 : }
531 : }
532 0 : Actual { read_any, .. } if !*read_any => return Ok(()),
533 0 : Actual { .. } => "reset after reading is not supported",
534 0 : Cloned { .. } => "reset after second clone is not supported",
535 : };
536 0 : Err(azure_core::error::Error::new(
537 0 : azure_core::error::ErrorKind::Io,
538 0 : std::io::Error::new(std::io::ErrorKind::Other, msg),
539 0 : ))
540 0 : }
541 :
542 : // Note: it is not documented if this should be the total or remaining length, total passes the
543 : // tests.
544 0 : fn len(&self) -> usize {
545 0 : use NonSeekableStream::*;
546 0 : match self {
547 0 : Initial { len, .. } => *len,
548 0 : Actual { len, .. } => *len,
549 0 : Cloned { len_was, .. } => *len_was,
550 : }
551 0 : }
552 : }
|