From 4a7f349093e2fbba118faeb83bf93c82cbac42a7 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Tue, 19 Nov 2024 13:19:56 +0800 Subject: [PATCH 01/48] add support for constant column --- backend/Cargo.toml | 3 ++- backend/src/stwo/circuit_builder.rs | 31 ++++++++++++++++++-------- backend/src/stwo/prover.rs | 34 ++++++++++++++--------------- 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 4210387c9a..3450e46fa1 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -62,7 +62,8 @@ p3-commit = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf2 p3-matrix = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } p3-uni-stark = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } # TODO: Change this to main branch when the `andrew/dev/update-toolchain` branch is merged,the main branch is using "nightly-2024-01-04", not compatiable with plonky3 -stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "e6d10bc107c11cce54bb4aa152c3afa2e15e92c1" } +# stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "e6d10bc107c11cce54bb4aa152c3afa2e15e92c1" } +stwo-prover = { git = "https://github.com/ShuangWu121/stwo.git", optional = true, rev = "564a4ddcde376ba0ae78da4d86ea5ad7338ef6fe" } strum = { version = "0.24.1", features = ["derive"] } log = "0.4.17" diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 3b72ecf0b6..7825639b05 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -53,6 +53,7 @@ where pub struct PowdrEval { analyzed: Arc>, witness_columns: BTreeMap, + constant_columns: BTreeMap, } impl PowdrEval { @@ -63,10 +64,17 @@ impl PowdrEval { .enumerate() .map(|(index, (_, id))| (id, index)) .collect(); + let constant_columns: BTreeMap = analyzed + .definitions_in_source_order(PolynomialType::Constant) + .flat_map(|(symbol, _)| symbol.array_elements()) + .enumerate() + .map(|(index, (_, id))| (id, index)) + .collect(); Self { analyzed, witness_columns, + constant_columns, } } } @@ -87,7 +95,12 @@ impl FrameworkEval for PowdrEval { let witness_eval: BTreeMap::F; 2]> = self .witness_columns .keys() - .map(|poly_id| (*poly_id, eval.next_interaction_mask(0, [0, 1]))) + .map(|poly_id| (*poly_id, eval.next_interaction_mask(1, [0, 1]))) + .collect(); + let constant_eval: BTreeMap::F> = self + .constant_columns + .keys() + .map(|poly_id| (*poly_id, eval.next_interaction_mask(1, [0])[0].clone())) .collect(); for id in self @@ -96,7 +109,8 @@ impl FrameworkEval for PowdrEval { { match id { Identity::Polynomial(identity) => { - let expr = to_stwo_expression(&identity.expression, &witness_eval); + let expr = + to_stwo_expression(&identity.expression, &witness_eval, &constant_eval); eval.add_constraint(expr); } Identity::Connect(..) => { @@ -119,6 +133,7 @@ impl FrameworkEval for PowdrEval { fn to_stwo_expression( expr: &AlgebraicExpression, witness_eval: &BTreeMap, + constant_eval: &BTreeMap, ) -> F where F: FieldExpOps @@ -144,9 +159,7 @@ where false => witness_eval[&poly_id][0].clone(), true => witness_eval[&poly_id][1].clone(), }, - PolynomialType::Constant => { - unimplemented!("Constant polynomials are not supported in stwo yet") - } + PolynomialType::Constant => constant_eval[&poly_id].clone(), PolynomialType::Intermediate => { unimplemented!("Intermediate polynomials are not supported in stwo yet") } @@ -162,15 +175,15 @@ where right, }) => match **right { AlgebraicExpression::Number(n) => { - let left = to_stwo_expression(left, witness_eval); + let left = to_stwo_expression(left, witness_eval, constant_eval); (0u32..n.to_integer().try_into_u32().unwrap()) .fold(F::one(), |acc, _| acc * left.clone()) } _ => unimplemented!("pow with non-constant exponent"), }, AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { - let left = to_stwo_expression(left, witness_eval); - let right = to_stwo_expression(right, witness_eval); + let left = to_stwo_expression(left, witness_eval, constant_eval); + let right = to_stwo_expression(right, witness_eval, constant_eval); match op { Add => left + right, @@ -180,7 +193,7 @@ where } } AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => { - let expr = to_stwo_expression(expr, witness_eval); + let expr = to_stwo_expression(expr, witness_eval, constant_eval); match op { AlgebraicUnaryOperator::Minus => -expr, diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index ab79e93b03..fc2b3b824a 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -70,14 +70,20 @@ where ); // Setup protocol. - let mut prover_channel = ::C::default(); + let prover_channel = &mut ::C::default(); let commitment_scheme = &mut CommitmentSchemeProver::::new(config, &twiddles); + // Preprocessed trace + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals([]); + tree_builder.commit(prover_channel); + + // committed/witness trace let trace = gen_stwo_circuit_trace::(witness); let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(trace); - tree_builder.commit(&mut prover_channel); + tree_builder.commit(prover_channel); let component = PowdrComponent::new( &mut TraceLocationAllocator::default(), @@ -86,7 +92,7 @@ where let proof = stwo_prover::core::prover::prove::( &[&component], - &mut prover_channel, + prover_channel, commitment_scheme, ) .unwrap(); @@ -105,8 +111,8 @@ where let proof: StarkProof = bincode::deserialize(proof).map_err(|e| format!("Failed to deserialize proof: {e}"))?; - let mut verifier_channel = ::C::default(); - let mut commitment_scheme = CommitmentSchemeVerifier::::new(config); + let verifier_channel = &mut ::C::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); //Constraints that are to be proved let component = PowdrComponent::new( @@ -115,21 +121,13 @@ where ); // Retrieve the expected column sizes in each commitment interaction, from the AIR. - // TODO: When constant columns are supported, there will be more than one sizes and proof.commitments - // size[0] is for constant columns, size[1] is for witness columns, size[2] is for lookup columns - // pass size[1] for witness columns now is not doable due to this branch is outdated for the new feature of constant columns - // it will throw errors. let sizes = component.trace_log_degree_bounds(); - assert_eq!(sizes.len(), 1); - commitment_scheme.commit(proof.commitments[0], &sizes[0], &mut verifier_channel); - stwo_prover::core::prover::verify( - &[&component], - &mut verifier_channel, - &mut commitment_scheme, - proof, - ) - .map_err(|e| e.to_string()) + commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); + commitment_scheme.commit(proof.commitments[1], &sizes[1], verifier_channel); + + stwo_prover::core::prover::verify(&[&component], verifier_channel, commitment_scheme, proof) + .map_err(|e| e.to_string()) } } From d1becd048edf636ac69056b73678d4b964dfd132 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Thu, 21 Nov 2024 12:02:01 +0100 Subject: [PATCH 02/48] building constant columns --- backend/src/stwo/circuit_builder.rs | 38 ++++++++++++++++---- backend/src/stwo/mod.rs | 4 +-- backend/src/stwo/prover.rs | 56 +++++++++++++++++++++++++---- 3 files changed, 83 insertions(+), 15 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 7825639b05..46c1443439 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -13,7 +13,10 @@ use std::sync::Arc; use powdr_ast::analyzed::{ AlgebraicUnaryOperation, AlgebraicUnaryOperator, PolyID, PolynomialType, }; -use stwo_prover::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval}; +use stwo_prover::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn}; +use stwo_prover::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, ORIGINAL_TRACE_IDX, PREPROCESSED_TRACE_IDX, +}; use stwo_prover::core::backend::ColumnOps; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; @@ -64,12 +67,15 @@ impl PowdrEval { .enumerate() .map(|(index, (_, id))| (id, index)) .collect(); + + println!(" \n witness_columns is {:?} \n ", witness_columns); let constant_columns: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Constant) .flat_map(|(symbol, _)| symbol.array_elements()) .enumerate() .map(|(index, (_, id))| (id, index)) .collect(); + println!("\n constant_columns is {:?} \n", constant_columns); Self { analyzed, @@ -88,20 +94,32 @@ impl FrameworkEval for PowdrEval { } fn evaluate(&self, mut eval: E) -> E { assert!( - self.analyzed.constant_count() == 0 && self.analyzed.publics_count() == 0, - "Error: Expected no fixed columns nor public inputs, as they are not supported yet.", + self.analyzed.publics_count() == 0, + "Error: Expected no public inputs, as they are not supported yet.", ); let witness_eval: BTreeMap::F; 2]> = self .witness_columns .keys() - .map(|poly_id| (*poly_id, eval.next_interaction_mask(1, [0, 1]))) + .map(|poly_id| { + ( + *poly_id, + eval.next_interaction_mask(ORIGINAL_TRACE_IDX, [0, 1]), + ) + }) .collect(); + println!("witness_eval is {:?}", witness_eval); let constant_eval: BTreeMap::F> = self .constant_columns .keys() - .map(|poly_id| (*poly_id, eval.next_interaction_mask(1, [0])[0].clone())) + .map(|poly_id| { + ( + *poly_id, + eval.get_preprocessed_column(PreprocessedColumn::Plonk(3)), + ) + }) .collect(); + println!("constant_eval is {:?}", constant_eval); for id in self .analyzed @@ -109,6 +127,7 @@ impl FrameworkEval for PowdrEval { { match id { Identity::Polynomial(identity) => { + println!("exprrrrrrrrrrrrrrrrrrrrr is {:?}",&identity.expression); let expr = to_stwo_expression(&identity.expression, &witness_eval, &constant_eval); eval.add_constraint(expr); @@ -150,6 +169,7 @@ where + From, { use AlgebraicBinaryOperator::*; + // println!("expr is {:?}", expr); match expr { AlgebraicExpression::Reference(r) => { let poly_id = r.poly_id; @@ -159,7 +179,13 @@ where false => witness_eval[&poly_id][0].clone(), true => witness_eval[&poly_id][1].clone(), }, - PolynomialType::Constant => constant_eval[&poly_id].clone(), + PolynomialType::Constant => { + println!( + "constant_eval[&poly_id].clone() is {:?}", + constant_eval[&poly_id].clone() + ); + constant_eval[&poly_id].clone() + } PolynomialType::Intermediate => { unimplemented!("Intermediate polynomials are not supported in stwo yet") } diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 794b66401d..7926950b91 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -18,6 +18,7 @@ use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; mod circuit_builder; mod prover; +//mod proof; #[allow(dead_code)] struct RestrictedFactory; @@ -42,9 +43,6 @@ impl BackendFactory for RestrictedFactory { if pil.degrees().len() > 1 { return Err(Error::NoVariableDegreeAvailable); } - let fixed = Arc::new( - get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, - ); let stwo: Box> = Box::new(StwoProver::new(pil, fixed)?); Ok(stwo) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index fc2b3b824a..0776f01877 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -1,4 +1,6 @@ use powdr_ast::analyzed::Analyzed; +use powdr_backend_utils::machine_fixed_columns; +use powdr_executor::constant_evaluator::VariablySizedColumn; use serde::de::DeserializeOwned; use serde::ser::Serialize; use std::io; @@ -14,10 +16,12 @@ use powdr_number::FieldElement; use stwo_prover::core::air::{Component, ComponentProver}; use stwo_prover::core::backend::{Backend, BackendForChannel}; use stwo_prover::core::channel::{Channel, MerkleChannel}; -use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fri::FriConfig; use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; -use stwo_prover::core::poly::circle::CanonicCoset; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::ColumnVec; const FRI_LOG_BLOWUP: usize = 1; const FRI_NUM_QUERIES: usize = 100; @@ -26,7 +30,7 @@ const LOG_LAST_LAYER_DEGREE_BOUND: usize = 0; pub struct StwoProver { pub analyzed: Arc>, - _fixed: Arc)>>, + fixed: Arc)>>, /// Proving key placeholder _proving_key: Option<()>, /// Verifying key placeholder @@ -46,11 +50,11 @@ where { pub fn new( analyzed: Arc>, - _fixed: Arc)>>, + fixed: Arc)>>, ) -> Result { Ok(Self { analyzed, - _fixed, + fixed, _proving_key: None, _verifying_key: None, _channel_marker: PhantomData, @@ -73,13 +77,43 @@ where let prover_channel = &mut ::C::default(); let commitment_scheme = &mut CommitmentSchemeProver::::new(config, &twiddles); + // get fix_columns evaluations + let fixed_columns = machine_fixed_columns(&self.fixed, &self.analyzed); + + let domain = CanonicCoset::new( + fixed_columns + .keys() + .next() + .map(|&first_key| first_key.ilog2()) + .unwrap_or(0), + ) + .circle_domain(); + // println!("domain size is {}", *fixed_columns.keys().next().unwrap() as u32); + + let constant_trace: ColumnVec> = + fixed_columns + .values() + .flat_map(|vec| { + println!("vec is {:?}", vec); + vec.iter().map(|(_name, values)| { + let values = values + .iter() + .map(|v| v.try_into_i32().unwrap().into()) + .collect(); + CircleEvaluation::new(domain, values) + }) + }) + .collect(); + println!("constant trace is {:?}", constant_trace); + // Preprocessed trace let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals([]); + tree_builder.extend_evals(constant_trace); tree_builder.commit(prover_channel); // committed/witness trace let trace = gen_stwo_circuit_trace::(witness); + print!("trace is {:?}", trace); let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(trace); @@ -90,12 +124,22 @@ where PowdrEval::new(self.analyzed.clone()), ); + let n_preprocessed_columns = commitment_scheme.trees[0] + .polynomials + .len(); + println!("n_preprocessed_columns is {}", n_preprocessed_columns); + + let trace = commitment_scheme.trace(); + println!("trace is {:?}", trace.evals); + let proof = stwo_prover::core::prover::prove::( &[&component], prover_channel, commitment_scheme, ) .unwrap(); + + Ok(bincode::serialize(&proof).unwrap()) } From 53183d494f2b4bea34cc0e66e3f49b1f867021d0 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Thu, 21 Nov 2024 12:02:50 +0100 Subject: [PATCH 03/48] test pil --- test_data/pil/fibo_no_public.pil | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 test_data/pil/fibo_no_public.pil diff --git a/test_data/pil/fibo_no_public.pil b/test_data/pil/fibo_no_public.pil new file mode 100644 index 0000000000..94c674b863 --- /dev/null +++ b/test_data/pil/fibo_no_public.pil @@ -0,0 +1,13 @@ +let N = 4; + +// This uses the alternative nomenclature as well. + +namespace Fibonacci(N); + col fixed ISLAST(i) { if i == N - 1 { 1 } else { 0 } }; + col witness x, y; + + ISLAST * (y' - 1) = 0; + ISLAST * (x' - 1) = 0; + + (1-ISLAST) * (x' - y) = 0; + (1-ISLAST) * (y' - (x + y)) = 0; \ No newline at end of file From a3f6d0f4481a1c757f813f9d95434b5af66d4d02 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Thu, 21 Nov 2024 12:10:16 +0100 Subject: [PATCH 04/48] cleaner --- backend/src/stwo/circuit_builder.rs | 10 ---------- backend/src/stwo/prover.rs | 3 --- 2 files changed, 13 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 46c1443439..e6d0f2528b 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -68,14 +68,12 @@ impl PowdrEval { .map(|(index, (_, id))| (id, index)) .collect(); - println!(" \n witness_columns is {:?} \n ", witness_columns); let constant_columns: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Constant) .flat_map(|(symbol, _)| symbol.array_elements()) .enumerate() .map(|(index, (_, id))| (id, index)) .collect(); - println!("\n constant_columns is {:?} \n", constant_columns); Self { analyzed, @@ -108,7 +106,6 @@ impl FrameworkEval for PowdrEval { ) }) .collect(); - println!("witness_eval is {:?}", witness_eval); let constant_eval: BTreeMap::F> = self .constant_columns .keys() @@ -119,7 +116,6 @@ impl FrameworkEval for PowdrEval { ) }) .collect(); - println!("constant_eval is {:?}", constant_eval); for id in self .analyzed @@ -127,7 +123,6 @@ impl FrameworkEval for PowdrEval { { match id { Identity::Polynomial(identity) => { - println!("exprrrrrrrrrrrrrrrrrrrrr is {:?}",&identity.expression); let expr = to_stwo_expression(&identity.expression, &witness_eval, &constant_eval); eval.add_constraint(expr); @@ -169,7 +164,6 @@ where + From, { use AlgebraicBinaryOperator::*; - // println!("expr is {:?}", expr); match expr { AlgebraicExpression::Reference(r) => { let poly_id = r.poly_id; @@ -180,10 +174,6 @@ where true => witness_eval[&poly_id][1].clone(), }, PolynomialType::Constant => { - println!( - "constant_eval[&poly_id].clone() is {:?}", - constant_eval[&poly_id].clone() - ); constant_eval[&poly_id].clone() } PolynomialType::Intermediate => { diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 0776f01877..da1b0b93df 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -88,7 +88,6 @@ where .unwrap_or(0), ) .circle_domain(); - // println!("domain size is {}", *fixed_columns.keys().next().unwrap() as u32); let constant_trace: ColumnVec> = fixed_columns @@ -104,7 +103,6 @@ where }) }) .collect(); - println!("constant trace is {:?}", constant_trace); // Preprocessed trace let mut tree_builder = commitment_scheme.tree_builder(); @@ -113,7 +111,6 @@ where // committed/witness trace let trace = gen_stwo_circuit_trace::(witness); - print!("trace is {:?}", trace); let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(trace); From eb6f2ec268116d14d93180c44f92df9365de499a Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 22 Nov 2024 00:49:58 +0100 Subject: [PATCH 05/48] constant support --- backend/Cargo.toml | 2 +- backend/src/stwo/circuit_builder.rs | 51 +++++++++--- backend/src/stwo/mod.rs | 6 +- backend/src/stwo/prover.rs | 124 ++++++++++++++++++++++------ pipeline/tests/pil.rs | 7 ++ test_data/pil/fibo_no_publics.pil | 13 +++ 6 files changed, 162 insertions(+), 41 deletions(-) create mode 100644 test_data/pil/fibo_no_publics.pil diff --git a/backend/Cargo.toml b/backend/Cargo.toml index ebac38e74c..d67d1af6e0 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -63,7 +63,7 @@ p3-commit = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf2 p3-matrix = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } p3-uni-stark = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } # TODO: Change this to main branch when the `andrew/dev/update-toolchain` branch is merged,the main branch is using "nightly-2024-01-04", not compatiable with plonky3 -stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "e6d10bc107c11cce54bb4aa152c3afa2e15e92c1" } +stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "564a4ddcde376ba0ae78da4d86ea5ad7338ef6fe" } strum = { version = "0.24.1", features = ["derive"] } log = "0.4.17" diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 3b72ecf0b6..f34b0672bd 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -13,7 +13,10 @@ use std::sync::Arc; use powdr_ast::analyzed::{ AlgebraicUnaryOperation, AlgebraicUnaryOperator, PolyID, PolynomialType, }; -use stwo_prover::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval}; +use stwo_prover::constraint_framework::preprocessed_columns::PreprocessedColumn; +use stwo_prover::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, ORIGINAL_TRACE_IDX, +}; use stwo_prover::core::backend::ColumnOps; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; @@ -53,10 +56,18 @@ where pub struct PowdrEval { analyzed: Arc>, witness_columns: BTreeMap, + constant_columns: BTreeMap, } impl PowdrEval { pub fn new(analyzed: Arc>) -> Self { + let constant_columns: BTreeMap = analyzed + .definitions_in_source_order(PolynomialType::Constant) + .flat_map(|(symbol, _)| symbol.array_elements()) + .enumerate() + .map(|(index, (_, id))| (id, index)) + .collect(); + let witness_columns: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Committed) .flat_map(|(symbol, _)| symbol.array_elements()) @@ -67,6 +78,7 @@ impl PowdrEval { Self { analyzed, witness_columns, + constant_columns, } } } @@ -80,14 +92,29 @@ impl FrameworkEval for PowdrEval { } fn evaluate(&self, mut eval: E) -> E { assert!( - self.analyzed.constant_count() == 0 && self.analyzed.publics_count() == 0, - "Error: Expected no fixed columns nor public inputs, as they are not supported yet.", + self.analyzed.publics_count() == 0, + "Error: Expected no public inputs, as they are not supported yet.", ); let witness_eval: BTreeMap::F; 2]> = self .witness_columns .keys() - .map(|poly_id| (*poly_id, eval.next_interaction_mask(0, [0, 1]))) + .map(|poly_id| { + ( + *poly_id, + eval.next_interaction_mask(ORIGINAL_TRACE_IDX, [0, 1]), + ) + }) + .collect(); + let constant_eval: BTreeMap::F> = self + .constant_columns + .keys() + .map(|poly_id| { + ( + *poly_id, + eval.get_preprocessed_column(PreprocessedColumn::Plonk(3)), + ) + }) .collect(); for id in self @@ -96,7 +123,8 @@ impl FrameworkEval for PowdrEval { { match id { Identity::Polynomial(identity) => { - let expr = to_stwo_expression(&identity.expression, &witness_eval); + let expr = + to_stwo_expression(&identity.expression, &witness_eval, &constant_eval); eval.add_constraint(expr); } Identity::Connect(..) => { @@ -119,6 +147,7 @@ impl FrameworkEval for PowdrEval { fn to_stwo_expression( expr: &AlgebraicExpression, witness_eval: &BTreeMap, + constant_eval: &BTreeMap, ) -> F where F: FieldExpOps @@ -144,9 +173,7 @@ where false => witness_eval[&poly_id][0].clone(), true => witness_eval[&poly_id][1].clone(), }, - PolynomialType::Constant => { - unimplemented!("Constant polynomials are not supported in stwo yet") - } + PolynomialType::Constant => constant_eval[&poly_id].clone(), PolynomialType::Intermediate => { unimplemented!("Intermediate polynomials are not supported in stwo yet") } @@ -162,15 +189,15 @@ where right, }) => match **right { AlgebraicExpression::Number(n) => { - let left = to_stwo_expression(left, witness_eval); + let left = to_stwo_expression(left, witness_eval, constant_eval); (0u32..n.to_integer().try_into_u32().unwrap()) .fold(F::one(), |acc, _| acc * left.clone()) } _ => unimplemented!("pow with non-constant exponent"), }, AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { - let left = to_stwo_expression(left, witness_eval); - let right = to_stwo_expression(right, witness_eval); + let left = to_stwo_expression(left, witness_eval, constant_eval); + let right = to_stwo_expression(right, witness_eval, constant_eval); match op { Add => left + right, @@ -180,7 +207,7 @@ where } } AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => { - let expr = to_stwo_expression(expr, witness_eval); + let expr = to_stwo_expression(expr, witness_eval, constant_eval); match op { AlgebraicUnaryOperator::Minus => -expr, diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 794b66401d..99cd858418 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -8,7 +8,7 @@ use crate::{ field_filter::generalize_factory, Backend, BackendFactory, BackendOptions, Error, Proof, }; use powdr_ast::analyzed::Analyzed; -use powdr_executor::constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}; +use powdr_executor::constant_evaluator::VariablySizedColumn; use powdr_executor::witgen::WitgenCallback; use powdr_number::{FieldElement, Mersenne31Field}; use prover::StwoProver; @@ -18,6 +18,7 @@ use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; mod circuit_builder; mod prover; + #[allow(dead_code)] struct RestrictedFactory; @@ -42,9 +43,6 @@ impl BackendFactory for RestrictedFactory { if pil.degrees().len() > 1 { return Err(Error::NoVariableDegreeAvailable); } - let fixed = Arc::new( - get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, - ); let stwo: Box> = Box::new(StwoProver::new(pil, fixed)?); Ok(stwo) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index ab79e93b03..bf4bb40739 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -1,6 +1,11 @@ + use powdr_ast::analyzed::Analyzed; +use powdr_backend_utils::machine_fixed_columns; +use powdr_executor::constant_evaluator::VariablySizedColumn; +use powdr_number::{DegreeType, FieldElement}; use serde::de::DeserializeOwned; use serde::ser::Serialize; +use std::collections::BTreeMap; use std::io; use std::marker::PhantomData; use std::sync::Arc; @@ -10,14 +15,16 @@ use crate::stwo::circuit_builder::{gen_stwo_circuit_trace, PowdrComponent, Powdr use stwo_prover::constraint_framework::TraceLocationAllocator; use stwo_prover::core::prover::StarkProof; -use powdr_number::FieldElement; use stwo_prover::core::air::{Component, ComponentProver}; use stwo_prover::core::backend::{Backend, BackendForChannel}; use stwo_prover::core::channel::{Channel, MerkleChannel}; -use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fri::FriConfig; use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; -use stwo_prover::core::poly::circle::CanonicCoset; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order; +use stwo_prover::core::ColumnVec; const FRI_LOG_BLOWUP: usize = 1; const FRI_NUM_QUERIES: usize = 100; @@ -26,7 +33,7 @@ const LOG_LAST_LAYER_DEGREE_BOUND: usize = 0; pub struct StwoProver { pub analyzed: Arc>, - _fixed: Arc)>>, + fixed: Arc)>>, /// Proving key placeholder _proving_key: Option<()>, /// Verifying key placeholder @@ -46,11 +53,11 @@ where { pub fn new( analyzed: Arc>, - _fixed: Arc)>>, + fixed: Arc)>>, ) -> Result { Ok(Self { analyzed, - _fixed, + fixed, _proving_key: None, _verifying_key: None, _channel_marker: PhantomData, @@ -70,14 +77,91 @@ where ); // Setup protocol. - let mut prover_channel = ::C::default(); + let prover_channel = &mut ::C::default(); let commitment_scheme = &mut CommitmentSchemeProver::::new(config, &twiddles); + // get fix_columns evaluations + let fixed_columns = machine_fixed_columns(&self.fixed, &self.analyzed); + + let domain = CanonicCoset::new( + fixed_columns + .keys() + .next() + .map(|&first_key| first_key.ilog2()) + .unwrap_or(0), + ) + .circle_domain(); + + let updated_fixed_columns: BTreeMap)>> = fixed_columns + .iter() + .map(|(key, vec)| { + let transformed_vec: Vec<(String, Vec)> = vec + .iter() + .map(|(name, slice)| { + let mut values: Vec = slice.to_vec(); // Clone the slice into a Vec + bit_reverse_coset_to_circle_domain_order(&mut values); // Apply bit reversal + (name.clone(), values) // Return the updated tuple + }) + .collect(); // Collect the updated vector + (*key, transformed_vec) // Rebuild the BTreeMap with transformed vectors + }) + .collect(); + + let constant_trace: ColumnVec> = + updated_fixed_columns + .values() + .flat_map(|vec| { + vec.iter().map(|(_name, values)| { + let values = values + .iter() + .map(|v| v.try_into_i32().unwrap().into()) + .collect(); + CircleEvaluation::new(domain, values) + }) + }) + .collect(); + + // Preprocessed trace + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(constant_trace.clone()); + tree_builder.commit(prover_channel); + + + + + let transformed_witness: Vec<(String, Vec)> = witness + .iter() + .map(|(name, vec)| { + ( + name.clone(), + vec.iter() + .map(|&elem| { + if elem == F::from(2147483647) { + F::zero() + } else { + elem + } + }) + .collect(), + ) + }) + .collect(); + + let witness: &Vec<(String, Vec)> = &transformed_witness + .into_iter() + .map(|(name, mut vec)| { + bit_reverse_coset_to_circle_domain_order(&mut vec); + (name, vec) + }) + .collect(); + + + // committed/witness trace let trace = gen_stwo_circuit_trace::(witness); let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(trace); - tree_builder.commit(&mut prover_channel); + tree_builder.commit(prover_channel); let component = PowdrComponent::new( &mut TraceLocationAllocator::default(), @@ -86,7 +170,7 @@ where let proof = stwo_prover::core::prover::prove::( &[&component], - &mut prover_channel, + prover_channel, commitment_scheme, ) .unwrap(); @@ -105,8 +189,8 @@ where let proof: StarkProof = bincode::deserialize(proof).map_err(|e| format!("Failed to deserialize proof: {e}"))?; - let mut verifier_channel = ::C::default(); - let mut commitment_scheme = CommitmentSchemeVerifier::::new(config); + let verifier_channel = &mut ::C::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); //Constraints that are to be proved let component = PowdrComponent::new( @@ -115,21 +199,13 @@ where ); // Retrieve the expected column sizes in each commitment interaction, from the AIR. - // TODO: When constant columns are supported, there will be more than one sizes and proof.commitments - // size[0] is for constant columns, size[1] is for witness columns, size[2] is for lookup columns - // pass size[1] for witness columns now is not doable due to this branch is outdated for the new feature of constant columns - // it will throw errors. let sizes = component.trace_log_degree_bounds(); - assert_eq!(sizes.len(), 1); - commitment_scheme.commit(proof.commitments[0], &sizes[0], &mut verifier_channel); - stwo_prover::core::prover::verify( - &[&component], - &mut verifier_channel, - &mut commitment_scheme, - proof, - ) - .map_err(|e| e.to_string()) + commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); + commitment_scheme.commit(proof.commitments[1], &sizes[1], verifier_channel); + + stwo_prover::core::prover::verify(&[&component], verifier_channel, commitment_scheme, proof) + .map_err(|e| e.to_string()) } } diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index a4931ddbb8..27848f388b 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -268,6 +268,13 @@ fn stwo_add_and_equal() { let f = "pil/add_and_equal.pil"; test_stwo(f, Default::default()); } + +#[test] +fn stwo_fibonacci() { + let f = "pil/fibo_no_publics.pil"; + test_stwo(f, Default::default()); +} + #[test] fn simple_div() { let f = "pil/simple_div.pil"; diff --git a/test_data/pil/fibo_no_publics.pil b/test_data/pil/fibo_no_publics.pil new file mode 100644 index 0000000000..285cdf6b93 --- /dev/null +++ b/test_data/pil/fibo_no_publics.pil @@ -0,0 +1,13 @@ +let N = 512; + +// This uses the alternative nomenclature as well. + +namespace Fibonacci(N); + col fixed ISLAST(i) { if i == N - 1 { 1 } else { 0 } }; + col witness x, y; + + ISLAST * (y' - 1) = 0; + ISLAST * (x' - 1) = 0; + + (1-ISLAST) * (x' - y) = 0; + (1-ISLAST) * (y' - (x + y)) = 0; \ No newline at end of file From cd67a59e49fad602401a77e010376b004fefdbc7 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 22 Nov 2024 01:08:16 +0100 Subject: [PATCH 06/48] clean up --- backend/src/stwo/circuit_builder.rs | 13 ++----------- backend/src/stwo/prover.rs | 15 --------------- test_data/pil/fibo_no_public.pil | 13 ------------- 3 files changed, 2 insertions(+), 39 deletions(-) delete mode 100644 test_data/pil/fibo_no_public.pil diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 7fd4f1fd99..b1f7422df1 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -61,20 +61,13 @@ pub struct PowdrEval { impl PowdrEval { pub fn new(analyzed: Arc>) -> Self { - let constant_columns: BTreeMap = analyzed - .definitions_in_source_order(PolynomialType::Constant) - .flat_map(|(symbol, _)| symbol.array_elements()) - .enumerate() - .map(|(index, (_, id))| (id, index)) - .collect(); - let witness_columns: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Committed) .flat_map(|(symbol, _)| symbol.array_elements()) .enumerate() .map(|(index, (_, id))| (id, index)) .collect(); - + let constant_columns: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Constant) .flat_map(|(symbol, _)| symbol.array_elements()) @@ -180,9 +173,7 @@ where false => witness_eval[&poly_id][0].clone(), true => witness_eval[&poly_id][1].clone(), }, - PolynomialType::Constant => { - constant_eval[&poly_id].clone() - } + PolynomialType::Constant => constant_eval[&poly_id].clone(), PolynomialType::Intermediate => { unimplemented!("Intermediate polynomials are not supported in stwo yet") } diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 8ee59e9eac..4a638fc19e 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -1,4 +1,3 @@ - use powdr_ast::analyzed::Analyzed; use powdr_backend_utils::machine_fixed_columns; use powdr_executor::constant_evaluator::VariablySizedColumn; @@ -126,9 +125,6 @@ where tree_builder.extend_evals(constant_trace.clone()); tree_builder.commit(prover_channel); - - - let transformed_witness: Vec<(String, Vec)> = witness .iter() .map(|(name, vec)| { @@ -155,7 +151,6 @@ where }) .collect(); - // committed/witness trace let trace = gen_stwo_circuit_trace::(witness); @@ -168,22 +163,12 @@ where PowdrEval::new(self.analyzed.clone()), ); - let n_preprocessed_columns = commitment_scheme.trees[0] - .polynomials - .len(); - println!("n_preprocessed_columns is {}", n_preprocessed_columns); - - let trace = commitment_scheme.trace(); - println!("trace is {:?}", trace.evals); - let proof = stwo_prover::core::prover::prove::( &[&component], prover_channel, commitment_scheme, ) .unwrap(); - - Ok(bincode::serialize(&proof).unwrap()) } diff --git a/test_data/pil/fibo_no_public.pil b/test_data/pil/fibo_no_public.pil deleted file mode 100644 index 94c674b863..0000000000 --- a/test_data/pil/fibo_no_public.pil +++ /dev/null @@ -1,13 +0,0 @@ -let N = 4; - -// This uses the alternative nomenclature as well. - -namespace Fibonacci(N); - col fixed ISLAST(i) { if i == N - 1 { 1 } else { 0 } }; - col witness x, y; - - ISLAST * (y' - 1) = 0; - ISLAST * (x' - 1) = 0; - - (1-ISLAST) * (x' - y) = 0; - (1-ISLAST) * (y' - (x + y)) = 0; \ No newline at end of file From 12146becfc2cb0c52f44ea2baea02e0be103c75c Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Sun, 24 Nov 2024 10:04:20 +0100 Subject: [PATCH 07/48] add setup --- backend/src/stwo/circuit_builder.rs | 45 +++++++++++ backend/src/stwo/mod.rs | 4 +- backend/src/stwo/proof.rs | 21 +++++ backend/src/stwo/prover.rs | 116 ++++++++++++++++++++++++++-- pipeline/tests/pil.rs | 6 ++ 5 files changed, 186 insertions(+), 6 deletions(-) create mode 100644 backend/src/stwo/proof.rs diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index b1f7422df1..2e2a9fc05a 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -1,3 +1,4 @@ +use itertools::Itertools; use num_traits::Zero; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; @@ -26,6 +27,50 @@ use stwo_prover::core::ColumnVec; pub type PowdrComponent<'a, F> = FrameworkComponent>; +/// A description of the constraint system. +/// All of the data is derived from the analyzed PIL, but is materialized +/// here for performance reasons. +pub struct ConstraintSystem { + // for each witness column, the stage and index of this column in this stage + witness_columns: BTreeMap, + // for each fixed column, the index of this column in the fixed columns + fixed_columns: BTreeMap, + identities: Vec>, +} + +impl From<&Analyzed> for ConstraintSystem { + fn from(analyzed: &Analyzed) -> Self { + let identities = analyzed.identities.clone(); + + let fixed_columns = analyzed + .definitions_in_source_order(PolynomialType::Constant) + .flat_map(|(symbol, _)| symbol.array_elements()) + .enumerate() + .map(|(index, (_, id))| (id, index)) + .collect(); + + let witness_columns = analyzed + .definitions_in_source_order(PolynomialType::Committed) + .into_group_map_by(|(s, _)| s.stage.unwrap_or_default()) + .into_iter() + .flat_map(|(stage, symbols)| { + symbols + .into_iter() + .flat_map(|(s, _)| s.array_elements()) + .enumerate() + .map(move |(index_in_stage, (_, poly_id))| { + (poly_id, (stage as usize, index_in_stage)) + }) + }) + .collect(); + Self { + identities, + witness_columns, + fixed_columns, + } + } +} + pub(crate) fn gen_stwo_circuit_trace( witness: &[(String, Vec)], ) -> ColumnVec> diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 99cd858418..46a6d0e2c4 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -17,6 +17,7 @@ use stwo_prover::core::channel::{Blake2sChannel, Channel, MerkleChannel}; use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; mod circuit_builder; +mod proof; mod prover; #[allow(dead_code)] @@ -45,6 +46,7 @@ impl BackendFactory for RestrictedFactory { } let stwo: Box> = Box::new(StwoProver::new(pil, fixed)?); + Ok(stwo) } } @@ -54,7 +56,7 @@ generalize_factory!(Factory <- RestrictedFactory, [Mersenne31Field]); impl Backend for StwoProver where - SimdBackend: BackendForChannel, + SimdBackend: BackendForChannel + Send, MC: MerkleChannel, C: Channel, MC::H: DeserializeOwned + Serialize, diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs new file mode 100644 index 0000000000..7aef207675 --- /dev/null +++ b/backend/src/stwo/proof.rs @@ -0,0 +1,21 @@ +use std::collections::BTreeMap; +use stwo_prover::core::backend::BackendForChannel; +use stwo_prover::core::channel::{Channel, MerkleChannel}; +use stwo_prover::core::pcs::TreeVec; +use stwo_prover::core::pcs::{ + CommitmentSchemeProver, CommitmentSchemeVerifier, CommitmentTreeProver, PcsConfig, +}; + +/// For each possible size, the commitment and prover data +pub type TableProvingKeyCollection = BTreeMap>; + +pub struct TableProvingKey, MC: MerkleChannel> { + pub trees: TreeVec>, +} + +pub struct StarkProvingKey, MC: MerkleChannel> { + // for each table, the preprocessed data + pub preprocessed: BTreeMap>, +} + +unsafe impl, MC: MerkleChannel> Send for StarkProvingKey {} diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 4a638fc19e..6c535a6091 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -9,7 +9,8 @@ use std::io; use std::marker::PhantomData; use std::sync::Arc; -use crate::stwo::circuit_builder::{gen_stwo_circuit_trace, PowdrComponent, PowdrEval}; +use crate::stwo::circuit_builder::{gen_stwo_circuit_trace, PowdrComponent, PowdrEval, ConstraintSystem}; +use crate::stwo::proof::{TableProvingKeyCollection,TableProvingKey,StarkProvingKey}; use stwo_prover::constraint_framework::TraceLocationAllocator; use stwo_prover::core::prover::StarkProof; @@ -30,11 +31,15 @@ const FRI_NUM_QUERIES: usize = 100; const FRI_PROOF_OF_WORK_BITS: usize = 16; const LOG_LAST_LAYER_DEGREE_BOUND: usize = 0; -pub struct StwoProver { +pub struct StwoProver + Send, MC: MerkleChannel, C: Channel> { pub analyzed: Arc>, fixed: Arc)>>, + + /// The split analyzed PIL + split: BTreeMap, ConstraintSystem)>, + /// Proving key placeholder - _proving_key: Option<()>, + proving_key: Option>, /// Verifying key placeholder _verifying_key: Option<()>, _channel_marker: PhantomData, @@ -49,15 +54,116 @@ where C: Channel + Send, MC::H: DeserializeOwned + Serialize, PowdrComponent<'a, F>: ComponentProver, -{ +{ + pub fn setup(&mut self) { + let preprocessed: BTreeMap> = self + .split + .iter() + .filter_map(|(namespace, (pil, _))| { + // if we have neither fixed columns nor publics, we don't need to commit to anything + if pil.constant_count() + pil.publics_count() == 0 { + None + } else { + Some(( + namespace.to_string(), + pil.committed_polys_in_source_order() + .find_map(|(s, _)| s.degree) + .unwrap() + .iter() + .map(|size| { + // get the config + let config = get_config(); + + // commit to the fixed columns + let twiddles = B::precompute_twiddles( + CanonicCoset::new(self.analyzed.degree().ilog2() + 1 + FRI_LOG_BLOWUP as u32) + .circle_domain() + .half_coset, + ); + + // Setup protocol. + let prover_channel = &mut ::C::default(); + let commitment_scheme = &mut CommitmentSchemeProver::::new(config, &twiddles); + + // get fix_columns evaluations + let fixed_columns = machine_fixed_columns(&self.fixed, &self.analyzed); + + let domain = CanonicCoset::new( + fixed_columns + .keys() + .next() + .map(|&first_key| first_key.ilog2()) + .unwrap_or(0), + ) + .circle_domain(); + + let updated_fixed_columns: BTreeMap)>> = fixed_columns + .iter() + .map(|(key, vec)| { + let transformed_vec: Vec<(String, Vec)> = vec + .iter() + .map(|(name, slice)| { + let mut values: Vec = slice.to_vec(); // Clone the slice into a Vec + bit_reverse_coset_to_circle_domain_order(&mut values); // Apply bit reversal + (name.clone(), values) // Return the updated tuple + }) + .collect(); // Collect the updated vector + (*key, transformed_vec) // Rebuild the BTreeMap with transformed vectors + }) + .collect(); + + let constant_trace: ColumnVec> = + updated_fixed_columns + .values() + .flat_map(|vec| { + vec.iter().map(|(_name, values)| { + let values = values + .iter() + .map(|v| v.try_into_i32().unwrap().into()) + .collect(); + CircleEvaluation::new(domain, values) + }) + }) + .collect(); + + // Preprocessed trace + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(constant_trace.clone()); + tree_builder.commit(prover_channel); + let trees = std::mem::take(&mut commitment_scheme.trees); + + ( + self.analyzed.degree().ilog2() as usize, + TableProvingKey { + trees, + } + ) + }) + .collect(), + )) + } + }) + .collect(); + let proving_key = StarkProvingKey { preprocessed }; + + self.proving_key = Some(proving_key); + } pub fn new( analyzed: Arc>, fixed: Arc)>>, ) -> Result { + Ok(Self { + split: powdr_backend_utils::split_pil(&analyzed) + .into_iter() + .map(|(name, pil)| { + let constraint_system = ConstraintSystem::from(&pil); + (name, (pil, constraint_system)) + }) + .collect(), analyzed, fixed, - _proving_key: None, + proving_key: None, _verifying_key: None, _channel_marker: PhantomData, _backend_marker: PhantomData, diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index 27848f388b..e0f97caf33 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -275,6 +275,12 @@ fn stwo_fibonacci() { test_stwo(f, Default::default()); } +#[test] +fn stwo_fixed_columns() { + let f = "pil/fixed_columns.pil"; + test_stwo(f, Default::default()); +} + #[test] fn simple_div() { let f = "pil/simple_div.pil"; From 2c100ba7a71f522cd22fdecd02671874011bb287 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Sun, 24 Nov 2024 22:06:41 +0100 Subject: [PATCH 08/48] put setup into new function --- backend/src/stwo/circuit_builder.rs | 45 ------- backend/src/stwo/proof.rs | 8 +- backend/src/stwo/prover.rs | 188 ++++++++++------------------ 3 files changed, 69 insertions(+), 172 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 2e2a9fc05a..b1f7422df1 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -1,4 +1,3 @@ -use itertools::Itertools; use num_traits::Zero; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; @@ -27,50 +26,6 @@ use stwo_prover::core::ColumnVec; pub type PowdrComponent<'a, F> = FrameworkComponent>; -/// A description of the constraint system. -/// All of the data is derived from the analyzed PIL, but is materialized -/// here for performance reasons. -pub struct ConstraintSystem { - // for each witness column, the stage and index of this column in this stage - witness_columns: BTreeMap, - // for each fixed column, the index of this column in the fixed columns - fixed_columns: BTreeMap, - identities: Vec>, -} - -impl From<&Analyzed> for ConstraintSystem { - fn from(analyzed: &Analyzed) -> Self { - let identities = analyzed.identities.clone(); - - let fixed_columns = analyzed - .definitions_in_source_order(PolynomialType::Constant) - .flat_map(|(symbol, _)| symbol.array_elements()) - .enumerate() - .map(|(index, (_, id))| (id, index)) - .collect(); - - let witness_columns = analyzed - .definitions_in_source_order(PolynomialType::Committed) - .into_group_map_by(|(s, _)| s.stage.unwrap_or_default()) - .into_iter() - .flat_map(|(stage, symbols)| { - symbols - .into_iter() - .flat_map(|(s, _)| s.array_elements()) - .enumerate() - .map(move |(index_in_stage, (_, poly_id))| { - (poly_id, (stage as usize, index_in_stage)) - }) - }) - .collect(); - Self { - identities, - witness_columns, - fixed_columns, - } - } -} - pub(crate) fn gen_stwo_circuit_trace( witness: &[(String, Vec)], ) -> ColumnVec> diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 7aef207675..8fe76fd03e 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -1,10 +1,8 @@ use std::collections::BTreeMap; use stwo_prover::core::backend::BackendForChannel; -use stwo_prover::core::channel::{Channel, MerkleChannel}; +use stwo_prover::core::channel::MerkleChannel; +use stwo_prover::core::pcs::CommitmentTreeProver; use stwo_prover::core::pcs::TreeVec; -use stwo_prover::core::pcs::{ - CommitmentSchemeProver, CommitmentSchemeVerifier, CommitmentTreeProver, PcsConfig, -}; /// For each possible size, the commitment and prover data pub type TableProvingKeyCollection = BTreeMap>; @@ -18,4 +16,4 @@ pub struct StarkProvingKey, MC: MerkleChannel> { pub preprocessed: BTreeMap>, } -unsafe impl, MC: MerkleChannel> Send for StarkProvingKey {} +unsafe impl, MC: MerkleChannel> Send for TableProvingKey {} diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 6c535a6091..ee6077dd7e 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -9,12 +9,13 @@ use std::io; use std::marker::PhantomData; use std::sync::Arc; -use crate::stwo::circuit_builder::{gen_stwo_circuit_trace, PowdrComponent, PowdrEval, ConstraintSystem}; -use crate::stwo::proof::{TableProvingKeyCollection,TableProvingKey,StarkProvingKey}; +use crate::stwo::circuit_builder::{gen_stwo_circuit_trace, PowdrComponent, PowdrEval}; +use crate::stwo::proof::{StarkProvingKey, TableProvingKey, TableProvingKeyCollection}; use stwo_prover::constraint_framework::TraceLocationAllocator; use stwo_prover::core::prover::StarkProof; +use std::cell::RefCell; use stwo_prover::core::air::{Component, ComponentProver}; use stwo_prover::core::backend::{Backend, BackendForChannel}; use stwo_prover::core::channel::{Channel, MerkleChannel}; @@ -33,13 +34,9 @@ const LOG_LAST_LAYER_DEGREE_BOUND: usize = 0; pub struct StwoProver + Send, MC: MerkleChannel, C: Channel> { pub analyzed: Arc>, - fixed: Arc)>>, - - /// The split analyzed PIL - split: BTreeMap, ConstraintSystem)>, /// Proving key placeholder - proving_key: Option>, + proving_key: RefCell>>, /// Verifying key placeholder _verifying_key: Option<()>, _channel_marker: PhantomData, @@ -54,12 +51,18 @@ where C: Channel + Send, MC::H: DeserializeOwned + Serialize, PowdrComponent<'a, F>: ComponentProver, -{ - pub fn setup(&mut self) { - let preprocessed: BTreeMap> = self - .split +{ + pub fn new( + analyzed: Arc>, + fixed: Arc)>>, + ) -> Result { + let split: BTreeMap> = powdr_backend_utils::split_pil(&analyzed) + .into_iter() + .collect(); + + let preprocessed: BTreeMap> = split .iter() - .filter_map(|(namespace, (pil, _))| { + .filter_map(|(namespace, pil)| { // if we have neither fixed columns nor publics, we don't need to commit to anything if pil.constant_count() + pil.publics_count() == 0 { None @@ -75,19 +78,20 @@ where let config = get_config(); // commit to the fixed columns - let twiddles = B::precompute_twiddles( - CanonicCoset::new(self.analyzed.degree().ilog2() + 1 + FRI_LOG_BLOWUP as u32) + let twiddles = Arc::new(B::precompute_twiddles( + CanonicCoset::new(size.ilog2() + 1 + FRI_LOG_BLOWUP as u32) .circle_domain() .half_coset, - ); - + )); + // Setup protocol. let prover_channel = &mut ::C::default(); - let commitment_scheme = &mut CommitmentSchemeProver::::new(config, &twiddles); - + let mut commitment_scheme = + CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); + // get fix_columns evaluations - let fixed_columns = machine_fixed_columns(&self.fixed, &self.analyzed); - + let fixed_columns = machine_fixed_columns(&fixed, &analyzed); + let domain = CanonicCoset::new( fixed_columns .keys() @@ -96,47 +100,51 @@ where .unwrap_or(0), ) .circle_domain(); - - let updated_fixed_columns: BTreeMap)>> = fixed_columns + + let updated_fixed_columns: BTreeMap< + DegreeType, + Vec<(String, Vec)>, + > = fixed_columns .iter() .map(|(key, vec)| { let transformed_vec: Vec<(String, Vec)> = vec .iter() .map(|(name, slice)| { let mut values: Vec = slice.to_vec(); // Clone the slice into a Vec - bit_reverse_coset_to_circle_domain_order(&mut values); // Apply bit reversal + bit_reverse_coset_to_circle_domain_order( + &mut values, + ); // Apply bit reversal (name.clone(), values) // Return the updated tuple }) .collect(); // Collect the updated vector (*key, transformed_vec) // Rebuild the BTreeMap with transformed vectors }) .collect(); - - let constant_trace: ColumnVec> = - updated_fixed_columns - .values() - .flat_map(|vec| { - vec.iter().map(|(_name, values)| { - let values = values - .iter() - .map(|v| v.try_into_i32().unwrap().into()) - .collect(); - CircleEvaluation::new(domain, values) - }) + + let constant_trace: ColumnVec< + CircleEvaluation, + > = updated_fixed_columns + .values() + .flat_map(|vec| { + vec.iter().map(|(_name, values)| { + let values = values + .iter() + .map(|v| v.try_into_i32().unwrap().into()) + .collect(); + CircleEvaluation::new(domain, values) }) - .collect(); - + }) + .collect(); + // Preprocessed trace let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(constant_trace.clone()); tree_builder.commit(prover_channel); - let trees = std::mem::take(&mut commitment_scheme.trees); - - ( - self.analyzed.degree().ilog2() as usize, - TableProvingKey { - trees, - } + let trees = commitment_scheme.trees; + + ( + analyzed.degree().ilog2() as usize, + TableProvingKey { trees }, ) }) .collect(), @@ -146,24 +154,9 @@ where .collect(); let proving_key = StarkProvingKey { preprocessed }; - self.proving_key = Some(proving_key); - } - pub fn new( - analyzed: Arc>, - fixed: Arc)>>, - ) -> Result { - Ok(Self { - split: powdr_backend_utils::split_pil(&analyzed) - .into_iter() - .map(|(name, pil)| { - let constraint_system = ConstraintSystem::from(&pil); - (name, (pil, constraint_system)) - }) - .collect(), analyzed, - fixed, - proving_key: None, + proving_key: RefCell::new(Some(proving_key)), _verifying_key: None, _channel_marker: PhantomData, _backend_marker: PhantomData, @@ -183,70 +176,21 @@ where // Setup protocol. let prover_channel = &mut ::C::default(); - let commitment_scheme = &mut CommitmentSchemeProver::::new(config, &twiddles); - - // get fix_columns evaluations - let fixed_columns = machine_fixed_columns(&self.fixed, &self.analyzed); - - let domain = CanonicCoset::new( - fixed_columns - .keys() - .next() - .map(|&first_key| first_key.ilog2()) - .unwrap_or(0), - ) - .circle_domain(); - - let updated_fixed_columns: BTreeMap)>> = fixed_columns - .iter() - .map(|(key, vec)| { - let transformed_vec: Vec<(String, Vec)> = vec - .iter() - .map(|(name, slice)| { - let mut values: Vec = slice.to_vec(); // Clone the slice into a Vec - bit_reverse_coset_to_circle_domain_order(&mut values); // Apply bit reversal - (name.clone(), values) // Return the updated tuple - }) - .collect(); // Collect the updated vector - (*key, transformed_vec) // Rebuild the BTreeMap with transformed vectors - }) - .collect(); - - let constant_trace: ColumnVec> = - updated_fixed_columns - .values() - .flat_map(|vec| { - vec.iter().map(|(_name, values)| { - let values = values - .iter() - .map(|v| v.try_into_i32().unwrap().into()) - .collect(); - CircleEvaluation::new(domain, values) - }) - }) - .collect(); - - // Preprocessed trace - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(constant_trace.clone()); - tree_builder.commit(prover_channel); + //let commitment_scheme = &mut CommitmentSchemeProver::::new(config, &twiddles); + let trees = self + .proving_key + .borrow_mut() // Borrow as mutable using RefCell + .as_mut() + .and_then(|stark_proving_key| stark_proving_key.preprocessed.values_mut().next()) + .and_then(|table_collection| table_collection.values_mut().next()) + .map(|table_proving_key| std::mem::take(&mut table_proving_key.trees)) // Take ownership + .expect("Expected to find commitment_scheme in proving key"); + let mut commitment_scheme = CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); + commitment_scheme.trees = trees; let transformed_witness: Vec<(String, Vec)> = witness .iter() - .map(|(name, vec)| { - ( - name.clone(), - vec.iter() - .map(|&elem| { - if elem == F::from(2147483647) { - F::zero() - } else { - elem - } - }) - .collect(), - ) - }) + .map(|(name, vec)| (name.clone(), vec.to_vec())) .collect(); let witness: &Vec<(String, Vec)> = &transformed_witness @@ -272,7 +216,7 @@ where let proof = stwo_prover::core::prover::prove::( &[&component], prover_channel, - commitment_scheme, + &mut commitment_scheme, ) .unwrap(); From c9606b5d56d4f35c158c80772eccfc6fa6bdf163 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Sun, 24 Nov 2024 22:33:45 +0100 Subject: [PATCH 09/48] clean up --- backend/src/stwo/mod.rs | 4 +--- test_data/pil/fibo_no_publics.pil | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 46a6d0e2c4..181db3bdc5 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -20,12 +20,10 @@ mod circuit_builder; mod proof; mod prover; -#[allow(dead_code)] struct RestrictedFactory; impl BackendFactory for RestrictedFactory { - #[allow(unreachable_code)] #[allow(unused_variables)] fn create( &self, @@ -56,7 +54,7 @@ generalize_factory!(Factory <- RestrictedFactory, [Mersenne31Field]); impl Backend for StwoProver where - SimdBackend: BackendForChannel + Send, + SimdBackend: BackendForChannel, MC: MerkleChannel, C: Channel, MC::H: DeserializeOwned + Serialize, diff --git a/test_data/pil/fibo_no_publics.pil b/test_data/pil/fibo_no_publics.pil index 285cdf6b93..3f092cb3a3 100644 --- a/test_data/pil/fibo_no_publics.pil +++ b/test_data/pil/fibo_no_publics.pil @@ -1,4 +1,4 @@ -let N = 512; +let N = 262144; // This uses the alternative nomenclature as well. From 9756ef4b23059ebb2bd49c0d823d5848592bb89c Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 25 Nov 2024 09:01:50 +0100 Subject: [PATCH 10/48] clean up --- backend/src/stwo/mod.rs | 4 ---- backend/src/stwo/prover.rs | 10 ++-------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 181db3bdc5..31fa47104b 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -20,7 +20,6 @@ mod circuit_builder; mod proof; mod prover; - struct RestrictedFactory; impl BackendFactory for RestrictedFactory { @@ -39,9 +38,6 @@ impl BackendFactory for RestrictedFactory { if proving_key.is_some() { return Err(Error::BackendError("Proving key unused".to_string())); } - if pil.degrees().len() > 1 { - return Err(Error::NoVariableDegreeAvailable); - } let stwo: Box> = Box::new(StwoProver::new(pil, fixed)?); diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index ee6077dd7e..108813ae5e 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -67,6 +67,7 @@ where if pil.constant_count() + pil.publics_count() == 0 { None } else { + let fixed_columns = machine_fixed_columns(&fixed, &pil); Some(( namespace.to_string(), pil.committed_polys_in_source_order() @@ -76,7 +77,6 @@ where .map(|size| { // get the config let config = get_config(); - // commit to the fixed columns let twiddles = Arc::new(B::precompute_twiddles( CanonicCoset::new(size.ilog2() + 1 + FRI_LOG_BLOWUP as u32) @@ -89,9 +89,6 @@ where let mut commitment_scheme = CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); - // get fix_columns evaluations - let fixed_columns = machine_fixed_columns(&fixed, &analyzed); - let domain = CanonicCoset::new( fixed_columns .keys() @@ -142,10 +139,7 @@ where tree_builder.commit(prover_channel); let trees = commitment_scheme.trees; - ( - analyzed.degree().ilog2() as usize, - TableProvingKey { trees }, - ) + (size as usize, TableProvingKey { trees }) }) .collect(), )) From 9f9c20d0ba6086efd1e4e0610d8e376fc6845431 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 25 Nov 2024 09:20:51 +0100 Subject: [PATCH 11/48] clean up --- backend/src/stwo/prover.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 108813ae5e..1eaceb9ade 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -67,7 +67,7 @@ where if pil.constant_count() + pil.publics_count() == 0 { None } else { - let fixed_columns = machine_fixed_columns(&fixed, &pil); + let fixed_columns = machine_fixed_columns(&fixed, pil); Some(( namespace.to_string(), pil.committed_polys_in_source_order() From b99c1a9e1d70e7640787155276ca7a494d5c1cd9 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 25 Nov 2024 10:50:01 +0100 Subject: [PATCH 12/48] add challenge channel to tableProvingkey --- backend/src/stwo/proof.rs | 1 + backend/src/stwo/prover.rs | 29 +++++++++++++++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 8fe76fd03e..69e974d15e 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -9,6 +9,7 @@ pub type TableProvingKeyCollection = BTreeMap, MC: MerkleChannel> { pub trees: TreeVec>, + pub prover_channel: ::C, } pub struct StarkProvingKey, MC: MerkleChannel> { diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 1eaceb9ade..c5e44029fb 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -77,7 +77,7 @@ where .map(|size| { // get the config let config = get_config(); - // commit to the fixed columns + let twiddles = Arc::new(B::precompute_twiddles( CanonicCoset::new(size.ilog2() + 1 + FRI_LOG_BLOWUP as u32) .circle_domain() @@ -133,13 +133,19 @@ where }) .collect(); - // Preprocessed trace + // commit to the fixed columns let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(constant_trace.clone()); tree_builder.commit(prover_channel); let trees = commitment_scheme.trees; - (size as usize, TableProvingKey { trees }) + ( + size as usize, + TableProvingKey { + trees, + prover_channel: prover_channel.clone(), + }, + ) }) .collect(), )) @@ -168,9 +174,6 @@ where .half_coset, ); - // Setup protocol. - let prover_channel = &mut ::C::default(); - //let commitment_scheme = &mut CommitmentSchemeProver::::new(config, &twiddles); let trees = self .proving_key .borrow_mut() // Borrow as mutable using RefCell @@ -179,6 +182,16 @@ where .and_then(|table_collection| table_collection.values_mut().next()) .map(|table_proving_key| std::mem::take(&mut table_proving_key.trees)) // Take ownership .expect("Expected to find commitment_scheme in proving key"); + + let mut prover_channel = self + .proving_key + .borrow_mut() // Borrow as mutable using RefCell + .as_mut() + .and_then(|stark_proving_key| stark_proving_key.preprocessed.values_mut().next()) + .and_then(|table_collection| table_collection.values_mut().next()) + .map(|table_proving_key| std::mem::take(&mut table_proving_key.prover_channel)) // Take ownership + .expect("Expected to find commitment_scheme in proving key"); + let mut commitment_scheme = CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); commitment_scheme.trees = trees; @@ -200,7 +213,7 @@ where let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(trace); - tree_builder.commit(prover_channel); + tree_builder.commit(&mut prover_channel); let component = PowdrComponent::new( &mut TraceLocationAllocator::default(), @@ -209,7 +222,7 @@ where let proof = stwo_prover::core::prover::prove::( &[&component], - prover_channel, + &mut prover_channel, &mut commitment_scheme, ) .unwrap(); From e688e3ec4a57ff2eea6b978c1270af452f4b760a Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 25 Nov 2024 11:41:05 +0100 Subject: [PATCH 13/48] handle empty constant case --- backend/src/stwo/prover.rs | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index c5e44029fb..5a72e65a78 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -174,23 +174,27 @@ where .half_coset, ); - let trees = self + let (trees, mut prover_channel) = self .proving_key - .borrow_mut() // Borrow as mutable using RefCell + .borrow_mut() .as_mut() .and_then(|stark_proving_key| stark_proving_key.preprocessed.values_mut().next()) .and_then(|table_collection| table_collection.values_mut().next()) - .map(|table_proving_key| std::mem::take(&mut table_proving_key.trees)) // Take ownership - .expect("Expected to find commitment_scheme in proving key"); - - let mut prover_channel = self - .proving_key - .borrow_mut() // Borrow as mutable using RefCell - .as_mut() - .and_then(|stark_proving_key| stark_proving_key.preprocessed.values_mut().next()) - .and_then(|table_collection| table_collection.values_mut().next()) - .map(|table_proving_key| std::mem::take(&mut table_proving_key.prover_channel)) // Take ownership - .expect("Expected to find commitment_scheme in proving key"); + .map(|table_proving_key| { + ( + std::mem::take(&mut table_proving_key.trees), + std::mem::take(&mut table_proving_key.prover_channel), + ) + }) + .unwrap_or_else(|| { + let mut prover_channel = ::C::default(); + let mut commitment_scheme = + CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals([]); + tree_builder.commit(&mut prover_channel); + (commitment_scheme.trees, prover_channel) + }); let mut commitment_scheme = CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); commitment_scheme.trees = trees; From d9db23628555c6470473823cc73c966b18b8a7e4 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 25 Nov 2024 13:17:52 +0100 Subject: [PATCH 14/48] add more test, and comments --- backend/src/stwo/circuit_builder.rs | 8 ++++++-- backend/src/stwo/mod.rs | 4 ++++ backend/src/stwo/prover.rs | 15 ++++++--------- test_data/pil/incremental_one.pil | 11 +++++++++++ 4 files changed, 27 insertions(+), 11 deletions(-) create mode 100644 test_data/pil/incremental_one.pil diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index b1f7422df1..e6aebd6ac1 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -112,7 +112,8 @@ impl FrameworkEval for PowdrEval { .map(|poly_id| { ( *poly_id, - eval.get_preprocessed_column(PreprocessedColumn::Plonk(3)), + // PreprocessedColumn::Plonk(0) is unused argument in get_preprocessed_column,0 has no meaning + eval.get_preprocessed_column(PreprocessedColumn::Plonk(0)), ) }) .collect(); @@ -173,7 +174,10 @@ where false => witness_eval[&poly_id][0].clone(), true => witness_eval[&poly_id][1].clone(), }, - PolynomialType::Constant => constant_eval[&poly_id].clone(), + PolynomialType::Constant => match r.next { + false => constant_eval[&poly_id].clone(), + true => panic!("Next on constant polynomials is not supported"), + }, PolynomialType::Intermediate => { unimplemented!("Intermediate polynomials are not supported in stwo yet") } diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 31fa47104b..31b33b890c 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -38,6 +38,10 @@ impl BackendFactory for RestrictedFactory { if proving_key.is_some() { return Err(Error::BackendError("Proving key unused".to_string())); } + if pil.degrees().len() > 1 { + return Err(Error::NoVariableDegreeAvailable); + } + let stwo: Box> = Box::new(StwoProver::new(pil, fixed)?); diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 5a72e65a78..b42805cc91 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -40,8 +40,6 @@ pub struct StwoProver + Send, MC: MerkleChannel, C: /// Verifying key placeholder _verifying_key: Option<()>, _channel_marker: PhantomData, - _backend_marker: PhantomData, - _merkle_channel_marker: PhantomData, } impl<'a, F: FieldElement, B, MC, C> StwoProver @@ -97,8 +95,8 @@ where .unwrap_or(0), ) .circle_domain(); - - let updated_fixed_columns: BTreeMap< + // witness and constant traces need to be bit reversed + let bit_reversed_fixed_columns: BTreeMap< DegreeType, Vec<(String, Vec)>, > = fixed_columns @@ -120,7 +118,7 @@ where let constant_trace: ColumnVec< CircleEvaluation, - > = updated_fixed_columns + > = bit_reversed_fixed_columns .values() .flat_map(|vec| { vec.iter().map(|(_name, values)| { @@ -159,8 +157,6 @@ where proving_key: RefCell::new(Some(proving_key)), _verifying_key: None, _channel_marker: PhantomData, - _backend_marker: PhantomData, - _merkle_channel_marker: PhantomData, }) } pub fn prove(&self, witness: &[(String, Vec)]) -> Result, String> { @@ -173,7 +169,7 @@ where .circle_domain() .half_coset, ); - + //TODO: make machines with multi degree sizes work, one only the first one is taken, multi degrees error is handled by NoVariableDegreeAvailable in mod.rs let (trees, mut prover_channel) = self .proving_key .borrow_mut() @@ -195,7 +191,8 @@ where tree_builder.commit(&mut prover_channel); (commitment_scheme.trees, prover_channel) }); - + //get the commitment for constant columns + //TODO: different degree sizes machines need to have their own twiddles, but now as only one the first one is taken, only one twiddles is used let mut commitment_scheme = CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); commitment_scheme.trees = trees; diff --git a/test_data/pil/incremental_one.pil b/test_data/pil/incremental_one.pil new file mode 100644 index 0000000000..ccda39ccb4 --- /dev/null +++ b/test_data/pil/incremental_one.pil @@ -0,0 +1,11 @@ +let N = 32; + +// This uses the alternative nomenclature as well. + +namespace Incremental(N); + col fixed ISLAST(i) { if i == N - 1 { 1 } else { 0 } }; + col witness x, y; + + ISLAST * (x' - 1) = 0; + + (1-ISLAST) * (x' - x-1) = 0; \ No newline at end of file From 3ea95b5f8b7b4245db7cf5acfac398987fcff246 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 25 Nov 2024 13:20:55 +0100 Subject: [PATCH 15/48] add test in pil --- pipeline/tests/pil.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index e0f97caf33..a382bfdeb2 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -281,6 +281,12 @@ fn stwo_fixed_columns() { test_stwo(f, Default::default()); } +#[test] +fn stwo_incremental_one() { + let f = "pil/incremental_one.pil"; + test_stwo(f, Default::default()); +} + #[test] fn simple_div() { let f = "pil/simple_div.pil"; From 6fa35a717cddd328a661c57e80f51275ccc99419 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Wed, 27 Nov 2024 11:33:38 +0100 Subject: [PATCH 16/48] remove prover channel from table key --- backend/Cargo.toml | 1 - backend/src/stwo/proof.rs | 1 - backend/src/stwo/prover.rs | 18 +++++++----------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/backend/Cargo.toml b/backend/Cargo.toml index b065ecf6e3..e8773ab31a 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -63,7 +63,6 @@ p3-commit = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf2 p3-matrix = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } p3-uni-stark = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } # TODO: Change this to main branch when the `andrew/dev/update-toolchain` branch is merged,the main branch is using "nightly-2024-01-04", not compatiable with plonky3 -# stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "e6d10bc107c11cce54bb4aa152c3afa2e15e92c1" } stwo-prover = { git = "https://github.com/ShuangWu121/stwo.git", optional = true, rev = "564a4ddcde376ba0ae78da4d86ea5ad7338ef6fe" } strum = { version = "0.24.1", features = ["derive"] } diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 69e974d15e..8fe76fd03e 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -9,7 +9,6 @@ pub type TableProvingKeyCollection = BTreeMap, MC: MerkleChannel> { pub trees: TreeVec>, - pub prover_channel: ::C, } pub struct StarkProvingKey, MC: MerkleChannel> { diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index b42805cc91..54a729d924 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -137,13 +137,7 @@ where tree_builder.commit(prover_channel); let trees = commitment_scheme.trees; - ( - size as usize, - TableProvingKey { - trees, - prover_channel: prover_channel.clone(), - }, - ) + (size as usize, TableProvingKey { trees }) }) .collect(), )) @@ -177,10 +171,12 @@ where .and_then(|stark_proving_key| stark_proving_key.preprocessed.values_mut().next()) .and_then(|table_collection| table_collection.values_mut().next()) .map(|table_proving_key| { - ( - std::mem::take(&mut table_proving_key.trees), - std::mem::take(&mut table_proving_key.prover_channel), - ) + let mut prover_channel = ::C::default(); + let trees = std::mem::take(&mut table_proving_key.trees); + + //prover_channel consume the latest merkle tree commitment root. + MC::mix_root(&mut prover_channel, trees[0].commitment.root()); + (trees, prover_channel) }) .unwrap_or_else(|| { let mut prover_channel = ::C::default(); From 648435d0c70567e9bea7991cb0eea8ba83de4ddd Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Wed, 27 Nov 2024 13:54:43 +0100 Subject: [PATCH 17/48] avoid clone witness, using better API to do bit reverse order of the witness --- backend/src/stwo/circuit_builder.rs | 22 ++++++--- backend/src/stwo/prover.rs | 70 +++++++++++++---------------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index e6aebd6ac1..0d0b1b6a97 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -17,11 +17,12 @@ use stwo_prover::constraint_framework::preprocessed_columns::PreprocessedColumn; use stwo_prover::constraint_framework::{ EvalAtRow, FrameworkComponent, FrameworkEval, ORIGINAL_TRACE_IDX, }; -use stwo_prover::core::backend::ColumnOps; +use stwo_prover::core::backend::{Column, ColumnOps}; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; use stwo_prover::core::ColumnVec; pub type PowdrComponent<'a, F> = FrameworkComponent>; @@ -40,15 +41,24 @@ where .all(|(_name, vec)| vec.len() == witness[0].1.len()), "All Vec in witness must have the same length. Mismatch found!" ); + let domain = CanonicCoset::new(witness[0].1.len().ilog2()).circle_domain(); witness .iter() .map(|(_name, values)| { - let values = values - .iter() - .map(|v| v.try_into_i32().unwrap().into()) - .collect(); - CircleEvaluation::new(domain, values) + let mut column: >::Column = + >::Column::zeros(values.len()); + values.iter().enumerate().for_each(|(i, v)| { + column.set( + bit_reverse_index( + coset_index_to_circle_domain_index(i, values.len().ilog2()), + values.len().ilog2(), + ), + v.try_into_i32().unwrap().into(), + ); + }); + + CircleEvaluation::new(domain, column) }) .collect() } diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 54a729d924..5182d9fb8e 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -1,7 +1,7 @@ use powdr_ast::analyzed::Analyzed; use powdr_backend_utils::machine_fixed_columns; use powdr_executor::constant_evaluator::VariablySizedColumn; -use powdr_number::{DegreeType, FieldElement}; +use powdr_number::FieldElement; use serde::de::DeserializeOwned; use serde::ser::Serialize; use std::collections::BTreeMap; @@ -17,14 +17,14 @@ use stwo_prover::core::prover::StarkProof; use std::cell::RefCell; use stwo_prover::core::air::{Component, ComponentProver}; -use stwo_prover::core::backend::{Backend, BackendForChannel}; +use stwo_prover::core::backend::{Backend, BackendForChannel, Column, ColumnOps}; use stwo_prover::core::channel::{Channel, MerkleChannel}; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fri::FriConfig; use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; -use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order; +use stwo_prover::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; use stwo_prover::core::ColumnVec; const FRI_LOG_BLOWUP: usize = 1; @@ -95,38 +95,28 @@ where .unwrap_or(0), ) .circle_domain(); - // witness and constant traces need to be bit reversed - let bit_reversed_fixed_columns: BTreeMap< - DegreeType, - Vec<(String, Vec)>, - > = fixed_columns - .iter() - .map(|(key, vec)| { - let transformed_vec: Vec<(String, Vec)> = vec - .iter() - .map(|(name, slice)| { - let mut values: Vec = slice.to_vec(); // Clone the slice into a Vec - bit_reverse_coset_to_circle_domain_order( - &mut values, - ); // Apply bit reversal - (name.clone(), values) // Return the updated tuple - }) - .collect(); // Collect the updated vector - (*key, transformed_vec) // Rebuild the BTreeMap with transformed vectors - }) - .collect(); let constant_trace: ColumnVec< CircleEvaluation, - > = bit_reversed_fixed_columns + > = fixed_columns .values() .flat_map(|vec| { vec.iter().map(|(_name, values)| { - let values = values - .iter() - .map(|v| v.try_into_i32().unwrap().into()) - .collect(); - CircleEvaluation::new(domain, values) + let mut column: >::Column = + >::Column::zeros(values.len()); + values.iter().enumerate().for_each(|(i, v)| { + column.set( + bit_reverse_index( + coset_index_to_circle_domain_index( + i, + values.len().ilog2(), + ), + values.len().ilog2(), + ), + v.try_into_i32().unwrap().into(), + ); + }); + CircleEvaluation::new(domain, column) }) }) .collect(); @@ -192,18 +182,18 @@ where let mut commitment_scheme = CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); commitment_scheme.trees = trees; - let transformed_witness: Vec<(String, Vec)> = witness - .iter() - .map(|(name, vec)| (name.clone(), vec.to_vec())) - .collect(); + // let transformed_witness: Vec<(String, Vec)> = witness + // .iter() + // .map(|(name, vec)| (name.clone(), vec.to_vec())) + // .collect(); - let witness: &Vec<(String, Vec)> = &transformed_witness - .into_iter() - .map(|(name, mut vec)| { - bit_reverse_coset_to_circle_domain_order(&mut vec); - (name, vec) - }) - .collect(); + // let witness: &Vec<(String, Vec)> = &transformed_witness + // .into_iter() + // .map(|(name, mut vec)| { + // bit_reverse_coset_to_circle_domain_order(&mut vec); + // (name, vec) + // }) + // .collect(); // committed/witness trace let trace = gen_stwo_circuit_trace::(witness); From b4a2411753b536dd97672f262b1eb0aa4186f1c2 Mon Sep 17 00:00:00 2001 From: ShuangWu121 <47602565+ShuangWu121@users.noreply.github.com> Date: Thu, 28 Nov 2024 08:53:41 +0100 Subject: [PATCH 18/48] Update backend/src/stwo/circuit_builder.rs Co-authored-by: Thibaut Schaeffer --- backend/src/stwo/circuit_builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 0d0b1b6a97..b952a8d273 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -116,7 +116,7 @@ impl FrameworkEval for PowdrEval { ) }) .collect(); - let constant_eval: BTreeMap::F> = self + let constant_eval: BTreeMap<_, _> = self .constant_columns .keys() .map(|poly_id| { From fe6f22042a26a905f1a0d51e4afbe0f2dad3f56c Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Thu, 28 Nov 2024 17:52:32 +0100 Subject: [PATCH 19/48] cannot make setup work because of 'a --- backend/src/stwo/mod.rs | 13 +- backend/src/stwo/proof.rs | 15 +- backend/src/stwo/prover.rs | 319 +++++++++++++++++++++++++++---------- 3 files changed, 248 insertions(+), 99 deletions(-) diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 31b33b890c..f06f0aeb01 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -1,5 +1,6 @@ use serde::de::DeserializeOwned; use serde::Serialize; +use std::collections::BTreeMap; use std::io; use std::path::PathBuf; use std::sync::Arc; @@ -41,9 +42,11 @@ impl BackendFactory for RestrictedFactory { if pil.degrees().len() > 1 { return Err(Error::NoVariableDegreeAvailable); } - - let stwo: Box> = - Box::new(StwoProver::new(pil, fixed)?); + let split: BTreeMap> = + powdr_backend_utils::split_pil(&pil).into_iter().collect(); + let twiddle_map = prover::TwiddleMap::new(16, split); + let mut stwo: Box> = + Box::new(StwoProver::new(pil, fixed, &twiddle_map)?); Ok(stwo) } @@ -51,8 +54,8 @@ impl BackendFactory for RestrictedFactory { generalize_factory!(Factory <- RestrictedFactory, [Mersenne31Field]); -impl Backend - for StwoProver +impl<'a, T: FieldElement, MC: MerkleChannel + Send, C: Channel + Send> Backend + for StwoProver<'a, T, SimdBackend, MC, C> where SimdBackend: BackendForChannel, MC: MerkleChannel, diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 8fe76fd03e..5df9ce7387 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -1,19 +1,20 @@ use std::collections::BTreeMap; +use std::sync::Arc; use stwo_prover::core::backend::BackendForChannel; use stwo_prover::core::channel::MerkleChannel; -use stwo_prover::core::pcs::CommitmentTreeProver; use stwo_prover::core::pcs::TreeVec; +use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentTreeProver}; /// For each possible size, the commitment and prover data -pub type TableProvingKeyCollection = BTreeMap>; +pub type TableProvingKeyCollection<'a, B, MC> = BTreeMap>; -pub struct TableProvingKey, MC: MerkleChannel> { - pub trees: TreeVec>, +pub struct TableProvingKey<'a, B: BackendForChannel, MC: MerkleChannel> { + pub commitment_scheme: CommitmentSchemeProver<'a, B, MC>, } -pub struct StarkProvingKey, MC: MerkleChannel> { +pub struct StarkProvingKey<'a, B: BackendForChannel, MC: MerkleChannel> { // for each table, the preprocessed data - pub preprocessed: BTreeMap>, + pub preprocessed: BTreeMap>, } -unsafe impl, MC: MerkleChannel> Send for TableProvingKey {} +unsafe impl<'a, B: BackendForChannel, MC: MerkleChannel> Send for TableProvingKey<'a, B, MC> {} diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 5182d9fb8e..ef38ad7353 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -1,18 +1,26 @@ +use halo2_proofs::poly::commitment; use powdr_ast::analyzed::Analyzed; use powdr_backend_utils::machine_fixed_columns; use powdr_executor::constant_evaluator::VariablySizedColumn; use powdr_number::FieldElement; use serde::de::DeserializeOwned; use serde::ser::Serialize; +use std::borrow::Borrow; use std::collections::BTreeMap; -use std::io; use std::marker::PhantomData; +use std::rc::Rc; use std::sync::Arc; +use std::{clone, io}; +use stwo_prover::core::pcs::CommitmentTreeProver; +use stwo_prover::core::poly::twiddles::{self, TwiddleTree}; use crate::stwo::circuit_builder::{gen_stwo_circuit_trace, PowdrComponent, PowdrEval}; use crate::stwo::proof::{StarkProvingKey, TableProvingKey, TableProvingKeyCollection}; -use stwo_prover::constraint_framework::TraceLocationAllocator; +use stwo_prover::constraint_framework::{ + TraceLocationAllocator, ORIGINAL_TRACE_IDX, PREPROCESSED_TRACE_IDX, +}; +use stwo_prover::core::pcs::TreeVec; use stwo_prover::core::prover::StarkProof; use std::cell::RefCell; @@ -32,19 +40,75 @@ const FRI_NUM_QUERIES: usize = 100; const FRI_PROOF_OF_WORK_BITS: usize = 16; const LOG_LAST_LAYER_DEGREE_BOUND: usize = 0; -pub struct StwoProver + Send, MC: MerkleChannel, C: Channel> { +pub struct TwiddleMap, MC: MerkleChannel, T: FieldElement> { + twiddles_map: BTreeMap>>, + pub _channel_marker: PhantomData, + pub _channel_marker1: PhantomData, +} + +impl, MC: MerkleChannel, T: FieldElement> TwiddleMap { + pub fn new(size: usize, split: BTreeMap>) -> Self { + let twiddles_map: BTreeMap>> = split + .iter() + .filter_map(|(_, pil)| { + if pil.constant_count() + pil.publics_count() == 0 { + None + } else { + // precompute twiddles for all sizes in the PIL + let twiddles_size: Vec<(usize, Arc>)> = pil + .committed_polys_in_source_order() + .flat_map(|(s, _)| { + s.degree.iter().flat_map(|range| { + let min = range.min; + let max = range.max; + + // Iterate over powers of 2 from min to max + (min..=max) + .filter(|&size| size.is_power_of_two()) // Only take powers of 2 + .map(|size| { + // Compute twiddles for this size + let twiddles = B::precompute_twiddles( + CanonicCoset::new( + size.ilog2() + 1 + FRI_LOG_BLOWUP as u32, + ) + .circle_domain() + .half_coset, + ); + (size as usize, Arc::new(twiddles)) + }) + .collect::>() + }) + }) + .collect(); + Some(twiddles_size.into_iter()) + } + }) + .flatten() + .collect(); + Self { + twiddles_map, + _channel_marker: PhantomData, + _channel_marker1: PhantomData, + } + } +} + +pub struct StwoProver<'a, T, B: BackendForChannel + Send, MC: MerkleChannel, C: Channel> { pub analyzed: Arc>, + /// The split analyzed PIL + split: BTreeMap>, + pub fixed: Arc)>>, /// Proving key placeholder - proving_key: RefCell>>, + proving_key: RefCell>>, /// Verifying key placeholder _verifying_key: Option<()>, _channel_marker: PhantomData, } -impl<'a, F: FieldElement, B, MC, C> StwoProver +impl<'a, F: FieldElement, B, MC, C> StwoProver<'a, F, B, MC, C> where - B: Backend + Send + BackendForChannel, // Ensure B implements BackendForChannel + B: Backend + Send + BackendForChannel, MC: MerkleChannel + Send, C: Channel + Send, MC::H: DeserializeOwned + Serialize, @@ -53,12 +117,14 @@ where pub fn new( analyzed: Arc>, fixed: Arc)>>, + twiddle_map: &'a TwiddleMap, ) -> Result { let split: BTreeMap> = powdr_backend_utils::split_pil(&analyzed) .into_iter() .collect(); - let preprocessed: BTreeMap> = split + // commitment_scheme.twiddles is &'a TwiddleTree, in order to pass commitment_scheme to different functions, twiddles connot be owned by a temperary function. + let mut preprocessed: BTreeMap> = split .iter() .filter_map(|(namespace, pil)| { // if we have neither fixed columns nor publics, we don't need to commit to anything @@ -76,16 +142,20 @@ where // get the config let config = get_config(); - let twiddles = Arc::new(B::precompute_twiddles( - CanonicCoset::new(size.ilog2() + 1 + FRI_LOG_BLOWUP as u32) - .circle_domain() - .half_coset, - )); - // Setup protocol. let prover_channel = &mut ::C::default(); let mut commitment_scheme = - CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); + CommitmentSchemeProver::<'_, B, MC>::new( + config, + Arc::as_ref( + &twiddle_map + .twiddles_map + .iter() + .find(|(s, _)| **s == (size as usize)) + .unwrap() + .1, + ), + ); let domain = CanonicCoset::new( fixed_columns @@ -125,96 +195,162 @@ where let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(constant_trace.clone()); tree_builder.commit(prover_channel); - let trees = commitment_scheme.trees; + //let trees = commitment_scheme.trees; - (size as usize, TableProvingKey { trees }) + (size as usize, TableProvingKey { commitment_scheme }) }) .collect(), )) } }) .collect(); - let proving_key = StarkProvingKey { preprocessed }; + let proving_key = StarkProvingKey { + preprocessed: preprocessed, + }; Ok(Self { analyzed, + split, + fixed, proving_key: RefCell::new(Some(proving_key)), _verifying_key: None, _channel_marker: PhantomData, }) } + // pub fn setup(&'static mut self) { + // let mut preprocessed: BTreeMap> = self + // .split + // .iter() + // .filter_map(|(namespace, pil)| { + // // if we have neither fixed columns nor publics, we don't need to commit to anything + // if pil.constant_count() + pil.publics_count() == 0 { + // None + // } else { + // let fixed_columns = machine_fixed_columns(&self.fixed, pil); + // Some(( + // namespace.to_string(), + // pil.committed_polys_in_source_order() + // .find_map(|(s, _)| s.degree) + // .unwrap() + // .iter() + // .map(|size| { + // // get the config + // let config = get_config(); + + // // Setup protocol. + // let prover_channel = &mut ::C::default(); + // let mut commitment_scheme = + // CommitmentSchemeProver::<'_, B, MC>::new( + // config, + // Arc::as_ref(twiddles_map + // .twiddles + // .iter() + // .find(|(s, _)| **s == (size as usize)) + // .unwrap() + // .1), + // ); + + // let domain = CanonicCoset::new( + // fixed_columns + // .keys() + // .next() + // .map(|&first_key| first_key.ilog2()) + // .unwrap_or(0), + // ) + // .circle_domain(); + + // let constant_trace: ColumnVec< + // CircleEvaluation, + // > = fixed_columns + // .values() + // .flat_map(|vec| { + // vec.iter().map(|(_name, values)| { + // let mut column: >::Column = + // >::Column::zeros(values.len()); + // values.iter().enumerate().for_each(|(i, v)| { + // column.set( + // bit_reverse_index( + // coset_index_to_circle_domain_index( + // i, + // values.len().ilog2(), + // ), + // values.len().ilog2(), + // ), + // v.try_into_i32().unwrap().into(), + // ); + // }); + // CircleEvaluation::new(domain, column) + // }) + // }) + // .collect(); + + // // commit to the fixed columns + // let mut tree_builder = commitment_scheme.tree_builder(); + // tree_builder.extend_evals(constant_trace.clone()); + // tree_builder.commit(prover_channel); + // //let trees = commitment_scheme.trees; + + // (size as usize, TableProvingKey { commitment_scheme }) + // }) + // .collect(), + // )) + // } + // }) + // .collect(); + // let proving_key = StarkProvingKey { + // preprocessed: preprocessed, + // }; + + // self.proving_key = RefCell::new(Some(proving_key)); + // } pub fn prove(&self, witness: &[(String, Vec)]) -> Result, String> { - let config = get_config(); - // twiddles are used for FFT, they are computed in a bigger group than the eval domain. - // the eval domain is the half coset G_{2n} + - // twiddles are computed in half coset G_{4n} + , double the size of eval doamin. - let twiddles = B::precompute_twiddles( - CanonicCoset::new(self.analyzed.degree().ilog2() + 1 + FRI_LOG_BLOWUP as u32) - .circle_domain() - .half_coset, - ); - //TODO: make machines with multi degree sizes work, one only the first one is taken, multi degrees error is handled by NoVariableDegreeAvailable in mod.rs - let (trees, mut prover_channel) = self - .proving_key - .borrow_mut() - .as_mut() - .and_then(|stark_proving_key| stark_proving_key.preprocessed.values_mut().next()) - .and_then(|table_collection| table_collection.values_mut().next()) - .map(|table_proving_key| { - let mut prover_channel = ::C::default(); - let trees = std::mem::take(&mut table_proving_key.trees); - - //prover_channel consume the latest merkle tree commitment root. - MC::mix_root(&mut prover_channel, trees[0].commitment.root()); - (trees, prover_channel) - }) - .unwrap_or_else(|| { - let mut prover_channel = ::C::default(); - let mut commitment_scheme = - CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals([]); - tree_builder.commit(&mut prover_channel); - (commitment_scheme.trees, prover_channel) - }); - //get the commitment for constant columns - //TODO: different degree sizes machines need to have their own twiddles, but now as only one the first one is taken, only one twiddles is used - let mut commitment_scheme = CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles); - commitment_scheme.trees = trees; - - // let transformed_witness: Vec<(String, Vec)> = witness - // .iter() - // .map(|(name, vec)| (name.clone(), vec.to_vec())) - // .collect(); - - // let witness: &Vec<(String, Vec)> = &transformed_witness - // .into_iter() - // .map(|(name, mut vec)| { - // bit_reverse_coset_to_circle_domain_order(&mut vec); - // (name, vec) - // }) - // .collect(); - - // committed/witness trace - let trace = gen_stwo_circuit_trace::(witness); - - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); - tree_builder.commit(&mut prover_channel); + // Use RefCell to access proving_key mutably + if let Some(ref mut proving_key) = *self.proving_key.borrow_mut() { + // Access the proving_key without consuming the Option + let preprocessed = &mut proving_key.preprocessed; - let component = PowdrComponent::new( - &mut TraceLocationAllocator::default(), - PowdrEval::new(self.analyzed.clone()), - ); + // Use preprocessed as needed + let (commitment_scheme, mut prover_channel) = preprocessed + .iter_mut() + .next() + .and_then(|(_, table_collection)| table_collection.iter_mut().next()) + .map(|(_, table_proving_key)| { + let mut prover_channel = ::C::default(); + + // prover_channel consumes the latest Merkle tree commitment root + MC::mix_root( + &mut prover_channel, + table_proving_key.commitment_scheme.trees[0] + .commitment + .root(), + ); + (&mut table_proving_key.commitment_scheme, prover_channel) + }) + .unwrap_or_else(|| unimplemented!()); + + // committed/witness trace + let trace = gen_stwo_circuit_trace::(witness); - let proof = stwo_prover::core::prover::prove::( - &[&component], - &mut prover_channel, - &mut commitment_scheme, - ) - .unwrap(); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(&mut prover_channel); - Ok(bincode::serialize(&proof).unwrap()) + let component = PowdrComponent::new( + &mut TraceLocationAllocator::default(), + PowdrEval::new(self.analyzed.clone()), + ); + + let proof = stwo_prover::core::prover::prove::( + &[&component], + &mut prover_channel, + commitment_scheme, + ) + .unwrap(); + + Ok(bincode::serialize(&proof).unwrap()) + } else { + panic!("proving_key is None"); + } } pub fn verify(&self, proof: &[u8], _instances: &[F]) -> Result<(), String> { @@ -238,10 +374,19 @@ where ); // Retrieve the expected column sizes in each commitment interaction, from the AIR. + // the sizes include the degrees of the constant, witness, native lookups. Native lookups are not used yet. let sizes = component.trace_log_degree_bounds(); - commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); - commitment_scheme.commit(proof.commitments[1], &sizes[1], verifier_channel); + commitment_scheme.commit( + proof.commitments[PREPROCESSED_TRACE_IDX], + &sizes[PREPROCESSED_TRACE_IDX], + verifier_channel, + ); + commitment_scheme.commit( + proof.commitments[ORIGINAL_TRACE_IDX], + &sizes[ORIGINAL_TRACE_IDX], + verifier_channel, + ); stwo_prover::core::prover::verify(&[&component], verifier_channel, commitment_scheme, proof) .map_err(|e| e.to_string()) From 94c1ac291706249b93fabbcc2180a679b80a9c7e Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 29 Nov 2024 00:26:57 +0100 Subject: [PATCH 20/48] avoid using refcell, and std::mem::take --- backend/src/stwo/mod.rs | 11 +- backend/src/stwo/proof.rs | 20 +-- backend/src/stwo/prover.rs | 321 ++++++++++++------------------------- 3 files changed, 114 insertions(+), 238 deletions(-) diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index f06f0aeb01..8ec9c11da5 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -1,6 +1,5 @@ use serde::de::DeserializeOwned; use serde::Serialize; -use std::collections::BTreeMap; use std::io; use std::path::PathBuf; use std::sync::Arc; @@ -42,11 +41,9 @@ impl BackendFactory for RestrictedFactory { if pil.degrees().len() > 1 { return Err(Error::NoVariableDegreeAvailable); } - let split: BTreeMap> = - powdr_backend_utils::split_pil(&pil).into_iter().collect(); - let twiddle_map = prover::TwiddleMap::new(16, split); let mut stwo: Box> = - Box::new(StwoProver::new(pil, fixed, &twiddle_map)?); + Box::new(StwoProver::new(pil, fixed)?); + stwo.setup(); Ok(stwo) } @@ -54,8 +51,8 @@ impl BackendFactory for RestrictedFactory { generalize_factory!(Factory <- RestrictedFactory, [Mersenne31Field]); -impl<'a, T: FieldElement, MC: MerkleChannel + Send, C: Channel + Send> Backend - for StwoProver<'a, T, SimdBackend, MC, C> +impl Backend + for StwoProver where SimdBackend: BackendForChannel, MC: MerkleChannel, diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 5df9ce7387..ad21a43f40 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -1,20 +1,22 @@ use std::collections::BTreeMap; -use std::sync::Arc; use stwo_prover::core::backend::BackendForChannel; use stwo_prover::core::channel::MerkleChannel; -use stwo_prover::core::pcs::TreeVec; -use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentTreeProver}; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::poly::circle::CircleEvaluation; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::ColumnVec; /// For each possible size, the commitment and prover data -pub type TableProvingKeyCollection<'a, B, MC> = BTreeMap>; +pub type TableProvingKeyCollection = BTreeMap>; -pub struct TableProvingKey<'a, B: BackendForChannel, MC: MerkleChannel> { - pub commitment_scheme: CommitmentSchemeProver<'a, B, MC>, +pub struct TableProvingKey, MC: MerkleChannel> { + pub constant_trace_circle_domain: ColumnVec>, + pub _marker: std::marker::PhantomData, } -pub struct StarkProvingKey<'a, B: BackendForChannel, MC: MerkleChannel> { +pub struct StarkProvingKey, MC: MerkleChannel> { // for each table, the preprocessed data - pub preprocessed: BTreeMap>, + pub preprocessed: BTreeMap>, } -unsafe impl<'a, B: BackendForChannel, MC: MerkleChannel> Send for TableProvingKey<'a, B, MC> {} +unsafe impl, MC: MerkleChannel> Send for TableProvingKey {} diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index ef38ad7353..5f36dece68 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -1,18 +1,14 @@ -use halo2_proofs::poly::commitment; use powdr_ast::analyzed::Analyzed; use powdr_backend_utils::machine_fixed_columns; use powdr_executor::constant_evaluator::VariablySizedColumn; use powdr_number::FieldElement; use serde::de::DeserializeOwned; use serde::ser::Serialize; -use std::borrow::Borrow; use std::collections::BTreeMap; +use std::io; use std::marker::PhantomData; -use std::rc::Rc; use std::sync::Arc; -use std::{clone, io}; -use stwo_prover::core::pcs::CommitmentTreeProver; -use stwo_prover::core::poly::twiddles::{self, TwiddleTree}; +use stwo_prover::core::poly::twiddles::TwiddleTree; use crate::stwo::circuit_builder::{gen_stwo_circuit_trace, PowdrComponent, PowdrEval}; use crate::stwo::proof::{StarkProvingKey, TableProvingKey, TableProvingKeyCollection}; @@ -20,10 +16,8 @@ use crate::stwo::proof::{StarkProvingKey, TableProvingKey, TableProvingKeyCollec use stwo_prover::constraint_framework::{ TraceLocationAllocator, ORIGINAL_TRACE_IDX, PREPROCESSED_TRACE_IDX, }; -use stwo_prover::core::pcs::TreeVec; use stwo_prover::core::prover::StarkProof; -use std::cell::RefCell; use stwo_prover::core::air::{Component, ComponentProver}; use stwo_prover::core::backend::{Backend, BackendForChannel, Column, ColumnOps}; use stwo_prover::core::channel::{Channel, MerkleChannel}; @@ -40,73 +34,20 @@ const FRI_NUM_QUERIES: usize = 100; const FRI_PROOF_OF_WORK_BITS: usize = 16; const LOG_LAST_LAYER_DEGREE_BOUND: usize = 0; -pub struct TwiddleMap, MC: MerkleChannel, T: FieldElement> { - twiddles_map: BTreeMap>>, - pub _channel_marker: PhantomData, - pub _channel_marker1: PhantomData, -} - -impl, MC: MerkleChannel, T: FieldElement> TwiddleMap { - pub fn new(size: usize, split: BTreeMap>) -> Self { - let twiddles_map: BTreeMap>> = split - .iter() - .filter_map(|(_, pil)| { - if pil.constant_count() + pil.publics_count() == 0 { - None - } else { - // precompute twiddles for all sizes in the PIL - let twiddles_size: Vec<(usize, Arc>)> = pil - .committed_polys_in_source_order() - .flat_map(|(s, _)| { - s.degree.iter().flat_map(|range| { - let min = range.min; - let max = range.max; - - // Iterate over powers of 2 from min to max - (min..=max) - .filter(|&size| size.is_power_of_two()) // Only take powers of 2 - .map(|size| { - // Compute twiddles for this size - let twiddles = B::precompute_twiddles( - CanonicCoset::new( - size.ilog2() + 1 + FRI_LOG_BLOWUP as u32, - ) - .circle_domain() - .half_coset, - ); - (size as usize, Arc::new(twiddles)) - }) - .collect::>() - }) - }) - .collect(); - Some(twiddles_size.into_iter()) - } - }) - .flatten() - .collect(); - Self { - twiddles_map, - _channel_marker: PhantomData, - _channel_marker1: PhantomData, - } - } -} - -pub struct StwoProver<'a, T, B: BackendForChannel + Send, MC: MerkleChannel, C: Channel> { +pub struct StwoProver + Send, MC: MerkleChannel, C: Channel> { pub analyzed: Arc>, /// The split analyzed PIL split: BTreeMap>, pub fixed: Arc)>>, /// Proving key placeholder - proving_key: RefCell>>, + proving_key: Option>, /// Verifying key placeholder _verifying_key: Option<()>, _channel_marker: PhantomData, } -impl<'a, F: FieldElement, B, MC, C> StwoProver<'a, F, B, MC, C> +impl<'a, F: FieldElement, B, MC, C> StwoProver where B: Backend + Send + BackendForChannel, MC: MerkleChannel + Send, @@ -117,21 +58,31 @@ where pub fn new( analyzed: Arc>, fixed: Arc)>>, - twiddle_map: &'a TwiddleMap, ) -> Result { let split: BTreeMap> = powdr_backend_utils::split_pil(&analyzed) .into_iter() .collect(); + Ok(Self { + analyzed, + split, + fixed, + proving_key: None, + _verifying_key: None, + _channel_marker: PhantomData, + }) + } + pub fn setup(&mut self) { // commitment_scheme.twiddles is &'a TwiddleTree, in order to pass commitment_scheme to different functions, twiddles connot be owned by a temperary function. - let mut preprocessed: BTreeMap> = split + let preprocessed: BTreeMap> = self + .split .iter() .filter_map(|(namespace, pil)| { // if we have neither fixed columns nor publics, we don't need to commit to anything if pil.constant_count() + pil.publics_count() == 0 { None } else { - let fixed_columns = machine_fixed_columns(&fixed, pil); + let fixed_columns = machine_fixed_columns(&self.fixed, pil); Some(( namespace.to_string(), pil.committed_polys_in_source_order() @@ -139,24 +90,6 @@ where .unwrap() .iter() .map(|size| { - // get the config - let config = get_config(); - - // Setup protocol. - let prover_channel = &mut ::C::default(); - let mut commitment_scheme = - CommitmentSchemeProver::<'_, B, MC>::new( - config, - Arc::as_ref( - &twiddle_map - .twiddles_map - .iter() - .find(|(s, _)| **s == (size as usize)) - .unwrap() - .1, - ), - ); - let domain = CanonicCoset::new( fixed_columns .keys() @@ -191,166 +124,110 @@ where }) .collect(); - // commit to the fixed columns - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(constant_trace.clone()); - tree_builder.commit(prover_channel); - //let trees = commitment_scheme.trees; - - (size as usize, TableProvingKey { commitment_scheme }) + ( + size as usize, + TableProvingKey { + constant_trace_circle_domain: constant_trace, + _marker: PhantomData, + }, + ) }) .collect(), )) } }) .collect(); - let proving_key = StarkProvingKey { - preprocessed: preprocessed, - }; - - Ok(Self { - analyzed, - split, - fixed, - proving_key: RefCell::new(Some(proving_key)), - _verifying_key: None, - _channel_marker: PhantomData, - }) + let proving_key = StarkProvingKey { preprocessed }; + self.proving_key = Some(proving_key); } - // pub fn setup(&'static mut self) { - // let mut preprocessed: BTreeMap> = self - // .split - // .iter() - // .filter_map(|(namespace, pil)| { - // // if we have neither fixed columns nor publics, we don't need to commit to anything - // if pil.constant_count() + pil.publics_count() == 0 { - // None - // } else { - // let fixed_columns = machine_fixed_columns(&self.fixed, pil); - // Some(( - // namespace.to_string(), - // pil.committed_polys_in_source_order() - // .find_map(|(s, _)| s.degree) - // .unwrap() - // .iter() - // .map(|size| { - // // get the config - // let config = get_config(); - - // // Setup protocol. - // let prover_channel = &mut ::C::default(); - // let mut commitment_scheme = - // CommitmentSchemeProver::<'_, B, MC>::new( - // config, - // Arc::as_ref(twiddles_map - // .twiddles - // .iter() - // .find(|(s, _)| **s == (size as usize)) - // .unwrap() - // .1), - // ); - // let domain = CanonicCoset::new( - // fixed_columns - // .keys() - // .next() - // .map(|&first_key| first_key.ilog2()) - // .unwrap_or(0), - // ) - // .circle_domain(); - - // let constant_trace: ColumnVec< - // CircleEvaluation, - // > = fixed_columns - // .values() - // .flat_map(|vec| { - // vec.iter().map(|(_name, values)| { - // let mut column: >::Column = - // >::Column::zeros(values.len()); - // values.iter().enumerate().for_each(|(i, v)| { - // column.set( - // bit_reverse_index( - // coset_index_to_circle_domain_index( - // i, - // values.len().ilog2(), - // ), - // values.len().ilog2(), - // ), - // v.try_into_i32().unwrap().into(), - // ); - // }); - // CircleEvaluation::new(domain, column) - // }) - // }) - // .collect(); + pub fn prove(&self, witness: &[(String, Vec)]) -> Result, String> { + let config = get_config(); + let twiddles_map: BTreeMap> = self + .split + .iter() + .filter_map(|(_, pil)| { + if pil.constant_count() + pil.publics_count() == 0 { + None + } else { + // precompute twiddles for all sizes in the PIL + let twiddles_size: Vec<(usize, TwiddleTree)> = pil + .committed_polys_in_source_order() + .flat_map(|(s, _)| { + s.degree.iter().flat_map(|range| { + let min = range.min; + let max = range.max; - // // commit to the fixed columns - // let mut tree_builder = commitment_scheme.tree_builder(); - // tree_builder.extend_evals(constant_trace.clone()); - // tree_builder.commit(prover_channel); - // //let trees = commitment_scheme.trees; + // Iterate over powers of 2 from min to max + (min..=max) + .filter(|&size| size.is_power_of_two()) // Only take powers of 2 + .map(|size| { + // Compute twiddles for this size + let twiddles = B::precompute_twiddles( + CanonicCoset::new( + size.ilog2() + 1 + FRI_LOG_BLOWUP as u32, + ) + .circle_domain() + .half_coset, + ); + (size as usize, twiddles) + }) + .collect::>() + }) + }) + .collect(); + Some(twiddles_size.into_iter()) + } + }) + .flatten() + .collect(); + // Use RefCell to access proving_key mutably + let prover_channel = &mut ::C::default(); + let mut commitment_scheme = + CommitmentSchemeProver::<'_, B, MC>::new(config, twiddles_map.iter().next().unwrap().1); - // (size as usize, TableProvingKey { commitment_scheme }) - // }) - // .collect(), - // )) - // } - // }) - // .collect(); - // let proving_key = StarkProvingKey { - // preprocessed: preprocessed, - // }; + let mut tree_builder = commitment_scheme.tree_builder(); - // self.proving_key = RefCell::new(Some(proving_key)); - // } - pub fn prove(&self, witness: &[(String, Vec)]) -> Result, String> { - // Use RefCell to access proving_key mutably - if let Some(ref mut proving_key) = *self.proving_key.borrow_mut() { + if let Some(proving_key) = &self.proving_key { // Access the proving_key without consuming the Option - let preprocessed = &mut proving_key.preprocessed; + let preprocessed = &proving_key.preprocessed; // Use preprocessed as needed - let (commitment_scheme, mut prover_channel) = preprocessed - .iter_mut() + preprocessed + .iter() .next() - .and_then(|(_, table_collection)| table_collection.iter_mut().next()) + .and_then(|(_, table_collection)| table_collection.iter().next()) .map(|(_, table_proving_key)| { - let mut prover_channel = ::C::default(); - - // prover_channel consumes the latest Merkle tree commitment root - MC::mix_root( - &mut prover_channel, - table_proving_key.commitment_scheme.trees[0] - .commitment - .root(), - ); - (&mut table_proving_key.commitment_scheme, prover_channel) + tree_builder + .extend_evals(table_proving_key.constant_trace_circle_domain.clone()); + tree_builder.commit(prover_channel); }) .unwrap_or_else(|| unimplemented!()); + } else { + tree_builder.extend_evals([]); + tree_builder.commit(prover_channel); + } - // committed/witness trace - let trace = gen_stwo_circuit_trace::(witness); + // committed/witness trace + let trace = gen_stwo_circuit_trace::(witness); - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); - tree_builder.commit(&mut prover_channel); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(prover_channel); - let component = PowdrComponent::new( - &mut TraceLocationAllocator::default(), - PowdrEval::new(self.analyzed.clone()), - ); + let component = PowdrComponent::new( + &mut TraceLocationAllocator::default(), + PowdrEval::new(self.analyzed.clone()), + ); - let proof = stwo_prover::core::prover::prove::( - &[&component], - &mut prover_channel, - commitment_scheme, - ) - .unwrap(); + let proof = stwo_prover::core::prover::prove::( + &[&component], + prover_channel, + &mut commitment_scheme, + ) + .unwrap(); - Ok(bincode::serialize(&proof).unwrap()) - } else { - panic!("proving_key is None"); - } + Ok(bincode::serialize(&proof).unwrap()) } pub fn verify(&self, proof: &[u8], _instances: &[F]) -> Result<(), String> { From f19f55758739dcc4f9e63acde4a1f22e827fa274 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 29 Nov 2024 10:29:36 +0100 Subject: [PATCH 21/48] fix error with empty constant, simplified code --- backend/src/stwo/proof.rs | 2 +- backend/src/stwo/prover.rs | 89 +++++++++++++++++--------------------- 2 files changed, 40 insertions(+), 51 deletions(-) diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index ad21a43f40..676167c089 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -16,7 +16,7 @@ pub struct TableProvingKey, MC: MerkleChannel> { pub struct StarkProvingKey, MC: MerkleChannel> { // for each table, the preprocessed data - pub preprocessed: BTreeMap>, + pub preprocessed: Option>>, } unsafe impl, MC: MerkleChannel> Send for TableProvingKey {} diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 5f36dece68..64ffe938c9 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -41,7 +41,7 @@ pub struct StwoProver + Send, MC: MerkleChannel, C: pub fixed: Arc)>>, /// Proving key placeholder - proving_key: Option>, + proving_key: StarkProvingKey, /// Verifying key placeholder _verifying_key: Option<()>, _channel_marker: PhantomData, @@ -67,7 +67,7 @@ where analyzed, split, fixed, - proving_key: None, + proving_key: StarkProvingKey { preprocessed: None }, _verifying_key: None, _channel_marker: PhantomData, }) @@ -137,49 +137,42 @@ where } }) .collect(); - let proving_key = StarkProvingKey { preprocessed }; - self.proving_key = Some(proving_key); + let proving_key = StarkProvingKey { + preprocessed: Some(preprocessed), + }; + self.proving_key = proving_key; } pub fn prove(&self, witness: &[(String, Vec)]) -> Result, String> { let config = get_config(); let twiddles_map: BTreeMap> = self .split - .iter() - .filter_map(|(_, pil)| { - if pil.constant_count() + pil.publics_count() == 0 { - None - } else { - // precompute twiddles for all sizes in the PIL - let twiddles_size: Vec<(usize, TwiddleTree)> = pil - .committed_polys_in_source_order() - .flat_map(|(s, _)| { - s.degree.iter().flat_map(|range| { - let min = range.min; - let max = range.max; + .values() + .flat_map(|pil| { + // Precompute twiddles for all sizes in the PIL + pil.committed_polys_in_source_order() + .flat_map(|(s, _)| { + s.degree.iter().flat_map(|range| { + let min = range.min; + let max = range.max; - // Iterate over powers of 2 from min to max - (min..=max) - .filter(|&size| size.is_power_of_two()) // Only take powers of 2 - .map(|size| { - // Compute twiddles for this size - let twiddles = B::precompute_twiddles( - CanonicCoset::new( - size.ilog2() + 1 + FRI_LOG_BLOWUP as u32, - ) + // Iterate over powers of 2 from min to max + (min..=max) + .filter(|&size| size.is_power_of_two()) // Only take powers of 2 + .map(|size| { + // Compute twiddles for this size + let twiddles = B::precompute_twiddles( + CanonicCoset::new(size.ilog2() + 1 + FRI_LOG_BLOWUP as u32) .circle_domain() .half_coset, - ); - (size as usize, twiddles) - }) - .collect::>() - }) + ); + (size as usize, twiddles) + }) + .collect::>() // Collect results into a Vec }) - .collect(); - Some(twiddles_size.into_iter()) - } + }) + .collect::>() // Collect the inner results into a Vec }) - .flatten() .collect(); // Use RefCell to access proving_key mutably let prover_channel = &mut ::C::default(); @@ -188,25 +181,21 @@ where let mut tree_builder = commitment_scheme.tree_builder(); - if let Some(proving_key) = &self.proving_key { - // Access the proving_key without consuming the Option - let preprocessed = &proving_key.preprocessed; - - // Use preprocessed as needed - preprocessed - .iter() - .next() - .and_then(|(_, table_collection)| table_collection.iter().next()) - .map(|(_, table_proving_key)| { - tree_builder - .extend_evals(table_proving_key.constant_trace_circle_domain.clone()); - tree_builder.commit(prover_channel); + if let Some((_, table_proving_key)) = + self.proving_key + .preprocessed + .as_ref() + .and_then(|preprocessed| { + preprocessed + .iter() + .find_map(|(_, table_collection)| table_collection.iter().next()) }) - .unwrap_or_else(|| unimplemented!()); + { + tree_builder.extend_evals(table_proving_key.constant_trace_circle_domain.clone()); } else { - tree_builder.extend_evals([]); - tree_builder.commit(prover_channel); + tree_builder.extend_evals(Vec::new()); } + tree_builder.commit(prover_channel); // committed/witness trace let trace = gen_stwo_circuit_trace::(witness); From 14f76bac454a548c9e9e731dc5fe21361164f3f4 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 29 Nov 2024 11:02:53 +0100 Subject: [PATCH 22/48] add enumerate to plonk(i) --- backend/src/stwo/circuit_builder.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index b952a8d273..8f48f380ff 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -119,11 +119,12 @@ impl FrameworkEval for PowdrEval { let constant_eval: BTreeMap<_, _> = self .constant_columns .keys() - .map(|poly_id| { + .enumerate() + .map(|(i,poly_id)| { ( *poly_id, // PreprocessedColumn::Plonk(0) is unused argument in get_preprocessed_column,0 has no meaning - eval.get_preprocessed_column(PreprocessedColumn::Plonk(0)), + eval.get_preprocessed_column(PreprocessedColumn::Plonk(i)), ) }) .collect(); From 07ceb5d17251ef6365c8e033ebe94f316140c628 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 29 Nov 2024 11:18:04 +0100 Subject: [PATCH 23/48] clean up --- backend/src/stwo/circuit_builder.rs | 6 +++--- backend/src/stwo/prover.rs | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 8f48f380ff..7992dd3388 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -1,6 +1,7 @@ use num_traits::Zero; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; +use std::sync::Arc; extern crate alloc; use alloc::{collections::btree_map::BTreeMap, string::String, vec::Vec}; @@ -8,7 +9,6 @@ use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, Analyzed, Identity, }; use powdr_number::{FieldElement, LargeInt}; -use std::sync::Arc; use powdr_ast::analyzed::{ AlgebraicUnaryOperation, AlgebraicUnaryOperator, PolyID, PolynomialType, @@ -120,10 +120,10 @@ impl FrameworkEval for PowdrEval { .constant_columns .keys() .enumerate() - .map(|(i,poly_id)| { + .map(|(i, poly_id)| { ( *poly_id, - // PreprocessedColumn::Plonk(0) is unused argument in get_preprocessed_column,0 has no meaning + // PreprocessedColumn::Plonk(i) is unused argument in get_preprocessed_column eval.get_preprocessed_column(PreprocessedColumn::Plonk(i)), ) }) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 64ffe938c9..a1dfdaf4ce 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -8,7 +8,6 @@ use std::collections::BTreeMap; use std::io; use std::marker::PhantomData; use std::sync::Arc; -use stwo_prover::core::poly::twiddles::TwiddleTree; use crate::stwo::circuit_builder::{gen_stwo_circuit_trace, PowdrComponent, PowdrEval}; use crate::stwo::proof::{StarkProvingKey, TableProvingKey, TableProvingKeyCollection}; @@ -25,6 +24,7 @@ use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fri::FriConfig; use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::twiddles::TwiddleTree; use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; use stwo_prover::core::ColumnVec; @@ -38,9 +38,10 @@ pub struct StwoProver + Send, MC: MerkleChannel, C: pub analyzed: Arc>, /// The split analyzed PIL split: BTreeMap>, + /// The value of the fixed columns pub fixed: Arc)>>, - /// Proving key placeholder + /// Proving key proving_key: StarkProvingKey, /// Verifying key placeholder _verifying_key: Option<()>, @@ -73,7 +74,6 @@ where }) } pub fn setup(&mut self) { - // commitment_scheme.twiddles is &'a TwiddleTree, in order to pass commitment_scheme to different functions, twiddles connot be owned by a temperary function. let preprocessed: BTreeMap> = self .split .iter() @@ -174,13 +174,14 @@ where .collect::>() // Collect the inner results into a Vec }) .collect(); - // Use RefCell to access proving_key mutably + // only the frist one is used, machines with varying sizes are not supported yet, and it is checked in backendfactory create function. let prover_channel = &mut ::C::default(); let mut commitment_scheme = CommitmentSchemeProver::<'_, B, MC>::new(config, twiddles_map.iter().next().unwrap().1); let mut tree_builder = commitment_scheme.tree_builder(); + // only the frist one is used, machines with varying sizes are not supported yet, and it is checked in backendfactory create function. if let Some((_, table_proving_key)) = self.proving_key .preprocessed @@ -193,7 +194,7 @@ where { tree_builder.extend_evals(table_proving_key.constant_trace_circle_domain.clone()); } else { - tree_builder.extend_evals(Vec::new()); + tree_builder.extend_evals([]); } tree_builder.commit(prover_channel); From c422ffde6c131be22e2590d76558a1d09e84feca Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 29 Nov 2024 11:59:42 +0100 Subject: [PATCH 24/48] add more comment --- backend/src/stwo/prover.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index a1dfdaf4ce..e3b23b34cd 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -74,6 +74,8 @@ where }) } pub fn setup(&mut self) { + // machines with varying sizes are not supported yet, and it is checked in backendfactory create function. + //TODO: support machines with varying sizes let preprocessed: BTreeMap> = self .split .iter() @@ -95,7 +97,7 @@ where .keys() .next() .map(|&first_key| first_key.ilog2()) - .unwrap_or(0), + .unwrap(), ) .circle_domain(); From d3f7782c41abe7e75021d887d6dee7d7cb4fc669 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 29 Nov 2024 13:51:56 +0100 Subject: [PATCH 25/48] add fail test --- pipeline/src/test_util.rs | 22 ++++++++++++++++++++++ pipeline/tests/pil.rs | 24 +++++++++++++++++++++++- test_data/pil/fibo_no_publics.pil | 2 +- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index d66c6b0a32..4365ce0135 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -598,6 +598,28 @@ pub fn test_stwo(file_name: &str, inputs: Vec) { .collect(); pipeline.verify(&proof, &[publics]).unwrap(); } +#[cfg(feature = "stwo")] +pub fn assert_proofs_fail_for_invalid_witnesses_stwo( + file_name: &str, + witness: &[(String, Vec)], +) { + let pipeline = Pipeline::::default() + .from_file(resolve_test_file(file_name)) + .set_witness(convert_witness(witness)); + + assert!(pipeline + .clone() + .with_backend(powdr_backend::BackendType::Stwo, None) + .compute_proof() + .is_err()); +} + +#[cfg(not(feature = "stwo"))] +pub fn assert_proofs_fail_for_invalid_witnesses_stwo( + _file_name: &str, + _witness: &[(String, Vec)], +) { +} #[cfg(not(feature = "stwo"))] pub fn test_stwo(_file_name: &str, _inputs: Vec) {} diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index a382bfdeb2..ca75e1ec67 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -5,7 +5,8 @@ use powdr_pipeline::{ test_util::{ assert_proofs_fail_for_invalid_witnesses, assert_proofs_fail_for_invalid_witnesses_estark, assert_proofs_fail_for_invalid_witnesses_halo2, - assert_proofs_fail_for_invalid_witnesses_pilcom, gen_estark_proof, + assert_proofs_fail_for_invalid_witnesses_pilcom, + assert_proofs_fail_for_invalid_witnesses_stwo, gen_estark_proof, gen_estark_proof_with_backend_variant, make_prepared_pipeline, make_simple_prepared_pipeline, regular_test, run_pilcom_with_backend_variant, test_halo2, test_halo2_with_backend_variant, test_pilcom, test_plonky3_with_backend_variant, test_stwo, @@ -286,7 +287,28 @@ fn stwo_incremental_one() { let f = "pil/incremental_one.pil"; test_stwo(f, Default::default()); } +#[test] +fn fibonacci_invalid_witness_stwo() { + let f = "pil/fibo_no_publics.pil"; + + // Changed one value and then continued. + // The following constraint should fail in row 1: + // (1-ISLAST) * (x' - y) = 0; + let witness = vec![ + ("Fibonacci::x".to_string(), vec![1, 1, 10, 3]), + ("Fibonacci::y".to_string(), vec![1, 2, 3, 13]), + ]; + assert_proofs_fail_for_invalid_witnesses_stwo(f, &witness); + // All constraints are valid, except the initial row. + // The following constraint should fail in row 3: + // ISLAST * (y' - 1) = 0; + let witness = vec![ + ("Fibonacci::x".to_string(), vec![1, 2, 3, 5]), + ("Fibonacci::y".to_string(), vec![2, 3, 5, 8]), + ]; + assert_proofs_fail_for_invalid_witnesses_stwo(f, &witness); +} #[test] fn simple_div() { let f = "pil/simple_div.pil"; diff --git a/test_data/pil/fibo_no_publics.pil b/test_data/pil/fibo_no_publics.pil index 3f092cb3a3..94c674b863 100644 --- a/test_data/pil/fibo_no_publics.pil +++ b/test_data/pil/fibo_no_publics.pil @@ -1,4 +1,4 @@ -let N = 262144; +let N = 4; // This uses the alternative nomenclature as well. From d4e0bb17b5305b7fedf2902e1877e2769783881e Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Sat, 30 Nov 2024 18:54:45 +0100 Subject: [PATCH 26/48] fix test case --- backend/src/stwo/prover.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index e3b23b34cd..061108f93c 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -212,12 +212,16 @@ where PowdrEval::new(self.analyzed.clone()), ); - let proof = stwo_prover::core::prover::prove::( + let proof_result = stwo_prover::core::prover::prove::( &[&component], prover_channel, &mut commitment_scheme, - ) - .unwrap(); + ); + + let proof = match proof_result { + Ok(value) => value, + Err(e) => return Err(e.to_string()), // Propagate the error instead of panicking + }; Ok(bincode::serialize(&proof).unwrap()) } From 998a240fd703f96d272a81f0bc949438da24772a Mon Sep 17 00:00:00 2001 From: ShuangWu121 <47602565+ShuangWu121@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:15:10 +0100 Subject: [PATCH 27/48] Update backend/src/stwo/prover.rs Co-authored-by: Thibaut Schaeffer --- backend/src/stwo/prover.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 061108f93c..78f1122a6c 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -176,7 +176,7 @@ where .collect::>() // Collect the inner results into a Vec }) .collect(); - // only the frist one is used, machines with varying sizes are not supported yet, and it is checked in backendfactory create function. + // only the first one is used, machines with varying sizes are not supported yet, and it is checked in backendfactory create function. let prover_channel = &mut ::C::default(); let mut commitment_scheme = CommitmentSchemeProver::<'_, B, MC>::new(config, twiddles_map.iter().next().unwrap().1); From 26c8ce90c770c9a38778277778d973ccb9fc7d7e Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Wed, 4 Dec 2024 13:14:41 +0100 Subject: [PATCH 28/48] make gen_stwo_circle_column work on slice --- backend/src/stwo/circuit_builder.rs | 52 +++++++++---------- backend/src/stwo/prover.rs | 78 ++++++++++++++++++----------- 2 files changed, 72 insertions(+), 58 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 7992dd3388..c3363c6d5f 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -1,4 +1,5 @@ use num_traits::Zero; +use rayon::slice; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; use std::sync::Arc; @@ -20,47 +21,40 @@ use stwo_prover::constraint_framework::{ use stwo_prover::core::backend::{Column, ColumnOps}; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; -use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; use stwo_prover::core::ColumnVec; pub type PowdrComponent<'a, F> = FrameworkComponent>; -pub(crate) fn gen_stwo_circuit_trace( - witness: &[(String, Vec)], -) -> ColumnVec> +pub fn gen_stwo_circle_column( + domain: CircleDomain, + slice: &Vec, +) -> CircleEvaluation where - T: FieldElement, //only Merenne31Field is supported, checked in runtime - B: FieldOps + ColumnOps, // Ensure B implements FieldOps for M31 + T: FieldElement, + B: FieldOps + ColumnOps, + F: ExtensionOf, { assert!( - witness - .iter() - .all(|(_name, vec)| vec.len() == witness[0].1.len()), - "All Vec in witness must have the same length. Mismatch found!" + slice.len().ilog2() == (domain.size().ilog2() as u32), + "column size must be equal to domain size" ); + let mut column: >::Column = + >::Column::zeros(slice.len()); + slice.iter().enumerate().for_each(|(i, v)| { + column.set( + bit_reverse_index( + coset_index_to_circle_domain_index(i, slice.len().ilog2()), + slice.len().ilog2(), + ), + v.try_into_i32().unwrap().into(), + ); + }); - let domain = CanonicCoset::new(witness[0].1.len().ilog2()).circle_domain(); - witness - .iter() - .map(|(_name, values)| { - let mut column: >::Column = - >::Column::zeros(values.len()); - values.iter().enumerate().for_each(|(i, v)| { - column.set( - bit_reverse_index( - coset_index_to_circle_domain_index(i, values.len().ilog2()), - values.len().ilog2(), - ), - v.try_into_i32().unwrap().into(), - ); - }); - - CircleEvaluation::new(domain, column) - }) - .collect() + CircleEvaluation::new(domain, column) } pub struct PowdrEval { diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 78f1122a6c..8f71a3f3f3 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -9,7 +9,7 @@ use std::io; use std::marker::PhantomData; use std::sync::Arc; -use crate::stwo::circuit_builder::{gen_stwo_circuit_trace, PowdrComponent, PowdrEval}; +use crate::stwo::circuit_builder::{gen_stwo_circle_column, PowdrComponent, PowdrEval}; use crate::stwo::proof::{StarkProvingKey, TableProvingKey, TableProvingKeyCollection}; use stwo_prover::constraint_framework::{ @@ -23,7 +23,7 @@ use stwo_prover::core::channel::{Channel, MerkleChannel}; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fri::FriConfig; use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; -use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; use stwo_prover::core::poly::twiddles::TwiddleTree; use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; @@ -76,15 +76,27 @@ where pub fn setup(&mut self) { // machines with varying sizes are not supported yet, and it is checked in backendfactory create function. //TODO: support machines with varying sizes + let domain_map: BTreeMap = self + .analyzed + .degrees() + .iter() + .map(|size| { + ( + (size.ilog2() as usize), + CanonicCoset::new(size.ilog2()).circle_domain(), + ) + }) + .collect(); let preprocessed: BTreeMap> = self .split .iter() .filter_map(|(namespace, pil)| { // if we have neither fixed columns nor publics, we don't need to commit to anything - if pil.constant_count() + pil.publics_count() == 0 { + if pil.constant_count() == 0 { None } else { let fixed_columns = machine_fixed_columns(&self.fixed, pil); + Some(( namespace.to_string(), pil.committed_polys_in_source_order() @@ -92,39 +104,22 @@ where .unwrap() .iter() .map(|size| { - let domain = CanonicCoset::new( - fixed_columns - .keys() - .next() - .map(|&first_key| first_key.ilog2()) - .unwrap(), - ) - .circle_domain(); - let constant_trace: ColumnVec< CircleEvaluation, > = fixed_columns .values() .flat_map(|vec| { vec.iter().map(|(_name, values)| { - let mut column: >::Column = - >::Column::zeros(values.len()); - values.iter().enumerate().for_each(|(i, v)| { - column.set( - bit_reverse_index( - coset_index_to_circle_domain_index( - i, - values.len().ilog2(), - ), - values.len().ilog2(), - ), - v.try_into_i32().unwrap().into(), - ); - }); - CircleEvaluation::new(domain, column) + gen_stwo_circle_column::( + *domain_map + .get(&(values.len().ilog2() as usize)) + .unwrap(), + &(values.to_vec()), + ) }) }) .collect(); + // Collect into a `Vec` ( size as usize, @@ -147,6 +142,17 @@ where pub fn prove(&self, witness: &[(String, Vec)]) -> Result, String> { let config = get_config(); + let domain_map: BTreeMap = self + .analyzed + .degrees() + .iter() + .map(|size| { + ( + (size.ilog2() as usize), + CanonicCoset::new(size.ilog2()).circle_domain(), + ) + }) + .collect(); let twiddles_map: BTreeMap> = self .split .values() @@ -200,8 +206,22 @@ where } tree_builder.commit(prover_channel); - // committed/witness trace - let trace = gen_stwo_circuit_trace::(witness); + assert!( + witness + .iter() + .all(|(_name, vec)| vec.len() == witness[0].1.len()), + "All Vec in witness must have the same length. Mismatch found!" + ); + + let trace: ColumnVec> = witness + .iter() + .map(|(_name, values)| { + gen_stwo_circle_column::( + *domain_map.get(&(values.len().ilog2() as usize)).unwrap(), + values, + ) + }) + .collect(); let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(trace); From 567b0f86438d682f486a505121445c540bf0c432 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Thu, 5 Dec 2024 16:13:47 +0100 Subject: [PATCH 29/48] support constant column with next reference --- backend/Cargo.toml | 2 +- backend/src/stwo/circuit_builder.rs | 130 ++++++++++++++++++++++++---- backend/src/stwo/prover.rs | 51 +++++++++-- 3 files changed, 158 insertions(+), 25 deletions(-) diff --git a/backend/Cargo.toml b/backend/Cargo.toml index e8773ab31a..8797c48fcb 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -63,7 +63,7 @@ p3-commit = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf2 p3-matrix = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } p3-uni-stark = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } # TODO: Change this to main branch when the `andrew/dev/update-toolchain` branch is merged,the main branch is using "nightly-2024-01-04", not compatiable with plonky3 -stwo-prover = { git = "https://github.com/ShuangWu121/stwo.git", optional = true, rev = "564a4ddcde376ba0ae78da4d86ea5ad7338ef6fe" } +stwo-prover = { git = "https://github.com/ShuangWu121/stwo.git", optional = true, rev = "564a4ddcde376ba0ae78da4d86ea5ad7338ef6fe",features = ["parallel"] } strum = { version = "0.24.1", features = ["derive"] } log = "0.4.17" diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index c3363c6d5f..eb38ba0029 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -1,11 +1,10 @@ use num_traits::Zero; -use rayon::slice; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; use std::sync::Arc; extern crate alloc; -use alloc::{collections::btree_map::BTreeMap, string::String, vec::Vec}; +use alloc::collections::btree_map::BTreeMap; use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, Analyzed, Identity, }; @@ -21,16 +20,15 @@ use stwo_prover::constraint_framework::{ use stwo_prover::core::backend::{Column, ColumnOps}; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; -use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; +use stwo_prover::core::poly::circle::{CircleDomain, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; -use stwo_prover::core::ColumnVec; pub type PowdrComponent<'a, F> = FrameworkComponent>; pub fn gen_stwo_circle_column( domain: CircleDomain, - slice: &Vec, + slice: &[T], ) -> CircleEvaluation where T: FieldElement, @@ -39,7 +37,7 @@ where F: ExtensionOf, { assert!( - slice.len().ilog2() == (domain.size().ilog2() as u32), + slice.len().ilog2() == domain.size().ilog2(), "column size must be equal to domain size" ); let mut column: >::Column = @@ -60,6 +58,7 @@ where pub struct PowdrEval { analyzed: Arc>, witness_columns: BTreeMap, + constant_with_next_columns: BTreeMap, constant_columns: BTreeMap, } @@ -72,16 +71,28 @@ impl PowdrEval { .map(|(index, (_, id))| (id, index)) .collect(); + let constant_with_next_list = get_constant_with_next_list(&analyzed); + + let constant_with_next_columns: BTreeMap = analyzed + .definitions_in_source_order(PolynomialType::Constant) + .flat_map(|(symbol, _)| symbol.array_elements()) + .enumerate() + .filter(|(_, (_, id))| constant_with_next_list.contains(&(id.id as usize))) + .map(|(index, (_, id))| (id, index)) + .collect(); + let constant_columns: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Constant) .flat_map(|(symbol, _)| symbol.array_elements()) .enumerate() + .filter(|(_, (_, id))| !constant_with_next_list.contains(&(id.id as usize))) .map(|(index, (_, id))| (id, index)) .collect(); Self { analyzed, witness_columns, + constant_with_next_columns, constant_columns, } } @@ -110,6 +121,18 @@ impl FrameworkEval for PowdrEval { ) }) .collect(); + + let constant_with_next_eval: BTreeMap::F; 2]> = self + .constant_with_next_columns + .keys() + .map(|poly_id| { + ( + *poly_id, + eval.next_interaction_mask(ORIGINAL_TRACE_IDX, [0, 1]), + ) + }) + .collect(); + let constant_eval: BTreeMap<_, _> = self .constant_columns .keys() @@ -129,8 +152,12 @@ impl FrameworkEval for PowdrEval { { match id { Identity::Polynomial(identity) => { - let expr = - to_stwo_expression(&identity.expression, &witness_eval, &constant_eval); + let expr = to_stwo_expression( + &identity.expression, + &witness_eval, + &constant_with_next_eval, + &constant_eval, + ); eval.add_constraint(expr); } Identity::Connect(..) => { @@ -153,6 +180,7 @@ impl FrameworkEval for PowdrEval { fn to_stwo_expression( expr: &AlgebraicExpression, witness_eval: &BTreeMap, + constant_with_next_eval: &BTreeMap, constant_eval: &BTreeMap, ) -> F where @@ -179,10 +207,19 @@ where false => witness_eval[&poly_id][0].clone(), true => witness_eval[&poly_id][1].clone(), }, - PolynomialType::Constant => match r.next { - false => constant_eval[&poly_id].clone(), - true => panic!("Next on constant polynomials is not supported"), - }, + PolynomialType::Constant => { + if !constant_with_next_eval.contains_key(&poly_id) { + match r.next { + false => constant_eval[&poly_id].clone(), + true => panic!("Next on a constant polynomial filter fails"), + } + } else { + match r.next { + false => constant_with_next_eval[&poly_id][0].clone(), + true => constant_with_next_eval[&poly_id][1].clone(), + } + } + } PolynomialType::Intermediate => { unimplemented!("Intermediate polynomials are not supported in stwo yet") } @@ -198,15 +235,18 @@ where right, }) => match **right { AlgebraicExpression::Number(n) => { - let left = to_stwo_expression(left, witness_eval, constant_eval); + let left = + to_stwo_expression(left, witness_eval, constant_with_next_eval, constant_eval); (0u32..n.to_integer().try_into_u32().unwrap()) .fold(F::one(), |acc, _| acc * left.clone()) } _ => unimplemented!("pow with non-constant exponent"), }, AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { - let left = to_stwo_expression(left, witness_eval, constant_eval); - let right = to_stwo_expression(right, witness_eval, constant_eval); + let left = + to_stwo_expression(left, witness_eval, constant_with_next_eval, constant_eval); + let right = + to_stwo_expression(right, witness_eval, constant_with_next_eval, constant_eval); match op { Add => left + right, @@ -216,7 +256,8 @@ where } } AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => { - let expr = to_stwo_expression(expr, witness_eval, constant_eval); + let expr = + to_stwo_expression(expr, witness_eval, constant_with_next_eval, constant_eval); match op { AlgebraicUnaryOperator::Minus => -expr, @@ -227,3 +268,60 @@ where } } } + +pub fn constant_with_next_to_witness_col( + expr: &AlgebraicExpression, + constant_with_next_list: &mut Vec, +) { + use AlgebraicBinaryOperator::*; + match expr { + AlgebraicExpression::Reference(r) => { + let poly_id = r.poly_id; + + match poly_id.ptype { + PolynomialType::Committed => {} + PolynomialType::Constant => match r.next { + false => {} + true => { + constant_with_next_list.push(r.poly_id.id as usize); + } + }, + PolynomialType::Intermediate => {} + } + } + AlgebraicExpression::PublicReference(..) => {} + AlgebraicExpression::Number(_) => {} + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { + left, + op: Pow, + right, + }) => match **right { + AlgebraicExpression::Number(n) => { + let left = constant_with_next_to_witness_col::(left, constant_with_next_list); + } + _ => unimplemented!("pow with non-constant exponent"), + }, + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { + constant_with_next_to_witness_col::(left, constant_with_next_list); + constant_with_next_to_witness_col::(right, constant_with_next_list); + } + AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => { + constant_with_next_to_witness_col::(expr, constant_with_next_list); + } + AlgebraicExpression::Challenge(_challenge) => {} + } +} + +pub fn get_constant_with_next_list(analyzed: &Arc>) -> Vec { + let mut all_constant_with_next: Vec = Vec::new(); + for id in analyzed.identities_with_inlined_intermediate_polynomials() { + if let Identity::Polynomial(identity) = id { + let mut constant_with_next: Vec = Vec::new(); + constant_with_next_to_witness_col::(&identity.expression, &mut constant_with_next); + all_constant_with_next.extend(constant_with_next) + } + } + all_constant_with_next.sort_unstable(); + all_constant_with_next.dedup(); + all_constant_with_next +} diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 8f71a3f3f3..350cd4cf74 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -1,4 +1,4 @@ -use powdr_ast::analyzed::Analyzed; +use powdr_ast::analyzed::{Analyzed, Identity}; use powdr_backend_utils::machine_fixed_columns; use powdr_executor::constant_evaluator::VariablySizedColumn; use powdr_number::FieldElement; @@ -9,7 +9,9 @@ use std::io; use std::marker::PhantomData; use std::sync::Arc; -use crate::stwo::circuit_builder::{gen_stwo_circle_column, PowdrComponent, PowdrEval}; +use crate::stwo::circuit_builder::{ + gen_stwo_circle_column, get_constant_with_next_list, PowdrComponent, PowdrEval, +}; use crate::stwo::proof::{StarkProvingKey, TableProvingKey, TableProvingKeyCollection}; use stwo_prover::constraint_framework::{ @@ -18,7 +20,7 @@ use stwo_prover::constraint_framework::{ use stwo_prover::core::prover::StarkProof; use stwo_prover::core::air::{Component, ComponentProver}; -use stwo_prover::core::backend::{Backend, BackendForChannel, Column, ColumnOps}; +use stwo_prover::core::backend::{Backend, BackendForChannel}; use stwo_prover::core::channel::{Channel, MerkleChannel}; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fri::FriConfig; @@ -26,7 +28,6 @@ use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, P use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; use stwo_prover::core::poly::twiddles::TwiddleTree; use stwo_prover::core::poly::BitReversedOrder; -use stwo_prover::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; use stwo_prover::core::ColumnVec; const FRI_LOG_BLOWUP: usize = 1; @@ -87,6 +88,7 @@ where ) }) .collect(); + let preprocessed: BTreeMap> = self .split .iter() @@ -99,6 +101,7 @@ where Some(( namespace.to_string(), + //why here it is committed_polys_in_source_order() instead of constant polys? pil.committed_polys_in_source_order() .find_map(|(s, _)| s.degree) .unwrap() @@ -114,12 +117,11 @@ where *domain_map .get(&(values.len().ilog2() as usize)) .unwrap(), - &(values.to_vec()), + values, ) }) }) .collect(); - // Collect into a `Vec` ( size as usize, @@ -190,6 +192,9 @@ where let mut tree_builder = commitment_scheme.tree_builder(); // only the frist one is used, machines with varying sizes are not supported yet, and it is checked in backendfactory create function. + let constant_list: Vec = get_constant_with_next_list(&self.analyzed); + + //commit to the constant polynomials with next reference constraint if let Some((_, table_proving_key)) = self.proving_key .preprocessed @@ -200,7 +205,15 @@ where .find_map(|(_, table_collection)| table_collection.iter().next()) }) { - tree_builder.extend_evals(table_proving_key.constant_trace_circle_domain.clone()); + tree_builder.extend_evals( + table_proving_key + .constant_trace_circle_domain + .clone() + .into_iter() // Convert it into an iterator + .enumerate() // Enumerate to get (index, value) + .filter(|(index, _)| !constant_list.contains(index)) // Keep only elements whose index is not in `constant_list` + .map(|(_, element)| element), + ); } else { tree_builder.extend_evals([]); } @@ -213,7 +226,7 @@ where "All Vec in witness must have the same length. Mismatch found!" ); - let trace: ColumnVec> = witness + let mut trace: ColumnVec> = witness .iter() .map(|(_name, values)| { gen_stwo_circle_column::( @@ -223,6 +236,28 @@ where }) .collect(); + if let Some((_, table_proving_key)) = + self.proving_key + .preprocessed + .as_ref() + .and_then(|preprocessed| { + preprocessed + .iter() + .find_map(|(_, table_collection)| table_collection.iter().next()) + }) + { + let constants_with_next: Vec> = + table_proving_key + .constant_trace_circle_domain + .clone() + .into_iter() + .enumerate() + .filter(|(index, _)| constant_list.contains(index)) // Keep only elements whose index is not in `constant_list` + .map(|(_, element)| element) + .collect(); + trace.extend(constants_with_next); + } + let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(trace); tree_builder.commit(prover_channel); From 2e627dd1b66d6ec8eb4d7ef5fc49d6017a7f294b Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Thu, 5 Dec 2024 16:35:32 +0100 Subject: [PATCH 30/48] use a wrong clippy command, now fixed it --- backend/src/stwo/circuit_builder.rs | 8 ++++---- backend/src/stwo/prover.rs | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index eb38ba0029..e9514ea15b 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -296,16 +296,16 @@ pub fn constant_with_next_to_witness_col( op: Pow, right, }) => match **right { - AlgebraicExpression::Number(n) => { - let left = constant_with_next_to_witness_col::(left, constant_with_next_list); + AlgebraicExpression::Number(_) => { + constant_with_next_to_witness_col::(left, constant_with_next_list); } _ => unimplemented!("pow with non-constant exponent"), }, - AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op: _, right }) => { constant_with_next_to_witness_col::(left, constant_with_next_list); constant_with_next_to_witness_col::(right, constant_with_next_list); } - AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => { + AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op: _, expr }) => { constant_with_next_to_witness_col::(expr, constant_with_next_list); } AlgebraicExpression::Challenge(_challenge) => {} diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 350cd4cf74..5f90869dc9 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -1,4 +1,4 @@ -use powdr_ast::analyzed::{Analyzed, Identity}; +use powdr_ast::analyzed::Analyzed; use powdr_backend_utils::machine_fixed_columns; use powdr_executor::constant_evaluator::VariablySizedColumn; use powdr_number::FieldElement; From 0f05f7245ee5d7ab307264174b66f480c3282f22 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Thu, 5 Dec 2024 17:03:47 +0100 Subject: [PATCH 31/48] use identities, so no panic for lookups --- backend/src/stwo/circuit_builder.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index e9514ea15b..8e0b97c14d 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -146,10 +146,7 @@ impl FrameworkEval for PowdrEval { }) .collect(); - for id in self - .analyzed - .identities_with_inlined_intermediate_polynomials() - { + for id in self.analyzed.identities.clone() { match id { Identity::Polynomial(identity) => { let expr = to_stwo_expression( From bc4bf3046523f2d90b4f4d51e7d2ba3f2c4c4f03 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 6 Dec 2024 09:42:17 +0100 Subject: [PATCH 32/48] add test for fixed col with next reference, remove mc generic in proving key --- backend/src/stwo/circuit_builder.rs | 1 + backend/src/stwo/mod.rs | 7 ++++--- backend/src/stwo/proof.rs | 15 +++++++-------- backend/src/stwo/prover.rs | 29 +++++++++++++++++------------ pipeline/tests/pil.rs | 7 +++++++ test_data/pil/fixed_with_next.pil | 14 ++++++++++++++ 6 files changed, 50 insertions(+), 23 deletions(-) create mode 100644 test_data/pil/fixed_with_next.pil diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 8e0b97c14d..feffc3eb14 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -309,6 +309,7 @@ pub fn constant_with_next_to_witness_col( } } +// This function creates a list of indexs of the constant polynomials that have next references constraint pub fn get_constant_with_next_list(analyzed: &Arc>) -> Vec { let mut all_constant_with_next: Vec = Vec::new(); for id in analyzed.identities_with_inlined_intermediate_polynomials() { diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 8ec9c11da5..87b497e063 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -38,9 +38,10 @@ impl BackendFactory for RestrictedFactory { if proving_key.is_some() { return Err(Error::BackendError("Proving key unused".to_string())); } - if pil.degrees().len() > 1 { - return Err(Error::NoVariableDegreeAvailable); - } + // if pil.degrees().len() > 1 { + // return Err(Error::NoVariableDegreeAvailable); + // } + let mut stwo: Box> = Box::new(StwoProver::new(pil, fixed)?); stwo.setup(); diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 676167c089..077108aa83 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -1,22 +1,21 @@ use std::collections::BTreeMap; -use stwo_prover::core::backend::BackendForChannel; -use stwo_prover::core::channel::MerkleChannel; +use stwo_prover::core::backend::Backend; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::poly::circle::CircleEvaluation; use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::ColumnVec; /// For each possible size, the commitment and prover data -pub type TableProvingKeyCollection = BTreeMap>; +pub type TableProvingKeyCollection = BTreeMap>; -pub struct TableProvingKey, MC: MerkleChannel> { +#[derive(Debug)] +pub struct TableProvingKey { pub constant_trace_circle_domain: ColumnVec>, - pub _marker: std::marker::PhantomData, } -pub struct StarkProvingKey, MC: MerkleChannel> { +pub struct StarkProvingKey { // for each table, the preprocessed data - pub preprocessed: Option>>, + pub preprocessed: Option>>, } -unsafe impl, MC: MerkleChannel> Send for TableProvingKey {} +//unsafe impl, MC: MerkleChannel> Send for TableProvingKey {} diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 5f90869dc9..5536352ac4 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -43,10 +43,11 @@ pub struct StwoProver + Send, MC: MerkleChannel, C: pub fixed: Arc)>>, /// Proving key - proving_key: StarkProvingKey, + proving_key: StarkProvingKey, /// Verifying key placeholder _verifying_key: Option<()>, _channel_marker: PhantomData, + _merkle_channel_marker: PhantomData, } impl<'a, F: FieldElement, B, MC, C> StwoProver @@ -72,6 +73,7 @@ where proving_key: StarkProvingKey { preprocessed: None }, _verifying_key: None, _channel_marker: PhantomData, + _merkle_channel_marker: PhantomData, }) } pub fn setup(&mut self) { @@ -89,7 +91,7 @@ where }) .collect(); - let preprocessed: BTreeMap> = self + let preprocessed: BTreeMap> = self .split .iter() .filter_map(|(namespace, pil)| { @@ -127,7 +129,6 @@ where size as usize, TableProvingKey { constant_trace_circle_domain: constant_trace, - _marker: PhantomData, }, ) }) @@ -143,6 +144,14 @@ where } pub fn prove(&self, witness: &[(String, Vec)]) -> Result, String> { + + assert!( + witness + .iter() + .all(|(_name, vec)| vec.len() == witness[0].1.len()), + "All Vec in witness must have the same length. Mismatch found!" + ); + let config = get_config(); let domain_map: BTreeMap = self .analyzed @@ -191,10 +200,10 @@ where let mut tree_builder = commitment_scheme.tree_builder(); - // only the frist one is used, machines with varying sizes are not supported yet, and it is checked in backendfactory create function. + // Get the list of constant polynomials with next reference constraint let constant_list: Vec = get_constant_with_next_list(&self.analyzed); - //commit to the constant polynomials with next reference constraint + //commit to the constant polynomials that are without next reference constraint if let Some((_, table_proving_key)) = self.proving_key .preprocessed @@ -219,12 +228,7 @@ where } tree_builder.commit(prover_channel); - assert!( - witness - .iter() - .all(|(_name, vec)| vec.len() == witness[0].1.len()), - "All Vec in witness must have the same length. Mismatch found!" - ); + let mut trace: ColumnVec> = witness .iter() @@ -235,7 +239,8 @@ where ) }) .collect(); - + + //extend the witness trace with the constant polys that have next reference constraint if let Some((_, table_proving_key)) = self.proving_key .preprocessed diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index b4e6a6d6da..c63a8706eb 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -296,6 +296,13 @@ fn stwo_incremental_one() { let f = "pil/incremental_one.pil"; test_stwo(f, Default::default()); } + +#[test] +fn stwo_constant_next_test() { + let f = "pil/fixed_with_incremental.pil"; + test_stwo(f, Default::default()); +} + #[test] fn fibonacci_invalid_witness_stwo() { let f = "pil/fibo_no_publics.pil"; diff --git a/test_data/pil/fixed_with_next.pil b/test_data/pil/fixed_with_next.pil new file mode 100644 index 0000000000..786876d57f --- /dev/null +++ b/test_data/pil/fixed_with_next.pil @@ -0,0 +1,14 @@ +let N = 32; + +// This uses the alternative nomenclature as well. + +namespace Incremental(N); + col fixed ISLAST(i) { if i == N - 1 { 1 } else { 0 } }; + col fixed INCREMENT(i) { i + 1 }; + col witness x, y; + + ISLAST * (x' - 1) = 0; + ISLAST * (INCREMENT' - 1) = 0; + + (1-ISLAST) * (x' - x-1) = 0; + (1-ISLAST) * (INCREMENT' - INCREMENT - 1) = 0; \ No newline at end of file From 50521e92691d67e612cdd5351ea7d7a448001eec Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 6 Dec 2024 10:09:56 +0100 Subject: [PATCH 33/48] fix fmt --- backend/src/stwo/prover.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 5536352ac4..8b33276cbb 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -144,7 +144,6 @@ where } pub fn prove(&self, witness: &[(String, Vec)]) -> Result, String> { - assert!( witness .iter() @@ -228,8 +227,6 @@ where } tree_builder.commit(prover_channel); - - let mut trace: ColumnVec> = witness .iter() .map(|(_name, values)| { @@ -239,7 +236,7 @@ where ) }) .collect(); - + //extend the witness trace with the constant polys that have next reference constraint if let Some((_, table_proving_key)) = self.proving_key From 410c68a5bc64f5b3706ddaea14723f5f259bd91a Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Fri, 6 Dec 2024 15:20:49 +0100 Subject: [PATCH 34/48] no intermidate panic --- backend/src/stwo/circuit_builder.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index feffc3eb14..a745fe675c 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -146,7 +146,11 @@ impl FrameworkEval for PowdrEval { }) .collect(); - for id in self.analyzed.identities.clone() { + println!("self.analyzed.identities_with_inlined_intermediate_polynomials(): {:?}", self.analyzed.identities_with_inlined_intermediate_polynomials()); + + println!("\n self.analyzed.identities is {:?}", self.analyzed.identities); + + for id in self.analyzed.identities_with_inlined_intermediate_polynomials() { match id { Identity::Polynomial(identity) => { let expr = to_stwo_expression( From a1a663c6e00cc6ec427873ed31c63541a9429441 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 9 Dec 2024 09:29:01 +0100 Subject: [PATCH 35/48] serialize and deserialize proving keys --- backend/src/stwo/circuit_builder.rs | 9 +- backend/src/stwo/mod.rs | 23 ++++-- backend/src/stwo/proof.rs | 122 +++++++++++++++++++++++++++- backend/src/stwo/prover.rs | 46 +++++++++-- 4 files changed, 179 insertions(+), 21 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index a745fe675c..79282929ec 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -146,11 +146,14 @@ impl FrameworkEval for PowdrEval { }) .collect(); - println!("self.analyzed.identities_with_inlined_intermediate_polynomials(): {:?}", self.analyzed.identities_with_inlined_intermediate_polynomials()); + //println!("self.analyzed.identities_with_inlined_intermediate_polynomials(): {:?}", self.analyzed.identities_with_inlined_intermediate_polynomials()); - println!("\n self.analyzed.identities is {:?}", self.analyzed.identities); + //println!("\n self.analyzed.identities is {:?}", self.analyzed.identities); - for id in self.analyzed.identities_with_inlined_intermediate_polynomials() { + for id in self + .analyzed + .identities_with_inlined_intermediate_polynomials() + { match id { Identity::Polynomial(identity) => { let expr = to_stwo_expression( diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 87b497e063..7f3febcaf0 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -38,13 +38,22 @@ impl BackendFactory for RestrictedFactory { if proving_key.is_some() { return Err(Error::BackendError("Proving key unused".to_string())); } - // if pil.degrees().len() > 1 { - // return Err(Error::NoVariableDegreeAvailable); - // } + if pil.degrees().len() > 1 { + return Err(Error::NoVariableDegreeAvailable); + } let mut stwo: Box> = Box::new(StwoProver::new(pil, fixed)?); - stwo.setup(); + + match (proving_key, verification_key) { + (Some(pk), Some(vk)) => { + stwo.set_proving_key(pk); + //stwo.set_verifying_key(vk); + } + _ => { + stwo.setup(); + } + } Ok(stwo) } @@ -80,8 +89,8 @@ where } Ok(StwoProver::prove(self, witness)?) } - #[allow(unused_variables)] - fn export_verification_key(&self, output: &mut dyn io::Write) -> Result<(), Error> { - unimplemented!() + fn export_proving_key(&self, output: &mut dyn io::Write) -> Result<(), Error> { + self.export_proving_key(output) + .map_err(|e| Error::BackendError(e.to_string())) } } diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 077108aa83..9e4a6e8c21 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -1,21 +1,135 @@ +use serde::Deserialize; +use serde::Serialize; use std::collections::BTreeMap; use stwo_prover::core::backend::Backend; +use stwo_prover::core::backend::Column; +use stwo_prover::core::backend::ColumnOps; use stwo_prover::core::fields::m31::BaseField; -use stwo_prover::core::poly::circle::CircleEvaluation; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::ColumnVec; /// For each possible size, the commitment and prover data pub type TableProvingKeyCollection = BTreeMap>; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct TableProvingKey { pub constant_trace_circle_domain: ColumnVec>, } +impl From for TableProvingKey { + fn from(serializable: SerializableTableProvingKey) -> Self { + let constant_trace_circle_domain = serializable + .constant_trace_circle_domain + .into_iter() + .map(|circle_eval| { + let mut column: >::Column = + >::Column::zeros(circle_eval.values.len()); + circle_eval.values.iter().enumerate().for_each(|(i, v)| { + column.set(i, *v); + }); + CircleEvaluation::::new( + CanonicCoset::new(circle_eval.domain_log_size).circle_domain(), + column, + ) + }) + .collect::>(); + + TableProvingKey { + constant_trace_circle_domain, + } + } +} + +#[derive(Debug, Clone)] pub struct StarkProvingKey { - // for each table, the preprocessed data pub preprocessed: Option>>, } -//unsafe impl, MC: MerkleChannel> Send for TableProvingKey {} +impl From for StarkProvingKey { + fn from(serializable: SerializableStarkProvingKey) -> Self { + let preprocessed = serializable.preprocessed.map(|map| { + map.into_iter() + .map(|(key, value)| { + ( + key, + value + .into_iter() + .map(|(inner_key, inner_value)| { + (inner_key, TableProvingKey::::from(inner_value)) + }) + .collect::>(), + ) + }) + .collect::>() + }); + + StarkProvingKey { preprocessed } + } +} + +#[derive(Serialize, Deserialize)] +pub struct SerializableCircleEvaluation { + domain_log_size: u32, + values: Vec, +} + +impl From> + for SerializableCircleEvaluation +{ + fn from(circle_evaluation: CircleEvaluation) -> Self { + let domain_log_size = circle_evaluation.domain.log_size(); + let values = circle_evaluation.values.to_cpu(); + Self { + domain_log_size, + values, + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct SerializableTableProvingKey { + constant_trace_circle_domain: Vec, +} + +impl From> for SerializableTableProvingKey { + fn from(table_proving_key: TableProvingKey) -> Self { + let constant_trace_circle_domain = table_proving_key + .constant_trace_circle_domain + .iter() + .map(|circle_eval| SerializableCircleEvaluation::from(circle_eval.clone())) + .collect(); + + Self { + constant_trace_circle_domain, + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct SerializableStarkProvingKey { + preprocessed: Option>>, +} + +impl From> for SerializableStarkProvingKey { + fn from(stark_proving_key: StarkProvingKey) -> Self { + let preprocessed = stark_proving_key.preprocessed.map(|map| { + map.into_iter() + .map(|(key, value)| { + ( + key, + value + .into_iter() + .map(|(inner_key, inner_value)| { + (inner_key, SerializableTableProvingKey::from(inner_value)) + }) + .collect::>(), + ) + }) + .collect::>() + }); + + Self { preprocessed } + } +} diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 8b33276cbb..7391ba877d 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -5,14 +5,16 @@ use powdr_number::FieldElement; use serde::de::DeserializeOwned; use serde::ser::Serialize; use std::collections::BTreeMap; -use std::io; use std::marker::PhantomData; use std::sync::Arc; +use std::{fmt, io}; use crate::stwo::circuit_builder::{ gen_stwo_circle_column, get_constant_with_next_list, PowdrComponent, PowdrEval, }; -use crate::stwo::proof::{StarkProvingKey, TableProvingKey, TableProvingKeyCollection}; +use crate::stwo::proof::{ + SerializableStarkProvingKey, StarkProvingKey, TableProvingKey, TableProvingKeyCollection, +}; use stwo_prover::constraint_framework::{ TraceLocationAllocator, ORIGINAL_TRACE_IDX, PREPROCESSED_TRACE_IDX, @@ -35,6 +37,20 @@ const FRI_NUM_QUERIES: usize = 100; const FRI_PROOF_OF_WORK_BITS: usize = 16; const LOG_LAST_LAYER_DEGREE_BOUND: usize = 0; +pub enum KeyExportError { + NoProvingKey, + //NoVerificationKey, +} + +impl fmt::Display for KeyExportError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::NoProvingKey => write!(f, "No proving key set"), + // Self::NoVerificationKey => write!(f, "No verification key set"), + } + } +} + pub struct StwoProver + Send, MC: MerkleChannel, C: Channel> { pub analyzed: Arc>, /// The split analyzed PIL @@ -76,6 +92,25 @@ where _merkle_channel_marker: PhantomData, }) } + + pub fn set_proving_key(&mut self, rdr: &mut dyn std::io::Read) { + let serializable_key: SerializableStarkProvingKey = bincode::deserialize_from(rdr).unwrap(); + self.proving_key = StarkProvingKey::from(serializable_key); + } + + pub fn export_proving_key( + &self, + writer: &mut dyn std::io::Write, + ) -> Result<(), KeyExportError> { + let pk = SerializableStarkProvingKey::from(self.proving_key.clone()); + self.proving_key + .preprocessed + .as_ref() + .ok_or(KeyExportError::NoProvingKey)?; + bincode::serialize_into(writer, &pk).unwrap(); + Ok(()) + } + pub fn setup(&mut self) { // machines with varying sizes are not supported yet, and it is checked in backendfactory create function. //TODO: support machines with varying sizes @@ -171,11 +206,8 @@ where pil.committed_polys_in_source_order() .flat_map(|(s, _)| { s.degree.iter().flat_map(|range| { - let min = range.min; - let max = range.max; - - // Iterate over powers of 2 from min to max - (min..=max) + range + .iter() .filter(|&size| size.is_power_of_two()) // Only take powers of 2 .map(|size| { // Compute twiddles for this size From 9503341bc87f7c65260f1481edc446a502ad45c9 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 9 Dec 2024 10:12:38 +0100 Subject: [PATCH 36/48] add test file --- test_data/pil/fixed_with_incremental.pil | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 test_data/pil/fixed_with_incremental.pil diff --git a/test_data/pil/fixed_with_incremental.pil b/test_data/pil/fixed_with_incremental.pil new file mode 100644 index 0000000000..dfd46c0f55 --- /dev/null +++ b/test_data/pil/fixed_with_incremental.pil @@ -0,0 +1,13 @@ +let N = 32; + +// This uses the alternative nomenclature as well. + +namespace Incremental(N); + col fixed ISLAST(i) { if i == N - 1 { 1 } else { 0 } }; + col witness x ; + col fixed INCREMENTAL(i) {i+1}; + ISLAST * (x' - 1) = 0; + ISLAST * (INCREMENTAL' - 1) = 0; + + (1-ISLAST) * (x' - x-1) = 0; + (1-ISLAST) * (INCREMENTAL' - INCREMENTAL-1) = 0; \ No newline at end of file From 66b1c1b254cef69fdc0787d3122bf6addd8f8b27 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 9 Dec 2024 11:34:33 +0100 Subject: [PATCH 37/48] simplify serilization --- backend/src/stwo/proof.rs | 42 ++++++++++++--------------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 9e4a6e8c21..19762abc0d 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -23,14 +23,14 @@ impl From for TableProvingKey { let constant_trace_circle_domain = serializable .constant_trace_circle_domain .into_iter() - .map(|circle_eval| { + .map(|(size, values)| { let mut column: >::Column = - >::Column::zeros(circle_eval.values.len()); - circle_eval.values.iter().enumerate().for_each(|(i, v)| { + >::Column::zeros(values.len()); + values.iter().enumerate().for_each(|(i, v)| { column.set(i, *v); }); CircleEvaluation::::new( - CanonicCoset::new(circle_eval.domain_log_size).circle_domain(), + CanonicCoset::new(size as u32).circle_domain(), column, ) }) @@ -69,37 +69,21 @@ impl From for StarkProvingKey { } } -#[derive(Serialize, Deserialize)] -pub struct SerializableCircleEvaluation { - domain_log_size: u32, - values: Vec, -} - -impl From> - for SerializableCircleEvaluation -{ - fn from(circle_evaluation: CircleEvaluation) -> Self { - let domain_log_size = circle_evaluation.domain.log_size(); - let values = circle_evaluation.values.to_cpu(); - Self { - domain_log_size, - values, - } - } -} - #[derive(Serialize, Deserialize)] pub struct SerializableTableProvingKey { - constant_trace_circle_domain: Vec, + // usize is the domain log size, Vec is the values of the circle evaluation + constant_trace_circle_domain: BTreeMap>, // Single BTreeMap } impl From> for SerializableTableProvingKey { fn from(table_proving_key: TableProvingKey) -> Self { - let constant_trace_circle_domain = table_proving_key - .constant_trace_circle_domain - .iter() - .map(|circle_eval| SerializableCircleEvaluation::from(circle_eval.clone())) - .collect(); + let mut constant_trace_circle_domain = BTreeMap::new(); + + for circle_eval in &table_proving_key.constant_trace_circle_domain { + let domain_log_size = circle_eval.domain.log_size() as usize; + let values = circle_eval.values.to_cpu(); + constant_trace_circle_domain.insert(domain_log_size, values); + } Self { constant_trace_circle_domain, From b18e937357b9d01394d1806aa80e58a366abb7a6 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 9 Dec 2024 13:31:30 +0100 Subject: [PATCH 38/48] simplify serilization --- backend/src/stwo/proof.rs | 101 +++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 57 deletions(-) diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 19762abc0d..3fadf056e9 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -18,46 +18,40 @@ pub struct TableProvingKey { pub constant_trace_circle_domain: ColumnVec>, } -impl From for TableProvingKey { - fn from(serializable: SerializableTableProvingKey) -> Self { - let constant_trace_circle_domain = serializable - .constant_trace_circle_domain - .into_iter() - .map(|(size, values)| { - let mut column: >::Column = - >::Column::zeros(values.len()); - values.iter().enumerate().for_each(|(i, v)| { - column.set(i, *v); - }); - CircleEvaluation::::new( - CanonicCoset::new(size as u32).circle_domain(), - column, - ) - }) - .collect::>(); - - TableProvingKey { - constant_trace_circle_domain, - } - } -} - #[derive(Debug, Clone)] pub struct StarkProvingKey { pub preprocessed: Option>>, } impl From for StarkProvingKey { - fn from(serializable: SerializableStarkProvingKey) -> Self { - let preprocessed = serializable.preprocessed.map(|map| { + fn from(serializable_stark_provingkey: SerializableStarkProvingKey) -> Self { + let preprocessed = serializable_stark_provingkey.preprocessed.map(|map| { map.into_iter() - .map(|(key, value)| { + .map(|(namespace, table_provingkey_collection)| { ( - key, - value + namespace, + table_provingkey_collection .into_iter() - .map(|(inner_key, inner_value)| { - (inner_key, TableProvingKey::::from(inner_value)) + .map(|(machine_size, table_provingkey)| { + ( + machine_size, + TableProvingKey{ + constant_trace_circle_domain: table_provingkey + .into_iter() + .map(|(size,values)|{ + let mut column: >::Column = + >::Column::zeros(values.len()); + values.iter().enumerate().for_each(|(i, v)| { + column.set(i, *v); + }); + CircleEvaluation::::new( + CanonicCoset::new(size as u32).circle_domain(), + column, + ) + }) + .collect::>(), + } + ) }) .collect::>(), ) @@ -69,44 +63,37 @@ impl From for StarkProvingKey { } } -#[derive(Serialize, Deserialize)] -pub struct SerializableTableProvingKey { - // usize is the domain log size, Vec is the values of the circle evaluation - constant_trace_circle_domain: BTreeMap>, // Single BTreeMap -} - -impl From> for SerializableTableProvingKey { - fn from(table_proving_key: TableProvingKey) -> Self { - let mut constant_trace_circle_domain = BTreeMap::new(); - - for circle_eval in &table_proving_key.constant_trace_circle_domain { - let domain_log_size = circle_eval.domain.log_size() as usize; - let values = circle_eval.values.to_cpu(); - constant_trace_circle_domain.insert(domain_log_size, values); - } - - Self { - constant_trace_circle_domain, - } - } -} +type CircleEvaluationMap = BTreeMap>; #[derive(Serialize, Deserialize)] pub struct SerializableStarkProvingKey { - preprocessed: Option>>, + // usize is the domain log size, Vec is the values of the circle evaluation + preprocessed: Option>>, } impl From> for SerializableStarkProvingKey { fn from(stark_proving_key: StarkProvingKey) -> Self { let preprocessed = stark_proving_key.preprocessed.map(|map| { map.into_iter() - .map(|(key, value)| { + .map(|(namespace, value)| { ( - key, + namespace, value .into_iter() - .map(|(inner_key, inner_value)| { - (inner_key, SerializableTableProvingKey::from(inner_value)) + .map(|(machine_size, table_provingkey)| { + ( + machine_size, + table_provingkey + .constant_trace_circle_domain + .into_iter() + .map(|circle_evaluation| { + ( + circle_evaluation.domain.log_size() as usize, + circle_evaluation.values.to_cpu(), + ) + }) + .collect::>(), + ) }) .collect::>(), ) From 524d9ff93ab1d85eb1f92cdf5b1ce7c87b924283 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Mon, 9 Dec 2024 16:20:36 +0100 Subject: [PATCH 39/48] clean up --- backend/src/stwo/circuit_builder.rs | 8 ++------ backend/src/stwo/mod.rs | 3 ++- backend/src/stwo/prover.rs | 9 ++++----- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 79282929ec..eafc2d0041 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -70,7 +70,7 @@ impl PowdrEval { .enumerate() .map(|(index, (_, id))| (id, index)) .collect(); - + // create a list of indexs of the constant polynomials that have next references constraint let constant_with_next_list = get_constant_with_next_list(&analyzed); let constant_with_next_columns: BTreeMap = analyzed @@ -146,10 +146,6 @@ impl FrameworkEval for PowdrEval { }) .collect(); - //println!("self.analyzed.identities_with_inlined_intermediate_polynomials(): {:?}", self.analyzed.identities_with_inlined_intermediate_polynomials()); - - //println!("\n self.analyzed.identities is {:?}", self.analyzed.identities); - for id in self .analyzed .identities_with_inlined_intermediate_polynomials() @@ -215,7 +211,7 @@ where if !constant_with_next_eval.contains_key(&poly_id) { match r.next { false => constant_eval[&poly_id].clone(), - true => panic!("Next on a constant polynomial filter fails"), + true => panic!("constant polynomial with next reference filter fails"), } } else { match r.next { diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 7f3febcaf0..14630dfc08 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -38,6 +38,7 @@ impl BackendFactory for RestrictedFactory { if proving_key.is_some() { return Err(Error::BackendError("Proving key unused".to_string())); } + if pil.degrees().len() > 1 { return Err(Error::NoVariableDegreeAvailable); } @@ -76,7 +77,7 @@ where Ok(self.verify(proof, instances)?) } - #[allow(unreachable_code)] + #[allow(unused_variables)] fn prove( &self, diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 7391ba877d..ba2d166b5c 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -130,7 +130,7 @@ where .split .iter() .filter_map(|(namespace, pil)| { - // if we have neither fixed columns nor publics, we don't need to commit to anything + // if we no fixed columns, we don't need to commit to anything, publics is not supported yet. if pil.constant_count() == 0 { None } else { @@ -208,9 +208,8 @@ where s.degree.iter().flat_map(|range| { range .iter() - .filter(|&size| size.is_power_of_two()) // Only take powers of 2 + .filter(|&size| size.is_power_of_two()) .map(|size| { - // Compute twiddles for this size let twiddles = B::precompute_twiddles( CanonicCoset::new(size.ilog2() + 1 + FRI_LOG_BLOWUP as u32) .circle_domain() @@ -218,10 +217,10 @@ where ); (size as usize, twiddles) }) - .collect::>() // Collect results into a Vec + .collect::>() }) }) - .collect::>() // Collect the inner results into a Vec + .collect::>() }) .collect(); // only the first one is used, machines with varying sizes are not supported yet, and it is checked in backendfactory create function. From 5e592f791408f5ca32023e590ba34fccdde3e999 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Tue, 10 Dec 2024 10:37:35 +0100 Subject: [PATCH 40/48] refactor proof serilization and next reference on constant --- backend/src/stwo/circuit_builder.rs | 59 ++++++-------- backend/src/stwo/proof.rs | 118 +++++++++++++++++----------- backend/src/stwo/prover.rs | 65 +++++++-------- 3 files changed, 125 insertions(+), 117 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index eafc2d0041..0bd57b16d9 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -58,7 +58,7 @@ where pub struct PowdrEval { analyzed: Arc>, witness_columns: BTreeMap, - constant_with_next_columns: BTreeMap, + constant_shifted: BTreeMap, constant_columns: BTreeMap, } @@ -73,7 +73,7 @@ impl PowdrEval { // create a list of indexs of the constant polynomials that have next references constraint let constant_with_next_list = get_constant_with_next_list(&analyzed); - let constant_with_next_columns: BTreeMap = analyzed + let constant_shifted: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Constant) .flat_map(|(symbol, _)| symbol.array_elements()) .enumerate() @@ -85,14 +85,14 @@ impl PowdrEval { .definitions_in_source_order(PolynomialType::Constant) .flat_map(|(symbol, _)| symbol.array_elements()) .enumerate() - .filter(|(_, (_, id))| !constant_with_next_list.contains(&(id.id as usize))) + // .filter(|(_, (_, id))| !constant_with_next_list.contains(&(id.id as usize))) .map(|(index, (_, id))| (id, index)) .collect(); Self { analyzed, witness_columns, - constant_with_next_columns, + constant_shifted, constant_columns, } } @@ -122,26 +122,30 @@ impl FrameworkEval for PowdrEval { }) .collect(); - let constant_with_next_eval: BTreeMap::F; 2]> = self - .constant_with_next_columns + let constant_eval: BTreeMap<_, _> = self + .constant_columns .keys() - .map(|poly_id| { + .enumerate() + .map(|(i, poly_id)| { ( *poly_id, - eval.next_interaction_mask(ORIGINAL_TRACE_IDX, [0, 1]), + // PreprocessedColumn::Plonk(i) is unused argument in get_preprocessed_column + eval.get_preprocessed_column(PreprocessedColumn::Plonk(i)), ) }) .collect(); - let constant_eval: BTreeMap<_, _> = self - .constant_columns + let constant_shifted_eval: BTreeMap<_, _> = self + .constant_shifted .keys() .enumerate() .map(|(i, poly_id)| { ( *poly_id, // PreprocessedColumn::Plonk(i) is unused argument in get_preprocessed_column - eval.get_preprocessed_column(PreprocessedColumn::Plonk(i)), + eval.get_preprocessed_column(PreprocessedColumn::Plonk( + i + constant_eval.len(), + )), ) }) .collect(); @@ -155,7 +159,7 @@ impl FrameworkEval for PowdrEval { let expr = to_stwo_expression( &identity.expression, &witness_eval, - &constant_with_next_eval, + &constant_shifted_eval, &constant_eval, ); eval.add_constraint(expr); @@ -180,7 +184,7 @@ impl FrameworkEval for PowdrEval { fn to_stwo_expression( expr: &AlgebraicExpression, witness_eval: &BTreeMap, - constant_with_next_eval: &BTreeMap, + constant_shifted_eval: &BTreeMap, constant_eval: &BTreeMap, ) -> F where @@ -207,19 +211,10 @@ where false => witness_eval[&poly_id][0].clone(), true => witness_eval[&poly_id][1].clone(), }, - PolynomialType::Constant => { - if !constant_with_next_eval.contains_key(&poly_id) { - match r.next { - false => constant_eval[&poly_id].clone(), - true => panic!("constant polynomial with next reference filter fails"), - } - } else { - match r.next { - false => constant_with_next_eval[&poly_id][0].clone(), - true => constant_with_next_eval[&poly_id][1].clone(), - } - } - } + PolynomialType::Constant => match r.next { + false => constant_eval[&poly_id].clone(), + true => constant_shifted_eval[&poly_id].clone(), + }, PolynomialType::Intermediate => { unimplemented!("Intermediate polynomials are not supported in stwo yet") } @@ -236,17 +231,16 @@ where }) => match **right { AlgebraicExpression::Number(n) => { let left = - to_stwo_expression(left, witness_eval, constant_with_next_eval, constant_eval); + to_stwo_expression(left, witness_eval, constant_shifted_eval, constant_eval); (0u32..n.to_integer().try_into_u32().unwrap()) .fold(F::one(), |acc, _| acc * left.clone()) } _ => unimplemented!("pow with non-constant exponent"), }, AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { - let left = - to_stwo_expression(left, witness_eval, constant_with_next_eval, constant_eval); + let left = to_stwo_expression(left, witness_eval, constant_shifted_eval, constant_eval); let right = - to_stwo_expression(right, witness_eval, constant_with_next_eval, constant_eval); + to_stwo_expression(right, witness_eval, constant_shifted_eval, constant_eval); match op { Add => left + right, @@ -256,8 +250,7 @@ where } } AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => { - let expr = - to_stwo_expression(expr, witness_eval, constant_with_next_eval, constant_eval); + let expr = to_stwo_expression(expr, witness_eval, constant_shifted_eval, constant_eval); match op { AlgebraicUnaryOperator::Minus => -expr, @@ -313,7 +306,7 @@ pub fn constant_with_next_to_witness_col( } // This function creates a list of indexs of the constant polynomials that have next references constraint -pub fn get_constant_with_next_list(analyzed: &Arc>) -> Vec { +pub fn get_constant_with_next_list(analyzed: &Analyzed) -> Vec { let mut all_constant_with_next: Vec = Vec::new(); for id in analyzed.identities_with_inlined_intermediate_polynomials() { if let Identity::Polynomial(identity) = id { diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 3fadf056e9..9a332f432a 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -13,6 +13,39 @@ use stwo_prover::core::ColumnVec; /// For each possible size, the commitment and prover data pub type TableProvingKeyCollection = BTreeMap>; +impl From for TableProvingKeyCollection { + fn from(serializable: SerializableTableProvingKeyCollection) -> Self { + let constant_trace_circle_domain_collection = serializable + .constant_trace_circle_domain_collection + .into_iter() + .map(|(size, table_provingkey)| { + let domain = CanonicCoset::new(size as u32).circle_domain(); + let constant_trace_circle_domain = table_provingkey + .into_values() + .map(|values| { + let mut column: >::Column = + >::Column::zeros(values.len()); + values.iter().enumerate().for_each(|(i, v)| { + column.set(i, *v); + }); + + CircleEvaluation::::new(domain, column) + }) + .collect::>(); + + ( + size, + TableProvingKey { + constant_trace_circle_domain, + }, + ) + }) + .collect::>(); + + constant_trace_circle_domain_collection + } +} + #[derive(Debug, Clone)] pub struct TableProvingKey { pub constant_trace_circle_domain: ColumnVec>, @@ -24,36 +57,13 @@ pub struct StarkProvingKey { } impl From for StarkProvingKey { - fn from(serializable_stark_provingkey: SerializableStarkProvingKey) -> Self { - let preprocessed = serializable_stark_provingkey.preprocessed.map(|map| { + fn from(serializable: SerializableStarkProvingKey) -> Self { + let preprocessed = serializable.preprocessed.map(|map| { map.into_iter() .map(|(namespace, table_provingkey_collection)| { ( namespace, - table_provingkey_collection - .into_iter() - .map(|(machine_size, table_provingkey)| { - ( - machine_size, - TableProvingKey{ - constant_trace_circle_domain: table_provingkey - .into_iter() - .map(|(size,values)|{ - let mut column: >::Column = - >::Column::zeros(values.len()); - values.iter().enumerate().for_each(|(i, v)| { - column.set(i, *v); - }); - CircleEvaluation::::new( - CanonicCoset::new(size as u32).circle_domain(), - column, - ) - }) - .collect::>(), - } - ) - }) - .collect::>(), + TableProvingKeyCollection::::from(table_provingkey_collection), ) }) .collect::>() @@ -63,39 +73,51 @@ impl From for StarkProvingKey { } } -type CircleEvaluationMap = BTreeMap>; +#[derive(Serialize, Deserialize)] +pub struct SerializableTableProvingKeyCollection { + constant_trace_circle_domain_collection: BTreeMap>>, +} + +impl From> for SerializableTableProvingKeyCollection { + fn from(table_provingkey_collection: TableProvingKeyCollection) -> Self { + let mut constant_trace_circle_domain_collection = BTreeMap::new(); + + table_provingkey_collection + .iter() + .for_each(|(&size, trable_provingkey)| { + let mut values: BTreeMap> = BTreeMap::new(); + trable_provingkey + .constant_trace_circle_domain + .iter() + .for_each(|circle_eval| { + values.insert( + circle_eval.domain.log_size() as usize, + circle_eval.values.to_cpu().to_vec(), + ); + }); + + constant_trace_circle_domain_collection.insert(size, values); + }); + + Self { + constant_trace_circle_domain_collection, + } + } +} #[derive(Serialize, Deserialize)] pub struct SerializableStarkProvingKey { - // usize is the domain log size, Vec is the values of the circle evaluation - preprocessed: Option>>, + preprocessed: Option>, } impl From> for SerializableStarkProvingKey { fn from(stark_proving_key: StarkProvingKey) -> Self { let preprocessed = stark_proving_key.preprocessed.map(|map| { map.into_iter() - .map(|(namespace, value)| { + .map(|(namespace, table_provingkey_collection)| { ( namespace, - value - .into_iter() - .map(|(machine_size, table_provingkey)| { - ( - machine_size, - table_provingkey - .constant_trace_circle_domain - .into_iter() - .map(|circle_evaluation| { - ( - circle_evaluation.domain.log_size() as usize, - circle_evaluation.values.to_cpu(), - ) - }) - .collect::>(), - ) - }) - .collect::>(), + SerializableTableProvingKeyCollection::from(table_provingkey_collection), ) }) .collect::>() diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index ba2d166b5c..47f1e504b8 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -138,13 +138,12 @@ where Some(( namespace.to_string(), - //why here it is committed_polys_in_source_order() instead of constant polys? pil.committed_polys_in_source_order() .find_map(|(s, _)| s.degree) .unwrap() .iter() .map(|size| { - let constant_trace: ColumnVec< + let mut constant_trace: ColumnVec< CircleEvaluation, > = fixed_columns .values() @@ -160,6 +159,31 @@ where }) .collect(); + let constant_with_next_list = get_constant_with_next_list(pil); + + let constant_shifted_trace: ColumnVec< + CircleEvaluation, + > = fixed_columns + .values() + .flat_map(|vec| { + vec.iter() + .enumerate() + .filter(|(i, _)| constant_with_next_list.contains(i)) + .map(|(_, (_name, values))| { + let mut rotated_values = values.to_vec(); + rotated_values.rotate_left(1); + gen_stwo_circle_column::( + *domain_map + .get(&(values.len().ilog2() as usize)) + .unwrap(), + &rotated_values, + ) + }) + }) + .collect(); + + constant_trace.extend(constant_shifted_trace); + ( size as usize, TableProvingKey { @@ -231,7 +255,7 @@ where let mut tree_builder = commitment_scheme.tree_builder(); // Get the list of constant polynomials with next reference constraint - let constant_list: Vec = get_constant_with_next_list(&self.analyzed); + // let constant_list: Vec = get_constant_with_next_list(&self.analyzed); //commit to the constant polynomials that are without next reference constraint if let Some((_, table_proving_key)) = @@ -244,21 +268,13 @@ where .find_map(|(_, table_collection)| table_collection.iter().next()) }) { - tree_builder.extend_evals( - table_proving_key - .constant_trace_circle_domain - .clone() - .into_iter() // Convert it into an iterator - .enumerate() // Enumerate to get (index, value) - .filter(|(index, _)| !constant_list.contains(index)) // Keep only elements whose index is not in `constant_list` - .map(|(_, element)| element), - ); + tree_builder.extend_evals(table_proving_key.constant_trace_circle_domain.clone()); } else { tree_builder.extend_evals([]); } tree_builder.commit(prover_channel); - let mut trace: ColumnVec> = witness + let trace: ColumnVec> = witness .iter() .map(|(_name, values)| { gen_stwo_circle_column::( @@ -268,29 +284,6 @@ where }) .collect(); - //extend the witness trace with the constant polys that have next reference constraint - if let Some((_, table_proving_key)) = - self.proving_key - .preprocessed - .as_ref() - .and_then(|preprocessed| { - preprocessed - .iter() - .find_map(|(_, table_collection)| table_collection.iter().next()) - }) - { - let constants_with_next: Vec> = - table_proving_key - .constant_trace_circle_domain - .clone() - .into_iter() - .enumerate() - .filter(|(index, _)| constant_list.contains(index)) // Keep only elements whose index is not in `constant_list` - .map(|(_, element)| element) - .collect(); - trace.extend(constants_with_next); - } - let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(trace); tree_builder.commit(prover_channel); From 30bd4aee7b3d19b70a294fb11dfa927e3634eb9e Mon Sep 17 00:00:00 2001 From: ShuangWu121 <47602565+ShuangWu121@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:38:22 +0100 Subject: [PATCH 41/48] Update backend/src/stwo/prover.rs Co-authored-by: Thibaut Schaeffer --- backend/src/stwo/prover.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 47f1e504b8..028380e28b 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -130,7 +130,7 @@ where .split .iter() .filter_map(|(namespace, pil)| { - // if we no fixed columns, we don't need to commit to anything, publics is not supported yet. + // if we have no fixed columns, we don't need to commit to anything. if pil.constant_count() == 0 { None } else { From 9ab7a3b98a3db9b14ad517a5e6cce923bf165d93 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Tue, 10 Dec 2024 10:43:53 +0100 Subject: [PATCH 42/48] clean up --- backend/src/stwo/circuit_builder.rs | 2 -- backend/src/stwo/prover.rs | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 0bd57b16d9..0d7370f0f1 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -85,7 +85,6 @@ impl PowdrEval { .definitions_in_source_order(PolynomialType::Constant) .flat_map(|(symbol, _)| symbol.array_elements()) .enumerate() - // .filter(|(_, (_, id))| !constant_with_next_list.contains(&(id.id as usize))) .map(|(index, (_, id))| (id, index)) .collect(); @@ -142,7 +141,6 @@ impl FrameworkEval for PowdrEval { .map(|(i, poly_id)| { ( *poly_id, - // PreprocessedColumn::Plonk(i) is unused argument in get_preprocessed_column eval.get_preprocessed_column(PreprocessedColumn::Plonk( i + constant_eval.len(), )), diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 028380e28b..aaee78e395 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -254,10 +254,8 @@ where let mut tree_builder = commitment_scheme.tree_builder(); - // Get the list of constant polynomials with next reference constraint - // let constant_list: Vec = get_constant_with_next_list(&self.analyzed); - //commit to the constant polynomials that are without next reference constraint + //commit to the constant and shifted constant polynomials if let Some((_, table_proving_key)) = self.proving_key .preprocessed From 3e2a078a02d3bd2a6bc346a74e94fcb1f538f287 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Tue, 10 Dec 2024 10:44:36 +0100 Subject: [PATCH 43/48] clean up --- backend/src/stwo/prover.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index aaee78e395..d57eb0a227 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -254,8 +254,7 @@ where let mut tree_builder = commitment_scheme.tree_builder(); - - //commit to the constant and shifted constant polynomials + //commit to the constant and shifted constant polynomials if let Some((_, table_proving_key)) = self.proving_key .preprocessed From 2b565878320ab642c6d858ea7a556fe1865266be Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Tue, 10 Dec 2024 11:13:19 +0100 Subject: [PATCH 44/48] log size correction --- backend/src/stwo/proof.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 9a332f432a..8177511f97 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -19,7 +19,7 @@ impl From for TableProvingKey .constant_trace_circle_domain_collection .into_iter() .map(|(size, table_provingkey)| { - let domain = CanonicCoset::new(size as u32).circle_domain(); + let domain = CanonicCoset::new(size.ilog2()).circle_domain(); let constant_trace_circle_domain = table_provingkey .into_values() .map(|values| { @@ -86,14 +86,12 @@ impl From> for SerializableTableProving .iter() .for_each(|(&size, trable_provingkey)| { let mut values: BTreeMap> = BTreeMap::new(); + let log_size = size.ilog2(); trable_provingkey .constant_trace_circle_domain .iter() .for_each(|circle_eval| { - values.insert( - circle_eval.domain.log_size() as usize, - circle_eval.values.to_cpu().to_vec(), - ); + values.insert(log_size as usize, circle_eval.values.to_cpu().to_vec()); }); constant_trace_circle_domain_collection.insert(size, values); From 8f6eb44f7ded350c4fa977da133f0e8d1918b2eb Mon Sep 17 00:00:00 2001 From: ShuangWu121 <47602565+ShuangWu121@users.noreply.github.com> Date: Tue, 10 Dec 2024 12:31:20 +0100 Subject: [PATCH 45/48] Update backend/src/stwo/circuit_builder.rs Co-authored-by: Thibaut Schaeffer --- backend/src/stwo/circuit_builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 0d7370f0f1..fb5746d097 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -303,7 +303,7 @@ pub fn constant_with_next_to_witness_col( } } -// This function creates a list of indexs of the constant polynomials that have next references constraint +// This function creates a list of indices of the constant polynomials that have next references constraint pub fn get_constant_with_next_list(analyzed: &Analyzed) -> Vec { let mut all_constant_with_next: Vec = Vec::new(); for id in analyzed.identities_with_inlined_intermediate_polynomials() { From af614130ac955fe011d9812718eb5bbde2cb8457 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Wed, 11 Dec 2024 13:30:12 +0100 Subject: [PATCH 46/48] simplify the function to create list of constant with next reference --- backend/src/stwo/circuit_builder.rs | 80 ++++++++--------------------- backend/src/stwo/prover.rs | 3 +- 2 files changed, 24 insertions(+), 59 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index fb5746d097..d41b144b49 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -1,4 +1,5 @@ use num_traits::Zero; +use std::collections::HashSet; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; use std::sync::Arc; @@ -6,7 +7,8 @@ use std::sync::Arc; extern crate alloc; use alloc::collections::btree_map::BTreeMap; use powdr_ast::analyzed::{ - AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, Analyzed, Identity, + AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, AlgebraicReference, + Analyzed, Identity, }; use powdr_number::{FieldElement, LargeInt}; @@ -70,8 +72,9 @@ impl PowdrEval { .enumerate() .map(|(index, (_, id))| (id, index)) .collect(); - // create a list of indexs of the constant polynomials that have next references constraint - let constant_with_next_list = get_constant_with_next_list(&analyzed); + let mut analyzed_mut = (*analyzed).clone(); + + let constant_with_next_list = get_constant_with_next_list(&mut analyzed_mut); let constant_shifted: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Constant) @@ -260,60 +263,21 @@ where } } -pub fn constant_with_next_to_witness_col( - expr: &AlgebraicExpression, - constant_with_next_list: &mut Vec, -) { - use AlgebraicBinaryOperator::*; - match expr { - AlgebraicExpression::Reference(r) => { - let poly_id = r.poly_id; - - match poly_id.ptype { - PolynomialType::Committed => {} - PolynomialType::Constant => match r.next { - false => {} - true => { - constant_with_next_list.push(r.poly_id.id as usize); - } - }, - PolynomialType::Intermediate => {} - } - } - AlgebraicExpression::PublicReference(..) => {} - AlgebraicExpression::Number(_) => {} - AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { - left, - op: Pow, - right, - }) => match **right { - AlgebraicExpression::Number(_) => { - constant_with_next_to_witness_col::(left, constant_with_next_list); - } - _ => unimplemented!("pow with non-constant exponent"), - }, - AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op: _, right }) => { - constant_with_next_to_witness_col::(left, constant_with_next_list); - constant_with_next_to_witness_col::(right, constant_with_next_list); - } - AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op: _, expr }) => { - constant_with_next_to_witness_col::(expr, constant_with_next_list); - } - AlgebraicExpression::Challenge(_challenge) => {} - } -} - // This function creates a list of indices of the constant polynomials that have next references constraint -pub fn get_constant_with_next_list(analyzed: &Analyzed) -> Vec { - let mut all_constant_with_next: Vec = Vec::new(); - for id in analyzed.identities_with_inlined_intermediate_polynomials() { - if let Identity::Polynomial(identity) = id { - let mut constant_with_next: Vec = Vec::new(); - constant_with_next_to_witness_col::(&identity.expression, &mut constant_with_next); - all_constant_with_next.extend(constant_with_next) - } - } - all_constant_with_next.sort_unstable(); - all_constant_with_next.dedup(); - all_constant_with_next +pub fn get_constant_with_next_list(analyzed: &mut Analyzed) -> HashSet { + let mut constant_with_next_list: HashSet = HashSet::new(); + analyzed.post_visit_expressions_in_identities_mut(&mut |e| { + if let AlgebraicExpression::Reference(AlgebraicReference { + name: _, + poly_id, + next, + }) = e + { + if matches!(poly_id.ptype, PolynomialType::Constant) && *next { + // add the index of the constant polynomial to the list + constant_with_next_list.insert(poly_id.id as usize); + } + }; + }); + constant_with_next_list } diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index d57eb0a227..87a2168758 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -159,7 +159,8 @@ where }) .collect(); - let constant_with_next_list = get_constant_with_next_list(pil); + let constant_with_next_list = + get_constant_with_next_list(&mut pil.clone()); let constant_shifted_trace: ColumnVec< CircleEvaluation, From f446e46182796b34c821b6087080ae85ab8ffa5d Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Wed, 11 Dec 2024 15:52:15 +0100 Subject: [PATCH 47/48] use all_children to create list --- backend/src/stwo/circuit_builder.rs | 8 ++++---- backend/src/stwo/prover.rs | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index d41b144b49..4c8c1bcbde 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -1,4 +1,5 @@ use num_traits::Zero; +use powdr_ast::parsed::visitor::AllChildren; use std::collections::HashSet; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; @@ -72,9 +73,8 @@ impl PowdrEval { .enumerate() .map(|(index, (_, id))| (id, index)) .collect(); - let mut analyzed_mut = (*analyzed).clone(); - let constant_with_next_list = get_constant_with_next_list(&mut analyzed_mut); + let constant_with_next_list = get_constant_with_next_list(&analyzed); let constant_shifted: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Constant) @@ -264,9 +264,9 @@ where } // This function creates a list of indices of the constant polynomials that have next references constraint -pub fn get_constant_with_next_list(analyzed: &mut Analyzed) -> HashSet { +pub fn get_constant_with_next_list(analyzed: &Analyzed) -> HashSet { let mut constant_with_next_list: HashSet = HashSet::new(); - analyzed.post_visit_expressions_in_identities_mut(&mut |e| { + analyzed.all_children().for_each(|e| { if let AlgebraicExpression::Reference(AlgebraicReference { name: _, poly_id, diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 87a2168758..d57eb0a227 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -159,8 +159,7 @@ where }) .collect(); - let constant_with_next_list = - get_constant_with_next_list(&mut pil.clone()); + let constant_with_next_list = get_constant_with_next_list(pil); let constant_shifted_trace: ColumnVec< CircleEvaluation, From 0557c6059583403875d627f9bb5ea6c968c43180 Mon Sep 17 00:00:00 2001 From: ShuangWu121 Date: Wed, 11 Dec 2024 16:27:15 +0100 Subject: [PATCH 48/48] create name list for constant with next --- backend/src/stwo/circuit_builder.rs | 14 +++++++------- backend/src/stwo/prover.rs | 7 ++++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 4c8c1bcbde..3957573c8f 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -80,7 +80,7 @@ impl PowdrEval { .definitions_in_source_order(PolynomialType::Constant) .flat_map(|(symbol, _)| symbol.array_elements()) .enumerate() - .filter(|(_, (_, id))| constant_with_next_list.contains(&(id.id as usize))) + .filter(|(_, (name, _))| constant_with_next_list.contains(name)) .map(|(index, (_, id))| (id, index)) .collect(); @@ -263,19 +263,19 @@ where } } -// This function creates a list of indices of the constant polynomials that have next references constraint -pub fn get_constant_with_next_list(analyzed: &Analyzed) -> HashSet { - let mut constant_with_next_list: HashSet = HashSet::new(); +// This function creates a list of the names of the constant polynomials that have next references constraint +pub fn get_constant_with_next_list(analyzed: &Analyzed) -> HashSet<&String> { + let mut constant_with_next_list: HashSet<&String> = HashSet::new(); analyzed.all_children().for_each(|e| { if let AlgebraicExpression::Reference(AlgebraicReference { - name: _, + name, poly_id, next, }) = e { if matches!(poly_id.ptype, PolynomialType::Constant) && *next { - // add the index of the constant polynomial to the list - constant_with_next_list.insert(poly_id.id as usize); + // add the name of the constant polynomial to the list + constant_with_next_list.insert(name); } }; }); diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index d57eb0a227..8bdb954b8e 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -167,9 +167,10 @@ where .values() .flat_map(|vec| { vec.iter() - .enumerate() - .filter(|(i, _)| constant_with_next_list.contains(i)) - .map(|(_, (_name, values))| { + .filter(|(name, _)| { + constant_with_next_list.contains(name) + }) + .map(|(_, values)| { let mut rotated_values = values.to_vec(); rotated_values.rotate_left(1); gen_stwo_circle_column::(