diff --git a/crates/primitives/src/consts.rs b/crates/primitives/src/consts.rs index 396905274b..008068f3a0 100644 --- a/crates/primitives/src/consts.rs +++ b/crates/primitives/src/consts.rs @@ -6,37 +6,35 @@ pub const WORD_SIZE: usize = 4; /// Converts a slice of words to a byte vector in little endian. pub fn words_to_bytes_le_vec(words: &[u32]) -> Vec { - words.iter().flat_map(|word| word.to_le_bytes().to_vec()).collect::>() + words.iter().flat_map(|&word| word.to_le_bytes()).collect() } /// Converts a slice of words to a slice of bytes in little endian. pub fn words_to_bytes_le(words: &[u32]) -> [u8; B] { - debug_assert_eq!(words.len() * 4, B); - words - .iter() - .flat_map(|word| word.to_le_bytes().to_vec()) - .collect::>() - .try_into() - .unwrap() + debug_assert_eq!(words.len() * WORD_SIZE, B); + let mut bytes = [0u8; B]; + words.iter().enumerate().for_each(|(i, &word)| { + bytes[i * WORD_SIZE..(i + 1) * WORD_SIZE].copy_from_slice(&word.to_le_bytes()); + }); + bytes } /// Converts a byte array in little endian to a slice of words. pub fn bytes_to_words_le(bytes: &[u8]) -> [u32; W] { - debug_assert_eq!(bytes.len(), W * 4); - bytes - .chunks_exact(4) - .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) - .collect::>() - .try_into() - .unwrap() + debug_assert_eq!(bytes.len(), W * WORD_SIZE); + let mut words = [0u32; W]; + bytes.chunks_exact(WORD_SIZE).enumerate().for_each(|(i, chunk)| { + words[i] = u32::from_le_bytes(chunk.try_into().unwrap()); + }); + words } /// Converts a byte array in little endian to a vector of words. pub fn bytes_to_words_le_vec(bytes: &[u8]) -> Vec { bytes - .chunks_exact(4) + .chunks_exact(WORD_SIZE) .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) - .collect::>() + .collect() } // Converts a num to a string with commas every 3 digits. @@ -54,3 +52,47 @@ pub fn num_to_comma_separated(value: T) -> String { .rev() .collect() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_words_to_bytes_le_vec() { + let words = [0x12345678, 0x90ABCDEF]; + let expected = vec![0x78, 0x56, 0x34, 0x12, 0xEF, 0xCD, 0xAB, 0x90]; + assert_eq!(words_to_bytes_le_vec(&words), expected); + } + + #[test] + fn test_words_to_bytes_le() { + let words = [0x12345678, 0x90ABCDEF]; + let expected: [u8; 8] = [0x78, 0x56, 0x34, 0x12, 0xEF, 0xCD, 0xAB, 0x90]; + assert_eq!(words_to_bytes_le::<8>(&words), expected); + } + + #[test] + fn test_bytes_to_words_le() { + let bytes = [0x78, 0x56, 0x34, 0x12, 0xEF, 0xCD, 0xAB, 0x90]; + let expected = [0x12345678, 0x90ABCDEF]; + assert_eq!(bytes_to_words_le::<2>(&bytes), expected); + } + + #[test] + fn test_bytes_to_words_le_vec() { + let bytes = [0x78, 0x56, 0x34, 0x12, 0xEF, 0xCD, 0xAB, 0x90]; + let expected = vec![0x12345678, 0x90ABCDEF]; + assert_eq!(bytes_to_words_le_vec(&bytes), expected); + } + + #[test] + fn test_num_to_comma_separated() { + assert_eq!(num_to_comma_separated(1000), "1,000"); + assert_eq!(num_to_comma_separated(1000000), "1,000,000"); + assert_eq!(num_to_comma_separated(987654321), "987,654,321"); + + // Test with a large number as BigUint + let large_num = num_bigint::BigUint::from(12345678901234567890u64); + assert_eq!(num_to_comma_separated(large_num), "12,345,678,901,234,567,890"); + } +} diff --git a/crates/primitives/src/io.rs b/crates/primitives/src/io.rs index 0d4d89e957..afc7cfb722 100644 --- a/crates/primitives/src/io.rs +++ b/crates/primitives/src/io.rs @@ -10,11 +10,6 @@ pub struct SP1PublicValues { } impl SP1PublicValues { - /// Create a new `SP1PublicValues`. - pub const fn new() -> Self { - Self { buffer: Buffer::new() } - } - pub fn raw(&self) -> String { format!("0x{}", hex::encode(self.buffer.data.clone())) } @@ -32,7 +27,7 @@ impl SP1PublicValues { self.buffer.data.clone() } - /// Read a value from the buffer. + /// Read a value from the buffer. pub fn read(&mut self) -> T { self.buffer.read() } @@ -89,13 +84,84 @@ impl AsRef<[u8]> for SP1PublicValues { #[cfg(test)] mod tests { use super::*; + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Serialize, Deserialize, PartialEq)] + struct TestStruct { + a: u32, + b: String, + } + + #[test] + fn test_new() { + let public_values = SP1PublicValues::default(); + assert!(public_values.buffer.data.is_empty()); + } + + #[test] + fn test_from_slice() { + let data = b"test data"; + let public_values = SP1PublicValues::from(data.as_ref()); + assert_eq!(public_values.buffer.data, data); + } + + #[test] + fn test_raw() { + let data = b"test raw data"; + let public_values = SP1PublicValues::from(data.as_ref()); + let expected_hex = format!("0x{}", hex::encode(data)); + assert_eq!(public_values.raw(), expected_hex); + } + + #[test] + fn test_as_slice() { + let data = b"test slice data"; + let public_values = SP1PublicValues::from(data.as_ref()); + assert_eq!(public_values.as_slice(), data); + } + + #[test] + fn test_to_vec() { + let data = b"test vec data"; + let public_values = SP1PublicValues::from(data.as_ref()); + assert_eq!(public_values.to_vec(), data.to_vec()); + } + + #[test] + fn test_write_and_read() { + let mut public_values = SP1PublicValues::default(); + let obj = TestStruct { a: 123, b: "test".to_string() }; + + public_values.write(&obj); + let read_obj: TestStruct = public_values.read(); + assert_eq!(read_obj, obj); + } + + #[test] + fn test_write_slice_and_read_slice() { + let mut public_values = SP1PublicValues::default(); + let slice = [1, 2, 3, 4, 5]; + public_values.write_slice(&slice); + + let mut read_slice = [0; 5]; + public_values.read_slice(&mut read_slice); + assert_eq!(read_slice, slice); + } + + #[test] + fn test_hash() { + let data = b"some data to hash"; + let public_values = SP1PublicValues::from(data.as_ref()); + let expected_hash = Sha256::digest(data).to_vec(); + assert_eq!(public_values.hash(), expected_hash); + } #[test] fn test_hash_public_values() { let test_hex = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"; let test_bytes = hex::decode(test_hex).unwrap(); - let mut public_values = SP1PublicValues::new(); + let mut public_values = SP1PublicValues::default(); public_values.write_slice(&test_bytes); let hash = public_values.hash_bn254(); diff --git a/crates/primitives/src/types.rs b/crates/primitives/src/types.rs index 6e0e2d326d..62dc49b59b 100644 --- a/crates/primitives/src/types.rs +++ b/crates/primitives/src/types.rs @@ -9,7 +9,7 @@ pub enum RecursionProgramType { Wrap, } -/// A buffer of serializable/deserializable objects. +/// A buffer of serializable/deserializable objects. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Buffer { pub data: Vec, @@ -26,12 +26,12 @@ impl Buffer { Self { data: data.to_vec(), ptr: 0 } } - /// Set the position ptr to the beginning of the buffer. + /// Set the position ptr to the beginning of the buffer. pub fn head(&mut self) { self.ptr = 0; } - /// Read the serializable object from the buffer. + /// Read the serializable object from the buffer. pub fn read(&mut self) -> T { let result: T = bincode::deserialize(&self.data[self.ptr..]).expect("failed to deserialize"); @@ -45,14 +45,12 @@ impl Buffer { self.ptr += slice.len(); } - /// Write the serializable object from the buffer. + /// Write the serializable object from the buffer. pub fn write(&mut self, data: &T) { - let mut tmp = Vec::new(); - bincode::serialize_into(&mut tmp, data).expect("serialization failed"); - self.data.extend(tmp); + bincode::serialize_into(&mut self.data, data).expect("serialization failed"); } - /// Write the slice of bytes to the buffer. + /// Write the slice of bytes to the buffer. pub fn write_slice(&mut self, slice: &[u8]) { self.data.extend_from_slice(slice); } @@ -63,3 +61,109 @@ impl Default for Buffer { Self::new() } } + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Serialize, Deserialize, PartialEq)] + struct TestStruct { + a: u32, + b: String, + } + + #[test] + fn test_new() { + let buffer = Buffer::new(); + assert!(buffer.data.is_empty()); + assert_eq!(buffer.ptr, 0); + } + + #[test] + fn test_from_slice() { + let data: &[u8] = &[1, 2, 3, 4]; + let buffer = Buffer::from(data); + assert_eq!(buffer.data, data); + assert_eq!(buffer.ptr, 0); + } + + #[test] + fn test_head() { + let data: &[u8] = &[1, 2, 3, 4]; + let mut buffer = Buffer::from(data); + buffer.ptr = 2; + buffer.head(); + assert_eq!(buffer.ptr, 0); + } + + #[test] + fn test_write_and_read() { + let mut buffer = Buffer::new(); + let obj = TestStruct { a: 123, b: "test".to_string() }; + + // Serialize `obj` using bincode for comparison + let mut expected_data = Vec::new(); + bincode::serialize_into(&mut expected_data, &obj).expect("serialization failed"); + + // Write `obj` to buffer and check if `buffer.data` matches `expected_data` + buffer.write(&obj); + assert_eq!(buffer.data, expected_data); + assert_eq!(buffer.ptr, 0); + + let read_obj: TestStruct = buffer.read(); + assert_eq!(read_obj, obj); + assert_eq!(buffer.ptr, buffer.data.len()); + } + + #[test] + fn test_write_slice_and_read_slice() { + let mut buffer = Buffer::new(); + let slice = [1, 2, 3, 4, 5]; + + buffer.write_slice(&slice); + assert_eq!(buffer.data, slice); + + let mut read_slice = [0; 5]; + buffer.head(); + buffer.read_slice(&mut read_slice); + assert_eq!(read_slice, slice); + } + + #[test] + fn test_multiple_writes_and_reads() { + let mut buffer = Buffer::new(); + let obj1 = TestStruct { a: 123, b: "first".to_string() }; + let obj2 = TestStruct { a: 456, b: "second".to_string() }; + + buffer.write(&obj1); + buffer.write(&obj2); + + buffer.head(); + let read_obj1: TestStruct = buffer.read(); + let read_obj2: TestStruct = buffer.read(); + + assert_eq!(read_obj1, obj1); + assert_eq!(read_obj2, obj2); + } + + #[test] + fn test_default() { + let buffer: Buffer = Default::default(); + assert!(buffer.data.is_empty()); + assert_eq!(buffer.ptr, 0); + } + + #[test] + fn test_pointer_after_read() { + let mut buffer = Buffer::new(); + let obj = TestStruct { a: 789, b: "pointer_test".to_string() }; + + buffer.write(&obj); + buffer.head(); + let start_ptr = buffer.ptr; + + let _read_obj: TestStruct = buffer.read(); + assert!(buffer.ptr > start_ptr, "Pointer should have advanced after read"); + } +}