Line data Source code
1 : use std::{fmt::Debug, num::NonZeroUsize, str::FromStr, time::Duration};
2 :
3 : use anyhow::bail;
4 : use aws_sdk_s3::types::StorageClass;
5 : use camino::Utf8PathBuf;
6 :
7 : use serde::{Deserialize, Serialize};
8 :
9 : use crate::{
10 : DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT,
11 : DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
12 : };
13 :
14 : /// External backup storage configuration, enough for creating a client for that storage.
15 114 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
16 : pub struct RemoteStorageConfig {
17 : /// The storage connection configuration.
18 : #[serde(flatten)]
19 : pub storage: RemoteStorageKind,
20 : /// A common timeout enforced for all requests after concurrency limiter permit has been
21 : /// acquired.
22 : #[serde(
23 : with = "humantime_serde",
24 : default = "default_timeout",
25 : skip_serializing_if = "is_default_timeout"
26 : )]
27 : pub timeout: Duration,
28 : }
29 :
30 10 : fn default_timeout() -> Duration {
31 10 : RemoteStorageConfig::DEFAULT_TIMEOUT
32 10 : }
33 :
34 0 : fn is_default_timeout(d: &Duration) -> bool {
35 0 : *d == RemoteStorageConfig::DEFAULT_TIMEOUT
36 0 : }
37 :
38 : /// A kind of a remote storage to connect to, with its connection configuration.
39 90 : #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
40 : #[serde(untagged)]
41 : pub enum RemoteStorageKind {
42 : /// Storage based on local file system.
43 : /// Specify a root folder to place all stored files into.
44 : LocalFs { local_path: Utf8PathBuf },
45 : /// AWS S3 based storage, storing all files in the S3 bucket
46 : /// specified by the config
47 : AwsS3(S3Config),
48 : /// Azure Blob based storage, storing all files in the container
49 : /// specified by the config
50 : AzureContainer(AzureConfig),
51 : }
52 :
53 : /// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
54 82 : #[derive(Clone, PartialEq, Eq, Deserialize, Serialize)]
55 : pub struct S3Config {
56 : /// Name of the bucket to connect to.
57 : pub bucket_name: String,
58 : /// The region where the bucket is located at.
59 : pub bucket_region: String,
60 : /// A "subfolder" in the bucket, to use the same bucket separately by multiple remote storage users at once.
61 : pub prefix_in_bucket: Option<String>,
62 : /// A base URL to send S3 requests to.
63 : /// By default, the endpoint is derived from a region name, assuming it's
64 : /// an AWS S3 region name, erroring on wrong region name.
65 : /// Endpoint provides a way to support other S3 flavors and their regions.
66 : ///
67 : /// Example: `http://127.0.0.1:5000`
68 : pub endpoint: Option<String>,
69 : /// AWS S3 has various limits on its API calls, we need not to exceed those.
70 : /// See [`DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT`] for more details.
71 : #[serde(default = "default_remote_storage_s3_concurrency_limit")]
72 : pub concurrency_limit: NonZeroUsize,
73 : #[serde(default = "default_max_keys_per_list_response")]
74 : pub max_keys_per_list_response: Option<i32>,
75 : #[serde(
76 : deserialize_with = "deserialize_storage_class",
77 : serialize_with = "serialize_storage_class",
78 : default
79 : )]
80 : pub upload_storage_class: Option<StorageClass>,
81 : }
82 :
83 10 : fn default_remote_storage_s3_concurrency_limit() -> NonZeroUsize {
84 10 : DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
85 10 : .try_into()
86 10 : .unwrap()
87 10 : }
88 :
89 14 : fn default_max_keys_per_list_response() -> Option<i32> {
90 14 : DEFAULT_MAX_KEYS_PER_LIST_RESPONSE
91 14 : }
92 :
93 : impl Debug for S3Config {
94 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 0 : f.debug_struct("S3Config")
96 0 : .field("bucket_name", &self.bucket_name)
97 0 : .field("bucket_region", &self.bucket_region)
98 0 : .field("prefix_in_bucket", &self.prefix_in_bucket)
99 0 : .field("concurrency_limit", &self.concurrency_limit)
100 0 : .field(
101 0 : "max_keys_per_list_response",
102 0 : &self.max_keys_per_list_response,
103 0 : )
104 0 : .finish()
105 0 : }
106 : }
107 :
108 : /// Azure bucket coordinates and access credentials to manage the bucket contents (read and write).
109 16 : #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
110 : pub struct AzureConfig {
111 : /// Name of the container to connect to.
112 : pub container_name: String,
113 : /// Name of the storage account the container is inside of
114 : pub storage_account: Option<String>,
115 : /// The region where the bucket is located at.
116 : pub container_region: String,
117 : /// A "subfolder" in the container, to use the same container separately by multiple remote storage users at once.
118 : pub prefix_in_container: Option<String>,
119 : /// Azure has various limits on its API calls, we need not to exceed those.
120 : /// See [`DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT`] for more details.
121 : #[serde(default = "default_remote_storage_azure_concurrency_limit")]
122 : pub concurrency_limit: NonZeroUsize,
123 : #[serde(default = "default_max_keys_per_list_response")]
124 : pub max_keys_per_list_response: Option<i32>,
125 : }
126 :
127 8 : fn default_remote_storage_azure_concurrency_limit() -> NonZeroUsize {
128 8 : NonZeroUsize::new(DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT).unwrap()
129 8 : }
130 :
131 : impl Debug for AzureConfig {
132 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 0 : f.debug_struct("AzureConfig")
134 0 : .field("bucket_name", &self.container_name)
135 0 : .field("storage_account", &self.storage_account)
136 0 : .field("bucket_region", &self.container_region)
137 0 : .field("prefix_in_container", &self.prefix_in_container)
138 0 : .field("concurrency_limit", &self.concurrency_limit)
139 0 : .field(
140 0 : "max_keys_per_list_response",
141 0 : &self.max_keys_per_list_response,
142 0 : )
143 0 : .finish()
144 0 : }
145 : }
146 :
147 8 : fn deserialize_storage_class<'de, D: serde::Deserializer<'de>>(
148 8 : deserializer: D,
149 8 : ) -> Result<Option<StorageClass>, D::Error> {
150 8 : Option::<String>::deserialize(deserializer).and_then(|s| {
151 8 : if let Some(s) = s {
152 : use serde::de::Error;
153 8 : let storage_class = StorageClass::from_str(&s).expect("infallible");
154 : #[allow(deprecated)]
155 8 : if matches!(storage_class, StorageClass::Unknown(_)) {
156 0 : return Err(D::Error::custom(format!(
157 0 : "Specified storage class unknown to SDK: '{s}'. Allowed values: {:?}",
158 0 : StorageClass::values()
159 0 : )));
160 8 : }
161 8 : Ok(Some(storage_class))
162 : } else {
163 0 : Ok(None)
164 : }
165 8 : })
166 8 : }
167 :
168 0 : fn serialize_storage_class<S: serde::Serializer>(
169 0 : val: &Option<StorageClass>,
170 0 : serializer: S,
171 0 : ) -> Result<S::Ok, S::Error> {
172 0 : let val = val.as_ref().map(StorageClass::as_str);
173 0 : Option::<&str>::serialize(&val, serializer)
174 0 : }
175 :
176 : impl RemoteStorageConfig {
177 : pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120);
178 :
179 44 : pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<Option<RemoteStorageConfig>> {
180 44 : let document: toml_edit::Document = match toml {
181 16 : toml_edit::Item::Table(toml) => toml.clone().into(),
182 28 : toml_edit::Item::Value(toml_edit::Value::InlineTable(toml)) => {
183 28 : toml.clone().into_table().into()
184 : }
185 0 : _ => bail!("toml not a table or inline table"),
186 : };
187 :
188 44 : if document.is_empty() {
189 22 : return Ok(None);
190 22 : }
191 22 :
192 22 : Ok(Some(toml_edit::de::from_document(document)?))
193 44 : }
194 : }
195 :
196 : #[cfg(test)]
197 : mod tests {
198 : use super::*;
199 :
200 12 : fn parse(input: &str) -> anyhow::Result<Option<RemoteStorageConfig>> {
201 12 : let toml = input.parse::<toml_edit::Document>().unwrap();
202 12 : RemoteStorageConfig::from_toml(toml.as_item())
203 12 : }
204 :
205 : #[test]
206 4 : fn parse_localfs_config_with_timeout() {
207 4 : let input = "local_path = '.'
208 4 : timeout = '5s'";
209 4 :
210 4 : let config = parse(input).unwrap().expect("it exists");
211 4 :
212 4 : assert_eq!(
213 4 : config,
214 4 : RemoteStorageConfig {
215 4 : storage: RemoteStorageKind::LocalFs {
216 4 : local_path: Utf8PathBuf::from(".")
217 4 : },
218 4 : timeout: Duration::from_secs(5)
219 4 : }
220 4 : );
221 4 : }
222 :
223 : #[test]
224 4 : fn test_s3_parsing() {
225 4 : let toml = "\
226 4 : bucket_name = 'foo-bar'
227 4 : bucket_region = 'eu-central-1'
228 4 : upload_storage_class = 'INTELLIGENT_TIERING'
229 4 : timeout = '7s'
230 4 : ";
231 4 :
232 4 : let config = parse(toml).unwrap().expect("it exists");
233 4 :
234 4 : assert_eq!(
235 4 : config,
236 4 : RemoteStorageConfig {
237 4 : storage: RemoteStorageKind::AwsS3(S3Config {
238 4 : bucket_name: "foo-bar".into(),
239 4 : bucket_region: "eu-central-1".into(),
240 4 : prefix_in_bucket: None,
241 4 : endpoint: None,
242 4 : concurrency_limit: default_remote_storage_s3_concurrency_limit(),
243 4 : max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
244 4 : upload_storage_class: Some(StorageClass::IntelligentTiering),
245 4 : }),
246 4 : timeout: Duration::from_secs(7)
247 4 : }
248 4 : );
249 4 : }
250 :
251 : #[test]
252 4 : fn test_azure_parsing() {
253 4 : let toml = "\
254 4 : container_name = 'foo-bar'
255 4 : container_region = 'westeurope'
256 4 : upload_storage_class = 'INTELLIGENT_TIERING'
257 4 : timeout = '7s'
258 4 : ";
259 4 :
260 4 : let config = parse(toml).unwrap().expect("it exists");
261 4 :
262 4 : assert_eq!(
263 4 : config,
264 4 : RemoteStorageConfig {
265 4 : storage: RemoteStorageKind::AzureContainer(AzureConfig {
266 4 : container_name: "foo-bar".into(),
267 4 : storage_account: None,
268 4 : container_region: "westeurope".into(),
269 4 : prefix_in_container: None,
270 4 : concurrency_limit: default_remote_storage_azure_concurrency_limit(),
271 4 : max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
272 4 : }),
273 4 : timeout: Duration::from_secs(7)
274 4 : }
275 4 : );
276 4 : }
277 : }
|