LCOV - code coverage report
Current view: top level - endpoint_storage/src - lib.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 97.0 % 198 192
Test Date: 2025-07-16 12:29:03 Functions: 63.4 % 41 26

            Line data    Source code
       1              : pub mod claims;
       2              : use crate::claims::{DeletePrefixClaims, EndpointStorageClaims};
       3              : use anyhow::Result;
       4              : use axum::extract::{FromRequestParts, Path};
       5              : use axum::response::{IntoResponse, Response};
       6              : use axum::{RequestPartsExt, http::StatusCode, http::request::Parts};
       7              : use axum_extra::TypedHeader;
       8              : use axum_extra::headers::{Authorization, authorization::Bearer};
       9              : use camino::Utf8PathBuf;
      10              : use jsonwebtoken::{DecodingKey, Validation};
      11              : use remote_storage::{GenericRemoteStorage, RemotePath};
      12              : use serde::{Deserialize, Serialize};
      13              : use std::fmt::Display;
      14              : use std::result::Result as StdResult;
      15              : use std::sync::Arc;
      16              : use tokio_util::sync::CancellationToken;
      17              : use tracing::{debug, error};
      18              : use utils::id::{EndpointId, TenantId, TimelineId};
      19              : 
      20              : // simplified version of utils::auth::JwtAuth
      21              : pub struct JwtAuth {
      22              :     decoding_key: DecodingKey,
      23              :     validation: Validation,
      24              : }
      25              : 
      26              : pub const VALIDATION_ALGO: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::EdDSA;
      27              : impl JwtAuth {
      28           35 :     pub fn new(key: &[u8]) -> Result<Self> {
      29              :         Ok(Self {
      30           35 :             decoding_key: DecodingKey::from_ed_pem(key)?,
      31           35 :             validation: Validation::new(VALIDATION_ALGO),
      32              :         })
      33           35 :     }
      34              : 
      35          117 :     pub fn decode<T: serde::de::DeserializeOwned>(&self, token: &str) -> Result<T> {
      36          117 :         Ok(jsonwebtoken::decode(token, &self.decoding_key, &self.validation).map(|t| t.claims)?)
      37          117 :     }
      38              : }
      39              : 
      40           50 : fn normalize_key(key: &str) -> StdResult<Utf8PathBuf, String> {
      41           50 :     let key = clean_utf8(&Utf8PathBuf::from(key));
      42           50 :     if key.starts_with("..") || key == "." || key == "/" {
      43            9 :         return Err(format!("invalid key {key}"));
      44           41 :     }
      45           41 :     match key.strip_prefix("/").map(Utf8PathBuf::from) {
      46            1 :         Ok(p) => Ok(p),
      47           40 :         _ => Ok(key),
      48              :     }
      49           50 : }
      50              : 
      51              : // Copied from path_clean crate with PathBuf->Utf8PathBuf
      52           50 : fn clean_utf8(path: &camino::Utf8Path) -> Utf8PathBuf {
      53              :     use camino::Utf8Component as Comp;
      54           50 :     let mut out = Vec::new();
      55           82 :     for comp in path.components() {
      56           82 :         match comp {
      57            1 :             Comp::CurDir => (),
      58           18 :             Comp::ParentDir => match out.last() {
      59            1 :                 Some(Comp::RootDir) => (),
      60           12 :                 Some(Comp::Normal(_)) => {
      61           12 :                     out.pop();
      62           12 :                 }
      63              :                 None | Some(Comp::CurDir) | Some(Comp::ParentDir) | Some(Comp::Prefix(_)) => {
      64            5 :                     out.push(comp)
      65              :                 }
      66              :             },
      67           63 :             comp => out.push(comp),
      68              :         }
      69              :     }
      70           50 :     if !out.is_empty() {
      71           48 :         out.iter().collect()
      72              :     } else {
      73            2 :         Utf8PathBuf::from(".")
      74              :     }
      75           50 : }
      76              : 
      77              : pub struct Storage {
      78              :     pub auth: JwtAuth,
      79              :     pub storage: GenericRemoteStorage,
      80              :     pub cancel: CancellationToken,
      81              :     pub max_upload_file_limit: usize,
      82              : }
      83              : 
      84            6 : #[derive(Deserialize, Serialize)]
      85              : struct KeyRequest {
      86              :     tenant_id: TenantId,
      87              :     timeline_id: TimelineId,
      88              :     endpoint_id: EndpointId,
      89              :     path: String,
      90              : }
      91              : 
      92            6 : #[derive(Deserialize, Serialize, PartialEq)]
      93              : struct PrefixKeyRequest {
      94              :     tenant_id: TenantId,
      95              :     timeline_id: Option<TimelineId>,
      96              :     endpoint_id: Option<EndpointId>,
      97              : }
      98              : 
      99              : #[derive(Debug, PartialEq)]
     100              : pub struct S3Path {
     101              :     pub path: RemotePath,
     102              : }
     103              : 
     104              : impl TryFrom<&KeyRequest> for S3Path {
     105              :     type Error = String;
     106           41 :     fn try_from(req: &KeyRequest) -> StdResult<Self, Self::Error> {
     107              :         let KeyRequest {
     108           41 :             tenant_id,
     109           41 :             timeline_id,
     110           41 :             endpoint_id,
     111           41 :             path,
     112           41 :         } = &req;
     113           41 :         let prefix = format!("{tenant_id}/{timeline_id}/{endpoint_id}",);
     114           41 :         let path = Utf8PathBuf::from(prefix).join(normalize_key(path)?);
     115           38 :         let path = RemotePath::new(&path).unwrap(); // unwrap() because the path is already relative
     116           38 :         Ok(S3Path { path })
     117           41 :     }
     118              : }
     119              : 
     120           73 : fn unauthorized(route: impl Display, claims: impl Display) -> Response {
     121           73 :     debug!(%route, %claims, "route doesn't match claims");
     122           73 :     StatusCode::UNAUTHORIZED.into_response()
     123           73 : }
     124              : 
     125           14 : pub fn bad_request(err: impl Display, desc: &'static str) -> Response {
     126           14 :     debug!(%err, desc);
     127           14 :     (StatusCode::BAD_REQUEST, err.to_string()).into_response()
     128           14 : }
     129              : 
     130           21 : pub fn ok() -> Response {
     131           21 :     StatusCode::OK.into_response()
     132           21 : }
     133              : 
     134            0 : pub fn internal_error(err: impl Display, path: impl Display, desc: &'static str) -> Response {
     135            0 :     error!(%err, %path, desc);
     136            0 :     StatusCode::INTERNAL_SERVER_ERROR.into_response()
     137            0 : }
     138              : 
     139           15 : pub fn not_found(key: impl ToString) -> Response {
     140           15 :     (StatusCode::NOT_FOUND, key.to_string()).into_response()
     141            0 : }
     142              : 
     143              : impl FromRequestParts<Arc<Storage>> for S3Path {
     144              :     type Rejection = Response;
     145          117 :     async fn from_request_parts(
     146          117 :         parts: &mut Parts,
     147          117 :         state: &Arc<Storage>,
     148          117 :     ) -> Result<Self, Self::Rejection> {
     149          117 :         let Path(path): Path<KeyRequest> = parts
     150          117 :             .extract()
     151          117 :             .await
     152          117 :             .map_err(|e| bad_request(e, "invalid route"))?;
     153          111 :         let TypedHeader(Authorization(bearer)) = parts
     154          111 :             .extract::<TypedHeader<Authorization<Bearer>>>()
     155          111 :             .await
     156          111 :             .map_err(|e| bad_request(e, "invalid token"))?;
     157          111 :         let claims: EndpointStorageClaims = state
     158          111 :             .auth
     159          111 :             .decode(bearer.token())
     160          111 :             .map_err(|e| bad_request(e, "decoding token"))?;
     161              : 
     162              :         // Read paths may have different endpoint ids. For readonly -> readwrite replica
     163              :         // prewarming, endpoint must read other endpoint's data.
     164          111 :         let endpoint_id = if parts.method == axum::http::Method::GET {
     165           46 :             claims.endpoint_id.clone()
     166              :         } else {
     167           65 :             path.endpoint_id.clone()
     168              :         };
     169              : 
     170          111 :         let route = EndpointStorageClaims {
     171          111 :             tenant_id: path.tenant_id,
     172          111 :             timeline_id: path.timeline_id,
     173          111 :             endpoint_id,
     174          111 :             exp: claims.exp,
     175          111 :         };
     176          111 :         if route != claims {
     177           73 :             return Err(unauthorized(route, claims));
     178           38 :         }
     179           38 :         (&path)
     180           38 :             .try_into()
     181           38 :             .map_err(|e| bad_request(e, "invalid route"))
     182          117 :     }
     183              : }
     184              : 
     185              : #[derive(Debug, PartialEq)]
     186              : pub struct PrefixS3Path {
     187              :     pub path: RemotePath,
     188              : }
     189              : 
     190              : impl From<&DeletePrefixClaims> for PrefixS3Path {
     191            9 :     fn from(path: &DeletePrefixClaims) -> Self {
     192            9 :         let timeline_id = path
     193            9 :             .timeline_id
     194            9 :             .as_ref()
     195            9 :             .map(ToString::to_string)
     196            9 :             .unwrap_or("".to_string());
     197            9 :         let endpoint_id = path
     198            9 :             .endpoint_id
     199            9 :             .as_ref()
     200            9 :             .map(ToString::to_string)
     201            9 :             .unwrap_or("".to_string());
     202            9 :         let path = Utf8PathBuf::from(path.tenant_id.to_string())
     203            9 :             .join(timeline_id)
     204            9 :             .join(endpoint_id);
     205            9 :         let path = RemotePath::new(&path).unwrap(); // unwrap() because the path is already relative
     206            9 :         PrefixS3Path { path }
     207            9 :     }
     208              : }
     209              : 
     210              : impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
     211              :     type Rejection = Response;
     212           12 :     async fn from_request_parts(
     213           12 :         parts: &mut Parts,
     214           12 :         state: &Arc<Storage>,
     215           12 :     ) -> Result<Self, Self::Rejection> {
     216           12 :         let Path(path) = parts
     217           12 :             .extract::<Path<PrefixKeyRequest>>()
     218           12 :             .await
     219           12 :             .map_err(|e| bad_request(e, "invalid route"))?;
     220            6 :         let TypedHeader(Authorization(bearer)) = parts
     221            6 :             .extract::<TypedHeader<Authorization<Bearer>>>()
     222            6 :             .await
     223            6 :             .map_err(|e| bad_request(e, "invalid token"))?;
     224            6 :         let claims: DeletePrefixClaims = state
     225            6 :             .auth
     226            6 :             .decode(bearer.token())
     227            6 :             .map_err(|e| bad_request(e, "invalid token"))?;
     228            6 :         let route = DeletePrefixClaims {
     229            6 :             tenant_id: path.tenant_id,
     230            6 :             timeline_id: path.timeline_id,
     231            6 :             endpoint_id: path.endpoint_id,
     232            6 :             exp: claims.exp,
     233            6 :         };
     234            6 :         if route != claims {
     235            0 :             return Err(unauthorized(route, claims));
     236            6 :         }
     237            6 :         Ok((&route).into())
     238           12 :     }
     239              : }
     240              : 
     241              : #[cfg(test)]
     242              : mod tests {
     243              :     use super::*;
     244              : 
     245              :     #[test]
     246            1 :     fn normalize_key() {
     247            1 :         let f = super::normalize_key;
     248            1 :         assert_eq!(f("hello/world/..").unwrap(), Utf8PathBuf::from("hello"));
     249            1 :         assert_eq!(
     250            1 :             f("ololo/1/../../not_ololo").unwrap(),
     251            1 :             Utf8PathBuf::from("not_ololo")
     252              :         );
     253            1 :         assert!(f("ololo/1/../../../").is_err());
     254            1 :         assert!(f(".").is_err());
     255            1 :         assert!(f("../").is_err());
     256            1 :         assert!(f("").is_err());
     257            1 :         assert_eq!(f("/1/2/3").unwrap(), Utf8PathBuf::from("1/2/3"));
     258            1 :         assert!(f("/1/2/3/../../../").is_err());
     259            1 :         assert!(f("/1/2/3/../../../../").is_err());
     260            1 :     }
     261              : 
     262              :     const TENANT_ID: TenantId =
     263              :         TenantId::from_array([1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6]);
     264              :     const TIMELINE_ID: TimelineId =
     265              :         TimelineId::from_array([1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 7]);
     266              :     const ENDPOINT_ID: &str = "ep-winter-frost-a662z3vg";
     267              : 
     268              :     #[test]
     269            1 :     fn s3_path() {
     270            1 :         let auth = EndpointStorageClaims {
     271            1 :             tenant_id: TENANT_ID,
     272            1 :             timeline_id: TIMELINE_ID,
     273            1 :             endpoint_id: ENDPOINT_ID.into(),
     274            1 :             exp: u64::MAX,
     275            1 :         };
     276            2 :         let s3_path = |key| {
     277            2 :             let path = &format!("{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}/{key}");
     278            2 :             let path = RemotePath::from_string(path).unwrap();
     279            2 :             S3Path { path }
     280            2 :         };
     281              : 
     282            1 :         let path = "cache_key".to_string();
     283            1 :         let mut key_path = KeyRequest {
     284            1 :             path,
     285            1 :             tenant_id: auth.tenant_id,
     286            1 :             timeline_id: auth.timeline_id,
     287            1 :             endpoint_id: auth.endpoint_id,
     288            1 :         };
     289            1 :         assert_eq!(S3Path::try_from(&key_path).unwrap(), s3_path(key_path.path));
     290              : 
     291            1 :         key_path.path = "we/can/have/nested/paths".to_string();
     292            1 :         assert_eq!(S3Path::try_from(&key_path).unwrap(), s3_path(key_path.path));
     293              : 
     294            1 :         key_path.path = "../error/hello/../".to_string();
     295            1 :         assert!(S3Path::try_from(&key_path).is_err());
     296            1 :     }
     297              : 
     298              :     #[test]
     299            1 :     fn prefix_s3_path() {
     300            1 :         let mut path = DeletePrefixClaims {
     301            1 :             tenant_id: TENANT_ID,
     302            1 :             timeline_id: None,
     303            1 :             endpoint_id: None,
     304            1 :             exp: 0,
     305            1 :         };
     306            3 :         let prefix_path = |s: String| RemotePath::from_string(&s).unwrap();
     307            1 :         assert_eq!(
     308            1 :             PrefixS3Path::from(&path).path,
     309            1 :             prefix_path(format!("{TENANT_ID}"))
     310              :         );
     311              : 
     312            1 :         path.timeline_id = Some(TIMELINE_ID);
     313            1 :         assert_eq!(
     314            1 :             PrefixS3Path::from(&path).path,
     315            1 :             prefix_path(format!("{TENANT_ID}/{TIMELINE_ID}"))
     316              :         );
     317              : 
     318            1 :         path.endpoint_id = Some(ENDPOINT_ID.into());
     319            1 :         assert_eq!(
     320            1 :             PrefixS3Path::from(&path).path,
     321            1 :             prefix_path(format!("{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}"))
     322              :         );
     323            1 :     }
     324              : }
        

Generated by: LCOV version 2.1-beta