LCOV - code coverage report
Current view: top level - endpoint_storage/src - lib.rs (source / functions) Coverage Total Hit
Test: aca806cab4756d7eb6a304846130f4a73a5d5393.info Lines: 87.7 % 220 193
Test Date: 2025-04-24 20:31:15 Functions: 55.2 % 58 32

            Line data    Source code
       1              : use anyhow::Result;
       2              : use axum::extract::{FromRequestParts, Path};
       3              : use axum::response::{IntoResponse, Response};
       4              : use axum::{RequestPartsExt, http::StatusCode, http::request::Parts};
       5              : use axum_extra::TypedHeader;
       6              : use axum_extra::headers::{Authorization, authorization::Bearer};
       7              : use camino::Utf8PathBuf;
       8              : use jsonwebtoken::{DecodingKey, Validation};
       9              : use remote_storage::{GenericRemoteStorage, RemotePath};
      10              : use serde::{Deserialize, Serialize};
      11              : use std::fmt::Display;
      12              : use std::result::Result as StdResult;
      13              : use std::sync::Arc;
      14              : use tokio_util::sync::CancellationToken;
      15              : use tracing::{debug, error};
      16              : use utils::id::{TenantId, TimelineId};
      17              : 
      18              : // simplified version of utils::auth::JwtAuth
      19              : pub struct JwtAuth {
      20              :     decoding_key: DecodingKey,
      21              :     validation: Validation,
      22              : }
      23              : 
      24              : pub const VALIDATION_ALGO: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::EdDSA;
      25              : impl JwtAuth {
      26           35 :     pub fn new(key: &[u8]) -> Result<Self> {
      27           35 :         Ok(Self {
      28           35 :             decoding_key: DecodingKey::from_ed_pem(key)?,
      29           35 :             validation: Validation::new(VALIDATION_ALGO),
      30              :         })
      31           35 :     }
      32              : 
      33          117 :     pub fn decode<T: serde::de::DeserializeOwned>(&self, token: &str) -> Result<T> {
      34          117 :         Ok(jsonwebtoken::decode(token, &self.decoding_key, &self.validation).map(|t| t.claims)?)
      35          117 :     }
      36              : }
      37              : 
      38           50 : fn normalize_key(key: &str) -> StdResult<Utf8PathBuf, String> {
      39           50 :     let key = clean_utf8(&Utf8PathBuf::from(key));
      40           50 :     if key.starts_with("..") || key == "." || key == "/" {
      41            9 :         return Err(format!("invalid key {key}"));
      42           41 :     }
      43           41 :     match key.strip_prefix("/").map(Utf8PathBuf::from) {
      44            1 :         Ok(p) => Ok(p),
      45           40 :         _ => Ok(key),
      46              :     }
      47           50 : }
      48              : 
      49              : // Copied from path_clean crate with PathBuf->Utf8PathBuf
      50           50 : fn clean_utf8(path: &camino::Utf8Path) -> Utf8PathBuf {
      51              :     use camino::Utf8Component as Comp;
      52           50 :     let mut out = Vec::new();
      53           82 :     for comp in path.components() {
      54           82 :         match comp {
      55            1 :             Comp::CurDir => (),
      56           18 :             Comp::ParentDir => match out.last() {
      57            1 :                 Some(Comp::RootDir) => (),
      58           12 :                 Some(Comp::Normal(_)) => {
      59           12 :                     out.pop();
      60           12 :                 }
      61              :                 None | Some(Comp::CurDir) | Some(Comp::ParentDir) | Some(Comp::Prefix(_)) => {
      62            5 :                     out.push(comp)
      63              :                 }
      64              :             },
      65           63 :             comp => out.push(comp),
      66              :         }
      67              :     }
      68           50 :     if !out.is_empty() {
      69           48 :         out.iter().collect()
      70              :     } else {
      71            2 :         Utf8PathBuf::from(".")
      72              :     }
      73           50 : }
      74              : 
      75              : pub struct Storage {
      76              :     pub auth: JwtAuth,
      77              :     pub storage: GenericRemoteStorage,
      78              :     pub cancel: CancellationToken,
      79              :     pub max_upload_file_limit: usize,
      80              : }
      81              : 
      82              : pub type EndpointId = String; // If needed, reuse small string from proxy/src/types.rc
      83              : 
      84          444 : #[derive(Deserialize, Serialize, PartialEq)]
      85              : pub struct Claims {
      86              :     pub tenant_id: TenantId,
      87              :     pub timeline_id: TimelineId,
      88              :     pub endpoint_id: EndpointId,
      89              :     pub exp: u64,
      90              : }
      91              : 
      92              : impl Display for Claims {
      93            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      94            0 :         write!(
      95            0 :             f,
      96            0 :             "Claims(tenant_id {} timeline_id {} endpoint_id {} exp {})",
      97            0 :             self.tenant_id, self.timeline_id, self.endpoint_id, self.exp
      98            0 :         )
      99            0 :     }
     100              : }
     101              : 
     102          450 : #[derive(Deserialize, Serialize)]
     103              : struct KeyRequest {
     104              :     tenant_id: TenantId,
     105              :     timeline_id: TimelineId,
     106              :     endpoint_id: EndpointId,
     107              :     path: String,
     108              : }
     109              : 
     110              : #[derive(Debug, PartialEq)]
     111              : pub struct S3Path {
     112              :     pub path: RemotePath,
     113              : }
     114              : 
     115              : impl TryFrom<&KeyRequest> for S3Path {
     116              :     type Error = String;
     117           41 :     fn try_from(req: &KeyRequest) -> StdResult<Self, Self::Error> {
     118           41 :         let KeyRequest {
     119           41 :             tenant_id,
     120           41 :             timeline_id,
     121           41 :             endpoint_id,
     122           41 :             path,
     123           41 :         } = &req;
     124           41 :         let prefix = format!("{tenant_id}/{timeline_id}/{endpoint_id}",);
     125           41 :         let path = Utf8PathBuf::from(prefix).join(normalize_key(path)?);
     126           38 :         let path = RemotePath::new(&path).unwrap(); // unwrap() because the path is already relative
     127           38 :         Ok(S3Path { path })
     128           41 :     }
     129              : }
     130              : 
     131           73 : fn unauthorized(route: impl Display, claims: impl Display) -> Response {
     132           73 :     debug!(%route, %claims, "route doesn't match claims");
     133           73 :     StatusCode::UNAUTHORIZED.into_response()
     134           73 : }
     135              : 
     136           14 : pub fn bad_request(err: impl Display, desc: &'static str) -> Response {
     137           14 :     debug!(%err, desc);
     138           14 :     (StatusCode::BAD_REQUEST, err.to_string()).into_response()
     139           14 : }
     140              : 
     141           21 : pub fn ok() -> Response {
     142           21 :     StatusCode::OK.into_response()
     143           21 : }
     144              : 
     145            0 : pub fn internal_error(err: impl Display, path: impl Display, desc: &'static str) -> Response {
     146            0 :     error!(%err, %path, desc);
     147            0 :     StatusCode::INTERNAL_SERVER_ERROR.into_response()
     148            0 : }
     149              : 
     150           15 : pub fn not_found(key: impl ToString) -> Response {
     151           15 :     (StatusCode::NOT_FOUND, key.to_string()).into_response()
     152           15 : }
     153              : 
     154              : impl FromRequestParts<Arc<Storage>> for S3Path {
     155              :     type Rejection = Response;
     156          117 :     async fn from_request_parts(
     157          117 :         parts: &mut Parts,
     158          117 :         state: &Arc<Storage>,
     159          117 :     ) -> Result<Self, Self::Rejection> {
     160          117 :         let Path(path): Path<KeyRequest> = parts
     161          117 :             .extract()
     162          117 :             .await
     163          117 :             .map_err(|e| bad_request(e, "invalid route"))?;
     164          111 :         let TypedHeader(Authorization(bearer)) = parts
     165          111 :             .extract::<TypedHeader<Authorization<Bearer>>>()
     166          111 :             .await
     167          111 :             .map_err(|e| bad_request(e, "invalid token"))?;
     168          111 :         let claims: Claims = state
     169          111 :             .auth
     170          111 :             .decode(bearer.token())
     171          111 :             .map_err(|e| bad_request(e, "decoding token"))?;
     172              : 
     173              :         // Read paths may have different endpoint ids. For readonly -> readwrite replica
     174              :         // prewarming, endpoint must read other endpoint's data.
     175          111 :         let endpoint_id = if parts.method == axum::http::Method::GET {
     176           46 :             claims.endpoint_id.clone()
     177              :         } else {
     178           65 :             path.endpoint_id.clone()
     179              :         };
     180              : 
     181          111 :         let route = Claims {
     182          111 :             tenant_id: path.tenant_id,
     183          111 :             timeline_id: path.timeline_id,
     184          111 :             endpoint_id,
     185          111 :             exp: claims.exp,
     186          111 :         };
     187          111 :         if route != claims {
     188           73 :             return Err(unauthorized(route, claims));
     189           38 :         }
     190           38 :         (&path)
     191           38 :             .try_into()
     192           38 :             .map_err(|e| bad_request(e, "invalid route"))
     193          117 :     }
     194              : }
     195              : 
     196           44 : #[derive(Deserialize, Serialize, PartialEq)]
     197              : pub struct PrefixKeyPath {
     198              :     pub tenant_id: TenantId,
     199              :     pub timeline_id: Option<TimelineId>,
     200              :     pub endpoint_id: Option<EndpointId>,
     201              : }
     202              : 
     203              : impl Display for PrefixKeyPath {
     204            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     205            0 :         write!(
     206            0 :             f,
     207            0 :             "PrefixKeyPath(tenant_id {} timeline_id {} endpoint_id {})",
     208            0 :             self.tenant_id,
     209            0 :             self.timeline_id
     210            0 :                 .as_ref()
     211            0 :                 .map(ToString::to_string)
     212            0 :                 .unwrap_or("".to_string()),
     213            0 :             self.endpoint_id
     214            0 :                 .as_ref()
     215            0 :                 .map(ToString::to_string)
     216            0 :                 .unwrap_or("".to_string())
     217            0 :         )
     218            0 :     }
     219              : }
     220              : 
     221              : #[derive(Debug, PartialEq)]
     222              : pub struct PrefixS3Path {
     223              :     pub path: RemotePath,
     224              : }
     225              : 
     226              : impl From<&PrefixKeyPath> for PrefixS3Path {
     227            9 :     fn from(path: &PrefixKeyPath) -> Self {
     228            9 :         let timeline_id = path
     229            9 :             .timeline_id
     230            9 :             .as_ref()
     231            9 :             .map(ToString::to_string)
     232            9 :             .unwrap_or("".to_string());
     233            9 :         let endpoint_id = path
     234            9 :             .endpoint_id
     235            9 :             .as_ref()
     236            9 :             .map(ToString::to_string)
     237            9 :             .unwrap_or("".to_string());
     238            9 :         let path = Utf8PathBuf::from(path.tenant_id.to_string())
     239            9 :             .join(timeline_id)
     240            9 :             .join(endpoint_id);
     241            9 :         let path = RemotePath::new(&path).unwrap(); // unwrap() because the path is already relative
     242            9 :         PrefixS3Path { path }
     243            9 :     }
     244              : }
     245              : 
     246              : impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
     247              :     type Rejection = Response;
     248           12 :     async fn from_request_parts(
     249           12 :         parts: &mut Parts,
     250           12 :         state: &Arc<Storage>,
     251           12 :     ) -> Result<Self, Self::Rejection> {
     252           12 :         let Path(path) = parts
     253           12 :             .extract::<Path<PrefixKeyPath>>()
     254           12 :             .await
     255           12 :             .map_err(|e| bad_request(e, "invalid route"))?;
     256            6 :         let TypedHeader(Authorization(bearer)) = parts
     257            6 :             .extract::<TypedHeader<Authorization<Bearer>>>()
     258            6 :             .await
     259            6 :             .map_err(|e| bad_request(e, "invalid token"))?;
     260            6 :         let claims: PrefixKeyPath = state
     261            6 :             .auth
     262            6 :             .decode(bearer.token())
     263            6 :             .map_err(|e| bad_request(e, "invalid token"))?;
     264            6 :         if path != claims {
     265            0 :             return Err(unauthorized(path, claims));
     266            6 :         }
     267            6 :         Ok((&path).into())
     268           12 :     }
     269              : }
     270              : 
     271              : #[cfg(test)]
     272              : mod tests {
     273              :     use super::*;
     274              : 
     275              :     #[test]
     276            1 :     fn normalize_key() {
     277            1 :         let f = super::normalize_key;
     278            1 :         assert_eq!(f("hello/world/..").unwrap(), Utf8PathBuf::from("hello"));
     279            1 :         assert_eq!(
     280            1 :             f("ololo/1/../../not_ololo").unwrap(),
     281            1 :             Utf8PathBuf::from("not_ololo")
     282            1 :         );
     283            1 :         assert!(f("ololo/1/../../../").is_err());
     284            1 :         assert!(f(".").is_err());
     285            1 :         assert!(f("../").is_err());
     286            1 :         assert!(f("").is_err());
     287            1 :         assert_eq!(f("/1/2/3").unwrap(), Utf8PathBuf::from("1/2/3"));
     288            1 :         assert!(f("/1/2/3/../../../").is_err());
     289            1 :         assert!(f("/1/2/3/../../../../").is_err());
     290            1 :     }
     291              : 
     292              :     const TENANT_ID: TenantId =
     293              :         TenantId::from_array([1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6]);
     294              :     const TIMELINE_ID: TimelineId =
     295              :         TimelineId::from_array([1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 7]);
     296              :     const ENDPOINT_ID: &str = "ep-winter-frost-a662z3vg";
     297              : 
     298              :     #[test]
     299            1 :     fn s3_path() {
     300            1 :         let auth = Claims {
     301            1 :             tenant_id: TENANT_ID,
     302            1 :             timeline_id: TIMELINE_ID,
     303            1 :             endpoint_id: ENDPOINT_ID.into(),
     304            1 :             exp: u64::MAX,
     305            1 :         };
     306            2 :         let s3_path = |key| {
     307            2 :             let path = &format!("{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}/{key}");
     308            2 :             let path = RemotePath::from_string(path).unwrap();
     309            2 :             S3Path { path }
     310            2 :         };
     311              : 
     312            1 :         let path = "cache_key".to_string();
     313            1 :         let mut key_path = KeyRequest {
     314            1 :             path,
     315            1 :             tenant_id: auth.tenant_id,
     316            1 :             timeline_id: auth.timeline_id,
     317            1 :             endpoint_id: auth.endpoint_id,
     318            1 :         };
     319            1 :         assert_eq!(S3Path::try_from(&key_path).unwrap(), s3_path(key_path.path));
     320              : 
     321            1 :         key_path.path = "we/can/have/nested/paths".to_string();
     322            1 :         assert_eq!(S3Path::try_from(&key_path).unwrap(), s3_path(key_path.path));
     323              : 
     324            1 :         key_path.path = "../error/hello/../".to_string();
     325            1 :         assert!(S3Path::try_from(&key_path).is_err());
     326            1 :     }
     327              : 
     328              :     #[test]
     329            1 :     fn prefix_s3_path() {
     330            1 :         let mut path = PrefixKeyPath {
     331            1 :             tenant_id: TENANT_ID,
     332            1 :             timeline_id: None,
     333            1 :             endpoint_id: None,
     334            1 :         };
     335            3 :         let prefix_path = |s: String| RemotePath::from_string(&s).unwrap();
     336            1 :         assert_eq!(
     337            1 :             PrefixS3Path::from(&path).path,
     338            1 :             prefix_path(format!("{TENANT_ID}"))
     339            1 :         );
     340              : 
     341            1 :         path.timeline_id = Some(TIMELINE_ID);
     342            1 :         assert_eq!(
     343            1 :             PrefixS3Path::from(&path).path,
     344            1 :             prefix_path(format!("{TENANT_ID}/{TIMELINE_ID}"))
     345            1 :         );
     346              : 
     347            1 :         path.endpoint_id = Some(ENDPOINT_ID.into());
     348            1 :         assert_eq!(
     349            1 :             PrefixS3Path::from(&path).path,
     350            1 :             prefix_path(format!("{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}"))
     351            1 :         );
     352            1 :     }
     353              : }
        

Generated by: LCOV version 2.1-beta