Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for constant column in STWO backend #2112

Merged
merged 51 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
4a7f349
add support for constant column
ShuangWu121 Nov 19, 2024
d1becd0
building constant columns
ShuangWu121 Nov 21, 2024
53183d4
test pil
ShuangWu121 Nov 21, 2024
a3f6d0f
cleaner
ShuangWu121 Nov 21, 2024
eb6f2ec
constant support
ShuangWu121 Nov 21, 2024
7169334
Merge branch 'main' into stwo-constant-support
ShuangWu121 Nov 21, 2024
cd67a59
clean up
ShuangWu121 Nov 22, 2024
12146be
add setup
ShuangWu121 Nov 24, 2024
2c100ba
put setup into new function
ShuangWu121 Nov 24, 2024
c9606b5
clean up
ShuangWu121 Nov 24, 2024
9756ef4
clean up
ShuangWu121 Nov 25, 2024
9f9c20d
clean up
ShuangWu121 Nov 25, 2024
b99c1a9
add challenge channel to tableProvingkey
ShuangWu121 Nov 25, 2024
e688e3e
handle empty constant case
ShuangWu121 Nov 25, 2024
d9db236
add more test, and comments
ShuangWu121 Nov 25, 2024
3ea95b5
add test in pil
ShuangWu121 Nov 25, 2024
6fa35a7
remove prover channel from table key
ShuangWu121 Nov 27, 2024
648435d
avoid clone witness, using better API to do bit reverse order of the …
ShuangWu121 Nov 27, 2024
b4a2411
Update backend/src/stwo/circuit_builder.rs
ShuangWu121 Nov 28, 2024
fe6f220
cannot make setup work because of 'a
ShuangWu121 Nov 28, 2024
94c1ac2
avoid using refcell, and std::mem::take
ShuangWu121 Nov 28, 2024
f19f557
fix error with empty constant, simplified code
ShuangWu121 Nov 29, 2024
14f76ba
add enumerate to plonk(i)
ShuangWu121 Nov 29, 2024
07ceb5d
clean up
ShuangWu121 Nov 29, 2024
c422ffd
add more comment
ShuangWu121 Nov 29, 2024
d3f7782
add fail test
ShuangWu121 Nov 29, 2024
61a989e
merge main
ShuangWu121 Nov 29, 2024
d4e0bb1
fix test case
ShuangWu121 Nov 30, 2024
998a240
Update backend/src/stwo/prover.rs
ShuangWu121 Dec 3, 2024
26c8ce9
make gen_stwo_circle_column work on slice
ShuangWu121 Dec 4, 2024
567b0f8
support constant column with next reference
ShuangWu121 Dec 5, 2024
2e627dd
use a wrong clippy command, now fixed it
ShuangWu121 Dec 5, 2024
0f05f72
use identities, so no panic for lookups
ShuangWu121 Dec 5, 2024
bc4bf30
add test for fixed col with next reference, remove mc generic in prov…
ShuangWu121 Dec 6, 2024
50521e9
fix fmt
ShuangWu121 Dec 6, 2024
34e8963
merge to main Merge remote-tracking branch 'origin/main' into stwo-co…
ShuangWu121 Dec 6, 2024
410c68a
no intermidate panic
ShuangWu121 Dec 6, 2024
a1a663c
serialize and deserialize proving keys
ShuangWu121 Dec 9, 2024
9503341
add test file
ShuangWu121 Dec 9, 2024
66b1c1b
simplify serilization
ShuangWu121 Dec 9, 2024
b18e937
simplify serilization
ShuangWu121 Dec 9, 2024
524d9ff
clean up
ShuangWu121 Dec 9, 2024
5e592f7
refactor proof serilization and next reference on constant
ShuangWu121 Dec 10, 2024
30bd4ae
Update backend/src/stwo/prover.rs
ShuangWu121 Dec 10, 2024
9ab7a3b
clean up
ShuangWu121 Dec 10, 2024
3e2a078
clean up
ShuangWu121 Dec 10, 2024
2b56587
log size correction
ShuangWu121 Dec 10, 2024
8f6eb44
Update backend/src/stwo/circuit_builder.rs
ShuangWu121 Dec 10, 2024
af61413
simplify the function to create list of constant with next reference
ShuangWu121 Dec 11, 2024
f446e46
use all_children to create list
ShuangWu121 Dec 11, 2024
0557c60
create name list for constant with next
ShuangWu121 Dec 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ p3-commit = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf2
p3-matrix = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true }
p3-uni-stark = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true }
# TODO: Change this to main branch when the `andrew/dev/update-toolchain` branch is merged,the main branch is using "nightly-2024-01-04", not compatiable with plonky3
stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "e6d10bc107c11cce54bb4aa152c3afa2e15e92c1" }
stwo-prover = { git = "https://github.com/ShuangWu121/stwo.git", optional = true, rev = "564a4ddcde376ba0ae78da4d86ea5ad7338ef6fe",features = ["parallel"] }

