Skip to content

Commit

Permalink
Integrated plonky3 prover (#1857)
Browse files Browse the repository at this point in the history
This is a major change to the plonky3 prover to support proving many
machines.

# Sharing costs across tables
- at setup phase, fixed columns for each machine are committed to for
each possible size. This happens in separate commitments, so that the
prover and verifier can pick the relevant ones for a given execution
- for each phase of the proving, the corresponding traces across all
machines are committed to jointly
- the quotient chunks are committed to jointly across all tables

# Multi-stage publics
The implementation supports public values for each stage of each table.
This is tested internally in the plonky3 crate but not end-to-end in
pipeline tests.

---------

Co-authored-by: Leo Alt <[email protected]>
  • Loading branch information
Schaeff and leonardoalt authored Oct 17, 2024
1 parent fe054dd commit 0f66c6a
Show file tree
Hide file tree
Showing 25 changed files with 1,461 additions and 1,141 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ members = [
"riscv-executor",
"riscv-syscalls",
"schemas",
"backend-utils",
]

exclude = [ "riscv-runtime" ]
Expand All @@ -47,6 +48,7 @@ powdr-asm-to-pil = { path = "./asm-to-pil", version = "0.1.0-alpha.2" }
powdr-isa-utils = { path = "./isa-utils", version = "0.1.0-alpha.2" }
powdr-analysis = { path = "./analysis", version = "0.1.0-alpha.2" }
powdr-backend = { path = "./backend", version = "0.1.0-alpha.2" }
powdr-backend-utils = { path = "./backend-utils", version = "0.1.0-alpha.2" }
powdr-executor = { path = "./executor", version = "0.1.0-alpha.2" }
powdr-importer = { path = "./importer", version = "0.1.0-alpha.2" }
powdr-jit-compiler = { path = "./jit-compiler", version = "0.1.0-alpha.2" }
Expand Down
21 changes: 12 additions & 9 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,23 +326,26 @@ impl<T> Analyzed<T> {
.for_each(|definition| definition.post_visit_expressions_mut(f))
}

/// Retrieves (col_name, poly_id, offset) of each public witness in the trace.
pub fn get_publics(&self) -> Vec<(String, PolyID, usize)> {
/// Retrieves (col_name, poly_id, offset, stage) of each public witness in the trace.
pub fn get_publics(&self) -> Vec<(String, PolyID, usize, u8)> {
let mut publics = self
.public_declarations
.values()
.map(|public_declaration| {
let column_name = public_declaration.referenced_poly_name();
let poly_id = {
let (poly_id, stage) = {
let symbol = &self.definitions[&public_declaration.polynomial.name].0;
symbol
.array_elements()
.nth(public_declaration.array_index.unwrap_or_default())
.unwrap()
.1
(
symbol
.array_elements()
.nth(public_declaration.array_index.unwrap_or_default())
.unwrap()
.1,
symbol.stage.unwrap_or_default() as u8,
)
};
let row_offset = public_declaration.index as usize;
(column_name, poly_id, row_offset)
(column_name, poly_id, row_offset, stage)
})
.collect::<Vec<_>>();

Expand Down
16 changes: 16 additions & 0 deletions backend-utils/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "powdr-backend-utils"
version.workspace = true
edition.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true

[dependencies]
powdr-pil-analyzer.workspace = true
powdr-parser.workspace = true
powdr-ast.workspace = true
powdr-number.workspace = true
powdr-executor.workspace = true
log = "0.4.22"
itertools = "0.13.0"
21 changes: 8 additions & 13 deletions backend/src/composite/split.rs → backend-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const DUMMY_COLUMN_NAME: &str = "__dummy";
/// 1. The PIL is split into namespaces
/// 2. Namespaces without any columns are duplicated and merged with the other namespaces
/// 3. Any lookups or permutations that reference multiple namespaces are removed.
pub(crate) fn split_pil<F: FieldElement>(pil: &Analyzed<F>) -> BTreeMap<String, Analyzed<F>> {
pub fn split_pil<F: FieldElement>(pil: &Analyzed<F>) -> BTreeMap<String, Analyzed<F>> {
let statements_by_namespace = split_by_namespace(pil);
let statements_by_machine = merge_empty_namespaces(statements_by_namespace, pil);

Expand All @@ -37,9 +37,9 @@ pub(crate) fn split_pil<F: FieldElement>(pil: &Analyzed<F>) -> BTreeMap<String,
.collect()
}

/// Given a set of columns and a PIL describing the machine, returns the witness column that belong to the machine.
/// Given a set of columns and a PIL describing the machine, returns the witness columns that belong to the machine.
/// Note that this also adds the dummy column.
pub(crate) fn machine_witness_columns<F: FieldElement>(
pub fn machine_witness_columns<F: FieldElement>(
all_witness_columns: &[(String, Vec<F>)],
machine_pil: &Analyzed<F>,
machine_name: &str,
Expand Down Expand Up @@ -71,10 +71,10 @@ pub(crate) fn machine_witness_columns<F: FieldElement>(
}

/// Given a set of columns and a PIL describing the machine, returns the fixed column that belong to the machine.
pub(crate) fn machine_fixed_columns<F: FieldElement>(
all_fixed_columns: &[(String, VariablySizedColumn<F>)],
machine_pil: &Analyzed<F>,
) -> BTreeMap<DegreeType, Vec<(String, VariablySizedColumn<F>)>> {
pub fn machine_fixed_columns<'a, F: FieldElement>(
all_fixed_columns: &'a [(String, VariablySizedColumn<F>)],
machine_pil: &'a Analyzed<F>,
) -> BTreeMap<DegreeType, Vec<(String, &'a [F])>> {
let machine_columns = select_machine_columns(
all_fixed_columns,
machine_pil.constant_polys_in_source_order(),
Expand Down Expand Up @@ -106,12 +106,7 @@ pub(crate) fn machine_fixed_columns<F: FieldElement>(
size,
machine_columns
.iter()
.map(|(name, column)| {
(
name.clone(),
column.get_by_size(size).unwrap().to_vec().into(),
)
})
.map(|(name, column)| (name.clone(), column.get_by_size(size).unwrap()))
.collect::<Vec<_>>(),
)
})
Expand Down
1 change: 1 addition & 0 deletions backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ powdr-parser.workspace = true
powdr-pil-analyzer.workspace = true
powdr-executor.workspace = true
powdr-parser-util.workspace = true
powdr-backend-utils.workspace = true

powdr-plonky3 = { path = "../plonky3", optional = true }

Expand Down
10 changes: 6 additions & 4 deletions backend/src/composite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@ use std::{

use itertools::Itertools;
use powdr_ast::analyzed::Analyzed;
use powdr_backend_utils::{machine_fixed_columns, machine_witness_columns};
use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback};
use powdr_number::{DegreeType, FieldElement};
use serde::{Deserialize, Serialize};
use split::{machine_fixed_columns, machine_witness_columns};

use crate::{Backend, BackendFactory, BackendOptions, Error, Proof};

use self::sub_prover::RunStatus;

mod split;

/// Maps each size to the corresponding verification key.
type VerificationKeyBySize = BTreeMap<DegreeType, Vec<u8>>;

Expand Down Expand Up @@ -76,7 +74,7 @@ impl<F: FieldElement, B: BackendFactory<F>> BackendFactory<F> for CompositeBacke
unimplemented!();
}

let pils = split::split_pil(&pil);
let pils = powdr_backend_utils::split_pil(&pil);

// Read the setup once (if any) to pass to all backends.
let setup_bytes = setup.map(|setup| {
Expand Down Expand Up @@ -109,6 +107,10 @@ impl<F: FieldElement, B: BackendFactory<F>> BackendFactory<F> for CompositeBacke
machine_fixed_columns(&fixed, &pil)
.into_iter()
.map(|(size, fixed)| {
let fixed = fixed
.into_iter()
.map(|(name, values)| (name, values.to_vec().into()))
.collect();
let pil = set_size(pil.clone(), size as DegreeType);
// Set up readers for the setup and verification key
let mut setup_cursor = setup_bytes.as_ref().map(Cursor::new);
Expand Down
7 changes: 0 additions & 7 deletions backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ pub enum BackendType {
#[cfg(feature = "plonky3")]
#[strum(serialize = "plonky3")]
Plonky3,
#[cfg(feature = "plonky3")]
#[strum(serialize = "plonky3-composite")]
Plonky3Composite,
}

pub type BackendOptions = String;
Expand Down Expand Up @@ -87,10 +84,6 @@ impl BackendType {
}
#[cfg(feature = "plonky3")]
BackendType::Plonky3 => Box::new(plonky3::Factory),
#[cfg(feature = "plonky3")]
BackendType::Plonky3Composite => {
Box::new(composite::CompositeBackendFactory::new(plonky3::Factory))
}
}
}
}
Expand Down
24 changes: 5 additions & 19 deletions backend/src/plonky3/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::{io, path::PathBuf, sync::Arc};

use powdr_ast::analyzed::Analyzed;
use powdr_executor::{
constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn},
witgen::WitgenCallback,
};
use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback};
use powdr_number::{BabyBearField, GoldilocksField, Mersenne31Field};
use powdr_plonky3::{Commitment, FieldElementMap, Plonky3Prover, ProverData};

Expand Down Expand Up @@ -35,22 +32,8 @@ where
if verification_app_key.is_some() {
return Err(Error::NoAggregationAvailable);
}
if pil.degrees().len() > 1 {
return Err(Error::NoVariableDegreeAvailable);
}
if pil.public_declarations_in_source_order().any(|(_, d)| {
pil.definitions.iter().any(|(_, (symbol, _))| {
symbol.absolute_name == d.name && symbol.stage.unwrap_or_default() > 0
})
}) {
return Err(Error::NoLaterStagePublicAvailable);
}

let fixed = Arc::new(
get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?,
);

let mut p3 = Box::new(Plonky3Prover::new(pil.clone(), fixed.clone()));
let mut p3 = Box::new(Plonky3Prover::new(pil.clone(), fixed));

if let Some(verification_key) = verification_key {
p3.set_verifying_key(verification_key);
Expand All @@ -70,6 +53,9 @@ where
Commitment<T>: Send,
{
fn verify(&self, proof: &[u8], instances: &[Vec<T>]) -> Result<(), Error> {
assert_eq!(instances.len(), 1);
let instances = &instances[0];

Ok(self.verify(proof, instances)?)
}

Expand Down
13 changes: 10 additions & 3 deletions executor/src/witgen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,12 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> {

log::debug!("Publics:");
for (name, value) in extract_publics(&witness_cols, self.analyzed) {
log::debug!(" {name:>30}: {value}");
log::debug!(
" {name:>30}: {}",
value
.map(|value| value.to_string())
.unwrap_or_else(|| "Not yet known at this stage".to_string())
);
}
witness_cols
}
Expand All @@ -297,7 +302,7 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> {
pub fn extract_publics<T: FieldElement>(
witness: &[(String, Vec<T>)],
pil: &Analyzed<T>,
) -> Vec<(String, T)> {
) -> Vec<(String, Option<T>)> {
let witness = witness
.iter()
.map(|(name, col)| (name.clone(), col))
Expand All @@ -306,7 +311,9 @@ pub fn extract_publics<T: FieldElement>(
.map(|(name, public_declaration)| {
let poly_name = &public_declaration.referenced_poly_name();
let poly_index = public_declaration.index;
let value = witness[poly_name][poly_index as usize];
let value = witness
.get(poly_name)
.map(|column| column[poly_index as usize]);
((*name).clone(), value)
})
.collect()
Expand Down
2 changes: 1 addition & 1 deletion pipeline/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ impl<T: FieldElement> Pipeline<T> {
Ok(self.artifact.witness.as_ref().unwrap().clone())
}

pub fn publics(&self) -> Result<Vec<(String, T)>, Vec<String>> {
pub fn publics(&self) -> Result<Vec<(String, Option<T>)>, Vec<String>> {
let pil = self.optimized_pil()?;
let witness = self.witness()?;
Ok(extract_publics(&witness, &pil))
Expand Down
26 changes: 11 additions & 15 deletions pipeline/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ pub fn gen_estark_proof_with_backend_variant(
.publics()
.unwrap()
.iter()
.map(|(_name, v)| *v)
.map(|(_name, v)| v.expect("all publics should be known since we created a proof"))
.collect();

pipeline.verify(&proof, &[publics]).unwrap();
Expand Down Expand Up @@ -277,7 +277,7 @@ pub fn gen_halo2_proof(pipeline: Pipeline<Bn254Field>, backend: BackendVariant)
.publics()
.unwrap()
.iter()
.map(|(_name, v)| *v)
.map(|(_name, v)| v.expect("all publics should be known since we created a proof"))
.collect();

pipeline.verify(&proof, &[publics]).unwrap();
Expand All @@ -287,15 +287,8 @@ pub fn gen_halo2_proof(pipeline: Pipeline<Bn254Field>, backend: BackendVariant)
pub fn gen_halo2_proof(_pipeline: Pipeline<Bn254Field>, _backend: BackendVariant) {}

#[cfg(feature = "plonky3")]
pub fn test_plonky3_with_backend_variant<T: FieldElement>(
file_name: &str,
inputs: Vec<T>,
backend: BackendVariant,
) {
let backend = match backend {
BackendVariant::Monolithic => powdr_backend::BackendType::Plonky3,
BackendVariant::Composite => powdr_backend::BackendType::Plonky3Composite,
};
pub fn test_plonky3<T: FieldElement>(file_name: &str, inputs: Vec<T>) {
let backend = powdr_backend::BackendType::Plonky3;
let mut pipeline = Pipeline::default()
.with_tmp_output()
.from_file(resolve_test_file(file_name))
Expand All @@ -310,7 +303,7 @@ pub fn test_plonky3_with_backend_variant<T: FieldElement>(
.clone()
.unwrap()
.iter()
.map(|(_name, v)| *v)
.map(|(_name, v)| v.expect("all publics should be known since we created a proof"))
.collect();

pipeline.verify(&proof, &[publics.clone()]).unwrap();
Expand All @@ -333,7 +326,7 @@ pub fn test_plonky3_with_backend_variant<T: FieldElement>(

#[cfg(feature = "plonky3")]
pub fn test_plonky3_pipeline<T: FieldElement>(pipeline: Pipeline<T>) {
let mut pipeline = pipeline.with_backend(powdr_backend::BackendType::Plonky3Composite, None);
let mut pipeline = pipeline.with_backend(powdr_backend::BackendType::Plonky3, None);

pipeline.compute_witness().unwrap();

Expand All @@ -349,7 +342,7 @@ pub fn test_plonky3_pipeline<T: FieldElement>(pipeline: Pipeline<T>) {
.clone()
.unwrap()
.iter()
.map(|(_name, v)| *v)
.map(|(_name, v)| v.expect("all publics should be known since we created a proof"))
.collect();

pipeline.verify(&proof, &[publics.clone()]).unwrap();
Expand All @@ -371,7 +364,10 @@ pub fn test_plonky3_pipeline<T: FieldElement>(pipeline: Pipeline<T>) {
}

#[cfg(not(feature = "plonky3"))]
pub fn test_plonky3_with_backend_variant<T: FieldElement>(_: &str, _: Vec<T>, _: BackendVariant) {}
pub fn test_plonky3<T: FieldElement>(_: &str, _: Vec<T>) {}

#[cfg(not(feature = "plonky3"))]
pub fn test_plonky3_pipeline<T: FieldElement>(_: Pipeline<T>) {}

#[cfg(not(feature = "plonky3"))]
pub fn gen_plonky3_proof<T: FieldElement>(_: &str, _: Vec<T>) {}
Expand Down
9 changes: 2 additions & 7 deletions pipeline/tests/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ use powdr_pipeline::{
asm_string_to_pil, gen_estark_proof_with_backend_variant, make_prepared_pipeline,
make_simple_prepared_pipeline, regular_test, regular_test_without_babybear,
resolve_test_file, run_pilcom_with_backend_variant, test_halo2,
test_halo2_with_backend_variant, test_pilcom, test_plonky3_with_backend_variant,
BackendVariant,
test_halo2_with_backend_variant, test_pilcom, test_plonky3, BackendVariant,
},
util::{FixedPolySet, PolySet, WitnessPolySet},
Pipeline,
Expand Down Expand Up @@ -39,11 +38,7 @@ fn simple_sum_asm() {
let f = "asm/simple_sum.asm";
let i = [16, 4, 1, 2, 8, 5];
regular_test(f, &i);
test_plonky3_with_backend_variant::<GoldilocksField>(
f,
slice_to_vec(&i),
BackendVariant::Composite,
);
test_plonky3::<GoldilocksField>(f, slice_to_vec(&i));
}

#[test]
Expand Down
Loading

0 comments on commit 0f66c6a

Please sign in to comment.