From 99b795a203b4b0be920d34092e1d2ababdaadb3f Mon Sep 17 00:00:00 2001 From: erabinov Date: Tue, 6 Aug 2024 15:00:05 -0700 Subject: [PATCH 1/4] initial commit --- recursion/circuit/src/challenger.rs | 3 +-- recursion/circuit/src/fri.rs | 3 +++ recursion/compiler/src/constraints/mod.rs | 4 ++++ recursion/compiler/src/constraints/opcodes.rs | 1 + recursion/compiler/src/ir/builder.rs | 6 ++++++ recursion/compiler/src/ir/instructions.rs | 2 ++ recursion/gnark-ffi/go/sp1/babybear/babybear.go | 14 +++++++++++++- recursion/gnark-ffi/go/sp1/sp1.go | 2 ++ 8 files changed, 32 insertions(+), 3 deletions(-) diff --git a/recursion/circuit/src/challenger.rs b/recursion/circuit/src/challenger.rs index aa947f6b04..ac623ba027 100644 --- a/recursion/circuit/src/challenger.rs +++ b/recursion/circuit/src/challenger.rs @@ -122,8 +122,7 @@ pub fn reduce_32(builder: &mut Builder, vals: &[Felt]) -> Va let mut power = C::N::one(); let result: Var = builder.eval(C::N::zero()); for val in vals.iter() { - let bits = builder.num2bits_f_circuit(*val); - let val = builder.bits2num_v_circuit(&bits); + let val = builder.felt2var_circuit(*val); builder.assign(result, result + val * power); power *= C::N::from_canonical_u64(1u64 << 32); } diff --git a/recursion/circuit/src/fri.rs b/recursion/circuit/src/fri.rs index 7e30d656b0..73724c8163 100644 --- a/recursion/circuit/src/fri.rs +++ b/recursion/circuit/src/fri.rs @@ -214,6 +214,9 @@ pub fn verify_query( let index_sibling: Var<_> = builder.eval(one - index_bits.clone()[offset]); let index_pair = &index_bits[(offset + 1)..]; + builder.reduce_e(folded_eval); + builder.reduce_e(step.sibling_value); + let evals_ext = [ builder.select_ef(index_sibling, folded_eval, step.sibling_value), builder.select_ef(index_sibling, step.sibling_value, folded_eval), diff --git a/recursion/compiler/src/constraints/mod.rs b/recursion/compiler/src/constraints/mod.rs index 612aee4a46..841b4862ed 100644 --- a/recursion/compiler/src/constraints/mod.rs +++ b/recursion/compiler/src/constraints/mod.rs @@ -362,6 +362,10 @@ impl ConstraintCompiler { opcode: ConstraintOpcode::ReduceE, args: vec![vec![a.id()]], }), + DslIr::CircuitFelt2Var(a, b) => constraints.push(Constraint { + opcode: ConstraintOpcode::CircuitFelt2Var, + args: vec![vec![b.id()], vec![a.id()]], + }), _ => panic!("unsupported {:?}", instruction), }; } diff --git a/recursion/compiler/src/constraints/opcodes.rs b/recursion/compiler/src/constraints/opcodes.rs index 02ed47eea5..edb6b1c2e0 100644 --- a/recursion/compiler/src/constraints/opcodes.rs +++ b/recursion/compiler/src/constraints/opcodes.rs @@ -46,6 +46,7 @@ pub enum ConstraintOpcode { CommitVkeyHash, CommitCommitedValuesDigest, CircuitFelts2Ext, + CircuitFelt2Var, PermuteBabyBear, ReduceE, } diff --git a/recursion/compiler/src/ir/builder.rs b/recursion/compiler/src/ir/builder.rs index f9236b8970..980f3a9567 100644 --- a/recursion/compiler/src/ir/builder.rs +++ b/recursion/compiler/src/ir/builder.rs @@ -519,6 +519,12 @@ impl Builder { self.operations.push(DslIr::ReduceE(ext)); } + pub fn felt2var_circuit(&mut self, felt: Felt) -> Var { + let var = self.uninit(); + self.operations.push(DslIr::CircuitFelt2Var(felt, var)); + var + } + pub fn cycle_tracker(&mut self, name: &str) { self.operations.push(DslIr::CycleTracker(name.to_string())); } diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index 636f4c86a2..b0416dd171 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -191,6 +191,8 @@ pub enum DslIr { CircuitNum2BitsV(Var, usize, Vec>), /// Decompose a field element into bits (bits = num2bits(felt)). Should only be used when target is a gnark circuit. CircuitNum2BitsF(Felt, Vec>), + /// Convert a Felt to a Var in a circuit. Avoids decomposing to bits and then reconstructing. + CircuitFelt2Var(Felt, Var), // Hashing. /// Permutes an array of baby bear elements using Poseidon2 (output = p2_permute(array)). diff --git a/recursion/gnark-ffi/go/sp1/babybear/babybear.go b/recursion/gnark-ffi/go/sp1/babybear/babybear.go index 8259d486c7..96c58c5c9a 100644 --- a/recursion/gnark-ffi/go/sp1/babybear/babybear.go +++ b/recursion/gnark-ffi/go/sp1/babybear/babybear.go @@ -53,7 +53,19 @@ func Zero() Variable { } } +func One() Variable { + return Variable{ + Value: frontend.Variable("1"), + NbBits: 1, + } +} + func NewF(value string) Variable { + if value == "0" { + return Zero() + } else if value == "1" { + return One() + } return Variable{ Value: frontend.Variable(value), NbBits: 31, @@ -283,7 +295,7 @@ func (p *Chip) reduceFast(x Variable) Variable { } func (p *Chip) ReduceSlow(x Variable) Variable { - if x.NbBits == 31 { + if x.NbBits > 31 { return x } return Variable{ diff --git a/recursion/gnark-ffi/go/sp1/sp1.go b/recursion/gnark-ffi/go/sp1/sp1.go index 5760a43d90..b1ffa56748 100644 --- a/recursion/gnark-ffi/go/sp1/sp1.go +++ b/recursion/gnark-ffi/go/sp1/sp1.go @@ -195,6 +195,8 @@ func (circuit *Circuit) Define(api frontend.API) error { api.AssertIsEqual(circuit.CommitedValuesDigest, element) case "CircuitFelts2Ext": exts[cs.Args[0][0]] = babybear.Felts2Ext(felts[cs.Args[1][0]], felts[cs.Args[2][0]], felts[cs.Args[3][0]], felts[cs.Args[4][0]]) + case "CircuitFelt2Var": + vars[cs.Args[0][0]] = fieldAPI.ReduceSlow(felts[cs.Args[1][0]]).Value case "ReduceE": exts[cs.Args[0][0]] = fieldAPI.ReduceE(exts[cs.Args[0][0]]) default: From 9eab7866b5e0505bc601f0d387a28c02e9e30bd8 Mon Sep 17 00:00:00 2001 From: erabinov Date: Tue, 6 Aug 2024 15:29:00 -0700 Subject: [PATCH 2/4] optimistic reduction --- prover/src/lib.rs | 60 +++++++++---------- .../gnark-ffi/go/sp1/babybear/babybear.go | 44 +++++++++----- .../go/sp1/poseidon2/poseidon2_babybear.go | 4 +- recursion/gnark-ffi/go/sp1/sp1.go | 2 +- 4 files changed, 63 insertions(+), 47 deletions(-) diff --git a/prover/src/lib.rs b/prover/src/lib.rs index c77d7473bb..f4c539056f 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -986,45 +986,45 @@ pub mod tests { tracing::info!("setup elf"); let (pk, vk) = prover.setup(elf); - tracing::info!("prove core"); - let stdin = SP1Stdin::new(); - let core_proof = prover.prove_core(&pk, &stdin, opts, context)?; - let public_values = core_proof.public_values.clone(); + // tracing::info!("prove core"); + // let stdin = SP1Stdin::new(); + // let core_proof = prover.prove_core(&pk, &stdin, opts, context)?; + // let public_values = core_proof.public_values.clone(); - tracing::info!("verify core"); - prover.verify(&core_proof.proof, &vk)?; + // tracing::info!("verify core"); + // prover.verify(&core_proof.proof, &vk)?; - if test_kind == Test::Core { - return Ok(()); - } + // if test_kind == Test::Core { + // return Ok(()); + // } - tracing::info!("compress"); - let compressed_proof = prover.compress(&vk, core_proof, vec![], opts)?; + // tracing::info!("compress"); + // let compressed_proof = prover.compress(&vk, core_proof, vec![], opts)?; - tracing::info!("verify compressed"); - prover.verify_compressed(&compressed_proof, &vk)?; + // tracing::info!("verify compressed"); + // prover.verify_compressed(&compressed_proof, &vk)?; - if test_kind == Test::Compress { - return Ok(()); - } + // if test_kind == Test::Compress { + // return Ok(()); + // } - tracing::info!("shrink"); - let shrink_proof = prover.shrink(compressed_proof, opts)?; + // tracing::info!("shrink"); + // let shrink_proof = prover.shrink(compressed_proof, opts)?; - tracing::info!("verify shrink"); - prover.verify_shrink(&shrink_proof, &vk)?; + // tracing::info!("verify shrink"); + // prover.verify_shrink(&shrink_proof, &vk)?; - if test_kind == Test::Shrink { - return Ok(()); - } + // if test_kind == Test::Shrink { + // return Ok(()); + // } - tracing::info!("wrap bn254"); - let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; - let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); + // tracing::info!("wrap bn254"); + // let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; + // let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); - // Save the proof. - let mut file = File::create("proof-with-pis.bin").unwrap(); - file.write_all(bytes.as_slice()).unwrap(); + // // Save the proof. + // let mut file = File::create("proof-with-pis.bin").unwrap(); + // file.write_all(bytes.as_slice()).unwrap(); // Load the proof. let mut file = File::open("proof-with-pis.bin").unwrap(); @@ -1054,7 +1054,7 @@ pub mod tests { let plonk_bn254_proof = prover.wrap_plonk_bn254(wrapped_bn254_proof, &artifacts_dir); println!("{:?}", plonk_bn254_proof); - prover.verify_plonk_bn254(&plonk_bn254_proof, &vk, &public_values, &artifacts_dir)?; + // prover.verify_plonk_bn254(&plonk_bn254_proof, &vk, &public_values, &artifacts_dir)?; Ok(()) } diff --git a/recursion/gnark-ffi/go/sp1/babybear/babybear.go b/recursion/gnark-ffi/go/sp1/babybear/babybear.go index 96c58c5c9a..cc8dde71f4 100644 --- a/recursion/gnark-ffi/go/sp1/babybear/babybear.go +++ b/recursion/gnark-ffi/go/sp1/babybear/babybear.go @@ -102,11 +102,22 @@ func (c *Chip) SubF(a, b Variable) Variable { return c.AddF(a, negB) } -func (c *Chip) MulF(a, b Variable) Variable { - return c.reduceFast(Variable{ - Value: c.api.Mul(a.Value, b.Value), - NbBits: a.NbBits + b.NbBits, - }) +func (c *Chip) MulF(a, b Variable) (Variable, Variable, Variable) { + varC := a + varD := b + + for varC.NbBits+varD.NbBits > 250 { + if varC.NbBits > varD.NbBits { + varC = Variable{Value: c.reduceWithMaxBits(varC.Value, uint64(varC.NbBits)), NbBits: 31} + } else { + varD = Variable{Value: c.reduceWithMaxBits(varD.Value, uint64(varD.NbBits)), NbBits: 31} + } + } + + return Variable{ + Value: c.api.Mul(varC.Value, varD.Value), + NbBits: varC.NbBits + varD.NbBits, + }, varC, varD } func (c *Chip) MulFConst(a Variable, b int) Variable { @@ -143,7 +154,7 @@ func (c *Chip) invF(in Variable) Variable { Value: result[0], NbBits: 31, } - product := c.MulF(in, xinv) + product, _, _ := c.MulF(in, xinv) c.AssertIsEqualF(product, NewF("1")) return xinv @@ -219,14 +230,19 @@ func (c *Chip) MulE(a, b ExtensionVariable) ExtensionVariable { Zero(), Zero(), } + varA := a.Value + varB := b.Value for i := 0; i < 4; i++ { for j := 0; j < 4; j++ { + newVal, newA, newB := c.MulF(varA[i], varB[j]) if i+j >= 4 { - v2[i+j-4] = c.AddF(v2[i+j-4], c.MulFConst(c.MulF(a.Value[i], b.Value[j]), 11)) + v2[i+j-4] = c.AddF(v2[i+j-4], c.MulFConst(newVal, 11)) } else { - v2[i+j] = c.AddF(v2[i+j], c.MulF(a.Value[i], b.Value[j])) + v2[i+j] = c.AddF(v2[i+j], newVal) } + varA[i] = newA + varB[j] = newB } } @@ -234,10 +250,10 @@ func (c *Chip) MulE(a, b ExtensionVariable) ExtensionVariable { } func (c *Chip) MulEF(a ExtensionVariable, b Variable) ExtensionVariable { - v1 := c.MulF(a.Value[0], b) - v2 := c.MulF(a.Value[1], b) - v3 := c.MulF(a.Value[2], b) - v4 := c.MulF(a.Value[3], b) + v1, _, newB := c.MulF(a.Value[0], b) + v2, _, newB := c.MulF(a.Value[1], newB) + v3, _, newB := c.MulF(a.Value[2], newB) + v4, _, _ := c.MulF(a.Value[3], newB) return ExtensionVariable{Value: [4]Variable{v1, v2, v3, v4}} } @@ -285,7 +301,7 @@ func (c *Chip) ToBinary(in Variable) []frontend.Variable { } func (p *Chip) reduceFast(x Variable) Variable { - if x.NbBits >= uint(126) { + if x.NbBits >= uint(252) { return Variable{ Value: p.reduceWithMaxBits(x.Value, uint64(x.NbBits)), NbBits: 31, @@ -295,7 +311,7 @@ func (p *Chip) reduceFast(x Variable) Variable { } func (p *Chip) ReduceSlow(x Variable) Variable { - if x.NbBits > 31 { + if x.NbBits <= 31 { return x } return Variable{ diff --git a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go index a16cc609fe..616e26316c 100644 --- a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go +++ b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go @@ -135,7 +135,7 @@ func (p *Poseidon2BabyBearChip) diffusionPermuteMut(state *[BABYBEAR_WIDTH]babyb montyInverse := babybear.NewF("943718400") p.matmulInternal(state, &matInternalDiagM1) for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.MulF(state[i], montyInverse) + state[i], _, _ = p.fieldApi.MulF(state[i], montyInverse) } } @@ -150,7 +150,7 @@ func (p *Poseidon2BabyBearChip) matmulInternal( } for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.MulF(state[i], matInternalDiagM1[i]) + state[i], _, _ = p.fieldApi.MulF(state[i], matInternalDiagM1[i]) state[i] = p.fieldApi.AddF(state[i], sum) } } diff --git a/recursion/gnark-ffi/go/sp1/sp1.go b/recursion/gnark-ffi/go/sp1/sp1.go index b1ffa56748..85035170d2 100644 --- a/recursion/gnark-ffi/go/sp1/sp1.go +++ b/recursion/gnark-ffi/go/sp1/sp1.go @@ -101,7 +101,7 @@ func (circuit *Circuit) Define(api frontend.API) error { case "MulV": vars[cs.Args[0][0]] = api.Mul(vars[cs.Args[1][0]], vars[cs.Args[2][0]]) case "MulF": - felts[cs.Args[0][0]] = fieldAPI.MulF(felts[cs.Args[1][0]], felts[cs.Args[2][0]]) + felts[cs.Args[0][0]], _, _ = fieldAPI.MulF(felts[cs.Args[1][0]], felts[cs.Args[2][0]]) case "MulE": exts[cs.Args[0][0]] = fieldAPI.MulE(exts[cs.Args[1][0]], exts[cs.Args[2][0]]) case "MulEF": From 5d2f5a432052330ff7a48b99588d871977f48717 Mon Sep 17 00:00:00 2001 From: erabinov Date: Tue, 6 Aug 2024 15:29:22 -0700 Subject: [PATCH 3/4] uncomment --- prover/src/lib.rs | 58 +++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/prover/src/lib.rs b/prover/src/lib.rs index f4c539056f..a80353146c 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -986,45 +986,45 @@ pub mod tests { tracing::info!("setup elf"); let (pk, vk) = prover.setup(elf); - // tracing::info!("prove core"); - // let stdin = SP1Stdin::new(); - // let core_proof = prover.prove_core(&pk, &stdin, opts, context)?; - // let public_values = core_proof.public_values.clone(); + tracing::info!("prove core"); + let stdin = SP1Stdin::new(); + let core_proof = prover.prove_core(&pk, &stdin, opts, context)?; + let public_values = core_proof.public_values.clone(); - // tracing::info!("verify core"); - // prover.verify(&core_proof.proof, &vk)?; + tracing::info!("verify core"); + prover.verify(&core_proof.proof, &vk)?; - // if test_kind == Test::Core { - // return Ok(()); - // } + if test_kind == Test::Core { + return Ok(()); + } - // tracing::info!("compress"); - // let compressed_proof = prover.compress(&vk, core_proof, vec![], opts)?; + tracing::info!("compress"); + let compressed_proof = prover.compress(&vk, core_proof, vec![], opts)?; - // tracing::info!("verify compressed"); - // prover.verify_compressed(&compressed_proof, &vk)?; + tracing::info!("verify compressed"); + prover.verify_compressed(&compressed_proof, &vk)?; - // if test_kind == Test::Compress { - // return Ok(()); - // } + if test_kind == Test::Compress { + return Ok(()); + } - // tracing::info!("shrink"); - // let shrink_proof = prover.shrink(compressed_proof, opts)?; + tracing::info!("shrink"); + let shrink_proof = prover.shrink(compressed_proof, opts)?; - // tracing::info!("verify shrink"); - // prover.verify_shrink(&shrink_proof, &vk)?; + tracing::info!("verify shrink"); + prover.verify_shrink(&shrink_proof, &vk)?; - // if test_kind == Test::Shrink { - // return Ok(()); - // } + if test_kind == Test::Shrink { + return Ok(()); + } - // tracing::info!("wrap bn254"); - // let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; - // let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); + tracing::info!("wrap bn254"); + let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; + let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); - // // Save the proof. - // let mut file = File::create("proof-with-pis.bin").unwrap(); - // file.write_all(bytes.as_slice()).unwrap(); + // Save the proof. + let mut file = File::create("proof-with-pis.bin").unwrap(); + file.write_all(bytes.as_slice()).unwrap(); // Load the proof. let mut file = File::open("proof-with-pis.bin").unwrap(); From f0da2f8f3346e66b44ef25b4537542528afe5610 Mon Sep 17 00:00:00 2001 From: erabinov Date: Tue, 6 Aug 2024 15:32:08 -0700 Subject: [PATCH 4/4] uncomment --- prover/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prover/src/lib.rs b/prover/src/lib.rs index a80353146c..c77d7473bb 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -1054,7 +1054,7 @@ pub mod tests { let plonk_bn254_proof = prover.wrap_plonk_bn254(wrapped_bn254_proof, &artifacts_dir); println!("{:?}", plonk_bn254_proof); - // prover.verify_plonk_bn254(&plonk_bn254_proof, &vk, &public_values, &artifacts_dir)?; + prover.verify_plonk_bn254(&plonk_bn254_proof, &vk, &public_values, &artifacts_dir)?; Ok(()) }