Line data Source code
1 : use anyhow::Result;
2 : use axum::extract::{FromRequestParts, Path};
3 : use axum::response::{IntoResponse, Response};
4 : use axum::{RequestPartsExt, http::StatusCode, http::request::Parts};
5 : use axum_extra::TypedHeader;
6 : use axum_extra::headers::{Authorization, authorization::Bearer};
7 : use camino::Utf8PathBuf;
8 : use jsonwebtoken::{DecodingKey, Validation};
9 : use remote_storage::{GenericRemoteStorage, RemotePath};
10 : use serde::{Deserialize, Serialize};
11 : use std::fmt::Display;
12 : use std::result::Result as StdResult;
13 : use std::sync::Arc;
14 : use tokio_util::sync::CancellationToken;
15 : use tracing::{debug, error};
16 : use utils::id::{TenantId, TimelineId};
17 :
18 : // simplified version of utils::auth::JwtAuth
19 : pub struct JwtAuth {
20 : decoding_key: DecodingKey,
21 : validation: Validation,
22 : }
23 :
24 : pub const VALIDATION_ALGO: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::EdDSA;
25 : impl JwtAuth {
26 35 : pub fn new(key: &[u8]) -> Result<Self> {
27 35 : Ok(Self {
28 35 : decoding_key: DecodingKey::from_ed_pem(key)?,
29 35 : validation: Validation::new(VALIDATION_ALGO),
30 : })
31 35 : }
32 :
33 117 : pub fn decode<T: serde::de::DeserializeOwned>(&self, token: &str) -> Result<T> {
34 117 : Ok(jsonwebtoken::decode(token, &self.decoding_key, &self.validation).map(|t| t.claims)?)
35 117 : }
36 : }
37 :
38 50 : fn normalize_key(key: &str) -> StdResult<Utf8PathBuf, String> {
39 50 : let key = clean_utf8(&Utf8PathBuf::from(key));
40 50 : if key.starts_with("..") || key == "." || key == "/" {
41 9 : return Err(format!("invalid key {key}"));
42 41 : }
43 41 : match key.strip_prefix("/").map(Utf8PathBuf::from) {
44 1 : Ok(p) => Ok(p),
45 40 : _ => Ok(key),
46 : }
47 50 : }
48 :
49 : // Copied from path_clean crate with PathBuf->Utf8PathBuf
50 50 : fn clean_utf8(path: &camino::Utf8Path) -> Utf8PathBuf {
51 : use camino::Utf8Component as Comp;
52 50 : let mut out = Vec::new();
53 82 : for comp in path.components() {
54 82 : match comp {
55 1 : Comp::CurDir => (),
56 18 : Comp::ParentDir => match out.last() {
57 1 : Some(Comp::RootDir) => (),
58 12 : Some(Comp::Normal(_)) => {
59 12 : out.pop();
60 12 : }
61 : None | Some(Comp::CurDir) | Some(Comp::ParentDir) | Some(Comp::Prefix(_)) => {
62 5 : out.push(comp)
63 : }
64 : },
65 63 : comp => out.push(comp),
66 : }
67 : }
68 50 : if !out.is_empty() {
69 48 : out.iter().collect()
70 : } else {
71 2 : Utf8PathBuf::from(".")
72 : }
73 50 : }
74 :
75 : pub struct Storage {
76 : pub auth: JwtAuth,
77 : pub storage: GenericRemoteStorage,
78 : pub cancel: CancellationToken,
79 : pub max_upload_file_limit: usize,
80 : }
81 :
82 : pub type EndpointId = String; // If needed, reuse small string from proxy/src/types.rc
83 :
84 444 : #[derive(Deserialize, Serialize, PartialEq)]
85 : pub struct Claims {
86 : pub tenant_id: TenantId,
87 : pub timeline_id: TimelineId,
88 : pub endpoint_id: EndpointId,
89 : pub exp: u64,
90 : }
91 :
92 : impl Display for Claims {
93 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 0 : write!(
95 0 : f,
96 0 : "Claims(tenant_id {} timeline_id {} endpoint_id {} exp {})",
97 0 : self.tenant_id, self.timeline_id, self.endpoint_id, self.exp
98 0 : )
99 0 : }
100 : }
101 :
102 450 : #[derive(Deserialize, Serialize)]
103 : struct KeyRequest {
104 : tenant_id: TenantId,
105 : timeline_id: TimelineId,
106 : endpoint_id: EndpointId,
107 : path: String,
108 : }
109 :
110 : #[derive(Debug, PartialEq)]
111 : pub struct S3Path {
112 : pub path: RemotePath,
113 : }
114 :
115 : impl TryFrom<&KeyRequest> for S3Path {
116 : type Error = String;
117 41 : fn try_from(req: &KeyRequest) -> StdResult<Self, Self::Error> {
118 41 : let KeyRequest {
119 41 : tenant_id,
120 41 : timeline_id,
121 41 : endpoint_id,
122 41 : path,
123 41 : } = &req;
124 41 : let prefix = format!("{tenant_id}/{timeline_id}/{endpoint_id}",);
125 41 : let path = Utf8PathBuf::from(prefix).join(normalize_key(path)?);
126 38 : let path = RemotePath::new(&path).unwrap(); // unwrap() because the path is already relative
127 38 : Ok(S3Path { path })
128 41 : }
129 : }
130 :
131 73 : fn unauthorized(route: impl Display, claims: impl Display) -> Response {
132 73 : debug!(%route, %claims, "route doesn't match claims");
133 73 : StatusCode::UNAUTHORIZED.into_response()
134 73 : }
135 :
136 14 : pub fn bad_request(err: impl Display, desc: &'static str) -> Response {
137 14 : debug!(%err, desc);
138 14 : (StatusCode::BAD_REQUEST, err.to_string()).into_response()
139 14 : }
140 :
141 21 : pub fn ok() -> Response {
142 21 : StatusCode::OK.into_response()
143 21 : }
144 :
145 0 : pub fn internal_error(err: impl Display, path: impl Display, desc: &'static str) -> Response {
146 0 : error!(%err, %path, desc);
147 0 : StatusCode::INTERNAL_SERVER_ERROR.into_response()
148 0 : }
149 :
150 15 : pub fn not_found(key: impl ToString) -> Response {
151 15 : (StatusCode::NOT_FOUND, key.to_string()).into_response()
152 15 : }
153 :
154 : impl FromRequestParts<Arc<Storage>> for S3Path {
155 : type Rejection = Response;
156 117 : async fn from_request_parts(
157 117 : parts: &mut Parts,
158 117 : state: &Arc<Storage>,
159 117 : ) -> Result<Self, Self::Rejection> {
160 117 : let Path(path): Path<KeyRequest> = parts
161 117 : .extract()
162 117 : .await
163 117 : .map_err(|e| bad_request(e, "invalid route"))?;
164 111 : let TypedHeader(Authorization(bearer)) = parts
165 111 : .extract::<TypedHeader<Authorization<Bearer>>>()
166 111 : .await
167 111 : .map_err(|e| bad_request(e, "invalid token"))?;
168 111 : let claims: Claims = state
169 111 : .auth
170 111 : .decode(bearer.token())
171 111 : .map_err(|e| bad_request(e, "decoding token"))?;
172 :
173 : // Read paths may have different endpoint ids. For readonly -> readwrite replica
174 : // prewarming, endpoint must read other endpoint's data.
175 111 : let endpoint_id = if parts.method == axum::http::Method::GET {
176 46 : claims.endpoint_id.clone()
177 : } else {
178 65 : path.endpoint_id.clone()
179 : };
180 :
181 111 : let route = Claims {
182 111 : tenant_id: path.tenant_id,
183 111 : timeline_id: path.timeline_id,
184 111 : endpoint_id,
185 111 : exp: claims.exp,
186 111 : };
187 111 : if route != claims {
188 73 : return Err(unauthorized(route, claims));
189 38 : }
190 38 : (&path)
191 38 : .try_into()
192 38 : .map_err(|e| bad_request(e, "invalid route"))
193 117 : }
194 : }
195 :
196 44 : #[derive(Deserialize, Serialize, PartialEq)]
197 : pub struct PrefixKeyPath {
198 : pub tenant_id: TenantId,
199 : pub timeline_id: Option<TimelineId>,
200 : pub endpoint_id: Option<EndpointId>,
201 : }
202 :
203 : impl Display for PrefixKeyPath {
204 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 0 : write!(
206 0 : f,
207 0 : "PrefixKeyPath(tenant_id {} timeline_id {} endpoint_id {})",
208 0 : self.tenant_id,
209 0 : self.timeline_id
210 0 : .as_ref()
211 0 : .map(ToString::to_string)
212 0 : .unwrap_or("".to_string()),
213 0 : self.endpoint_id
214 0 : .as_ref()
215 0 : .map(ToString::to_string)
216 0 : .unwrap_or("".to_string())
217 0 : )
218 0 : }
219 : }
220 :
221 : #[derive(Debug, PartialEq)]
222 : pub struct PrefixS3Path {
223 : pub path: RemotePath,
224 : }
225 :
226 : impl From<&PrefixKeyPath> for PrefixS3Path {
227 9 : fn from(path: &PrefixKeyPath) -> Self {
228 9 : let timeline_id = path
229 9 : .timeline_id
230 9 : .as_ref()
231 9 : .map(ToString::to_string)
232 9 : .unwrap_or("".to_string());
233 9 : let endpoint_id = path
234 9 : .endpoint_id
235 9 : .as_ref()
236 9 : .map(ToString::to_string)
237 9 : .unwrap_or("".to_string());
238 9 : let path = Utf8PathBuf::from(path.tenant_id.to_string())
239 9 : .join(timeline_id)
240 9 : .join(endpoint_id);
241 9 : let path = RemotePath::new(&path).unwrap(); // unwrap() because the path is already relative
242 9 : PrefixS3Path { path }
243 9 : }
244 : }
245 :
246 : impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
247 : type Rejection = Response;
248 12 : async fn from_request_parts(
249 12 : parts: &mut Parts,
250 12 : state: &Arc<Storage>,
251 12 : ) -> Result<Self, Self::Rejection> {
252 12 : let Path(path) = parts
253 12 : .extract::<Path<PrefixKeyPath>>()
254 12 : .await
255 12 : .map_err(|e| bad_request(e, "invalid route"))?;
256 6 : let TypedHeader(Authorization(bearer)) = parts
257 6 : .extract::<TypedHeader<Authorization<Bearer>>>()
258 6 : .await
259 6 : .map_err(|e| bad_request(e, "invalid token"))?;
260 6 : let claims: PrefixKeyPath = state
261 6 : .auth
262 6 : .decode(bearer.token())
263 6 : .map_err(|e| bad_request(e, "invalid token"))?;
264 6 : if path != claims {
265 0 : return Err(unauthorized(path, claims));
266 6 : }
267 6 : Ok((&path).into())
268 12 : }
269 : }
270 :
271 : #[cfg(test)]
272 : mod tests {
273 : use super::*;
274 :
275 : #[test]
276 1 : fn normalize_key() {
277 1 : let f = super::normalize_key;
278 1 : assert_eq!(f("hello/world/..").unwrap(), Utf8PathBuf::from("hello"));
279 1 : assert_eq!(
280 1 : f("ololo/1/../../not_ololo").unwrap(),
281 1 : Utf8PathBuf::from("not_ololo")
282 1 : );
283 1 : assert!(f("ololo/1/../../../").is_err());
284 1 : assert!(f(".").is_err());
285 1 : assert!(f("../").is_err());
286 1 : assert!(f("").is_err());
287 1 : assert_eq!(f("/1/2/3").unwrap(), Utf8PathBuf::from("1/2/3"));
288 1 : assert!(f("/1/2/3/../../../").is_err());
289 1 : assert!(f("/1/2/3/../../../../").is_err());
290 1 : }
291 :
292 : const TENANT_ID: TenantId =
293 : TenantId::from_array([1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6]);
294 : const TIMELINE_ID: TimelineId =
295 : TimelineId::from_array([1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 7]);
296 : const ENDPOINT_ID: &str = "ep-winter-frost-a662z3vg";
297 :
298 : #[test]
299 1 : fn s3_path() {
300 1 : let auth = Claims {
301 1 : tenant_id: TENANT_ID,
302 1 : timeline_id: TIMELINE_ID,
303 1 : endpoint_id: ENDPOINT_ID.into(),
304 1 : exp: u64::MAX,
305 1 : };
306 2 : let s3_path = |key| {
307 2 : let path = &format!("{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}/{key}");
308 2 : let path = RemotePath::from_string(path).unwrap();
309 2 : S3Path { path }
310 2 : };
311 :
312 1 : let path = "cache_key".to_string();
313 1 : let mut key_path = KeyRequest {
314 1 : path,
315 1 : tenant_id: auth.tenant_id,
316 1 : timeline_id: auth.timeline_id,
317 1 : endpoint_id: auth.endpoint_id,
318 1 : };
319 1 : assert_eq!(S3Path::try_from(&key_path).unwrap(), s3_path(key_path.path));
320 :
321 1 : key_path.path = "we/can/have/nested/paths".to_string();
322 1 : assert_eq!(S3Path::try_from(&key_path).unwrap(), s3_path(key_path.path));
323 :
324 1 : key_path.path = "../error/hello/../".to_string();
325 1 : assert!(S3Path::try_from(&key_path).is_err());
326 1 : }
327 :
328 : #[test]
329 1 : fn prefix_s3_path() {
330 1 : let mut path = PrefixKeyPath {
331 1 : tenant_id: TENANT_ID,
332 1 : timeline_id: None,
333 1 : endpoint_id: None,
334 1 : };
335 3 : let prefix_path = |s: String| RemotePath::from_string(&s).unwrap();
336 1 : assert_eq!(
337 1 : PrefixS3Path::from(&path).path,
338 1 : prefix_path(format!("{TENANT_ID}"))
339 1 : );
340 :
341 1 : path.timeline_id = Some(TIMELINE_ID);
342 1 : assert_eq!(
343 1 : PrefixS3Path::from(&path).path,
344 1 : prefix_path(format!("{TENANT_ID}/{TIMELINE_ID}"))
345 1 : );
346 :
347 1 : path.endpoint_id = Some(ENDPOINT_ID.into());
348 1 : assert_eq!(
349 1 : PrefixS3Path::from(&path).path,
350 1 : prefix_path(format!("{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}"))
351 1 : );
352 1 : }
353 : }
|