Line data Source code
1 : use std::{fmt::Debug, num::NonZeroUsize, str::FromStr, time::Duration};
2 :
3 : use aws_sdk_s3::types::StorageClass;
4 : use camino::Utf8PathBuf;
5 :
6 : use serde::{Deserialize, Serialize};
7 :
8 : use crate::{
9 : DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT,
10 : DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
11 : };
12 :
13 : /// External backup storage configuration, enough for creating a client for that storage.
14 53 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
15 : pub struct RemoteStorageConfig {
16 : /// The storage connection configuration.
17 : #[serde(flatten)]
18 : pub storage: RemoteStorageKind,
19 : /// A common timeout enforced for all requests after concurrency limiter permit has been
20 : /// acquired.
21 : #[serde(
22 : with = "humantime_serde",
23 : default = "default_timeout",
24 : skip_serializing_if = "is_default_timeout"
25 : )]
26 : pub timeout: Duration,
27 : }
28 :
29 : impl RemoteStorageKind {
30 0 : pub fn bucket_name(&self) -> Option<&str> {
31 0 : match self {
32 0 : RemoteStorageKind::LocalFs { .. } => None,
33 0 : RemoteStorageKind::AwsS3(config) => Some(&config.bucket_name),
34 0 : RemoteStorageKind::AzureContainer(config) => Some(&config.container_name),
35 : }
36 0 : }
37 : }
38 :
39 1 : fn default_timeout() -> Duration {
40 1 : RemoteStorageConfig::DEFAULT_TIMEOUT
41 1 : }
42 :
43 0 : fn is_default_timeout(d: &Duration) -> bool {
44 0 : *d == RemoteStorageConfig::DEFAULT_TIMEOUT
45 0 : }
46 :
47 : /// A kind of a remote storage to connect to, with its connection configuration.
48 35 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
49 : #[serde(untagged)]
50 : pub enum RemoteStorageKind {
51 : /// Storage based on local file system.
52 : /// Specify a root folder to place all stored files into.
53 : LocalFs { local_path: Utf8PathBuf },
54 : /// AWS S3 based storage, storing all files in the S3 bucket
55 : /// specified by the config
56 : AwsS3(S3Config),
57 : /// Azure Blob based storage, storing all files in the container
58 : /// specified by the config
59 : AzureContainer(AzureConfig),
60 : }
61 :
62 : /// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
63 35 : #[derive(Clone, PartialEq, Eq, Deserialize, Serialize)]
64 : pub struct S3Config {
65 : /// Name of the bucket to connect to.
66 : pub bucket_name: String,
67 : /// The region where the bucket is located at.
68 : pub bucket_region: String,
69 : /// A "subfolder" in the bucket, to use the same bucket separately by multiple remote storage users at once.
70 : pub prefix_in_bucket: Option<String>,
71 : /// A base URL to send S3 requests to.
72 : /// By default, the endpoint is derived from a region name, assuming it's
73 : /// an AWS S3 region name, erroring on wrong region name.
74 : /// Endpoint provides a way to support other S3 flavors and their regions.
75 : ///
76 : /// Example: `http://127.0.0.1:5000`
77 : pub endpoint: Option<String>,
78 : /// AWS S3 has various limits on its API calls, we need not to exceed those.
79 : /// See [`DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT`] for more details.
80 : #[serde(default = "default_remote_storage_s3_concurrency_limit")]
81 : pub concurrency_limit: NonZeroUsize,
82 : #[serde(default = "default_max_keys_per_list_response")]
83 : pub max_keys_per_list_response: Option<i32>,
84 : #[serde(
85 : deserialize_with = "deserialize_storage_class",
86 : serialize_with = "serialize_storage_class",
87 : default
88 : )]
89 : pub upload_storage_class: Option<StorageClass>,
90 : }
91 :
92 7 : fn default_remote_storage_s3_concurrency_limit() -> NonZeroUsize {
93 7 : DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
94 7 : .try_into()
95 7 : .unwrap()
96 7 : }
97 :
98 7 : fn default_max_keys_per_list_response() -> Option<i32> {
99 7 : DEFAULT_MAX_KEYS_PER_LIST_RESPONSE
100 7 : }
101 :
102 : impl Debug for S3Config {
103 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 0 : f.debug_struct("S3Config")
105 0 : .field("bucket_name", &self.bucket_name)
106 0 : .field("bucket_region", &self.bucket_region)
107 0 : .field("prefix_in_bucket", &self.prefix_in_bucket)
108 0 : .field("concurrency_limit", &self.concurrency_limit)
109 0 : .field(
110 0 : "max_keys_per_list_response",
111 0 : &self.max_keys_per_list_response,
112 0 : )
113 0 : .finish()
114 0 : }
115 : }
116 :
117 : /// Azure bucket coordinates and access credentials to manage the bucket contents (read and write).
118 12 : #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
119 : pub struct AzureConfig {
120 : /// Name of the container to connect to.
121 : pub container_name: String,
122 : /// Name of the storage account the container is inside of
123 : pub storage_account: Option<String>,
124 : /// The region where the bucket is located at.
125 : pub container_region: String,
126 : /// A "subfolder" in the container, to use the same container separately by multiple remote storage users at once.
127 : pub prefix_in_container: Option<String>,
128 : /// Azure has various limits on its API calls, we need not to exceed those.
129 : /// See [`DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT`] for more details.
130 : #[serde(default = "default_remote_storage_azure_concurrency_limit")]
131 : pub concurrency_limit: NonZeroUsize,
132 : #[serde(default = "default_max_keys_per_list_response")]
133 : pub max_keys_per_list_response: Option<i32>,
134 : }
135 :
136 6 : fn default_remote_storage_azure_concurrency_limit() -> NonZeroUsize {
137 6 : NonZeroUsize::new(DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT).unwrap()
138 6 : }
139 :
140 : impl Debug for AzureConfig {
141 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 0 : f.debug_struct("AzureConfig")
143 0 : .field("bucket_name", &self.container_name)
144 0 : .field("storage_account", &self.storage_account)
145 0 : .field("bucket_region", &self.container_region)
146 0 : .field("prefix_in_container", &self.prefix_in_container)
147 0 : .field("concurrency_limit", &self.concurrency_limit)
148 0 : .field(
149 0 : "max_keys_per_list_response",
150 0 : &self.max_keys_per_list_response,
151 0 : )
152 0 : .finish()
153 0 : }
154 : }
155 :
156 15 : fn deserialize_storage_class<'de, D: serde::Deserializer<'de>>(
157 15 : deserializer: D,
158 15 : ) -> Result<Option<StorageClass>, D::Error> {
159 15 : Option::<String>::deserialize(deserializer).and_then(|s| {
160 15 : if let Some(s) = s {
161 : use serde::de::Error;
162 12 : let storage_class = StorageClass::from_str(&s).expect("infallible");
163 : #[allow(deprecated)]
164 12 : if matches!(storage_class, StorageClass::Unknown(_)) {
165 0 : return Err(D::Error::custom(format!(
166 0 : "Specified storage class unknown to SDK: '{s}'. Allowed values: {:?}",
167 0 : StorageClass::values()
168 0 : )));
169 12 : }
170 12 : Ok(Some(storage_class))
171 : } else {
172 3 : Ok(None)
173 : }
174 15 : })
175 15 : }
176 :
177 9 : fn serialize_storage_class<S: serde::Serializer>(
178 9 : val: &Option<StorageClass>,
179 9 : serializer: S,
180 9 : ) -> Result<S::Ok, S::Error> {
181 9 : let val = val.as_ref().map(StorageClass::as_str);
182 9 : Option::<&str>::serialize(&val, serializer)
183 9 : }
184 :
185 : impl RemoteStorageConfig {
186 : pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120);
187 :
188 10 : pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<RemoteStorageConfig> {
189 10 : Ok(utils::toml_edit_ext::deserialize_item(toml)?)
190 10 : }
191 :
192 9 : pub fn from_toml_str(input: &str) -> anyhow::Result<RemoteStorageConfig> {
193 9 : let toml_document = toml_edit::DocumentMut::from_str(input)?;
194 9 : if let Some(item) = toml_document.get("remote_storage") {
195 0 : return Self::from_toml(item);
196 9 : }
197 9 : Self::from_toml(toml_document.as_item())
198 9 : }
199 : }
200 :
201 : #[cfg(test)]
202 : mod tests {
203 : use super::*;
204 :
205 9 : fn parse(input: &str) -> anyhow::Result<RemoteStorageConfig> {
206 9 : RemoteStorageConfig::from_toml_str(input)
207 9 : }
208 :
209 : #[test]
210 3 : fn parse_localfs_config_with_timeout() {
211 3 : let input = "local_path = '.'
212 3 : timeout = '5s'";
213 3 :
214 3 : let config = parse(input).unwrap();
215 3 :
216 3 : assert_eq!(
217 3 : config,
218 3 : RemoteStorageConfig {
219 3 : storage: RemoteStorageKind::LocalFs {
220 3 : local_path: Utf8PathBuf::from(".")
221 3 : },
222 3 : timeout: Duration::from_secs(5)
223 3 : }
224 3 : );
225 3 : }
226 :
227 : #[test]
228 3 : fn test_s3_parsing() {
229 3 : let toml = "\
230 3 : bucket_name = 'foo-bar'
231 3 : bucket_region = 'eu-central-1'
232 3 : upload_storage_class = 'INTELLIGENT_TIERING'
233 3 : timeout = '7s'
234 3 : ";
235 3 :
236 3 : let config = parse(toml).unwrap();
237 3 :
238 3 : assert_eq!(
239 3 : config,
240 3 : RemoteStorageConfig {
241 3 : storage: RemoteStorageKind::AwsS3(S3Config {
242 3 : bucket_name: "foo-bar".into(),
243 3 : bucket_region: "eu-central-1".into(),
244 3 : prefix_in_bucket: None,
245 3 : endpoint: None,
246 3 : concurrency_limit: default_remote_storage_s3_concurrency_limit(),
247 3 : max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
248 3 : upload_storage_class: Some(StorageClass::IntelligentTiering),
249 3 : }),
250 3 : timeout: Duration::from_secs(7)
251 3 : }
252 3 : );
253 3 : }
254 :
255 : #[test]
256 3 : fn test_storage_class_serde_roundtrip() {
257 3 : let classes = [
258 3 : None,
259 3 : Some(StorageClass::Standard),
260 3 : Some(StorageClass::IntelligentTiering),
261 3 : ];
262 12 : for class in classes {
263 27 : #[derive(Serialize, Deserialize)]
264 9 : struct Wrapper {
265 : #[serde(
266 : deserialize_with = "deserialize_storage_class",
267 : serialize_with = "serialize_storage_class"
268 : )]
269 : class: Option<StorageClass>,
270 : }
271 9 : let wrapped = Wrapper {
272 9 : class: class.clone(),
273 9 : };
274 9 : let serialized = serde_json::to_string(&wrapped).unwrap();
275 9 : let deserialized: Wrapper = serde_json::from_str(&serialized).unwrap();
276 9 : assert_eq!(class, deserialized.class);
277 : }
278 3 : }
279 :
280 : #[test]
281 3 : fn test_azure_parsing() {
282 3 : let toml = "\
283 3 : container_name = 'foo-bar'
284 3 : container_region = 'westeurope'
285 3 : upload_storage_class = 'INTELLIGENT_TIERING'
286 3 : timeout = '7s'
287 3 : ";
288 3 :
289 3 : let config = parse(toml).unwrap();
290 3 :
291 3 : assert_eq!(
292 3 : config,
293 3 : RemoteStorageConfig {
294 3 : storage: RemoteStorageKind::AzureContainer(AzureConfig {
295 3 : container_name: "foo-bar".into(),
296 3 : storage_account: None,
297 3 : container_region: "westeurope".into(),
298 3 : prefix_in_container: None,
299 3 : concurrency_limit: default_remote_storage_azure_concurrency_limit(),
300 3 : max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
301 3 : }),
302 3 : timeout: Duration::from_secs(7)
303 3 : }
304 3 : );
305 3 : }
306 : }
|