From 0f66c6a7550687e8a034ca22040b6af467849c64 Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Thu, 17 Oct 2024 13:49:09 +0200 Subject: [PATCH] Integrated plonky3 prover (#1857) This is a major change to the plonky3 prover to support proving many machines. # Sharing costs across tables - at setup phase, fixed columns for each machine are committed to for each possible size. This happens in separate commitments, so that the prover and verifier can pick the relevant ones for a given execution - for each phase of the proving, the corresponding traces across all machines are committed to jointly - the quotient chunks are committed to jointly across all tables # Multi-stage publics The implementation supports public values for each stage of each table. This is tested internally in the plonky3 crate but not end-to-end in pipeline tests. --------- Co-authored-by: Leo Alt --- Cargo.toml | 2 + ast/src/analyzed/mod.rs | 21 +- backend-utils/Cargo.toml | 16 + .../split.rs => backend-utils/src/lib.rs | 21 +- backend/Cargo.toml | 1 + backend/src/composite/mod.rs | 10 +- backend/src/lib.rs | 7 - backend/src/plonky3/mod.rs | 24 +- executor/src/witgen/mod.rs | 13 +- pipeline/src/pipeline.rs | 2 +- pipeline/src/test_util.rs | 26 +- pipeline/tests/asm.rs | 9 +- pipeline/tests/pil.rs | 14 +- pipeline/tests/powdr_std.rs | 22 +- plonky3/Cargo.toml | 1 + plonky3/src/check_constraints.rs | 161 ---- plonky3/src/circuit_builder.rs | 380 ++++----- plonky3/src/folder.rs | 66 +- plonky3/src/lib.rs | 3 - plonky3/src/proof.rs | 47 +- plonky3/src/prover.rs | 769 ++++++++++++------ plonky3/src/stark.rs | 369 +++++---- plonky3/src/symbolic_builder.rs | 27 +- plonky3/src/traits.rs | 33 +- plonky3/src/verifier.rs | 558 ++++++++----- 25 files changed, 1461 insertions(+), 1141 deletions(-) create mode 100644 backend-utils/Cargo.toml rename backend/src/composite/split.rs => backend-utils/src/lib.rs (94%) delete mode 100644 plonky3/src/check_constraints.rs diff --git a/Cargo.toml b/Cargo.toml index 75bac00111..31eb9a727b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ members = [ "riscv-executor", "riscv-syscalls", "schemas", + "backend-utils", ] exclude = [ "riscv-runtime" ] @@ -47,6 +48,7 @@ powdr-asm-to-pil = { path = "./asm-to-pil", version = "0.1.0-alpha.2" } powdr-isa-utils = { path = "./isa-utils", version = "0.1.0-alpha.2" } powdr-analysis = { path = "./analysis", version = "0.1.0-alpha.2" } powdr-backend = { path = "./backend", version = "0.1.0-alpha.2" } +powdr-backend-utils = { path = "./backend-utils", version = "0.1.0-alpha.2" } powdr-executor = { path = "./executor", version = "0.1.0-alpha.2" } powdr-importer = { path = "./importer", version = "0.1.0-alpha.2" } powdr-jit-compiler = { path = "./jit-compiler", version = "0.1.0-alpha.2" } diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 44f8ba7111..46bd15d816 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -326,23 +326,26 @@ impl Analyzed { .for_each(|definition| definition.post_visit_expressions_mut(f)) } - /// Retrieves (col_name, poly_id, offset) of each public witness in the trace. - pub fn get_publics(&self) -> Vec<(String, PolyID, usize)> { + /// Retrieves (col_name, poly_id, offset, stage) of each public witness in the trace. + pub fn get_publics(&self) -> Vec<(String, PolyID, usize, u8)> { let mut publics = self .public_declarations .values() .map(|public_declaration| { let column_name = public_declaration.referenced_poly_name(); - let poly_id = { + let (poly_id, stage) = { let symbol = &self.definitions[&public_declaration.polynomial.name].0; - symbol - .array_elements() - .nth(public_declaration.array_index.unwrap_or_default()) - .unwrap() - .1 + ( + symbol + .array_elements() + .nth(public_declaration.array_index.unwrap_or_default()) + .unwrap() + .1, + symbol.stage.unwrap_or_default() as u8, + ) }; let row_offset = public_declaration.index as usize; - (column_name, poly_id, row_offset) + (column_name, poly_id, row_offset, stage) }) .collect::>(); diff --git a/backend-utils/Cargo.toml b/backend-utils/Cargo.toml new file mode 100644 index 0000000000..6876f981d0 --- /dev/null +++ b/backend-utils/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "powdr-backend-utils" +version.workspace = true +edition.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +powdr-pil-analyzer.workspace = true +powdr-parser.workspace = true +powdr-ast.workspace = true +powdr-number.workspace = true +powdr-executor.workspace = true +log = "0.4.22" +itertools = "0.13.0" diff --git a/backend/src/composite/split.rs b/backend-utils/src/lib.rs similarity index 94% rename from backend/src/composite/split.rs rename to backend-utils/src/lib.rs index f5691df17d..2b095e193b 100644 --- a/backend/src/composite/split.rs +++ b/backend-utils/src/lib.rs @@ -25,7 +25,7 @@ const DUMMY_COLUMN_NAME: &str = "__dummy"; /// 1. The PIL is split into namespaces /// 2. Namespaces without any columns are duplicated and merged with the other namespaces /// 3. Any lookups or permutations that reference multiple namespaces are removed. -pub(crate) fn split_pil(pil: &Analyzed) -> BTreeMap> { +pub fn split_pil(pil: &Analyzed) -> BTreeMap> { let statements_by_namespace = split_by_namespace(pil); let statements_by_machine = merge_empty_namespaces(statements_by_namespace, pil); @@ -37,9 +37,9 @@ pub(crate) fn split_pil(pil: &Analyzed) -> BTreeMap( +pub fn machine_witness_columns( all_witness_columns: &[(String, Vec)], machine_pil: &Analyzed, machine_name: &str, @@ -71,10 +71,10 @@ pub(crate) fn machine_witness_columns( } /// Given a set of columns and a PIL describing the machine, returns the fixed column that belong to the machine. -pub(crate) fn machine_fixed_columns( - all_fixed_columns: &[(String, VariablySizedColumn)], - machine_pil: &Analyzed, -) -> BTreeMap)>> { +pub fn machine_fixed_columns<'a, F: FieldElement>( + all_fixed_columns: &'a [(String, VariablySizedColumn)], + machine_pil: &'a Analyzed, +) -> BTreeMap> { let machine_columns = select_machine_columns( all_fixed_columns, machine_pil.constant_polys_in_source_order(), @@ -106,12 +106,7 @@ pub(crate) fn machine_fixed_columns( size, machine_columns .iter() - .map(|(name, column)| { - ( - name.clone(), - column.get_by_size(size).unwrap().to_vec().into(), - ) - }) + .map(|(name, column)| (name.clone(), column.get_by_size(size).unwrap())) .collect::>(), ) }) diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 1d1798b0d1..770d861ce7 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -20,6 +20,7 @@ powdr-parser.workspace = true powdr-pil-analyzer.workspace = true powdr-executor.workspace = true powdr-parser-util.workspace = true +powdr-backend-utils.workspace = true powdr-plonky3 = { path = "../plonky3", optional = true } diff --git a/backend/src/composite/mod.rs b/backend/src/composite/mod.rs index 92032a52bc..ddce97366a 100644 --- a/backend/src/composite/mod.rs +++ b/backend/src/composite/mod.rs @@ -10,17 +10,15 @@ use std::{ use itertools::Itertools; use powdr_ast::analyzed::Analyzed; +use powdr_backend_utils::{machine_fixed_columns, machine_witness_columns}; use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback}; use powdr_number::{DegreeType, FieldElement}; use serde::{Deserialize, Serialize}; -use split::{machine_fixed_columns, machine_witness_columns}; use crate::{Backend, BackendFactory, BackendOptions, Error, Proof}; use self::sub_prover::RunStatus; -mod split; - /// Maps each size to the corresponding verification key. type VerificationKeyBySize = BTreeMap>; @@ -76,7 +74,7 @@ impl> BackendFactory for CompositeBacke unimplemented!(); } - let pils = split::split_pil(&pil); + let pils = powdr_backend_utils::split_pil(&pil); // Read the setup once (if any) to pass to all backends. let setup_bytes = setup.map(|setup| { @@ -109,6 +107,10 @@ impl> BackendFactory for CompositeBacke machine_fixed_columns(&fixed, &pil) .into_iter() .map(|(size, fixed)| { + let fixed = fixed + .into_iter() + .map(|(name, values)| (name, values.to_vec().into())) + .collect(); let pil = set_size(pil.clone(), size as DegreeType); // Set up readers for the setup and verification key let mut setup_cursor = setup_bytes.as_ref().map(Cursor::new); diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 86ab2ab089..97878163b7 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -46,9 +46,6 @@ pub enum BackendType { #[cfg(feature = "plonky3")] #[strum(serialize = "plonky3")] Plonky3, - #[cfg(feature = "plonky3")] - #[strum(serialize = "plonky3-composite")] - Plonky3Composite, } pub type BackendOptions = String; @@ -87,10 +84,6 @@ impl BackendType { } #[cfg(feature = "plonky3")] BackendType::Plonky3 => Box::new(plonky3::Factory), - #[cfg(feature = "plonky3")] - BackendType::Plonky3Composite => { - Box::new(composite::CompositeBackendFactory::new(plonky3::Factory)) - } } } } diff --git a/backend/src/plonky3/mod.rs b/backend/src/plonky3/mod.rs index f7f1837ff5..c8edba6ffe 100644 --- a/backend/src/plonky3/mod.rs +++ b/backend/src/plonky3/mod.rs @@ -1,10 +1,7 @@ use std::{io, path::PathBuf, sync::Arc}; use powdr_ast::analyzed::Analyzed; -use powdr_executor::{ - constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}, - witgen::WitgenCallback, -}; +use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback}; use powdr_number::{BabyBearField, GoldilocksField, Mersenne31Field}; use powdr_plonky3::{Commitment, FieldElementMap, Plonky3Prover, ProverData}; @@ -35,22 +32,8 @@ where if verification_app_key.is_some() { return Err(Error::NoAggregationAvailable); } - if pil.degrees().len() > 1 { - return Err(Error::NoVariableDegreeAvailable); - } - if pil.public_declarations_in_source_order().any(|(_, d)| { - pil.definitions.iter().any(|(_, (symbol, _))| { - symbol.absolute_name == d.name && symbol.stage.unwrap_or_default() > 0 - }) - }) { - return Err(Error::NoLaterStagePublicAvailable); - } - let fixed = Arc::new( - get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, - ); - - let mut p3 = Box::new(Plonky3Prover::new(pil.clone(), fixed.clone())); + let mut p3 = Box::new(Plonky3Prover::new(pil.clone(), fixed)); if let Some(verification_key) = verification_key { p3.set_verifying_key(verification_key); @@ -70,6 +53,9 @@ where Commitment: Send, { fn verify(&self, proof: &[u8], instances: &[Vec]) -> Result<(), Error> { + assert_eq!(instances.len(), 1); + let instances = &instances[0]; + Ok(self.verify(proof, instances)?) } diff --git a/executor/src/witgen/mod.rs b/executor/src/witgen/mod.rs index e14dfee33a..d47ff2cd49 100644 --- a/executor/src/witgen/mod.rs +++ b/executor/src/witgen/mod.rs @@ -288,7 +288,12 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> { log::debug!("Publics:"); for (name, value) in extract_publics(&witness_cols, self.analyzed) { - log::debug!(" {name:>30}: {value}"); + log::debug!( + " {name:>30}: {}", + value + .map(|value| value.to_string()) + .unwrap_or_else(|| "Not yet known at this stage".to_string()) + ); } witness_cols } @@ -297,7 +302,7 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> { pub fn extract_publics( witness: &[(String, Vec)], pil: &Analyzed, -) -> Vec<(String, T)> { +) -> Vec<(String, Option)> { let witness = witness .iter() .map(|(name, col)| (name.clone(), col)) @@ -306,7 +311,9 @@ pub fn extract_publics( .map(|(name, public_declaration)| { let poly_name = &public_declaration.referenced_poly_name(); let poly_index = public_declaration.index; - let value = witness[poly_name][poly_index as usize]; + let value = witness + .get(poly_name) + .map(|column| column[poly_index as usize]); ((*name).clone(), value) }) .collect() diff --git a/pipeline/src/pipeline.rs b/pipeline/src/pipeline.rs index f83cf9a490..31d1d9eb50 100644 --- a/pipeline/src/pipeline.rs +++ b/pipeline/src/pipeline.rs @@ -921,7 +921,7 @@ impl Pipeline { Ok(self.artifact.witness.as_ref().unwrap().clone()) } - pub fn publics(&self) -> Result, Vec> { + pub fn publics(&self) -> Result)>, Vec> { let pil = self.optimized_pil()?; let witness = self.witness()?; Ok(extract_publics(&witness, &pil)) diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index 70401310d7..924ce5b21c 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -180,7 +180,7 @@ pub fn gen_estark_proof_with_backend_variant( .publics() .unwrap() .iter() - .map(|(_name, v)| *v) + .map(|(_name, v)| v.expect("all publics should be known since we created a proof")) .collect(); pipeline.verify(&proof, &[publics]).unwrap(); @@ -277,7 +277,7 @@ pub fn gen_halo2_proof(pipeline: Pipeline, backend: BackendVariant) .publics() .unwrap() .iter() - .map(|(_name, v)| *v) + .map(|(_name, v)| v.expect("all publics should be known since we created a proof")) .collect(); pipeline.verify(&proof, &[publics]).unwrap(); @@ -287,15 +287,8 @@ pub fn gen_halo2_proof(pipeline: Pipeline, backend: BackendVariant) pub fn gen_halo2_proof(_pipeline: Pipeline, _backend: BackendVariant) {} #[cfg(feature = "plonky3")] -pub fn test_plonky3_with_backend_variant( - file_name: &str, - inputs: Vec, - backend: BackendVariant, -) { - let backend = match backend { - BackendVariant::Monolithic => powdr_backend::BackendType::Plonky3, - BackendVariant::Composite => powdr_backend::BackendType::Plonky3Composite, - }; +pub fn test_plonky3(file_name: &str, inputs: Vec) { + let backend = powdr_backend::BackendType::Plonky3; let mut pipeline = Pipeline::default() .with_tmp_output() .from_file(resolve_test_file(file_name)) @@ -310,7 +303,7 @@ pub fn test_plonky3_with_backend_variant( .clone() .unwrap() .iter() - .map(|(_name, v)| *v) + .map(|(_name, v)| v.expect("all publics should be known since we created a proof")) .collect(); pipeline.verify(&proof, &[publics.clone()]).unwrap(); @@ -333,7 +326,7 @@ pub fn test_plonky3_with_backend_variant( #[cfg(feature = "plonky3")] pub fn test_plonky3_pipeline(pipeline: Pipeline) { - let mut pipeline = pipeline.with_backend(powdr_backend::BackendType::Plonky3Composite, None); + let mut pipeline = pipeline.with_backend(powdr_backend::BackendType::Plonky3, None); pipeline.compute_witness().unwrap(); @@ -349,7 +342,7 @@ pub fn test_plonky3_pipeline(pipeline: Pipeline) { .clone() .unwrap() .iter() - .map(|(_name, v)| *v) + .map(|(_name, v)| v.expect("all publics should be known since we created a proof")) .collect(); pipeline.verify(&proof, &[publics.clone()]).unwrap(); @@ -371,7 +364,10 @@ pub fn test_plonky3_pipeline(pipeline: Pipeline) { } #[cfg(not(feature = "plonky3"))] -pub fn test_plonky3_with_backend_variant(_: &str, _: Vec, _: BackendVariant) {} +pub fn test_plonky3(_: &str, _: Vec) {} + +#[cfg(not(feature = "plonky3"))] +pub fn test_plonky3_pipeline(_: Pipeline) {} #[cfg(not(feature = "plonky3"))] pub fn gen_plonky3_proof(_: &str, _: Vec) {} diff --git a/pipeline/tests/asm.rs b/pipeline/tests/asm.rs index 7aa991785c..43ebc07040 100644 --- a/pipeline/tests/asm.rs +++ b/pipeline/tests/asm.rs @@ -8,8 +8,7 @@ use powdr_pipeline::{ asm_string_to_pil, gen_estark_proof_with_backend_variant, make_prepared_pipeline, make_simple_prepared_pipeline, regular_test, regular_test_without_babybear, resolve_test_file, run_pilcom_with_backend_variant, test_halo2, - test_halo2_with_backend_variant, test_pilcom, test_plonky3_with_backend_variant, - BackendVariant, + test_halo2_with_backend_variant, test_pilcom, test_plonky3, BackendVariant, }, util::{FixedPolySet, PolySet, WitnessPolySet}, Pipeline, @@ -39,11 +38,7 @@ fn simple_sum_asm() { let f = "asm/simple_sum.asm"; let i = [16, 4, 1, 2, 8, 5]; regular_test(f, &i); - test_plonky3_with_backend_variant::( - f, - slice_to_vec(&i), - BackendVariant::Composite, - ); + test_plonky3::(f, slice_to_vec(&i)); } #[test] diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index 3c50a0c0d6..178bf4a66f 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -7,7 +7,7 @@ use powdr_pipeline::test_util::{ assert_proofs_fail_for_invalid_witnesses_pilcom, gen_estark_proof, gen_estark_proof_with_backend_variant, make_prepared_pipeline, make_simple_prepared_pipeline, regular_test, run_pilcom_with_backend_variant, test_halo2, test_halo2_with_backend_variant, - test_pilcom, test_plonky3_with_backend_variant, BackendVariant, + test_pilcom, test_plonky3, BackendVariant, }; use test_log::test; @@ -87,11 +87,7 @@ fn permutation_with_selector() { fn fibonacci() { let f = "pil/fibonacci.pil"; regular_test(f, Default::default()); - test_plonky3_with_backend_variant::( - f, - Default::default(), - BackendVariant::Monolithic, - ); + test_plonky3::(f, Default::default()); } #[test] @@ -245,11 +241,7 @@ fn halo_without_lookup() { #[test] fn add() { let f = "pil/add.pil"; - test_plonky3_with_backend_variant::( - f, - Default::default(), - BackendVariant::Monolithic, - ); + test_plonky3::(f, Default::default()); } #[test] diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index 02f96d304a..cff9effceb 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -7,8 +7,8 @@ use powdr_pipeline::{ test_util::{ evaluate_function, evaluate_integer_function, execute_test_file, gen_estark_proof, gen_halo2_proof, make_simple_prepared_pipeline, regular_test, - regular_test_without_babybear, std_analyzed, test_halo2, test_pilcom, - test_plonky3_with_backend_variant, BackendVariant, + regular_test_without_babybear, std_analyzed, test_halo2, test_pilcom, test_plonky3, + BackendVariant, }, Pipeline, }; @@ -48,7 +48,7 @@ fn poseidon_gl_memory_test() { #[ignore = "Too slow"] fn keccakf16_test() { let f = "std/keccakf16_test.asm"; - test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); + test_plonky3::(f, vec![]); } #[test] @@ -71,21 +71,21 @@ fn split_gl_test() { #[ignore = "Too slow"] fn split_bb_test() { let f = "std/split_bb_test.asm"; - test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); + test_plonky3::(f, vec![]); } #[test] #[ignore = "Too slow"] fn add_sub_small_test() { let f = "std/add_sub_small_test.asm"; - test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); + test_plonky3::(f, vec![]); } #[test] #[ignore = "Too slow"] fn arith_small_test() { let f = "std/arith_small_test.asm"; - test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); + test_plonky3::(f, vec![]); } #[test] @@ -131,7 +131,7 @@ fn memory_large_test_parallel_accesses() { #[ignore = "Too slow"] fn memory_small_test() { let f = "std/memory_small_test.asm"; - test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); + test_plonky3::(f, vec![]); } #[test] @@ -228,14 +228,14 @@ fn binary_large_test() { #[ignore = "Too slow"] fn binary_small_8_test() { let f = "std/binary_small_8_test.asm"; - test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); + test_plonky3::(f, vec![]); } #[test] #[ignore = "Too slow"] fn binary_small_test() { let f = "std/binary_small_test.asm"; - test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); + test_plonky3::(f, vec![]); } #[test] @@ -250,7 +250,7 @@ fn shift_large_test() { #[ignore = "Too slow"] fn shift_small_test() { let f = "std/shift_small_test.asm"; - test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); + test_plonky3::(f, vec![]); } #[test] @@ -265,7 +265,7 @@ fn rotate_large_test() { #[ignore = "Too slow"] fn rotate_small_test() { let f = "std/rotate_small_test.asm"; - test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); + test_plonky3::(f, vec![]); } #[test] diff --git a/plonky3/Cargo.toml b/plonky3/Cargo.toml index d4c474ef4f..e71e1dbcf4 100644 --- a/plonky3/Cargo.toml +++ b/plonky3/Cargo.toml @@ -9,6 +9,7 @@ repository.workspace = true [dependencies] powdr-ast.workspace = true powdr-number.workspace = true +powdr-backend-utils.workspace = true rand = "0.8.5" powdr-analysis = { path = "../analysis" } powdr-executor = { path = "../executor" } diff --git a/plonky3/src/check_constraints.rs b/plonky3/src/check_constraints.rs deleted file mode 100644 index d033ef4d2d..0000000000 --- a/plonky3/src/check_constraints.rs +++ /dev/null @@ -1,161 +0,0 @@ -use alloc::vec::Vec; - -use itertools::Itertools; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, PairBuilder}; -use p3_field::Field; -use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; -use p3_matrix::stack::VerticalPair; -use p3_matrix::Matrix; -use tracing::instrument; - -use crate::traits::MultistageAirBuilder; - -#[instrument(name = "check constraints", skip_all)] -pub(crate) fn check_constraints( - air: &A, - preprocessed: &RowMajorMatrix, - traces_by_stage: Vec<&RowMajorMatrix>, - public_values_by_stage: &Vec<&Vec>, - challenges: Vec<&Vec>, -) where - F: Field, - A: for<'a> Air>, -{ - let num_stages = traces_by_stage.len(); - let height = traces_by_stage[0].height(); - - (0..height).for_each(|i| { - let i_next = (i + 1) % height; - - let local_preprocessed = preprocessed.row_slice(i); - let next_preprocessed = preprocessed.row_slice(i_next); - let preprocessed = VerticalPair::new( - RowMajorMatrixView::new_row(&*local_preprocessed), - RowMajorMatrixView::new_row(&*next_preprocessed), - ); - - let stages_local_next = traces_by_stage - .iter() - .map(|trace| { - let stage_local = trace.row_slice(i); - let stage_next = trace.row_slice(i_next); - (stage_local, stage_next) - }) - .collect_vec(); - - let traces_by_stage = (0..num_stages) - .map(|stage| { - VerticalPair::new( - RowMajorMatrixView::new_row(&*stages_local_next[stage].0), - RowMajorMatrixView::new_row(&*stages_local_next[stage].1), - ) - }) - .collect(); - - let mut builder = DebugConstraintBuilder { - row_index: i, - challenges: challenges.clone(), - preprocessed, - traces_by_stage, - public_values_by_stage, - is_first_row: F::from_bool(i == 0), - is_last_row: F::from_bool(i == height - 1), - is_transition: F::from_bool(i != height - 1), - }; - - air.eval(&mut builder); - }); -} - -/// An `AirBuilder` which asserts that each constraint is zero, allowing any failed constraints to -/// be detected early. -#[derive(Debug)] -pub struct DebugConstraintBuilder<'a, F: Field> { - row_index: usize, - preprocessed: VerticalPair, RowMajorMatrixView<'a, F>>, - challenges: Vec<&'a Vec>, - traces_by_stage: Vec, RowMajorMatrixView<'a, F>>>, - public_values_by_stage: &'a [&'a Vec], - is_first_row: F, - is_last_row: F, - is_transition: F, -} - -impl<'a, F> AirBuilder for DebugConstraintBuilder<'a, F> -where - F: Field, -{ - type F = F; - type Expr = F; - type Var = F; - type M = VerticalPair, RowMajorMatrixView<'a, F>>; - - fn is_first_row(&self) -> Self::Expr { - self.is_first_row - } - - fn is_last_row(&self) -> Self::Expr { - self.is_last_row - } - - fn is_transition_window(&self, size: usize) -> Self::Expr { - if size == 2 { - self.is_transition - } else { - panic!("only supports a window size of 2") - } - } - - fn main(&self) -> Self::M { - self.traces_by_stage[0] - } - - fn assert_zero>(&mut self, x: I) { - assert_eq!( - x.into(), - F::zero(), - "constraints had nonzero value on row {}", - self.row_index - ); - } - - fn assert_eq, I2: Into>(&mut self, x: I1, y: I2) { - let x = x.into(); - let y = y.into(); - assert_eq!( - x, y, - "values didn't match on row {}: {} != {}", - self.row_index, x, y - ); - } -} - -impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F> { - type PublicVar = Self::F; - - fn public_values(&self) -> &[Self::PublicVar] { - self.stage_public_values(0) - } -} - -impl<'a, F: Field> PairBuilder for DebugConstraintBuilder<'a, F> { - fn preprocessed(&self) -> Self::M { - self.preprocessed - } -} - -impl<'a, F: Field> MultistageAirBuilder for DebugConstraintBuilder<'a, F> { - type Challenge = Self::Expr; - - fn stage_public_values(&self, stage: usize) -> &[Self::F] { - self.public_values_by_stage[stage] - } - - fn stage_trace(&self, stage: usize) -> Self::M { - self.traces_by_stage[stage] - } - - fn stage_challenges(&self, stage: usize) -> &[Self::Expr] { - self.challenges[stage] - } -} diff --git a/plonky3/src/circuit_builder.rs b/plonky3/src/circuit_builder.rs index 70e7c98970..0856e266ef 100644 --- a/plonky3/src/circuit_builder.rs +++ b/plonky3/src/circuit_builder.rs @@ -8,51 +8,46 @@ use itertools::Itertools; use p3_field::AbstractField; -use std::{ - cell::RefCell, - collections::{BTreeMap, BTreeSet, HashMap}, - sync::Mutex, -}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; -use crate::params::{Commitment, FieldElementMap, Plonky3Field, ProverData}; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder}; +use crate::{ + params::{Commitment, FieldElementMap, Plonky3Field, ProverData}, + AirStage, +}; +use p3_air::{Air, AirBuilder, BaseAir, PairBuilder}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, - AlgebraicUnaryOperation, AlgebraicUnaryOperator, Analyzed, Challenge, Identity, IdentityKind, - PolyID, PolynomialType, SelectedExpressions, + AlgebraicUnaryOperation, AlgebraicUnaryOperator, Analyzed, Identity, IdentityKind, PolyID, + PolynomialType, SelectedExpressions, }; -use crate::{CallbackResult, MultiStageAir, MultistageAirBuilder, NextStageTraceCallback}; +use crate::{CallbackResult, MultiStageAir, MultistageAirBuilder}; use powdr_ast::parsed::visitor::ExpressionVisitable; use powdr_executor::witgen::WitgenCallback; use powdr_number::{FieldElement, LargeInt}; -type Witness = Mutex)>>>; - /// A description of the constraint system. /// All of the data is derived from the analyzed PIL, but is materialized /// here for performance reasons. -struct ConstraintSystem { +pub struct ConstraintSystem { // for each witness column, the stage and index of this column in this stage witness_columns: HashMap, // for each fixed column, the index of this column in the fixed columns fixed_columns: HashMap, identities: Vec>>>, - publics: Vec<(String, PolyID, usize)>, - commitment_count: usize, + // for each public column, the name, poly_id, index in the witness columns, and stage + pub(crate) publics_by_stage: Vec>, constant_count: usize, - // for each stage, the number of witness columns + // for each stage, the number of witness columns. There is always a least one stage, possibly empty stage_widths: Vec, - challenges: BTreeSet, + challenges_by_stage: Vec>, } impl From<&Analyzed> for ConstraintSystem { fn from(analyzed: &Analyzed) -> Self { let identities = analyzed.identities_with_inlined_intermediate_polynomials(); - let publics = analyzed.get_publics(); - let commitment_count = analyzed.commitment_count(); let constant_count = analyzed.constant_count(); let stage_widths = (0..analyzed.stage_count() as u32) .map(|stage| { @@ -88,42 +83,104 @@ impl From<&Analyzed> for ConstraintSystem { }) .collect(); - let mut challenges = BTreeSet::default(); + let mut challenges_by_stage = vec![vec![]; analyzed.stage_count()]; for identity in &identities { identity.pre_visit_expressions(&mut |expr| { if let AlgebraicExpression::Challenge(challenge) = expr { - challenges.insert(*challenge); + challenges_by_stage[challenge.stage as usize].push(challenge.id); } }); } + let publics_by_stage = analyzed.get_publics().into_iter().fold( + vec![vec![]; analyzed.stage_count()], + |mut acc, (name, id, row, stage)| { + acc[stage as usize].push((name, id, row)); + acc + }, + ); + Self { identities, - publics, - commitment_count, + publics_by_stage, constant_count, stage_widths, witness_columns, fixed_columns, - challenges, + challenges_by_stage, } } } -pub(crate) struct PowdrCircuit +pub(crate) struct PowdrCircuit<'a, T: FieldElementMap> where ProverData: Send, Commitment: Send, { - /// The constraint system description - constraint_system: ConstraintSystem, - /// The values of the witness, in a [RefCell] as it gets mutated as we go through stages - witness_so_far: Witness, + /// The split program + pub split: &'a BTreeMap, ConstraintSystem)>, /// Callback to augment the witness in the later stages witgen_callback: Option>, - /// The matrix of preprocessed values, used in debug mode to check the constraints before proving - #[cfg(debug_assertions)] - preprocessed: Option>>, +} + +impl<'a, T: FieldElementMap> PowdrCircuit<'a, T> +where + ProverData: Send, + Commitment: Send, +{ + pub(crate) fn new(split: &'a BTreeMap, ConstraintSystem)>) -> Self { + Self { + split, + witgen_callback: None, + } + } + + /// Calculates public values from generated witness values. + /// For stages in which there are no public values, return an empty vector + pub(crate) fn public_values_so_far( + &self, + witness: &[(String, Vec)], + ) -> BTreeMap>>> { + let witness = witness + .iter() + // this map seems redundant but it turns a reference over a tuple into a tuple of references + .map(|(name, values)| (name, values)) + .collect::>(); + + self.split + .iter() + .map(|(name, (_, table))| { + let res = table + .publics_by_stage + .iter() + .map(|publics| { + publics + .iter() + .map(|(name, _, row)| witness.get(name).map(|column| column[*row])) + .collect() + }) + .collect(); + + (name.clone(), res) + }) + .collect() + } + + pub(crate) fn with_witgen_callback(self, witgen_callback: WitgenCallback) -> Self { + Self { + witgen_callback: Some(witgen_callback), + ..self + } + } +} + +pub(crate) struct PowdrTable<'a, T: FieldElementMap> +where + ProverData: Send, + Commitment: Send, +{ + /// The constraint system description + constraint_system: &'a ConstraintSystem, } /// Convert a witness for a stage @@ -152,64 +209,13 @@ where RowMajorMatrix::new(values, width) } -impl PowdrCircuit +impl<'a, T: FieldElementMap> PowdrTable<'a, T> where ProverData: Send, Commitment: Send, { - pub(crate) fn new(analyzed: &Analyzed) -> Self { - Self { - constraint_system: analyzed.into(), - witgen_callback: None, - witness_so_far: Default::default(), - #[cfg(debug_assertions)] - preprocessed: None, - } - } - - /// Calculates public values from generated witness values. - pub(crate) fn public_values_so_far(&self) -> Vec> { - let binding = &self.witness_so_far.lock().unwrap(); - let witness = binding.borrow(); - - let witness = witness - .iter() - .map(|(name, values)| (name, values)) - .collect::>(); - - self.constraint_system - .publics - .iter() - .filter_map(|(col_name, _, idx)| { - witness - .get(&col_name) - .map(|column| column[*idx].into_p3_field()) - }) - .collect() - } - - pub(crate) fn with_phase_0_witness(self, witness: &[(String, Vec)]) -> Self { - assert!(self.witness_so_far.lock().unwrap().borrow().is_empty()); - Self { - witness_so_far: RefCell::new(witness.to_vec()).into(), - ..self - } - } - - pub(crate) fn with_witgen_callback(self, witgen_callback: WitgenCallback) -> Self { - Self { - witgen_callback: Some(witgen_callback), - ..self - } - } - - #[cfg(debug_assertions)] - pub(crate) fn with_preprocessed( - mut self, - preprocessed_matrix: RowMajorMatrix>, - ) -> Self { - self.preprocessed = Some(preprocessed_matrix); - self + pub(crate) fn new(constraint_system: &'a ConstraintSystem) -> Self { + Self { constraint_system } } /// Conversion to plonky3 expression @@ -218,8 +224,8 @@ where e: &AlgebraicExpression, traces_by_stage: &[AB::M], fixed: &AB::M, - publics: &BTreeMap<&String, ::PublicVar>, - challenges: &BTreeMap::Challenge>>, + publics: &BTreeMap<&String, ::PublicVar>, + challenges: &[BTreeMap<&u64, ::Challenge>], ) -> AB::Expr { use AlgebraicBinaryOperator::*; let res = match e { @@ -288,9 +294,10 @@ where AlgebraicUnaryOperator::Minus => -expr, } } - AlgebraicExpression::Challenge(challenge) => { - challenges[&challenge.stage][&challenge.id].clone().into() - } + AlgebraicExpression::Challenge(challenge) => challenges[challenge.stage as usize] + [&challenge.id] + .clone() + .into(), }; res } @@ -298,85 +305,86 @@ where /// An extension of [Air] allowing access to the number of fixed columns -impl BaseAir> for PowdrCircuit +impl<'a, T: FieldElementMap> BaseAir> for PowdrTable<'a, T> where ProverData: Send, Commitment: Send, { fn width(&self) -> usize { - self.constraint_system.commitment_count + unimplemented!("use MultiStageAir method instead") } fn preprocessed_trace(&self) -> Option>> { - #[cfg(debug_assertions)] - { - self.preprocessed.clone() - } - #[cfg(not(debug_assertions))] unimplemented!() } } -impl< - T: FieldElementMap, - AB: AirBuilderWithPublicValues> + PairBuilder + MultistageAirBuilder, - > Air for PowdrCircuit +impl<'a, T: FieldElementMap, AB: PairBuilder + MultistageAirBuilder>> Air + for PowdrTable<'a, T> where ProverData: Send, Commitment: Send, { fn eval(&self, builder: &mut AB) { let stage_count = >::stage_count(self); - let trace_by_stage: Vec = (0..stage_count).map(|i| builder.stage_trace(i)).collect(); + let traces_by_stage: Vec = + (0..stage_count).map(|i| builder.stage_trace(i)).collect(); let fixed = builder.preprocessed(); - let pi = builder.public_values(); + let public_input_values_by_stage = (0..stage_count) + .map(|i| builder.stage_public_values(i)) + .collect_vec(); // for each stage, the values of the challenges drawn at the end of that stage - let challenges: BTreeMap> = self + let challenges_by_stage: Vec> = self .constraint_system - .challenges + .challenges_by_stage .iter() - .map(|c| (c.stage, c.id)) - .into_group_map() - .into_iter() + .enumerate() .map(|(stage, ids)| { - let p3_challenges = builder.stage_challenges(stage as usize).to_vec(); - assert_eq!(p3_challenges.len(), ids.len()); - (stage, ids.into_iter().zip(p3_challenges).collect()) + let stage_challenges = builder.stage_challenges(stage as u8); + ( + stage, + ids.iter() + .zip_eq(stage_challenges.iter().cloned()) + .collect(), + ) }) - .collect(); - assert_eq!(self.constraint_system.publics.len(), pi.len()); - - let stage_0_local = trace_by_stage[0].row_slice(0); + .fold( + vec![BTreeMap::default(); stage_count as usize], + |mut acc, (stage, challenges)| { + acc[stage] = challenges; + acc + }, + ); // public constraints let public_vals_by_id = self .constraint_system - .publics + .publics_by_stage .iter() - .zip(pi.to_vec()) - .map(|((id, _, _), val)| (id, val)) - .collect::::PublicVar>>(); + .zip_eq(public_input_values_by_stage) + .flat_map(|(publics, values)| publics.iter().zip_eq(values.iter())) + .map(|((id, _, _), pi)| (id, *pi)) + .collect::::PublicVar>>(); // constrain public inputs using witness columns in stage 0 let fixed_local = fixed.row_slice(0); let public_offset = self.constraint_system.constant_count; - self.constraint_system.publics.iter().enumerate().for_each( - |(index, (pub_id, poly_id, _))| { + self.constraint_system + .publics_by_stage + .iter() + .flatten() + .enumerate() + .for_each(|(index, (pub_id, poly_id, _))| { let selector = fixed_local[public_offset + index]; let (stage, index) = self.constraint_system.witness_columns[poly_id]; - assert_eq!( - stage, 0, - "public inputs are only allowed in the first stage" - ); - let witness_col = stage_0_local[index]; + let witness_col = traces_by_stage[stage].row_slice(0)[index]; let public_value = public_vals_by_id[pub_id]; // constraining s(i) * (pub[i] - x(i)) = 0 builder.assert_zero(selector * (public_value.into() - witness_col)); - }, - ); + }); // circuit constraints for identity in &self.constraint_system.identities { @@ -388,10 +396,10 @@ where let left = self.to_plonky3_expr::( identity.left.selector.as_ref().unwrap(), - &trace_by_stage, + &traces_by_stage, &fixed, &public_vals_by_id, - &challenges, + &challenges_by_stage, ); builder.assert_zero(left); @@ -406,95 +414,105 @@ where } } -impl< - T: FieldElementMap, - AB: AirBuilderWithPublicValues> + PairBuilder + MultistageAirBuilder, - > MultiStageAir for PowdrCircuit +impl<'a, T: FieldElementMap, AB: PairBuilder + MultistageAirBuilder>> + MultiStageAir for PowdrTable<'a, T> where ProverData: Send, Commitment: Send, { + fn stage_public_count(&self, stage: u8) -> usize { + self.constraint_system.publics_by_stage[stage as usize].len() + } + fn preprocessed_width(&self) -> usize { - self.constraint_system.constant_count + self.constraint_system.publics.len() + self.constraint_system.constant_count + + self + .constraint_system + .publics_by_stage + .iter() + .map(|publics| publics.len()) + .sum::() } - fn stage_count(&self) -> usize { - self.constraint_system.stage_widths.len() + fn stage_count(&self) -> u8 { + self.constraint_system.stage_widths.len() as u8 } - fn stage_trace_width(&self, stage: u32) -> usize { + fn stage_trace_width(&self, stage: u8) -> usize { self.constraint_system.stage_widths[stage as usize] } - fn stage_challenge_count(&self, stage: u32) -> usize { - self.constraint_system - .challenges - .iter() - .filter(|c| c.stage == stage) - .count() + fn stage_challenge_count(&self, stage: u8) -> usize { + self.constraint_system.challenges_by_stage[stage as usize].len() } } -impl NextStageTraceCallback for PowdrCircuit +impl<'a, T: FieldElementMap> PowdrCircuit<'a, T> where ProverData: Send, Commitment: Send, { - // this wraps the witgen callback to make it compatible with p3: - // - p3 passes its local challenge values and the stage id - // - it receives the trace for the next stage in the expected format, - // as well as the public values for this stage and the updated challenges - // for the previous stage - // - internally, the full witness is accumulated in [Self] as it's needed - // in order to call the witgen callback - fn compute_stage( + /// Computes the stage data for stage number `trace_stage` based on `new_challenge_values` drawn at the end of stage `trace_stage - 1`. + pub fn compute_stage( &self, - trace_stage: u32, + trace_stage: u8, new_challenge_values: &[Plonky3Field], + witness: &mut Vec<(String, Vec)>, ) -> CallbackResult> { - let witness = self.witness_so_far.lock().unwrap(); - let mut witness = witness.borrow_mut(); - - let previous_stage_challenges: BTreeSet = self - .constraint_system - .challenges - .iter() - .filter(|c| c.stage == trace_stage - 1) - .cloned() + let previous_stage_challenges: BTreeSet<&u64> = self + .split + .values() + .flat_map(|(_, constraint_system)| { + &constraint_system.challenges_by_stage[trace_stage as usize - 1] + }) .collect(); + assert_eq!(previous_stage_challenges.len(), new_challenge_values.len()); let challenge_map = previous_stage_challenges .into_iter() .zip(new_challenge_values) - .map(|(c, v)| (c.id, T::from_p3_field(*v))) + .map(|(c, v)| (*c, T::from_p3_field(*v))) .collect(); + // remember the columns we already know about let columns_before: BTreeSet = witness.iter().map(|(name, _)| name.clone()).collect(); - // to call the witgen callback, we need to pass the witness for all stages so far + // call the witgen callback, updating the witness *witness = { self.witgen_callback.as_ref().unwrap().next_stage_witness( - &witness, + witness, challenge_map, - trace_stage as u8, + trace_stage, ) }; + let public_values = self.public_values_so_far(witness); + // generate the next trace in the format p3 expects - // TODO: since the current witgen callback returns the entire witness so far, - // we filter out the columns we already know about. Instead, return only the - // new witness in the witgen callback. - let trace = generate_matrix( - witness - .iter() - .filter(|(name, _)| !columns_before.contains(name)) - .map(|(name, values)| (name, values.as_ref())), - ); + // since the witgen callback returns the entire witness so far, + // we filter out the columns we already know about + let air_stages = witness + .iter() + .filter(|(name, _)| !columns_before.contains(name)) + .map(|(name, values)| (name, values.as_ref())) + .into_group_map_by(|(name, _)| name.split("::").next().unwrap()) + .into_iter() + .map(|(table_name, columns)| { + ( + table_name.to_string(), + AirStage { + trace: generate_matrix(columns.into_iter()), + public_values: public_values[table_name][trace_stage as usize] + .iter() + .map(|v| v.expect("public value for stage {trace_stage} should be available at this point").into_p3_field()) + .collect(), + }, + ) + }) + .collect(); - // return the next trace - // later stage publics are not supported, so we return an empty vector. TODO: change this - // shared challenges are unsupported so we return the local challenges. TODO: change this - CallbackResult::new(trace, vec![], new_challenge_values.to_vec()) + // return the next stage for each table + CallbackResult { air_stages } } } diff --git a/plonky3/src/folder.rs b/plonky3/src/folder.rs index d581e1dbf9..4e6ac42b0a 100644 --- a/plonky3/src/folder.rs +++ b/plonky3/src/folder.rs @@ -1,8 +1,8 @@ use alloc::vec::Vec; -use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder}; +use p3_air::{AirBuilder, PairBuilder}; use p3_field::AbstractField; -use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; +use p3_matrix::dense::RowMajorMatrixView; use p3_matrix::stack::VerticalPair; use crate::traits::MultistageAirBuilder; @@ -10,10 +10,10 @@ use p3_uni_stark::{PackedChallenge, PackedVal, StarkGenericConfig, Val}; #[derive(Debug)] pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { - pub challenges: Vec>>, - pub traces_by_stage: Vec>>, - pub preprocessed: RowMajorMatrix>, - pub public_values_by_stage: &'a Vec>>, + pub challenges: &'a [Vec>], + pub traces_by_stage: Vec>>, + pub preprocessed: RowMajorMatrixView<'a, PackedVal>, + pub public_values_by_stage: &'a [Vec>], pub is_first_row: PackedVal, pub is_last_row: PackedVal, pub is_transition: PackedVal, @@ -25,10 +25,10 @@ type ViewPair<'a, T> = VerticalPair, RowMajorMatrixVie #[derive(Debug)] pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> { - pub challenges: Vec>>, + pub challenges: &'a [Vec>], pub traces_by_stage: Vec>, pub preprocessed: ViewPair<'a, SC::Challenge>, - pub public_values_by_stage: Vec<&'a Vec>>, + pub public_values_by_stage: &'a [Vec>], pub is_first_row: SC::Challenge, pub is_last_row: SC::Challenge, pub is_transition: SC::Challenge, @@ -40,10 +40,10 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> { type F = Val; type Expr = PackedVal; type Var = PackedVal; - type M = RowMajorMatrix>; + type M = RowMajorMatrixView<'a, PackedVal>; fn main(&self) -> Self::M { - self.traces_by_stage[0].clone() + unimplemented!("use MultiStageAirBuilder instead") } fn is_first_row(&self) -> Self::Expr { @@ -69,32 +69,25 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> { } } -impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraintFolder<'a, SC> { - type PublicVar = Val; - - fn public_values(&self) -> &[Self::PublicVar] { - self.stage_public_values(0) - } -} - impl<'a, SC: StarkGenericConfig> MultistageAirBuilder for ProverConstraintFolder<'a, SC> { type Challenge = Val; + type PublicVar = Val; - fn stage_trace(&self, stage: usize) -> ::M { - self.traces_by_stage[stage].clone() + fn stage_trace(&self, stage: u8) -> ::M { + self.traces_by_stage[stage as usize] } - fn stage_challenges(&self, stage: usize) -> &[Self::Challenge] { - &self.challenges[stage] + fn stage_challenges(&self, stage: u8) -> &[Self::Challenge] { + &self.challenges[stage as usize] } - fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { - &self.public_values_by_stage[stage] + fn stage_public_values(&self, stage: u8) -> &[Self::PublicVar] { + &self.public_values_by_stage[stage as usize] } } impl<'a, SC: StarkGenericConfig> PairBuilder for ProverConstraintFolder<'a, SC> { fn preprocessed(&self) -> Self::M { - self.preprocessed.clone() + self.preprocessed } } @@ -105,7 +98,7 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> type M = ViewPair<'a, SC::Challenge>; fn main(&self) -> Self::M { - self.traces_by_stage[0] + unimplemented!("use MultiStageAirBuilder instead") } fn is_first_row(&self) -> Self::Expr { @@ -131,26 +124,19 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> } } -impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstraintFolder<'a, SC> { - type PublicVar = Val; - - fn public_values(&self) -> &[Self::PublicVar] { - self.stage_public_values(0) - } -} - impl<'a, SC: StarkGenericConfig> MultistageAirBuilder for VerifierConstraintFolder<'a, SC> { type Challenge = Val; + type PublicVar = Val; - fn stage_trace(&self, stage: usize) -> ::M { - self.traces_by_stage[stage] + fn stage_trace(&self, stage: u8) -> ::M { + self.traces_by_stage[stage as usize] } - fn stage_challenges(&self, stage: usize) -> &[Self::Challenge] { - &self.challenges[stage] + fn stage_challenges(&self, stage: u8) -> &[Self::Challenge] { + &self.challenges[stage as usize] } - fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { - self.public_values_by_stage[stage] + fn stage_public_values(&self, stage: u8) -> &[Self::PublicVar] { + &self.public_values_by_stage[stage as usize] } } diff --git a/plonky3/src/lib.rs b/plonky3/src/lib.rs index 961fe9635e..92fb67de65 100644 --- a/plonky3/src/lib.rs +++ b/plonky3/src/lib.rs @@ -15,9 +15,6 @@ use prover::*; use traits::*; use verifier::*; -#[cfg(debug_assertions)] -mod check_constraints; - mod circuit_builder; mod params; mod stark; diff --git a/plonky3/src/proof.rs b/plonky3/src/proof.rs index c44c1642d6..ec5306dfb9 100644 --- a/plonky3/src/proof.rs +++ b/plonky3/src/proof.rs @@ -1,16 +1,17 @@ +use std::collections::BTreeMap; + use alloc::vec::Vec; use p3_commit::Pcs; -use p3_matrix::dense::RowMajorMatrix; use serde::{Deserialize, Serialize}; use p3_uni_stark::{StarkGenericConfig, Val}; -type Com = <::Pcs as Pcs< +pub type Com = <::Pcs as Pcs< ::Challenge, ::Challenger, >>::Commitment; -type PcsProof = <::Pcs as Pcs< +pub type PcsProof = <::Pcs as Pcs< ::Challenge, ::Challenger, >>::Proof; @@ -25,7 +26,6 @@ pub struct Proof { pub(crate) commitments: Commitments>, pub(crate) opened_values: OpenedValues, pub(crate) opening_proof: PcsProof, - pub(crate) degree_bits: usize, } #[derive(Debug, Serialize, Deserialize)] @@ -34,31 +34,48 @@ pub struct Commitments { pub(crate) quotient_chunks: Com, } +pub type OpenedValues = BTreeMap>; + #[derive(Debug, Serialize, Deserialize)] -pub struct OpenedValues { - pub(crate) preprocessed_local: Vec, - pub(crate) preprocessed_next: Vec, - pub(crate) traces_by_stage_local: Vec>, - pub(crate) traces_by_stage_next: Vec>, +pub struct TableOpenedValues { + pub(crate) preprocessed: Option>, + pub(crate) traces_by_stage: Vec>, pub(crate) quotient_chunks: Vec>, + pub(crate) log_degree: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct StageOpenedValues { + pub(crate) local: Vec, + pub(crate) next: Vec, } pub struct StarkProvingKey { - pub preprocessed_commit: Com, - pub preprocessed_data: PcsProverData, + // for each table, the preprocessed data + pub preprocessed: BTreeMap>, +} + +/// For each possible size, the commitment and prover data +pub type TableProvingKeyCollection = BTreeMap>; + +/// For each possible size, the commitment +pub type TableVerifyingKeyCollection = BTreeMap>; + +pub struct TableProvingKey { + pub commitment: Com, + pub prover_data: PcsProverData, } #[derive(Serialize, Deserialize)] #[serde(bound = "")] pub struct StarkVerifyingKey { - pub preprocessed_commit: Com, + // for each table, for each possible size, the commitment + pub preprocessed: BTreeMap>, } pub struct ProcessedStage { pub(crate) commitment: Com, pub(crate) prover_data: PcsProverData, pub(crate) challenge_values: Vec>, - pub(crate) public_values: Vec>, - #[cfg(debug_assertions)] - pub(crate) trace: RowMajorMatrix>, + pub(crate) public_values: Vec>>, } diff --git a/plonky3/src/prover.rs b/plonky3/src/prover.rs index 2a21338f7f..3fb48b38c3 100644 --- a/plonky3/src/prover.rs +++ b/plonky3/src/prover.rs @@ -1,71 +1,468 @@ -use alloc::borrow::ToOwned; use alloc::vec; use alloc::vec::Vec; use core::iter::{self, once}; +use powdr_backend_utils::machine_witness_columns; +use std::collections::BTreeMap; -use itertools::{izip, Itertools}; +use itertools::Itertools; use p3_air::Air; use p3_challenger::{CanObserve, CanSample, FieldChallenger}; -use p3_commit::{Pcs, PolynomialSpace}; +use p3_commit::{Pcs as _, PolynomialSpace}; use p3_field::{AbstractExtensionField, AbstractField, PackedValue}; -use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::dense::{DenseMatrix, RowMajorMatrix}; use p3_matrix::Matrix; use p3_maybe_rayon::prelude::*; use p3_util::log2_strict_usize; use tracing::{info_span, instrument}; +use crate::circuit_builder::{generate_matrix, PowdrCircuit, PowdrTable}; +use crate::params::{Challenge, Challenger, Pcs}; +use crate::proof::{OpenedValues, StageOpenedValues}; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; use crate::traits::MultiStageAir; use crate::{ - Commitments, OpenedValues, ProcessedStage, Proof, ProverConstraintFolder, StarkProvingKey, + Com, Commitment, Commitments, FieldElementMap, PcsProof, PcsProverData, ProcessedStage, Proof, + ProverConstraintFolder, ProverData, StarkProvingKey, TableOpenedValues, + TableProvingKeyCollection, }; use p3_uni_stark::{Domain, PackedChallenge, PackedVal, StarkGenericConfig, Val}; +pub(crate) struct MultiTable<'a, T: FieldElementMap> +where + ProverData: Send, + Commitment: Send, +{ + pub(crate) tables: BTreeMap>, +} + +impl<'a, T: FieldElementMap> MultiTable<'a, T> +where + ProverData: Send, + Commitment: Send, +{ + fn table_count(&self) -> usize { + self.tables.len() + } + + fn table_names(&self) -> Vec<&String> { + self.tables.keys().collect() + } + + /// Returns the number of stages in the table with the most stages. + /// + /// # Panics + /// + /// Panics if there are no tables. + fn stage_count(&self) -> u8 { + self.tables + .values() + .map(|i| &i.air) + .map(<_ as MultiStageAir>>::stage_count) + .max() + .expect("expected at least one table") + } + + /// Observe the instance for each table. + fn observe_instances(&self, challenger: &mut Challenger) { + for input in self.tables.values() { + input.observe_instance(challenger); + } + } + + fn quotient_chunks_count(&self) -> usize { + self.tables + .values() + .map(|table| 1 << table.log_quotient_degree()) + .sum() + } + + /// Commit to the quotient polynomial across all tables. + /// + /// Returns a single commitment and the prover data. + fn commit_to_quotient( + &self, + state: &mut ProverState<'a, T>, + proving_key: Option<&StarkProvingKey>, + ) -> (Com, PcsProverData) { + let alpha: Challenge = state.challenger.sample_ext_element(); + + // get the quotient domains and chunks for each table + let quotient_domains_and_chunks: Vec<_> = self + .tables + .iter() + .enumerate() + .flat_map(|(index, (name, i))| { + i.quotient_domains_and_chunks( + index, + state, + proving_key + .as_ref() + .and_then(|proving_key| proving_key.preprocessed.get(name)), + alpha, + ) + }) + .collect(); + + assert_eq!( + quotient_domains_and_chunks.len(), + self.quotient_chunks_count() + ); + + // commit to the chunks + let (quotient_commit, quotient_data) = info_span!("commit to quotient poly chunks") + .in_scope(|| state.pcs.commit(quotient_domains_and_chunks)); + // observe the commitment + state.challenger.observe(quotient_commit.clone()); + + (quotient_commit, quotient_data) + } + + /// Opens the commitments to the preprocessed trace, the traces, and the quotient polynomial. + fn open( + &self, + state: &mut ProverState, + proving_key: Option<&StarkProvingKey>, + quotient_data: PcsProverData, + ) -> (OpenedValues>, PcsProof) { + let zeta: Challenge = state.challenger.sample(); + + let preprocessed_data_and_opening_points = proving_key + .as_ref() + .map(|key| { + self.tables.iter().filter_map(|(name, table)| { + key.preprocessed.get(name).map(|preprocessed| { + ( + // pick the preprocessed data for this table in the correct size + &preprocessed[&(1 << table.log_degree())].prover_data, + vec![vec![ + zeta, + table.trace_domain(state.pcs).next_point(zeta).unwrap(), + ]], + ) + }) + }) + }) + .into_iter() + .flatten(); + + let trace_data_and_points_per_stage: Vec<(_, Vec>)> = state + .processed_stages + .iter() + .map(|processed_stage| { + let points = self + .tables + .values() + .map(|input| { + vec![ + zeta, + input.trace_domain(state.pcs).next_point(zeta).unwrap(), + ] + }) + .collect(); + (&processed_stage.prover_data, points) + }) + .collect(); + + let quotient_opening_points: Vec<_> = (0..self.quotient_chunks_count()) + .map(|_| vec![zeta]) + .collect(); + + let (opened_values, proof) = state.pcs.open( + preprocessed_data_and_opening_points + .chain(trace_data_and_points_per_stage) + .chain(once(("ient_data, quotient_opening_points))) + .collect(), + state.challenger, + ); + + let mut opened_values = opened_values.into_iter(); + + // maybe get values for the preprocessed columns + let preprocessed: Vec<_> = if let Some(proving_key) = proving_key { + state + .program + .tables + .keys() + .map(|name| { + proving_key.preprocessed.contains_key(name).then(|| { + let value = opened_values.next().unwrap(); + assert_eq!(value.len(), 1); + StageOpenedValues { + local: value[0][0].clone(), + next: value[0][1].clone(), + } + }) + }) + .collect() + } else { + vec![None; state.program.table_count()] + }; + + // get values for the traces + let traces_by_table_by_stage: Vec> = state.processed_stages.iter().fold( + vec![vec![]; state.program.table_count()], + |mut traces_by_table, _| { + let values = opened_values.next().unwrap(); + for (values, v) in traces_by_table.iter_mut().zip_eq(values) { + let [local, next] = v.try_into().unwrap(); + + values.push(StageOpenedValues { next, local }); + } + traces_by_table + }, + ); + + // get values for the quotient + let mut value = opened_values.next().unwrap().into_iter(); + let quotient_chunks: Vec>>> = self + .tables + .values() + .map(|i| { + let log_quotient_degree = i.log_quotient_degree(); + let quotient_degree = 1 << log_quotient_degree; + (&mut value) + .take(quotient_degree) + .map(|v| { + let [v] = v.try_into().unwrap(); + v + }) + .collect() + }) + .collect(); + + assert!(opened_values.next().is_none()); + + let opened_values = state + .program + .tables + .iter() + .zip_eq(preprocessed) + .zip_eq(traces_by_table_by_stage) + .zip_eq(quotient_chunks) + .map( + |((((name, table), preprocessed), traces_by_stage), quotient_chunks)| { + ( + name.clone(), + TableOpenedValues { + preprocessed, + traces_by_stage, + quotient_chunks, + log_degree: table.log_degree(), + }, + ) + }, + ) + .collect(); + (opened_values, proof) + } + + /// For a given stage, return the number of challenges required by the table with the most challenges. + /// + /// # Panics + /// + /// Panics if there are no tables. + fn stage_challenge_count(&self, stage_id: u8) -> usize { + self.tables + .values() + .map(|table| { + <_ as MultiStageAir>>::stage_challenge_count( + &table.air, stage_id, + ) + }) + .max() + .unwrap() + } +} + +/// A sub-table to be proven, in the form of an air and a degree +pub(crate) struct Table<'a, T: FieldElementMap> +where + ProverData: Send, + Commitment: Send, +{ + air: PowdrTable<'a, T>, + degree: usize, +} + +impl<'a, T: FieldElementMap> Table<'a, T> +where + ProverData: Send, + Commitment: Send, +{ + fn log_degree(&self) -> usize { + log2_strict_usize(self.degree) + } + + fn trace_domain(&self, pcs: &Pcs) -> Domain { + pcs.natural_domain_for_degree(self.degree) + } + + fn public_input_count_per_stage(&self) -> Vec { + (0..<_ as MultiStageAir>>::stage_count(&self.air)) + .map(|stage| { + <_ as MultiStageAir>>::stage_public_count(&self.air, stage) + }) + .collect() + } + + fn log_quotient_degree(&self) -> usize { + get_log_quotient_degree(&self.air, &self.public_input_count_per_stage()) + } + + fn observe_instance(&self, challenger: &mut Challenger) { + challenger.observe(Val::::from_canonical_usize(self.log_degree())); + // TODO: Might be best practice to include other instance data here; see verifier comment. + } + + /// Compute the quotient domains and chunks for this table. + /// * Arguments: + /// * `table_index`: The index of the table in the program. This is used as the index for this table in the mmcs. + /// * `state`: The current prover state. + /// * `table_preprocessed_data`: The preprocessed data for this table, if it exists. + /// * `alpha`: The challenge value for the quotient polynomial. + fn quotient_domains_and_chunks( + &self, + table_index: usize, + state: &ProverState, + table_preprocessed_data: Option<&TableProvingKeyCollection>, + alpha: Challenge, + ) -> impl Iterator, DenseMatrix>)> { + let quotient_domain = self + .trace_domain(state.pcs) + .create_disjoint_domain(1 << (self.log_degree() + self.log_quotient_degree())); + + let preprocessed_on_quotient_domain = table_preprocessed_data.map(|preprocessed| { + state.pcs.get_evaluations_on_domain( + &preprocessed[&(1 << self.log_degree())].prover_data, + // the index is 0 because we committed to each preprocessed matrix alone, see setup + 0, + quotient_domain, + ) + }); + + let traces_on_quotient_domain = state + .processed_stages + .iter() + .map(|s| { + state + .pcs + // the index is `table_index` because we committed to all table for a given stage together, and this is the `table_index`th table + .get_evaluations_on_domain(&s.prover_data, table_index, quotient_domain) + }) + .collect(); + + let challenges = state + .processed_stages + .iter() + .map(|stage| stage.challenge_values.clone()) + .collect_vec(); + + let public_values_by_stage = state + .processed_stages + .iter() + .map(|stage| stage.public_values[table_index].clone()) + .collect_vec(); + + let quotient_values = quotient_values::( + &self.air, + &public_values_by_stage, + self.trace_domain(state.pcs), + quotient_domain, + preprocessed_on_quotient_domain, + traces_on_quotient_domain, + &challenges, + alpha, + ); + + let quotient_flat = RowMajorMatrix::new_col(quotient_values).flatten_to_base(); + + let quotient_degree = 1 << self.log_quotient_degree(); + let quotient_chunks = quotient_domain.split_evals(quotient_degree, quotient_flat); + let qc_domains = quotient_domain.split_domains(quotient_degree); + qc_domains.into_iter().zip_eq(quotient_chunks) + } +} + #[instrument(skip_all)] #[allow(clippy::multiple_bound_locations)] // cfg not supported in where clauses? -pub fn prove_with_key< - SC, - #[cfg(debug_assertions)] A: for<'a> Air>>, - #[cfg(not(debug_assertions))] A, - C, ->( - config: &SC, - proving_key: Option<&StarkProvingKey>, - air: &A, - challenger: &mut SC::Challenger, - stage_0_trace: RowMajorMatrix>, - next_stage_trace_callback: &C, - #[allow(clippy::ptr_arg)] - // we do not use `&[Val]` in order to keep the same API - stage_0_public_values: &Vec>, -) -> Proof +pub fn prove( + proving_key: Option<&StarkProvingKey>, + program: &PowdrCircuit, + witness: &mut Vec<(String, Vec)>, + challenger: &mut Challenger, +) -> Proof where - SC: StarkGenericConfig, - A: MultiStageAir>> - + for<'a> MultiStageAir>, - C: NextStageTraceCallback, + ProverData: Send, + Commitment: Send, { - let degree = stage_0_trace.height(); - let log_degree = log2_strict_usize(degree); + let (tables, stage_0): (BTreeMap<_, _>, BTreeMap<_, _>) = program + .split + .iter() + .map(|(name, (pil, constraint_system))| { + let columns = machine_witness_columns(witness, pil, name); + let degree = columns[0].1.len(); + + ( + ( + name.clone(), + Table { + air: PowdrTable::new(constraint_system), + degree, + }, + ), + ( + name.clone(), + AirStage { + trace: generate_matrix( + columns.iter().map(|(name, values)| (name, values.as_ref())), + ), + public_values: constraint_system.publics_by_stage[0] + .iter() + .map(|(name, _, row)| { + witness + .iter() + .find_map(|(n, v)| (n == name).then(|| v[*row])) + .unwrap() + .into_p3_field() + }) + .collect(), + }, + ), + ) + }) + .unzip(); - let stage_count = >>::stage_count(air); + if tables.is_empty() { + panic!("No tables to prove"); + } - let pcs = config.pcs(); - let trace_domain = pcs.natural_domain_for_degree(degree); + let multi_table = MultiTable { tables }; + + let config = T::get_config(); - // Observe the instance. - challenger.observe(Val::::from_canonical_usize(log_degree)); - // TODO: Might be best practice to include other instance data here; see verifier comment. + assert_eq!(stage_0.keys().collect_vec(), multi_table.table_names()); + let stage_count = multi_table.stage_count(); + + let pcs = config.pcs(); + + // observe the parts of the proving key which correspond to the sizes of the tables we are proving if let Some(proving_key) = proving_key { - challenger.observe(proving_key.preprocessed_commit.clone()) + for commitment in proving_key + .preprocessed + .iter() + .map(|(name, map)| &map[&multi_table.tables[name].degree].commitment) + { + challenger.observe(commitment.clone()); + } }; - let mut state: ProverState = ProverState::new(pcs, trace_domain, challenger); + multi_table.observe_instances(challenger); + + let mut state = ProverState::new(&multi_table, pcs, challenger); let mut stage = Stage { - trace: stage_0_trace, - challenge_count: >>::stage_challenge_count(air, 0), - public_values: stage_0_public_values.to_owned(), + id: 0, + air_stages: stage_0, }; assert!(stage_count >= 1); @@ -74,21 +471,15 @@ where state = state.run_stage(stage); // get the challenges drawn at the end of the previous stage let local_challenges = &state.processed_stages.last().unwrap().challenge_values; - let CallbackResult { - trace, - public_values, - challenges, - } = next_stage_trace_callback.compute_stage(stage_id as u32, local_challenges); - // replace the challenges of the last stage with the ones received - state.processed_stages.last_mut().unwrap().challenge_values = challenges; + let CallbackResult { air_stages } = + program.compute_stage(stage_id, local_challenges, witness); + + assert_eq!(air_stages.len(), multi_table.table_count()); + // go to the next stage stage = Stage { - trace, - challenge_count: >>::stage_challenge_count( - air, - stage_id as u32, - ), - public_values, + id: stage_id, + air_stages, }; } @@ -103,83 +494,9 @@ where .challenge_values .is_empty()); // sanity check that we processed as many stages as expected - assert_eq!(state.processed_stages.len(), stage_count); - - // with the witness complete, check the constraints - #[cfg(debug_assertions)] - crate::check_constraints::check_constraints( - air, - &air.preprocessed_trace() - .unwrap_or(RowMajorMatrix::new(Default::default(), 0)), - state.processed_stages.iter().map(|s| &s.trace).collect(), - &state - .processed_stages - .iter() - .map(|s| &s.public_values) - .collect(), - state - .processed_stages - .iter() - .map(|s| &s.challenge_values) - .collect(), - ); - - let log_quotient_degree = get_log_quotient_degree::, A>( - air, - &state - .processed_stages - .iter() - .map(|s| s.public_values.len()) - .collect::>(), - ); - let quotient_degree = 1 << log_quotient_degree; - - let challenger = &mut state.challenger; - - let alpha: SC::Challenge = challenger.sample_ext_element(); - - let quotient_domain = - trace_domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree)); - - let preprocessed_on_quotient_domain = proving_key.map(|proving_key| { - pcs.get_evaluations_on_domain(&proving_key.preprocessed_data, 0, quotient_domain) - }); - - let traces_on_quotient_domain = state - .processed_stages - .iter() - .map(|s| pcs.get_evaluations_on_domain(&s.prover_data, 0, quotient_domain)) - .collect(); - - let challenges = state - .processed_stages - .iter() - .map(|stage| stage.challenge_values.clone()) - .collect(); + assert_eq!(state.processed_stages.len() as u8, stage_count); - let public_values_by_stage = state - .processed_stages - .iter() - .map(|stage| stage.public_values.clone()) - .collect(); - - let quotient_values = quotient_values( - air, - &public_values_by_stage, - trace_domain, - quotient_domain, - preprocessed_on_quotient_domain, - traces_on_quotient_domain, - challenges, - alpha, - ); - let quotient_flat = RowMajorMatrix::new_col(quotient_values).flatten_to_base(); - let quotient_chunks = quotient_domain.split_evals(quotient_degree, quotient_flat); - let qc_domains = quotient_domain.split_domains(quotient_degree); - - let (quotient_commit, quotient_data) = info_span!("commit to quotient poly chunks") - .in_scope(|| pcs.commit(izip!(qc_domains, quotient_chunks).collect_vec())); - challenger.observe(quotient_commit.clone()); + let (quotient_commit, quotient_data) = multi_table.commit_to_quotient(&mut state, proving_key); let commitments = Commitments { traces_by_stage: state @@ -190,94 +507,30 @@ where quotient_chunks: quotient_commit, }; - let zeta: SC::Challenge = challenger.sample(); - let zeta_next = trace_domain.next_point(zeta).unwrap(); + let (opened_values, opening_proof) = multi_table.open(&mut state, proving_key, quotient_data); - let (opened_values, opening_proof) = pcs.open( - iter::empty() - .chain( - proving_key - .map(|proving_key| { - (&proving_key.preprocessed_data, vec![vec![zeta, zeta_next]]) - }) - .into_iter(), - ) - .chain( - state - .processed_stages - .iter() - .map(|processed_stage| { - (&processed_stage.prover_data, vec![vec![zeta, zeta_next]]) - }) - .collect_vec(), - ) - .chain(once(( - "ient_data, - // open every chunk at zeta - (0..quotient_degree).map(|_| vec![zeta]).collect_vec(), - ))) - .collect_vec(), - challenger, - ); - let mut opened_values = opened_values.iter(); - - // maybe get values for the preprocessed columns - let (preprocessed_local, preprocessed_next) = if proving_key.is_some() { - let value = opened_values.next().unwrap(); - assert_eq!(value.len(), 1); - assert_eq!(value[0].len(), 2); - (value[0][0].clone(), value[0][1].clone()) - } else { - (vec![], vec![]) - }; - - // get values for the traces - let (traces_by_stage_local, traces_by_stage_next): (Vec<_>, Vec<_>) = state - .processed_stages - .iter() - .map(|_| { - let value = opened_values.next().unwrap(); - assert_eq!(value.len(), 1); - assert_eq!(value[0].len(), 2); - (value[0][0].clone(), value[0][1].clone()) - }) - .unzip(); - - // get values for the quotient - let value = opened_values.next().unwrap(); - assert_eq!(value.len(), quotient_degree); - let quotient_chunks = value.iter().map(|v| v[0].clone()).collect_vec(); - - let opened_values = OpenedValues { - traces_by_stage_local, - traces_by_stage_next, - preprocessed_local, - preprocessed_next, - quotient_chunks, - }; Proof { commitments, opened_values, opening_proof, - degree_bits: log_degree, } } #[allow(clippy::too_many_arguments)] #[instrument(name = "compute quotient polynomial", skip_all)] -fn quotient_values<'a, SC, A, Mat>( +fn quotient_values( air: &A, - public_values_by_stage: &'a Vec>>, + public_values_by_stage: &[Vec>], trace_domain: Domain, quotient_domain: Domain, preprocessed_on_quotient_domain: Option, traces_on_quotient_domain: Vec, - challenges: Vec>>, + challenges: &[Vec>], alpha: SC::Challenge, ) -> Vec where SC: StarkGenericConfig, - A: Air>, + A: for<'a> Air>, Mat: Matrix> + Sync, { let quotient_size = quotient_domain.size(); @@ -313,10 +566,13 @@ where let preprocessed = RowMajorMatrix::new( preprocessed_on_quotient_domain .as_ref() - .map(|on_quotient_domain| { + .map(|preprocessed_on_quotient_domain| { iter::empty() - .chain(on_quotient_domain.vertically_packed_row(i_start)) - .chain(on_quotient_domain.vertically_packed_row(i_start + next_step)) + .chain(preprocessed_on_quotient_domain.vertically_packed_row(i_start)) + .chain( + preprocessed_on_quotient_domain + .vertically_packed_row(i_start + next_step), + ) .collect_vec() }) .unwrap_or_default(), @@ -336,13 +592,16 @@ where trace_on_quotient_domain.width(), ) }) - .collect(); + .collect_vec(); let accumulator = PackedChallenge::::zero(); let mut folder = ProverConstraintFolder { - challenges: challenges.clone(), - traces_by_stage, - preprocessed, + challenges, + traces_by_stage: traces_by_stage + .iter() + .map(|trace| trace.as_view()) + .collect(), + preprocessed: preprocessed.as_view(), public_values_by_stage, is_first_row, is_last_row, @@ -366,84 +625,92 @@ where .collect() } -pub struct ProverState<'a, SC: StarkGenericConfig> { - pub(crate) processed_stages: Vec>, - pub(crate) challenger: &'a mut SC::Challenger, - pub(crate) pcs: &'a ::Pcs, - pub(crate) trace_domain: Domain, +struct ProverState<'a, T: FieldElementMap> +where + ProverData: Send, + Commitment: Send, +{ + pub(crate) program: &'a MultiTable<'a, T>, + pub(crate) processed_stages: Vec>, + pub(crate) challenger: &'a mut Challenger, + pub(crate) pcs: &'a Pcs, } -impl<'a, SC: StarkGenericConfig> ProverState<'a, SC> { +impl<'a, T: FieldElementMap> ProverState<'a, T> +where + ProverData: Send, + Commitment: Send, +{ pub(crate) fn new( - pcs: &'a ::Pcs, - trace_domain: Domain, - challenger: &'a mut ::Challenger, + program: &'a MultiTable<'a, T>, + pcs: &'a ::Pcs, + challenger: &'a mut ::Challenger, ) -> Self { Self { + program, processed_stages: Default::default(), challenger, pcs, - trace_domain, } } - pub(crate) fn run_stage(mut self, stage: Stage) -> Self { - #[cfg(debug_assertions)] - let trace = stage.trace.clone(); - - // commit to the trace for this stage - let (commitment, prover_data) = info_span!("commit to stage {stage} data") - .in_scope(|| self.pcs.commit(vec![(self.trace_domain, stage.trace)])); + pub(crate) fn run_stage(mut self, stage: Stage>) -> Self { + let (commit_inputs, public_values): (_, Vec<_>) = stage + .air_stages + .into_values() + .map(|air_stage| { + ( + ( + self.pcs.natural_domain_for_degree(air_stage.trace.height()), + air_stage.trace, + ), + air_stage.public_values, + ) + }) + .unzip(); + // commit to the traces + let (commitment, prover_data) = + info_span!("commit to stage {stage} data").in_scope(|| self.pcs.commit(commit_inputs)); self.challenger.observe(commitment.clone()); - // observe the public inputs for this stage - self.challenger.observe_slice(&stage.public_values); - let challenge_values = (0..stage.challenge_count) + // observe the public inputs + for public_values in &public_values { + self.challenger.observe_slice(public_values); + } + + // draw challenges + let challenge_values = (0..self.program.stage_challenge_count(stage.id)) .map(|_| self.challenger.sample()) .collect(); + // update the state with the output of this stage self.processed_stages.push(ProcessedStage { - public_values: stage.public_values, + public_values, prover_data, commitment, challenge_values, - #[cfg(debug_assertions)] - trace, }); + self } } -pub struct Stage { +pub struct AirStage { /// the witness for this stage - pub(crate) trace: RowMajorMatrix>, - /// the number of challenges to be drawn at the end of this stage - pub(crate) challenge_count: usize, + pub(crate) trace: RowMajorMatrix, /// the public values for this stage - pub(crate) public_values: Vec>, -} - -pub struct CallbackResult { - /// the trace for this stage - pub(crate) trace: RowMajorMatrix, - /// the values of the public inputs of this stage - pub(crate) public_values: Vec, - /// the values of the challenges drawn at the previous stage - pub(crate) challenges: Vec, + pub(crate) public_values: Vec, } -impl CallbackResult { - pub fn new(trace: RowMajorMatrix, public_values: Vec, challenges: Vec) -> Self { - Self { - trace, - public_values, - challenges, - } - } +pub struct Stage { + /// the id of this stage + pub(crate) id: u8, + /// the stage trace for each air + air_stages: BTreeMap>, } -pub trait NextStageTraceCallback { - /// Computes the stage number `trace_stage` based on `challenges` drawn at the end of stage `trace_stage - 1` - fn compute_stage(&self, stage: u32, challenges: &[Val]) -> CallbackResult>; +pub struct CallbackResult { + /// the next stage for each air + pub(crate) air_stages: BTreeMap>, } diff --git a/plonky3/src/stark.rs b/plonky3/src/stark.rs index 9747dfd601..405d18a91d 100644 --- a/plonky3/src/stark.rs +++ b/plonky3/src/stark.rs @@ -1,22 +1,29 @@ //! A plonky3 prover using FRI and Poseidon +use itertools::Itertools; +use p3_commit::Pcs; use p3_matrix::dense::RowMajorMatrix; +use powdr_backend_utils::machine_fixed_columns; +use powdr_executor::constant_evaluator::VariablySizedColumn; use core::fmt; -use std::iter::{once, repeat}; +use std::collections::BTreeMap; use std::sync::Arc; use powdr_ast::analyzed::Analyzed; use powdr_executor::witgen::WitgenCallback; -use crate::{prove_with_key, verify_with_key, Proof, StarkProvingKey, StarkVerifyingKey}; +use crate::{ + circuit_builder::ConstraintSystem, prove, verify, Proof, StarkProvingKey, StarkVerifyingKey, + TableProvingKey, TableProvingKeyCollection, +}; use p3_uni_stark::StarkGenericConfig; use crate::{ - circuit_builder::{generate_matrix, PowdrCircuit}, - params::{Challenger, Commitment, FieldElementMap, Plonky3Field, ProverData}, + circuit_builder::PowdrCircuit, + params::{Challenger, Commitment, FieldElementMap, ProverData}, }; pub struct Plonky3Prover @@ -26,8 +33,10 @@ where { /// The analyzed PIL analyzed: Arc>, + /// The split analyzed PIL + split: BTreeMap, ConstraintSystem)>, /// The value of the fixed columns - fixed: Arc)>>, + fixed: Arc)>>, /// Proving key proving_key: Option>, /// Verifying key @@ -51,8 +60,18 @@ where ProverData: Send, Commitment: Send, { - pub fn new(analyzed: Arc>, fixed: Arc)>>) -> Self { + pub fn new( + analyzed: Arc>, + fixed: Arc)>>, + ) -> Self { Self { + split: powdr_backend_utils::split_pil(&analyzed) + .into_iter() + .map(|(name, pil)| { + let constraint_system = ConstraintSystem::from(&pil); + (name, (pil, constraint_system)) + }) + .collect(), analyzed, fixed, proving_key: None, @@ -60,6 +79,10 @@ where } } + pub fn analyzed(&self) -> &Analyzed { + &self.analyzed + } + pub fn set_verifying_key(&mut self, rdr: &mut dyn std::io::Read) { self.verifying_key = Some(bincode::deserialize_from(rdr).unwrap()); } @@ -72,30 +95,6 @@ where ) .unwrap()) } - - /// Returns preprocessed matrix based on the fixed inputs [`Plonky3Prover`]. - /// This is used when running the setup phase - pub fn get_preprocessed_matrix(&self) -> RowMajorMatrix> { - let publics = self - .analyzed - .get_publics() - .into_iter() - .map(|(name, _, row_id)| { - let selector = (0..self.analyzed.degree()) - .map(move |i| T::from(i == row_id as u64)) - .collect::>(); - (name, selector) - }) - .collect::>(); - - let fixed_with_public_selectors = self - .fixed - .iter() - .chain(publics.iter()) - .map(|(name, values)| (name, values.as_ref())); - - generate_matrix(fixed_with_public_selectors) - } } impl Plonky3Prover @@ -104,61 +103,89 @@ where Commitment: Send, { pub fn setup(&mut self) { - // get fixed columns - let fixed = &self.fixed; - - // get selector columns for public values - let publics = self - .analyzed - .get_publics() - .into_iter() - .map(|(name, _, row_id)| { - let selector = (0..self.analyzed.degree()) - .map(move |i| T::from(i == row_id as u64)) - .collect::>(); - (name, selector) + let preprocessed: BTreeMap> = self + .split + .iter() + .filter_map(|(namespace, (pil, _))| { + // if we have neither fixed columns nor publics, we don't need to commit to anything + if pil.constant_count() + pil.publics_count() == 0 { + None + } else { + let fixed_columns = machine_fixed_columns(&self.fixed, pil); + Some(( + namespace.to_string(), + pil.committed_polys_in_source_order() + .find_map(|(s, _)| s.degree) + .unwrap() + .iter() + .map(|size| { + // get selector columns for the public inputs, as closures + let publics = pil + .get_publics() + .into_iter() + .map(|(_, _, row_id, _)| move |i| T::from(i == row_id as u64)) + .collect::>(); + + // get the config + let config = T::get_config(); + + // commit to the fixed columns + let pcs = config.pcs(); + let domain = pcs.natural_domain_for_degree(size as usize); + let fixed_columns = &fixed_columns[&size]; + + // generate the preprocessed matrix row by row + let matrix = RowMajorMatrix::new( + (0..size) + .flat_map(|i| { + fixed_columns + .iter() + .map(move |(_, column)| column[i as usize]) + .chain(publics.iter().map(move |f| f(i))) + .map(|value| value.into_p3_field()) + }) + .collect(), + fixed_columns.len() + publics.len(), + ); + + let evaluations = vec![(domain, matrix)]; + + // commit to the evaluations + let (commitment, prover_data) = + <_ as p3_commit::Pcs<_, Challenger>>::commit( + pcs, + evaluations, + ); + ( + size as usize, + TableProvingKey { + commitment, + prover_data, + }, + ) + }) + .collect(), + )) + } }) - .collect::)>>(); - - if fixed.is_empty() && publics.is_empty() { - return; - } + .collect(); - // get the config - let config = T::get_config(); - - // commit to the fixed columns - let pcs = config.pcs(); - let domain = <_ as p3_commit::Pcs<_, Challenger>>::natural_domain_for_degree( - pcs, - self.analyzed.degree() as usize, - ); - // write fixed into matrix row by row - let matrix = RowMajorMatrix::new( - (0..self.analyzed.degree()) - .flat_map(|i| { - fixed - .iter() - .chain(publics.iter()) - .map(move |(_, values)| values[i as usize].into_p3_field()) + let verifying_key = StarkVerifyingKey { + preprocessed: preprocessed + .iter() + .map(|(table_name, data)| { + ( + table_name.clone(), + data.iter() + .map(|(size, table_proving_key)| { + (*size, table_proving_key.commitment.clone()) + }) + .collect(), + ) }) .collect(), - self.fixed.len() + publics.len(), - ); - - let evaluations = vec![(domain, matrix)]; - - // commit to the evaluations - let (fixed_commit, fixed_data) = - <_ as p3_commit::Pcs<_, Challenger>>::commit(pcs, evaluations); - - let proving_key = StarkProvingKey { - preprocessed_commit: fixed_commit.clone(), - preprocessed_data: fixed_data, - }; - let verifying_key = StarkVerifyingKey { - preprocessed_commit: fixed_commit, }; + let proving_key = StarkProvingKey { preprocessed }; self.proving_key = Some(proving_key); self.verifying_key = Some(verifying_key); @@ -169,46 +196,44 @@ where witness: &[(String, Vec)], witgen_callback: WitgenCallback, ) -> Result, String> { - let stage_0_trace = - generate_matrix(witness.iter().map(|(name, value)| (name, value.as_ref()))); + // here we need to clone the witness because the callback will modify it + let witness = &mut witness.to_vec(); - let circuit = PowdrCircuit::new(&self.analyzed) - .with_witgen_callback(witgen_callback) - .with_phase_0_witness(witness); - - #[cfg(debug_assertions)] - let circuit = circuit.with_preprocessed(self.get_preprocessed_matrix()); - - let stage_0_publics = circuit.public_values_so_far(); - - let config = T::get_config(); + let circuit = PowdrCircuit::new(&self.split).with_witgen_callback(witgen_callback); let mut challenger = T::get_challenger(); let proving_key = self.proving_key.as_ref(); - let proof = prove_with_key( - &config, - proving_key, - &circuit, - &mut challenger, - stage_0_trace, - &circuit, - &stage_0_publics, - ); + let proof = prove(proving_key, &circuit, witness, &mut challenger); let mut challenger = T::get_challenger(); let verifying_key = self.verifying_key.as_ref(); - let empty_public = vec![]; - let public_values = once(&stage_0_publics) - .chain(repeat(&empty_public)) - .take(self.analyzed.stage_count()) + let public_values = circuit.public_values_so_far(witness); + + // extract the full map of public values by unwrapping all the options + let public_values = public_values + .into_iter() + .map(|(name, values)| { + ( + name, + values + .into_iter() + .map(|v| { + v.into_iter() + .map(|v| { + v.expect("all public values should be known after execution") + }) + .collect() + }) + .collect(), + ) + }) .collect(); - verify_with_key( - &config, + verify( verifying_key, &circuit, &mut challenger, @@ -219,34 +244,41 @@ where Ok(bincode::serialize(&proof).unwrap()) } - pub fn verify(&self, proof: &[u8], instances: &[Vec]) -> Result<(), String> { + // verify the proof given the instances for each table, for each stage + pub fn verify(&self, proof: &[u8], instances: &[T]) -> Result<(), String> { let proof: Proof<_> = bincode::deserialize(proof).map_err(|e| format!("Failed to deserialize proof: {e}"))?; - let publics = instances - .iter() - .flatten() - .map(|v| v.into_p3_field()) - .collect(); - - let config = T::get_config(); let mut challenger = T::get_challenger(); let verifying_key = self.verifying_key.as_ref(); - let empty_public = vec![]; - let public_values = once(&publics) - .chain(repeat(&empty_public)) - .take(self.analyzed.stage_count()) + let stage_count = self.analyzed.stage_count(); + + let mut instance_map: BTreeMap>> = self + .split + .keys() + .map(|name| (name.clone(), vec![vec![]; stage_count])) .collect(); - verify_with_key( - &config, + self.analyzed + .get_publics() + .iter() + .zip_eq(instances.iter()) + .map(|((poly_name, _, _, stage), value)| { + let namespace = poly_name.split("::").next().unwrap(); + (namespace, stage, value) + }) + .for_each(|(namespace, stage, value)| { + instance_map.get_mut(namespace).unwrap()[*stage as usize].push(*value); + }); + + verify( verifying_key, - &PowdrCircuit::new(&self.analyzed), + &PowdrCircuit::new(&self.split), &mut challenger, &proof, - public_values, + instance_map, ) .map_err(|e| format!("Failed to verify proof: {e:?}")) } @@ -254,9 +286,7 @@ where #[cfg(test)] mod tests { - use std::sync::Arc; - use powdr_executor::constant_evaluator::get_uniquely_sized_cloned; use powdr_number::{BabyBearField, GoldilocksField, Mersenne31Field}; use powdr_pipeline::Pipeline; use test_log::test; @@ -265,12 +295,16 @@ mod tests { /// Prove and verify execution over all supported fields fn run_test(pil: &str) { - run_test_publics::(pil, None); - run_test_publics::(pil, None); - run_test_publics::(pil, None); + run_test_publics(pil, &None); } - fn run_test_publics(pil: &str, malicious_publics: Option>) + fn run_test_publics(pil: &str, malicious_publics: &Option>) { + run_test_publics_aux::(pil, malicious_publics); + run_test_publics_aux::(pil, malicious_publics); + run_test_publics_aux::(pil, malicious_publics); + } + + fn run_test_publics_aux(pil: &str, malicious_publics: &Option>) where ProverData: Send, Commitment: Send, @@ -279,33 +313,28 @@ mod tests { let pil = pipeline.compute_optimized_pil().unwrap(); let witness_callback = pipeline.witgen_callback().unwrap(); - let witness = pipeline.compute_witness().unwrap(); + let witness = &mut pipeline.compute_witness().unwrap(); let fixed = pipeline.compute_fixed_cols().unwrap(); - let fixed = Arc::new(get_uniquely_sized_cloned(&fixed).unwrap()); let mut prover = Plonky3Prover::new(pil, fixed); prover.setup(); - let proof = prover.prove(&witness, witness_callback); + let proof = prover.prove(witness, witness_callback); assert!(proof.is_ok()); if let Some(publics) = malicious_publics { - prover.verify(&proof.unwrap(), &[publics]).unwrap() + prover + .verify( + &proof.unwrap(), + &publics + .iter() + .map(|i| F::from(*i as u64)) + .collect::>(), + ) + .unwrap() } } - #[test] - fn add_baby_bear() { - let content = r#" - namespace Add(8); - col witness x; - col witness y; - col witness z; - x + y = z; - "#; - run_test_publics::(content, None); - } - #[test] fn public_values() { let content = "namespace Global(8); pol witness x; x * (x - 1) = 0; public out = x(7);"; @@ -340,18 +369,12 @@ mod tests { public outz = z(7); "#; - let gl_malicious_publics = Some(vec![GoldilocksField::from(0)]); - run_test_publics(content, gl_malicious_publics); - - let bb_malicious_publics = Some(vec![BabyBearField::from(0)]); - run_test_publics(content, bb_malicious_publics); - - let m31_malicious_publics = Some(vec![Mersenne31Field::from(0)]); - run_test_publics(content, m31_malicious_publics); + let malicious_publics = Some(vec![0]); + run_test_publics(content, &malicious_publics); } #[test] - #[should_panic = "assertion `left == right` failed: Not a power of two: 0\n left: 0\n right: 1"] + #[should_panic = "No tables to prove"] fn empty() { let content = "namespace Global(8);"; run_test(content); @@ -369,6 +392,17 @@ mod tests { run_test(content); } + #[test] + fn next() { + let content = r#" + namespace Next(8); + col witness x; + col witness y; + x' + y = 0; + "#; + run_test(content); + } + #[test] fn fixed() { let content = r#" @@ -380,6 +414,28 @@ mod tests { run_test(content); } + #[test] + fn two_tables() { + // This test is a bit contrived but witgen wouldn't allow a more direct example + let content = r#" + namespace Add(8); + col witness x; + col witness y; + col witness z; + x = 0; + y = 0; + x + y = z; + 1 $ [ x, y, z ] in 1 $ [ Mul::x, Mul::y, Mul::z ]; + + namespace Mul(16); + col witness x; + col witness y; + col witness z; + x * y = z; + "#; + run_test(content); + } + #[test] fn challenge() { let content = r#" @@ -396,9 +452,7 @@ mod tests { } #[test] - #[should_panic = "no entry found for key"] fn stage_1_public() { - // this currently fails because we try to extract the public values from the stage 0 witness only let content = r#" let N: int = 8; @@ -411,8 +465,7 @@ mod tests { public out = y(N - 1); "#; - let malicious_publics = Some(vec![GoldilocksField::from(0)]); - run_test_publics::(content, malicious_publics); + run_test(content); } #[test] diff --git a/plonky3/src/symbolic_builder.rs b/plonky3/src/symbolic_builder.rs index 074d44b9a0..a197162690 100644 --- a/plonky3/src/symbolic_builder.rs +++ b/plonky3/src/symbolic_builder.rs @@ -1,7 +1,7 @@ use alloc::vec; use alloc::vec::Vec; -use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder}; +use p3_air::{AirBuilder, PairBuilder}; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use p3_util::log2_ceil_usize; @@ -50,10 +50,10 @@ where A: MultiStageAir>, { let widths: Vec<_> = (0..air.stage_count()) - .map(|i| air.stage_trace_width(i as u32)) + .map(|i| air.stage_trace_width(i)) .collect(); let challenges: Vec<_> = (0..air.stage_count()) - .map(|i| air.stage_challenge_count(i as u32)) + .map(|i| air.stage_challenge_count(i)) .collect(); let mut builder = SymbolicAirBuilder::new( air.preprocessed_width(), @@ -173,27 +173,20 @@ impl AirBuilder for SymbolicAirBuilder { } } -impl AirBuilderWithPublicValues for SymbolicAirBuilder { - type PublicVar = SymbolicVariable; - - fn public_values(&self) -> &[Self::PublicVar] { - self.stage_public_values(0) - } -} - impl MultistageAirBuilder for SymbolicAirBuilder { type Challenge = Self::Var; + type PublicVar = SymbolicVariable; - fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { - &self.public_values_by_stage[stage] + fn stage_trace(&self, stage: u8) -> Self::M { + self.traces_by_stage[stage as usize].clone() } - fn stage_trace(&self, stage: usize) -> Self::M { - self.traces_by_stage[stage].clone() + fn stage_public_values(&self, stage: u8) -> &[Self::PublicVar] { + &self.public_values_by_stage[stage as usize] } - fn stage_challenges(&self, stage: usize) -> &[Self::Challenge] { - &self.challenges[stage] + fn stage_challenges(&self, stage: u8) -> &[Self::Challenge] { + &self.challenges[stage as usize] } } diff --git a/plonky3/src/traits.rs b/plonky3/src/traits.rs index f12fb0f82b..4e523a8f4c 100644 --- a/plonky3/src/traits.rs +++ b/plonky3/src/traits.rs @@ -1,40 +1,29 @@ -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues}; +use p3_air::{Air, AirBuilder}; -pub trait MultistageAirBuilder: AirBuilderWithPublicValues { +pub trait MultistageAirBuilder: AirBuilder { type Challenge: Clone + Into; + type PublicVar: Into + Copy; /// Traces from each stage. - fn stage_trace(&self, stage: usize) -> Self::M; + fn stage_trace(&self, stage: u8) -> Self::M; /// Challenges from each stage, drawn from the base field - fn stage_challenges(&self, stage: usize) -> &[Self::Challenge]; + fn stage_challenges(&self, stage: u8) -> &[Self::Challenge]; /// Public values for each stage - fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { - match stage { - 0 => self.public_values(), - _ => unimplemented!(), - } - } + fn stage_public_values(&self, stage: u8) -> &[Self::PublicVar]; } pub trait MultiStageAir: Air { + fn stage_public_count(&self, stage: u8) -> usize; + fn preprocessed_width(&self) -> usize; - fn stage_count(&self) -> usize { - 1 - } + fn stage_count(&self) -> u8; /// The number of trace columns in this stage - fn stage_trace_width(&self, stage: u32) -> usize { - match stage { - 0 => self.width(), - _ => unimplemented!(), - } - } + fn stage_trace_width(&self, stage: u8) -> usize; /// The number of challenges produced at the end of each stage - fn stage_challenge_count(&self, _stage: u32) -> usize { - 0 - } + fn stage_challenge_count(&self, _stage: u8) -> usize; } diff --git a/plonky3/src/verifier.rs b/plonky3/src/verifier.rs index 7ed205d516..98a0d2dba8 100644 --- a/plonky3/src/verifier.rs +++ b/plonky3/src/verifier.rs @@ -1,259 +1,431 @@ use alloc::vec; use alloc::vec::Vec; -use core::iter; +use p3_air::Air; +use std::collections::BTreeMap; +use std::iter::once; -use itertools::{izip, Itertools}; +use itertools::Itertools; use p3_challenger::{CanObserve, CanSample, FieldChallenger}; -use p3_commit::{Pcs, PolynomialSpace}; +use p3_commit::{Pcs as _, PolynomialSpace}; use p3_field::{AbstractExtensionField, AbstractField, Field}; use p3_matrix::dense::RowMajorMatrixView; use p3_matrix::stack::VerticalPair; use tracing::instrument; +use crate::circuit_builder::{PowdrCircuit, PowdrTable}; +use crate::params::{Challenge, Challenger, Commitment, Pcs, ProverData}; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; -use crate::{MultiStageAir, Proof, StarkVerifyingKey, VerifierConstraintFolder}; -use p3_uni_stark::{PcsError, StarkGenericConfig, Val}; +use crate::{ + FieldElementMap, MultiStageAir, Proof, StageOpenedValues, StarkVerifyingKey, TableOpenedValues, + TableVerifyingKeyCollection, VerifierConstraintFolder, +}; +use p3_uni_stark::{Domain, PcsError, StarkGenericConfig, Val}; -#[instrument(skip_all)] -pub fn verify( - config: &SC, - air: &A, - challenger: &mut SC::Challenger, - proof: &Proof, - public_values: &Vec>, -) -> Result<(), VerificationError>> +/// A sub-table to be proven, in the form of an air and values for the public inputs +struct Table<'a, T: FieldElementMap> where - SC: StarkGenericConfig, - A: MultiStageAir>> - + for<'a> MultiStageAir>, + ProverData: Send, + Commitment: Send, { - verify_with_key(config, None, air, challenger, proof, vec![public_values]) + air: PowdrTable<'a, T>, + preprocessed: Option<&'a TableVerifyingKeyCollection>, + opened_values: &'a TableOpenedValues>, + public_values_by_stage: &'a [Vec>], +} + +impl<'a, T: FieldElementMap> Table<'a, T> +where + ProverData: Send, + Commitment: Send, +{ + fn get_log_quotient_degree(&self) -> usize { + get_log_quotient_degree( + &self.air, + &self + .public_values_by_stage + .iter() + .map(|values| values.len()) + .collect::>(), + ) + } + + fn natural_domain(&self, pcs: &Pcs) -> Domain { + let degree = 1 << self.opened_values.log_degree; + pcs.natural_domain_for_degree(degree) + } + + fn preprocessed_commit(&self) -> Option<&Commitment> { + self.preprocessed + .as_ref() + .map(|preprocessed| &preprocessed[&(1 << self.opened_values.log_degree)]) + } + + fn quotient_domains(&self, pcs: &Pcs) -> Vec> { + let log_quotient_degree = self.get_log_quotient_degree(); + self.natural_domain(pcs) + .create_disjoint_domain(1 << (self.opened_values.log_degree + log_quotient_degree)) + .split_domains(1 << log_quotient_degree) + } } #[instrument(skip_all)] -pub fn verify_with_key( - config: &SC, - verifying_key: Option<&StarkVerifyingKey>, - air: &A, - challenger: &mut SC::Challenger, - proof: &Proof, - public_values_by_stage: Vec<&Vec>>, -) -> Result<(), VerificationError>> +pub fn verify( + verifying_key: Option<&StarkVerifyingKey>, + program: &PowdrCircuit, + challenger: &mut Challenger, + proof: &Proof, + public_inputs: BTreeMap>>, +) -> Result<(), VerificationError>> where - SC: StarkGenericConfig, - A: MultiStageAir>> - + for<'a> MultiStageAir>, + ProverData: Send, + Commitment: Send, { + let public_inputs = public_inputs + .into_iter() + .map(|(name, values)| { + ( + name, + values + .into_iter() + .map(|values| values.into_iter().map(|v| v.into_p3_field()).collect_vec()) + .collect_vec(), + ) + }) + .collect::>(); + let Proof { commitments, opened_values, opening_proof, - degree_bits, } = proof; - let degree = 1 << degree_bits; - let log_quotient_degree = get_log_quotient_degree::, A>( - air, - &public_values_by_stage - .iter() - .map(|values| values.len()) - .collect::>(), - ); - let quotient_degree = 1 << log_quotient_degree; - let stage_count = proof.commitments.traces_by_stage.len(); - let challenge_counts: Vec = (0..stage_count) - .map(|i| >>::stage_challenge_count(air, i as u32)) + // sanity check that the two maps have the same keys + itertools::assert_equal(program.split.keys(), public_inputs.keys()); + + // error out if the opened values do not have the same keys as the tables + if !itertools::equal(program.split.keys(), opened_values.keys()) { + return Err(VerificationError::InvalidProofShape); + } + + let tables: BTreeMap<&String, Table<_>> = program + .split + .values() + .zip_eq(public_inputs.iter()) + .zip_eq(opened_values.values()) + .map( + |(((_, constraints), (name, public_values_by_stage)), opened_values)| { + ( + name, + Table { + air: PowdrTable::new(constraints), + opened_values, + public_values_by_stage, + preprocessed: verifying_key + .as_ref() + .and_then(|vk| vk.preprocessed.get(name)), + }, + ) + }, + ) .collect(); - let pcs = config.pcs(); - let trace_domain = pcs.natural_domain_for_degree(degree); - let quotient_domain = - trace_domain.create_disjoint_domain(1 << (degree_bits + log_quotient_degree)); - let quotient_chunks_domains = quotient_domain.split_domains(quotient_degree); + let config = T::get_config(); - let air_widths = (0..stage_count) - .map(|stage| { - >>>::stage_trace_width(air, stage as u32) - }) - .collect::>(); - let air_fixed_width = - >>>::preprocessed_width(air); - let valid_shape = opened_values.preprocessed_local.len() == air_fixed_width - && opened_values.preprocessed_next.len() == air_fixed_width - && opened_values - .traces_by_stage_local - .iter() - .zip(&air_widths) - .all(|(stage, air_width)| stage.len() == *air_width) - && opened_values - .traces_by_stage_next - .iter() - .zip(&air_widths) - .all(|(stage, air_width)| stage.len() == *air_width) - && opened_values.quotient_chunks.len() == quotient_degree - && opened_values - .quotient_chunks - .iter() - .all(|qc| qc.len() == >>::D) - && public_values_by_stage.len() == stage_count - && challenge_counts.len() == stage_count; + let pcs = config.pcs(); - if !valid_shape { - return Err(VerificationError::InvalidProofShape); + for table in tables.values() { + if let Some(preprocessed_commit) = table.preprocessed_commit() { + challenger.observe(preprocessed_commit.clone()); + } } - // Observe the instance. - challenger.observe(Val::::from_canonical_usize(proof.degree_bits)); + // Observe the instances. + for table in tables.values() { + challenger.observe(Val::::from_canonical_usize( + table.opened_values.log_degree, + )); + } // TODO: Might be best practice to include other instance data here in the transcript, like some // encoding of the AIR. This protects against transcript collisions between distinct instances. // Practically speaking though, the only related known attack is from failing to include public // values. It's not clear if failing to include other instance data could enable a transcript // collision, since most such changes would completely change the set of satisfying witnesses. - if let Some(verifying_key) = verifying_key { - challenger.observe(verifying_key.preprocessed_commit.clone()) - }; + let stage_count = tables + .values() + .map(|i| &i.air) + .map(<_ as MultiStageAir>>::stage_count) + .max() + .unwrap(); - let mut challenges = vec![]; + let challenge_count_by_stage: Vec = (0..stage_count) + .map(|stage_id| { + tables + .values() + .map(|table| { + <_ as MultiStageAir>>::stage_challenge_count( + &table.air, stage_id, + ) + }) + .max() + .unwrap() + }) + .collect(); - commitments + let challenges_by_stage = commitments .traces_by_stage .iter() - .zip(&public_values_by_stage) - .zip(challenge_counts) - .for_each(|((commitment, public_values), challenge_count)| { + .zip_eq((0..stage_count).map(|i| { + tables + .values() + .map(|table| &table.public_values_by_stage[i as usize]) + .collect_vec() + })) + .zip_eq(challenge_count_by_stage) + .map(|((commitment, public_values_by_stage), challenge_count)| { challenger.observe(commitment.clone()); - challenger.observe_slice(public_values); - challenges.push((0..challenge_count).map(|_| challenger.sample()).collect()); - }); - let alpha: SC::Challenge = challenger.sample_ext_element(); + for public_values in &public_values_by_stage { + challenger.observe_slice(public_values); + } + (0..challenge_count) + .map(|_| challenger.sample()) + .collect_vec() + }) + .collect_vec(); + + let alpha: Challenge = challenger.sample_ext_element(); challenger.observe(commitments.quotient_chunks.clone()); - let zeta: SC::Challenge = challenger.sample(); - let zeta_next = trace_domain.next_point(zeta).unwrap(); + let zeta: Challenge = challenger.sample(); - pcs.verify( - iter::empty() - .chain( - verifying_key - .map(|verifying_key| { + // for preprocessed commitments, we have one optional commitment per table, opened on the trace domain at `zeta` and `zeta_next` + let preprocessed_domains_points_and_opens: Vec<(_, Vec<(_, _)>)> = + tables + .values() + .flat_map(|table| { + let trace_domain = table.natural_domain(pcs); + + let zeta_next = trace_domain.next_point(zeta).unwrap(); + + table.opened_values.preprocessed.iter().map( + move |StageOpenedValues { local, next }| { ( - verifying_key.preprocessed_commit.clone(), - (vec![( + // choose the correct preprocessed commitment based on the degree in the proof + // this could be optimized by putting the preproccessed commitments in a merkle tree + // and have the prover prove that it used commitments matching the lengths of the traces + // this way the verifier does not need to have all the preprocessed commitments for all sizes + table.preprocessed_commit().expect("a preprocessed commitment was expected because a preprocessed opening was found").clone(), + vec![( trace_domain, - vec![ - (zeta, opened_values.preprocessed_local.clone()), - (zeta_next, opened_values.preprocessed_next.clone()), - ], - )]), + vec![(zeta, local.clone()), (zeta_next, next.clone())], + )], ) - }) - .into_iter(), - ) - .chain( - izip!( - commitments.traces_by_stage.iter(), - opened_values.traces_by_stage_local.iter(), - opened_values.traces_by_stage_next.iter() + }, ) - .map(|(trace_commit, opened_local, opened_next)| { - ( - trace_commit.clone(), - vec![( + }) + .collect(); + + // for trace commitments, we have one commitment per stage, opened on each trace domain at `zeta` and `zeta_next` + let trace_domains_points_and_opens_by_stage: Vec<(_, Vec<(_, _)>)> = proof + .commitments + .traces_by_stage + .iter() + .zip_eq((0..stage_count as usize).map(|i| { + tables + .values() + .map(|table| &table.opened_values.traces_by_stage[i]) + .collect_vec() + })) + .map(|(commit, openings)| { + ( + commit.clone(), + tables + .values() + .zip_eq(openings) + .map(|(table, StageOpenedValues { local, next })| { + let trace_domain = table.natural_domain(pcs); + let zeta_next = trace_domain.next_point(zeta).unwrap(); + ( trace_domain, - vec![ - (zeta, opened_local.clone()), - (zeta_next, opened_next.clone()), - ], - )], - ) - }) - .collect_vec(), - ) - .chain([( - commitments.quotient_chunks.clone(), - quotient_chunks_domains - .iter() - .zip(&opened_values.quotient_chunks) - .map(|(domain, values)| (*domain, vec![(zeta, values.clone())])) + vec![(zeta, local.clone()), (zeta_next, next.clone())], + ) + }) .collect_vec(), - )]) + ) + }) + .collect(); + + // for quotient commitments, we have a single commitment, opened on each quotient domain at many points + let quotient_chunks_domain_point_and_opens: (_, Vec<(_, _)>) = ( + proof.commitments.quotient_chunks.clone(), + tables + .values() + .flat_map(|table| { + let quotient_domains = table.quotient_domains(pcs); + quotient_domains + .into_iter() + .zip_eq(table.opened_values.quotient_chunks.iter()) + .map(|(domain, chunk)| (domain, vec![(zeta, chunk.clone())])) + }) .collect_vec(), - opening_proof, - challenger, - ) - .map_err(VerificationError::InvalidOpeningArgument)?; + ); - let zps = quotient_chunks_domains - .iter() - .enumerate() - .map(|(i, domain)| { - quotient_chunks_domains - .iter() - .enumerate() - .filter(|(j, _)| *j != i) - .map(|(_, other_domain)| { - other_domain.zp_at_point(zeta) - * other_domain.zp_at_point(domain.first_point()).inverse() - }) - .product::() - }) - .collect_vec(); + let verify_input = preprocessed_domains_points_and_opens + .into_iter() + .chain(trace_domains_points_and_opens_by_stage) + .chain(once(quotient_chunks_domain_point_and_opens)) + .collect(); - let quotient = opened_values - .quotient_chunks - .iter() - .enumerate() - .map(|(ch_i, ch)| { - ch.iter() - .enumerate() - .map(|(e_i, &c)| zps[ch_i] * SC::Challenge::monomial(e_i) * c) - .sum::() - }) - .sum::(); + pcs.verify(verify_input, opening_proof, challenger) + .map_err(VerificationError::InvalidOpeningArgument)?; - let sels = trace_domain.selectors_at_point(zeta); + // Verify the constraint evaluations. + for table in tables.values() { + // Verify the shape of the opening arguments matches the expected values. + verify_opening_shape(table)?; + // Verify the constraint evaluation. + let zps = table + .quotient_domains(pcs) + .iter() + .enumerate() + .map(|(i, domain)| { + table + .quotient_domains(pcs) + .iter() + .enumerate() + .filter(|(j, _)| *j != i) + .map(|(_, other_domain)| { + other_domain.zp_at_point(zeta) + * other_domain.zp_at_point(domain.first_point()).inverse() + }) + .product::>() + }) + .collect_vec(); - let preprocessed = VerticalPair::new( - RowMajorMatrixView::new_row(&opened_values.preprocessed_local), - RowMajorMatrixView::new_row(&opened_values.preprocessed_next), - ); + let quotient = table + .opened_values + .quotient_chunks + .iter() + .enumerate() + .map(|(ch_i, ch)| { + ch.iter() + .enumerate() + .map(|(e_i, &c)| zps[ch_i] * Challenge::::monomial(e_i) * c) + .sum() + }) + .sum(); - let traces_by_stage = opened_values - .traces_by_stage_local - .iter() - .zip(opened_values.traces_by_stage_next.iter()) - .map(|(trace_local, trace_next)| { + let sels = table.natural_domain(pcs).selectors_at_point(zeta); + + let empty_vec = vec![]; + + let preprocessed = if let Some(preprocessed) = table.opened_values.preprocessed.as_ref() { VerticalPair::new( - RowMajorMatrixView::new_row(trace_local), - RowMajorMatrixView::new_row(trace_next), + RowMajorMatrixView::new_row(&preprocessed.local), + RowMajorMatrixView::new_row(&preprocessed.next), ) - }) - .collect::>>(); - - let mut folder = VerifierConstraintFolder { - challenges, - preprocessed, - traces_by_stage, - public_values_by_stage, - is_first_row: sels.is_first_row, - is_last_row: sels.is_last_row, - is_transition: sels.is_transition, - alpha, - accumulator: SC::Challenge::zero(), - }; - air.eval(&mut folder); - let folded_constraints = folder.accumulator; - - // Finally, check that - // folded_constraints(zeta) / Z_H(zeta) = quotient(zeta) - if folded_constraints * sels.inv_zeroifier != quotient { - return Err(VerificationError::OodEvaluationMismatch); + } else { + VerticalPair::new( + RowMajorMatrixView::new(&empty_vec, 0), + RowMajorMatrixView::new(&empty_vec, 0), + ) + }; + + let traces_by_stage = table + .opened_values + .traces_by_stage + .iter() + .map(|trace| { + VerticalPair::new( + RowMajorMatrixView::new_row(&trace.local), + RowMajorMatrixView::new_row(&trace.next), + ) + }) + .collect::>>(); + + let mut folder: VerifierConstraintFolder<'_, T::Config> = VerifierConstraintFolder { + challenges: &challenges_by_stage, + preprocessed, + traces_by_stage, + public_values_by_stage: table.public_values_by_stage, + is_first_row: sels.is_first_row, + is_last_row: sels.is_last_row, + is_transition: sels.is_transition, + alpha, + accumulator: Challenge::::zero(), + }; + table.air.eval(&mut folder); + let folded_constraints = folder.accumulator; + + // Finally, check that + // folded_constraints(zeta) / Z_H(zeta) = quotient(zeta) + if folded_constraints * sels.inv_zeroifier != quotient { + return Err(VerificationError::OodEvaluationMismatch); + } } Ok(()) } +fn verify_opening_shape( + table: &Table<'_, T>, +) -> Result<(), VerificationError>> +where + ProverData: Send, + Commitment: Send, +{ + let log_quotient_degree = get_log_quotient_degree::, _>( + &table.air, + &table + .public_values_by_stage + .iter() + .map(|values| values.len()) + .collect::>(), + ); + let quotient_degree = 1 << log_quotient_degree; + let stage_count = <_ as MultiStageAir>>::stage_count(&table.air); + let challenge_counts: Vec = (0..stage_count) + .map(|i| <_ as MultiStageAir>>::stage_challenge_count(&table.air, i)) + .collect(); + + let air_widths = (0..stage_count) + .map(|stage| { + <_ as MultiStageAir>>>::stage_trace_width( + &table.air, stage, + ) + }) + .collect::>(); + let air_fixed_width = + <_ as MultiStageAir>>>::preprocessed_width(&table.air); + let res = table + .opened_values + .preprocessed + .as_ref() + .map(|StageOpenedValues { local, next }| { + local.len() == air_fixed_width && next.len() == air_fixed_width + }) + .unwrap_or(true) + && table + .opened_values + .traces_by_stage + .iter() + .zip_eq(&air_widths) + .all(|(StageOpenedValues { local, next }, air_width)| { + local.len() == *air_width && next.len() == *air_width + }) + && table.opened_values.quotient_chunks.len() == quotient_degree + && table + .opened_values + .quotient_chunks + .iter() + .all(|qc| qc.len() == as AbstractExtensionField>>::D) + && table.public_values_by_stage.len() as u8 == stage_count + && challenge_counts.len() as u8 == stage_count; + + res.then_some(()) + .ok_or(VerificationError::InvalidProofShape) +} + #[derive(Debug)] pub enum VerificationError { InvalidProofShape,