strum = { version = "0.24.1", features = ["derive"] }
log = "0.4.17"
Expand Down
200 changes: 163 additions & 37 deletions backend/src/stwo/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,65 @@
use num_traits::Zero;
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Neg, Sub};
use std::sync::Arc;

extern crate alloc;
use alloc::{collections::btree_map::BTreeMap, string::String, vec::Vec};
use alloc::collections::btree_map::BTreeMap;
use powdr_ast::analyzed::{
AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, Analyzed, Identity,
};
use powdr_number::{FieldElement, LargeInt};
use std::sync::Arc;

use powdr_ast::analyzed::{
AlgebraicUnaryOperation, AlgebraicUnaryOperator, PolyID, PolynomialType,
};
use stwo_prover::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval};
use stwo_prover::core::backend::ColumnOps;
use stwo_prover::constraint_framework::preprocessed_columns::PreprocessedColumn;
use stwo_prover::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, ORIGINAL_TRACE_IDX,
};
use stwo_prover::core::backend::{Column, ColumnOps};
use stwo_prover::core::fields::m31::{BaseField, M31};
use stwo_prover::core::fields::{ExtensionOf, FieldExpOps, FieldOps};
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::circle::{CircleDomain, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::ColumnVec;
use stwo_prover::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};

pub type PowdrComponent<'a, F> = FrameworkComponent<PowdrEval<F>>;

pub(crate) fn gen_stwo_circuit_trace<T, B, F>(
witness: &[(String, Vec<T>)],
) -> ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>
pub fn gen_stwo_circle_column<T, B, F>(
domain: CircleDomain,
slice: &[T],
) -> CircleEvaluation<B, BaseField, BitReversedOrder>
where
T: FieldElement, //only Merenne31Field is supported, checked in runtime
B: FieldOps<M31> + ColumnOps<F>, // Ensure B implements FieldOps for M31
T: FieldElement,
B: FieldOps<M31> + ColumnOps<F>,

F: ExtensionOf<BaseField>,
{
assert!(
witness
.iter()
.all(|(_name, vec)| vec.len() == witness[0].1.len()),
"All Vec<T> in witness must have the same length. Mismatch found!"
slice.len().ilog2() == domain.size().ilog2(),
"column size must be equal to domain size"
);
let domain = CanonicCoset::new(witness[0].1.len().ilog2()).circle_domain();
witness
.iter()
.map(|(_name, values)| {
let values = values
.iter()
.map(|v| v.try_into_i32().unwrap().into())
.collect();
CircleEvaluation::new(domain, values)
})
.collect()
let mut column: <B as ColumnOps<M31>>::Column =
<B as ColumnOps<M31>>::Column::zeros(slice.len());
slice.iter().enumerate().for_each(|(i, v)| {
column.set(
bit_reverse_index(
coset_index_to_circle_domain_index(i, slice.len().ilog2()),
slice.len().ilog2(),
),
v.try_into_i32().unwrap().into(),
);
});

CircleEvaluation::new(domain, column)
}

pub struct PowdrEval<T> {
analyzed: Arc<Analyzed<T>>,
witness_columns: BTreeMap<PolyID, usize>,
constant_shifted: BTreeMap<PolyID, usize>,
constant_columns: BTreeMap<PolyID, usize>,
}

