Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MachineError custom enum type #28

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ clap = { version = "4.3.10", features = ["derive"] }
stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "f6214d1" }
tracing = "0.1"
tracing-subscriber = "0.3"
thiserror = "2.0"
1 change: 1 addition & 0 deletions crates/brainfuck_vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ num-traits = "0.2.19"
stwo-prover.workspace = true
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["env-filter"] }
thiserror.workspace = true
13 changes: 8 additions & 5 deletions crates/brainfuck_vm/src/bin/brainfuck_vm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Adapted from rkdud007 brainfuck-zkvm https://github.com/rkdud007/brainfuck-zkvm/blob/main/src/main.rs

use brainfuck_vm::{compiler::Compiler, machine::Machine};
use brainfuck_vm::{
compiler::Compiler,
machine::{Machine, MachineError},
};
use clap::{Parser, ValueHint};
use std::{
fs,
Expand All @@ -23,7 +26,7 @@ struct Args {
ram_size: Option<usize>,
}

fn main() {
fn main() -> Result<(), MachineError> {
let args = Args::parse();

tracing_subscriber::fmt().with_env_filter(args.log).init();
Expand All @@ -36,8 +39,8 @@ fn main() {
let stdin = stdin();
let stdout = stdout();
let mut bf_vm = match args.ram_size {
Some(size) => Machine::new_with_config(&ins, stdin, stdout, size),
None => Machine::new(&ins, stdin, stdout),
Some(size) => Machine::new_with_config(&ins, stdin, stdout, size)?,
None => Machine::new(&ins, stdin, stdout)?,
};
tracing::info!("Provide inputs separated by linefeeds: ");
bf_vm.execute().unwrap();
Expand All @@ -48,5 +51,5 @@ fn main() {
let trace = bf_vm.get_trace();
tracing::info!("Execution trace: {:#?}", trace);
}
// Ok(())
Ok(())
}
40 changes: 26 additions & 14 deletions crates/brainfuck_vm/src/instruction.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
// Taken from rkdud007 brainfuck-zkvm https://github.com/rkdud007/brainfuck-zkvm/blob/main/src/instruction.rs

use std::{fmt::Display, str::FromStr};
use thiserror::Error;

/// Custom error type for instructions
#[derive(Debug, Error, PartialEq, Eq)]
pub enum InstructionError {
/// Error when converting a character to an instruction
#[error("Value `{0}` is not a valid instruction")]
Conversion(char),
}

#[derive(Debug, Clone)]
pub struct Instruction {
Expand Down Expand Up @@ -66,9 +75,12 @@ impl Display for InstructionType {
}
}

impl From<u8> for InstructionType {
fn from(value: u8) -> Self {
Self::from_str(&(value as char).to_string()).expect("Invalid instruction")
impl TryFrom<u8> for InstructionType {
type Error = InstructionError;

fn try_from(value: u8) -> Result<Self, Self::Error> {
Self::from_str(&(value as char).to_string())
.map_err(|()| InstructionError::Conversion(value as char))
}
}

Expand Down Expand Up @@ -114,21 +126,21 @@ mod tests {
// Test from_u8 implementation
#[test]
fn test_instruction_type_from_u8() {
assert_eq!(InstructionType::from(b'>'), InstructionType::Right);
assert_eq!(InstructionType::from(b'<'), InstructionType::Left);
assert_eq!(InstructionType::from(b'+'), InstructionType::Plus);
assert_eq!(InstructionType::from(b'-'), InstructionType::Minus);
assert_eq!(InstructionType::from(b'.'), InstructionType::PutChar);
assert_eq!(InstructionType::from(b','), InstructionType::ReadChar);
assert_eq!(InstructionType::from(b'['), InstructionType::JumpIfZero);
assert_eq!(InstructionType::from(b']'), InstructionType::JumpIfNotZero);
assert_eq!(InstructionType::try_from(b'>').unwrap(), InstructionType::Right);
assert_eq!(InstructionType::try_from(b'<').unwrap(), InstructionType::Left);
assert_eq!(InstructionType::try_from(b'+').unwrap(), InstructionType::Plus);
assert_eq!(InstructionType::try_from(b'-').unwrap(), InstructionType::Minus);
assert_eq!(InstructionType::try_from(b'.').unwrap(), InstructionType::PutChar);
assert_eq!(InstructionType::try_from(b',').unwrap(), InstructionType::ReadChar);
assert_eq!(InstructionType::try_from(b'[').unwrap(), InstructionType::JumpIfZero);
assert_eq!(InstructionType::try_from(b']').unwrap(), InstructionType::JumpIfNotZero);
}

// Test from_u8 with invalid input (should panic)
// Test to ensure invalid input as u8 returns an error
#[test]
#[should_panic(expected = "Invalid instruction")]
fn test_instruction_type_from_u8_invalid() {
let _ = InstructionType::from(b'x');
let result = InstructionType::try_from(b'x');
assert_eq!(result, Err(InstructionError::Conversion('x')));
}

// Test Instruction struct creation
Expand Down
78 changes: 48 additions & 30 deletions crates/brainfuck_vm/src/machine.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
// Adapted from rkdud007 brainfuck-zkvm https://github.com/rkdud007/brainfuck-zkvm/blob/main/src/machine.rs

use crate::{instruction::InstructionType, registers::Registers};
use num_traits::identities::{One, Zero};
use std::{
error::Error,
io::{Read, Write},
use crate::{
instruction::{InstructionError, InstructionType},
registers::Registers,
};
use num_traits::identities::{One, Zero};
use std::io::{Read, Write};
use stwo_prover::core::fields::{m31::BaseField, FieldExpOps};
use thiserror::Error;

/// Custom error type for the Machine.
#[derive(Debug, Error)]
pub enum MachineError {
/// I/O operation failed.
#[error("I/O operation failed: {0}")]
IoError(#[from] std::io::Error),

/// Instructions related error.
#[error("Instruction error: {0}")]
Instruction(#[from] InstructionError),
}

pub struct MachineBuilder {
code: Vec<BaseField>,
Expand Down Expand Up @@ -43,9 +56,13 @@ impl MachineBuilder {
}

/// Builds the [`Machine`] instance with the provided configuration.
pub fn build(self) -> Result<Machine, &'static str> {
pub fn build(self) -> Result<Machine, MachineError> {
if self.input.is_none() || self.output.is_none() {
return Err("Input and output streams must be provided");
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Input and output streams must be provided",
)
.into());
zmalatrax marked this conversation as resolved.
Show resolved Hide resolved
}

Ok(Machine {
Expand Down Expand Up @@ -86,7 +103,12 @@ pub struct Machine {
impl Machine {
pub const DEFAULT_RAM_SIZE: usize = 30000;

pub fn new_with_config<R, W>(code: &[BaseField], input: R, output: W, ram_size: usize) -> Self
pub fn new_with_config<R, W>(
code: &[BaseField],
input: R,
output: W,
ram_size: usize,
) -> Result<Self, MachineError>
where
R: Read + 'static,
W: Write + 'static,
Expand All @@ -96,22 +118,17 @@ impl Machine {
.with_output(output)
.with_ram_size(ram_size)
.build()
.expect("Failed to build Machine")
}

pub fn new<R, W>(code: &[BaseField], input: R, output: W) -> Self
pub fn new<R, W>(code: &[BaseField], input: R, output: W) -> Result<Self, MachineError>
where
R: Read + 'static,
W: Write + 'static,
{
MachineBuilder::new(code)
.with_input(input)
.with_output(output)
.build()
.expect("Failed to build Machine")
MachineBuilder::new(code).with_input(input).with_output(output).build()
}

pub fn execute(&mut self) -> Result<(), Box<dyn Error>> {
pub fn execute(&mut self) -> Result<(), MachineError> {
while self.state.registers.ip < BaseField::from(self.program.code.len()) {
self.state.registers.ci = self.program.code[self.state.registers.ip.0 as usize];
self.state.registers.ni =
Expand All @@ -121,7 +138,7 @@ impl Machine {
self.program.code[(self.state.registers.ip + BaseField::one()).0 as usize]
};
self.write_trace();
let ins_type = InstructionType::from(self.state.registers.ci.0 as u8);
let ins_type = InstructionType::try_from(self.state.registers.ci.0 as u8)?;
self.execute_instruction(&ins_type)?;
self.next_clock_cycle();
}
Expand All @@ -133,21 +150,21 @@ impl Machine {
Ok(())
}

fn read_char(&mut self) -> Result<(), std::io::Error> {
fn read_char(&mut self) -> Result<(), MachineError> {
let mut buf = [0; 1];
self.io.input.read_exact(&mut buf)?;
let input_char = buf[0] as usize;
self.state.ram[self.state.registers.mp.0 as usize] = BaseField::from(input_char as u32);
Ok(())
}

fn write_char(&mut self) -> Result<(), std::io::Error> {
fn write_char(&mut self) -> Result<(), MachineError> {
let char_to_write = self.state.ram[self.state.registers.mp.0 as usize].0 as u8;
self.io.output.write_all(&[char_to_write])?;
Ok(())
}

fn execute_instruction(&mut self, ins: &InstructionType) -> Result<(), Box<dyn Error>> {
fn execute_instruction(&mut self, ins: &InstructionType) -> Result<(), MachineError> {
match ins {
InstructionType::Right => {
self.state.registers.mp += BaseField::one();
Expand Down Expand Up @@ -255,7 +272,8 @@ mod tests {
let input: &[u8] = &[];
let output = TestWriter::new();
let ram_size = 55000;
let machine = Machine::new_with_config(&code, input, output, ram_size);
let machine = Machine::new_with_config(&code, input, output, ram_size)
.expect("Machine creation failed");

assert_eq!(machine.program, ProgramMemory { code });
assert_eq!(
Expand All @@ -265,7 +283,7 @@ mod tests {
}

#[test]
fn test_right_instruction() -> Result<(), Box<dyn Error>> {
fn test_right_instruction() -> Result<(), MachineError> {
// '>>'
let code = vec![BaseField::from(62), BaseField::from(62)];
let (mut machine, _) = create_test_machine(&code, &[]);
Expand All @@ -277,7 +295,7 @@ mod tests {
}

#[test]
fn test_left_instruction() -> Result<(), Box<dyn Error>> {
fn test_left_instruction() -> Result<(), MachineError> {
// '>><'
let code = vec![BaseField::from(62), BaseField::from(62), BaseField::from(60)];
let (mut machine, _) = create_test_machine(&code, &[]);
Expand All @@ -288,7 +306,7 @@ mod tests {
}

#[test]
fn test_plus_instruction() -> Result<(), Box<dyn Error>> {
fn test_plus_instruction() -> Result<(), MachineError> {
// '+'
let code = vec![BaseField::from(43)];
let (mut machine, _) = create_test_machine(&code, &[]);
Expand All @@ -300,7 +318,7 @@ mod tests {
}

#[test]
fn test_minus_instruction() -> Result<(), Box<dyn Error>> {
fn test_minus_instruction() -> Result<(), MachineError> {
// '--'
let code = vec![BaseField::from(45), BaseField::from(45)];
let (mut machine, _) = create_test_machine(&code, &[]);
Expand All @@ -312,7 +330,7 @@ mod tests {
}

#[test]
fn test_read_write_char() -> Result<(), Box<dyn Error>> {
fn test_read_write_char() -> Result<(), MachineError> {
// ',.'
let code = vec![BaseField::from(44), BaseField::from(46)];
let input = b"a";
Expand All @@ -326,7 +344,7 @@ mod tests {
}

#[test]
fn test_skip_loop() -> Result<(), Box<dyn Error>> {
fn test_skip_loop() -> Result<(), MachineError> {
// Skip the loop
// '[-]+'
let code = vec![
Expand All @@ -346,7 +364,7 @@ mod tests {
}

#[test]
fn test_enter_loop() -> Result<(), Box<dyn Error>> {
fn test_enter_loop() -> Result<(), MachineError> {
// Enter the loop
// '+[+>]'
let code = vec![
Expand All @@ -368,7 +386,7 @@ mod tests {
}

#[test]
fn test_get_trace() -> Result<(), Box<dyn Error>> {
fn test_get_trace() -> Result<(), MachineError> {
// '++'
let code = vec![BaseField::from(43), BaseField::from(43)];
let (mut machine, _) = create_test_machine(&code, &[]);
Expand Down Expand Up @@ -409,7 +427,7 @@ mod tests {
}

#[test]
fn test_pad_trace() -> Result<(), Box<dyn Error>> {
fn test_pad_trace() -> Result<(), MachineError> {
// '++'
let code = vec![BaseField::from(43), BaseField::from(43)];
let (mut machine, _) = create_test_machine(&code, &[]);
Expand Down
2 changes: 1 addition & 1 deletion crates/brainfuck_vm/src/test_helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ pub fn create_test_machine(code: &[BaseField], input: &[u8]) -> (Machine, TestWr
let input = Cursor::new(input.to_vec());
let output = TestWriter::new();
let test_output = output.clone();
let machine = Machine::new(code, input, output);
let machine = Machine::new(code, input, output).expect("Failed to create machine");
(machine, test_output)
}