Line data Source code
1 : use std::{fmt::Debug, num::NonZeroUsize, str::FromStr, time::Duration};
2 :
3 : use aws_sdk_s3::types::StorageClass;
4 : use camino::Utf8PathBuf;
5 :
6 : use serde::{Deserialize, Serialize};
7 :
8 : use crate::{
9 : DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT,
10 : DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
11 : };
12 :
13 : /// External backup storage configuration, enough for creating a client for that storage.
14 116 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
15 : pub struct RemoteStorageConfig {
16 : /// The storage connection configuration.
17 : #[serde(flatten)]
18 : pub storage: RemoteStorageKind,
19 : /// A common timeout enforced for all requests after concurrency limiter permit has been
20 : /// acquired.
21 : #[serde(
22 : with = "humantime_serde",
23 : default = "default_timeout",
24 : skip_serializing_if = "is_default_timeout"
25 : )]
26 : pub timeout: Duration,
27 : }
28 :
29 12 : fn default_timeout() -> Duration {
30 12 : RemoteStorageConfig::DEFAULT_TIMEOUT
31 12 : }
32 :
33 0 : fn is_default_timeout(d: &Duration) -> bool {
34 0 : *d == RemoteStorageConfig::DEFAULT_TIMEOUT
35 0 : }
36 :
37 : /// A kind of a remote storage to connect to, with its connection configuration.
38 92 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
39 : #[serde(untagged)]
40 : pub enum RemoteStorageKind {
41 : /// Storage based on local file system.
42 : /// Specify a root folder to place all stored files into.
43 : LocalFs { local_path: Utf8PathBuf },
44 : /// AWS S3 based storage, storing all files in the S3 bucket
45 : /// specified by the config
46 : AwsS3(S3Config),
47 : /// Azure Blob based storage, storing all files in the container
48 : /// specified by the config
49 : AzureContainer(AzureConfig),
50 : }
51 :
52 : /// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
53 84 : #[derive(Clone, PartialEq, Eq, Deserialize, Serialize)]
54 : pub struct S3Config {
55 : /// Name of the bucket to connect to.
56 : pub bucket_name: String,
57 : /// The region where the bucket is located at.
58 : pub bucket_region: String,
59 : /// A "subfolder" in the bucket, to use the same bucket separately by multiple remote storage users at once.
60 : pub prefix_in_bucket: Option<String>,
61 : /// A base URL to send S3 requests to.
62 : /// By default, the endpoint is derived from a region name, assuming it's
63 : /// an AWS S3 region name, erroring on wrong region name.
64 : /// Endpoint provides a way to support other S3 flavors and their regions.
65 : ///
66 : /// Example: `http://127.0.0.1:5000`
67 : pub endpoint: Option<String>,
68 : /// AWS S3 has various limits on its API calls, we need not to exceed those.
69 : /// See [`DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT`] for more details.
70 : #[serde(default = "default_remote_storage_s3_concurrency_limit")]
71 : pub concurrency_limit: NonZeroUsize,
72 : #[serde(default = "default_max_keys_per_list_response")]
73 : pub max_keys_per_list_response: Option<i32>,
74 : #[serde(
75 : deserialize_with = "deserialize_storage_class",
76 : serialize_with = "serialize_storage_class",
77 : default
78 : )]
79 : pub upload_storage_class: Option<StorageClass>,
80 : }
81 :
82 10 : fn default_remote_storage_s3_concurrency_limit() -> NonZeroUsize {
83 10 : DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
84 10 : .try_into()
85 10 : .unwrap()
86 10 : }
87 :
88 14 : fn default_max_keys_per_list_response() -> Option<i32> {
89 14 : DEFAULT_MAX_KEYS_PER_LIST_RESPONSE
90 14 : }
91 :
92 : impl Debug for S3Config {
93 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 0 : f.debug_struct("S3Config")
95 0 : .field("bucket_name", &self.bucket_name)
96 0 : .field("bucket_region", &self.bucket_region)
97 0 : .field("prefix_in_bucket", &self.prefix_in_bucket)
98 0 : .field("concurrency_limit", &self.concurrency_limit)
99 0 : .field(
100 0 : "max_keys_per_list_response",
101 0 : &self.max_keys_per_list_response,
102 0 : )
103 0 : .finish()
104 0 : }
105 : }
106 :
107 : /// Azure bucket coordinates and access credentials to manage the bucket contents (read and write).
108 18 : #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
109 : pub struct AzureConfig {
110 : /// Name of the container to connect to.
111 : pub container_name: String,
112 : /// Name of the storage account the container is inside of
113 : pub storage_account: Option<String>,
114 : /// The region where the bucket is located at.
115 : pub container_region: String,
116 : /// A "subfolder" in the container, to use the same container separately by multiple remote storage users at once.
117 : pub prefix_in_container: Option<String>,
118 : /// Azure has various limits on its API calls, we need not to exceed those.
119 : /// See [`DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT`] for more details.
120 : #[serde(default = "default_remote_storage_azure_concurrency_limit")]
121 : pub concurrency_limit: NonZeroUsize,
122 : #[serde(default = "default_max_keys_per_list_response")]
123 : pub max_keys_per_list_response: Option<i32>,
124 : }
125 :
126 8 : fn default_remote_storage_azure_concurrency_limit() -> NonZeroUsize {
127 8 : NonZeroUsize::new(DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT).unwrap()
128 8 : }
129 :
130 : impl Debug for AzureConfig {
131 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 0 : f.debug_struct("AzureConfig")
133 0 : .field("bucket_name", &self.container_name)
134 0 : .field("storage_account", &self.storage_account)
135 0 : .field("bucket_region", &self.container_region)
136 0 : .field("prefix_in_container", &self.prefix_in_container)
137 0 : .field("concurrency_limit", &self.concurrency_limit)
138 0 : .field(
139 0 : "max_keys_per_list_response",
140 0 : &self.max_keys_per_list_response,
141 0 : )
142 0 : .finish()
143 0 : }
144 : }
145 :
146 8 : fn deserialize_storage_class<'de, D: serde::Deserializer<'de>>(
147 8 : deserializer: D,
148 8 : ) -> Result<Option<StorageClass>, D::Error> {
149 8 : Option::<String>::deserialize(deserializer).and_then(|s| {
150 8 : if let Some(s) = s {
151 : use serde::de::Error;
152 8 : let storage_class = StorageClass::from_str(&s).expect("infallible");
153 : #[allow(deprecated)]
154 8 : if matches!(storage_class, StorageClass::Unknown(_)) {
155 0 : return Err(D::Error::custom(format!(
156 0 : "Specified storage class unknown to SDK: '{s}'. Allowed values: {:?}",
157 0 : StorageClass::values()
158 0 : )));
159 8 : }
160 8 : Ok(Some(storage_class))
161 : } else {
162 0 : Ok(None)
163 : }
164 8 : })
165 8 : }
166 :
167 0 : fn serialize_storage_class<S: serde::Serializer>(
168 0 : val: &Option<StorageClass>,
169 0 : serializer: S,
170 0 : ) -> Result<S::Ok, S::Error> {
171 0 : let val = val.as_ref().map(StorageClass::as_str);
172 0 : Option::<&str>::serialize(&val, serializer)
173 0 : }
174 :
175 : impl RemoteStorageConfig {
176 : pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120);
177 :
178 24 : pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<RemoteStorageConfig> {
179 24 : Ok(utils::toml_edit_ext::deserialize_item(toml)?)
180 24 : }
181 : }
182 :
183 : #[cfg(test)]
184 : mod tests {
185 : use super::*;
186 :
187 12 : fn parse(input: &str) -> anyhow::Result<RemoteStorageConfig> {
188 12 : let toml = input.parse::<toml_edit::Document>().unwrap();
189 12 : RemoteStorageConfig::from_toml(toml.as_item())
190 12 : }
191 :
192 : #[test]
193 4 : fn parse_localfs_config_with_timeout() {
194 4 : let input = "local_path = '.'
195 4 : timeout = '5s'";
196 4 :
197 4 : let config = parse(input).unwrap();
198 4 :
199 4 : assert_eq!(
200 4 : config,
201 4 : RemoteStorageConfig {
202 4 : storage: RemoteStorageKind::LocalFs {
203 4 : local_path: Utf8PathBuf::from(".")
204 4 : },
205 4 : timeout: Duration::from_secs(5)
206 4 : }
207 4 : );
208 4 : }
209 :
210 : #[test]
211 4 : fn test_s3_parsing() {
212 4 : let toml = "\
213 4 : bucket_name = 'foo-bar'
214 4 : bucket_region = 'eu-central-1'
215 4 : upload_storage_class = 'INTELLIGENT_TIERING'
216 4 : timeout = '7s'
217 4 : ";
218 4 :
219 4 : let config = parse(toml).unwrap();
220 4 :
221 4 : assert_eq!(
222 4 : config,
223 4 : RemoteStorageConfig {
224 4 : storage: RemoteStorageKind::AwsS3(S3Config {
225 4 : bucket_name: "foo-bar".into(),
226 4 : bucket_region: "eu-central-1".into(),
227 4 : prefix_in_bucket: None,
228 4 : endpoint: None,
229 4 : concurrency_limit: default_remote_storage_s3_concurrency_limit(),
230 4 : max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
231 4 : upload_storage_class: Some(StorageClass::IntelligentTiering),
232 4 : }),
233 4 : timeout: Duration::from_secs(7)
234 4 : }
235 4 : );
236 4 : }
237 :
238 : #[test]
239 4 : fn test_azure_parsing() {
240 4 : let toml = "\
241 4 : container_name = 'foo-bar'
242 4 : container_region = 'westeurope'
243 4 : upload_storage_class = 'INTELLIGENT_TIERING'
244 4 : timeout = '7s'
245 4 : ";
246 4 :
247 4 : let config = parse(toml).unwrap();
248 4 :
249 4 : assert_eq!(
250 4 : config,
251 4 : RemoteStorageConfig {
252 4 : storage: RemoteStorageKind::AzureContainer(AzureConfig {
253 4 : container_name: "foo-bar".into(),
254 4 : storage_account: None,
255 4 : container_region: "westeurope".into(),
256 4 : prefix_in_container: None,
257 4 : concurrency_limit: default_remote_storage_azure_concurrency_limit(),
258 4 : max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
259 4 : }),
260 4 : timeout: Duration::from_secs(7)
261 4 : }
262 4 : );
263 4 : }
264 : }
|