Skip to content

Commit

Permalink
State machine AIR (#841)
Browse files Browse the repository at this point in the history
State machine single component example
  • Loading branch information
shaharsamocha7 authored Sep 26, 2024
2 parents 6e649fc + 7372a06 commit d7c7997
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 0 deletions.
1 change: 1 addition & 0 deletions crates/prover/src/examples/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod blake;
pub mod plonk;
pub mod poseidon;
pub mod state_machine;
pub mod wide_fibonacci;
pub mod xor;
49 changes: 49 additions & 0 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use num_traits::One;

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkEval};
use crate::core::fields::qm31::QM31;
use crate::core::lookups::utils::Fraction;

const LOG_CONSTRAINT_DEGREE: u32 = 1;
pub const STATE_SIZE: usize = 2;
/// Random elements to combine the StateMachine state.
pub type StateMachineElements = LookupElements<STATE_SIZE>;

/// State machine with state of size `STATE_SIZE`.
/// Transition `COORDINATE` of state increments the state by 1 at that offset.
#[derive(Clone)]
pub struct StateTransitionEval<const COORDINATE: usize> {
pub log_n_rows: u32,
pub lookup_elements: StateMachineElements,
pub total_sum: QM31,
}

impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE> {
fn log_size(&self) -> u32 {
self.log_n_rows
}
fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_n_rows + LOG_CONSTRAINT_DEGREE
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let mut logup: LogupAtRow<E> = LogupAtRow::new(1, self.total_sum, None, is_first);

let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask());
let input_denom: E::EF = self.lookup_elements.combine(&input_state);

let mut output_state = input_state;
output_state[COORDINATE] += E::F::one();
let output_denom: E::EF = self.lookup_elements.combine(&output_state);

logup.write_frac(
&mut eval,
Fraction::new(E::EF::one(), input_denom)
+ Fraction::new(-E::EF::one(), output_denom.clone()),
);

logup.finalize(&mut eval);
eval
}
}
135 changes: 135 additions & 0 deletions crates/prover/src/examples/state_machine/gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use itertools::Itertools;
use num_traits::{One, Zero};

use super::components::STATE_SIZE;
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedQM31;
use crate::core::backend::simd::SimdBackend;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;

pub type State = [M31; STATE_SIZE];

// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the
// `inc_index` dimension.
// E.g. [x, y] -> [x, y + 1] -> [x, y + 2] -> [x, y + 1 << log_size].
pub fn gen_trace(
log_size: u32,
initial_state: State,
inc_index: usize,
) -> ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>> {
let domain = CanonicCoset::new(log_size).circle_domain();
let mut trace = (0..STATE_SIZE)
.map(|_| vec![M31::zero(); 1 << log_size])
.collect_vec();

let mut curr_state = initial_state;
for i in 0..1 << log_size {
for j in 0..STATE_SIZE {
trace[j][i] = curr_state[j];
}
// Increment the state to the next state row.
curr_state[inc_index] += M31::one();
}

trace
.into_iter()
.map(|col| {
CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(
domain,
BaseColumn::from_iter(col),
)
})
.collect_vec()
}

pub fn gen_interaction_trace(
log_size: u32,
trace: &ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
inc_index: usize,
lookup_elements: &LookupElements<STATE_SIZE>,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
QM31,
) {
let ones = PackedM31::broadcast(M31::one());
let mut logup_gen = LogupTraceGenerator::new(log_size);
let mut col_gen = logup_gen.new_col();

for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let mut packed_state: [PackedM31; STATE_SIZE] = trace
.iter()
.map(|col| col.data[vec_row])
.collect_vec()
.try_into()
.unwrap();
let input_denom: PackedQM31 = lookup_elements.combine(&packed_state);
packed_state[inc_index] += ones;
let output_denom: PackedQM31 = lookup_elements.combine(&packed_state);
col_gen.write_frac(
vec_row,
output_denom - input_denom,
input_denom * output_denom,
);
}
col_gen.finalize_col();

logup_gen.finalize_last()
}

#[cfg(test)]
mod tests {
use crate::core::backend::Column;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::FieldExpOps;
use crate::examples::state_machine::components::StateMachineElements;
use crate::examples::state_machine::gen::{gen_interaction_trace, gen_trace};

#[test]
fn test_gen_trace() {
let log_size = 8;
let initial_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(16)];
let inc_index = 1;
let row = 123;

let trace = gen_trace(log_size, initial_state, inc_index);

assert_eq!(trace.len(), 2);
assert_eq!(trace[0].at(row), initial_state[0]);
assert_eq!(
trace[1].at(row),
initial_state[1] + M31::from_u32_unchecked(row as u32)
);
}

#[test]
fn test_gen_interaction_trace() {
let log_size = 8;
let inc_index = 1;
// Prepare the first and the last states.
let first_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(12)];
let mut last_state = first_state;
last_state[inc_index] += M31::from_u32_unchecked(1 << log_size);

let trace = gen_trace(log_size, first_state, inc_index);
let lookup_elements = StateMachineElements::dummy();
let first_state_comb: QM31 = lookup_elements.combine(&first_state);
let last_state_comb: QM31 = lookup_elements.combine(&last_state);

