From b3a26da13ec0750417ca5e9ce195e8e50efedd0a Mon Sep 17 00:00:00 2001 From: malatrax Date: Mon, 23 Dec 2024 12:17:24 +0100 Subject: [PATCH] feat: add trace evaluation to Program table --- .../src/components/program/table.rs | 224 +++++++++++++++++- 1 file changed, 222 insertions(+), 2 deletions(-) diff --git a/crates/brainfuck_prover/src/components/program/table.rs b/crates/brainfuck_prover/src/components/program/table.rs index 8ffbd04..54b78e6 100644 --- a/crates/brainfuck_prover/src/components/program/table.rs +++ b/crates/brainfuck_prover/src/components/program/table.rs @@ -1,8 +1,15 @@ use brainfuck_vm::registers::Registers; use num_traits::One; -use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::{ + backend::{ + simd::{column::BaseColumn, m31::LOG_N_LANES}, + Column, + }, + fields::m31::BaseField, + poly::circle::{CanonicCoset, CircleEvaluation}, +}; -use crate::components::TraceColumn; +use crate::components::{ProgramClaim, TraceColumn, TraceError, TraceEval}; /// Represents a single row in the Program Table. /// @@ -95,6 +102,54 @@ impl ProgramTable { } } } + + /// Transforms the [`ProgramTable`] into a [`TraceEval`], to be committed when + /// generating a STARK proof. + /// + /// The [`ProgramTable`] is transformed from an array of rows (one element = one step + /// of all registers) to an array of columns (one element = all steps of one register). + /// It is then evaluated on the circle domain. + /// + /// # Returns + /// A tuple containing the evaluated trace and claim for STARK proof. + /// + /// # Errors + /// Returns [`TraceError::EmptyTrace`] if the table is empty. + pub fn trace_evaluation(&self) -> Result<(TraceEval, ProgramClaim), TraceError> { + let n_rows = self.table.len() as u32; + + // If the table is empty, there is no data to evaluate, so return an error. + if n_rows == 0 { + return Err(TraceError::EmptyTrace); + } + + // Compute `log_n_rows`, the base-2 logarithm of the number of rows. + let log_n_rows = n_rows.ilog2(); + + // Add `LOG_N_LANES` to account for SIMD optimization. + let log_size = log_n_rows + LOG_N_LANES; + + // Initialize a trace with 4 columns (`ip`, `ci`, `ni`, `d`), each column containing + // `2^log_size` entries initialized to zero. + let mut trace = vec![BaseColumn::zeros(1 << log_size); 4]; + + // Populate the column with data from the table rows. + for (index, row) in self.table.iter().enumerate().take(1 << log_n_rows) { + trace[ProgramColumn::Ip.index()].data[index] = row.ip.into(); + trace[ProgramColumn::Ci.index()].data[index] = row.ci.into(); + trace[ProgramColumn::Ni.index()].data[index] = row.ni.into(); + trace[ProgramColumn::D.index()].data[index] = row.d.into(); + } + + // Create a circle domain using a canonical coset. + let domain = CanonicCoset::new(log_size).circle_domain(); + + // Map the column into the circle domain. + let trace = trace.into_iter().map(|col| CircleEvaluation::new(domain, col)).collect(); + + // Return the evaluated trace and a claim containing the log size of the domain. + Ok((trace, ProgramClaim::new(log_size))) + } } impl From<&Vec> for ProgramTable { @@ -253,4 +308,169 @@ mod tests { assert_eq!(ProgramTable::from(&trace), expected_program_table); } + + #[test] + fn test_trace_evaluation_empty_table() { + let program_table = ProgramTable::new(); + + let result = program_table.trace_evaluation(); + + assert!(matches!(result, Err(TraceError::EmptyTrace))); + } + + #[allow(clippy::similar_names)] + #[test] + fn test_trace_evaluation_single_row() { + let mut program_table = ProgramTable::new(); + program_table.add_row(ProgramTableRow::new( + BaseField::one(), + BaseField::from(44), + BaseField::from(42), + )); + + let (trace, claim) = program_table.trace_evaluation().unwrap(); + + assert_eq!(claim.log_size, LOG_N_LANES, "Log size should include SIMD lanes."); + assert_eq!( + trace.len(), + ProgramColumn::count().0, + "Trace should contain one column per register." + ); + + // Expected values for the single row + let expected_ip_column = vec![BaseField::one(); 1 << LOG_N_LANES]; + let expected_ci_column = vec![BaseField::from(44); 1 << LOG_N_LANES]; + let expected_ni_column = vec![BaseField::from(42); 1 << LOG_N_LANES]; + let expected_d_column = vec![BaseField::zero(); 1 << LOG_N_LANES]; + + // Check that the entire column matches expected values + assert_eq!( + trace[ProgramColumn::Ip.index()].to_cpu().values, + expected_ip_column, + "Ip column should match expected values." + ); + assert_eq!( + trace[ProgramColumn::Ci.index()].to_cpu().values, + expected_ci_column, + "Ci column should match expected values." + ); + assert_eq!( + trace[ProgramColumn::Ni.index()].to_cpu().values, + expected_ni_column, + "Ni column should match expected values." + ); + assert_eq!( + trace[ProgramColumn::D.index()].to_cpu().values, + expected_d_column, + "D column should match expected values." + ); + } + + #[test] + fn test_program_table_trace_evaluation() { + let mut program_table = ProgramTable::new(); + + // Add rows to the Program table. + let rows = vec![ + ProgramTableRow::new(BaseField::zero(), BaseField::from(44), BaseField::one()), + ProgramTableRow::new(BaseField::one(), BaseField::from(44), BaseField::from(2)), + ProgramTableRow::new_dummy(BaseField::from(2)), + ProgramTableRow::new_dummy(BaseField::from(3)), + ]; + program_table.add_rows(rows); + + // Perform the trace evaluation. + let (trace, claim) = program_table.trace_evaluation().unwrap(); + + // Calculate the expected parameters. + let expected_log_n_rows: u32 = 2; // log2(2 rows) + let expected_log_size = expected_log_n_rows + LOG_N_LANES; + let expected_size = 1 << expected_log_size; + + // Construct the expected trace column for `ip`, `ci`, `ni` and `d`. + let mut expected_columns = vec![BaseColumn::zeros(expected_size); ProgramColumn::count().0]; + + // Populate the `ip` column with row data and padding. + expected_columns[ProgramColumn::Ip.index()].data[0] = BaseField::zero().into(); + expected_columns[ProgramColumn::Ip.index()].data[1] = BaseField::one().into(); + expected_columns[ProgramColumn::Ip.index()].data[2] = BaseField::from(2).into(); + expected_columns[ProgramColumn::Ip.index()].data[3] = BaseField::from(3).into(); + + // Populate the `ci` column with row data and padding. + expected_columns[ProgramColumn::Ci.index()].data[0] = BaseField::from(44).into(); + expected_columns[ProgramColumn::Ci.index()].data[1] = BaseField::from(44).into(); + expected_columns[ProgramColumn::Ci.index()].data[2] = BaseField::zero().into(); + expected_columns[ProgramColumn::Ci.index()].data[3] = BaseField::zero().into(); + + // Populate the `ni` column with row data and padding. + expected_columns[ProgramColumn::Ni.index()].data[0] = BaseField::one().into(); + expected_columns[ProgramColumn::Ni.index()].data[1] = BaseField::from(2).into(); + expected_columns[ProgramColumn::Ni.index()].data[2] = BaseField::zero().into(); + expected_columns[ProgramColumn::Ni.index()].data[3] = BaseField::zero().into(); + + // Populate the `d` column with row data and padding. + expected_columns[ProgramColumn::D.index()].data[0] = BaseField::zero().into(); + expected_columns[ProgramColumn::D.index()].data[1] = BaseField::zero().into(); + expected_columns[ProgramColumn::D.index()].data[2] = BaseField::one().into(); + expected_columns[ProgramColumn::D.index()].data[3] = BaseField::one().into(); + + // Create the expected domain for evaluation. + let domain = CanonicCoset::new(expected_log_size).circle_domain(); + + // Transform expected columns into CircleEvaluation. + let expected_trace: TraceEval = + expected_columns.into_iter().map(|col| CircleEvaluation::new(domain, col)).collect(); + + // Create the expected claim. + let expected_claim = ProgramClaim::new(expected_log_size); + + // Assert equality of the claim. + assert_eq!(claim, expected_claim, "The claim should match the expected claim."); + + // Assert equality of the trace for all columns. + for (actual, expected) in trace.iter().zip(expected_trace.iter()) { + assert_eq!( + actual.domain, expected.domain, + "The domain of the trace column should match the expected domain." + ); + assert_eq!( + actual.to_cpu().values, + expected.to_cpu().values, + "The values of the trace column should match the expected values." + ); + } + } + + #[test] + fn test_trace_evaluation_circle_domain() { + let mut program_table = ProgramTable::new(); + program_table.add_rows(vec![ + ProgramTableRow::new( + BaseField::zero(), + InstructionType::ReadChar.to_base_field(), + BaseField::one(), + ), + ProgramTableRow::new( + BaseField::one(), + InstructionType::ReadChar.to_base_field(), + BaseField::from(2), + ), + ProgramTableRow::new( + BaseField::from(3), + InstructionType::ReadChar.to_base_field(), + BaseField::from(7), + ), + ]); + + let (trace, claim) = program_table.trace_evaluation().unwrap(); + + // Verify the domain of the trace matches the expected circle domain. + let domain = CanonicCoset::new(claim.log_size).circle_domain(); + for column in trace { + assert_eq!( + column.domain, domain, + "Trace column domain should match the expected circle domain." + ); + } + } }