LCOV - code coverage report
Current view: top level - libs/remote_storage/src - config.rs (source / functions) Coverage Total Hit
Test: 465a86b0c1fda0069b3e0f6c1c126e6b635a1f72.info Lines: 73.0 % 152 111
Test Date: 2024-06-25 15:47:26 Functions: 36.8 % 68 25

            Line data    Source code
       1              : use std::{fmt::Debug, num::NonZeroUsize, str::FromStr, time::Duration};
       2              : 
       3              : use anyhow::bail;
       4              : use aws_sdk_s3::types::StorageClass;
       5              : use camino::Utf8PathBuf;
       6              : 
       7              : use serde::{Deserialize, Serialize};
       8              : 
       9              : use crate::{
      10              :     DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT,
      11              :     DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
      12              : };
      13              : 
      14              : /// External backup storage configuration, enough for creating a client for that storage.
      15          114 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
      16              : pub struct RemoteStorageConfig {
      17              :     /// The storage connection configuration.
      18              :     #[serde(flatten)]
      19              :     pub storage: RemoteStorageKind,
      20              :     /// A common timeout enforced for all requests after concurrency limiter permit has been
      21              :     /// acquired.
      22              :     #[serde(
      23              :         with = "humantime_serde",
      24              :         default = "default_timeout",
      25              :         skip_serializing_if = "is_default_timeout"
      26              :     )]
      27              :     pub timeout: Duration,
      28              : }
      29              : 
      30           10 : fn default_timeout() -> Duration {
      31           10 :     RemoteStorageConfig::DEFAULT_TIMEOUT
      32           10 : }
      33              : 
      34            0 : fn is_default_timeout(d: &Duration) -> bool {
      35            0 :     *d == RemoteStorageConfig::DEFAULT_TIMEOUT
      36            0 : }
      37              : 
      38              : /// A kind of a remote storage to connect to, with its connection configuration.
      39           90 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
      40              : #[serde(untagged)]
      41              : pub enum RemoteStorageKind {
      42              :     /// Storage based on local file system.
      43              :     /// Specify a root folder to place all stored files into.
      44              :     LocalFs { local_path: Utf8PathBuf },
      45              :     /// AWS S3 based storage, storing all files in the S3 bucket
      46              :     /// specified by the config
      47              :     AwsS3(S3Config),
      48              :     /// Azure Blob based storage, storing all files in the container
      49              :     /// specified by the config
      50              :     AzureContainer(AzureConfig),
      51              : }
      52              : 
      53              : /// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
      54           82 : #[derive(Clone, PartialEq, Eq, Deserialize, Serialize)]
      55              : pub struct S3Config {
      56              :     /// Name of the bucket to connect to.
      57              :     pub bucket_name: String,
      58              :     /// The region where the bucket is located at.
      59              :     pub bucket_region: String,
      60              :     /// A "subfolder" in the bucket, to use the same bucket separately by multiple remote storage users at once.
      61              :     pub prefix_in_bucket: Option<String>,
      62              :     /// A base URL to send S3 requests to.
      63              :     /// By default, the endpoint is derived from a region name, assuming it's
      64              :     /// an AWS S3 region name, erroring on wrong region name.
      65              :     /// Endpoint provides a way to support other S3 flavors and their regions.
      66              :     ///
      67              :     /// Example: `http://127.0.0.1:5000`
      68              :     pub endpoint: Option<String>,
      69              :     /// AWS S3 has various limits on its API calls, we need not to exceed those.
      70              :     /// See [`DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT`] for more details.
      71              :     #[serde(default = "default_remote_storage_s3_concurrency_limit")]
      72              :     pub concurrency_limit: NonZeroUsize,
      73              :     #[serde(default = "default_max_keys_per_list_response")]
      74              :     pub max_keys_per_list_response: Option<i32>,
      75              :     #[serde(
      76              :         deserialize_with = "deserialize_storage_class",
      77              :         serialize_with = "serialize_storage_class",
      78              :         default
      79              :     )]
      80              :     pub upload_storage_class: Option<StorageClass>,
      81              : }
      82              : 
      83           10 : fn default_remote_storage_s3_concurrency_limit() -> NonZeroUsize {
      84           10 :     DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
      85           10 :         .try_into()
      86           10 :         .unwrap()
      87           10 : }
      88              : 
      89           14 : fn default_max_keys_per_list_response() -> Option<i32> {
      90           14 :     DEFAULT_MAX_KEYS_PER_LIST_RESPONSE
      91           14 : }
      92              : 
      93              : impl Debug for S3Config {
      94            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      95            0 :         f.debug_struct("S3Config")
      96            0 :             .field("bucket_name", &self.bucket_name)
      97            0 :             .field("bucket_region", &self.bucket_region)
      98            0 :             .field("prefix_in_bucket", &self.prefix_in_bucket)
      99            0 :             .field("concurrency_limit", &self.concurrency_limit)
     100            0 :             .field(
     101            0 :                 "max_keys_per_list_response",
     102            0 :                 &self.max_keys_per_list_response,
     103            0 :             )
     104            0 :             .finish()
     105            0 :     }
     106              : }
     107              : 
     108              : /// Azure  bucket coordinates and access credentials to manage the bucket contents (read and write).
     109           16 : #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
     110              : pub struct AzureConfig {
     111              :     /// Name of the container to connect to.
     112              :     pub container_name: String,
     113              :     /// Name of the storage account the container is inside of
     114              :     pub storage_account: Option<String>,
     115              :     /// The region where the bucket is located at.
     116              :     pub container_region: String,
     117              :     /// A "subfolder" in the container, to use the same container separately by multiple remote storage users at once.
     118              :     pub prefix_in_container: Option<String>,
     119              :     /// Azure has various limits on its API calls, we need not to exceed those.
     120              :     /// See [`DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT`] for more details.
     121              :     #[serde(default = "default_remote_storage_azure_concurrency_limit")]
     122              :     pub concurrency_limit: NonZeroUsize,
     123              :     #[serde(default = "default_max_keys_per_list_response")]
     124              :     pub max_keys_per_list_response: Option<i32>,
     125              : }
     126              : 
     127            8 : fn default_remote_storage_azure_concurrency_limit() -> NonZeroUsize {
     128            8 :     NonZeroUsize::new(DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT).unwrap()
     129            8 : }
     130              : 
     131              : impl Debug for AzureConfig {
     132            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     133            0 :         f.debug_struct("AzureConfig")
     134            0 :             .field("bucket_name", &self.container_name)
     135            0 :             .field("storage_account", &self.storage_account)
     136            0 :             .field("bucket_region", &self.container_region)
     137            0 :             .field("prefix_in_container", &self.prefix_in_container)
     138            0 :             .field("concurrency_limit", &self.concurrency_limit)
     139            0 :             .field(
     140            0 :                 "max_keys_per_list_response",
     141            0 :                 &self.max_keys_per_list_response,
     142            0 :             )
     143            0 :             .finish()
     144            0 :     }
     145              : }
     146              : 
     147            8 : fn deserialize_storage_class<'de, D: serde::Deserializer<'de>>(
     148            8 :     deserializer: D,
     149            8 : ) -> Result<Option<StorageClass>, D::Error> {
     150            8 :     Option::<String>::deserialize(deserializer).and_then(|s| {
     151            8 :         if let Some(s) = s {
     152              :             use serde::de::Error;
     153            8 :             let storage_class = StorageClass::from_str(&s).expect("infallible");
     154              :             #[allow(deprecated)]
     155            8 :             if matches!(storage_class, StorageClass::Unknown(_)) {
     156            0 :                 return Err(D::Error::custom(format!(
     157            0 :                     "Specified storage class unknown to SDK: '{s}'. Allowed values: {:?}",
     158            0 :                     StorageClass::values()
     159            0 :                 )));
     160            8 :             }
     161            8 :             Ok(Some(storage_class))
     162              :         } else {
     163            0 :             Ok(None)
     164              :         }
     165            8 :     })
     166            8 : }
     167              : 
     168            0 : fn serialize_storage_class<S: serde::Serializer>(
     169            0 :     val: &Option<StorageClass>,
     170            0 :     serializer: S,
     171            0 : ) -> Result<S::Ok, S::Error> {
     172            0 :     let val = val.as_ref().map(StorageClass::as_str);
     173            0 :     Option::<&str>::serialize(&val, serializer)
     174            0 : }
     175              : 
     176              : impl RemoteStorageConfig {
     177              :     pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120);
     178              : 
     179           44 :     pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<Option<RemoteStorageConfig>> {
     180           44 :         let document: toml_edit::Document = match toml {
     181           16 :             toml_edit::Item::Table(toml) => toml.clone().into(),
     182           28 :             toml_edit::Item::Value(toml_edit::Value::InlineTable(toml)) => {
     183           28 :                 toml.clone().into_table().into()
     184              :             }
     185            0 :             _ => bail!("toml not a table or inline table"),
     186              :         };
     187              : 
     188           44 :         if document.is_empty() {
     189           22 :             return Ok(None);
     190           22 :         }
     191           22 : 
     192           22 :         Ok(Some(toml_edit::de::from_document(document)?))
     193           44 :     }
     194              : }
     195              : 
     196              : #[cfg(test)]
     197              : mod tests {
     198              :     use super::*;
     199              : 
     200           12 :     fn parse(input: &str) -> anyhow::Result<Option<RemoteStorageConfig>> {
     201           12 :         let toml = input.parse::<toml_edit::Document>().unwrap();
     202           12 :         RemoteStorageConfig::from_toml(toml.as_item())
     203           12 :     }
     204              : 
     205              :     #[test]
     206            4 :     fn parse_localfs_config_with_timeout() {
     207            4 :         let input = "local_path = '.'
     208            4 : timeout = '5s'";
     209            4 : 
     210            4 :         let config = parse(input).unwrap().expect("it exists");
     211            4 : 
     212            4 :         assert_eq!(
     213            4 :             config,
     214            4 :             RemoteStorageConfig {
     215            4 :                 storage: RemoteStorageKind::LocalFs {
     216            4 :                     local_path: Utf8PathBuf::from(".")
     217            4 :                 },
     218            4 :                 timeout: Duration::from_secs(5)
     219            4 :             }
     220            4 :         );
     221            4 :     }
     222              : 
     223              :     #[test]
     224            4 :     fn test_s3_parsing() {
     225            4 :         let toml = "\
     226            4 :     bucket_name = 'foo-bar'
     227            4 :     bucket_region = 'eu-central-1'
     228            4 :     upload_storage_class = 'INTELLIGENT_TIERING'
     229            4 :     timeout = '7s'
     230            4 :     ";
     231            4 : 
     232            4 :         let config = parse(toml).unwrap().expect("it exists");
     233            4 : 
     234            4 :         assert_eq!(
     235            4 :             config,
     236            4 :             RemoteStorageConfig {
     237            4 :                 storage: RemoteStorageKind::AwsS3(S3Config {
     238            4 :                     bucket_name: "foo-bar".into(),
     239            4 :                     bucket_region: "eu-central-1".into(),
     240            4 :                     prefix_in_bucket: None,
     241            4 :                     endpoint: None,
     242            4 :                     concurrency_limit: default_remote_storage_s3_concurrency_limit(),
     243            4 :                     max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
     244            4 :                     upload_storage_class: Some(StorageClass::IntelligentTiering),
     245            4 :                 }),
     246            4 :                 timeout: Duration::from_secs(7)
     247            4 :             }
     248            4 :         );
     249            4 :     }
     250              : 
     251              :     #[test]
     252            4 :     fn test_azure_parsing() {
     253            4 :         let toml = "\
     254            4 :     container_name = 'foo-bar'
     255            4 :     container_region = 'westeurope'
     256            4 :     upload_storage_class = 'INTELLIGENT_TIERING'
     257            4 :     timeout = '7s'
     258            4 :     ";
     259            4 : 
     260            4 :         let config = parse(toml).unwrap().expect("it exists");
     261            4 : 
     262            4 :         assert_eq!(
     263            4 :             config,
     264            4 :             RemoteStorageConfig {
     265            4 :                 storage: RemoteStorageKind::AzureContainer(AzureConfig {
     266            4 :                     container_name: "foo-bar".into(),
     267            4 :                     storage_account: None,
     268            4 :                     container_region: "westeurope".into(),
     269            4 :                     prefix_in_container: None,
     270            4 :                     concurrency_limit: default_remote_storage_azure_concurrency_limit(),
     271            4 :                     max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
     272            4 :                 }),
     273            4 :                 timeout: Duration::from_secs(7)
     274            4 :             }
     275            4 :         );
     276            4 :     }
     277              : }
        

Generated by: LCOV version 2.1-beta