Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

primitives: add unit tests and small refactor #1751

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 59 additions & 17 deletions crates/primitives/src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,35 @@ pub const WORD_SIZE: usize = 4;

/// Converts a slice of words to a byte vector in little endian.
pub fn words_to_bytes_le_vec(words: &[u32]) -> Vec<u8> {
words.iter().flat_map(|word| word.to_le_bytes().to_vec()).collect::<Vec<_>>()
words.iter().flat_map(|&word| word.to_le_bytes()).collect()
}

/// Converts a slice of words to a slice of bytes in little endian.
pub fn words_to_bytes_le<const B: usize>(words: &[u32]) -> [u8; B] {
debug_assert_eq!(words.len() * 4, B);
words
.iter()
.flat_map(|word| word.to_le_bytes().to_vec())
.collect::<Vec<_>>()
.try_into()
.unwrap()
debug_assert_eq!(words.len() * WORD_SIZE, B);
let mut bytes = [0u8; B];
words.iter().enumerate().for_each(|(i, &word)| {
bytes[i * WORD_SIZE..(i + 1) * WORD_SIZE].copy_from_slice(&word.to_le_bytes());
});
bytes
}

/// Converts a byte array in little endian to a slice of words.
pub fn bytes_to_words_le<const W: usize>(bytes: &[u8]) -> [u32; W] {
debug_assert_eq!(bytes.len(), W * 4);
bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
.collect::<Vec<_>>()
.try_into()
.unwrap()
debug_assert_eq!(bytes.len(), W * WORD_SIZE);
let mut words = [0u32; W];
bytes.chunks_exact(WORD_SIZE).enumerate().for_each(|(i, chunk)| {
words[i] = u32::from_le_bytes(chunk.try_into().unwrap());
});
words
}

/// Converts a byte array in little endian to a vector of words.
pub fn bytes_to_words_le_vec(bytes: &[u8]) -> Vec<u32> {
bytes
.chunks_exact(4)
.chunks_exact(WORD_SIZE)
.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
.collect::<Vec<_>>()
.collect()
}

// Converts a num to a string with commas every 3 digits.
Expand All @@ -54,3 +52,47 @@ pub fn num_to_comma_separated<T: ToString>(value: T) -> String {
.rev()
.collect()
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_words_to_bytes_le_vec() {
let words = [0x12345678, 0x90ABCDEF];
let expected = vec![0x78, 0x56, 0x34, 0x12, 0xEF, 0xCD, 0xAB, 0x90];
assert_eq!(words_to_bytes_le_vec(&words), expected);
}

#[test]
fn test_words_to_bytes_le() {
let words = [0x12345678, 0x90ABCDEF];
let expected: [u8; 8] = [0x78, 0x56, 0x34, 0x12, 0xEF, 0xCD, 0xAB, 0x90];
assert_eq!(words_to_bytes_le::<8>(&words), expected);
}

#[test]
fn test_bytes_to_words_le() {
let bytes = [0x78, 0x56, 0x34, 0x12, 0xEF, 0xCD, 0xAB, 0x90];
let expected = [0x12345678, 0x90ABCDEF];
assert_eq!(bytes_to_words_le::<2>(&bytes), expected);
}

#[test]
fn test_bytes_to_words_le_vec() {
let bytes = [0x78, 0x56, 0x34, 0x12, 0xEF, 0xCD, 0xAB, 0x90];
let expected = vec![0x12345678, 0x90ABCDEF];
assert_eq!(bytes_to_words_le_vec(&bytes), expected);
}

#[test]
fn test_num_to_comma_separated() {
assert_eq!(num_to_comma_separated(1000), "1,000");
assert_eq!(num_to_comma_separated(1000000), "1,000,000");
assert_eq!(num_to_comma_separated(987654321), "987,654,321");

// Test with a large number as BigUint
let large_num = num_bigint::BigUint::from(12345678901234567890u64);
assert_eq!(num_to_comma_separated(large_num), "12,345,678,901,234,567,890");
}
}
80 changes: 73 additions & 7 deletions crates/primitives/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ pub struct SP1PublicValues {
}

