Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add IO traces to entrypoints #144

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 108 additions & 13 deletions crates/brainfuck_prover/src/brainfuck_air/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use crate::components::{
instruction::table::InstructionElements,
io::table::IoElements,
io::{
self,
component::{InputComponent, InputEval, OutputComponent, OutputEval},
table::{InputTable, IoElements, OutputTable},
},
memory::{
self,
component::{MemoryComponent, MemoryEval},
table::{interaction_trace_evaluation, MemoryElements, MemoryTable},
table::{MemoryElements, MemoryTable},
},
MemoryClaim,
IoClaim, MemoryClaim,
};
use brainfuck_vm::machine::Machine;
use stwo_prover::{
Expand Down Expand Up @@ -42,16 +46,22 @@ pub struct BrainfuckProof<H: MerkleHasher> {
/// It includes the common claim values such as the initial and final states
/// and the claim of each component.
pub struct BrainfuckClaim {
pub input: IoClaim,
pub output: IoClaim,
pub memory: MemoryClaim,
}

impl BrainfuckClaim {
pub fn mix_into(&self, channel: &mut impl Channel) {
self.input.mix_into(channel);
self.output.mix_into(channel);
self.memory.mix_into(channel);
}

pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
let mut log_sizes = self.memory.log_sizes();
let mut log_sizes = TreeVec::concat_cols(
[self.input.log_sizes(), self.output.log_sizes(), self.memory.log_sizes()].into_iter(),
);

// We overwrite the preprocessed column claim to have all log sizes
// in the merkle root for the verification.
Expand Down Expand Up @@ -87,12 +97,16 @@ impl BrainfuckInteractionElements {
///
/// Mainly the claims on global relations (lookup, permutation, evaluation).
pub struct BrainfuckInteractionClaim {
input: io::component::InteractionClaim,
output: io::component::InteractionClaim,
memory: memory::component::InteractionClaim,
}

impl BrainfuckInteractionClaim {
/// Mix the claimed sums of every components in the Fiat-Shamir [`Channel`].
pub fn mix_into(&self, channel: &mut impl Channel) {
self.input.mix_into(channel);
self.output.mix_into(channel);
self.memory.mix_into(channel);
}
}
Expand All @@ -111,6 +125,8 @@ pub fn lookup_sum_valid(
/// Components are used by the prover as a `ComponentProver`,
/// and by the verifier as a `Component`.
pub struct BrainfuckComponents {
input: InputComponent,
output: OutputComponent,
memory: MemoryComponent,
}

Expand All @@ -129,18 +145,28 @@ impl BrainfuckComponents {
.collect::<Vec<_>>(),
);

let input = InputComponent::new(
tree_span_provider,
InputEval::new(&claim.input, interaction_elements.input_lookup_elements.clone()),
(interaction_claim.input.claimed_sum, None),
);
let output = OutputComponent::new(
tree_span_provider,
OutputEval::new(&claim.output, interaction_elements.output_lookup_elements.clone()),
(interaction_claim.output.claimed_sum, None),
);
let memory = MemoryComponent::new(
tree_span_provider,
MemoryEval::new(&claim.memory, interaction_elements.memory_lookup_elements.clone()),
(interaction_claim.memory.claimed_sum, None),
);

Self { memory }
Self { input, output, memory }
}

/// Returns the `ComponentProver` of each components, used by the prover.
pub fn provers(&self) -> Vec<&dyn ComponentProver<SimdBackend>> {
vec![&self.memory]
vec![&self.input, &self.output, &self.memory]
}

/// Returns the `Component` of each components, used by the verifier.
Expand Down Expand Up @@ -168,7 +194,7 @@ const LOG_MAX_ROWS: u32 = 20;
///
/// Ideally, we should cover all possible log sizes, between
/// 1 and `LOG_MAX_ROW`
const IS_FIRST_LOG_SIZES: [u32; 4] = [10, 9, 8, 5];
const IS_FIRST_LOG_SIZES: [u32; 8] = [15, 10, 9, 8, 7, 6, 5, 4];

/// Generate a STARK proof of the given Brainfuck program execution.
///
Expand Down Expand Up @@ -211,11 +237,15 @@ pub fn prove_brainfuck(
let mut tree_builder = commitment_scheme.tree_builder();

let vm_trace = inputs.trace();
let (memory_trace, memory_claim) = MemoryTable::from(vm_trace).trace_evaluation().unwrap();
let (input_trace, input_claim) = InputTable::from(&vm_trace).trace_evaluation();
let (output_trace, output_claim) = OutputTable::from(&vm_trace).trace_evaluation();
let (memory_trace, memory_claim) = MemoryTable::from(&vm_trace).trace_evaluation().unwrap();

tree_builder.extend_evals(input_trace.clone());
tree_builder.extend_evals(output_trace.clone());
tree_builder.extend_evals(memory_trace.clone());

let claim = BrainfuckClaim { memory: memory_claim };
let claim = BrainfuckClaim { input: input_claim, output: output_claim, memory: memory_claim };

// Mix the claim into the Fiat-Shamir channel.
claim.mix_into(channel);
Expand All @@ -232,13 +262,36 @@ pub fn prove_brainfuck(
// Generate the interaction trace and the BrainfuckInteractionClaim
let mut tree_builder = commitment_scheme.tree_builder();

let (memory_interaction_trace_eval, memory_interaction_claim) =
interaction_trace_evaluation(&memory_trace, &interaction_elements.memory_lookup_elements)
.unwrap();
let (input_interaction_trace_eval, input_interaction_claim) =
io::table::interaction_trace_evaluation(
&input_trace,
&interaction_elements.input_lookup_elements,
)
.unwrap();

let (output_interaction_trace_eval, output_interaction_claim) =
io::table::interaction_trace_evaluation(
&output_trace,
&interaction_elements.output_lookup_elements,
)
.unwrap();

let (memory_interaction_trace_eval, memory_interaction_claim) =
memory::table::interaction_trace_evaluation(
&memory_trace,
&interaction_elements.memory_lookup_elements,
)
.unwrap();

tree_builder.extend_evals(input_interaction_trace_eval);
tree_builder.extend_evals(output_interaction_trace_eval);
tree_builder.extend_evals(memory_interaction_trace_eval);

let interaction_claim = BrainfuckInteractionClaim { memory: memory_interaction_claim };
let interaction_claim = BrainfuckInteractionClaim {
input: input_interaction_claim,
output: output_interaction_claim,
memory: memory_interaction_claim,
};

// Mix the interaction claim into the Fiat-Shamir channel.
interaction_claim.mix_into(channel);
Expand Down Expand Up @@ -323,6 +376,20 @@ mod tests {

use super::{prove_brainfuck, verify_brainfuck};

#[test]
fn test_proof_cpu() {
// Get an execution trace from a valid Brainfuck program
let code = "+>,.";
let mut compiler = Compiler::new(code);
let instructions = compiler.compile();
let (mut machine, _) = create_test_machine(&instructions, &[1]);
let () = machine.execute().expect("Failed to execute machine");

let brainfuck_proof = prove_brainfuck(&machine).unwrap();

verify_brainfuck(brainfuck_proof).unwrap();
}

#[test]
fn test_proof() {
// Get an execution trace from a valid Brainfuck program
Expand All @@ -336,4 +403,32 @@ mod tests {

verify_brainfuck(brainfuck_proof).unwrap();
}

#[test]
fn test_proof_no_input() {
// Get an execution trace from a valid Brainfuck program
let code = "+++><[>+<-]";
let mut compiler = Compiler::new(code);
let instructions = compiler.compile();
let (mut machine, _) = create_test_machine(&instructions, &[1]);
let () = machine.execute().expect("Failed to execute machine");

let brainfuck_proof = prove_brainfuck(&machine).unwrap();

verify_brainfuck(brainfuck_proof).unwrap();
}

#[test]
fn test_proof_hello_world() {
// Get an execution trace from a valid Brainfuck program
let code = "++++++++++[>+++++++>++++++++++>+++>+<<<<-]>++.>+.+++++++..+++.>++.<<+++++++++++++++.>.+++.------.--------.>+.>.";
let mut compiler = Compiler::new(code);
let instructions = compiler.compile();
let (mut machine, _) = create_test_machine(&instructions, &[]);
let () = machine.execute().expect("Failed to execute machine");

let brainfuck_proof = prove_brainfuck(&machine).unwrap();

verify_brainfuck(brainfuck_proof).unwrap();
}
}
4 changes: 2 additions & 2 deletions crates/brainfuck_prover/src/components/io/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ mod test {
let preprocessed_trace = vec![is_first_col];

// Construct the main trace from the execution trace
let table = InputTable::from(trace_vm);
let table = InputTable::from(&trace_vm);
let (main_trace, claim) = table.trace_evaluation();

// Draw Interaction elements
Expand Down Expand Up @@ -204,7 +204,7 @@ mod test {
let preprocessed_trace = vec![is_first_col];

// Construct the main trace from the execution trace
let table = OutputTable::from(trace_vm);
let table = OutputTable::from(&trace_vm);
let (main_trace, claim) = table.trace_evaluation();

// Draw Interaction elements
Expand Down
10 changes: 5 additions & 5 deletions crates/brainfuck_prover/src/components/io/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,11 @@ impl<const N: u32> IOTable<N> {
}
}

impl<const N: u32> From<Vec<Registers>> for IOTable<N> {
fn from(registers: Vec<Registers>) -> Self {
impl<const N: u32> From<&Vec<Registers>> for IOTable<N> {
fn from(registers: &Vec<Registers>) -> Self {
let mut io_table = Self::new();
let rows = registers
.into_iter()
.iter()
.filter(|register| register.ci == BaseField::from_u32_unchecked(N))
.map(|x| IOTableRow::new(x.clk, x.ci, x.mv))
.collect();
Expand Down Expand Up @@ -453,7 +453,7 @@ mod tests {
let mut expected_io_table: InputTable = IOTable::new();
expected_io_table.add_row(row);

assert_eq!(IOTable::from(registers), expected_io_table);
assert_eq!(IOTable::from(&registers), expected_io_table);
}

#[test]
Expand All @@ -480,7 +480,7 @@ mod tests {
let mut expected_io_table: OutputTable = IOTable::new();
expected_io_table.add_row(row);

assert_eq!(IOTable::from(registers), expected_io_table);
assert_eq!(IOTable::from(&registers), expected_io_table);
}

#[test]
Expand Down
22 changes: 11 additions & 11 deletions crates/brainfuck_prover/src/components/memory/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ mod tests {
let preprocessed_trace = vec![is_first_col, is_first_col_2];

// Construct the main trace from the execution trace
let table = MemoryTable::from(trace_vm);
let table = MemoryTable::from(&trace_vm);
let (main_trace, claim) = table.trace_evaluation().unwrap();

// Draw Interaction elements
Expand Down Expand Up @@ -226,7 +226,7 @@ mod tests {
let preprocessed_trace_eval = vec![is_first_col];

let registers = vec![Registers { clk: BaseField::one(), ..Default::default() }];
let memory_table = MemoryTable::from(registers);
let memory_table = MemoryTable::from(&registers);
let (main_trace_eval, claim) = memory_table.trace_evaluation().unwrap();

// Required to use `MemoryElements::draw(channel: &mut impl Channel)`
Expand Down Expand Up @@ -266,7 +266,7 @@ mod tests {
let preprocessed_trace_eval = vec![is_first_col];

let registers = vec![Registers { mp: BaseField::one(), ..Default::default() }];
let memory_table = MemoryTable::from(registers);
let memory_table = MemoryTable::from(&registers);
let (main_trace_eval, claim) = memory_table.trace_evaluation().unwrap();

let channel = &mut Blake2sChannel::default();
Expand Down Expand Up @@ -303,7 +303,7 @@ mod tests {
let preprocessed_trace_eval = vec![is_first_col];

let registers = vec![Registers { mv: BaseField::one(), ..Default::default() }];
let memory_table = MemoryTable::from(registers);
let memory_table = MemoryTable::from(&registers);
let (main_trace_eval, claim) = memory_table.trace_evaluation().unwrap();

let channel = &mut Blake2sChannel::default();
Expand Down Expand Up @@ -340,7 +340,7 @@ mod tests {
let preprocessed_trace_eval = vec![is_first_col];

let registers = vec![Default::default()];
let mut memory_table = MemoryTable::from(registers);
let mut memory_table = MemoryTable::from(&registers);
// We must manually modify the value of d as rows from `Vec<Registers>` are assumed real.
memory_table.table[0].d = BaseField::one();
let (main_trace_eval, claim) = memory_table.trace_evaluation().unwrap();
Expand Down Expand Up @@ -381,7 +381,7 @@ mod tests {
// `mp` should increase by 0 or 1, here it increases by 2.
let registers =
vec![Default::default(), Registers { mp: BaseField::from(2), ..Default::default() }];
let memory_table = MemoryTable::from(registers);
let memory_table = MemoryTable::from(&registers);
let (main_trace_eval, claim) = memory_table.trace_evaluation().unwrap();

let memory_lookup_elements = MemoryElements::dummy();
Expand Down Expand Up @@ -418,7 +418,7 @@ mod tests {

// `mp` remains the same, but `clk` is not increased by 1.
let registers = vec![Default::default(), Default::default()];
let memory_table = MemoryTable::from(registers);
let memory_table = MemoryTable::from(&registers);
let (main_trace_eval, claim) = memory_table.trace_evaluation().unwrap();

let channel = &mut Blake2sChannel::default();
Expand Down Expand Up @@ -459,7 +459,7 @@ mod tests {
Default::default(),
Registers { mp: BaseField::one(), mv: BaseField::one(), ..Default::default() },
];
let memory_table = MemoryTable::from(registers);
let memory_table = MemoryTable::from(&registers);
let (main_trace_eval, claim) = memory_table.trace_evaluation().unwrap();

let channel = &mut Blake2sChannel::default();
Expand Down Expand Up @@ -498,7 +498,7 @@ mod tests {
// The next dummy register flag `next_d` is set to 2.
let registers =
vec![Default::default(), Registers { mp: BaseField::one(), ..Default::default() }];
let mut memory_table = MemoryTable::from(registers);
let mut memory_table = MemoryTable::from(&registers);
memory_table.table[0].next_d = BaseField::from(2);
let (main_trace_eval, claim) = memory_table.trace_evaluation().unwrap();

Expand Down Expand Up @@ -542,7 +542,7 @@ mod tests {
// And we modify the second entry of the second row to have `next_mp` different from `mp`
let registers =
vec![Default::default(), Registers { mp: BaseField::one(), ..Default::default() }];
let mut memory_table = MemoryTable::from(registers);
let mut memory_table = MemoryTable::from(&registers);
memory_table.table[1].d = BaseField::one();
memory_table.table[1].next_mp = BaseField::from(2);
let (main_trace_eval, claim) = memory_table.trace_evaluation().unwrap();
Expand Down Expand Up @@ -587,7 +587,7 @@ mod tests {
// And we modify the second entry of the second row to have `next_mv` different from `mv`
let registers =
vec![Default::default(), Registers { mp: BaseField::one(), ..Default::default() }];
let mut memory_table = MemoryTable::from(registers);
let mut memory_table = MemoryTable::from(&registers);
memory_table.table[1].d = BaseField::one();
memory_table.table[1].next_mv = BaseField::one();
let (main_trace_eval, claim) = memory_table.trace_evaluation().unwrap();
Expand Down
Loading
Loading