From b7ca18281793a07cceeb5c4ab368d0e949e53cf3 Mon Sep 17 00:00:00 2001 From: Daniel Lehmann Date: Tue, 23 Jul 2024 14:16:09 -0700 Subject: [PATCH] Simplify borsh serialize and deserialize --- src/lib.rs | 35 ++++++++++++++++++----------------- tests/tests.rs | 6 ++++++ 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 95ec417..8357560 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1080,18 +1080,11 @@ where Shr + TryInto + From + BitAnd, { fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { - let value = self.value(); - let total_bytes = (BITS + 7) / 8; - let mut byte_count_written = 0; - let byte_mask: T = u8::MAX.into(); - while byte_count_written < total_bytes { - let le_byte: u8 = ((value >> (byte_count_written << 3)) & byte_mask) - .try_into() - .ok() - .expect("we cut to u8 via mask"); - writer.write(&[le_byte])?; - byte_count_written += 1; - } + let serialized_byte_count = (BITS + 7) / 8; + let mut buffer = [0u8; 16]; + self.value.serialize(&mut &mut buffer[..])?; + writer.write(&buffer[0..serialized_byte_count])?; + Ok(()) } } @@ -1106,12 +1099,20 @@ where { fn deserialize_reader(reader: &mut R) -> borsh::io::Result { // Ideally, we'd want a buffer of size `BITS >> 3` or `size_of::`, but that's not possible - // with arrays. - // So instead we'll do a 16 byte buffer which handles the largest arbitrary-ints possible - - // not ideal, but still pretty small and better than going through an allocator. + // with arrays at present (feature(generic_const_exprs), once stable, will allow this). + // vec! would be an option, but an allocation is not expected at this level. + // Therefore, allocate a 16 byte buffer and take a slice out of it. + let serialized_byte_count = (BITS + 7) / 8; + let underlying_byte_count = core::mem::size_of::(); let mut buf = [0u8; 16]; - reader.read(&mut buf)?; - let value = T::deserialize(&mut &buf[..])?; + + // Read from the source, advancing cursor by the exact right number of bytes + reader.read(&mut buf[..serialized_byte_count])?; + + // Deserialize the underlying type. We have to pass in the correct number of bytes of the + // underlying type (or more, but let's be precise). The unused bytes are all still zero + let value = T::deserialize(&mut &buf[..underlying_byte_count])?; + if value >= Self::MIN.value() && value <= Self::MAX.value() { Ok(Self { value }) } else { diff --git a/tests/tests.rs b/tests/tests.rs index 9d8ffaf..e674586 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1929,9 +1929,15 @@ mod borsh_tests { input.serialize(&mut buf).unwrap(); assert_eq!(buf, expected_buffer); + // Add to the buffer a second time - this is a better test for the deserialization + // as it ensures we request the correct number of bytes + input.serialize(&mut buf).unwrap(); + // Deserialize back and compare against input let output = T::deserialize(&mut buf.as_ref()).unwrap(); + let output2 = T::deserialize(&mut &buf[buf.len() / 2..]).unwrap(); assert_eq!(input, output); + assert_eq!(input, output2); } #[test]