impl SP1PublicValues {
/// Create a new `SP1PublicValues`.
pub const fn new() -> Self {
Self { buffer: Buffer::new() }
}

pub fn raw(&self) -> String {
format!("0x{}", hex::encode(self.buffer.data.clone()))
}
Expand All @@ -32,7 +27,7 @@ impl SP1PublicValues {
self.buffer.data.clone()
}

/// Read a value from the buffer.
/// Read a value from the buffer.
pub fn read<T: Serialize + DeserializeOwned>(&mut self) -> T {
self.buffer.read()
}
Expand Down Expand Up @@ -89,13 +84,84 @@ impl AsRef<[u8]> for SP1PublicValues {
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestStruct {
a: u32,
b: String,
}

#[test]
fn test_new() {
let public_values = SP1PublicValues::default();
assert!(public_values.buffer.data.is_empty());
}

#[test]
fn test_from_slice() {
let data = b"test data";
let public_values = SP1PublicValues::from(data.as_ref());
assert_eq!(public_values.buffer.data, data);
}

#[test]
fn test_raw() {
let data = b"test raw data";
let public_values = SP1PublicValues::from(data.as_ref());
let expected_hex = format!("0x{}", hex::encode(data));
assert_eq!(public_values.raw(), expected_hex);
}

#[test]
fn test_as_slice() {
let data = b"test slice data";
let public_values = SP1PublicValues::from(data.as_ref());
assert_eq!(public_values.as_slice(), data);
}

#[test]
fn test_to_vec() {
let data = b"test vec data";
let public_values = SP1PublicValues::from(data.as_ref());
assert_eq!(public_values.to_vec(), data.to_vec());
}

#[test]
fn test_write_and_read() {
let mut public_values = SP1PublicValues::default();
let obj = TestStruct { a: 123, b: "test".to_string() };

public_values.write(&obj);
let read_obj: TestStruct = public_values.read();
assert_eq!(read_obj, obj);
}

#[test]
fn test_write_slice_and_read_slice() {
let mut public_values = SP1PublicValues::default();
let slice = [1, 2, 3, 4, 5];
public_values.write_slice(&slice);

let mut read_slice = [0; 5];
public_values.read_slice(&mut read_slice);
assert_eq!(read_slice, slice);
}

#[test]
fn test_hash() {
let data = b"some data to hash";
let public_values = SP1PublicValues::from(data.as_ref());
let expected_hash = Sha256::digest(data).to_vec();
assert_eq!(public_values.hash(), expected_hash);
}

#[test]
fn test_hash_public_values() {
let test_hex = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef";
let test_bytes = hex::decode(test_hex).unwrap();

let mut public_values = SP1PublicValues::new();
let mut public_values = SP1PublicValues::default();
public_values.write_slice(&test_bytes);
let hash = public_values.hash_bn254();

Expand Down
120 changes: 112 additions & 8 deletions crates/primitives/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub enum RecursionProgramType {
Wrap,
}

/// A buffer of serializable/deserializable objects.
/// A buffer of serializable/deserializable objects.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Buffer {
pub data: Vec<u8>,
Expand All @@ -26,12 +26,12 @@ impl Buffer {
Self { data: data.to_vec(), ptr: 0 }
}

/// Set the position ptr to the beginning of the buffer.
/// Set the position ptr to the beginning of the buffer.
pub fn head(&mut self) {
self.ptr = 0;
}

/// Read the serializable object from the buffer.
/// Read the serializable object from the buffer.
pub fn read<T: Serialize + DeserializeOwned>(&mut self) -> T {
let result: T =
bincode::deserialize(&self.data[self.ptr..]).expect("failed to deserialize");
Expand All @@ -45,14 +45,12 @@ impl Buffer {
self.ptr += slice.len();
}

/// Write the serializable object from the buffer.
/// Write the serializable object from the buffer.
pub fn write<T: Serialize>(&mut self, data: &T) {
let mut tmp = Vec::new();
bincode::serialize_into(&mut tmp, data).expect("serialization failed");
self.data.extend(tmp);
bincode::serialize_into(&mut self.data, data).expect("serialization failed");
}

/// Write the slice of bytes to the buffer.
/// Write the slice of bytes to the buffer.
pub fn write_slice(&mut self, slice: &[u8]) {
self.data.extend_from_slice(slice);
}
Expand All @@ -63,3 +61,109 @@ impl Default for Buffer {
Self::new()
}
}

#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestStruct {
a: u32,
b: String,
}

#[test]
fn test_new() {
let buffer = Buffer::new();
assert!(buffer.data.is_empty());
assert_eq!(buffer.ptr, 0);
}

#[test]
fn test_from_slice() {
let data: &[u8] = &[1, 2, 3, 4];
let buffer = Buffer::from(data);
assert_eq!(buffer.data, data);
assert_eq!(buffer.ptr, 0);
}

#[test]
fn test_head() {
let data: &[u8] = &[1, 2, 3, 4];
let mut buffer = Buffer::from(data);
buffer.ptr = 2;
buffer.head();
assert_eq!(buffer.ptr, 0);
}

#[test]
fn test_write_and_read() {
let mut buffer = Buffer::new();
let obj = TestStruct { a: 123, b: "test".to_string() };

// Serialize `obj` using bincode for comparison
let mut expected_data = Vec::new();
bincode::serialize_into(&mut expected_data, &obj).expect("serialization failed");

// Write `obj` to buffer and check if `buffer.data` matches `expected_data`
buffer.write(&obj);
assert_eq!(buffer.data, expected_data);
assert_eq!(buffer.ptr, 0);

let read_obj: TestStruct = buffer.read();
assert_eq!(read_obj, obj);
assert_eq!(buffer.ptr, buffer.data.len());
}

#[test]
fn test_write_slice_and_read_slice() {
let mut buffer = Buffer::new();
let slice = [1, 2, 3, 4, 5];

buffer.write_slice(&slice);
assert_eq!(buffer.data, slice);

let mut read_slice = [0; 5];
buffer.head();
buffer.read_slice(&mut read_slice);
assert_eq!(read_slice, slice);
}

#[test]
fn test_multiple_writes_and_reads() {
let mut buffer = Buffer::new();
let obj1 = TestStruct { a: 123, b: "first".to_string() };
let obj2 = TestStruct { a: 456, b: "second".to_string() };

buffer.write(&obj1);
buffer.write(&obj2);

buffer.head();
let read_obj1: TestStruct = buffer.read();
let read_obj2: TestStruct = buffer.read();

assert_eq!(read_obj1, obj1);
assert_eq!(read_obj2, obj2);
}

#[test]
fn test_default() {
let buffer: Buffer = Default::default();
assert!(buffer.data.is_empty());
assert_eq!(buffer.ptr, 0);
}

#[test]
fn test_pointer_after_read() {
let mut buffer = Buffer::new();
let obj = TestStruct { a: 789, b: "pointer_test".to_string() };

buffer.write(&obj);
buffer.head();
let start_ptr = buffer.ptr;

let _read_obj: TestStruct = buffer.read();
assert!(buffer.ptr > start_ptr, "Pointer should have advanced after read");
}
}
Loading