LCOV - code coverage report
Current view: top level - libs/utils/src - shard.rs (source / functions) Coverage Total Hit
Test: 1d5975439f3c9882b18414799141ebf9a3922c58.info Lines: 74.8 % 238 178
Test Date: 2025-07-31 15:59:03 Functions: 35.1 % 114 40

            Line data    Source code
       1              : //! See `pageserver_api::shard` for description on sharding.
       2              : 
       3              : use std::ops::RangeInclusive;
       4              : use std::str::FromStr;
       5              : 
       6              : use hex::FromHex;
       7              : use serde::{Deserialize, Serialize};
       8              : 
       9              : use crate::id::TenantId;
      10              : 
      11            0 : #[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug, Hash)]
      12              : pub struct ShardNumber(pub u8);
      13              : 
      14            0 : #[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug, Hash)]
      15              : pub struct ShardCount(pub u8);
      16              : 
      17              : /// Combination of ShardNumber and ShardCount.
      18              : ///
      19              : /// For use within the context of a particular tenant, when we need to know which shard we're
      20              : /// dealing with, but do not need to know the full ShardIdentity (because we won't be doing
      21              : /// any page->shard mapping), and do not need to know the fully qualified TenantShardId.
      22              : #[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)]
      23              : pub struct ShardIndex {
      24              :     pub shard_number: ShardNumber,
      25              :     pub shard_count: ShardCount,
      26              : }
      27              : 
      28              : /// Stripe size as number of pages.
      29              : ///
      30              : /// NB: don't implement Default, so callers don't lazily use it by mistake. See DEFAULT_STRIPE_SIZE.
      31            0 : #[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
      32              : pub struct ShardStripeSize(pub u32);
      33              : 
      34              : /// Formatting helper, for generating the `shard_id` label in traces.
      35              : pub struct ShardSlug<'a>(&'a TenantShardId);
      36              : 
      37              : /// TenantShardId globally identifies a particular shard in a particular tenant.
      38              : ///
      39              : /// These are written as `<TenantId>-<ShardSlug>`, for example:
      40              : ///   # The second shard in a two-shard tenant
      41              : ///   072f1291a5310026820b2fe4b2968934-0102
      42              : ///
      43              : /// If the `ShardCount` is _unsharded_, the `TenantShardId` is written without
      44              : /// a shard suffix and is equivalent to the encoding of a `TenantId`: this enables
      45              : /// an unsharded [`TenantShardId`] to be used interchangably with a [`TenantId`].
      46              : ///
      47              : /// The human-readable encoding of an unsharded TenantShardId, such as used in API URLs,
      48              : /// is both forward and backward compatible with TenantId: a legacy TenantId can be
      49              : /// decoded as a TenantShardId, and when re-encoded it will be parseable
      50              : /// as a TenantId.
      51              : #[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)]
      52              : pub struct TenantShardId {
      53              :     pub tenant_id: TenantId,
      54              :     pub shard_number: ShardNumber,
      55              :     pub shard_count: ShardCount,
      56              : }
      57              : 
      58              : impl ShardCount {
      59              :     pub const MAX: Self = Self(u8::MAX);
      60              :     pub const MIN: Self = Self(0);
      61              : 
      62            1 :     pub fn unsharded() -> Self {
      63            1 :         ShardCount(0)
      64            1 :     }
      65              : 
      66              :     /// The internal value of a ShardCount may be zero, which means "1 shard, but use
      67              :     /// legacy format for TenantShardId that excludes the shard suffix", also known
      68              :     /// as [`TenantShardId::unsharded`].
      69              :     ///
      70              :     /// This method returns the actual number of shards, i.e. if our internal value is
      71              :     /// zero, we return 1 (unsharded tenants have 1 shard).
      72      2407224 :     pub fn count(&self) -> u8 {
      73      2407224 :         if self.0 > 0 { self.0 } else { 1 }
      74      2407224 :     }
      75              : 
      76              :     /// The literal internal value: this is **not** the number of shards in the
      77              :     /// tenant, as we have a special zero value for legacy unsharded tenants.  Use
      78              :     /// [`Self::count`] if you want to know the cardinality of shards.
      79            2 :     pub fn literal(&self) -> u8 {
      80            2 :         self.0
      81            2 :     }
      82              : 
      83              :     /// Whether the `ShardCount` is for an unsharded tenant, so uses one shard but
      84              :     /// uses the legacy format for `TenantShardId`. See also the documentation for
      85              :     /// [`Self::count`].
      86            0 :     pub fn is_unsharded(&self) -> bool {
      87            0 :         self.0 == 0
      88            0 :     }
      89              : 
      90              :     /// `v` may be zero, or the number of shards in the tenant.  `v` is what
      91              :     /// [`Self::literal`] would return.
      92        10496 :     pub const fn new(val: u8) -> Self {
      93        10496 :         Self(val)
      94        10496 :     }
      95              : }
      96              : 
      97              : impl ShardNumber {
      98              :     pub const MAX: Self = Self(u8::MAX);
      99              : }
     100              : 
     101              : impl TenantShardId {
     102           48 :     pub fn unsharded(tenant_id: TenantId) -> Self {
     103           48 :         Self {
     104           48 :             tenant_id,
     105           48 :             shard_number: ShardNumber(0),
     106           48 :             shard_count: ShardCount(0),
     107           48 :         }
     108           48 :     }
     109              : 
     110              :     /// The range of all TenantShardId that belong to a particular TenantId.  This is useful when
     111              :     /// you have a BTreeMap of TenantShardId, and are querying by TenantId.
     112            0 :     pub fn tenant_range(tenant_id: TenantId) -> RangeInclusive<Self> {
     113            0 :         RangeInclusive::new(
     114            0 :             Self {
     115            0 :                 tenant_id,
     116            0 :                 shard_number: ShardNumber(0),
     117            0 :                 shard_count: ShardCount(0),
     118            0 :             },
     119            0 :             Self {
     120            0 :                 tenant_id,
     121            0 :                 shard_number: ShardNumber::MAX,
     122            0 :                 shard_count: ShardCount::MAX,
     123            0 :             },
     124              :         )
     125            0 :     }
     126              : 
     127            0 :     pub fn range(&self) -> RangeInclusive<Self> {
     128            0 :         RangeInclusive::new(*self, *self)
     129            0 :     }
     130              : 
     131        19002 :     pub fn shard_slug(&self) -> impl std::fmt::Display + '_ {
     132        19002 :         ShardSlug(self)
     133        19002 :     }
     134              : 
     135              :     /// Convenience for code that has special behavior on the 0th shard.
     136          310 :     pub fn is_shard_zero(&self) -> bool {
     137          310 :         self.shard_number == ShardNumber(0)
     138          310 :     }
     139              : 
     140              :     /// The "unsharded" value is distinct from simply having a single shard: it represents
     141              :     /// a tenant which is not shard-aware at all, and whose storage paths will not include
     142              :     /// a shard suffix.
     143            0 :     pub fn is_unsharded(&self) -> bool {
     144            0 :         self.shard_number == ShardNumber(0) && self.shard_count.is_unsharded()
     145            0 :     }
     146              : 
     147              :     /// Convenience for dropping the tenant_id and just getting the ShardIndex: this
     148              :     /// is useful when logging from code that is already in a span that includes tenant ID, to
     149              :     /// keep messages reasonably terse.
     150            0 :     pub fn to_index(&self) -> ShardIndex {
     151            0 :         ShardIndex {
     152            0 :             shard_number: self.shard_number,
     153            0 :             shard_count: self.shard_count,
     154            0 :         }
     155            0 :     }
     156              : 
     157              :     /// Calculate the children of this TenantShardId when splitting the overall tenant into
     158              :     /// the given number of shards.
     159            5 :     pub fn split(&self, new_shard_count: ShardCount) -> Vec<TenantShardId> {
     160            5 :         let effective_old_shard_count = std::cmp::max(self.shard_count.0, 1);
     161            5 :         let mut child_shards = Vec::new();
     162           24 :         for shard_number in 0..ShardNumber(new_shard_count.0).0 {
     163              :             // Key mapping is based on a round robin mapping of key hash modulo shard count,
     164              :             // so our child shards are the ones which the same keys would map to.
     165           24 :             if shard_number % effective_old_shard_count == self.shard_number.0 {
     166           20 :                 child_shards.push(TenantShardId {
     167           20 :                     tenant_id: self.tenant_id,
     168           20 :                     shard_number: ShardNumber(shard_number),
     169           20 :                     shard_count: new_shard_count,
     170           20 :                 })
     171            4 :             }
     172              :         }
     173              : 
     174            5 :         child_shards
     175            5 :     }
     176              : }
     177              : 
     178              : impl std::fmt::Display for ShardNumber {
     179            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     180            0 :         self.0.fmt(f)
     181            0 :     }
     182              : }
     183              : 
     184              : impl std::fmt::Display for ShardCount {
     185            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     186            0 :         self.0.fmt(f)
     187            0 :     }
     188              : }
     189              : 
     190              : impl std::fmt::Display for ShardStripeSize {
     191            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     192            0 :         self.0.fmt(f)
     193            0 :     }
     194              : }
     195              : 
     196              : impl std::fmt::Display for ShardSlug<'_> {
     197         6099 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     198         6099 :         write!(
     199         6099 :             f,
     200         6099 :             "{:02x}{:02x}",
     201              :             self.0.shard_number.0, self.0.shard_count.0
     202              :         )
     203         6099 :     }
     204              : }
     205              : 
     206              : impl std::fmt::Display for TenantShardId {
     207        11740 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     208        11740 :         if self.shard_count != ShardCount(0) {
     209          118 :             write!(f, "{}-{}", self.tenant_id, self.shard_slug())
     210              :         } else {
     211              :             // Legacy case (shard_count == 0) -- format as just the tenant id.  Note that this
     212              :             // is distinct from the normal single shard case (shard count == 1).
     213        11622 :             self.tenant_id.fmt(f)
     214              :         }
     215        11740 :     }
     216              : }
     217              : 
     218              : impl std::fmt::Debug for TenantShardId {
     219         5681 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     220              :         // Debug is the same as Display: the compact hex representation
     221         5681 :         write!(f, "{self}")
     222         5681 :     }
     223              : }
     224              : 
     225              : impl std::str::FromStr for TenantShardId {
     226              :     type Err = hex::FromHexError;
     227              : 
     228           16 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
     229              :         // Expect format: 16 byte TenantId, '-', 1 byte shard number, 1 byte shard count
     230           16 :         if s.len() == 32 {
     231              :             // Legacy case: no shard specified
     232              :             Ok(Self {
     233           13 :                 tenant_id: TenantId::from_str(s)?,
     234           13 :                 shard_number: ShardNumber(0),
     235           13 :                 shard_count: ShardCount(0),
     236              :             })
     237            3 :         } else if s.len() == 37 {
     238            3 :             let bytes = s.as_bytes();
     239            3 :             let tenant_id = TenantId::from_hex(&bytes[0..32])?;
     240            3 :             let mut shard_parts: [u8; 2] = [0u8; 2];
     241            3 :             hex::decode_to_slice(&bytes[33..37], &mut shard_parts)?;
     242            3 :             Ok(Self {
     243            3 :                 tenant_id,
     244            3 :                 shard_number: ShardNumber(shard_parts[0]),
     245            3 :                 shard_count: ShardCount(shard_parts[1]),
     246            3 :             })
     247              :         } else {
     248            0 :             Err(hex::FromHexError::InvalidStringLength)
     249              :         }
     250           16 :     }
     251              : }
     252              : 
     253              : impl From<[u8; 18]> for TenantShardId {
     254           25 :     fn from(b: [u8; 18]) -> Self {
     255           25 :         let tenant_id_bytes: [u8; 16] = b[0..16].try_into().unwrap();
     256              : 
     257           25 :         Self {
     258           25 :             tenant_id: TenantId::from(tenant_id_bytes),
     259           25 :             shard_number: ShardNumber(b[16]),
     260           25 :             shard_count: ShardCount(b[17]),
     261           25 :         }
     262           25 :     }
     263              : }
     264              : 
     265              : impl ShardIndex {
     266            7 :     pub fn new(number: ShardNumber, count: ShardCount) -> Self {
     267            7 :         Self {
     268            7 :             shard_number: number,
     269            7 :             shard_count: count,
     270            7 :         }
     271            7 :     }
     272           88 :     pub fn unsharded() -> Self {
     273           88 :         Self {
     274           88 :             shard_number: ShardNumber(0),
     275           88 :             shard_count: ShardCount(0),
     276           88 :         }
     277           88 :     }
     278              : 
     279              :     /// The "unsharded" value is distinct from simply having a single shard: it represents
     280              :     /// a tenant which is not shard-aware at all, and whose storage paths will not include
     281              :     /// a shard suffix.
     282        37554 :     pub fn is_unsharded(&self) -> bool {
     283        37554 :         self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0)
     284        37554 :     }
     285              : 
     286              :     /// For use in constructing remote storage paths: concatenate this with a TenantId
     287              :     /// to get a fully qualified TenantShardId.
     288              :     ///
     289              :     /// Backward compat: this function returns an empty string if Self::is_unsharded, such
     290              :     /// that the legacy pre-sharding remote key format is preserved.
     291         1098 :     pub fn get_suffix(&self) -> String {
     292         1098 :         if self.is_unsharded() {
     293         1092 :             "".to_string()
     294              :         } else {
     295            6 :             format!("-{:02x}{:02x}", self.shard_number.0, self.shard_count.0)
     296              :         }
     297         1098 :     }
     298              : }
     299              : 
     300              : impl std::fmt::Display for ShardIndex {
     301         1171 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     302         1171 :         write!(f, "{:02x}{:02x}", self.shard_number.0, self.shard_count.0)
     303         1171 :     }
     304              : }
     305              : 
     306              : impl std::fmt::Debug for ShardIndex {
     307          895 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     308              :         // Debug is the same as Display: the compact hex representation
     309          895 :         write!(f, "{self}")
     310          895 :     }
     311              : }
     312              : 
     313              : impl std::str::FromStr for ShardIndex {
     314              :     type Err = hex::FromHexError;
     315              : 
     316         1565 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
     317              :         // Expect format: 1 byte shard number, 1 byte shard count
     318         1565 :         if s.len() == 4 {
     319         1565 :             let bytes = s.as_bytes();
     320         1565 :             let mut shard_parts: [u8; 2] = [0u8; 2];
     321         1565 :             hex::decode_to_slice(bytes, &mut shard_parts)?;
     322         1565 :             Ok(Self {
     323         1565 :                 shard_number: ShardNumber(shard_parts[0]),
     324         1565 :                 shard_count: ShardCount(shard_parts[1]),
     325         1565 :             })
     326              :         } else {
     327            0 :             Err(hex::FromHexError::InvalidStringLength)
     328              :         }
     329         1565 :     }
     330              : }
     331              : 
     332              : impl From<[u8; 2]> for ShardIndex {
     333            1 :     fn from(b: [u8; 2]) -> Self {
     334            1 :         Self {
     335            1 :             shard_number: ShardNumber(b[0]),
     336            1 :             shard_count: ShardCount(b[1]),
     337            1 :         }
     338            1 :     }
     339              : }
     340              : 
     341              : impl Serialize for TenantShardId {
     342           35 :     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
     343           35 :     where
     344           35 :         S: serde::Serializer,
     345              :     {
     346           35 :         if serializer.is_human_readable() {
     347           31 :             serializer.collect_str(self)
     348              :         } else {
     349              :             // Note: while human encoding of [`TenantShardId`] is backward and forward
     350              :             // compatible, this binary encoding is not.
     351            4 :             let mut packed: [u8; 18] = [0; 18];
     352            4 :             packed[0..16].clone_from_slice(&self.tenant_id.as_arr());
     353            4 :             packed[16] = self.shard_number.0;
     354            4 :             packed[17] = self.shard_count.0;
     355              : 
     356            4 :             packed.serialize(serializer)
     357              :         }
     358            0 :     }
     359              : }
     360              : 
     361              : impl<'de> Deserialize<'de> for TenantShardId {
     362            6 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     363            6 :     where
     364            6 :         D: serde::Deserializer<'de>,
     365              :     {
     366              :         struct IdVisitor {
     367              :             is_human_readable_deserializer: bool,
     368              :         }
     369              : 
     370              :         impl<'de> serde::de::Visitor<'de> for IdVisitor {
     371              :             type Value = TenantShardId;
     372              : 
     373            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
     374            0 :                 if self.is_human_readable_deserializer {
     375            0 :                     formatter.write_str("value in form of hex string")
     376              :                 } else {
     377            0 :                     formatter.write_str("value in form of integer array([u8; 18])")
     378              :                 }
     379            0 :             }
     380              : 
     381            2 :             fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
     382            2 :             where
     383            2 :                 A: serde::de::SeqAccess<'de>,
     384              :             {
     385            2 :                 let s = serde::de::value::SeqAccessDeserializer::new(seq);
     386            2 :                 let id: [u8; 18] = Deserialize::deserialize(s)?;
     387            2 :                 Ok(TenantShardId::from(id))
     388            0 :             }
     389              : 
     390            4 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
     391            4 :             where
     392            4 :                 E: serde::de::Error,
     393              :             {
     394            4 :                 TenantShardId::from_str(v).map_err(E::custom)
     395            0 :             }
     396              :         }
     397              : 
     398            6 :         if deserializer.is_human_readable() {
     399            4 :             deserializer.deserialize_str(IdVisitor {
     400            4 :                 is_human_readable_deserializer: true,
     401            4 :             })
     402              :         } else {
     403            2 :             deserializer.deserialize_tuple(
     404              :                 18,
     405            2 :                 IdVisitor {
     406            2 :                     is_human_readable_deserializer: false,
     407            2 :                 },
     408              :             )
     409              :         }
     410            0 :     }
     411              : }
     412              : 
     413              : impl Serialize for ShardIndex {
     414           16 :     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
     415           16 :     where
     416           16 :         S: serde::Serializer,
     417              :     {
     418           16 :         if serializer.is_human_readable() {
     419           14 :             serializer.collect_str(self)
     420              :         } else {
     421              :             // Binary encoding is not used in index_part.json, but is included in anticipation of
     422              :             // switching various structures (e.g. inter-process communication, remote metadata) to more
     423              :             // compact binary encodings in future.
     424            2 :             let mut packed: [u8; 2] = [0; 2];
     425            2 :             packed[0] = self.shard_number.0;
     426            2 :             packed[1] = self.shard_count.0;
     427            2 :             packed.serialize(serializer)
     428              :         }
     429            0 :     }
     430              : }
     431              : 
     432              : impl<'de> Deserialize<'de> for ShardIndex {
     433         1565 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     434         1565 :     where
     435         1565 :         D: serde::Deserializer<'de>,
     436              :     {
     437              :         struct IdVisitor {
     438              :             is_human_readable_deserializer: bool,
     439              :         }
     440              : 
     441              :         impl<'de> serde::de::Visitor<'de> for IdVisitor {
     442              :             type Value = ShardIndex;
     443              : 
     444            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
     445            0 :                 if self.is_human_readable_deserializer {
     446            0 :                     formatter.write_str("value in form of hex string")
     447              :                 } else {
     448            0 :                     formatter.write_str("value in form of integer array([u8; 2])")
     449              :                 }
     450            0 :             }
     451              : 
     452            1 :             fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
     453            1 :             where
     454            1 :                 A: serde::de::SeqAccess<'de>,
     455              :             {
     456            1 :                 let s = serde::de::value::SeqAccessDeserializer::new(seq);
     457            1 :                 let id: [u8; 2] = Deserialize::deserialize(s)?;
     458            1 :                 Ok(ShardIndex::from(id))
     459            0 :             }
     460              : 
     461         1564 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
     462         1564 :             where
     463         1564 :                 E: serde::de::Error,
     464              :             {
     465         1564 :                 ShardIndex::from_str(v).map_err(E::custom)
     466            0 :             }
     467              :         }
     468              : 
     469         1565 :         if deserializer.is_human_readable() {
     470         1564 :             deserializer.deserialize_str(IdVisitor {
     471         1564 :                 is_human_readable_deserializer: true,
     472         1564 :             })
     473              :         } else {
     474            1 :             deserializer.deserialize_tuple(
     475              :                 2,
     476            1 :                 IdVisitor {
     477            1 :                     is_human_readable_deserializer: false,
     478            1 :                 },
     479              :             )
     480              :         }
     481            0 :     }
     482              : }
        

Generated by: LCOV version 2.1-beta