Skip to content

Commit

Permalink
feat: add trace evaluation to Program table
Browse files Browse the repository at this point in the history
  • Loading branch information
zmalatrax committed Dec 23, 2024
1 parent 26d5dde commit b3a26da
Showing 1 changed file with 222 additions and 2 deletions.
224 changes: 222 additions & 2 deletions crates/brainfuck_prover/src/components/program/table.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
use brainfuck_vm::registers::Registers;
use num_traits::One;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::{
backend::{
simd::{column::BaseColumn, m31::LOG_N_LANES},
Column,
},
fields::m31::BaseField,
poly::circle::{CanonicCoset, CircleEvaluation},
};

use crate::components::TraceColumn;
use crate::components::{ProgramClaim, TraceColumn, TraceError, TraceEval};

/// Represents a single row in the Program Table.
///
Expand Down Expand Up @@ -95,6 +102,54 @@ impl ProgramTable {
}
}
}

/// Transforms the [`ProgramTable`] into a [`TraceEval`], to be committed when
/// generating a STARK proof.
///
/// The [`ProgramTable`] is transformed from an array of rows (one element = one step
/// of all registers) to an array of columns (one element = all steps of one register).
/// It is then evaluated on the circle domain.
///
/// # Returns
/// A tuple containing the evaluated trace and claim for STARK proof.
///
/// # Errors
/// Returns [`TraceError::EmptyTrace`] if the table is empty.
pub fn trace_evaluation(&self) -> Result<(TraceEval, ProgramClaim), TraceError> {
let n_rows = self.table.len() as u32;

// If the table is empty, there is no data to evaluate, so return an error.
if n_rows == 0 {
return Err(TraceError::EmptyTrace);
}

// Compute `log_n_rows`, the base-2 logarithm of the number of rows.
let log_n_rows = n_rows.ilog2();

// Add `LOG_N_LANES` to account for SIMD optimization.
let log_size = log_n_rows + LOG_N_LANES;

// Initialize a trace with 4 columns (`ip`, `ci`, `ni`, `d`), each column containing
// `2^log_size` entries initialized to zero.
let mut trace = vec![BaseColumn::zeros(1 << log_size); 4];

// Populate the column with data from the table rows.
for (index, row) in self.table.iter().enumerate().take(1 << log_n_rows) {
trace[ProgramColumn::Ip.index()].data[index] = row.ip.into();
trace[ProgramColumn::Ci.index()].data[index] = row.ci.into();
trace[ProgramColumn::Ni.index()].data[index] = row.ni.into();
trace[ProgramColumn::D.index()].data[index] = row.d.into();
}

// Create a circle domain using a canonical coset.
let domain = CanonicCoset::new(log_size).circle_domain();

// Map the column into the circle domain.
let trace = trace.into_iter().map(|col| CircleEvaluation::new(domain, col)).collect();

// Return the evaluated trace and a claim containing the log size of the domain.
Ok((trace, ProgramClaim::new(log_size)))
}
}

