-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
State machine single component example
- Loading branch information
Showing
4 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
pub mod blake; | ||
pub mod plonk; | ||
pub mod poseidon; | ||
pub mod state_machine; | ||
pub mod wide_fibonacci; | ||
pub mod xor; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
use num_traits::One; | ||
|
||
use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; | ||
use crate::constraint_framework::{EvalAtRow, FrameworkEval}; | ||
use crate::core::fields::qm31::QM31; | ||
use crate::core::lookups::utils::Fraction; | ||
|
||
const LOG_CONSTRAINT_DEGREE: u32 = 1; | ||
pub const STATE_SIZE: usize = 2; | ||
/// Random elements to combine the StateMachine state. | ||
pub type StateMachineElements = LookupElements<STATE_SIZE>; | ||
|
||
/// State machine with state of size `STATE_SIZE`. | ||
/// Transition `COORDINATE` of state increments the state by 1 at that offset. | ||
#[derive(Clone)] | ||
pub struct StateTransitionEval<const COORDINATE: usize> { | ||
pub log_n_rows: u32, | ||
pub lookup_elements: StateMachineElements, | ||
pub total_sum: QM31, | ||
} | ||
|
||
impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE> { | ||
fn log_size(&self) -> u32 { | ||
self.log_n_rows | ||
} | ||
fn max_constraint_log_degree_bound(&self) -> u32 { | ||
self.log_n_rows + LOG_CONSTRAINT_DEGREE | ||
} | ||
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E { | ||
let [is_first] = eval.next_interaction_mask(2, [0]); | ||
let mut logup: LogupAtRow<E> = LogupAtRow::new(1, self.total_sum, None, is_first); | ||
|
||
let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask()); | ||
let input_denom: E::EF = self.lookup_elements.combine(&input_state); | ||
|
||
let mut output_state = input_state; | ||
output_state[COORDINATE] += E::F::one(); | ||
let output_denom: E::EF = self.lookup_elements.combine(&output_state); | ||
|
||
logup.write_frac( | ||
&mut eval, | ||
Fraction::new(E::EF::one(), input_denom) | ||
+ Fraction::new(-E::EF::one(), output_denom.clone()), | ||
); | ||
|
||
logup.finalize(&mut eval); | ||
eval | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
use itertools::Itertools; | ||
use num_traits::{One, Zero}; | ||
|
||
use super::components::STATE_SIZE; | ||
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; | ||
use crate::core::backend::simd::column::BaseColumn; | ||
use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES}; | ||
use crate::core::backend::simd::qm31::PackedQM31; | ||
use crate::core::backend::simd::SimdBackend; | ||
use crate::core::fields::m31::M31; | ||
use crate::core::fields::qm31::QM31; | ||
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; | ||
use crate::core::poly::BitReversedOrder; | ||
use crate::core::ColumnVec; | ||
|
||
pub type State = [M31; STATE_SIZE]; | ||
|
||
// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the | ||
// `inc_index` dimension. | ||
// E.g. [x, y] -> [x, y + 1] -> [x, y + 2] -> [x, y + 1 << log_size]. | ||
pub fn gen_trace( | ||
log_size: u32, | ||
initial_state: State, | ||
inc_index: usize, | ||
) -> ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>> { | ||
let domain = CanonicCoset::new(log_size).circle_domain(); | ||
let mut trace = (0..STATE_SIZE) | ||
.map(|_| vec![M31::zero(); 1 << log_size]) | ||
.collect_vec(); | ||
|
||
let mut curr_state = initial_state; | ||
for i in 0..1 << log_size { | ||
for j in 0..STATE_SIZE { | ||
trace[j][i] = curr_state[j]; | ||
} | ||
// Increment the state to the next state row. | ||
curr_state[inc_index] += M31::one(); | ||
} | ||
|
||
trace | ||
.into_iter() | ||
.map(|col| { | ||
CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new( | ||
domain, | ||
BaseColumn::from_iter(col), | ||
) | ||
}) | ||
.collect_vec() | ||
} | ||
|
||
pub fn gen_interaction_trace( | ||
log_size: u32, | ||
trace: &ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>, | ||
inc_index: usize, | ||
lookup_elements: &LookupElements<STATE_SIZE>, | ||
) -> ( | ||
ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>, | ||
QM31, | ||
) { | ||
let ones = PackedM31::broadcast(M31::one()); | ||
let mut logup_gen = LogupTraceGenerator::new(log_size); | ||
let mut col_gen = logup_gen.new_col(); | ||
|
||
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { | ||
let mut packed_state: [PackedM31; STATE_SIZE] = trace | ||
.iter() | ||
.map(|col| col.data[vec_row]) | ||
.collect_vec() | ||
.try_into() | ||
.unwrap(); | ||
let input_denom: PackedQM31 = lookup_elements.combine(&packed_state); | ||
packed_state[inc_index] += ones; | ||
let output_denom: PackedQM31 = lookup_elements.combine(&packed_state); | ||
col_gen.write_frac( | ||
vec_row, | ||
output_denom - input_denom, | ||
input_denom * output_denom, | ||
); | ||
} | ||
col_gen.finalize_col(); | ||
|
||
logup_gen.finalize_last() | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::core::backend::Column; | ||
use crate::core::fields::m31::M31; | ||
use crate::core::fields::qm31::QM31; | ||
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; | ||
use crate::core::fields::FieldExpOps; | ||
use crate::examples::state_machine::components::StateMachineElements; | ||
use crate::examples::state_machine::gen::{gen_interaction_trace, gen_trace}; | ||
|
||
#[test] | ||
fn test_gen_trace() { | ||
let log_size = 8; | ||
let initial_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(16)]; | ||
let inc_index = 1; | ||
let row = 123; | ||
|
||
let trace = gen_trace(log_size, initial_state, inc_index); | ||
|
||
assert_eq!(trace.len(), 2); | ||
assert_eq!(trace[0].at(row), initial_state[0]); | ||
assert_eq!( | ||
trace[1].at(row), | ||
initial_state[1] + M31::from_u32_unchecked(row as u32) | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_gen_interaction_trace() { | ||
let log_size = 8; | ||
let inc_index = 1; | ||
// Prepare the first and the last states. | ||
let first_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(12)]; | ||
let mut last_state = first_state; | ||
last_state[inc_index] += M31::from_u32_unchecked(1 << log_size); | ||
|
||
let trace = gen_trace(log_size, first_state, inc_index); | ||
let lookup_elements = StateMachineElements::dummy(); | ||
let first_state_comb: QM31 = lookup_elements.combine(&first_state); | ||
let last_state_comb: QM31 = lookup_elements.combine(&last_state); | ||
|
||
let (interaction_trace, total_sum) = | ||
gen_interaction_trace(log_size, &trace, inc_index, &lookup_elements); | ||
|
||
assert_eq!(interaction_trace.len(), SECURE_EXTENSION_DEGREE); // One extension column. | ||
assert_eq!( | ||
total_sum, | ||
first_state_comb.inverse() - last_state_comb.inverse() | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
pub mod components; | ||
pub mod gen; | ||
|
||
use components::{StateMachineElements, StateTransitionEval}; | ||
use gen::{gen_interaction_trace, gen_trace, State}; | ||
use itertools::Itertools; | ||
|
||
use crate::constraint_framework::constant_columns::gen_is_first; | ||
use crate::constraint_framework::{FrameworkComponent, TraceLocationAllocator}; | ||
use crate::core::air::Component; | ||
use crate::core::backend::simd::m31::LOG_N_LANES; | ||
use crate::core::backend::simd::SimdBackend; | ||
use crate::core::channel::Blake2sChannel; | ||
use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; | ||
use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps}; | ||
use crate::core::prover::{prove, verify, StarkProof, VerificationError}; | ||
use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; | ||
|
||
pub type StateMachineOp0Component = FrameworkComponent<StateTransitionEval<0>>; | ||
|
||
#[allow(unused)] | ||
pub fn prove_state_machine( | ||
log_n_rows: u32, | ||
initial_state: State, | ||
config: PcsConfig, | ||
channel: &mut Blake2sChannel, | ||
) -> ( | ||
StateMachineOp0Component, | ||
StarkProof<Blake2sMerkleHasher>, | ||
TreeVec<Vec<CirclePoly<SimdBackend>>>, | ||
) { | ||
assert!(log_n_rows >= LOG_N_LANES); | ||
|
||
// Precompute twiddles. | ||
let twiddles = SimdBackend::precompute_twiddles( | ||
CanonicCoset::new(log_n_rows + config.fri_config.log_blowup_factor + 1) | ||
.circle_domain() | ||
.half_coset, | ||
); | ||
|
||
// Setup protocol. | ||
let commitment_scheme = | ||
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); | ||
|
||
// Trace. | ||
let trace_op0 = gen_trace(log_n_rows, initial_state, 0); | ||
let mut tree_builder = commitment_scheme.tree_builder(); | ||
tree_builder.extend_evals(trace_op0.clone()); | ||
tree_builder.commit(channel); | ||
|
||
// Draw lookup element. | ||
let lookup_elements = StateMachineElements::draw(channel); | ||
|
||
// Interaction trace. | ||
let (interaction_trace_op0, total_sum_op0) = | ||
gen_interaction_trace(log_n_rows, &trace_op0, 0, &lookup_elements); | ||
let mut tree_builder = commitment_scheme.tree_builder(); | ||
tree_builder.extend_evals(interaction_trace_op0); | ||
tree_builder.commit(channel); | ||
|
||
// Constant trace. | ||
let mut tree_builder = commitment_scheme.tree_builder(); | ||
tree_builder.extend_evals(vec![gen_is_first(log_n_rows)]); | ||
tree_builder.commit(channel); | ||
|
||
let trace_polys = commitment_scheme | ||
.trees | ||
.as_ref() | ||
.map(|t| t.polynomials.iter().cloned().collect_vec()); | ||
|
||
// Prove constraints. | ||
let component_op0 = StateMachineOp0Component::new( | ||
&mut TraceLocationAllocator::default(), | ||
StateTransitionEval { | ||
log_n_rows, | ||
lookup_elements, | ||
total_sum: total_sum_op0, | ||
}, | ||
); | ||
|
||
let proof = prove(&[&component_op0], channel, commitment_scheme).unwrap(); | ||
|
||
(component_op0, proof, trace_polys) | ||
} | ||
|
||
pub fn verify_state_machine( | ||
config: PcsConfig, | ||
channel: &mut Blake2sChannel, | ||
component: StateMachineOp0Component, | ||
proof: StarkProof<Blake2sMerkleHasher>, | ||
) -> Result<(), VerificationError> { | ||
let commitment_scheme = &mut CommitmentSchemeVerifier::<Blake2sMerkleChannel>::new(config); | ||
|
||
// Decommit. | ||
// Retrieve the expected column sizes in each commitment interaction, from the AIR. | ||
let sizes = component.trace_log_degree_bounds(); | ||
// Trace columns. | ||
commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); | ||
// Interaction columns. | ||
commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); | ||
// Constant columns. | ||
commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); | ||
|
||
verify(&[&component], channel, commitment_scheme, proof) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use num_traits::Zero; | ||
|
||
use super::components::STATE_SIZE; | ||
use super::{prove_state_machine, verify_state_machine}; | ||
use crate::constraint_framework::{assert_constraints, FrameworkEval}; | ||
use crate::core::channel::Blake2sChannel; | ||
use crate::core::fields::m31::M31; | ||
use crate::core::fields::qm31::QM31; | ||
use crate::core::pcs::PcsConfig; | ||
use crate::core::poly::circle::CanonicCoset; | ||
|
||
#[test] | ||
fn test_state_machine_constraints() { | ||
let log_n_rows = 8; | ||
let config = PcsConfig::default(); | ||
|
||
// Initial and last state. | ||
let initial_state = [M31::zero(); STATE_SIZE]; | ||
let last_state = [M31::from_u32_unchecked(1 << log_n_rows), M31::zero()]; | ||
|
||
// Setup protocol. | ||
let channel = &mut Blake2sChannel::default(); | ||
let (component, _, trace_polys) = | ||
prove_state_machine(log_n_rows, initial_state, config, channel); | ||
|
||
let interaction_elements = component.lookup_elements.clone(); | ||
let initial_state_comb: QM31 = interaction_elements.combine(&initial_state); | ||
let last_state_comb: QM31 = interaction_elements.combine(&last_state); | ||
|
||
// Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`. | ||
assert_eq!( | ||
component.total_sum * initial_state_comb * last_state_comb, | ||
last_state_comb - initial_state_comb | ||
); | ||
|
||
// Assert constraints. | ||
assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |eval| { | ||
component.evaluate(eval); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn test_state_machine_prove() { | ||
let log_n_rows = 8; | ||
let config = PcsConfig::default(); | ||
let initial_state = [M31::zero(); STATE_SIZE]; | ||
let prover_channel = &mut Blake2sChannel::default(); | ||
let (component_op0, proof, _) = | ||
prove_state_machine(log_n_rows, initial_state, config, prover_channel); | ||
|
||
let verifier_channel = &mut Blake2sChannel::default(); | ||
verify_state_machine(config, verifier_channel, component_op0, proof).unwrap(); | ||
} | ||
} |