diff --git a/recursion/circuit/src/fri.rs b/recursion/circuit/src/fri.rs index 6f8c0b9b4b..fd63d9f4c7 100644 --- a/recursion/circuit/src/fri.rs +++ b/recursion/circuit/src/fri.rs @@ -67,6 +67,10 @@ pub fn verify_two_adic_pcs( let log_global_max_height = proof.fri_proof.commit_phase_commits.len() + config.log_blowup; + // The powers of alpha, where the ith element is alpha^i. + let mut alpha_pows: Vec> = + vec![builder.eval(SymbolicExt::from_f(C::EF::one()))]; + let reduced_openings = proof .query_openings .iter() @@ -74,8 +78,8 @@ pub fn verify_two_adic_pcs( .map(|(query_opening, &index)| { let mut ro: [Ext; 32] = [builder.eval(SymbolicExt::from_f(C::EF::zero())); 32]; - let mut alpha_pow: [Ext; 32] = - [builder.eval(SymbolicExt::from_f(C::EF::one())); 32]; + // An array of the current power for each log_height. + let mut log_height_pow = [0usize; 32]; for (batch_opening, round) in izip!(query_opening.clone(), &rounds) { let batch_commit = round.batch_commit; @@ -125,10 +129,17 @@ pub fn verify_two_adic_pcs( let mut acc: Ext = builder.eval(SymbolicExt::from_f(C::EF::zero())); for (p_at_x, &p_at_z) in izip!(mat_opening.clone(), ps_at_z) { - acc = - builder.eval(acc + (alpha_pow[log_height] * (p_at_z - p_at_x[0]))); - alpha_pow[log_height] = builder.eval(alpha_pow[log_height] * alpha); + let pow = log_height_pow[log_height]; + // Fill in any missing powers of alpha. + (alpha_pows.len()..pow + 1).for_each(|_| { + let new_alpha = builder.eval(*alpha_pows.last().unwrap() * alpha); + builder.reduce_e(new_alpha); + alpha_pows.push(new_alpha); + }); + acc = builder.eval(acc + (alpha_pows[pow] * (p_at_z - p_at_x[0]))); + log_height_pow[log_height] += 1; } + // builder.reduce_e(acc); ro[log_height] = builder.eval(ro[log_height] + acc / (*z - x)); } } @@ -233,6 +244,7 @@ pub fn verify_query( folded_eval = builder .eval(evals_ext[0] + (beta - xs[0]) * (evals_ext[1] - evals_ext[0]) / (xs[1] - xs[0])); x = builder.eval(x * x); + // builder.reduce_e(x); offset += 1; } diff --git a/recursion/circuit/src/stark.rs b/recursion/circuit/src/stark.rs index c2986c62a4..1874941530 100644 --- a/recursion/circuit/src/stark.rs +++ b/recursion/circuit/src/stark.rs @@ -86,6 +86,8 @@ where let zeta = challenger.sample_ext(builder); + // builder.reduce_e(zeta); + let num_shard_chips = opened_values.chips.len(); let mut trace_domains = Vec::new(); let mut quotient_domains = Vec::new(); @@ -140,6 +142,7 @@ where let mut trace_points = Vec::new(); let zeta_next = domain.next_point(builder, zeta); + // builder.reduce_e(zeta_next); trace_points.push(zeta); trace_points.push(zeta_next); diff --git a/recursion/compiler/src/constraints/mod.rs b/recursion/compiler/src/constraints/mod.rs index eb43951358..612aee4a46 100644 --- a/recursion/compiler/src/constraints/mod.rs +++ b/recursion/compiler/src/constraints/mod.rs @@ -358,6 +358,10 @@ impl ConstraintCompiler { vec![a[3].id()], ], }), + DslIr::ReduceE(a) => constraints.push(Constraint { + opcode: ConstraintOpcode::ReduceE, + args: vec![vec![a.id()]], + }), _ => panic!("unsupported {:?}", instruction), }; } diff --git a/recursion/compiler/src/constraints/opcodes.rs b/recursion/compiler/src/constraints/opcodes.rs index 4911e0f108..02ed47eea5 100644 --- a/recursion/compiler/src/constraints/opcodes.rs +++ b/recursion/compiler/src/constraints/opcodes.rs @@ -47,4 +47,5 @@ pub enum ConstraintOpcode { CommitCommitedValuesDigest, CircuitFelts2Ext, PermuteBabyBear, + ReduceE, } diff --git a/recursion/compiler/src/ir/builder.rs b/recursion/compiler/src/ir/builder.rs index e6480d6632..f9236b8970 100644 --- a/recursion/compiler/src/ir/builder.rs +++ b/recursion/compiler/src/ir/builder.rs @@ -515,6 +515,10 @@ impl Builder { .push(DslIr::CircuitCommitCommitedValuesDigest(var)); } + pub fn reduce_e(&mut self, ext: Ext) { + self.operations.push(DslIr::ReduceE(ext)); + } + pub fn cycle_tracker(&mut self, name: &str) { self.operations.push(DslIr::CycleTracker(name.to_string())); } diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index ab0cb67098..636f4c86a2 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -183,6 +183,9 @@ pub enum DslIr { /// Store extension field at address StoreE(Ext, Ptr, MemIndex), + /// Force reduction of field elements in circuit. + ReduceE(Ext), + // Bits. /// Decompose a variable into size bits (bits = num2bits(var, size)). Should only be used when target is a gnark circuit. CircuitNum2BitsV(Var, usize, Vec>), diff --git a/recursion/gnark-ffi/go/sp1/babybear/babybear.go b/recursion/gnark-ffi/go/sp1/babybear/babybear.go index 066edd2653..445765aad3 100644 --- a/recursion/gnark-ffi/go/sp1/babybear/babybear.go +++ b/recursion/gnark-ffi/go/sp1/babybear/babybear.go @@ -6,6 +6,8 @@ package babybear import "C" import ( + "fmt" + "math" "math/big" "github.com/consensys/gnark/constraint/solver" @@ -20,6 +22,7 @@ func init() { solver.RegisterHint(InvFHint) solver.RegisterHint(InvEHint) solver.RegisterHint(ReduceHint) + solver.RegisterHint(SplitLimbsHint) } type Variable struct { @@ -284,7 +287,40 @@ func (p *Chip) reduceWithMaxBits(x frontend.Variable, maxNbBits uint64) frontend p.rangeChecker.Check(quotient, int(maxNbBits-31)) remainder := result[1] - p.rangeChecker.Check(remainder, 31) + + // Check that the remainder has size less than the BabyBear modulus, by decomposing it into a 27 + // bit limb and a 4 bit limb. + new_result, new_err := p.api.Compiler().NewHint(SplitLimbsHint, 2, remainder) + if new_err != nil { + panic(new_err) + } + + lowLimb := new_result[0] + highLimb := new_result[1] + + // Check that the hint is correct. + p.api.AssertIsEqual( + p.api.Add( + p.api.Mul(highLimb, frontend.Variable(uint64(math.Pow(2, 27)))), + lowLimb, + ), + remainder, + ) + p.rangeChecker.Check(highLimb, 4) + p.rangeChecker.Check(lowLimb, 27) + + // If the most significant bits are all 1, then we need to check that the least significant bits + // are all zero in order for element to be less than the BabyBear modulus. Otherwise, we don't + // need to do any checks, since we already know that the element is less than the BabyBear modulus. + shouldCheck := p.api.IsZero(p.api.Sub(highLimb, uint64(math.Pow(2, 4))-1)) + p.api.AssertIsEqual( + p.api.Select( + shouldCheck, + lowLimb, + frontend.Variable(0), + ), + frontend.Variable(0), + ) p.api.AssertIsEqual(x, p.api.Add(p.api.Mul(quotient, modulus), result[1])) @@ -304,6 +340,13 @@ func ReduceHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { return nil } +func (p *Chip) ReduceE(x ExtensionVariable) ExtensionVariable { + for i := 0; i < 4; i++ { + x.Value[i] = p.ReduceSlow(x.Value[i]) + } + return x +} + func InvFHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { a := C.uint(inputs[0].Uint64()) ainv := C.babybearinv(a) @@ -311,6 +354,30 @@ func InvFHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { return nil } +// The hint used to split a BabyBear Variable into a 4 bit limb (the most significant bits) and a +// 27 bit limb. +func SplitLimbsHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + if len(inputs) != 1 { + panic("SplitLimbsHint expects 1 input operand") + } + + // The BabyBear field element + input := inputs[0] + + if input.Cmp(modulus) == 0 || input.Cmp(modulus) == 1 { + return fmt.Errorf("input is not in the field") + } + + two_27 := big.NewInt(int64(math.Pow(2, 27))) + + // The least significant bits + results[0] = new(big.Int).Rem(input, two_27) + // The most significant bits + results[1] = new(big.Int).Quo(input, two_27) + + return nil +} + func InvEHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { a := C.uint(inputs[0].Uint64()) b := C.uint(inputs[1].Uint64()) diff --git a/recursion/gnark-ffi/go/sp1/sp1.go b/recursion/gnark-ffi/go/sp1/sp1.go index a7fe4b651c..5760a43d90 100644 --- a/recursion/gnark-ffi/go/sp1/sp1.go +++ b/recursion/gnark-ffi/go/sp1/sp1.go @@ -195,6 +195,8 @@ func (circuit *Circuit) Define(api frontend.API) error { api.AssertIsEqual(circuit.CommitedValuesDigest, element) case "CircuitFelts2Ext": exts[cs.Args[0][0]] = babybear.Felts2Ext(felts[cs.Args[1][0]], felts[cs.Args[2][0]], felts[cs.Args[3][0]], felts[cs.Args[4][0]]) + case "ReduceE": + exts[cs.Args[0][0]] = fieldAPI.ReduceE(exts[cs.Args[0][0]]) default: return fmt.Errorf("unhandled opcode: %s", cs.Opcode) }