Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: optimistic reduction, efficient felt2var, two effective reduces #1270

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions recursion/circuit/src/challenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ pub fn reduce_32<C: Config>(builder: &mut Builder<C>, vals: &[Felt<C::F>]) -> Va
let mut power = C::N::one();
let result: Var<C::N> = builder.eval(C::N::zero());
for val in vals.iter() {
let bits = builder.num2bits_f_circuit(*val);
let val = builder.bits2num_v_circuit(&bits);
let val = builder.felt2var_circuit(*val);
builder.assign(result, result + val * power);
power *= C::N::from_canonical_u64(1u64 << 32);
}
Expand Down
3 changes: 3 additions & 0 deletions recursion/circuit/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ pub fn verify_query<C: Config>(
let index_sibling: Var<_> = builder.eval(one - index_bits.clone()[offset]);
let index_pair = &index_bits[(offset + 1)..];

builder.reduce_e(folded_eval);
builder.reduce_e(step.sibling_value);

let evals_ext = [
builder.select_ef(index_sibling, folded_eval, step.sibling_value),
builder.select_ef(index_sibling, step.sibling_value, folded_eval),
Expand Down
4 changes: 4 additions & 0 deletions recursion/compiler/src/constraints/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ impl<C: Config + Debug> ConstraintCompiler<C> {
opcode: ConstraintOpcode::ReduceE,
args: vec![vec![a.id()]],
}),
DslIr::CircuitFelt2Var(a, b) => constraints.push(Constraint {
opcode: ConstraintOpcode::CircuitFelt2Var,
args: vec![vec![b.id()], vec![a.id()]],
}),
_ => panic!("unsupported {:?}", instruction),
};
}
Expand Down
1 change: 1 addition & 0 deletions recursion/compiler/src/constraints/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub enum ConstraintOpcode {
CommitVkeyHash,
CommitCommitedValuesDigest,
CircuitFelts2Ext,
CircuitFelt2Var,
PermuteBabyBear,
ReduceE,
}
6 changes: 6 additions & 0 deletions recursion/compiler/src/ir/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,12 @@ impl<C: Config> Builder<C> {
self.operations.push(DslIr::ReduceE(ext));
}

pub fn felt2var_circuit(&mut self, felt: Felt<C::F>) -> Var<C::N> {
let var = self.uninit();
self.operations.push(DslIr::CircuitFelt2Var(felt, var));
var
}

pub fn cycle_tracker(&mut self, name: &str) {
self.operations.push(DslIr::CycleTracker(name.to_string()));
}
Expand Down
2 changes: 2 additions & 0 deletions recursion/compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ pub enum DslIr<C: Config> {
CircuitNum2BitsV(Var<C::N>, usize, Vec<Var<C::N>>),
/// Decompose a field element into bits (bits = num2bits(felt)). Should only be used when target is a gnark circuit.
CircuitNum2BitsF(Felt<C::F>, Vec<Var<C::N>>),
/// Convert a Felt to a Var in a circuit. Avoids decomposing to bits and then reconstructing.
CircuitFelt2Var(Felt<C::F>, Var<C::N>),

// Hashing.
/// Permutes an array of baby bear elements using Poseidon2 (output = p2_permute(array)).
Expand Down
56 changes: 42 additions & 14 deletions recursion/gnark-ffi/go/sp1/babybear/babybear.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,19 @@ func Zero() Variable {
}
}

func One() Variable {
return Variable{
Value: frontend.Variable("1"),
NbBits: 1,
}
}

