diff --git a/recursion/circuit/src/challenger.rs b/recursion/circuit/src/challenger.rs index aa947f6b04..ac623ba027 100644 --- a/recursion/circuit/src/challenger.rs +++ b/recursion/circuit/src/challenger.rs @@ -122,8 +122,7 @@ pub fn reduce_32(builder: &mut Builder, vals: &[Felt]) -> Va let mut power = C::N::one(); let result: Var = builder.eval(C::N::zero()); for val in vals.iter() { - let bits = builder.num2bits_f_circuit(*val); - let val = builder.bits2num_v_circuit(&bits); + let val = builder.felt2var_circuit(*val); builder.assign(result, result + val * power); power *= C::N::from_canonical_u64(1u64 << 32); } diff --git a/recursion/circuit/src/fri.rs b/recursion/circuit/src/fri.rs index 7e30d656b0..73724c8163 100644 --- a/recursion/circuit/src/fri.rs +++ b/recursion/circuit/src/fri.rs @@ -214,6 +214,9 @@ pub fn verify_query( let index_sibling: Var<_> = builder.eval(one - index_bits.clone()[offset]); let index_pair = &index_bits[(offset + 1)..]; + builder.reduce_e(folded_eval); + builder.reduce_e(step.sibling_value); + let evals_ext = [ builder.select_ef(index_sibling, folded_eval, step.sibling_value), builder.select_ef(index_sibling, step.sibling_value, folded_eval), diff --git a/recursion/compiler/src/constraints/mod.rs b/recursion/compiler/src/constraints/mod.rs index 612aee4a46..841b4862ed 100644 --- a/recursion/compiler/src/constraints/mod.rs +++ b/recursion/compiler/src/constraints/mod.rs @@ -362,6 +362,10 @@ impl ConstraintCompiler { opcode: ConstraintOpcode::ReduceE, args: vec![vec![a.id()]], }), + DslIr::CircuitFelt2Var(a, b) => constraints.push(Constraint { + opcode: ConstraintOpcode::CircuitFelt2Var, + args: vec![vec![b.id()], vec![a.id()]], + }), _ => panic!("unsupported {:?}", instruction), }; } diff --git a/recursion/compiler/src/constraints/opcodes.rs b/recursion/compiler/src/constraints/opcodes.rs index 02ed47eea5..edb6b1c2e0 100644 --- a/recursion/compiler/src/constraints/opcodes.rs +++ b/recursion/compiler/src/constraints/opcodes.rs @@ -46,6 +46,7 @@ pub enum ConstraintOpcode { CommitVkeyHash, CommitCommitedValuesDigest, CircuitFelts2Ext, + CircuitFelt2Var, PermuteBabyBear, ReduceE, } diff --git a/recursion/compiler/src/ir/builder.rs b/recursion/compiler/src/ir/builder.rs index f9236b8970..980f3a9567 100644 --- a/recursion/compiler/src/ir/builder.rs +++ b/recursion/compiler/src/ir/builder.rs @@ -519,6 +519,12 @@ impl Builder { self.operations.push(DslIr::ReduceE(ext)); } + pub fn felt2var_circuit(&mut self, felt: Felt) -> Var { + let var = self.uninit(); + self.operations.push(DslIr::CircuitFelt2Var(felt, var)); + var + } + pub fn cycle_tracker(&mut self, name: &str) { self.operations.push(DslIr::CycleTracker(name.to_string())); } diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index 636f4c86a2..b0416dd171 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -191,6 +191,8 @@ pub enum DslIr { CircuitNum2BitsV(Var, usize, Vec>), /// Decompose a field element into bits (bits = num2bits(felt)). Should only be used when target is a gnark circuit. CircuitNum2BitsF(Felt, Vec>), + /// Convert a Felt to a Var in a circuit. Avoids decomposing to bits and then reconstructing. + CircuitFelt2Var(Felt, Var), // Hashing. /// Permutes an array of baby bear elements using Poseidon2 (output = p2_permute(array)). diff --git a/recursion/gnark-ffi/go/sp1/babybear/babybear.go b/recursion/gnark-ffi/go/sp1/babybear/babybear.go index 8259d486c7..cc8dde71f4 100644 --- a/recursion/gnark-ffi/go/sp1/babybear/babybear.go +++ b/recursion/gnark-ffi/go/sp1/babybear/babybear.go @@ -53,7 +53,19 @@ func Zero() Variable { } } +func One() Variable { + return Variable{ + Value: frontend.Variable("1"), + NbBits: 1, + } +} + func NewF(value string) Variable { + if value == "0" { + return Zero() + } else if value == "1" { + return One() + } return Variable{ Value: frontend.Variable(value), NbBits: 31, @@ -90,11 +102,22 @@ func (c *Chip) SubF(a, b Variable) Variable { return c.AddF(a, negB) } -func (c *Chip) MulF(a, b Variable) Variable { - return c.reduceFast(Variable{ - Value: c.api.Mul(a.Value, b.Value), - NbBits: a.NbBits + b.NbBits, - }) +func (c *Chip) MulF(a, b Variable) (Variable, Variable, Variable) { + varC := a + varD := b + + for varC.NbBits+varD.NbBits > 250 { + if varC.NbBits > varD.NbBits { + varC = Variable{Value: c.reduceWithMaxBits(varC.Value, uint64(varC.NbBits)), NbBits: 31} + } else { + varD = Variable{Value: c.reduceWithMaxBits(varD.Value, uint64(varD.NbBits)), NbBits: 31} + } + } + + return Variable{ + Value: c.api.Mul(varC.Value, varD.Value), + NbBits: varC.NbBits + varD.NbBits, + }, varC, varD } func (c *Chip) MulFConst(a Variable, b int) Variable { @@ -131,7 +154,7 @@ func (c *Chip) invF(in Variable) Variable { Value: result[0], NbBits: 31, } - product := c.MulF(in, xinv) + product, _, _ := c.MulF(in, xinv) c.AssertIsEqualF(product, NewF("1")) return xinv @@ -207,14 +230,19 @@ func (c *Chip) MulE(a, b ExtensionVariable) ExtensionVariable { Zero(), Zero(), } + varA := a.Value + varB := b.Value for i := 0; i < 4; i++ { for j := 0; j < 4; j++ { + newVal, newA, newB := c.MulF(varA[i], varB[j]) if i+j >= 4 { - v2[i+j-4] = c.AddF(v2[i+j-4], c.MulFConst(c.MulF(a.Value[i], b.Value[j]), 11)) + v2[i+j-4] = c.AddF(v2[i+j-4], c.MulFConst(newVal, 11)) } else { - v2[i+j] = c.AddF(v2[i+j], c.MulF(a.Value[i], b.Value[j])) + v2[i+j] = c.AddF(v2[i+j], newVal) } + varA[i] = newA + varB[j] = newB } } @@ -222,10 +250,10 @@ func (c *Chip) MulE(a, b ExtensionVariable) ExtensionVariable { } func (c *Chip) MulEF(a ExtensionVariable, b Variable) ExtensionVariable { - v1 := c.MulF(a.Value[0], b) - v2 := c.MulF(a.Value[1], b) - v3 := c.MulF(a.Value[2], b) - v4 := c.MulF(a.Value[3], b) + v1, _, newB := c.MulF(a.Value[0], b) + v2, _, newB := c.MulF(a.Value[1], newB) + v3, _, newB := c.MulF(a.Value[2], newB) + v4, _, _ := c.MulF(a.Value[3], newB) return ExtensionVariable{Value: [4]Variable{v1, v2, v3, v4}} } @@ -273,7 +301,7 @@ func (c *Chip) ToBinary(in Variable) []frontend.Variable { } func (p *Chip) reduceFast(x Variable) Variable { - if x.NbBits >= uint(126) { + if x.NbBits >= uint(252) { return Variable{ Value: p.reduceWithMaxBits(x.Value, uint64(x.NbBits)), NbBits: 31, @@ -283,7 +311,7 @@ func (p *Chip) reduceFast(x Variable) Variable { } func (p *Chip) ReduceSlow(x Variable) Variable { - if x.NbBits == 31 { + if x.NbBits <= 31 { return x } return Variable{ diff --git a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go index a16cc609fe..616e26316c 100644 --- a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go +++ b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go @@ -135,7 +135,7 @@ func (p *Poseidon2BabyBearChip) diffusionPermuteMut(state *[BABYBEAR_WIDTH]babyb montyInverse := babybear.NewF("943718400") p.matmulInternal(state, &matInternalDiagM1) for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.MulF(state[i], montyInverse) + state[i], _, _ = p.fieldApi.MulF(state[i], montyInverse) } } @@ -150,7 +150,7 @@ func (p *Poseidon2BabyBearChip) matmulInternal( } for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.MulF(state[i], matInternalDiagM1[i]) + state[i], _, _ = p.fieldApi.MulF(state[i], matInternalDiagM1[i]) state[i] = p.fieldApi.AddF(state[i], sum) } } diff --git a/recursion/gnark-ffi/go/sp1/sp1.go b/recursion/gnark-ffi/go/sp1/sp1.go index 5760a43d90..85035170d2 100644 --- a/recursion/gnark-ffi/go/sp1/sp1.go +++ b/recursion/gnark-ffi/go/sp1/sp1.go @@ -101,7 +101,7 @@ func (circuit *Circuit) Define(api frontend.API) error { case "MulV": vars[cs.Args[0][0]] = api.Mul(vars[cs.Args[1][0]], vars[cs.Args[2][0]]) case "MulF": - felts[cs.Args[0][0]] = fieldAPI.MulF(felts[cs.Args[1][0]], felts[cs.Args[2][0]]) + felts[cs.Args[0][0]], _, _ = fieldAPI.MulF(felts[cs.Args[1][0]], felts[cs.Args[2][0]]) case "MulE": exts[cs.Args[0][0]] = fieldAPI.MulE(exts[cs.Args[1][0]], exts[cs.Args[2][0]]) case "MulEF": @@ -195,6 +195,8 @@ func (circuit *Circuit) Define(api frontend.API) error { api.AssertIsEqual(circuit.CommitedValuesDigest, element) case "CircuitFelts2Ext": exts[cs.Args[0][0]] = babybear.Felts2Ext(felts[cs.Args[1][0]], felts[cs.Args[2][0]], felts[cs.Args[3][0]], felts[cs.Args[4][0]]) + case "CircuitFelt2Var": + vars[cs.Args[0][0]] = fieldAPI.ReduceSlow(felts[cs.Args[1][0]]).Value case "ReduceE": exts[cs.Args[0][0]] = fieldAPI.ReduceE(exts[cs.Args[0][0]]) default: