From d4f6c7cc5a10e5722e76da8f56c3298a9d055e21 Mon Sep 17 00:00:00 2001 From: Daniel Lehmann Date: Tue, 23 Jul 2024 13:25:56 -0700 Subject: [PATCH] Rewrite borsh tests to catch more errors In particular, test at byte boundaries and try some unusual (large) uints --- src/lib.rs | 15 +++--- tests/tests.rs | 134 +++++++++++++++++++++++++++++++++++++------------ 2 files changed, 107 insertions(+), 42 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 51ea79c..7dc16bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1066,9 +1066,6 @@ where } } -// Borsh is byte-size little-endian de-needs-external-schema no-bit-compression serde. -// Current ser/de for it is not optimal impl because const math is not stable nor primitives has bits traits. -// Uses minimal amount of bytes to fit needed amount of bits without compression (borsh does not have it anyway). #[cfg(feature = "borsh")] impl borsh::BorshSerialize for UInt where @@ -1084,16 +1081,16 @@ where { fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { let value = self.value(); - let length = (BITS + 7) / 8; - let mut bytes = 0; - let mask: T = u8::MAX.into(); - while bytes < length { - let le_byte: u8 = ((value >> (bytes << 3)) & mask) + let total_bytes = (BITS + 7) / 8; + let mut byte_count_written = 0; + let byte_mask: T = u8::MAX.into(); + while byte_count_written < total_bytes { + let le_byte: u8 = ((value >> (byte_count_written << 3)) & byte_mask) .try_into() .ok() .expect("we cut to u8 via mask"); writer.write(&[le_byte])?; - bytes += 1; + byte_count_written += 1; } Ok(()) } diff --git a/tests/tests.rs b/tests/tests.rs index 3687875..ac84078 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1912,42 +1912,110 @@ fn serde() { ); } -#[cfg(all(feature = "borsh", feature = "std"))] -#[test] -fn borsh() { +#[cfg(feature = "borsh")] +mod borsh_tests { + use arbitrary_int::{ + u1, u14, u15, u50, u6, u63, u65, u7, u72, u79, u80, u81, u9, Number, UInt, + }; use borsh::schema::BorshSchemaContainer; - use borsh::{BorshDeserialize, BorshSerialize}; - let mut buf = Vec::new(); - let base_input: u8 = 42; - let input = u9::new(base_input.into()); - input.serialize(&mut buf).unwrap(); - let output = u9::deserialize(&mut buf.as_ref()).unwrap(); - let fits = u16::new(base_input.into()); - assert_eq!(buf, fits.to_le_bytes()); - assert_eq!(input, output); - - let input = u63::MAX; - let fits = u64::new(input.value()); - let mut buf = Vec::new(); - input.serialize(&mut buf).unwrap(); - let output: u63 = u63::deserialize(&mut buf.as_ref()).unwrap(); - assert_eq!(buf, fits.to_le_bytes()); - assert_eq!(input, output); - - let schema = BorshSchemaContainer::for_type::(); - match schema.get_definition("u9").expect("exists") { - borsh::schema::Definition::Primitive(2) => {} - _ => panic!("unexpected schema"), + use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; + use std::fmt::Debug; + + fn test_roundtrip( + input: T, + expected_buffer: &[u8], + ) { + let mut buf = Vec::new(); + + // Serialize and compare against expected + input.serialize(&mut buf).unwrap(); + assert_eq!(buf, expected_buffer); + + // Deserialize back and compare against input + let output = T::deserialize(&mut buf.as_ref()).unwrap(); + assert_eq!(input, output); + } + + #[test] + fn test_serialize_deserialize() { + // Run against plain u64 first (not an arbitrary_int) + test_roundtrip( + 0x12345678_9ABCDEF0u64, + &[0xF0, 0xDE, 0xBC, 0x9A, 0x78, 0x56, 0x34, 0x12], + ); + + // Now try various arbitrary ints + test_roundtrip(u1::new(0b0), &[0]); + test_roundtrip(u1::new(0b1), &[1]); + test_roundtrip(u6::new(0b101101), &[0b101101]); + test_roundtrip(u14::new(0b110101_11001101), &[0b11001101, 0b110101]); + test_roundtrip( + u72::new(0x36_01234567_89ABCDEF), + &[0xEF, 0xCD, 0xAB, 0x89, 0x67, 0x45, 0x23, 0x01, 0x36], + ); + + // Pick a byte boundary (80; test one below and one above to ensure we get the right number + // of bytes) + test_roundtrip( + u79::MAX, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], + ); + test_roundtrip( + u80::MAX, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF], + ); + test_roundtrip( + u81::MAX, + &[ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01, + ], + ); + + // Test actual u128 and arbitrary u128 (which is a legal one, though not a predefined) + test_roundtrip( + u128::MAX, + &[ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, + ], + ); + test_roundtrip( + UInt::::MAX, + &[ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, + ], + ); + } + + fn verify_byte_count_in_schema(expected_byte_count: u8, name: &str) { + let schema = BorshSchemaContainer::for_type::(); + match schema.get_definition(name).expect("exists") { + borsh::schema::Definition::Primitive(byte_count) => { + assert_eq!(*byte_count, expected_byte_count); + } + _ => panic!("unexpected schema"), + } } - let input = u50::MAX; - let fits = u64::new(input.value()); - let mut buf = Vec::new(); - input.serialize(&mut buf).unwrap(); - assert!(buf.len() < fits.to_le_bytes().len()); - assert_eq!(buf, fits.to_le_bytes()[0..((u50::BITS + 7) / 8)]); - let output: u50 = u50::deserialize(&mut buf.as_ref()).unwrap(); - assert_eq!(input, output); + #[test] + fn test_schema_byte_count() { + verify_byte_count_in_schema::(1, "u1"); + + verify_byte_count_in_schema::(1, "u7"); + + verify_byte_count_in_schema::>(1, "u8"); + verify_byte_count_in_schema::>(1, "u8"); + + verify_byte_count_in_schema::(2, "u9"); + + verify_byte_count_in_schema::(2, "u15"); + verify_byte_count_in_schema::>(2, "u15"); + + verify_byte_count_in_schema::(8, "u63"); + + verify_byte_count_in_schema::(9, "u65"); + } } #[cfg(feature = "schemars")]