diff --git a/crates/brainfuck_prover/src/components/program/table.rs b/crates/brainfuck_prover/src/components/program/table.rs index 2840d6a..e202ae1 100644 --- a/crates/brainfuck_prover/src/components/program/table.rs +++ b/crates/brainfuck_prover/src/components/program/table.rs @@ -1,10 +1,15 @@ +use super::component::InteractionClaim; +use crate::components::{ProgramClaim, TraceColumn, TraceError, TraceEval}; use brainfuck_vm::{machine::ProgramMemory, registers::Registers}; use num_traits::{One, Zero}; use stwo_prover::{ - constraint_framework::{logup::LookupElements, Relation, RelationEFTraitBound}, + constraint_framework::{ + logup::{LogupTraceGenerator, LookupElements}, + Relation, RelationEFTraitBound, + }, core::{ backend::{ - simd::{column::BaseColumn, m31::LOG_N_LANES}, + simd::{column::BaseColumn, m31::LOG_N_LANES, qm31::PackedSecureField}, Column, }, channel::Channel, @@ -13,8 +18,6 @@ use stwo_prover::{ }, }; -use crate::components::{ProgramClaim, TraceColumn, TraceError, TraceEval}; - /// Represents a single row in the Program Table. /// /// The Program Table stores: @@ -271,6 +274,61 @@ impl> Relation for ProgramElements } } +/// Creates the interaction trace from the main trace evaluation +/// and the interaction elements for the Program component. +/// +/// The Program table is used to prove that the Instruction table (a subset of it actually) +/// contains the program that has been executed. To do so we make a lookup argument which uses the +/// Instruction lookup sum. Here, each fraction have a multiplicity of 1, while the counterpart in +/// the Instruction components will have a multiplicity of -1. +/// The order is kept by having the `ip` register in the denominator. +/// +/// Only the 'real' rows are impacting the logUp sum. +/// Dummy rows are padding rows. +/// +/// Here, the logUp has a single extension column, which will be used +/// by both the Processor and the Program components. +/// +/// # Returns +/// - Interaction trace evaluation, to be committed. +/// - Interaction claim: the total sum from the logUp protocol, +/// to be mixed into the Fiat-Shamir [`Channel`]. +#[allow(clippy::similar_names)] +pub fn interaction_trace_evaluation( + main_trace_eval: &TraceEval, + lookup_elements: &ProgramElements, +) -> Result<(TraceEval, InteractionClaim), TraceError> { + if main_trace_eval.is_empty() { + return Err(TraceError::EmptyTrace) + } + + let log_size = main_trace_eval[0].domain.log_size(); + + let mut logup_gen = LogupTraceGenerator::new(log_size); + let mut col_gen = logup_gen.new_col(); + + let ip_col = &main_trace_eval[ProgramColumn::Ip.index()].data; + let ci_col = &main_trace_eval[ProgramColumn::Ci.index()].data; + let ni_col = &main_trace_eval[ProgramColumn::Ni.index()].data; + let d_col = &main_trace_eval[ProgramColumn::D.index()].data; + + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let ip = ip_col[vec_row]; + let ci = ci_col[vec_row]; + let ni = ni_col[vec_row]; + let d = d_col[vec_row]; + + let num = PackedSecureField::one() - PackedSecureField::from(d); + let denom: PackedSecureField = lookup_elements.combine(&[ip, ci, ni]); + col_gen.write_frac(vec_row, num, denom); + } + + col_gen.finalize_col(); + let (trace, claimed_sum) = logup_gen.finalize_last(); + + Ok((trace, InteractionClaim { claimed_sum })) +} + #[cfg(test)] mod tests { use super::*; @@ -549,4 +607,79 @@ mod tests { ); } } + + #[test] + fn test_empty_interaction_trace_evaluation() { + let empty_eval = vec![]; + let lookup_elements = ProgramElements::dummy(); + let interaction_trace_eval = interaction_trace_evaluation(&empty_eval, &lookup_elements); + + assert!(matches!(interaction_trace_eval, Err(TraceError::EmptyTrace))); + } + + #[allow(clippy::similar_names)] + #[test] + fn test_interaction_trace_evaluation() { + let code = "+->[-]"; + let mut compiler = Compiler::new(code); + let instructions = compiler.compile(); + let (mut machine, _) = create_test_machine(&instructions, &[]); + let () = machine.execute().expect("Failed to execute machine"); + + let program_memory = machine.program(); + let program_table = ProgramTable::from(program_memory); + + let (trace_eval, claim) = program_table.trace_evaluation().unwrap(); + + let lookup_elements = ProgramElements::dummy(); + let (interaction_trace_eval, interaction_claim) = + interaction_trace_evaluation(&trace_eval, &lookup_elements).unwrap(); + + let log_size = trace_eval[0].domain.log_size(); + + let mut denoms = [PackedSecureField::zero(); 8]; + let ip_col = &trace_eval[ProgramColumn::Ip.index()].data; + let ci_col = &trace_eval[ProgramColumn::Ci.index()].data; + let ni_col = &trace_eval[ProgramColumn::Ni.index()].data; + + // Construct the denominator for each row of the logUp column, from the main trace + // evaluation. + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let ip = ip_col[vec_row]; + let ci = ci_col[vec_row]; + let ni = ni_col[vec_row]; + let denom: PackedSecureField = lookup_elements.combine(&[ip, ci, ni]); + denoms[vec_row] = denom; + } + + let mut logup_gen = LogupTraceGenerator::new(log_size); + + let mut col_gen = logup_gen.new_col(); + + col_gen.write_frac(0, PackedSecureField::one(), denoms[0]); + col_gen.write_frac(1, PackedSecureField::one(), denoms[1]); + col_gen.write_frac(2, PackedSecureField::one(), denoms[2]); + col_gen.write_frac(3, PackedSecureField::one(), denoms[3]); + col_gen.write_frac(4, PackedSecureField::one(), denoms[4]); + col_gen.write_frac(5, PackedSecureField::one(), denoms[5]); + col_gen.write_frac(6, PackedSecureField::one(), denoms[6]); + col_gen.write_frac(7, PackedSecureField::one(), denoms[7]); + + col_gen.finalize_col(); + + let (expected_interaction_trace_eval, expected_claimed_sum) = logup_gen.finalize_last(); + + assert_eq!(claim.log_size, log_size,); + for col_index in 0..expected_interaction_trace_eval.len() { + assert_eq!( + interaction_trace_eval[col_index].domain, + expected_interaction_trace_eval[col_index].domain + ); + assert_eq!( + interaction_trace_eval[col_index].to_cpu().values, + expected_interaction_trace_eval[col_index].to_cpu().values + ); + } + assert_eq!(interaction_claim.claimed_sum, expected_claimed_sum); + } }