LCOV - code coverage report
Current view: top level - libs/utils/src - bin_ser.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 87.6 % 209 183
Test Date: 2025-02-20 13:11:02 Functions: 44.3 % 271 120

            Line data    Source code
       1              : //! Utilities for binary serialization/deserialization.
       2              : //!
       3              : //! The [`BeSer`] trait allows us to define data structures
       4              : //! that can match data structures that are sent over the wire
       5              : //! in big-endian form with no packing.
       6              : //!
       7              : //! The [`LeSer`] trait does the same thing, in little-endian form.
       8              : //!
       9              : //! Note: you will get a compile error if you try to `use` both traits
      10              : //! in the same module or scope. This is intended to be a safety
      11              : //! mechanism: mixing big-endian and little-endian encoding in the same file
      12              : //! is error-prone.
      13              : 
      14              : #![warn(missing_docs)]
      15              : 
      16              : use bincode::Options;
      17              : use serde::{de::DeserializeOwned, Serialize};
      18              : use std::io::{self, Read, Write};
      19              : use thiserror::Error;
      20              : 
      21              : /// An error that occurred during a deserialize operation
      22              : ///
      23              : /// This could happen because the input data was too short,
      24              : /// or because an invalid value was encountered.
      25              : #[derive(Debug, Error)]
      26              : pub enum DeserializeError {
      27              :     /// The deserializer isn't able to deserialize the supplied data.
      28              :     #[error("deserialize error")]
      29              :     BadInput,
      30              :     /// While deserializing from a `Read` source, an `io::Error` occurred.
      31              :     #[error("deserialize error: {0}")]
      32              :     Io(io::Error),
      33              : }
      34              : 
      35              : impl From<bincode::Error> for DeserializeError {
      36            2 :     fn from(e: bincode::Error) -> Self {
      37            2 :         match *e {
      38            2 :             bincode::ErrorKind::Io(io_err) => DeserializeError::Io(io_err),
      39            0 :             _ => DeserializeError::BadInput,
      40              :         }
      41            2 :     }
      42              : }
      43              : 
      44              : /// An error that occurred during a serialize operation
      45              : ///
      46              : /// This probably means our [`Write`] failed, e.g. we tried
      47              : /// to write beyond the end of a buffer.
      48              : #[derive(Debug, Error)]
      49              : pub enum SerializeError {
      50              :     /// The serializer isn't able to serialize the supplied data.
      51              :     #[error("serialize error")]
      52              :     BadInput,
      53              :     /// While serializing into a `Write` sink, an `io::Error` occurred.
      54              :     #[error("serialize error: {0}")]
      55              :     Io(io::Error),
      56              : }
      57              : 
      58              : impl From<bincode::Error> for SerializeError {
      59            0 :     fn from(e: bincode::Error) -> Self {
      60            0 :         match *e {
      61            0 :             bincode::ErrorKind::Io(io_err) => SerializeError::Io(io_err),
      62            0 :             _ => SerializeError::BadInput,
      63              :         }
      64            0 :     }
      65              : }
      66              : 
      67              : /// A shortcut that configures big-endian binary serialization
      68              : ///
      69              : /// Properties:
      70              : /// - Big endian
      71              : /// - Fixed integer encoding (i.e. 1u32 is 00000001 not 01)
      72              : ///
      73              : /// Does not allow trailing bytes in deserialization. If this is desired, you
      74              : /// may set [`Options::allow_trailing_bytes`] to explicitly accommodate this.
      75     29223845 : pub fn be_coder() -> impl Options {
      76     29223845 :     bincode::DefaultOptions::new()
      77     29223845 :         .with_big_endian()
      78     29223845 :         .with_fixint_encoding()
      79     29223845 : }
      80              : 
      81              : /// A shortcut that configures little-ending binary serialization
      82              : ///
      83              : /// Properties:
      84              : /// - Little endian
      85              : /// - Fixed integer encoding (i.e. 1u32 is 00000001 not 01)
      86              : ///
      87              : /// Does not allow trailing bytes in deserialization. If this is desired, you
      88              : /// may set [`Options::allow_trailing_bytes`] to explicitly accommodate this.
      89       692562 : pub fn le_coder() -> impl Options {
      90       692562 :     bincode::DefaultOptions::new()
      91       692562 :         .with_little_endian()
      92       692562 :         .with_fixint_encoding()
      93       692562 : }
      94              : 
      95              : /// Binary serialize/deserialize helper functions (Big Endian)
      96              : ///
      97              : pub trait BeSer {
      98              :     /// Serialize into a byte slice
      99            0 :     fn ser_into_slice(&self, mut b: &mut [u8]) -> Result<(), SerializeError>
     100            0 :     where
     101            0 :         Self: Serialize,
     102            0 :     {
     103            0 :         // &mut [u8] implements Write, but `ser_into` needs a mutable
     104            0 :         // reference to that. So we need the slightly awkward "mutable
     105            0 :         // reference to a mutable reference.
     106            0 :         self.ser_into(&mut b)
     107            0 :     }
     108              : 
     109              :     /// Serialize into a borrowed writer
     110              :     ///
     111              :     /// This is useful for most `Write` types except `&mut [u8]`, which
     112              :     /// can more easily use [`ser_into_slice`](Self::ser_into_slice).
     113     10185122 :     fn ser_into<W: Write>(&self, w: &mut W) -> Result<(), SerializeError>
     114     10185122 :     where
     115     10185122 :         Self: Serialize,
     116     10185122 :     {
     117     10185122 :         be_coder().serialize_into(w, &self).map_err(|e| e.into())
     118     10185122 :     }
     119              : 
     120              :     /// Serialize into a new heap-allocated buffer
     121      4174982 :     fn ser(&self) -> Result<Vec<u8>, SerializeError>
     122      4174982 :     where
     123      4174982 :         Self: Serialize,
     124      4174982 :     {
     125      4174982 :         be_coder().serialize(&self).map_err(|e| e.into())
     126      4174982 :     }
     127              : 
     128              :     /// Deserialize from the full contents of a byte slice
     129              :     ///
     130              :     /// See also: [`BeSer::des_prefix`]
     131      5508598 :     fn des(buf: &[u8]) -> Result<Self, DeserializeError>
     132      5508598 :     where
     133      5508598 :         Self: DeserializeOwned,
     134      5508598 :     {
     135      5508598 :         be_coder()
     136      5508598 :             .deserialize(buf)
     137      5508598 :             .or(Err(DeserializeError::BadInput))
     138      5508598 :     }
     139              : 
     140              :     /// Deserialize from a prefix of the byte slice
     141              :     ///
     142              :     /// Uses as much of the byte slice as is necessary to deserialize the
     143              :     /// type, but does not guarantee that the entire slice is used.
     144              :     ///
     145              :     /// See also: [`BeSer::des`]
     146         2457 :     fn des_prefix(buf: &[u8]) -> Result<Self, DeserializeError>
     147         2457 :     where
     148         2457 :         Self: DeserializeOwned,
     149         2457 :     {
     150         2457 :         be_coder()
     151         2457 :             .allow_trailing_bytes()
     152         2457 :             .deserialize(buf)
     153         2457 :             .or(Err(DeserializeError::BadInput))
     154         2457 :     }
     155              : 
     156              :     /// Deserialize from a reader
     157            2 :     fn des_from<R: Read>(r: &mut R) -> Result<Self, DeserializeError>
     158            2 :     where
     159            2 :         Self: DeserializeOwned,
     160            2 :     {
     161            2 :         be_coder().deserialize_from(r).map_err(|e| e.into())
     162            2 :     }
     163              : 
     164              :     /// Compute the serialized size of a data structure
     165              :     ///
     166              :     /// Note: it may be faster to serialize to a buffer and then measure the
     167              :     /// buffer length, than to call `serialized_size` and then `ser_into`.
     168      9352684 :     fn serialized_size(&self) -> Result<u64, SerializeError>
     169      9352684 :     where
     170      9352684 :         Self: Serialize,
     171      9352684 :     {
     172      9352684 :         be_coder().serialized_size(self).map_err(|e| e.into())
     173      9352684 :     }
     174              : }
     175              : 
     176              : /// Binary serialize/deserialize helper functions (Little Endian)
     177              : ///
     178              : pub trait LeSer {
     179              :     /// Serialize into a byte slice
     180            0 :     fn ser_into_slice(&self, mut b: &mut [u8]) -> Result<(), SerializeError>
     181            0 :     where
     182            0 :         Self: Serialize,
     183            0 :     {
     184            0 :         // &mut [u8] implements Write, but `ser_into` needs a mutable
     185            0 :         // reference to that. So we need the slightly awkward "mutable
     186            0 :         // reference to a mutable reference.
     187            0 :         self.ser_into(&mut b)
     188            0 :     }
     189              : 
     190              :     /// Serialize into a borrowed writer
     191              :     ///
     192              :     /// This is useful for most `Write` types except `&mut [u8]`, which
     193              :     /// can more easily use [`ser_into_slice`](Self::ser_into_slice).
     194           23 :     fn ser_into<W: Write>(&self, w: &mut W) -> Result<(), SerializeError>
     195           23 :     where
     196           23 :         Self: Serialize,
     197           23 :     {
     198           23 :         le_coder().serialize_into(w, &self).map_err(|e| e.into())
     199           23 :     }
     200              : 
     201              :     /// Serialize into a new heap-allocated buffer
     202        36498 :     fn ser(&self) -> Result<Vec<u8>, SerializeError>
     203        36498 :     where
     204        36498 :         Self: Serialize,
     205        36498 :     {
     206        36498 :         le_coder().serialize(&self).map_err(|e| e.into())
     207        36498 :     }
     208              : 
     209              :     /// Deserialize from the full contents of a byte slice
     210              :     ///
     211              :     /// See also: [`LeSer::des_prefix`]
     212       316795 :     fn des(buf: &[u8]) -> Result<Self, DeserializeError>
     213       316795 :     where
     214       316795 :         Self: DeserializeOwned,
     215       316795 :     {
     216       316795 :         le_coder()
     217       316795 :             .deserialize(buf)
     218       316795 :             .or(Err(DeserializeError::BadInput))
     219       316795 :     }
     220              : 
     221              :     /// Deserialize from a prefix of the byte slice
     222              :     ///
     223              :     /// Uses as much of the byte slice as is necessary to deserialize the
     224              :     /// type, but does not guarantee that the entire slice is used.
     225              :     ///
     226              :     /// See also: [`LeSer::des`]
     227            9 :     fn des_prefix(buf: &[u8]) -> Result<Self, DeserializeError>
     228            9 :     where
     229            9 :         Self: DeserializeOwned,
     230            9 :     {
     231            9 :         le_coder()
     232            9 :             .allow_trailing_bytes()
     233            9 :             .deserialize(buf)
     234            9 :             .or(Err(DeserializeError::BadInput))
     235            9 :     }
     236              : 
     237              :     /// Deserialize from a reader
     238       339234 :     fn des_from<R: Read>(r: &mut R) -> Result<Self, DeserializeError>
     239       339234 :     where
     240       339234 :         Self: DeserializeOwned,
     241       339234 :     {
     242       339234 :         le_coder().deserialize_from(r).map_err(|e| e.into())
     243       339234 :     }
     244              : 
     245              :     /// Compute the serialized size of a data structure
     246              :     ///
     247              :     /// Note: it may be faster to serialize to a buffer and then measure the
     248              :     /// buffer length, than to call `serialized_size` and then `ser_into`.
     249            3 :     fn serialized_size(&self) -> Result<u64, SerializeError>
     250            3 :     where
     251            3 :         Self: Serialize,
     252            3 :     {
     253            3 :         le_coder().serialized_size(self).map_err(|e| e.into())
     254            3 :     }
     255              : }
     256              : 
     257              : // Because usage of `BeSer` or `LeSer` can be done with *either* a Serialize or
     258              : // DeserializeOwned implementation, the blanket implementation has to be for every type.
     259              : impl<T> BeSer for T {}
     260              : impl<T> LeSer for T {}
     261              : 
     262              : #[cfg(test)]
     263              : mod tests {
     264              :     use super::DeserializeError;
     265              :     use serde::{Deserialize, Serialize};
     266              :     use std::io::Cursor;
     267              : 
     268            2 :     #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
     269              :     pub struct ShortStruct {
     270              :         a: u8,
     271              :         b: u32,
     272              :     }
     273              : 
     274              :     const SHORT1: ShortStruct = ShortStruct { a: 7, b: 65536 };
     275              :     const SHORT1_ENC_BE: &[u8] = &[7, 0, 1, 0, 0];
     276              :     const SHORT1_ENC_BE_TRAILING: &[u8] = &[7, 0, 1, 0, 0, 255, 255, 255];
     277              :     const SHORT1_ENC_LE: &[u8] = &[7, 0, 0, 1, 0];
     278              :     const SHORT1_ENC_LE_TRAILING: &[u8] = &[7, 0, 0, 1, 0, 255, 255, 255];
     279              : 
     280              :     const SHORT2: ShortStruct = ShortStruct {
     281              :         a: 8,
     282              :         b: 0x07030000,
     283              :     };
     284              :     const SHORT2_ENC_BE: &[u8] = &[8, 7, 3, 0, 0];
     285              :     const SHORT2_ENC_BE_TRAILING: &[u8] = &[8, 7, 3, 0, 0, 0xff, 0xff, 0xff];
     286              :     const SHORT2_ENC_LE: &[u8] = &[8, 0, 0, 3, 7];
     287              :     const SHORT2_ENC_LE_TRAILING: &[u8] = &[8, 0, 0, 3, 7, 0xff, 0xff, 0xff];
     288              : 
     289            0 :     #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
     290              :     struct NewTypeStruct(u32);
     291              :     const NT1: NewTypeStruct = NewTypeStruct(414243);
     292              :     const NT1_INNER: u32 = 414243;
     293              : 
     294            0 :     #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
     295              :     pub struct LongMsg {
     296              :         pub tag: u8,
     297              :         pub blockpos: u32,
     298              :         pub last_flush_position: u64,
     299              :         pub apply: u64,
     300              :         pub timestamp: i64,
     301              :         pub reply_requested: u8,
     302              :     }
     303              : 
     304              :     const LONG1: LongMsg = LongMsg {
     305              :         tag: 42,
     306              :         blockpos: 0x1000_2000,
     307              :         last_flush_position: 0x1234_2345_3456_4567,
     308              :         apply: 0x9876_5432_10FE_DCBA,
     309              :         timestamp: 0x7788_99AA_BBCC_DDFF,
     310              :         reply_requested: 1,
     311              :     };
     312              : 
     313              :     #[test]
     314            1 :     fn be_short() {
     315              :         use super::BeSer;
     316              : 
     317            1 :         assert_eq!(SHORT1.serialized_size().unwrap(), 5);
     318              : 
     319            1 :         let encoded = SHORT1.ser().unwrap();
     320            1 :         assert_eq!(encoded, SHORT1_ENC_BE);
     321              : 
     322            1 :         let decoded = ShortStruct::des(SHORT2_ENC_BE).unwrap();
     323            1 :         assert_eq!(decoded, SHORT2);
     324              : 
     325              :         // with trailing data
     326            1 :         let decoded = ShortStruct::des_prefix(SHORT2_ENC_BE_TRAILING).unwrap();
     327            1 :         assert_eq!(decoded, SHORT2);
     328            1 :         let err = ShortStruct::des(SHORT2_ENC_BE_TRAILING).unwrap_err();
     329            1 :         assert!(matches!(err, DeserializeError::BadInput));
     330              : 
     331              :         // serialize into a `Write` sink.
     332            1 :         let mut buf = Cursor::new(vec![0xFF; 8]);
     333            1 :         SHORT1.ser_into(&mut buf).unwrap();
     334            1 :         assert_eq!(buf.into_inner(), SHORT1_ENC_BE_TRAILING);
     335              : 
     336              :         // deserialize from a `Write` sink.
     337            1 :         let mut buf = Cursor::new(SHORT2_ENC_BE);
     338            1 :         let decoded = ShortStruct::des_from(&mut buf).unwrap();
     339            1 :         assert_eq!(decoded, SHORT2);
     340              : 
     341              :         // deserialize from a `Write` sink that terminates early.
     342            1 :         let mut buf = Cursor::new([0u8; 4]);
     343            1 :         let err = ShortStruct::des_from(&mut buf).unwrap_err();
     344            1 :         assert!(matches!(err, DeserializeError::Io(_)));
     345            1 :     }
     346              : 
     347              :     #[test]
     348            1 :     fn le_short() {
     349              :         use super::LeSer;
     350              : 
     351            1 :         assert_eq!(SHORT1.serialized_size().unwrap(), 5);
     352              : 
     353            1 :         let encoded = SHORT1.ser().unwrap();
     354            1 :         assert_eq!(encoded, SHORT1_ENC_LE);
     355              : 
     356            1 :         let decoded = ShortStruct::des(SHORT2_ENC_LE).unwrap();
     357            1 :         assert_eq!(decoded, SHORT2);
     358              : 
     359              :         // with trailing data
     360            1 :         let decoded = ShortStruct::des_prefix(SHORT2_ENC_LE_TRAILING).unwrap();
     361            1 :         assert_eq!(decoded, SHORT2);
     362            1 :         let err = ShortStruct::des(SHORT2_ENC_LE_TRAILING).unwrap_err();
     363            1 :         assert!(matches!(err, DeserializeError::BadInput));
     364              : 
     365              :         // serialize into a `Write` sink.
     366            1 :         let mut buf = Cursor::new(vec![0xFF; 8]);
     367            1 :         SHORT1.ser_into(&mut buf).unwrap();
     368            1 :         assert_eq!(buf.into_inner(), SHORT1_ENC_LE_TRAILING);
     369              : 
     370              :         // deserialize from a `Write` sink.
     371            1 :         let mut buf = Cursor::new(SHORT2_ENC_LE);
     372            1 :         let decoded = ShortStruct::des_from(&mut buf).unwrap();
     373            1 :         assert_eq!(decoded, SHORT2);
     374              : 
     375              :         // deserialize from a `Write` sink that terminates early.
     376            1 :         let mut buf = Cursor::new([0u8; 4]);
     377            1 :         let err = ShortStruct::des_from(&mut buf).unwrap_err();
     378            1 :         assert!(matches!(err, DeserializeError::Io(_)));
     379            1 :     }
     380              : 
     381              :     #[test]
     382            1 :     fn be_long() {
     383              :         use super::BeSer;
     384              : 
     385            1 :         assert_eq!(LONG1.serialized_size().unwrap(), 30);
     386              : 
     387            1 :         let msg = LONG1;
     388            1 : 
     389            1 :         let encoded = msg.ser().unwrap();
     390            1 :         let expected = hex_literal::hex!(
     391            1 :             "2A 1000 2000 1234 2345 3456 4567 9876 5432 10FE DCBA 7788 99AA BBCC DDFF 01"
     392            1 :         );
     393            1 :         assert_eq!(encoded, expected);
     394              : 
     395            1 :         let msg2 = LongMsg::des(&encoded).unwrap();
     396            1 :         assert_eq!(msg, msg2);
     397            1 :     }
     398              : 
     399              :     #[test]
     400            1 :     fn le_long() {
     401              :         use super::LeSer;
     402              : 
     403            1 :         assert_eq!(LONG1.serialized_size().unwrap(), 30);
     404              : 
     405            1 :         let msg = LONG1;
     406            1 : 
     407            1 :         let encoded = msg.ser().unwrap();
     408            1 :         let expected = hex_literal::hex!(
     409            1 :             "2A 0020 0010 6745 5634 4523 3412 BADC FE10 3254 7698 FFDD CCBB AA99 8877 01"
     410            1 :         );
     411            1 :         assert_eq!(encoded, expected);
     412              : 
     413            1 :         let msg2 = LongMsg::des(&encoded).unwrap();
     414            1 :         assert_eq!(msg, msg2);
     415            1 :     }
     416              : 
     417              :     #[test]
     418              :     /// Ensure that newtype wrappers around u32 don't change the serialization format
     419            1 :     fn be_nt() {
     420              :         use super::BeSer;
     421              : 
     422            1 :         assert_eq!(NT1.serialized_size().unwrap(), 4);
     423              : 
     424            1 :         let msg = NT1;
     425            1 : 
     426            1 :         let encoded = msg.ser().unwrap();
     427            1 :         let expected = hex_literal::hex!("0006 5223");
     428            1 :         assert_eq!(encoded, expected);
     429              : 
     430            1 :         assert_eq!(encoded, NT1_INNER.ser().unwrap());
     431              : 
     432            1 :         let msg2 = NewTypeStruct::des(&encoded).unwrap();
     433            1 :         assert_eq!(msg, msg2);
     434            1 :     }
     435              : 
     436              :     #[test]
     437              :     /// Ensure that newtype wrappers around u32 don't change the serialization format
     438            1 :     fn le_nt() {
     439              :         use super::LeSer;
     440              : 
     441            1 :         assert_eq!(NT1.serialized_size().unwrap(), 4);
     442              : 
     443            1 :         let msg = NT1;
     444            1 : 
     445            1 :         let encoded = msg.ser().unwrap();
     446            1 :         let expected = hex_literal::hex!("2352 0600");
     447            1 :         assert_eq!(encoded, expected);
     448              : 
     449            1 :         assert_eq!(encoded, NT1_INNER.ser().unwrap());
     450              : 
     451            1 :         let msg2 = NewTypeStruct::des(&encoded).unwrap();
     452            1 :         assert_eq!(msg, msg2);
     453            1 :     }
     454              : }
        

Generated by: LCOV version 2.1-beta