let (interaction_trace, total_sum) =
gen_interaction_trace(log_size, &trace, inc_index, &lookup_elements);

assert_eq!(interaction_trace.len(), SECURE_EXTENSION_DEGREE); // One extension column.
assert_eq!(
total_sum,
first_state_comb.inverse() - last_state_comb.inverse()
);
}
}
162 changes: 162 additions & 0 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
pub mod components;
pub mod gen;

use components::{StateMachineElements, StateTransitionEval};
use gen::{gen_interaction_trace, gen_trace, State};
use itertools::Itertools;

use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::{FrameworkComponent, TraceLocationAllocator};
use crate::core::air::Component;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Blake2sChannel;
use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec};
use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps};
use crate::core::prover::{prove, verify, StarkProof, VerificationError};
use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher};

pub type StateMachineOp0Component = FrameworkComponent<StateTransitionEval<0>>;

#[allow(unused)]
pub fn prove_state_machine(
log_n_rows: u32,
initial_state: State,
config: PcsConfig,
channel: &mut Blake2sChannel,
) -> (
StateMachineOp0Component,
StarkProof<Blake2sMerkleHasher>,
TreeVec<Vec<CirclePoly<SimdBackend>>>,
) {
assert!(log_n_rows >= LOG_N_LANES);

// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(log_n_rows + config.fri_config.log_blowup_factor + 1)
.circle_domain()
.half_coset,
);

// Setup protocol.
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

// Trace.
let trace_op0 = gen_trace(log_n_rows, initial_state, 0);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(trace_op0.clone());
tree_builder.commit(channel);

// Draw lookup element.
let lookup_elements = StateMachineElements::draw(channel);

// Interaction trace.
let (interaction_trace_op0, total_sum_op0) =
gen_interaction_trace(log_n_rows, &trace_op0, 0, &lookup_elements);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(interaction_trace_op0);
tree_builder.commit(channel);

// Constant trace.
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(vec![gen_is_first(log_n_rows)]);
tree_builder.commit(channel);

let trace_polys = commitment_scheme
.trees
.as_ref()
.map(|t| t.polynomials.iter().cloned().collect_vec());

// Prove constraints.
let component_op0 = StateMachineOp0Component::new(
&mut TraceLocationAllocator::default(),
StateTransitionEval {
log_n_rows,
lookup_elements,
total_sum: total_sum_op0,
},
);

let proof = prove(&[&component_op0], channel, commitment_scheme).unwrap();

(component_op0, proof, trace_polys)
}

pub fn verify_state_machine(
config: PcsConfig,
channel: &mut Blake2sChannel,
component: StateMachineOp0Component,
proof: StarkProof<Blake2sMerkleHasher>,
) -> Result<(), VerificationError> {
let commitment_scheme = &mut CommitmentSchemeVerifier::<Blake2sMerkleChannel>::new(config);

// Decommit.
// Retrieve the expected column sizes in each commitment interaction, from the AIR.
let sizes = component.trace_log_degree_bounds();
// Trace columns.
commitment_scheme.commit(proof.commitments[0], &sizes[0], channel);
// Interaction columns.
commitment_scheme.commit(proof.commitments[1], &sizes[1], channel);
// Constant columns.
commitment_scheme.commit(proof.commitments[2], &sizes[2], channel);

verify(&[&component], channel, commitment_scheme, proof)
}

#[cfg(test)]
mod tests {
use num_traits::Zero;

use super::components::STATE_SIZE;
use super::{prove_state_machine, verify_state_machine};
use crate::constraint_framework::{assert_constraints, FrameworkEval};
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::pcs::PcsConfig;
use crate::core::poly::circle::CanonicCoset;

#[test]
fn test_state_machine_constraints() {
let log_n_rows = 8;
let config = PcsConfig::default();

// Initial and last state.
let initial_state = [M31::zero(); STATE_SIZE];
let last_state = [M31::from_u32_unchecked(1 << log_n_rows), M31::zero()];

// Setup protocol.
let channel = &mut Blake2sChannel::default();
let (component, _, trace_polys) =
prove_state_machine(log_n_rows, initial_state, config, channel);

let interaction_elements = component.lookup_elements.clone();
let initial_state_comb: QM31 = interaction_elements.combine(&initial_state);
let last_state_comb: QM31 = interaction_elements.combine(&last_state);

// Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`.
assert_eq!(
component.total_sum * initial_state_comb * last_state_comb,
last_state_comb - initial_state_comb
);

// Assert constraints.
assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |eval| {
component.evaluate(eval);
});
}

#[test]
fn test_state_machine_prove() {
let log_n_rows = 8;
let config = PcsConfig::default();
let initial_state = [M31::zero(); STATE_SIZE];
let prover_channel = &mut Blake2sChannel::default();
let (component_op0, proof, _) =
prove_state_machine(log_n_rows, initial_state, config, prover_channel);

let verifier_channel = &mut Blake2sChannel::default();
verify_state_machine(config, verifier_channel, component_op0, proof).unwrap();
}
}

0 comments on commit d7c7997

Please sign in to comment.