impl<T: FieldElement> PowdrEval<T> {
Expand All @@ -63,10 +70,29 @@ impl<T: FieldElement> PowdrEval<T> {
.enumerate()
.map(|(index, (_, id))| (id, index))
.collect();
// create a list of indexs of the constant polynomials that have next references constraint
let constant_with_next_list = get_constant_with_next_list(&analyzed);

let constant_shifted: BTreeMap<PolyID, usize> = analyzed
.definitions_in_source_order(PolynomialType::Constant)
.flat_map(|(symbol, _)| symbol.array_elements())
.enumerate()
.filter(|(_, (_, id))| constant_with_next_list.contains(&(id.id as usize)))
.map(|(index, (_, id))| (id, index))
.collect();

let constant_columns: BTreeMap<PolyID, usize> = analyzed
.definitions_in_source_order(PolynomialType::Constant)
.flat_map(|(symbol, _)| symbol.array_elements())
.enumerate()
.map(|(index, (_, id))| (id, index))
.collect();

Self {
analyzed,
witness_columns,
constant_shifted,
constant_columns,
}
}
}
Expand All @@ -80,14 +106,46 @@ impl<T: FieldElement> FrameworkEval for PowdrEval<T> {
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
assert!(
self.analyzed.constant_count() == 0 && self.analyzed.publics_count() == 0,
"Error: Expected no fixed columns nor public inputs, as they are not supported yet.",
self.analyzed.publics_count() == 0,
"Error: Expected no public inputs, as they are not supported yet.",
);

let witness_eval: BTreeMap<PolyID, [<E as EvalAtRow>::F; 2]> = self
.witness_columns
.keys()
.map(|poly_id| (*poly_id, eval.next_interaction_mask(0, [0, 1])))
.map(|poly_id| {
(
*poly_id,
eval.next_interaction_mask(ORIGINAL_TRACE_IDX, [0, 1]),
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
)
})
.collect();

let constant_eval: BTreeMap<_, _> = self
.constant_columns
.keys()
.enumerate()
.map(|(i, poly_id)| {
(
*poly_id,
// PreprocessedColumn::Plonk(i) is unused argument in get_preprocessed_column
eval.get_preprocessed_column(PreprocessedColumn::Plonk(i)),
)
})
.collect();

let constant_shifted_eval: BTreeMap<_, _> = self
.constant_shifted
.keys()
.enumerate()
.map(|(i, poly_id)| {
(
*poly_id,
eval.get_preprocessed_column(PreprocessedColumn::Plonk(
i + constant_eval.len(),
)),
)
})
.collect();

for id in self
Expand All @@ -96,7 +154,12 @@ impl<T: FieldElement> FrameworkEval for PowdrEval<T> {
{
match id {
Identity::Polynomial(identity) => {
let expr = to_stwo_expression(&identity.expression, &witness_eval);
let expr = to_stwo_expression(
&identity.expression,
&witness_eval,
&constant_shifted_eval,
&constant_eval,
);
eval.add_constraint(expr);
}
Identity::Connect(..) => {
Expand All @@ -119,6 +182,8 @@ impl<T: FieldElement> FrameworkEval for PowdrEval<T> {
fn to_stwo_expression<T: FieldElement, F>(
expr: &AlgebraicExpression<T>,
witness_eval: &BTreeMap<PolyID, [F; 2]>,
constant_shifted_eval: &BTreeMap<PolyID, F>,
constant_eval: &BTreeMap<PolyID, F>,
) -> F
where
F: FieldExpOps
Expand All @@ -144,9 +209,10 @@ where
false => witness_eval[&poly_id][0].clone(),
true => witness_eval[&poly_id][1].clone(),
},
PolynomialType::Constant => {
unimplemented!("Constant polynomials are not supported in stwo yet")
}
PolynomialType::Constant => match r.next {
false => constant_eval[&poly_id].clone(),
true => constant_shifted_eval[&poly_id].clone(),
},
PolynomialType::Intermediate => {
unimplemented!("Intermediate polynomials are not supported in stwo yet")
}
Expand All @@ -162,15 +228,17 @@ where
right,
}) => match **right {
AlgebraicExpression::Number(n) => {
let left = to_stwo_expression(left, witness_eval);
let left =
to_stwo_expression(left, witness_eval, constant_shifted_eval, constant_eval);
(0u32..n.to_integer().try_into_u32().unwrap())
.fold(F::one(), |acc, _| acc * left.clone())
}
_ => unimplemented!("pow with non-constant exponent"),
},
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => {
let left = to_stwo_expression(left, witness_eval);
let right = to_stwo_expression(right, witness_eval);
let left = to_stwo_expression(left, witness_eval, constant_shifted_eval, constant_eval);
let right =
to_stwo_expression(right, witness_eval, constant_shifted_eval, constant_eval);

match op {
Add => left + right,
Expand All @@ -180,7 +248,7 @@ where
}
}
AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => {
let expr = to_stwo_expression(expr, witness_eval);
let expr = to_stwo_expression(expr, witness_eval, constant_shifted_eval, constant_eval);

match op {
AlgebraicUnaryOperator::Minus => -expr,
Expand All @@ -191,3 +259,61 @@ where
}
}
}

pub fn constant_with_next_to_witness_col<T: FieldElement>(
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
expr: &AlgebraicExpression<T>,
constant_with_next_list: &mut Vec<usize>,
) {
use AlgebraicBinaryOperator::*;
match expr {
AlgebraicExpression::Reference(r) => {
let poly_id = r.poly_id;

match poly_id.ptype {
PolynomialType::Committed => {}
PolynomialType::Constant => match r.next {
false => {}
true => {
constant_with_next_list.push(r.poly_id.id as usize);
}
},
PolynomialType::Intermediate => {}
}
}
AlgebraicExpression::PublicReference(..) => {}
AlgebraicExpression::Number(_) => {}
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation {
left,
op: Pow,
right,
}) => match **right {
AlgebraicExpression::Number(_) => {
constant_with_next_to_witness_col::<T>(left, constant_with_next_list);
}
_ => unimplemented!("pow with non-constant exponent"),
},
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op: _, right }) => {
constant_with_next_to_witness_col::<T>(left, constant_with_next_list);
constant_with_next_to_witness_col::<T>(right, constant_with_next_list);
}
AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op: _, expr }) => {
constant_with_next_to_witness_col::<T>(expr, constant_with_next_list);
}
AlgebraicExpression::Challenge(_challenge) => {}
}
}

