diff --git a/prover/src/lib.rs b/prover/src/lib.rs index c77d7473bb..f4c539056f 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -986,45 +986,45 @@ pub mod tests { tracing::info!("setup elf"); let (pk, vk) = prover.setup(elf); - tracing::info!("prove core"); - let stdin = SP1Stdin::new(); - let core_proof = prover.prove_core(&pk, &stdin, opts, context)?; - let public_values = core_proof.public_values.clone(); + // tracing::info!("prove core"); + // let stdin = SP1Stdin::new(); + // let core_proof = prover.prove_core(&pk, &stdin, opts, context)?; + // let public_values = core_proof.public_values.clone(); - tracing::info!("verify core"); - prover.verify(&core_proof.proof, &vk)?; + // tracing::info!("verify core"); + // prover.verify(&core_proof.proof, &vk)?; - if test_kind == Test::Core { - return Ok(()); - } + // if test_kind == Test::Core { + // return Ok(()); + // } - tracing::info!("compress"); - let compressed_proof = prover.compress(&vk, core_proof, vec![], opts)?; + // tracing::info!("compress"); + // let compressed_proof = prover.compress(&vk, core_proof, vec![], opts)?; - tracing::info!("verify compressed"); - prover.verify_compressed(&compressed_proof, &vk)?; + // tracing::info!("verify compressed"); + // prover.verify_compressed(&compressed_proof, &vk)?; - if test_kind == Test::Compress { - return Ok(()); - } + // if test_kind == Test::Compress { + // return Ok(()); + // } - tracing::info!("shrink"); - let shrink_proof = prover.shrink(compressed_proof, opts)?; + // tracing::info!("shrink"); + // let shrink_proof = prover.shrink(compressed_proof, opts)?; - tracing::info!("verify shrink"); - prover.verify_shrink(&shrink_proof, &vk)?; + // tracing::info!("verify shrink"); + // prover.verify_shrink(&shrink_proof, &vk)?; - if test_kind == Test::Shrink { - return Ok(()); - } + // if test_kind == Test::Shrink { + // return Ok(()); + // } - tracing::info!("wrap bn254"); - let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; - let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); + // tracing::info!("wrap bn254"); + // let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; + // let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); - // Save the proof. - let mut file = File::create("proof-with-pis.bin").unwrap(); - file.write_all(bytes.as_slice()).unwrap(); + // // Save the proof. + // let mut file = File::create("proof-with-pis.bin").unwrap(); + // file.write_all(bytes.as_slice()).unwrap(); // Load the proof. let mut file = File::open("proof-with-pis.bin").unwrap(); @@ -1054,7 +1054,7 @@ pub mod tests { let plonk_bn254_proof = prover.wrap_plonk_bn254(wrapped_bn254_proof, &artifacts_dir); println!("{:?}", plonk_bn254_proof); - prover.verify_plonk_bn254(&plonk_bn254_proof, &vk, &public_values, &artifacts_dir)?; + // prover.verify_plonk_bn254(&plonk_bn254_proof, &vk, &public_values, &artifacts_dir)?; Ok(()) } diff --git a/recursion/gnark-ffi/go/sp1/babybear/babybear.go b/recursion/gnark-ffi/go/sp1/babybear/babybear.go index 96c58c5c9a..cc8dde71f4 100644 --- a/recursion/gnark-ffi/go/sp1/babybear/babybear.go +++ b/recursion/gnark-ffi/go/sp1/babybear/babybear.go @@ -102,11 +102,22 @@ func (c *Chip) SubF(a, b Variable) Variable { return c.AddF(a, negB) } -func (c *Chip) MulF(a, b Variable) Variable { - return c.reduceFast(Variable{ - Value: c.api.Mul(a.Value, b.Value), - NbBits: a.NbBits + b.NbBits, - }) +func (c *Chip) MulF(a, b Variable) (Variable, Variable, Variable) { + varC := a + varD := b + + for varC.NbBits+varD.NbBits > 250 { + if varC.NbBits > varD.NbBits { + varC = Variable{Value: c.reduceWithMaxBits(varC.Value, uint64(varC.NbBits)), NbBits: 31} + } else { + varD = Variable{Value: c.reduceWithMaxBits(varD.Value, uint64(varD.NbBits)), NbBits: 31} + } + } + + return Variable{ + Value: c.api.Mul(varC.Value, varD.Value), + NbBits: varC.NbBits + varD.NbBits, + }, varC, varD } func (c *Chip) MulFConst(a Variable, b int) Variable { @@ -143,7 +154,7 @@ func (c *Chip) invF(in Variable) Variable { Value: result[0], NbBits: 31, } - product := c.MulF(in, xinv) + product, _, _ := c.MulF(in, xinv) c.AssertIsEqualF(product, NewF("1")) return xinv @@ -219,14 +230,19 @@ func (c *Chip) MulE(a, b ExtensionVariable) ExtensionVariable { Zero(), Zero(), } + varA := a.Value + varB := b.Value for i := 0; i < 4; i++ { for j := 0; j < 4; j++ { + newVal, newA, newB := c.MulF(varA[i], varB[j]) if i+j >= 4 { - v2[i+j-4] = c.AddF(v2[i+j-4], c.MulFConst(c.MulF(a.Value[i], b.Value[j]), 11)) + v2[i+j-4] = c.AddF(v2[i+j-4], c.MulFConst(newVal, 11)) } else { - v2[i+j] = c.AddF(v2[i+j], c.MulF(a.Value[i], b.Value[j])) + v2[i+j] = c.AddF(v2[i+j], newVal) } + varA[i] = newA + varB[j] = newB } } @@ -234,10 +250,10 @@ func (c *Chip) MulE(a, b ExtensionVariable) ExtensionVariable { } func (c *Chip) MulEF(a ExtensionVariable, b Variable) ExtensionVariable { - v1 := c.MulF(a.Value[0], b) - v2 := c.MulF(a.Value[1], b) - v3 := c.MulF(a.Value[2], b) - v4 := c.MulF(a.Value[3], b) + v1, _, newB := c.MulF(a.Value[0], b) + v2, _, newB := c.MulF(a.Value[1], newB) + v3, _, newB := c.MulF(a.Value[2], newB) + v4, _, _ := c.MulF(a.Value[3], newB) return ExtensionVariable{Value: [4]Variable{v1, v2, v3, v4}} } @@ -285,7 +301,7 @@ func (c *Chip) ToBinary(in Variable) []frontend.Variable { } func (p *Chip) reduceFast(x Variable) Variable { - if x.NbBits >= uint(126) { + if x.NbBits >= uint(252) { return Variable{ Value: p.reduceWithMaxBits(x.Value, uint64(x.NbBits)), NbBits: 31, @@ -295,7 +311,7 @@ func (p *Chip) reduceFast(x Variable) Variable { } func (p *Chip) ReduceSlow(x Variable) Variable { - if x.NbBits > 31 { + if x.NbBits <= 31 { return x } return Variable{ diff --git a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go index a16cc609fe..616e26316c 100644 --- a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go +++ b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go @@ -135,7 +135,7 @@ func (p *Poseidon2BabyBearChip) diffusionPermuteMut(state *[BABYBEAR_WIDTH]babyb montyInverse := babybear.NewF("943718400") p.matmulInternal(state, &matInternalDiagM1) for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.MulF(state[i], montyInverse) + state[i], _, _ = p.fieldApi.MulF(state[i], montyInverse) } } @@ -150,7 +150,7 @@ func (p *Poseidon2BabyBearChip) matmulInternal( } for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.MulF(state[i], matInternalDiagM1[i]) + state[i], _, _ = p.fieldApi.MulF(state[i], matInternalDiagM1[i]) state[i] = p.fieldApi.AddF(state[i], sum) } } diff --git a/recursion/gnark-ffi/go/sp1/sp1.go b/recursion/gnark-ffi/go/sp1/sp1.go index b1ffa56748..85035170d2 100644 --- a/recursion/gnark-ffi/go/sp1/sp1.go +++ b/recursion/gnark-ffi/go/sp1/sp1.go @@ -101,7 +101,7 @@ func (circuit *Circuit) Define(api frontend.API) error { case "MulV": vars[cs.Args[0][0]] = api.Mul(vars[cs.Args[1][0]], vars[cs.Args[2][0]]) case "MulF": - felts[cs.Args[0][0]] = fieldAPI.MulF(felts[cs.Args[1][0]], felts[cs.Args[2][0]]) + felts[cs.Args[0][0]], _, _ = fieldAPI.MulF(felts[cs.Args[1][0]], felts[cs.Args[2][0]]) case "MulE": exts[cs.Args[0][0]] = fieldAPI.MulE(exts[cs.Args[1][0]], exts[cs.Args[2][0]]) case "MulEF":