Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: gnark range check remainder of babybear reduction #1213

Closed
wants to merge 20 commits into from
Closed
22 changes: 17 additions & 5 deletions recursion/circuit/src/fri.rs
Original file line number Diff line number Diff line change
@@ -67,15 +67,19 @@ pub fn verify_two_adic_pcs<C: Config>(

let log_global_max_height = proof.fri_proof.commit_phase_commits.len() + config.log_blowup;

// The powers of alpha, where the ith element is alpha^i.
let mut alpha_pows: Vec<Ext<C::F, C::EF>> =
vec![builder.eval(SymbolicExt::from_f(C::EF::one()))];

let reduced_openings = proof
.query_openings
.iter()
.zip(&fri_challenges.query_indices)
.map(|(query_opening, &index)| {
let mut ro: [Ext<C::F, C::EF>; 32] =
[builder.eval(SymbolicExt::from_f(C::EF::zero())); 32];
let mut alpha_pow: [Ext<C::F, C::EF>; 32] =
[builder.eval(SymbolicExt::from_f(C::EF::one())); 32];
// An array of the current power for each log_height.
let mut log_height_pow = [0usize; 32];

for (batch_opening, round) in izip!(query_opening.clone(), &rounds) {
let batch_commit = round.batch_commit;
@@ -125,10 +129,17 @@ pub fn verify_two_adic_pcs<C: Config>(
let mut acc: Ext<C::F, C::EF> =
builder.eval(SymbolicExt::from_f(C::EF::zero()));
for (p_at_x, &p_at_z) in izip!(mat_opening.clone(), ps_at_z) {
acc =
builder.eval(acc + (alpha_pow[log_height] * (p_at_z - p_at_x[0])));
alpha_pow[log_height] = builder.eval(alpha_pow[log_height] * alpha);
let pow = log_height_pow[log_height];
// Fill in any missing powers of alpha.
(alpha_pows.len()..pow + 1).for_each(|_| {
let new_alpha = builder.eval(*alpha_pows.last().unwrap() * alpha);
builder.reduce_e(new_alpha);
alpha_pows.push(new_alpha);
});
acc = builder.eval(acc + (alpha_pows[pow] * (p_at_z - p_at_x[0])));
log_height_pow[log_height] += 1;
}
// builder.reduce_e(acc);
ro[log_height] = builder.eval(ro[log_height] + acc / (*z - x));
}
}
@@ -233,6 +244,7 @@ pub fn verify_query<C: Config>(
folded_eval = builder
.eval(evals_ext[0] + (beta - xs[0]) * (evals_ext[1] - evals_ext[0]) / (xs[1] - xs[0]));
x = builder.eval(x * x);
// builder.reduce_e(x);
offset += 1;
}

3 changes: 3 additions & 0 deletions recursion/circuit/src/stark.rs
Original file line number Diff line number Diff line change
@@ -86,6 +86,8 @@ where

let zeta = challenger.sample_ext(builder);

// builder.reduce_e(zeta);

let num_shard_chips = opened_values.chips.len();
let mut trace_domains = Vec::new();
let mut quotient_domains = Vec::new();
@@ -140,6 +142,7 @@ where

let mut trace_points = Vec::new();
let zeta_next = domain.next_point(builder, zeta);
// builder.reduce_e(zeta_next);
trace_points.push(zeta);
trace_points.push(zeta_next);

4 changes: 4 additions & 0 deletions recursion/compiler/src/constraints/mod.rs
Original file line number Diff line number Diff line change
@@ -358,6 +358,10 @@ impl<C: Config + Debug> ConstraintCompiler<C> {
vec![a[3].id()],
],
}),
DslIr::ReduceE(a) => constraints.push(Constraint {
opcode: ConstraintOpcode::ReduceE,
args: vec![vec![a.id()]],
}),
_ => panic!("unsupported {:?}", instruction),
};
}
1 change: 1 addition & 0 deletions recursion/compiler/src/constraints/opcodes.rs
Original file line number Diff line number Diff line change
@@ -47,4 +47,5 @@ pub enum ConstraintOpcode {
CommitCommitedValuesDigest,
CircuitFelts2Ext,
PermuteBabyBear,
ReduceE,
}
4 changes: 4 additions & 0 deletions recursion/compiler/src/ir/builder.rs
Original file line number Diff line number Diff line change
@@ -515,6 +515,10 @@ impl<C: Config> Builder<C> {
.push(DslIr::CircuitCommitCommitedValuesDigest(var));
}

pub fn reduce_e(&mut self, ext: Ext<C::F, C::EF>) {
self.operations.push(DslIr::ReduceE(ext));
}

pub fn cycle_tracker(&mut self, name: &str) {
self.operations.push(DslIr::CycleTracker(name.to_string()));
}
3 changes: 3 additions & 0 deletions recursion/compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
@@ -183,6 +183,9 @@ pub enum DslIr<C: Config> {
/// Store extension field at address
StoreE(Ext<C::F, C::EF>, Ptr<C::N>, MemIndex<C::N>),

/// Force reduction of field elements in circuit.
ReduceE(Ext<C::F, C::EF>),

// Bits.
/// Decompose a variable into size bits (bits = num2bits(var, size)). Should only be used when target is a gnark circuit.
CircuitNum2BitsV(Var<C::N>, usize, Vec<Var<C::N>>),
69 changes: 68 additions & 1 deletion recursion/gnark-ffi/go/sp1/babybear/babybear.go
Original file line number Diff line number Diff line change
@@ -6,6 +6,8 @@ package babybear
import "C"

import (
"fmt"
"math"
"math/big"

"github.com/consensys/gnark/constraint/solver"
@@ -20,6 +22,7 @@ func init() {
solver.RegisterHint(InvFHint)
solver.RegisterHint(InvEHint)
solver.RegisterHint(ReduceHint)
solver.RegisterHint(SplitLimbsHint)
}

type Variable struct {
@@ -284,7 +287,40 @@ func (p *Chip) reduceWithMaxBits(x frontend.Variable, maxNbBits uint64) frontend
p.rangeChecker.Check(quotient, int(maxNbBits-31))

remainder := result[1]
p.rangeChecker.Check(remainder, 31)

// Check that the remainder has size less than the BabyBear modulus, by decomposing it into a 27
// bit limb and a 4 bit limb.
new_result, new_err := p.api.Compiler().NewHint(SplitLimbsHint, 2, remainder)
if new_err != nil {
panic(new_err)
}

lowLimb := new_result[0]
highLimb := new_result[1]

// Check that the hint is correct.
p.api.AssertIsEqual(
p.api.Add(
p.api.Mul(highLimb, frontend.Variable(uint64(math.Pow(2, 27)))),
lowLimb,
),
remainder,
)
p.rangeChecker.Check(highLimb, 4)
p.rangeChecker.Check(lowLimb, 27)

// If the most significant bits are all 1, then we need to check that the least significant bits
// are all zero in order for element to be less than the BabyBear modulus. Otherwise, we don't
// need to do any checks, since we already know that the element is less than the BabyBear modulus.
shouldCheck := p.api.IsZero(p.api.Sub(highLimb, uint64(math.Pow(2, 4))-1))
p.api.AssertIsEqual(
p.api.Select(
shouldCheck,
lowLimb,
frontend.Variable(0),
),
frontend.Variable(0),
)

p.api.AssertIsEqual(x, p.api.Add(p.api.Mul(quotient, modulus), result[1]))

@@ -304,13 +340,44 @@ func ReduceHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
return nil
}

func (p *Chip) ReduceE(x ExtensionVariable) ExtensionVariable {
for i := 0; i < 4; i++ {
x.Value[i] = p.ReduceSlow(x.Value[i])
}
return x
}

func InvFHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
a := C.uint(inputs[0].Uint64())
ainv := C.babybearinv(a)
results[0].SetUint64(uint64(ainv))
return nil
}

// The hint used to split a BabyBear Variable into a 4 bit limb (the most significant bits) and a
// 27 bit limb.
func SplitLimbsHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
if len(inputs) != 1 {
panic("SplitLimbsHint expects 1 input operand")
}

// The BabyBear field element
input := inputs[0]

if input.Cmp(modulus) == 0 || input.Cmp(modulus) == 1 {
return fmt.Errorf("input is not in the field")
}

two_27 := big.NewInt(int64(math.Pow(2, 27)))

// The least significant bits
results[0] = new(big.Int).Rem(input, two_27)
// The most significant bits
results[1] = new(big.Int).Quo(input, two_27)

return nil
}

func InvEHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
a := C.uint(inputs[0].Uint64())
b := C.uint(inputs[1].Uint64())
2 changes: 2 additions & 0 deletions recursion/gnark-ffi/go/sp1/sp1.go
Original file line number Diff line number Diff line change
@@ -195,6 +195,8 @@ func (circuit *Circuit) Define(api frontend.API) error {
api.AssertIsEqual(circuit.CommitedValuesDigest, element)
case "CircuitFelts2Ext":
exts[cs.Args[0][0]] = babybear.Felts2Ext(felts[cs.Args[1][0]], felts[cs.Args[2][0]], felts[cs.Args[3][0]], felts[cs.Args[4][0]])
case "ReduceE":
exts[cs.Args[0][0]] = fieldAPI.ReduceE(exts[cs.Args[0][0]])
default:
return fmt.Errorf("unhandled opcode: %s", cs.Opcode)
}