Line data Source code
1 : pub mod claims;
2 : use crate::claims::{DeletePrefixClaims, EndpointStorageClaims};
3 : use anyhow::Result;
4 : use axum::extract::{FromRequestParts, Path};
5 : use axum::response::{IntoResponse, Response};
6 : use axum::{RequestPartsExt, http::StatusCode, http::request::Parts};
7 : use axum_extra::TypedHeader;
8 : use axum_extra::headers::{Authorization, authorization::Bearer};
9 : use camino::Utf8PathBuf;
10 : use jsonwebtoken::{DecodingKey, Validation};
11 : use remote_storage::{GenericRemoteStorage, RemotePath};
12 : use serde::{Deserialize, Serialize};
13 : use std::fmt::Display;
14 : use std::result::Result as StdResult;
15 : use std::sync::Arc;
16 : use tokio_util::sync::CancellationToken;
17 : use tracing::{debug, error};
18 : use utils::id::{EndpointId, TenantId, TimelineId};
19 :
20 : // simplified version of utils::auth::JwtAuth
21 : pub struct JwtAuth {
22 : decoding_key: DecodingKey,
23 : validation: Validation,
24 : }
25 :
26 : pub const VALIDATION_ALGO: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::EdDSA;
27 : impl JwtAuth {
28 35 : pub fn new(key: &[u8]) -> Result<Self> {
29 : Ok(Self {
30 35 : decoding_key: DecodingKey::from_ed_pem(key)?,
31 35 : validation: Validation::new(VALIDATION_ALGO),
32 : })
33 35 : }
34 :
35 117 : pub fn decode<T: serde::de::DeserializeOwned>(&self, token: &str) -> Result<T> {
36 117 : Ok(jsonwebtoken::decode(token, &self.decoding_key, &self.validation).map(|t| t.claims)?)
37 117 : }
38 : }
39 :
40 50 : fn normalize_key(key: &str) -> StdResult<Utf8PathBuf, String> {
41 50 : let key = clean_utf8(&Utf8PathBuf::from(key));
42 50 : if key.starts_with("..") || key == "." || key == "/" {
43 9 : return Err(format!("invalid key {key}"));
44 41 : }
45 41 : match key.strip_prefix("/").map(Utf8PathBuf::from) {
46 1 : Ok(p) => Ok(p),
47 40 : _ => Ok(key),
48 : }
49 50 : }
50 :
51 : // Copied from path_clean crate with PathBuf->Utf8PathBuf
52 50 : fn clean_utf8(path: &camino::Utf8Path) -> Utf8PathBuf {
53 : use camino::Utf8Component as Comp;
54 50 : let mut out = Vec::new();
55 82 : for comp in path.components() {
56 82 : match comp {
57 1 : Comp::CurDir => (),
58 18 : Comp::ParentDir => match out.last() {
59 1 : Some(Comp::RootDir) => (),
60 12 : Some(Comp::Normal(_)) => {
61 12 : out.pop();
62 12 : }
63 : None | Some(Comp::CurDir) | Some(Comp::ParentDir) | Some(Comp::Prefix(_)) => {
64 5 : out.push(comp)
65 : }
66 : },
67 63 : comp => out.push(comp),
68 : }
69 : }
70 50 : if !out.is_empty() {
71 48 : out.iter().collect()
72 : } else {
73 2 : Utf8PathBuf::from(".")
74 : }
75 50 : }
76 :
77 : pub struct Storage {
78 : pub auth: JwtAuth,
79 : pub storage: GenericRemoteStorage,
80 : pub cancel: CancellationToken,
81 : pub max_upload_file_limit: usize,
82 : }
83 :
84 6 : #[derive(Deserialize, Serialize)]
85 : struct KeyRequest {
86 : tenant_id: TenantId,
87 : timeline_id: TimelineId,
88 : endpoint_id: EndpointId,
89 : path: String,
90 : }
91 :
92 6 : #[derive(Deserialize, Serialize, PartialEq)]
93 : struct PrefixKeyRequest {
94 : tenant_id: TenantId,
95 : timeline_id: Option<TimelineId>,
96 : endpoint_id: Option<EndpointId>,
97 : }
98 :
99 : #[derive(Debug, PartialEq)]
100 : pub struct S3Path {
101 : pub path: RemotePath,
102 : }
103 :
104 : impl TryFrom<&KeyRequest> for S3Path {
105 : type Error = String;
106 41 : fn try_from(req: &KeyRequest) -> StdResult<Self, Self::Error> {
107 : let KeyRequest {
108 41 : tenant_id,
109 41 : timeline_id,
110 41 : endpoint_id,
111 41 : path,
112 41 : } = &req;
113 41 : let prefix = format!("{tenant_id}/{timeline_id}/{endpoint_id}",);
114 41 : let path = Utf8PathBuf::from(prefix).join(normalize_key(path)?);
115 38 : let path = RemotePath::new(&path).unwrap(); // unwrap() because the path is already relative
116 38 : Ok(S3Path { path })
117 41 : }
118 : }
119 :
120 73 : fn unauthorized(route: impl Display, claims: impl Display) -> Response {
121 73 : debug!(%route, %claims, "route doesn't match claims");
122 73 : StatusCode::UNAUTHORIZED.into_response()
123 73 : }
124 :
125 14 : pub fn bad_request(err: impl Display, desc: &'static str) -> Response {
126 14 : debug!(%err, desc);
127 14 : (StatusCode::BAD_REQUEST, err.to_string()).into_response()
128 14 : }
129 :
130 21 : pub fn ok() -> Response {
131 21 : StatusCode::OK.into_response()
132 21 : }
133 :
134 0 : pub fn internal_error(err: impl Display, path: impl Display, desc: &'static str) -> Response {
135 0 : error!(%err, %path, desc);
136 0 : StatusCode::INTERNAL_SERVER_ERROR.into_response()
137 0 : }
138 :
139 15 : pub fn not_found(key: impl ToString) -> Response {
140 15 : (StatusCode::NOT_FOUND, key.to_string()).into_response()
141 0 : }
142 :
143 : impl FromRequestParts<Arc<Storage>> for S3Path {
144 : type Rejection = Response;
145 117 : async fn from_request_parts(
146 117 : parts: &mut Parts,
147 117 : state: &Arc<Storage>,
148 117 : ) -> Result<Self, Self::Rejection> {
149 117 : let Path(path): Path<KeyRequest> = parts
150 117 : .extract()
151 117 : .await
152 117 : .map_err(|e| bad_request(e, "invalid route"))?;
153 111 : let TypedHeader(Authorization(bearer)) = parts
154 111 : .extract::<TypedHeader<Authorization<Bearer>>>()
155 111 : .await
156 111 : .map_err(|e| bad_request(e, "invalid token"))?;
157 111 : let claims: EndpointStorageClaims = state
158 111 : .auth
159 111 : .decode(bearer.token())
160 111 : .map_err(|e| bad_request(e, "decoding token"))?;
161 :
162 : // Read paths may have different endpoint ids. For readonly -> readwrite replica
163 : // prewarming, endpoint must read other endpoint's data.
164 111 : let endpoint_id = if parts.method == axum::http::Method::GET {
165 46 : claims.endpoint_id.clone()
166 : } else {
167 65 : path.endpoint_id.clone()
168 : };
169 :
170 111 : let route = EndpointStorageClaims {
171 111 : tenant_id: path.tenant_id,
172 111 : timeline_id: path.timeline_id,
173 111 : endpoint_id,
174 111 : exp: claims.exp,
175 111 : };
176 111 : if route != claims {
177 73 : return Err(unauthorized(route, claims));
178 38 : }
179 38 : (&path)
180 38 : .try_into()
181 38 : .map_err(|e| bad_request(e, "invalid route"))
182 117 : }
183 : }
184 :
185 : #[derive(Debug, PartialEq)]
186 : pub struct PrefixS3Path {
187 : pub path: RemotePath,
188 : }
189 :
190 : impl From<&DeletePrefixClaims> for PrefixS3Path {
191 9 : fn from(path: &DeletePrefixClaims) -> Self {
192 9 : let timeline_id = path
193 9 : .timeline_id
194 9 : .as_ref()
195 9 : .map(ToString::to_string)
196 9 : .unwrap_or("".to_string());
197 9 : let endpoint_id = path
198 9 : .endpoint_id
199 9 : .as_ref()
200 9 : .map(ToString::to_string)
201 9 : .unwrap_or("".to_string());
202 9 : let path = Utf8PathBuf::from(path.tenant_id.to_string())
203 9 : .join(timeline_id)
204 9 : .join(endpoint_id);
205 9 : let path = RemotePath::new(&path).unwrap(); // unwrap() because the path is already relative
206 9 : PrefixS3Path { path }
207 9 : }
208 : }
209 :
210 : impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
211 : type Rejection = Response;
212 12 : async fn from_request_parts(
213 12 : parts: &mut Parts,
214 12 : state: &Arc<Storage>,
215 12 : ) -> Result<Self, Self::Rejection> {
216 12 : let Path(path) = parts
217 12 : .extract::<Path<PrefixKeyRequest>>()
218 12 : .await
219 12 : .map_err(|e| bad_request(e, "invalid route"))?;
220 6 : let TypedHeader(Authorization(bearer)) = parts
221 6 : .extract::<TypedHeader<Authorization<Bearer>>>()
222 6 : .await
223 6 : .map_err(|e| bad_request(e, "invalid token"))?;
224 6 : let claims: DeletePrefixClaims = state
225 6 : .auth
226 6 : .decode(bearer.token())
227 6 : .map_err(|e| bad_request(e, "invalid token"))?;
228 6 : let route = DeletePrefixClaims {
229 6 : tenant_id: path.tenant_id,
230 6 : timeline_id: path.timeline_id,
231 6 : endpoint_id: path.endpoint_id,
232 6 : exp: claims.exp,
233 6 : };
234 6 : if route != claims {
235 0 : return Err(unauthorized(route, claims));
236 6 : }
237 6 : Ok((&route).into())
238 12 : }
239 : }
240 :
241 : #[cfg(test)]
242 : mod tests {
243 : use super::*;
244 :
245 : #[test]
246 1 : fn normalize_key() {
247 1 : let f = super::normalize_key;
248 1 : assert_eq!(f("hello/world/..").unwrap(), Utf8PathBuf::from("hello"));
249 1 : assert_eq!(
250 1 : f("ololo/1/../../not_ololo").unwrap(),
251 1 : Utf8PathBuf::from("not_ololo")
252 : );
253 1 : assert!(f("ololo/1/../../../").is_err());
254 1 : assert!(f(".").is_err());
255 1 : assert!(f("../").is_err());
256 1 : assert!(f("").is_err());
257 1 : assert_eq!(f("/1/2/3").unwrap(), Utf8PathBuf::from("1/2/3"));
258 1 : assert!(f("/1/2/3/../../../").is_err());
259 1 : assert!(f("/1/2/3/../../../../").is_err());
260 1 : }
261 :
262 : const TENANT_ID: TenantId =
263 : TenantId::from_array([1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6]);
264 : const TIMELINE_ID: TimelineId =
265 : TimelineId::from_array([1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 7]);
266 : const ENDPOINT_ID: &str = "ep-winter-frost-a662z3vg";
267 :
268 : #[test]
269 1 : fn s3_path() {
270 1 : let auth = EndpointStorageClaims {
271 1 : tenant_id: TENANT_ID,
272 1 : timeline_id: TIMELINE_ID,
273 1 : endpoint_id: ENDPOINT_ID.into(),
274 1 : exp: u64::MAX,
275 1 : };
276 2 : let s3_path = |key| {
277 2 : let path = &format!("{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}/{key}");
278 2 : let path = RemotePath::from_string(path).unwrap();
279 2 : S3Path { path }
280 2 : };
281 :
282 1 : let path = "cache_key".to_string();
283 1 : let mut key_path = KeyRequest {
284 1 : path,
285 1 : tenant_id: auth.tenant_id,
286 1 : timeline_id: auth.timeline_id,
287 1 : endpoint_id: auth.endpoint_id,
288 1 : };
289 1 : assert_eq!(S3Path::try_from(&key_path).unwrap(), s3_path(key_path.path));
290 :
291 1 : key_path.path = "we/can/have/nested/paths".to_string();
292 1 : assert_eq!(S3Path::try_from(&key_path).unwrap(), s3_path(key_path.path));
293 :
294 1 : key_path.path = "../error/hello/../".to_string();
295 1 : assert!(S3Path::try_from(&key_path).is_err());
296 1 : }
297 :
298 : #[test]
299 1 : fn prefix_s3_path() {
300 1 : let mut path = DeletePrefixClaims {
301 1 : tenant_id: TENANT_ID,
302 1 : timeline_id: None,
303 1 : endpoint_id: None,
304 1 : exp: 0,
305 1 : };
306 3 : let prefix_path = |s: String| RemotePath::from_string(&s).unwrap();
307 1 : assert_eq!(
308 1 : PrefixS3Path::from(&path).path,
309 1 : prefix_path(format!("{TENANT_ID}"))
310 : );
311 :
312 1 : path.timeline_id = Some(TIMELINE_ID);
313 1 : assert_eq!(
314 1 : PrefixS3Path::from(&path).path,
315 1 : prefix_path(format!("{TENANT_ID}/{TIMELINE_ID}"))
316 : );
317 :
318 1 : path.endpoint_id = Some(ENDPOINT_ID.into());
319 1 : assert_eq!(
320 1 : PrefixS3Path::from(&path).path,
321 1 : prefix_path(format!("{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}"))
322 : );
323 1 : }
324 : }
|