TLA Line data Source code
1 : //! Azure Blob Storage wrapper
2 :
3 : use std::borrow::Cow;
4 : use std::collections::HashMap;
5 : use std::env;
6 : use std::num::NonZeroU32;
7 : use std::pin::Pin;
8 : use std::sync::Arc;
9 :
10 : use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
11 : use anyhow::Result;
12 : use azure_core::request_options::{MaxResults, Metadata, Range};
13 : use azure_core::RetryOptions;
14 : use azure_identity::DefaultAzureCredential;
15 : use azure_storage::StorageCredentials;
16 : use azure_storage_blobs::prelude::ClientBuilder;
17 : use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
18 : use bytes::Bytes;
19 : use futures::stream::Stream;
20 : use futures_util::StreamExt;
21 : use http_types::StatusCode;
22 : use tracing::debug;
23 :
24 : use crate::s3_bucket::RequestKind;
25 : use crate::{
26 : AzureConfig, ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath,
27 : RemoteStorage, StorageMetadata,
28 : };
29 :
30 : pub struct AzureBlobStorage {
31 : client: ContainerClient,
32 : prefix_in_container: Option<String>,
33 : max_keys_per_list_response: Option<NonZeroU32>,
34 : concurrency_limiter: ConcurrencyLimiter,
35 : }
36 :
37 : impl AzureBlobStorage {
38 CBC 5 : pub fn new(azure_config: &AzureConfig) -> Result<Self> {
39 5 : debug!(
40 UBC 0 : "Creating azure remote storage for azure container {}",
41 0 : azure_config.container_name
42 0 : );
43 :
44 CBC 5 : let account = env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT");
45 :
46 : // If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
47 : // otherwise try the token based credentials.
48 5 : let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
49 5 : StorageCredentials::access_key(account.clone(), access_key)
50 : } else {
51 UBC 0 : let token_credential = DefaultAzureCredential::default();
52 0 : StorageCredentials::token_credential(Arc::new(token_credential))
53 : };
54 :
55 : // we have an outer retry
56 CBC 5 : let builder = ClientBuilder::new(account, credentials).retry(RetryOptions::none());
57 5 :
58 5 : let client = builder.container_client(azure_config.container_name.to_owned());
59 :
60 5 : let max_keys_per_list_response =
61 5 : if let Some(limit) = azure_config.max_keys_per_list_response {
62 : Some(
63 2 : NonZeroU32::new(limit as u32)
64 2 : .ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
65 : )
66 : } else {
67 3 : None
68 : };
69 :
70 5 : Ok(AzureBlobStorage {
71 5 : client,
72 5 : prefix_in_container: azure_config.prefix_in_container.to_owned(),
73 5 : max_keys_per_list_response,
74 5 : concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
75 5 : })
76 5 : }
77 :
78 101 : pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
79 101 : assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
80 101 : let path_string = path
81 101 : .get_path()
82 101 : .as_str()
83 101 : .trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
84 101 : match &self.prefix_in_container {
85 101 : Some(prefix) => {
86 101 : if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
87 101 : prefix.clone() + path_string
88 : } else {
89 UBC 0 : format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
90 : }
91 : }
92 0 : None => path_string.to_string(),
93 : }
94 CBC 101 : }
95 :
96 51 : fn name_to_relative_path(&self, key: &str) -> RemotePath {
97 51 : let relative_path =
98 51 : match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
99 51 : Some(stripped) => stripped,
100 : // we rely on Azure to return properly prefixed paths
101 : // for requests with a certain prefix
102 UBC 0 : None => panic!(
103 0 : "Key {key} does not start with container prefix {:?}",
104 0 : self.prefix_in_container
105 0 : ),
106 : };
107 CBC 51 : RemotePath(
108 51 : relative_path
109 51 : .split(REMOTE_STORAGE_PREFIX_SEPARATOR)
110 51 : .collect(),
111 51 : )
112 51 : }
113 :
114 6 : async fn download_for_builder(
115 6 : &self,
116 6 : builder: GetBlobBuilder,
117 6 : ) -> Result<Download, DownloadError> {
118 6 : let mut response = builder.into_stream();
119 6 :
120 6 : let mut etag = None;
121 6 : let mut last_modified = None;
122 6 : let mut metadata = HashMap::new();
123 6 : // TODO give proper streaming response instead of buffering into RAM
124 6 : // https://github.com/neondatabase/neon/issues/5563
125 6 :
126 6 : let mut bufs = Vec::new();
127 12 : while let Some(part) = response.next().await {
128 6 : let part = part.map_err(to_download_error)?;
129 6 : let etag_str: &str = part.blob.properties.etag.as_ref();
130 6 : if etag.is_none() {
131 6 : etag = Some(etag.unwrap_or_else(|| etag_str.to_owned()));
132 6 : }
133 6 : if last_modified.is_none() {
134 6 : last_modified = Some(part.blob.properties.last_modified.into());
135 6 : }
136 6 : if let Some(blob_meta) = part.blob.metadata {
137 UBC 0 : metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
138 CBC 6 : }
139 6 : let data = part
140 6 : .data
141 6 : .collect()
142 UBC 0 : .await
143 CBC 6 : .map_err(|e| DownloadError::Other(e.into()))?;
144 6 : bufs.push(data);
145 : }
146 6 : Ok(Download {
147 6 : download_stream: Box::pin(futures::stream::iter(bufs.into_iter().map(Ok))),
148 6 : etag,
149 6 : last_modified,
150 6 : metadata: Some(StorageMetadata(metadata)),
151 6 : })
152 6 : }
153 :
154 99 : async fn permit(&self, kind: RequestKind) -> tokio::sync::SemaphorePermit<'_> {
155 99 : self.concurrency_limiter
156 99 : .acquire(kind)
157 UBC 0 : .await
158 CBC 99 : .expect("semaphore is never closed")
159 99 : }
160 : }
161 :
162 UBC 0 : fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
163 0 : let mut res = Metadata::new();
164 0 : for (k, v) in metadata.0.into_iter() {
165 0 : res.insert(k, v);
166 0 : }
167 0 : res
168 0 : }
169 :
170 0 : fn to_download_error(error: azure_core::Error) -> DownloadError {
171 0 : if let Some(http_err) = error.as_http_error() {
172 0 : match http_err.status() {
173 0 : StatusCode::NotFound => DownloadError::NotFound,
174 0 : StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
175 0 : _ => DownloadError::Other(anyhow::Error::new(error)),
176 : }
177 : } else {
178 0 : DownloadError::Other(error.into())
179 : }
180 0 : }
181 :
182 : #[async_trait::async_trait]
183 : impl RemoteStorage for AzureBlobStorage {
184 CBC 5 : async fn list(
185 5 : &self,
186 5 : prefix: Option<&RemotePath>,
187 5 : mode: ListingMode,
188 5 : ) -> anyhow::Result<Listing, DownloadError> {
189 : // get the passed prefix or if it is not set use prefix_in_bucket value
190 5 : let list_prefix = prefix
191 5 : .map(|p| self.relative_path_to_name(p))
192 5 : .or_else(|| self.prefix_in_container.clone())
193 5 : .map(|mut p| {
194 : // required to end with a separator
195 : // otherwise request will return only the entry of a prefix
196 5 : if matches!(mode, ListingMode::WithDelimiter)
197 3 : && !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
198 1 : {
199 1 : p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
200 4 : }
201 5 : p
202 5 : });
203 5 :
204 5 : let mut builder = self.client.list_blobs();
205 5 :
206 5 : if let ListingMode::WithDelimiter = mode {
207 3 : builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
208 3 : }
209 :
210 5 : if let Some(prefix) = list_prefix {
211 5 : builder = builder.prefix(Cow::from(prefix.to_owned()));
212 5 : }
213 :
214 5 : if let Some(limit) = self.max_keys_per_list_response {
215 4 : builder = builder.max_results(MaxResults::new(limit));
216 4 : }
217 :
218 5 : let mut response = builder.into_stream();
219 5 : let mut res = Listing::default();
220 18 : while let Some(l) = response.next().await {
221 9 : let entry = l.map_err(to_download_error)?;
222 9 : let prefix_iter = entry
223 9 : .blobs
224 9 : .prefixes()
225 23 : .map(|prefix| self.name_to_relative_path(&prefix.name));
226 9 : res.prefixes.extend(prefix_iter);
227 9 :
228 9 : let blob_iter = entry
229 9 : .blobs
230 9 : .blobs()
231 28 : .map(|k| self.name_to_relative_path(&k.name));
232 9 : res.keys.extend(blob_iter);
233 : }
234 5 : Ok(res)
235 10 : }
236 :
237 UBC 0 : async fn upload(
238 0 : &self,
239 0 : from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
240 0 : data_size_bytes: usize,
241 0 : to: &RemotePath,
242 0 : metadata: Option<StorageMetadata>,
243 0 : ) -> anyhow::Result<()> {
244 0 : let _permit = self.permit(RequestKind::Put).await;
245 0 : let blob_client = self.client.blob_client(self.relative_path_to_name(to));
246 0 :
247 0 : let from: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static>> =
248 0 : Box::pin(from);
249 0 :
250 0 : let from = NonSeekableStream::new(from, data_size_bytes);
251 0 :
252 0 : let body = azure_core::Body::SeekableStream(Box::new(from));
253 0 :
254 0 : let mut builder = blob_client.put_block_blob(body);
255 :
256 0 : if let Some(metadata) = metadata {
257 0 : builder = builder.metadata(to_azure_metadata(metadata));
258 0 : }
259 :
260 0 : let _response = builder.into_future().await?;
261 :
262 0 : Ok(())
263 0 : }
264 :
265 CBC 1 : async fn download(&self, from: &RemotePath) -> Result<Download, DownloadError> {
266 1 : let _permit = self.permit(RequestKind::Get).await;
267 1 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
268 1 :
269 1 : let builder = blob_client.get();
270 1 :
271 1 : self.download_for_builder(builder).await
272 2 : }
273 :
274 5 : async fn download_byte_range(
275 5 : &self,
276 5 : from: &RemotePath,
277 5 : start_inclusive: u64,
278 5 : end_exclusive: Option<u64>,
279 5 : ) -> Result<Download, DownloadError> {
280 5 : let _permit = self.permit(RequestKind::Get).await;
281 5 : let blob_client = self.client.blob_client(self.relative_path_to_name(from));
282 5 :
283 5 : let mut builder = blob_client.get();
284 :
285 5 : let range: Range = if let Some(end_exclusive) = end_exclusive {
286 3 : (start_inclusive..end_exclusive).into()
287 : } else {
288 2 : (start_inclusive..).into()
289 : };
290 5 : builder = builder.range(range);
291 5 :
292 5 : self.download_for_builder(builder).await
293 10 : }
294 :
295 47 : async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> {
296 47 : let _permit = self.permit(RequestKind::Delete).await;
297 47 : let blob_client = self.client.blob_client(self.relative_path_to_name(path));
298 47 :
299 47 : let builder = blob_client.delete();
300 47 :
301 53 : match builder.into_future().await {
302 46 : Ok(_response) => Ok(()),
303 1 : Err(e) => {
304 1 : if let Some(http_err) = e.as_http_error() {
305 1 : if http_err.status() == StatusCode::NotFound {
306 1 : return Ok(());
307 UBC 0 : }
308 0 : }
309 0 : Err(anyhow::Error::new(e))
310 : }
311 : }
312 CBC 94 : }
313 :
314 2 : async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> {
315 : // Permit is already obtained by inner delete function
316 :
317 : // TODO batch requests are also not supported by the SDK
318 : // https://github.com/Azure/azure-sdk-for-rust/issues/1068
319 : // https://github.com/Azure/azure-sdk-for-rust/issues/1249
320 5 : for path in paths {
321 3 : self.delete(path).await?;
322 : }
323 2 : Ok(())
324 4 : }
325 :
326 UBC 0 : async fn copy(&self, _from: &RemotePath, _to: &RemotePath) -> anyhow::Result<()> {
327 0 : Err(anyhow::anyhow!(
328 0 : "copy for azure blob storage is not implemented"
329 0 : ))
330 0 : }
331 : }
332 :
333 : pin_project_lite::pin_project! {
334 : /// Hack to work around not being able to stream once with azure sdk.
335 : ///
336 : /// Azure sdk clones streams around with the assumption that they are like
337 : /// `Arc<tokio::fs::File>` (except not supporting tokio), however our streams are not like
338 : /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`]
339 : /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially"
340 : /// seekable, but we can also just re-try the request easier.
341 : #[project = NonSeekableStreamProj]
342 : enum NonSeekableStream<S> {
343 : /// A stream wrappers initial form.
344 : ///
345 : /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1
346 : /// clone before first request, then this must be changed.
347 : Initial {
348 : inner: std::sync::Mutex<Option<tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>>>,
349 : len: usize,
350 : },
351 : /// The actually readable variant, produced by cloning the Initial variant.
352 : ///
353 : /// The sdk currently always clones once, even without retry policy.
354 : Actual {
355 : #[pin]
356 : inner: tokio_util::compat::Compat<tokio_util::io::StreamReader<S, Bytes>>,
357 : len: usize,
358 : read_any: bool,
359 : },
360 : /// Most likely unneeded, but left to make life easier, in case more clones are added.
361 : Cloned {
362 : len_was: usize,
363 : }
364 : }
365 : }
366 :
367 : impl<S> NonSeekableStream<S>
368 : where
369 : S: Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
370 : {
371 0 : fn new(inner: S, len: usize) -> NonSeekableStream<S> {
372 0 : use tokio_util::compat::TokioAsyncReadCompatExt;
373 0 :
374 0 : let inner = tokio_util::io::StreamReader::new(inner).compat();
375 0 : let inner = Some(inner);
376 0 : let inner = std::sync::Mutex::new(inner);
377 0 : NonSeekableStream::Initial { inner, len }
378 0 : }
379 : }
380 :
381 : impl<S> std::fmt::Debug for NonSeekableStream<S> {
382 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383 0 : match self {
384 0 : Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(),
385 0 : Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(),
386 0 : Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(),
387 : }
388 0 : }
389 : }
390 :
391 : impl<S> futures::io::AsyncRead for NonSeekableStream<S>
392 : where
393 : S: Stream<Item = std::io::Result<Bytes>>,
394 : {
395 0 : fn poll_read(
396 0 : self: std::pin::Pin<&mut Self>,
397 0 : cx: &mut std::task::Context<'_>,
398 0 : buf: &mut [u8],
399 0 : ) -> std::task::Poll<std::io::Result<usize>> {
400 0 : match self.project() {
401 : NonSeekableStreamProj::Actual {
402 0 : inner, read_any, ..
403 0 : } => {
404 0 : *read_any = true;
405 0 : inner.poll_read(cx, buf)
406 : }
407 : // NonSeekableStream::Initial does not support reading because it is just much easier
408 : // to have the mutex in place where one does not poll the contents, or that's how it
409 : // seemed originally. If there is a version upgrade which changes the cloning, then
410 : // that support needs to be hacked in.
411 : //
412 : // including {self:?} into the message would be useful, but unsure how to unproject.
413 0 : _ => std::task::Poll::Ready(Err(std::io::Error::new(
414 0 : std::io::ErrorKind::Other,
415 0 : "cloned or initial values cannot be read",
416 0 : ))),
417 : }
418 0 : }
419 : }
420 :
421 : impl<S> Clone for NonSeekableStream<S> {
422 : /// Weird clone implementation exists to support the sdk doing cloning before issuing the first
423 : /// request, see type documentation.
424 0 : fn clone(&self) -> Self {
425 0 : use NonSeekableStream::*;
426 0 :
427 0 : match self {
428 0 : Initial { inner, len } => {
429 0 : if let Some(inner) = inner.lock().unwrap().take() {
430 0 : Actual {
431 0 : inner,
432 0 : len: *len,
433 0 : read_any: false,
434 0 : }
435 : } else {
436 0 : Self::Cloned { len_was: *len }
437 : }
438 : }
439 0 : Actual { len, .. } => Cloned { len_was: *len },
440 0 : Cloned { len_was } => Cloned { len_was: *len_was },
441 : }
442 0 : }
443 : }
444 :
445 : #[async_trait::async_trait]
446 : impl<S> azure_core::SeekableStream for NonSeekableStream<S>
447 : where
448 : S: Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync + 'static,
449 : {
450 0 : async fn reset(&mut self) -> azure_core::error::Result<()> {
451 : use NonSeekableStream::*;
452 :
453 0 : let msg = match self {
454 0 : Initial { inner, .. } => {
455 0 : if inner.get_mut().unwrap().is_some() {
456 0 : return Ok(());
457 : } else {
458 0 : "reset after first clone is not supported"
459 : }
460 : }
461 0 : Actual { read_any, .. } if !*read_any => return Ok(()),
462 0 : Actual { .. } => "reset after reading is not supported",
463 0 : Cloned { .. } => "reset after second clone is not supported",
464 : };
465 0 : Err(azure_core::error::Error::new(
466 0 : azure_core::error::ErrorKind::Io,
467 0 : std::io::Error::new(std::io::ErrorKind::Other, msg),
468 0 : ))
469 0 : }
470 :
471 : // Note: it is not documented if this should be the total or remaining length, total passes the
472 : // tests.
473 0 : fn len(&self) -> usize {
474 0 : use NonSeekableStream::*;
475 0 : match self {
476 0 : Initial { len, .. } => *len,
477 0 : Actual { len, .. } => *len,
478 0 : Cloned { len_was, .. } => *len_was,
479 : }
480 0 : }
481 : }
|