diff --git a/prover/src/lib.rs b/prover/src/lib.rs index efd218f860..08cc9444f6 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -765,40 +765,40 @@ pub mod tests { let core_proof = prover.prove_core(&pk, &stdin, opts, context)?; let public_values = core_proof.public_values.clone(); - tracing::info!("verify core"); - prover.verify(&core_proof.proof, &vk)?; + // tracing::info!("verify core"); + // prover.verify(&core_proof.proof, &vk)?; - if test_kind == Test::Core { - return Ok(()); - } + // if test_kind == Test::Core { + // return Ok(()); + // } - tracing::info!("compress"); - let compressed_proof = prover.compress(&vk, core_proof, vec![], opts)?; + // tracing::info!("compress"); + // let compressed_proof = prover.compress(&vk, core_proof, vec![], opts)?; - tracing::info!("verify compressed"); - prover.verify_compressed(&compressed_proof, &vk)?; + // tracing::info!("verify compressed"); + // prover.verify_compressed(&compressed_proof, &vk)?; - if test_kind == Test::Compress { - return Ok(()); - } + // if test_kind == Test::Compress { + // return Ok(()); + // } - tracing::info!("shrink"); - let shrink_proof = prover.shrink(compressed_proof, opts)?; + // tracing::info!("shrink"); + // let shrink_proof = prover.shrink(compressed_proof, opts)?; - tracing::info!("verify shrink"); - prover.verify_shrink(&shrink_proof, &vk)?; + // tracing::info!("verify shrink"); + // prover.verify_shrink(&shrink_proof, &vk)?; - if test_kind == Test::Shrink { - return Ok(()); - } + // if test_kind == Test::Shrink { + // return Ok(()); + // } - tracing::info!("wrap bn254"); - let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; - let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); + // tracing::info!("wrap bn254"); + // let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; + // let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); - // Save the proof. - let mut file = File::create("proof-with-pis.bin").unwrap(); - file.write_all(bytes.as_slice()).unwrap(); + // // Save the proof. + // let mut file = File::create("proof-with-pis.bin").unwrap(); + // file.write_all(bytes.as_slice()).unwrap(); // Load the proof. let mut file = File::open("proof-with-pis.bin").unwrap(); @@ -828,7 +828,7 @@ pub mod tests { let plonk_bn254_proof = prover.wrap_plonk_bn254(wrapped_bn254_proof, &artifacts_dir); println!("{:?}", plonk_bn254_proof); - prover.verify_plonk_bn254(&plonk_bn254_proof, &vk, &public_values, &artifacts_dir)?; + // prover.verify_plonk_bn254(&plonk_bn254_proof, &vk, &public_values, &artifacts_dir)?; Ok(()) } diff --git a/recursion/circuit-v2/Cargo.toml b/recursion/circuit-v2/Cargo.toml index ae8c6cf89c..984574a77b 100644 --- a/recursion/circuit-v2/Cargo.toml +++ b/recursion/circuit-v2/Cargo.toml @@ -52,3 +52,7 @@ rand = "0.8.5" [features] native-gnark = ["sp1-recursion-gnark-ffi/native"] + +[[bin]] +name = "circuit-architecture-sweep" +path = "scripts/circuit_architecture_sweep.rs" \ No newline at end of file diff --git a/recursion/circuit-v2/scripts/circuit_architecture_sweep.rs b/recursion/circuit-v2/scripts/circuit_architecture_sweep.rs index 0e76784618..40a69732fd 100644 --- a/recursion/circuit-v2/scripts/circuit_architecture_sweep.rs +++ b/recursion/circuit-v2/scripts/circuit_architecture_sweep.rs @@ -18,49 +18,49 @@ fn machine_with_dummy( fn main() { // Test the performance of the full architecture with different degrees. - let machine_maker_3 = || machine_with_all_chips::<3>(16, 16, 16); - let machine_maker_5 = || machine_with_all_chips::<5>(16, 16, 16); - let machine_maker_9 = || machine_with_all_chips::<9>(16, 16, 16); - let machine_maker_17 = || machine_with_all_chips::<17>(16, 16, 16); - test_machine(machine_maker_3); - test_machine(machine_maker_5); - test_machine(machine_maker_9); - test_machine(machine_maker_17); + // let machine_maker_3 = || machine_with_all_chips::<3>(16, 16, 16); + // let machine_maker_5 = || machine_with_all_chips::<5>(16, 16, 16); + // let machine_maker_9 = || machine_with_all_chips::<9>(16, 16, 16); + // let machine_maker_17 = || machine_with_all_chips::<17>(16, 16, 16); + // test_machine(machine_maker_3); + // test_machine(machine_maker_5); + // test_machine(machine_maker_9); + // test_machine(machine_maker_17); // Test the performance of the machine with the full architecture for different numbers of rows // in the precompiles. Degree is set to 9. - let machine_maker = |i| machine_with_all_chips::<9>(i, i, i); - for i in 1..=5 { - test_machine(|| machine_maker(i)); - } + // let machine_maker = |i| machine_with_all_chips::<9>(i, i, i); + // for i in 1..=5 { + // test_machine(|| machine_maker(i)); + // } // Test the performance of the dummy machine for different numbers of columns in the dummy table. // Degree is kept fixed at 9. - test_machine(|| machine_with_dummy::<9, 1>(16)); - test_machine(|| machine_with_dummy::<9, 50>(16)); - test_machine(|| machine_with_dummy::<9, 100>(16)); - test_machine(|| machine_with_dummy::<9, 150>(16)); - test_machine(|| machine_with_dummy::<9, 200>(16)); - test_machine(|| machine_with_dummy::<9, 250>(16)); - test_machine(|| machine_with_dummy::<9, 300>(16)); - test_machine(|| machine_with_dummy::<9, 350>(16)); - test_machine(|| machine_with_dummy::<9, 400>(16)); - test_machine(|| machine_with_dummy::<9, 450>(16)); - test_machine(|| machine_with_dummy::<9, 500>(16)); - test_machine(|| machine_with_dummy::<9, 550>(16)); - test_machine(|| machine_with_dummy::<9, 600>(16)); - test_machine(|| machine_with_dummy::<9, 650>(16)); - test_machine(|| machine_with_dummy::<9, 700>(16)); - test_machine(|| machine_with_dummy::<9, 750>(16)); + // test_machine(|| machine_with_dummy::<9, 1>(16)); + // test_machine(|| machine_with_dummy::<9, 50>(16)); + // test_machine(|| machine_with_dummy::<9, 100>(16)); + // test_machine(|| machine_with_dummy::<9, 150>(16)); + // test_machine(|| machine_with_dummy::<9, 200>(16)); + // test_machine(|| machine_with_dummy::<9, 250>(16)); + // test_machine(|| machine_with_dummy::<9, 300>(16)); + // test_machine(|| machine_with_dummy::<9, 350>(16)); + // test_machine(|| machine_with_dummy::<9, 400>(16)); + // test_machine(|| machine_with_dummy::<9, 450>(16)); + // test_machine(|| machine_with_dummy::<9, 500>(16)); + // test_machine(|| machine_with_dummy::<9, 550>(16)); + // test_machine(|| machine_with_dummy::<9, 600>(16)); + // test_machine(|| machine_with_dummy::<9, 650>(16)); + // test_machine(|| machine_with_dummy::<9, 700>(16)); + test_machine(|| machine_with_dummy::<5, 130>(18)); // Test the performance of the dummy machine for different heights of the dummy table. - for i in 4..=7 { - test_machine(|| machine_with_dummy::<9, 1>(i)); - } + // for i in 4..=7 { + // test_machine(|| machine_with_dummy::<9, 1>(i)); + // } // Change the degree for the dummy table, keeping other parameters fixed. - test_machine(|| machine_with_dummy::<3, 500>(16)); - test_machine(|| machine_with_dummy::<5, 500>(16)); - test_machine(|| machine_with_dummy::<9, 500>(16)); - test_machine(|| machine_with_dummy::<17, 500>(16)); + // test_machine(|| machine_with_dummy::<3, 500>(16)); + // test_machine(|| machine_with_dummy::<5, 500>(16)); + // test_machine(|| machine_with_dummy::<9, 500>(16)); + // test_machine(|| machine_with_dummy::<17, 500>(16)); } diff --git a/recursion/circuit-v2/src/build_wrap_v2.rs b/recursion/circuit-v2/src/build_wrap_v2.rs index e9a6bc33da..39ab349d66 100644 --- a/recursion/circuit-v2/src/build_wrap_v2.rs +++ b/recursion/circuit-v2/src/build_wrap_v2.rs @@ -1,5 +1,5 @@ -use std::borrow::Borrow; use std::iter::once; +use std::{borrow::Borrow, os}; use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear}; use p3_bn254_fr::Bn254Fr; @@ -27,9 +27,11 @@ use sp1_recursion_circuit::{ use sp1_recursion_compiler::{ config::OuterConfig, constraints::{Constraint, ConstraintCompiler}, - ir::{Builder, Config, Ext, Felt, Usize, Var, Witness}, + ir::{Builder, Config, DslIr, Ext, Felt, Usize, Var, Witness}, +}; +use sp1_recursion_core::{ + air::RecursionPublicValues, range_check, stark::config::BabyBearPoseidon2Outer, }; -use sp1_recursion_core::{air::RecursionPublicValues, stark::config::BabyBearPoseidon2Outer}; use sp1_recursion_program::types::QuotientDataValues; type OuterSC = BabyBearPoseidon2Outer; @@ -159,6 +161,10 @@ where // builder.assert_felt_eq(*expected_elm, *calculated_elm); // } + builder + .operations + .push(DslIr::CycleTracker("Hello world".to_string())); + let mut backend = ConstraintCompiler::::default(); backend.emit(builder.operations) } diff --git a/recursion/circuit/src/challenger.rs b/recursion/circuit/src/challenger.rs index aa947f6b04..ac623ba027 100644 --- a/recursion/circuit/src/challenger.rs +++ b/recursion/circuit/src/challenger.rs @@ -122,8 +122,7 @@ pub fn reduce_32(builder: &mut Builder, vals: &[Felt]) -> Va let mut power = C::N::one(); let result: Var = builder.eval(C::N::zero()); for val in vals.iter() { - let bits = builder.num2bits_f_circuit(*val); - let val = builder.bits2num_v_circuit(&bits); + let val = builder.felt2var_circuit(*val); builder.assign(result, result + val * power); power *= C::N::from_canonical_u64(1u64 << 32); } diff --git a/recursion/circuit/src/fri.rs b/recursion/circuit/src/fri.rs index 6a5d518859..97a42abe97 100644 --- a/recursion/circuit/src/fri.rs +++ b/recursion/circuit/src/fri.rs @@ -43,11 +43,26 @@ pub fn verify_shape_and_sample_challenges( challenger.check_witness(builder, config.proof_of_work_bits, proof.pow_witness); let log_max_height = proof.commit_phase_commits.len() + config.log_blowup; + + let mut precomputed_generator_powers: Vec> = vec![ + builder.eval(SymbolicFelt::from_f(C::F::one())), + builder.eval(SymbolicFelt::from_f(C::F::two_adic_generator( + log_max_height, + ))), + ]; + + let mut cur = precomputed_generator_powers[1]; + for _ in 2..log_max_height { + cur = builder.eval(cur * cur); + precomputed_generator_powers.push(cur); + } + let query_indices: Vec> = (0..config.num_queries) .map(|_| challenger.sample_bits(builder, log_max_height)) .collect(); FriChallenges { + precomputed_generator_powers, query_indices, betas, } @@ -60,6 +75,7 @@ pub fn verify_two_adic_pcs( challenger: &mut MultiField32ChallengerVariable, rounds: Vec>, ) { + let mut counter = 0; let alpha = challenger.sample_ext(builder); let fri_challenges = @@ -125,16 +141,22 @@ pub fn verify_two_adic_pcs( let two_adic_generator_exp = builder.exp_f_bits(two_adic_generator, rev_reduced_index); let x: Felt<_> = builder.eval(g * two_adic_generator_exp); + builder.operations.push(DslIr::ReduceF(x)); for (z, ps_at_z) in izip!(mat_points, mat_values) { let mut acc: Ext = builder.eval(SymbolicExt::from_f(C::EF::zero())); + counter += 1; for (p_at_x, &p_at_z) in izip!(mat_opening.clone(), ps_at_z) { let pow = log_height_pow[log_height]; + // Fill in any missing powers of alpha. (alpha_pows.len()..pow + 1).for_each(|_| { - alpha_pows.push(builder.eval(*alpha_pows.last().unwrap() * alpha)); + let alpha = builder.eval(*alpha_pows.last().unwrap() * alpha); + builder.operations.push(DslIr::ReduceE(alpha)); + alpha_pows.push(alpha); }); + acc = builder.eval(acc + (alpha_pows[pow] * (p_at_z - p_at_x[0]))); log_height_pow[log_height] += 1; } @@ -146,6 +168,8 @@ pub fn verify_two_adic_pcs( }) .collect::>(); + println!("Counter: {}", counter); + verify_challenges( builder, config, @@ -163,6 +187,8 @@ pub fn verify_challenges( reduced_openings: Vec<[Ext; 32]>, ) { let log_max_height = proof.commit_phase_commits.len() + config.log_blowup; + println!("Log max height: {}", log_max_height); + println!("Num queries: {}", config.num_queries); for (&index, query_proof, ro) in izip!( &challenges.query_indices, &proof.query_proofs, @@ -176,6 +202,7 @@ pub fn verify_challenges( challenges.betas.clone(), ro, log_max_height, + challenges.precomputed_generator_powers.clone(), ); builder.assert_ext_eq(folded_eval, proof.final_poly); @@ -190,14 +217,12 @@ pub fn verify_query( betas: Vec>, reduced_openings: [Ext; 32], log_max_height: usize, + precomputed_generator_powers: Vec>, ) -> Ext { let mut folded_eval: Ext = builder.eval(SymbolicExt::from_f(C::EF::zero())); - let two_adic_generator = builder.eval(SymbolicExt::from_f(C::EF::two_adic_generator( - log_max_height, - ))); let index_bits = builder.num2bits_v_circuit(index, 32); let rev_reduced_index = builder.reverse_bits_len_circuit(index_bits.clone(), log_max_height); - let mut x = builder.exp_e_bits(two_adic_generator, rev_reduced_index); + let mut x = builder.exp_f_bits_precomputed(rev_reduced_index, &precomputed_generator_powers); let mut offset = 0; for (log_folded_height, commit, step, beta) in izip!( @@ -212,6 +237,9 @@ pub fn verify_query( let index_sibling: Var<_> = builder.eval(one - index_bits.clone()[offset]); let index_pair = &index_bits[(offset + 1)..]; + builder.operations.push(DslIr::ReduceE(folded_eval)); + builder.operations.push(DslIr::ReduceE(step.sibling_value)); + let evals_ext = [ builder.select_ef(index_sibling, folded_eval, step.sibling_value), builder.select_ef(index_sibling, step.sibling_value, folded_eval), @@ -234,13 +262,15 @@ pub fn verify_query( step.opening_proof.clone(), ); - let xs_new = builder.eval(x * C::EF::two_adic_generator(1)); + let xs_new = builder.eval(x * C::F::two_adic_generator(1)); let xs = [ - builder.select_ef(index_sibling, x, xs_new), - builder.select_ef(index_sibling, xs_new, x), + builder.select_f(index_sibling, x, xs_new), + builder.select_f(index_sibling, xs_new, x), ]; - folded_eval = builder - .eval(evals_ext[0] + (beta - xs[0]) * (evals_ext[1] - evals_ext[0]) / (xs[1] - xs[0])); + let one: Felt<_> = builder.eval(C::F::one()); + let inv: Felt<_> = builder.eval(one / (xs[1] - xs[0])); + folded_eval = + builder.eval(evals_ext[0] + (beta - xs[0]) * (evals_ext[1] - evals_ext[0]) * inv); x = builder.eval(x * x); offset += 1; } diff --git a/recursion/circuit/src/stark.rs b/recursion/circuit/src/stark.rs index 526466cd6f..65cace5e81 100644 --- a/recursion/circuit/src/stark.rs +++ b/recursion/circuit/src/stark.rs @@ -19,7 +19,7 @@ use sp1_core::{ }; use sp1_recursion_compiler::config::OuterConfig; use sp1_recursion_compiler::constraints::{Constraint, ConstraintCompiler}; -use sp1_recursion_compiler::ir::{Builder, Config, Ext, Felt, Var}; +use sp1_recursion_compiler::ir::{Builder, Config, DslIr, Ext, Felt, Var}; use sp1_recursion_compiler::ir::{Usize, Witness}; use sp1_recursion_compiler::prelude::SymbolicVar; use sp1_recursion_core::air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH}; @@ -358,6 +358,10 @@ pub fn build_wrap_circuit( builder.assert_felt_eq(*expected_elm, *calculated_elm); } + builder + .operations + .push(DslIr::CycleTracker("Hello World".to_string())); + let mut backend = ConstraintCompiler::::default(); backend.emit(builder.operations) } diff --git a/recursion/circuit/src/types.rs b/recursion/circuit/src/types.rs index 562204dcf5..948a7ea5b2 100644 --- a/recursion/circuit/src/types.rs +++ b/recursion/circuit/src/types.rs @@ -48,6 +48,7 @@ pub struct FriQueryProofVariable { /// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/fri/src/verifier.rs#L22 #[derive(Clone)] pub struct FriChallenges { + pub precomputed_generator_powers: Vec>, pub query_indices: Vec>, pub betas: Vec>, } diff --git a/recursion/compiler/src/constraints/mod.rs b/recursion/compiler/src/constraints/mod.rs index eb43951358..d59866cad3 100644 --- a/recursion/compiler/src/constraints/mod.rs +++ b/recursion/compiler/src/constraints/mod.rs @@ -207,10 +207,21 @@ impl ConstraintCompiler { args: vec![vec![a.id()], vec![b.id()], vec![tmp]], }); } + DslIr::MulFI(a, b, c) => { + let tmp = self.alloc_f(&mut constraints, c); + constraints.push(Constraint { + opcode: ConstraintOpcode::MulF, + args: vec![vec![a.id()], vec![b.id()], vec![tmp]], + }); + } DslIr::MulEF(a, b, c) => constraints.push(Constraint { opcode: ConstraintOpcode::MulEF, args: vec![vec![a.id()], vec![b.id()], vec![c.id()]], }), + DslIr::DivF(a, b, c) => constraints.push(Constraint { + opcode: ConstraintOpcode::DivF, + args: vec![vec![a.id()], vec![b.id()], vec![c.id()]], + }), DslIr::DivFIN(a, b, c) => { let tmp = self.alloc_f(&mut constraints, b.inverse()); constraints.push(Constraint { @@ -245,6 +256,10 @@ impl ConstraintCompiler { opcode: ConstraintOpcode::Num2BitsF, args: vec![output.iter().map(|x| x.id()).collect(), vec![value.id()]], }), + DslIr::CircuitFelt2Var(value, output) => constraints.push(Constraint { + opcode: ConstraintOpcode::CircuitFelt2Var, + args: vec![vec![output.id()], vec![value.id()]], + }), DslIr::CircuitPoseidon2Permute(state) => constraints.push(Constraint { opcode: ConstraintOpcode::Permute, args: state.iter().map(|x| vec![x.id()]).collect(), @@ -358,6 +373,18 @@ impl ConstraintCompiler { vec![a[3].id()], ], }), + DslIr::ReduceF(a) => constraints.push(Constraint { + opcode: ConstraintOpcode::ReduceF, + args: vec![vec![a.id()]], + }), + DslIr::ReduceE(a) => constraints.push(Constraint { + opcode: ConstraintOpcode::ReduceE, + args: vec![vec![a.id()]], + }), + DslIr::CycleTracker(a) => constraints.push(Constraint { + opcode: ConstraintOpcode::CycleTracker, + args: vec![vec![a]], + }), _ => panic!("unsupported {:?}", instruction), }; } diff --git a/recursion/compiler/src/constraints/opcodes.rs b/recursion/compiler/src/constraints/opcodes.rs index 4911e0f108..437eb8b821 100644 --- a/recursion/compiler/src/constraints/opcodes.rs +++ b/recursion/compiler/src/constraints/opcodes.rs @@ -46,5 +46,9 @@ pub enum ConstraintOpcode { CommitVkeyHash, CommitCommitedValuesDigest, CircuitFelts2Ext, + CircuitFelt2Var, PermuteBabyBear, + ReduceF, + ReduceE, + CycleTracker, } diff --git a/recursion/compiler/src/ir/bits.rs b/recursion/compiler/src/ir/bits.rs index b08a7e29e5..f5fbdd33ad 100644 --- a/recursion/compiler/src/ir/bits.rs +++ b/recursion/compiler/src/ir/bits.rs @@ -92,6 +92,14 @@ impl Builder { output } + /// Converts a felt to a var inside a circuit. + pub fn felt2var_circuit(&mut self, num: Felt) -> Var { + let output = self.uninit(); + self.push(DslIr::CircuitFelt2Var(num, output)); + + output + } + /// Convert bits to a variable. pub fn bits2num_v(&mut self, bits: &Array>) -> Var { let num: Var<_> = self.eval(C::N::zero()); diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index a135a3b8eb..3a69ae8aca 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -17,6 +17,10 @@ pub enum DslIr { /// Assigns an ext field immediate to an extension field element (ext = ext field imm). ImmE(Ext, C::EF), + /// Force the reduction of an extension field element. + ReduceF(Felt), + ReduceE(Ext), + // Additions. /// Add two variables (var = var + var). AddV(Var, Var, Var), @@ -288,6 +292,8 @@ pub enum DslIr { CircuitExt2Felt([Felt; 4], Ext), /// Converts a slice of felts to an ext. Should only be used when target is a gnark circuit. CircuitFelts2Ext([Felt; 4], Ext), + /// Converts a felt to a var. Should only be used when target is a gnark circuit. + CircuitFelt2Var(Felt, Var), // Debugging instructions. /// Executes less than (var = var < var). This operation is NOT constrained. diff --git a/recursion/compiler/src/ir/types.rs b/recursion/compiler/src/ir/types.rs index c30c8fa4a4..a751dc1a65 100644 --- a/recursion/compiler/src/ir/types.rs +++ b/recursion/compiler/src/ir/types.rs @@ -2,6 +2,7 @@ use alloc::format; use core::marker::PhantomData; use std::collections::HashMap; use std::hash::Hash; +use std::rc::Rc; use p3_field::AbstractField; use p3_field::ExtensionField; @@ -993,6 +994,14 @@ impl> Ext { (SymbolicExt::Val(lhs, _), SymbolicExt::Val(rhs, _)) => { builder.push(DslIr::MulE(*self, *lhs, *rhs)); } + + (SymbolicExt::Base(lhs, _), SymbolicExt::Val(rhs, _)) => { + let lhs_val: Felt<_> = builder.uninit(); + let s = lhs.clone(); + let z = (*s).clone(); + lhs_val.assign_with_cache(z, builder, base_cache); + builder.push(DslIr::MulEF(*self, *rhs, lhs_val)); + } (SymbolicExt::Val(lhs, _), rhs) => { let rhs_value = Self::uninit(builder); rhs_value.assign_with_caches(rhs.clone(), builder, ext_cache, base_cache); @@ -1005,6 +1014,19 @@ impl> Ext { ext_cache.insert(lhs.clone(), lhs_value); builder.push(DslIr::MulEI(*self, lhs_value, *rhs)); } + (lhs, SymbolicExt::Base(rhs, _)) => { + let rhs_val: Felt<_> = builder.uninit(); + let s = rhs.clone(); + let z = (*s).clone(); + rhs_val.assign_with_cache(z, builder, base_cache); + + let lhs_value = Self::uninit(builder); + lhs_value.assign_with_caches(lhs.clone(), builder, ext_cache, base_cache); + ext_cache.insert(lhs.clone(), lhs_value); + + builder.push(DslIr::MulEF(*self, lhs_value, rhs_val)); + } + (lhs, SymbolicExt::Val(rhs, _)) => { let lhs_value = Self::uninit(builder); lhs_value.assign_with_caches(lhs.clone(), builder, ext_cache, base_cache); diff --git a/recursion/compiler/src/ir/utils.rs b/recursion/compiler/src/ir/utils.rs index 5c161a85ca..1cde4e8f37 100644 --- a/recursion/compiler/src/ir/utils.rs +++ b/recursion/compiler/src/ir/utils.rs @@ -1,7 +1,9 @@ use p3_field::{AbstractExtensionField, AbstractField}; use std::ops::{Add, Mul, MulAssign}; -use super::{Array, Builder, Config, DslIr, Ext, Felt, SymbolicExt, Usize, Var, Variable}; +use super::{ + Array, Builder, Config, DslIr, Ext, Felt, SymbolicExt, SymbolicFelt, Usize, Var, Variable, +}; impl Builder { /// The generator for the field. @@ -101,6 +103,22 @@ impl Builder { result } + /// Exponentiates a felt x to a list of bits in little endian. Uses precomputed powers + /// of x. + pub fn exp_f_bits_precomputed( + &mut self, + power_bits: Vec>, + two_adic_powers_of_x: &[Felt], + ) -> Felt { + let mut result: Felt<_> = self.eval(C::F::one()); + for i in 0..power_bits.len() { + let bit = power_bits[i]; + let tmp = self.eval(result * two_adic_powers_of_x[i]); + result = self.select_f(bit, tmp, result); + } + result + } + /// Exponetiates a varibale to a list of reversed bits with a given length. /// /// Reference: [p3_util::reverse_bits_len] diff --git a/recursion/gnark-ffi/go/main.go b/recursion/gnark-ffi/go/main.go index ed782400f2..f4eb43f036 100644 --- a/recursion/gnark-ffi/go/main.go +++ b/recursion/gnark-ffi/go/main.go @@ -18,11 +18,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" - "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/frontend/cs/scs" - "github.com/consensys/gnark/test/unsafekzg" "github.com/succinctlabs/sp1-recursion-gnark/sp1" "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" "github.com/succinctlabs/sp1-recursion-gnark/sp1/poseidon2" @@ -75,13 +72,16 @@ func VerifyPlonkBn254(dataDir *C.char, proof *C.char, vkeyHash *C.char, commited var testMutex = &sync.Mutex{} //export TestPlonkBn254 -func TestPlonkBn254(witnessPath *C.char, constraintsJson *C.char) *C.char { +func TestPlonkBn254(witnessPath *C.char, constraintsJson *C.char, rangeChecker *C.char) *C.char { // Because of the global env variables used here, we need to lock this function testMutex.Lock() witnessPathString := C.GoString(witnessPath) constraintsJsonString := C.GoString(constraintsJson) + rangeCheckerString := C.GoString(rangeChecker) os.Setenv("WITNESS_JSON", witnessPathString) os.Setenv("CONSTRAINTS_JSON", constraintsJsonString) + os.Setenv("RANGE_CHECKER", rangeCheckerString) + fmt.Print(rangeCheckerString) err := TestMain() testMutex.Unlock() if err != nil { @@ -112,36 +112,37 @@ func TestMain() error { // Compile the circuit. circuit := sp1.NewCircuit(inputs) - builder := scs.NewBuilder - scs, err := frontend.Compile(ecc.BN254.ScalarField(), builder, &circuit) - if err != nil { - return err - } - fmt.Println("[sp1] gnark verifier constraints:", scs.GetNbConstraints()) - - // Run the dummy setup. - srs, srsLagrange, err := unsafekzg.NewSRS(scs) - if err != nil { - return err - } - var pk plonk.ProvingKey - pk, _, err = plonk.Setup(scs, srs, srsLagrange) - if err != nil { - return err - } - - // Generate witness. - assignment := sp1.NewCircuit(inputs) - witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) - if err != nil { - return err - } - - // Generate the proof. - _, err = plonk.Prove(scs, pk, witness) + // builder := scs.NewBuilder + // Compile the circuit. + r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) if err != nil { - return err - } + panic(err) + } + fmt.Println("Groth16 Constraints:", r1cs.GetNbConstraints()) + + // // Run the dummy setup. + // srs, srsLagrange, err := unsafekzg.NewSRS(scs) + // if err != nil { + // return err + // } + // var pk plonk.ProvingKey + // pk, _, err = plonk.Setup(scs, srs, srsLagrange) + // if err != nil { + // return err + // } + + // // Generate witness. + // assignment := sp1.NewCircuit(inputs) + // witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + // if err != nil { + // return err + // } + + // // Generate the proof. + // _, err = plonk.Prove(scs, pk, witness) + // if err != nil { + // return err + // } return nil } diff --git a/recursion/gnark-ffi/go/sp1/babybear/babybear.go b/recursion/gnark-ffi/go/sp1/babybear/babybear.go index 066edd2653..3f86d2e33d 100644 --- a/recursion/gnark-ffi/go/sp1/babybear/babybear.go +++ b/recursion/gnark-ffi/go/sp1/babybear/babybear.go @@ -6,7 +6,10 @@ package babybear import "C" import ( + "fmt" + "math" "math/big" + "os" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" @@ -32,24 +35,53 @@ type ExtensionVariable struct { } type Chip struct { - api frontend.API - rangeChecker frontend.Rangechecker + api frontend.API + rangeChecker frontend.Rangechecker + ReduceMaxBitsCounter int + AddFCounter int + MulFCounter int + AddEFCounter int + MulEFCounter int + AddECounter int + MulECounter int + ReduceMaxBitsMap map[string]int } func NewChip(api frontend.API) *Chip { return &Chip{ - api: api, - rangeChecker: rangecheck.New(api), + api: api, + rangeChecker: rangecheck.New(api), + ReduceMaxBitsCounter: 0, + ReduceMaxBitsMap: make(map[string]int), } } func NewF(value string) Variable { + if value == "0" { + return Zero() + } else if value == "1" { + return One() + } return Variable{ Value: frontend.Variable(value), NbBits: 31, } } +func Zero() Variable { + return Variable{ + Value: frontend.Variable("0"), + NbBits: 1, + } +} + +func One() Variable { + return Variable{ + Value: frontend.Variable("1"), + NbBits: 1, + } +} + func NewE(value []string) ExtensionVariable { a := NewF(value[0]) b := NewF(value[1]) @@ -64,45 +96,76 @@ func Felts2Ext(a, b, c, d Variable) ExtensionVariable { func (c *Chip) AddF(a, b Variable) Variable { var maxBits uint + c.AddFCounter++ if a.NbBits > b.NbBits { maxBits = a.NbBits } else { maxBits = b.NbBits } - return c.reduceFast(Variable{ + curNumReduce := c.ReduceMaxBitsCounter + retVal := c.reduceFast(Variable{ Value: c.api.Add(a.Value, b.Value), - NbBits: maxBits + 1, + NbBits: maxBits, }) + c.ReduceMaxBitsMap["AddF"] += c.ReduceMaxBitsCounter - curNumReduce + return retVal + } func (c *Chip) SubF(a, b Variable) Variable { + curNumReduce := c.ReduceMaxBitsCounter negB := c.negF(b) - return c.AddF(a, negB) + retVal := c.AddF(a, negB) + c.ReduceMaxBitsMap["SubF"] += c.ReduceMaxBitsCounter - curNumReduce + return retVal } -func (c *Chip) MulF(a, b Variable) Variable { - return c.reduceFast(Variable{ - Value: c.api.Mul(a.Value, b.Value), - NbBits: a.NbBits + b.NbBits, - }) +func (c *Chip) MulF(a, b Variable) (Variable, Variable, Variable) { + curNumReduce := c.ReduceMaxBitsCounter + varC := a + varD := b + + for varC.NbBits+varD.NbBits > 252 { + if varC.NbBits > varD.NbBits { + varC = Variable{Value: c.reduceWithMaxBits(varC.Value, uint64(varC.NbBits)), NbBits: 31} + } else { + varD = Variable{Value: c.reduceWithMaxBits(varD.Value, uint64(varD.NbBits)), NbBits: 31} + } + } + c.MulFCounter += c.ReduceMaxBitsCounter - curNumReduce + + return Variable{ + Value: c.api.Mul(varC.Value, varD.Value), + NbBits: varC.NbBits + varD.NbBits, + }, varC, varD } func (c *Chip) MulFConst(a Variable, b int) Variable { - return c.reduceFast(Variable{ + curNumReduce := c.ReduceMaxBitsCounter + retVal := c.reduceFast(Variable{ Value: c.api.Mul(a.Value, b), NbBits: a.NbBits + 4, }) + c.ReduceMaxBitsMap["MulFConst"] += c.ReduceMaxBitsCounter - curNumReduce + return retVal } func (c *Chip) negF(a Variable) Variable { + var retVal Variable + curNumReduce := c.ReduceMaxBitsCounter if a.NbBits == 31 { - return Variable{Value: c.api.Sub(modulus, a.Value), NbBits: 31} + retVal = Variable{Value: c.api.Sub(modulus, a.Value), NbBits: 31} + } else { + negOne := NewF("2013265920") + retVal, _, _ = c.MulF(a, negOne) } - negOne := NewF("2013265920") - return c.MulF(a, negOne) + + c.ReduceMaxBitsMap["negF"] += c.ReduceMaxBitsCounter - curNumReduce + return retVal } func (c *Chip) invF(in Variable) Variable { + curNumReduce := c.ReduceMaxBitsCounter in = c.ReduceSlow(in) result, err := c.api.Compiler().NewHint(InvFHint, 1, in.Value) if err != nil { @@ -113,40 +176,49 @@ func (c *Chip) invF(in Variable) Variable { Value: result[0], NbBits: 31, } - product := c.MulF(in, xinv) + product, _, _ := c.MulF(in, xinv) c.AssertIsEqualF(product, NewF("1")) + c.ReduceMaxBitsMap["invF"] += c.ReduceMaxBitsCounter - curNumReduce return xinv } func (c *Chip) AssertIsEqualF(a, b Variable) { + curNumReduce := c.ReduceMaxBitsCounter a2 := c.ReduceSlow(a) b2 := c.ReduceSlow(b) c.api.AssertIsEqual(a2.Value, b2.Value) + c.ReduceMaxBitsMap["AssertIsEqualF"] += c.ReduceMaxBitsCounter - curNumReduce } func (c *Chip) AssertIsEqualE(a, b ExtensionVariable) { + curNumReduce := c.ReduceMaxBitsCounter c.AssertIsEqualF(a.Value[0], b.Value[0]) c.AssertIsEqualF(a.Value[1], b.Value[1]) c.AssertIsEqualF(a.Value[2], b.Value[2]) c.AssertIsEqualF(a.Value[3], b.Value[3]) + c.ReduceMaxBitsMap["AssertIsEqualE"] += c.ReduceMaxBitsCounter - curNumReduce } func (c *Chip) SelectF(cond frontend.Variable, a, b Variable) Variable { + curNumReduce := c.ReduceMaxBitsCounter var nbBits uint if a.NbBits > b.NbBits { nbBits = a.NbBits } else { nbBits = b.NbBits } - return Variable{ + retVal := Variable{ Value: c.api.Select(cond, a.Value, b.Value), NbBits: nbBits, } + c.ReduceMaxBitsMap["SelectF"] += c.ReduceMaxBitsCounter - curNumReduce + return retVal } func (c *Chip) SelectE(cond frontend.Variable, a, b ExtensionVariable) ExtensionVariable { - return ExtensionVariable{ + curNumReduce := c.ReduceMaxBitsCounter + retVal := ExtensionVariable{ Value: [4]Variable{ c.SelectF(cond, a.Value[0], b.Value[0]), c.SelectF(cond, a.Value[1], b.Value[1]), @@ -154,68 +226,92 @@ func (c *Chip) SelectE(cond frontend.Variable, a, b ExtensionVariable) Extension c.SelectF(cond, a.Value[3], b.Value[3]), }, } + c.ReduceMaxBitsMap["SelectE"] += c.ReduceMaxBitsCounter - curNumReduce + return retVal } func (c *Chip) AddEF(a ExtensionVariable, b Variable) ExtensionVariable { + c.AddEFCounter++ + curNumReduce := c.ReduceMaxBitsCounter v1 := c.AddF(a.Value[0], b) + c.ReduceMaxBitsMap["AddEF"] += c.ReduceMaxBitsCounter - curNumReduce return ExtensionVariable{Value: [4]Variable{v1, a.Value[1], a.Value[2], a.Value[3]}} } func (c *Chip) AddE(a, b ExtensionVariable) ExtensionVariable { + c.AddECounter++ + curNumReduce := c.ReduceMaxBitsCounter v1 := c.AddF(a.Value[0], b.Value[0]) v2 := c.AddF(a.Value[1], b.Value[1]) v3 := c.AddF(a.Value[2], b.Value[2]) v4 := c.AddF(a.Value[3], b.Value[3]) + c.ReduceMaxBitsMap["AddE"] += c.ReduceMaxBitsCounter - curNumReduce return ExtensionVariable{Value: [4]Variable{v1, v2, v3, v4}} } func (c *Chip) SubE(a, b ExtensionVariable) ExtensionVariable { + curNumReduce := c.ReduceMaxBitsCounter v1 := c.SubF(a.Value[0], b.Value[0]) v2 := c.SubF(a.Value[1], b.Value[1]) v3 := c.SubF(a.Value[2], b.Value[2]) v4 := c.SubF(a.Value[3], b.Value[3]) + c.ReduceMaxBitsMap["SubE"] += c.ReduceMaxBitsCounter - curNumReduce return ExtensionVariable{Value: [4]Variable{v1, v2, v3, v4}} } func (c *Chip) SubEF(a ExtensionVariable, b Variable) ExtensionVariable { + curNumReduce := c.ReduceMaxBitsCounter v1 := c.SubF(a.Value[0], b) + c.ReduceMaxBitsMap["SubEF"] += c.ReduceMaxBitsCounter - curNumReduce return ExtensionVariable{Value: [4]Variable{v1, a.Value[1], a.Value[2], a.Value[3]}} } func (c *Chip) MulE(a, b ExtensionVariable) ExtensionVariable { + c.MulECounter++ v2 := [4]Variable{ - NewF("0"), - NewF("0"), - NewF("0"), - NewF("0"), + Zero(), + Zero(), + Zero(), + Zero(), } + newA := a + newB := b for i := 0; i < 4; i++ { for j := 0; j < 4; j++ { + newVal, newAEntry, newBEntry := c.MulF(newA.Value[i], newB.Value[j]) if i+j >= 4 { - v2[i+j-4] = c.AddF(v2[i+j-4], c.MulFConst(c.MulF(a.Value[i], b.Value[j]), 11)) + v2[i+j-4] = c.AddF(v2[i+j-4], c.MulFConst(newVal, 11)) } else { - v2[i+j] = c.AddF(v2[i+j], c.MulF(a.Value[i], b.Value[j])) + v2[i+j] = c.AddF(v2[i+j], newVal) } + newA.Value[i] = newAEntry + newB.Value[j] = newBEntry } } return ExtensionVariable{Value: v2} + } func (c *Chip) MulEF(a ExtensionVariable, b Variable) ExtensionVariable { - v1 := c.MulF(a.Value[0], b) - v2 := c.MulF(a.Value[1], b) - v3 := c.MulF(a.Value[2], b) - v4 := c.MulF(a.Value[3], b) + c.MulEFCounter++ + curNumReduce := c.ReduceMaxBitsCounter + v1, _, newB := c.MulF(a.Value[0], b) + v2, _, newB := c.MulF(a.Value[1], newB) + v3, _, newB := c.MulF(a.Value[2], newB) + v4, _, _ := c.MulF(a.Value[3], newB) + c.ReduceMaxBitsMap["MulEF"] += c.ReduceMaxBitsCounter - curNumReduce return ExtensionVariable{Value: [4]Variable{v1, v2, v3, v4}} } func (c *Chip) InvE(in ExtensionVariable) ExtensionVariable { + curNumReduce := c.ReduceMaxBitsCounter in.Value[0] = c.ReduceSlow(in.Value[0]) in.Value[1] = c.ReduceSlow(in.Value[1]) in.Value[2] = c.ReduceSlow(in.Value[2]) in.Value[3] = c.ReduceSlow(in.Value[3]) + result, err := c.api.Compiler().NewHint(InvEHint, 4, in.Value[0].Value, in.Value[1].Value, in.Value[2].Value, in.Value[3].Value) if err != nil { panic(err) @@ -229,6 +325,7 @@ func (c *Chip) InvE(in ExtensionVariable) ExtensionVariable { product := c.MulE(in, out) c.AssertIsEqualE(product, NewE([]string{"1", "0", "0", "0"})) + c.ReduceMaxBitsMap["InvE"] += c.ReduceMaxBitsCounter - curNumReduce return out } @@ -237,25 +334,39 @@ func (c *Chip) Ext2Felt(in ExtensionVariable) [4]Variable { return in.Value } +func (c *Chip) DivF(a, b Variable) Variable { + bInv := c.invF(b) + x, _, _ := c.MulF(a, bInv) + return x +} + func (c *Chip) DivE(a, b ExtensionVariable) ExtensionVariable { + curNumReduce := c.ReduceMaxBitsCounter bInv := c.InvE(b) - return c.MulE(a, bInv) + retVal := c.MulE(a, bInv) + c.ReduceMaxBitsMap["DivE"] += c.ReduceMaxBitsCounter - curNumReduce + return retVal } func (c *Chip) NegE(a ExtensionVariable) ExtensionVariable { + curNumReduce := c.ReduceMaxBitsCounter v1 := c.negF(a.Value[0]) v2 := c.negF(a.Value[1]) v3 := c.negF(a.Value[2]) v4 := c.negF(a.Value[3]) + c.ReduceMaxBitsMap["NegE"] += c.ReduceMaxBitsCounter - curNumReduce return ExtensionVariable{Value: [4]Variable{v1, v2, v3, v4}} } func (c *Chip) ToBinary(in Variable) []frontend.Variable { - return c.api.ToBinary(c.ReduceSlow(in).Value, 32) + curNumReduce := c.ReduceMaxBitsCounter + retVal := c.api.ToBinary(c.ReduceSlow(in).Value, 32) + c.ReduceMaxBitsMap["ToBinary"] += c.ReduceMaxBitsCounter - curNumReduce + return retVal } func (p *Chip) reduceFast(x Variable) Variable { - if x.NbBits >= uint(120) { + if x.NbBits >= uint(252) { return Variable{ Value: p.reduceWithMaxBits(x.Value, uint64(x.NbBits)), NbBits: 31, @@ -274,19 +385,81 @@ func (p *Chip) ReduceSlow(x Variable) Variable { } } +func (p *Chip) ReduceF(x Variable) Variable { + return p.ReduceSlow(x) +} + +func (p *Chip) ReduceE(x ExtensionVariable) ExtensionVariable { + for i := 0; i < 4; i++ { + x.Value[i] = p.ReduceSlow(x.Value[i]) + } + return x +} + func (p *Chip) reduceWithMaxBits(x frontend.Variable, maxNbBits uint64) frontend.Variable { + if maxNbBits <= 31 { + return x + } result, err := p.api.Compiler().NewHint(ReduceHint, 2, x) if err != nil { panic(err) } + p.ReduceMaxBitsCounter++ quotient := result[0] - p.rangeChecker.Check(quotient, int(maxNbBits-31)) - remainder := result[1] - p.rangeChecker.Check(remainder, 31) - p.api.AssertIsEqual(x, p.api.Add(p.api.Mul(quotient, modulus), result[1])) + if os.Getenv("RANGE_CHECKER") == "true" { + p.rangeChecker.Check(quotient, int(maxNbBits-31)) + // Check that the remainder has size less than the BabyBear modulus, by decomposing it into a 27 + // bit limb and a 4 bit limb. + new_result, new_err := p.api.Compiler().NewHint(SplitLimbsHint, 2, remainder) + if new_err != nil { + panic(new_err) + } + + lowLimb := new_result[0] + highLimb := new_result[1] + + // Check that the hint is correct. + p.api.AssertIsEqual( + p.api.Add( + p.api.Mul(highLimb, frontend.Variable(uint64(math.Pow(2, 27)))), + lowLimb, + ), + remainder, + ) + p.rangeChecker.Check(highLimb, 4) + p.rangeChecker.Check(lowLimb, 27) + + // If the most significant bits are all 1, then we need to check that the least significant bits + // are all zero in order for element to be less than the BabyBear modulus. Otherwise, we don't + // need to do any checks, since we already know that the element is less than the BabyBear modulus. + shouldCheck := p.api.IsZero(p.api.Sub(highLimb, uint64(math.Pow(2, 4))-1)) + p.api.AssertIsEqual( + p.api.Select( + shouldCheck, + lowLimb, + frontend.Variable(0), + ), + frontend.Variable(0), + ) + } else { + bits := p.api.ToBinary(remainder, 31) + p.api.ToBinary(quotient, int(maxNbBits-31)) + lowBits := frontend.Variable(0) + highBits := frontend.Variable(0) + for i := 0; i < 27; i++ { + lowBits = p.api.Add(lowBits, bits[i]) + } + for i := 27; i < 31; i++ { + highBits = p.api.Add(highBits, bits[i]) + } + highBitsIsFour := p.api.IsZero(p.api.Sub(highBits, 4)) + p.api.AssertIsEqual(p.api.Select(highBitsIsFour, lowBits, frontend.Variable(0)), frontend.Variable(0)) + } + + p.api.AssertIsEqual(x, p.api.Add(p.api.Mul(quotient, modulus), remainder)) return remainder } @@ -304,6 +477,30 @@ func ReduceHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { return nil } +// The hint used to split a BabyBear Variable into a 4 bit limb (the most significant bits) and a +// 27 bit limb. +func SplitLimbsHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + if len(inputs) != 1 { + panic("SplitLimbsHint expects 1 input operand") + } + + // The BabyBear field element + input := inputs[0] + + if input.Cmp(modulus) == 0 || input.Cmp(modulus) == 1 { + return fmt.Errorf("input is not in the field") + } + + two_27 := big.NewInt(int64(math.Pow(2, 27))) + + // The least significant bits + results[0] = new(big.Int).Rem(input, two_27) + // The most significant bits + results[1] = new(big.Int).Quo(input, two_27) + + return nil +} + func InvFHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { a := C.uint(inputs[0].Uint64()) ainv := C.babybearinv(a) @@ -316,10 +513,12 @@ func InvEHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { b := C.uint(inputs[1].Uint64()) c := C.uint(inputs[2].Uint64()) d := C.uint(inputs[3].Uint64()) + ainv := C.babybearextinv(a, b, c, d, 0) binv := C.babybearextinv(a, b, c, d, 1) cinv := C.babybearextinv(a, b, c, d, 2) dinv := C.babybearextinv(a, b, c, d, 3) + results[0].SetUint64(uint64(ainv)) results[1].SetUint64(uint64(binv)) results[2].SetUint64(uint64(cinv)) diff --git a/recursion/gnark-ffi/go/sp1/build.go b/recursion/gnark-ffi/go/sp1/build.go index 5531b7ebb5..2e89dc0518 100644 --- a/recursion/gnark-ffi/go/sp1/build.go +++ b/recursion/gnark-ffi/go/sp1/build.go @@ -6,11 +6,14 @@ import ( "log" "os" "strings" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/kzg" + "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test/unsafekzg" "github.com/succinctlabs/sp1-recursion-gnark/sp1/trusted_setup" @@ -40,6 +43,33 @@ func Build(dataDir string) { // Initialize the circuit. circuit := NewCircuit(witnessInput) + r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } + fmt.Println("Groth16 Constraints:", r1cs.GetNbConstraints()) + + pk2, err := groth16.DummySetup(r1cs) + if err != nil { + panic(err) + } + + // Generate proof. + start := time.Now() + assignment := NewCircuit(witnessInput) + witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } + + _, err = groth16.Prove(r1cs, pk2, witness) + if err != nil { + panic(err) + } + + // Print the proof time. + fmt.Println("Groth16 proof time:", time.Since(start)) + // Compile the circuit. scs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuit) if err != nil { @@ -128,8 +158,8 @@ func Build(dataDir string) { } // Generate proof. - assignment := NewCircuit(witnessInput) - witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + assignment = NewCircuit(witnessInput) + witness, err = frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) if err != nil { panic(err) } diff --git a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go index a16cc609fe..5f9301d6e7 100644 --- a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go +++ b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go @@ -11,13 +11,13 @@ const babybearNumInternalRounds = 13 type Poseidon2BabyBearChip struct { api frontend.API - fieldApi *babybear.Chip + FieldApi *babybear.Chip } func NewBabyBearChip(api frontend.API) *Poseidon2BabyBearChip { return &Poseidon2BabyBearChip{ api: api, - fieldApi: babybear.NewChip(api), + FieldApi: babybear.NewChip(api), } } @@ -37,7 +37,7 @@ func (p *Poseidon2BabyBearChip) PermuteMut(state *[BABYBEAR_WIDTH]babybear.Varia // The internal rounds. p_end := roundsFBeggining + babybearNumInternalRounds for r := roundsFBeggining; r < p_end; r++ { - state[0] = p.fieldApi.AddF(state[0], rc16[r][0]) + state[0] = p.FieldApi.AddF(state[0], rc16[r][0]) state[0] = p.sboxP(state[0]) p.diffusionPermuteMut(state) } @@ -52,20 +52,20 @@ func (p *Poseidon2BabyBearChip) PermuteMut(state *[BABYBEAR_WIDTH]babybear.Varia func (p *Poseidon2BabyBearChip) addRc(state *[BABYBEAR_WIDTH]babybear.Variable, rc [BABYBEAR_WIDTH]babybear.Variable) { for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.AddF(state[i], rc[i]) + state[i] = p.FieldApi.AddF(state[i], rc[i]) } } func (p *Poseidon2BabyBearChip) sboxP(input babybear.Variable) babybear.Variable { zero := babybear.NewF("0") - inputCpy := p.fieldApi.AddF(input, zero) - inputCpy = p.fieldApi.ReduceSlow(inputCpy) + inputCpy := p.FieldApi.AddF(input, zero) + inputCpy = p.FieldApi.ReduceSlow(inputCpy) inputValue := inputCpy.Value i2 := p.api.Mul(inputValue, inputValue) i4 := p.api.Mul(i2, i2) i6 := p.api.Mul(i4, i2) i7 := p.api.Mul(i6, inputValue) - i7bb := p.fieldApi.ReduceSlow(babybear.Variable{ + i7bb := p.FieldApi.ReduceSlow(babybear.Variable{ Value: i7, NbBits: 31 * 7, }) @@ -79,15 +79,15 @@ func (p *Poseidon2BabyBearChip) sbox(state *[BABYBEAR_WIDTH]babybear.Variable) { } func (p *Poseidon2BabyBearChip) mdsLightPermutation4x4(state []babybear.Variable) { - t01 := p.fieldApi.AddF(state[0], state[1]) - t23 := p.fieldApi.AddF(state[2], state[3]) - t0123 := p.fieldApi.AddF(t01, t23) - t01123 := p.fieldApi.AddF(t0123, state[1]) - t01233 := p.fieldApi.AddF(t0123, state[3]) - state[3] = p.fieldApi.AddF(t01233, p.fieldApi.MulFConst(state[0], 2)) - state[1] = p.fieldApi.AddF(t01123, p.fieldApi.MulFConst(state[2], 2)) - state[0] = p.fieldApi.AddF(t01123, t01) - state[2] = p.fieldApi.AddF(t01233, t23) + t01 := p.FieldApi.AddF(state[0], state[1]) + t23 := p.FieldApi.AddF(state[2], state[3]) + t0123 := p.FieldApi.AddF(t01, t23) + t01123 := p.FieldApi.AddF(t0123, state[1]) + t01233 := p.FieldApi.AddF(t0123, state[3]) + state[3] = p.FieldApi.AddF(t01233, p.FieldApi.MulFConst(state[0], 2)) + state[1] = p.FieldApi.AddF(t01123, p.FieldApi.MulFConst(state[2], 2)) + state[0] = p.FieldApi.AddF(t01123, t01) + state[2] = p.FieldApi.AddF(t01233, t23) } func (p *Poseidon2BabyBearChip) externalLinearLayer(state *[BABYBEAR_WIDTH]babybear.Variable) { @@ -102,14 +102,14 @@ func (p *Poseidon2BabyBearChip) externalLinearLayer(state *[BABYBEAR_WIDTH]babyb state[3], } for i := 4; i < BABYBEAR_WIDTH; i += 4 { - sums[0] = p.fieldApi.AddF(sums[0], state[i]) - sums[1] = p.fieldApi.AddF(sums[1], state[i+1]) - sums[2] = p.fieldApi.AddF(sums[2], state[i+2]) - sums[3] = p.fieldApi.AddF(sums[3], state[i+3]) + sums[0] = p.FieldApi.AddF(sums[0], state[i]) + sums[1] = p.FieldApi.AddF(sums[1], state[i+1]) + sums[2] = p.FieldApi.AddF(sums[2], state[i+2]) + sums[3] = p.FieldApi.AddF(sums[3], state[i+3]) } for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.AddF(state[i], sums[i%4]) + state[i] = p.FieldApi.AddF(state[i], sums[i%4]) } } @@ -135,7 +135,7 @@ func (p *Poseidon2BabyBearChip) diffusionPermuteMut(state *[BABYBEAR_WIDTH]babyb montyInverse := babybear.NewF("943718400") p.matmulInternal(state, &matInternalDiagM1) for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.MulF(state[i], montyInverse) + state[i], _, _ = p.FieldApi.MulF(state[i], montyInverse) } } @@ -146,11 +146,11 @@ func (p *Poseidon2BabyBearChip) matmulInternal( ) { sum := babybear.NewF("0") for i := 0; i < BABYBEAR_WIDTH; i++ { - sum = p.fieldApi.AddF(sum, state[i]) + sum = p.FieldApi.AddF(sum, state[i]) } for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.MulF(state[i], matInternalDiagM1[i]) - state[i] = p.fieldApi.AddF(state[i], sum) + state[i], _, _ = p.FieldApi.MulF(state[i], matInternalDiagM1[i]) + state[i] = p.FieldApi.AddF(state[i], sum) } } diff --git a/recursion/gnark-ffi/go/sp1/sp1.go b/recursion/gnark-ffi/go/sp1/sp1.go index a7fe4b651c..731a17fd5f 100644 --- a/recursion/gnark-ffi/go/sp1/sp1.go +++ b/recursion/gnark-ffi/go/sp1/sp1.go @@ -88,6 +88,10 @@ func (circuit *Circuit) Define(api frontend.API) error { felts[cs.Args[0][0]] = fieldAPI.AddF(felts[cs.Args[1][0]], felts[cs.Args[2][0]]) case "AddE": exts[cs.Args[0][0]] = fieldAPI.AddE(exts[cs.Args[1][0]], exts[cs.Args[2][0]]) + case "ReduceF": + felts[cs.Args[0][0]] = fieldAPI.ReduceF(felts[cs.Args[0][0]]) + case "ReduceE": + exts[cs.Args[0][0]] = fieldAPI.ReduceE(exts[cs.Args[0][0]]) case "AddEF": exts[cs.Args[0][0]] = fieldAPI.AddEF(exts[cs.Args[1][0]], felts[cs.Args[2][0]]) case "SubV": @@ -101,11 +105,13 @@ func (circuit *Circuit) Define(api frontend.API) error { case "MulV": vars[cs.Args[0][0]] = api.Mul(vars[cs.Args[1][0]], vars[cs.Args[2][0]]) case "MulF": - felts[cs.Args[0][0]] = fieldAPI.MulF(felts[cs.Args[1][0]], felts[cs.Args[2][0]]) + felts[cs.Args[0][0]], _, _ = fieldAPI.MulF(felts[cs.Args[1][0]], felts[cs.Args[2][0]]) case "MulE": exts[cs.Args[0][0]] = fieldAPI.MulE(exts[cs.Args[1][0]], exts[cs.Args[2][0]]) case "MulEF": exts[cs.Args[0][0]] = fieldAPI.MulEF(exts[cs.Args[1][0]], felts[cs.Args[2][0]]) + case "DivF": + felts[cs.Args[0][0]] = fieldAPI.DivF(felts[cs.Args[1][0]], felts[cs.Args[2][0]]) case "DivE": exts[cs.Args[0][0]] = fieldAPI.DivE(exts[cs.Args[1][0]], exts[cs.Args[2][0]]) case "NegE": @@ -195,10 +201,30 @@ func (circuit *Circuit) Define(api frontend.API) error { api.AssertIsEqual(circuit.CommitedValuesDigest, element) case "CircuitFelts2Ext": exts[cs.Args[0][0]] = babybear.Felts2Ext(felts[cs.Args[1][0]], felts[cs.Args[2][0]], felts[cs.Args[3][0]], felts[cs.Args[4][0]]) + case "CircuitFelt2Var": + vars[cs.Args[0][0]] = fieldAPI.ReduceSlow(felts[cs.Args[1][0]]).Value + case "CycleTracker": + fmt.Println("BabyBear API Reduce Counts: ", hashBabyBearAPI.FieldApi.ReduceMaxBitsMap) + fmt.Println("Field API Reduce Counts: ", fieldAPI.ReduceMaxBitsMap) default: return fmt.Errorf("unhandled opcode: %s", cs.Opcode) } } + fmt.Println("BabyBear API ReduceMaxBitsCount:", hashBabyBearAPI.FieldApi.ReduceMaxBitsCounter) + fmt.Println("Field API ReduceMaxBitsCount: ", fieldAPI.ReduceMaxBitsCounter) + fmt.Println("Field API AddFCounter: ", fieldAPI.AddFCounter) + fmt.Println("Field API AddECounter: ", fieldAPI.AddECounter) + fmt.Println("Field API AddEFCounter: ", fieldAPI.AddEFCounter) + fmt.Println("Field API MulFCounter: ", fieldAPI.MulFCounter) + fmt.Println("Field API MulECounter: ", fieldAPI.MulECounter) + fmt.Println("Field API MulEFCounter: ", fieldAPI.MulEFCounter) + fmt.Println("BabyBear API AddFCounter: ", hashBabyBearAPI.FieldApi.AddFCounter) + fmt.Println("BabyBear API MulFCounter: ", hashBabyBearAPI.FieldApi.MulFCounter) + fmt.Println("BabyBear API AddEFCounter: ", hashBabyBearAPI.FieldApi.AddEFCounter) + fmt.Println("BabyBear API MulEFCounter: ", hashBabyBearAPI.FieldApi.MulEFCounter) + fmt.Println("BabyBear API AddECounter: ", hashBabyBearAPI.FieldApi.AddECounter) + fmt.Println("BabyBear API MulECounter: ", hashBabyBearAPI.FieldApi.MulECounter) + return nil } diff --git a/recursion/gnark-ffi/src/ffi/native.rs b/recursion/gnark-ffi/src/ffi/native.rs index 9d88a6b5b6..0a70be1c85 100644 --- a/recursion/gnark-ffi/src/ffi/native.rs +++ b/recursion/gnark-ffi/src/ffi/native.rs @@ -8,7 +8,10 @@ use crate::PlonkBn254Proof; use cfg_if::cfg_if; use sp1_core::SP1_CIRCUIT_VERSION; -use std::ffi::{c_char, CString}; +use std::{ + env, + ffi::{c_char, CString}, +}; #[allow(warnings, clippy::all)] mod bind { @@ -73,9 +76,20 @@ pub fn test_plonk_bn254(witness_json: &str, constraints_json: &str) { unsafe { let witness_json = CString::new(witness_json).expect("CString::new failed"); let build_dir = CString::new(constraints_json).expect("CString::new failed"); + println!( + "RANGE_CHECKER env variable set to: {:?}", + env::var("RANGE_CHECKER") + ); + let range_checker = match env::var("RANGE_CHECKER") { + Ok(value) => value, + Err(_) => "true".to_string(), + }; + let range_checker = CString::new(range_checker).expect("CString::new failed"); + let err_ptr = bind::TestPlonkBn254( witness_json.as_ptr() as *mut c_char, build_dir.as_ptr() as *mut c_char, + range_checker.as_ptr() as *mut c_char, ); if !err_ptr.is_null() { // Safety: The error message is returned from the go code and is guaranteed to be valid.