diff --git a/Cargo.toml b/Cargo.toml index ab12cb6..e1ce01e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,11 +32,11 @@ borsh = ["dep:borsh"] schemars = ["dep:schemars", "std"] [dependencies] -num-traits = { version = "0.2.17", default-features = false, optional = true } -defmt = { version = "0.3.5", optional = true } -serde = { version = "1.0", optional = true, default-features = false} +num-traits = { version = "0.2.19", default-features = false, optional = true } +defmt = { version = "0.3.8", optional = true } +serde = { version = "1.0", optional = true, default-features = false } borsh = { version = "1.5.1", optional = true, features = ["unstable__schema"], default-features = false } -schemars = { version = "0.8.1", optional = true, features = ["derive"], default-features = false } +schemars = { version = "0.8.21", optional = true, features = ["derive"], default-features = false } [dev-dependencies] serde_test = "1.0" diff --git a/src/lib.rs b/src/lib.rs index d937325..8880f46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,9 +21,6 @@ use core::ops::{ #[cfg(feature = "serde")] use serde::{Deserialize, Deserializer, Serialize, Serializer}; -#[cfg(feature = "borsh")] -use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; - #[cfg(all(feature = "borsh", not(feature = "std")))] use alloc::{collections::BTreeMap, string::ToString}; @@ -1069,51 +1066,46 @@ where } } -// Borsh is byte-size little-endian de-needs-external-schema no-bit-compression serde. -// Current ser/de for it is not optimal impl because const math is not stable nor primitives has bits traits. -// Uses minimal amount of bytes to fit needed amount of bits without compression (borsh does not have it anyway). #[cfg(feature = "borsh")] -impl BorshSerialize for UInt +impl borsh::BorshSerialize for UInt where Self: Number, - T: BorshSerialize - + From - + BitAnd - + TryInto - + Copy - + Shr, - as Number>::UnderlyingType: - Shr + TryInto + From + BitAnd, + T: borsh::BorshSerialize, { fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { - let value = self.value(); - let length = (BITS + 7) / 8; - let mut bytes = 0; - let mask: T = u8::MAX.into(); - while bytes < length { - let le_byte: u8 = ((value >> (bytes << 3)) & mask) - .try_into() - .ok() - .expect("we cut to u8 via mask"); - writer.write(&[le_byte])?; - bytes += 1; - } + let serialized_byte_count = (BITS + 7) / 8; + let mut buffer = [0u8; 16]; + self.value.serialize(&mut &mut buffer[..])?; + writer.write(&buffer[0..serialized_byte_count])?; + Ok(()) } } #[cfg(feature = "borsh")] impl< - T: BorshDeserialize + core::cmp::PartialOrd< as Number>::UnderlyingType>, + T: borsh::BorshDeserialize + PartialOrd< as Number>::UnderlyingType>, const BITS: usize, - > BorshDeserialize for UInt + > borsh::BorshDeserialize for UInt where Self: Number, { fn deserialize_reader(reader: &mut R) -> borsh::io::Result { - let mut buf = vec![0u8; core::mem::size_of::()]; - reader.read(&mut buf)?; - let value = T::deserialize(&mut &buf[..])?; + // Ideally, we'd want a buffer of size `BITS >> 3` or `size_of::`, but that's not possible + // with arrays at present (feature(generic_const_exprs), once stable, will allow this). + // vec! would be an option, but an allocation is not expected at this level. + // Therefore, allocate a 16 byte buffer and take a slice out of it. + let serialized_byte_count = (BITS + 7) / 8; + let underlying_byte_count = core::mem::size_of::(); + let mut buf = [0u8; 16]; + + // Read from the source, advancing cursor by the exact right number of bytes + reader.read(&mut buf[..serialized_byte_count])?; + + // Deserialize the underlying type. We have to pass in the correct number of bytes of the + // underlying type (or more, but let's be precise). The unused bytes are all still zero + let value = T::deserialize(&mut &buf[..underlying_byte_count])?; + if value >= Self::MIN.value() && value <= Self::MAX.value() { Ok(Self { value }) } else { @@ -1126,7 +1118,7 @@ where } #[cfg(feature = "borsh")] -impl BorshSchema for UInt { +impl borsh::BorshSchema for UInt { fn add_definitions_recursively( definitions: &mut BTreeMap, ) { diff --git a/tests/tests.rs b/tests/tests.rs index 3687875..e674586 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1912,42 +1912,114 @@ fn serde() { ); } -#[cfg(all(feature = "borsh", feature = "std"))] -#[test] -fn borsh() { +#[cfg(feature = "borsh")] +mod borsh_tests { + use arbitrary_int::{u1, u14, u15, u6, u63, u65, u7, u72, u79, u80, u81, u9, Number, UInt}; use borsh::schema::BorshSchemaContainer; - use borsh::{BorshDeserialize, BorshSerialize}; - let mut buf = Vec::new(); - let base_input: u8 = 42; - let input = u9::new(base_input.into()); - input.serialize(&mut buf).unwrap(); - let output = u9::deserialize(&mut buf.as_ref()).unwrap(); - let fits = u16::new(base_input.into()); - assert_eq!(buf, fits.to_le_bytes()); - assert_eq!(input, output); - - let input = u63::MAX; - let fits = u64::new(input.value()); - let mut buf = Vec::new(); - input.serialize(&mut buf).unwrap(); - let output: u63 = u63::deserialize(&mut buf.as_ref()).unwrap(); - assert_eq!(buf, fits.to_le_bytes()); - assert_eq!(input, output); - - let schema = BorshSchemaContainer::for_type::(); - match schema.get_definition("u9").expect("exists") { - borsh::schema::Definition::Primitive(2) => {} - _ => panic!("unexpected schema"), + use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; + use std::fmt::Debug; + + fn test_roundtrip( + input: T, + expected_buffer: &[u8], + ) { + let mut buf = Vec::new(); + + // Serialize and compare against expected + input.serialize(&mut buf).unwrap(); + assert_eq!(buf, expected_buffer); + + // Add to the buffer a second time - this is a better test for the deserialization + // as it ensures we request the correct number of bytes + input.serialize(&mut buf).unwrap(); + + // Deserialize back and compare against input + let output = T::deserialize(&mut buf.as_ref()).unwrap(); + let output2 = T::deserialize(&mut &buf[buf.len() / 2..]).unwrap(); + assert_eq!(input, output); + assert_eq!(input, output2); + } + + #[test] + fn test_serialize_deserialize() { + // Run against plain u64 first (not an arbitrary_int) + test_roundtrip( + 0x12345678_9ABCDEF0u64, + &[0xF0, 0xDE, 0xBC, 0x9A, 0x78, 0x56, 0x34, 0x12], + ); + + // Now try various arbitrary ints + test_roundtrip(u1::new(0b0), &[0]); + test_roundtrip(u1::new(0b1), &[1]); + test_roundtrip(u6::new(0b101101), &[0b101101]); + test_roundtrip(u14::new(0b110101_11001101), &[0b11001101, 0b110101]); + test_roundtrip( + u72::new(0x36_01234567_89ABCDEF), + &[0xEF, 0xCD, 0xAB, 0x89, 0x67, 0x45, 0x23, 0x01, 0x36], + ); + + // Pick a byte boundary (80; test one below and one above to ensure we get the right number + // of bytes) + test_roundtrip( + u79::MAX, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], + ); + test_roundtrip( + u80::MAX, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF], + ); + test_roundtrip( + u81::MAX, + &[ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01, + ], + ); + + // Test actual u128 and arbitrary u128 (which is a legal one, though not a predefined) + test_roundtrip( + u128::MAX, + &[ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, + ], + ); + test_roundtrip( + UInt::::MAX, + &[ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, + ], + ); + } + + fn verify_byte_count_in_schema(expected_byte_count: u8, name: &str) { + let schema = BorshSchemaContainer::for_type::(); + match schema.get_definition(name).expect("exists") { + borsh::schema::Definition::Primitive(byte_count) => { + assert_eq!(*byte_count, expected_byte_count); + } + _ => panic!("unexpected schema"), + } } - let input = u50::MAX; - let fits = u64::new(input.value()); - let mut buf = Vec::new(); - input.serialize(&mut buf).unwrap(); - assert!(buf.len() < fits.to_le_bytes().len()); - assert_eq!(buf, fits.to_le_bytes()[0..((u50::BITS + 7) / 8)]); - let output: u50 = u50::deserialize(&mut buf.as_ref()).unwrap(); - assert_eq!(input, output); + #[test] + fn test_schema_byte_count() { + verify_byte_count_in_schema::(1, "u1"); + + verify_byte_count_in_schema::(1, "u7"); + + verify_byte_count_in_schema::>(1, "u8"); + verify_byte_count_in_schema::>(1, "u8"); + + verify_byte_count_in_schema::(2, "u9"); + + verify_byte_count_in_schema::(2, "u15"); + verify_byte_count_in_schema::>(2, "u15"); + + verify_byte_count_in_schema::(8, "u63"); + + verify_byte_count_in_schema::(9, "u65"); + } } #[cfg(feature = "schemars")]