LCOV - code coverage report
Current view: top level - libs/remote_storage/src - config.rs (source / functions) Coverage Total Hit
Test: f5f94ec0366b63fd2cbbe02edc2087dbd893d04d.info Lines: 77.2 % 171 132
Test Date: 2024-11-20 05:34:23 Functions: 27.6 % 127 35

            Line data    Source code
       1              : use std::{fmt::Debug, num::NonZeroUsize, str::FromStr, time::Duration};
       2              : 
       3              : use aws_sdk_s3::types::StorageClass;
       4              : use camino::Utf8PathBuf;
       5              : 
       6              : use serde::{Deserialize, Serialize};
       7              : 
       8              : use crate::{
       9              :     DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT,
      10              :     DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
      11              : };
      12              : 
      13              : /// External backup storage configuration, enough for creating a client for that storage.
      14           53 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
      15              : pub struct RemoteStorageConfig {
      16              :     /// The storage connection configuration.
      17              :     #[serde(flatten)]
      18              :     pub storage: RemoteStorageKind,
      19              :     /// A common timeout enforced for all requests after concurrency limiter permit has been
      20              :     /// acquired.
      21              :     #[serde(
      22              :         with = "humantime_serde",
      23              :         default = "default_timeout",
      24              :         skip_serializing_if = "is_default_timeout"
      25              :     )]
      26              :     pub timeout: Duration,
      27              : }
      28              : 
      29              : impl RemoteStorageKind {
      30            0 :     pub fn bucket_name(&self) -> Option<&str> {
      31            0 :         match self {
      32            0 :             RemoteStorageKind::LocalFs { .. } => None,
      33            0 :             RemoteStorageKind::AwsS3(config) => Some(&config.bucket_name),
      34            0 :             RemoteStorageKind::AzureContainer(config) => Some(&config.container_name),
      35              :         }
      36            0 :     }
      37              : }
      38              : 
      39            1 : fn default_timeout() -> Duration {
      40            1 :     RemoteStorageConfig::DEFAULT_TIMEOUT
      41            1 : }
      42              : 
      43            0 : fn is_default_timeout(d: &Duration) -> bool {
      44            0 :     *d == RemoteStorageConfig::DEFAULT_TIMEOUT
      45            0 : }
      46              : 
      47              : /// A kind of a remote storage to connect to, with its connection configuration.
      48           35 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
      49              : #[serde(untagged)]
      50              : pub enum RemoteStorageKind {
      51              :     /// Storage based on local file system.
      52              :     /// Specify a root folder to place all stored files into.
      53              :     LocalFs { local_path: Utf8PathBuf },
      54              :     /// AWS S3 based storage, storing all files in the S3 bucket
      55              :     /// specified by the config
      56              :     AwsS3(S3Config),
      57              :     /// Azure Blob based storage, storing all files in the container
      58              :     /// specified by the config
      59              :     AzureContainer(AzureConfig),
      60              : }
      61              : 
      62              : /// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
      63           35 : #[derive(Clone, PartialEq, Eq, Deserialize, Serialize)]
      64              : pub struct S3Config {
      65              :     /// Name of the bucket to connect to.
      66              :     pub bucket_name: String,
      67              :     /// The region where the bucket is located at.
      68              :     pub bucket_region: String,
      69              :     /// A "subfolder" in the bucket, to use the same bucket separately by multiple remote storage users at once.
      70              :     pub prefix_in_bucket: Option<String>,
      71              :     /// A base URL to send S3 requests to.
      72              :     /// By default, the endpoint is derived from a region name, assuming it's
      73              :     /// an AWS S3 region name, erroring on wrong region name.
      74              :     /// Endpoint provides a way to support other S3 flavors and their regions.
      75              :     ///
      76              :     /// Example: `http://127.0.0.1:5000`
      77              :     pub endpoint: Option<String>,
      78              :     /// AWS S3 has various limits on its API calls, we need not to exceed those.
      79              :     /// See [`DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT`] for more details.
      80              :     #[serde(default = "default_remote_storage_s3_concurrency_limit")]
      81              :     pub concurrency_limit: NonZeroUsize,
      82              :     #[serde(default = "default_max_keys_per_list_response")]
      83              :     pub max_keys_per_list_response: Option<i32>,
      84              :     #[serde(
      85              :         deserialize_with = "deserialize_storage_class",
      86              :         serialize_with = "serialize_storage_class",
      87              :         default
      88              :     )]
      89              :     pub upload_storage_class: Option<StorageClass>,
      90              : }
      91              : 
      92            7 : fn default_remote_storage_s3_concurrency_limit() -> NonZeroUsize {
      93            7 :     DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
      94            7 :         .try_into()
      95            7 :         .unwrap()
      96            7 : }
      97              : 
      98            7 : fn default_max_keys_per_list_response() -> Option<i32> {
      99            7 :     DEFAULT_MAX_KEYS_PER_LIST_RESPONSE
     100            7 : }
     101              : 
     102              : impl Debug for S3Config {
     103            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     104            0 :         f.debug_struct("S3Config")
     105            0 :             .field("bucket_name", &self.bucket_name)
     106            0 :             .field("bucket_region", &self.bucket_region)
     107            0 :             .field("prefix_in_bucket", &self.prefix_in_bucket)
     108            0 :             .field("concurrency_limit", &self.concurrency_limit)
     109            0 :             .field(
     110            0 :                 "max_keys_per_list_response",
     111            0 :                 &self.max_keys_per_list_response,
     112            0 :             )
     113            0 :             .finish()
     114            0 :     }
     115              : }
     116              : 
     117              : /// Azure  bucket coordinates and access credentials to manage the bucket contents (read and write).
     118           12 : #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
     119              : pub struct AzureConfig {
     120              :     /// Name of the container to connect to.
     121              :     pub container_name: String,
     122              :     /// Name of the storage account the container is inside of
     123              :     pub storage_account: Option<String>,
     124              :     /// The region where the bucket is located at.
     125              :     pub container_region: String,
     126              :     /// A "subfolder" in the container, to use the same container separately by multiple remote storage users at once.
     127              :     pub prefix_in_container: Option<String>,
     128              :     /// Azure has various limits on its API calls, we need not to exceed those.
     129              :     /// See [`DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT`] for more details.
     130              :     #[serde(default = "default_remote_storage_azure_concurrency_limit")]
     131              :     pub concurrency_limit: NonZeroUsize,
     132              :     #[serde(default = "default_max_keys_per_list_response")]
     133              :     pub max_keys_per_list_response: Option<i32>,
     134              : }
     135              : 
     136            6 : fn default_remote_storage_azure_concurrency_limit() -> NonZeroUsize {
     137            6 :     NonZeroUsize::new(DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT).unwrap()
     138            6 : }
     139              : 
     140              : impl Debug for AzureConfig {
     141            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     142            0 :         f.debug_struct("AzureConfig")
     143            0 :             .field("bucket_name", &self.container_name)
     144            0 :             .field("storage_account", &self.storage_account)
     145            0 :             .field("bucket_region", &self.container_region)
     146            0 :             .field("prefix_in_container", &self.prefix_in_container)
     147            0 :             .field("concurrency_limit", &self.concurrency_limit)
     148            0 :             .field(
     149            0 :                 "max_keys_per_list_response",
     150            0 :                 &self.max_keys_per_list_response,
     151            0 :             )
     152            0 :             .finish()
     153            0 :     }
     154              : }
     155              : 
     156           15 : fn deserialize_storage_class<'de, D: serde::Deserializer<'de>>(
     157           15 :     deserializer: D,
     158           15 : ) -> Result<Option<StorageClass>, D::Error> {
     159           15 :     Option::<String>::deserialize(deserializer).and_then(|s| {
     160           15 :         if let Some(s) = s {
     161              :             use serde::de::Error;
     162           12 :             let storage_class = StorageClass::from_str(&s).expect("infallible");
     163              :             #[allow(deprecated)]
     164           12 :             if matches!(storage_class, StorageClass::Unknown(_)) {
     165            0 :                 return Err(D::Error::custom(format!(
     166            0 :                     "Specified storage class unknown to SDK: '{s}'. Allowed values: {:?}",
     167            0 :                     StorageClass::values()
     168            0 :                 )));
     169           12 :             }
     170           12 :             Ok(Some(storage_class))
     171              :         } else {
     172            3 :             Ok(None)
     173              :         }
     174           15 :     })
     175           15 : }
     176              : 
     177            9 : fn serialize_storage_class<S: serde::Serializer>(
     178            9 :     val: &Option<StorageClass>,
     179            9 :     serializer: S,
     180            9 : ) -> Result<S::Ok, S::Error> {
     181            9 :     let val = val.as_ref().map(StorageClass::as_str);
     182            9 :     Option::<&str>::serialize(&val, serializer)
     183            9 : }
     184              : 
     185              : impl RemoteStorageConfig {
     186              :     pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120);
     187              : 
     188           10 :     pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<RemoteStorageConfig> {
     189           10 :         Ok(utils::toml_edit_ext::deserialize_item(toml)?)
     190           10 :     }
     191              : 
     192            9 :     pub fn from_toml_str(input: &str) -> anyhow::Result<RemoteStorageConfig> {
     193            9 :         let toml_document = toml_edit::DocumentMut::from_str(input)?;
     194            9 :         if let Some(item) = toml_document.get("remote_storage") {
     195            0 :             return Self::from_toml(item);
     196            9 :         }
     197            9 :         Self::from_toml(toml_document.as_item())
     198            9 :     }
     199              : }
     200              : 
     201              : #[cfg(test)]
     202              : mod tests {
     203              :     use super::*;
     204              : 
     205            9 :     fn parse(input: &str) -> anyhow::Result<RemoteStorageConfig> {
     206            9 :         RemoteStorageConfig::from_toml_str(input)
     207            9 :     }
     208              : 
     209              :     #[test]
     210            3 :     fn parse_localfs_config_with_timeout() {
     211            3 :         let input = "local_path = '.'
     212            3 : timeout = '5s'";
     213            3 : 
     214            3 :         let config = parse(input).unwrap();
     215            3 : 
     216            3 :         assert_eq!(
     217            3 :             config,
     218            3 :             RemoteStorageConfig {
     219            3 :                 storage: RemoteStorageKind::LocalFs {
     220            3 :                     local_path: Utf8PathBuf::from(".")
     221            3 :                 },
     222            3 :                 timeout: Duration::from_secs(5)
     223            3 :             }
     224            3 :         );
     225            3 :     }
     226              : 
     227              :     #[test]
     228            3 :     fn test_s3_parsing() {
     229            3 :         let toml = "\
     230            3 :     bucket_name = 'foo-bar'
     231            3 :     bucket_region = 'eu-central-1'
     232            3 :     upload_storage_class = 'INTELLIGENT_TIERING'
     233            3 :     timeout = '7s'
     234            3 :     ";
     235            3 : 
     236            3 :         let config = parse(toml).unwrap();
     237            3 : 
     238            3 :         assert_eq!(
     239            3 :             config,
     240            3 :             RemoteStorageConfig {
     241            3 :                 storage: RemoteStorageKind::AwsS3(S3Config {
     242            3 :                     bucket_name: "foo-bar".into(),
     243            3 :                     bucket_region: "eu-central-1".into(),
     244            3 :                     prefix_in_bucket: None,
     245            3 :                     endpoint: None,
     246            3 :                     concurrency_limit: default_remote_storage_s3_concurrency_limit(),
     247            3 :                     max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
     248            3 :                     upload_storage_class: Some(StorageClass::IntelligentTiering),
     249            3 :                 }),
     250            3 :                 timeout: Duration::from_secs(7)
     251            3 :             }
     252            3 :         );
     253            3 :     }
     254              : 
     255              :     #[test]
     256            3 :     fn test_storage_class_serde_roundtrip() {
     257            3 :         let classes = [
     258            3 :             None,
     259            3 :             Some(StorageClass::Standard),
     260            3 :             Some(StorageClass::IntelligentTiering),
     261            3 :         ];
     262           12 :         for class in classes {
     263           27 :             #[derive(Serialize, Deserialize)]
     264            9 :             struct Wrapper {
     265              :                 #[serde(
     266              :                     deserialize_with = "deserialize_storage_class",
     267              :                     serialize_with = "serialize_storage_class"
     268              :                 )]
     269              :                 class: Option<StorageClass>,
     270              :             }
     271            9 :             let wrapped = Wrapper {
     272            9 :                 class: class.clone(),
     273            9 :             };
     274            9 :             let serialized = serde_json::to_string(&wrapped).unwrap();
     275            9 :             let deserialized: Wrapper = serde_json::from_str(&serialized).unwrap();
     276            9 :             assert_eq!(class, deserialized.class);
     277              :         }
     278            3 :     }
     279              : 
     280              :     #[test]
     281            3 :     fn test_azure_parsing() {
     282            3 :         let toml = "\
     283            3 :     container_name = 'foo-bar'
     284            3 :     container_region = 'westeurope'
     285            3 :     upload_storage_class = 'INTELLIGENT_TIERING'
     286            3 :     timeout = '7s'
     287            3 :     ";
     288            3 : 
     289            3 :         let config = parse(toml).unwrap();
     290            3 : 
     291            3 :         assert_eq!(
     292            3 :             config,
     293            3 :             RemoteStorageConfig {
     294            3 :                 storage: RemoteStorageKind::AzureContainer(AzureConfig {
     295            3 :                     container_name: "foo-bar".into(),
     296            3 :                     storage_account: None,
     297            3 :                     container_region: "westeurope".into(),
     298            3 :                     prefix_in_container: None,
     299            3 :                     concurrency_limit: default_remote_storage_azure_concurrency_limit(),
     300            3 :                     max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
     301            3 :                 }),
     302            3 :                 timeout: Duration::from_secs(7)
     303            3 :             }
     304            3 :         );
     305            3 :     }
     306              : }
        

Generated by: LCOV version 2.1-beta