func NewF(value string) Variable {
if value == "0" {
return Zero()
} else if value == "1" {
return One()
}
return Variable{
Value: frontend.Variable(value),
NbBits: 31,
Expand Down Expand Up @@ -90,11 +102,22 @@ func (c *Chip) SubF(a, b Variable) Variable {
return c.AddF(a, negB)
}

func (c *Chip) MulF(a, b Variable) Variable {
return c.reduceFast(Variable{
Value: c.api.Mul(a.Value, b.Value),
NbBits: a.NbBits + b.NbBits,
})
func (c *Chip) MulF(a, b Variable) (Variable, Variable, Variable) {
varC := a
varD := b

for varC.NbBits+varD.NbBits > 250 {
if varC.NbBits > varD.NbBits {
varC = Variable{Value: c.reduceWithMaxBits(varC.Value, uint64(varC.NbBits)), NbBits: 31}
} else {
varD = Variable{Value: c.reduceWithMaxBits(varD.Value, uint64(varD.NbBits)), NbBits: 31}
}
}

return Variable{
Value: c.api.Mul(varC.Value, varD.Value),
NbBits: varC.NbBits + varD.NbBits,
}, varC, varD
}

func (c *Chip) MulFConst(a Variable, b int) Variable {
Expand Down Expand Up @@ -131,7 +154,7 @@ func (c *Chip) invF(in Variable) Variable {
Value: result[0],
NbBits: 31,
}
product := c.MulF(in, xinv)
product, _, _ := c.MulF(in, xinv)
c.AssertIsEqualF(product, NewF("1"))

return xinv
Expand Down Expand Up @@ -207,25 +230,30 @@ func (c *Chip) MulE(a, b ExtensionVariable) ExtensionVariable {
Zero(),
Zero(),
}
varA := a.Value
varB := b.Value

for i := 0; i < 4; i++ {
for j := 0; j < 4; j++ {
newVal, newA, newB := c.MulF(varA[i], varB[j])
if i+j >= 4 {
v2[i+j-4] = c.AddF(v2[i+j-4], c.MulFConst(c.MulF(a.Value[i], b.Value[j]), 11))
v2[i+j-4] = c.AddF(v2[i+j-4], c.MulFConst(newVal, 11))
} else {
v2[i+j] = c.AddF(v2[i+j], c.MulF(a.Value[i], b.Value[j]))
v2[i+j] = c.AddF(v2[i+j], newVal)
}
varA[i] = newA
varB[j] = newB
}
}

return ExtensionVariable{Value: v2}
}

func (c *Chip) MulEF(a ExtensionVariable, b Variable) ExtensionVariable {
v1 := c.MulF(a.Value[0], b)
v2 := c.MulF(a.Value[1], b)
v3 := c.MulF(a.Value[2], b)
v4 := c.MulF(a.Value[3], b)
v1, _, newB := c.MulF(a.Value[0], b)
v2, _, newB := c.MulF(a.Value[1], newB)
v3, _, newB := c.MulF(a.Value[2], newB)
v4, _, _ := c.MulF(a.Value[3], newB)
return ExtensionVariable{Value: [4]Variable{v1, v2, v3, v4}}
}

Expand Down Expand Up @@ -273,7 +301,7 @@ func (c *Chip) ToBinary(in Variable) []frontend.Variable {
}

func (p *Chip) reduceFast(x Variable) Variable {
if x.NbBits >= uint(126) {
if x.NbBits >= uint(252) {
return Variable{
Value: p.reduceWithMaxBits(x.Value, uint64(x.NbBits)),
NbBits: 31,
Expand All @@ -283,7 +311,7 @@ func (p *Chip) reduceFast(x Variable) Variable {
}

func (p *Chip) ReduceSlow(x Variable) Variable {
if x.NbBits == 31 {
if x.NbBits <= 31 {
return x
}
return Variable{
Expand Down
4 changes: 2 additions & 2 deletions recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (p *Poseidon2BabyBearChip) diffusionPermuteMut(state *[BABYBEAR_WIDTH]babyb
montyInverse := babybear.NewF("943718400")
p.matmulInternal(state, &matInternalDiagM1)
for i := 0; i < BABYBEAR_WIDTH; i++ {
state[i] = p.fieldApi.MulF(state[i], montyInverse)
state[i], _, _ = p.fieldApi.MulF(state[i], montyInverse)
}

}
Expand All @@ -150,7 +150,7 @@ func (p *Poseidon2BabyBearChip) matmulInternal(
}

for i := 0; i < BABYBEAR_WIDTH; i++ {
state[i] = p.fieldApi.MulF(state[i], matInternalDiagM1[i])
state[i], _, _ = p.fieldApi.MulF(state[i], matInternalDiagM1[i])
state[i] = p.fieldApi.AddF(state[i], sum)
}
}
4 changes: 3 additions & 1 deletion recursion/gnark-ffi/go/sp1/sp1.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (circuit *Circuit) Define(api frontend.API) error {
case "MulV":
vars[cs.Args[0][0]] = api.Mul(vars[cs.Args[1][0]], vars[cs.Args[2][0]])
case "MulF":
felts[cs.Args[0][0]] = fieldAPI.MulF(felts[cs.Args[1][0]], felts[cs.Args[2][0]])
felts[cs.Args[0][0]], _, _ = fieldAPI.MulF(felts[cs.Args[1][0]], felts[cs.Args[2][0]])
case "MulE":
exts[cs.Args[0][0]] = fieldAPI.MulE(exts[cs.Args[1][0]], exts[cs.Args[2][0]])
case "MulEF":
Expand Down Expand Up @@ -195,6 +195,8 @@ func (circuit *Circuit) Define(api frontend.API) error {
api.AssertIsEqual(circuit.CommitedValuesDigest, element)
case "CircuitFelts2Ext":
exts[cs.Args[0][0]] = babybear.Felts2Ext(felts[cs.Args[1][0]], felts[cs.Args[2][0]], felts[cs.Args[3][0]], felts[cs.Args[4][0]])
case "CircuitFelt2Var":
vars[cs.Args[0][0]] = fieldAPI.ReduceSlow(felts[cs.Args[1][0]]).Value
case "ReduceE":
exts[cs.Args[0][0]] = fieldAPI.ReduceE(exts[cs.Args[0][0]])
default:
Expand Down
Loading