LCOV - code coverage report
Current view: top level - libs/remote_storage/src - config.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 65.4 % 191 125
Test Date: 2025-07-16 12:29:03 Functions: 37.0 % 46 17

            Line data    Source code
       1              : use std::fmt::Debug;
       2              : use std::num::NonZeroUsize;
       3              : use std::str::FromStr;
       4              : use std::time::Duration;
       5              : 
       6              : use aws_sdk_s3::types::StorageClass;
       7              : use camino::Utf8PathBuf;
       8              : use serde::{Deserialize, Serialize};
       9              : 
      10              : use crate::{
      11              :     DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT,
      12              :     DEFAULT_REMOTE_STORAGE_LOCALFS_CONCURRENCY_LIMIT, DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
      13              : };
      14              : 
      15              : /// External backup storage configuration, enough for creating a client for that storage.
      16              : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
      17              : pub struct RemoteStorageConfig {
      18              :     /// The storage connection configuration.
      19              :     #[serde(flatten)]
      20              :     pub storage: RemoteStorageKind,
      21              :     /// A common timeout enforced for all requests after concurrency limiter permit has been
      22              :     /// acquired.
      23              :     #[serde(
      24              :         with = "humantime_serde",
      25              :         default = "default_timeout",
      26              :         skip_serializing_if = "is_default_timeout"
      27              :     )]
      28              :     pub timeout: Duration,
      29              :     /// Alternative timeout used for metadata objects which are expected to be small
      30              :     #[serde(
      31              :         with = "humantime_serde",
      32              :         default = "default_small_timeout",
      33              :         skip_serializing_if = "is_default_small_timeout"
      34              :     )]
      35              :     pub small_timeout: Duration,
      36              : }
      37              : 
      38              : impl RemoteStorageKind {
      39            0 :     pub fn bucket_name(&self) -> Option<&str> {
      40            0 :         match self {
      41            0 :             RemoteStorageKind::LocalFs { .. } => None,
      42            0 :             RemoteStorageKind::AwsS3(config) => Some(&config.bucket_name),
      43            0 :             RemoteStorageKind::AzureContainer(config) => Some(&config.container_name),
      44              :         }
      45            0 :     }
      46              : }
      47              : 
      48              : impl RemoteStorageConfig {
      49              :     /// Helper to fetch the configured concurrency limit.
      50            0 :     pub fn concurrency_limit(&self) -> usize {
      51            0 :         match &self.storage {
      52            0 :             RemoteStorageKind::LocalFs { .. } => DEFAULT_REMOTE_STORAGE_LOCALFS_CONCURRENCY_LIMIT,
      53            0 :             RemoteStorageKind::AwsS3(c) => c.concurrency_limit.into(),
      54            0 :             RemoteStorageKind::AzureContainer(c) => c.concurrency_limit.into(),
      55              :         }
      56            0 :     }
      57              : }
      58              : 
      59            1 : fn default_timeout() -> Duration {
      60            1 :     RemoteStorageConfig::DEFAULT_TIMEOUT
      61            1 : }
      62              : 
      63           10 : fn default_small_timeout() -> Duration {
      64           10 :     RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT
      65           10 : }
      66              : 
      67            0 : fn is_default_timeout(d: &Duration) -> bool {
      68            0 :     *d == RemoteStorageConfig::DEFAULT_TIMEOUT
      69            0 : }
      70              : 
      71            0 : fn is_default_small_timeout(d: &Duration) -> bool {
      72            0 :     *d == RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT
      73            0 : }
      74              : 
      75              : /// A kind of a remote storage to connect to, with its connection configuration.
      76            0 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
      77              : #[serde(untagged)]
      78              : pub enum RemoteStorageKind {
      79              :     /// Storage based on local file system.
      80              :     /// Specify a root folder to place all stored files into.
      81              :     LocalFs { local_path: Utf8PathBuf },
      82              :     /// AWS S3 based storage, storing all files in the S3 bucket
      83              :     /// specified by the config
      84              :     AwsS3(S3Config),
      85              :     /// Azure Blob based storage, storing all files in the container
      86              :     /// specified by the config
      87              :     AzureContainer(AzureConfig),
      88              : }
      89              : 
      90            0 : #[derive(Deserialize)]
      91              : #[serde(tag = "type")]
      92              : /// Version of RemoteStorageKind which deserializes with type: LocalFs | AwsS3 | AzureContainer
      93              : /// Needed for endpoint storage service
      94              : pub enum TypedRemoteStorageKind {
      95              :     LocalFs { local_path: Utf8PathBuf },
      96              :     AwsS3(S3Config),
      97              :     AzureContainer(AzureConfig),
      98              : }
      99              : 
     100              : impl From<TypedRemoteStorageKind> for RemoteStorageKind {
     101            0 :     fn from(value: TypedRemoteStorageKind) -> Self {
     102            0 :         match value {
     103            0 :             TypedRemoteStorageKind::LocalFs { local_path } => {
     104            0 :                 RemoteStorageKind::LocalFs { local_path }
     105              :             }
     106            0 :             TypedRemoteStorageKind::AwsS3(v) => RemoteStorageKind::AwsS3(v),
     107            0 :             TypedRemoteStorageKind::AzureContainer(v) => RemoteStorageKind::AzureContainer(v),
     108              :         }
     109            0 :     }
     110              : }
     111              : 
     112              : /// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
     113              : #[derive(Clone, PartialEq, Eq, Deserialize, Serialize)]
     114              : pub struct S3Config {
     115              :     /// Name of the bucket to connect to.
     116              :     pub bucket_name: String,
     117              :     /// The region where the bucket is located at.
     118              :     pub bucket_region: String,
     119              :     /// A "subfolder" in the bucket, to use the same bucket separately by multiple remote storage users at once.
     120              :     pub prefix_in_bucket: Option<String>,
     121              :     /// A base URL to send S3 requests to.
     122              :     /// By default, the endpoint is derived from a region name, assuming it's
     123              :     /// an AWS S3 region name, erroring on wrong region name.
     124              :     /// Endpoint provides a way to support other S3 flavors and their regions.
     125              :     ///
     126              :     /// Example: `http://127.0.0.1:5000`
     127              :     pub endpoint: Option<String>,
     128              :     /// AWS S3 has various limits on its API calls, we need not to exceed those.
     129              :     /// See [`DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT`] for more details.
     130              :     #[serde(default = "default_remote_storage_s3_concurrency_limit")]
     131              :     pub concurrency_limit: NonZeroUsize,
     132              :     #[serde(default = "default_max_keys_per_list_response")]
     133              :     pub max_keys_per_list_response: Option<i32>,
     134              :     #[serde(
     135              :         deserialize_with = "deserialize_storage_class",
     136              :         serialize_with = "serialize_storage_class",
     137              :         default
     138              :     )]
     139              :     pub upload_storage_class: Option<StorageClass>,
     140              : }
     141              : 
     142            7 : fn default_remote_storage_s3_concurrency_limit() -> NonZeroUsize {
     143              :     DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
     144            7 :         .try_into()
     145            7 :         .unwrap()
     146            7 : }
     147              : 
     148            7 : fn default_max_keys_per_list_response() -> Option<i32> {
     149            7 :     DEFAULT_MAX_KEYS_PER_LIST_RESPONSE
     150            7 : }
     151              : 
     152            0 : fn default_azure_conn_pool_size() -> usize {
     153              :     // By default, the Azure SDK does no connection pooling, due to historic reports of hard-to-reproduce issues
     154              :     // (https://github.com/hyperium/hyper/issues/2312)
     155              :     //
     156              :     // However, using connection pooling is important to avoid exhausting client ports when
     157              :     // doing huge numbers of requests (https://github.com/neondatabase/cloud/issues/20971)
     158              :     //
     159              :     // We therefore enable a modest pool size by default: this may be configured to zero if
     160              :     // issues like the alleged upstream hyper issue appear.
     161            0 :     8
     162            0 : }
     163              : 
     164              : impl Debug for S3Config {
     165            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     166            0 :         f.debug_struct("S3Config")
     167            0 :             .field("bucket_name", &self.bucket_name)
     168            0 :             .field("bucket_region", &self.bucket_region)
     169            0 :             .field("prefix_in_bucket", &self.prefix_in_bucket)
     170            0 :             .field("concurrency_limit", &self.concurrency_limit)
     171            0 :             .field(
     172            0 :                 "max_keys_per_list_response",
     173            0 :                 &self.max_keys_per_list_response,
     174            0 :             )
     175            0 :             .finish()
     176            0 :     }
     177              : }
     178              : 
     179              : /// Azure  bucket coordinates and access credentials to manage the bucket contents (read and write).
     180            0 : #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
     181              : pub struct AzureConfig {
     182              :     /// Name of the container to connect to.
     183              :     pub container_name: String,
     184              :     /// Name of the storage account the container is inside of
     185              :     pub storage_account: Option<String>,
     186              :     /// The region where the bucket is located at.
     187              :     pub container_region: String,
     188              :     /// A "subfolder" in the container, to use the same container separately by multiple remote storage users at once.
     189              :     pub prefix_in_container: Option<String>,
     190              :     /// Azure has various limits on its API calls, we need not to exceed those.
     191              :     /// See [`DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT`] for more details.
     192              :     #[serde(default = "default_remote_storage_azure_concurrency_limit")]
     193              :     pub concurrency_limit: NonZeroUsize,
     194              :     #[serde(default = "default_max_keys_per_list_response")]
     195              :     pub max_keys_per_list_response: Option<i32>,
     196              :     #[serde(default = "default_azure_conn_pool_size")]
     197              :     pub conn_pool_size: usize,
     198              :     /* BEGIN_HADRON */
     199              :     #[serde(default = "default_azure_put_block_size_mb")]
     200              :     pub put_block_size_mb: Option<usize>,
     201              :     /* END_HADRON */
     202              : }
     203              : 
     204              : /* BEGIN_HADRON */
     205            0 : fn default_azure_put_block_size_mb() -> Option<usize> {
     206              :     // Disable parallel upload by default.
     207            0 :     Some(0)
     208            0 : }
     209              : /* END_HADRON */
     210              : 
     211            6 : fn default_remote_storage_azure_concurrency_limit() -> NonZeroUsize {
     212            6 :     NonZeroUsize::new(DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT).unwrap()
     213            6 : }
     214              : 
     215              : impl Debug for AzureConfig {
     216            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     217            0 :         f.debug_struct("AzureConfig")
     218            0 :             .field("bucket_name", &self.container_name)
     219            0 :             .field("storage_account", &self.storage_account)
     220            0 :             .field("bucket_region", &self.container_region)
     221            0 :             .field("prefix_in_container", &self.prefix_in_container)
     222            0 :             .field("concurrency_limit", &self.concurrency_limit)
     223            0 :             .field(
     224            0 :                 "max_keys_per_list_response",
     225            0 :                 &self.max_keys_per_list_response,
     226            0 :             )
     227            0 :             /* BEGIN_HADRON */
     228            0 :             .field("put_block_size_mb", &self.put_block_size_mb)
     229              :             /* END_HADRON */
     230            0 :             .finish()
     231            0 :     }
     232              : }
     233              : 
     234           15 : fn deserialize_storage_class<'de, D: serde::Deserializer<'de>>(
     235           15 :     deserializer: D,
     236           15 : ) -> Result<Option<StorageClass>, D::Error> {
     237           15 :     Option::<String>::deserialize(deserializer).and_then(|s| {
     238           15 :         if let Some(s) = s {
     239              :             use serde::de::Error;
     240           12 :             let storage_class = StorageClass::from_str(&s).expect("infallible");
     241              :             #[allow(deprecated)]
     242           12 :             if matches!(storage_class, StorageClass::Unknown(_)) {
     243            0 :                 return Err(D::Error::custom(format!(
     244            0 :                     "Specified storage class unknown to SDK: '{s}'. Allowed values: {:?}",
     245            0 :                     StorageClass::values()
     246            0 :                 )));
     247           12 :             }
     248           12 :             Ok(Some(storage_class))
     249              :         } else {
     250            3 :             Ok(None)
     251              :         }
     252           15 :     })
     253           15 : }
     254              : 
     255            9 : fn serialize_storage_class<S: serde::Serializer>(
     256            9 :     val: &Option<StorageClass>,
     257            9 :     serializer: S,
     258            9 : ) -> Result<S::Ok, S::Error> {
     259            9 :     let val = val.as_ref().map(StorageClass::as_str);
     260            9 :     Option::<&str>::serialize(&val, serializer)
     261            9 : }
     262              : 
     263              : impl RemoteStorageConfig {
     264              :     pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120);
     265              :     pub const DEFAULT_SMALL_TIMEOUT: Duration = std::time::Duration::from_secs(30);
     266              : 
     267           10 :     pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<RemoteStorageConfig> {
     268           10 :         Ok(utils::toml_edit_ext::deserialize_item(toml)?)
     269           10 :     }
     270              : 
     271            9 :     pub fn from_toml_str(input: &str) -> anyhow::Result<RemoteStorageConfig> {
     272            9 :         let toml_document = toml_edit::DocumentMut::from_str(input)?;
     273            9 :         if let Some(item) = toml_document.get("remote_storage") {
     274            0 :             return Self::from_toml(item);
     275            9 :         }
     276            9 :         Self::from_toml(toml_document.as_item())
     277            9 :     }
     278              : }
     279              : 
     280              : #[cfg(test)]
     281              : mod tests {
     282              :     use super::*;
     283              : 
     284            9 :     fn parse(input: &str) -> anyhow::Result<RemoteStorageConfig> {
     285            9 :         RemoteStorageConfig::from_toml_str(input)
     286            9 :     }
     287              : 
     288              :     #[test]
     289            3 :     fn parse_localfs_config_with_timeout() {
     290            3 :         let input = "local_path = '.'
     291            3 : timeout = '5s'";
     292              : 
     293            3 :         let config = parse(input).unwrap();
     294              : 
     295            3 :         assert_eq!(
     296              :             config,
     297            3 :             RemoteStorageConfig {
     298            3 :                 storage: RemoteStorageKind::LocalFs {
     299            3 :                     local_path: Utf8PathBuf::from(".")
     300            3 :                 },
     301            3 :                 timeout: Duration::from_secs(5),
     302            3 :                 small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT
     303            3 :             }
     304              :         );
     305            3 :     }
     306              : 
     307              :     #[test]
     308            3 :     fn test_s3_parsing() {
     309            3 :         let toml = "\
     310            3 :     bucket_name = 'foo-bar'
     311            3 :     bucket_region = 'eu-central-1'
     312            3 :     upload_storage_class = 'INTELLIGENT_TIERING'
     313            3 :     timeout = '7s'
     314            3 :     ";
     315              : 
     316            3 :         let config = parse(toml).unwrap();
     317              : 
     318            3 :         assert_eq!(
     319              :             config,
     320            3 :             RemoteStorageConfig {
     321            3 :                 storage: RemoteStorageKind::AwsS3(S3Config {
     322            3 :                     bucket_name: "foo-bar".into(),
     323            3 :                     bucket_region: "eu-central-1".into(),
     324            3 :                     prefix_in_bucket: None,
     325            3 :                     endpoint: None,
     326            3 :                     concurrency_limit: default_remote_storage_s3_concurrency_limit(),
     327            3 :                     max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
     328            3 :                     upload_storage_class: Some(StorageClass::IntelligentTiering),
     329            3 :                 }),
     330            3 :                 timeout: Duration::from_secs(7),
     331            3 :                 small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT
     332            3 :             }
     333              :         );
     334            3 :     }
     335              : 
     336              :     #[test]
     337            3 :     fn test_storage_class_serde_roundtrip() {
     338            3 :         let classes = [
     339            3 :             None,
     340            3 :             Some(StorageClass::Standard),
     341            3 :             Some(StorageClass::IntelligentTiering),
     342            3 :         ];
     343           12 :         for class in classes {
     344              :             #[derive(Serialize, Deserialize)]
     345              :             struct Wrapper {
     346              :                 #[serde(
     347              :                     deserialize_with = "deserialize_storage_class",
     348              :                     serialize_with = "serialize_storage_class"
     349              :                 )]
     350              :                 class: Option<StorageClass>,
     351              :             }
     352            9 :             let wrapped = Wrapper {
     353            9 :                 class: class.clone(),
     354            9 :             };
     355            9 :             let serialized = serde_json::to_string(&wrapped).unwrap();
     356            9 :             let deserialized: Wrapper = serde_json::from_str(&serialized).unwrap();
     357            9 :             assert_eq!(class, deserialized.class);
     358              :         }
     359            3 :     }
     360              : 
     361              :     #[test]
     362            3 :     fn test_azure_parsing() {
     363            3 :         let toml = "\
     364            3 :     container_name = 'foo-bar'
     365            3 :     container_region = 'westeurope'
     366            3 :     upload_storage_class = 'INTELLIGENT_TIERING'
     367            3 :     timeout = '7s'
     368            3 :     conn_pool_size = 8
     369            3 :     put_block_size_mb = 1024
     370            3 :     ";
     371              : 
     372            3 :         let config = parse(toml).unwrap();
     373              : 
     374            3 :         assert_eq!(
     375              :             config,
     376            3 :             RemoteStorageConfig {
     377            3 :                 storage: RemoteStorageKind::AzureContainer(AzureConfig {
     378            3 :                     container_name: "foo-bar".into(),
     379            3 :                     storage_account: None,
     380            3 :                     container_region: "westeurope".into(),
     381            3 :                     prefix_in_container: None,
     382            3 :                     concurrency_limit: default_remote_storage_azure_concurrency_limit(),
     383            3 :                     max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
     384            3 :                     conn_pool_size: 8,
     385            3 :                     /* BEGIN_HADRON */
     386            3 :                     put_block_size_mb: Some(1024),
     387            3 :                     /* END_HADRON */
     388            3 :                 }),
     389            3 :                 timeout: Duration::from_secs(7),
     390            3 :                 small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT
     391            3 :             }
     392              :         );
     393            3 :     }
     394              : }
        

Generated by: LCOV version 2.1-beta