impl From<&Vec<Registers>> for ProgramTable {
Expand Down Expand Up @@ -253,4 +308,169 @@ mod tests {

assert_eq!(ProgramTable::from(&trace), expected_program_table);
}

#[test]
fn test_trace_evaluation_empty_table() {
let program_table = ProgramTable::new();

let result = program_table.trace_evaluation();

assert!(matches!(result, Err(TraceError::EmptyTrace)));
}

#[allow(clippy::similar_names)]
#[test]
fn test_trace_evaluation_single_row() {
let mut program_table = ProgramTable::new();
program_table.add_row(ProgramTableRow::new(
BaseField::one(),
BaseField::from(44),
BaseField::from(42),
));

let (trace, claim) = program_table.trace_evaluation().unwrap();

assert_eq!(claim.log_size, LOG_N_LANES, "Log size should include SIMD lanes.");
assert_eq!(
trace.len(),
ProgramColumn::count().0,
"Trace should contain one column per register."
);

// Expected values for the single row
let expected_ip_column = vec![BaseField::one(); 1 << LOG_N_LANES];
let expected_ci_column = vec![BaseField::from(44); 1 << LOG_N_LANES];
let expected_ni_column = vec![BaseField::from(42); 1 << LOG_N_LANES];
let expected_d_column = vec![BaseField::zero(); 1 << LOG_N_LANES];

// Check that the entire column matches expected values
assert_eq!(
trace[ProgramColumn::Ip.index()].to_cpu().values,
expected_ip_column,
"Ip column should match expected values."
);
assert_eq!(
trace[ProgramColumn::Ci.index()].to_cpu().values,
expected_ci_column,
"Ci column should match expected values."
);
assert_eq!(
trace[ProgramColumn::Ni.index()].to_cpu().values,
expected_ni_column,
"Ni column should match expected values."
);
assert_eq!(
trace[ProgramColumn::D.index()].to_cpu().values,
expected_d_column,
"D column should match expected values."
);
}

#[test]
fn test_program_table_trace_evaluation() {
let mut program_table = ProgramTable::new();

// Add rows to the Program table.
let rows = vec![
ProgramTableRow::new(BaseField::zero(), BaseField::from(44), BaseField::one()),
ProgramTableRow::new(BaseField::one(), BaseField::from(44), BaseField::from(2)),
ProgramTableRow::new_dummy(BaseField::from(2)),
ProgramTableRow::new_dummy(BaseField::from(3)),
];
program_table.add_rows(rows);

// Perform the trace evaluation.
let (trace, claim) = program_table.trace_evaluation().unwrap();

// Calculate the expected parameters.
let expected_log_n_rows: u32 = 2; // log2(2 rows)
let expected_log_size = expected_log_n_rows + LOG_N_LANES;
let expected_size = 1 << expected_log_size;

// Construct the expected trace column for `ip`, `ci`, `ni` and `d`.
let mut expected_columns = vec![BaseColumn::zeros(expected_size); ProgramColumn::count().0];

// Populate the `ip` column with row data and padding.
expected_columns[ProgramColumn::Ip.index()].data[0] = BaseField::zero().into();
expected_columns[ProgramColumn::Ip.index()].data[1] = BaseField::one().into();
expected_columns[ProgramColumn::Ip.index()].data[2] = BaseField::from(2).into();
expected_columns[ProgramColumn::Ip.index()].data[3] = BaseField::from(3).into();

// Populate the `ci` column with row data and padding.
expected_columns[ProgramColumn::Ci.index()].data[0] = BaseField::from(44).into();
expected_columns[ProgramColumn::Ci.index()].data[1] = BaseField::from(44).into();
expected_columns[ProgramColumn::Ci.index()].data[2] = BaseField::zero().into();
expected_columns[ProgramColumn::Ci.index()].data[3] = BaseField::zero().into();

// Populate the `ni` column with row data and padding.
expected_columns[ProgramColumn::Ni.index()].data[0] = BaseField::one().into();
expected_columns[ProgramColumn::Ni.index()].data[1] = BaseField::from(2).into();
expected_columns[ProgramColumn::Ni.index()].data[2] = BaseField::zero().into();
expected_columns[ProgramColumn::Ni.index()].data[3] = BaseField::zero().into();

// Populate the `d` column with row data and padding.
expected_columns[ProgramColumn::D.index()].data[0] = BaseField::zero().into();
expected_columns[ProgramColumn::D.index()].data[1] = BaseField::zero().into();
expected_columns[ProgramColumn::D.index()].data[2] = BaseField::one().into();
expected_columns[ProgramColumn::D.index()].data[3] = BaseField::one().into();

// Create the expected domain for evaluation.
let domain = CanonicCoset::new(expected_log_size).circle_domain();

// Transform expected columns into CircleEvaluation.
let expected_trace: TraceEval =
expected_columns.into_iter().map(|col| CircleEvaluation::new(domain, col)).collect();

// Create the expected claim.
let expected_claim = ProgramClaim::new(expected_log_size);

// Assert equality of the claim.
assert_eq!(claim, expected_claim, "The claim should match the expected claim.");

// Assert equality of the trace for all columns.
for (actual, expected) in trace.iter().zip(expected_trace.iter()) {
assert_eq!(
actual.domain, expected.domain,
"The domain of the trace column should match the expected domain."
);
assert_eq!(
actual.to_cpu().values,
expected.to_cpu().values,
"The values of the trace column should match the expected values."
);
}
}

#[test]
fn test_trace_evaluation_circle_domain() {
let mut program_table = ProgramTable::new();
program_table.add_rows(vec![
ProgramTableRow::new(
BaseField::zero(),
InstructionType::ReadChar.to_base_field(),
BaseField::one(),
),
ProgramTableRow::new(
BaseField::one(),
InstructionType::ReadChar.to_base_field(),
BaseField::from(2),
),
ProgramTableRow::new(
BaseField::from(3),
InstructionType::ReadChar.to_base_field(),
BaseField::from(7),
),
]);

let (trace, claim) = program_table.trace_evaluation().unwrap();

// Verify the domain of the trace matches the expected circle domain.
let domain = CanonicCoset::new(claim.log_size).circle_domain();
for column in trace {
assert_eq!(
column.domain, domain,
"Trace column domain should match the expected circle domain."
);
}
}
}

0 comments on commit b3a26da

Please sign in to comment.