Skip to content

Commit

Permalink
Rewrite borsh tests to catch more errors
Browse files Browse the repository at this point in the history
In particular, test at byte boundaries and try some unusual (large) uints
  • Loading branch information
danlehmann committed Jul 23, 2024
1 parent 2f638e7 commit d4f6c7c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 42 deletions.
15 changes: 6 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1066,9 +1066,6 @@ where
}
}

// Borsh is byte-size little-endian de-needs-external-schema no-bit-compression serde.
// Current ser/de for it is not optimal impl because const math is not stable nor primitives has bits traits.
// Uses minimal amount of bytes to fit needed amount of bits without compression (borsh does not have it anyway).
#[cfg(feature = "borsh")]
impl<T, const BITS: usize> borsh::BorshSerialize for UInt<T, BITS>
where
Expand All @@ -1084,16 +1081,16 @@ where
{
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
let value = self.value();
let length = (BITS + 7) / 8;
let mut bytes = 0;
let mask: T = u8::MAX.into();
while bytes < length {
let le_byte: u8 = ((value >> (bytes << 3)) & mask)
let total_bytes = (BITS + 7) / 8;
let mut byte_count_written = 0;
let byte_mask: T = u8::MAX.into();
while byte_count_written < total_bytes {
let le_byte: u8 = ((value >> (byte_count_written << 3)) & byte_mask)
.try_into()
.ok()
.expect("we cut to u8 via mask");
writer.write(&[le_byte])?;
bytes += 1;
byte_count_written += 1;
}
Ok(())
}
Expand Down
134 changes: 101 additions & 33 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1912,42 +1912,110 @@ fn serde() {
);
}

#[cfg(all(feature = "borsh", feature = "std"))]
#[test]
fn borsh() {
#[cfg(feature = "borsh")]
mod borsh_tests {
use arbitrary_int::{
u1, u14, u15, u50, u6, u63, u65, u7, u72, u79, u80, u81, u9, Number, UInt,

Check failure on line 1918 in tests/tests.rs

View workflow job for this annotation

GitHub Actions / build-and-test

unused import: `u50`
};
use borsh::schema::BorshSchemaContainer;
use borsh::{BorshDeserialize, BorshSerialize};
let mut buf = Vec::new();
let base_input: u8 = 42;
let input = u9::new(base_input.into());
input.serialize(&mut buf).unwrap();
let output = u9::deserialize(&mut buf.as_ref()).unwrap();
let fits = u16::new(base_input.into());
assert_eq!(buf, fits.to_le_bytes());
assert_eq!(input, output);

let input = u63::MAX;
let fits = u64::new(input.value());
let mut buf = Vec::new();
input.serialize(&mut buf).unwrap();
let output: u63 = u63::deserialize(&mut buf.as_ref()).unwrap();
assert_eq!(buf, fits.to_le_bytes());
assert_eq!(input, output);

let schema = BorshSchemaContainer::for_type::<u9>();
match schema.get_definition("u9").expect("exists") {
borsh::schema::Definition::Primitive(2) => {}
_ => panic!("unexpected schema"),
use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};
use std::fmt::Debug;

fn test_roundtrip<T: Number + BorshSerialize + BorshDeserialize + PartialEq + Eq + Debug>(
input: T,
expected_buffer: &[u8],
) {
let mut buf = Vec::new();

// Serialize and compare against expected
input.serialize(&mut buf).unwrap();
assert_eq!(buf, expected_buffer);

// Deserialize back and compare against input
let output = T::deserialize(&mut buf.as_ref()).unwrap();
assert_eq!(input, output);
}

#[test]
fn test_serialize_deserialize() {
// Run against plain u64 first (not an arbitrary_int)
test_roundtrip(
0x12345678_9ABCDEF0u64,
&[0xF0, 0xDE, 0xBC, 0x9A, 0x78, 0x56, 0x34, 0x12],
);

// Now try various arbitrary ints
test_roundtrip(u1::new(0b0), &[0]);
test_roundtrip(u1::new(0b1), &[1]);
test_roundtrip(u6::new(0b101101), &[0b101101]);
test_roundtrip(u14::new(0b110101_11001101), &[0b11001101, 0b110101]);
test_roundtrip(
u72::new(0x36_01234567_89ABCDEF),
&[0xEF, 0xCD, 0xAB, 0x89, 0x67, 0x45, 0x23, 0x01, 0x36],
);

// Pick a byte boundary (80; test one below and one above to ensure we get the right number
// of bytes)
test_roundtrip(
u79::MAX,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
test_roundtrip(
u80::MAX,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
);
test_roundtrip(
u81::MAX,
&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01,
],
);

// Test actual u128 and arbitrary u128 (which is a legal one, though not a predefined)
test_roundtrip(
u128::MAX,
&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF,
],
);
test_roundtrip(
UInt::<u128, 128>::MAX,
&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF,
],
);
}

fn verify_byte_count_in_schema<T: BorshSchema + ?Sized>(expected_byte_count: u8, name: &str) {
let schema = BorshSchemaContainer::for_type::<T>();
match schema.get_definition(name).expect("exists") {
borsh::schema::Definition::Primitive(byte_count) => {
assert_eq!(*byte_count, expected_byte_count);
}
_ => panic!("unexpected schema"),
}
}

let input = u50::MAX;
let fits = u64::new(input.value());
let mut buf = Vec::new();
input.serialize(&mut buf).unwrap();
assert!(buf.len() < fits.to_le_bytes().len());
assert_eq!(buf, fits.to_le_bytes()[0..((u50::BITS + 7) / 8)]);
let output: u50 = u50::deserialize(&mut buf.as_ref()).unwrap();
assert_eq!(input, output);
#[test]
fn test_schema_byte_count() {
verify_byte_count_in_schema::<u1>(1, "u1");

verify_byte_count_in_schema::<u7>(1, "u7");

verify_byte_count_in_schema::<UInt<u8, 8>>(1, "u8");
verify_byte_count_in_schema::<UInt<u32, 8>>(1, "u8");

verify_byte_count_in_schema::<u9>(2, "u9");

verify_byte_count_in_schema::<u15>(2, "u15");
verify_byte_count_in_schema::<UInt<u128, 15>>(2, "u15");

verify_byte_count_in_schema::<u63>(8, "u63");

verify_byte_count_in_schema::<u65>(9, "u65");
}
}

#[cfg(feature = "schemars")]
Expand Down

0 comments on commit d4f6c7c

Please sign in to comment.