// This function creates a list of indices of the constant polynomials that have next references constraint
pub fn get_constant_with_next_list<T: FieldElement>(analyzed: &Analyzed<T>) -> Vec<usize> {
let mut all_constant_with_next: Vec<usize> = Vec::new();
for id in analyzed.identities_with_inlined_intermediate_polynomials() {
if let Identity::Polynomial(identity) = id {
let mut constant_with_next: Vec<usize> = Vec::new();
constant_with_next_to_witness_col::<T>(&identity.expression, &mut constant_with_next);
all_constant_with_next.extend(constant_with_next)
}
}
all_constant_with_next.sort_unstable();
all_constant_with_next.dedup();
all_constant_with_next
}
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
31 changes: 20 additions & 11 deletions backend/src/stwo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
field_filter::generalize_factory, Backend, BackendFactory, BackendOptions, Error, Proof,
};
use powdr_ast::analyzed::Analyzed;
use powdr_executor::constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn};
use powdr_executor::constant_evaluator::VariablySizedColumn;
use powdr_executor::witgen::WitgenCallback;
use powdr_number::{FieldElement, Mersenne31Field};
use prover::StwoProver;
Expand All @@ -17,13 +17,12 @@ use stwo_prover::core::channel::{Blake2sChannel, Channel, MerkleChannel};
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;

mod circuit_builder;
mod proof;
mod prover;
#[allow(dead_code)]

struct RestrictedFactory;

impl<F: FieldElement> BackendFactory<F> for RestrictedFactory {
#[allow(unreachable_code)]
#[allow(unused_variables)]
fn create(
&self,
Expand All @@ -39,14 +38,24 @@ impl<F: FieldElement> BackendFactory<F> for RestrictedFactory {
if proving_key.is_some() {
return Err(Error::BackendError("Proving key unused".to_string()));
}

if pil.degrees().len() > 1 {
return Err(Error::NoVariableDegreeAvailable);
}
let fixed = Arc::new(
get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?,
);
let stwo: Box<StwoProver<F, SimdBackend, Blake2sMerkleChannel, Blake2sChannel>> =

let mut stwo: Box<StwoProver<F, SimdBackend, Blake2sMerkleChannel, Blake2sChannel>> =
Box::new(StwoProver::new(pil, fixed)?);

match (proving_key, verification_key) {
(Some(pk), Some(vk)) => {
stwo.set_proving_key(pk);
//stwo.set_verifying_key(vk);
}
_ => {
stwo.setup();
}
}

Ok(stwo)
}
}
Expand All @@ -68,7 +77,7 @@ where

Ok(self.verify(proof, instances)?)
}
#[allow(unreachable_code)]

#[allow(unused_variables)]
fn prove(
&self,
Expand All @@ -81,8 +90,8 @@ where
}
Ok(StwoProver::prove(self, witness)?)
}
#[allow(unused_variables)]
fn export_verification_key(&self, output: &mut dyn io::Write) -> Result<(), Error> {
unimplemented!()
fn export_proving_key(&self, output: &mut dyn io::Write) -> Result<(), Error> {
self.export_proving_key(output)
.map_err(|e| Error::BackendError(e.to_string()))
}
}
Loading