Skip to content

Commit

Permalink
Merge pull request #1131 from Consensys/perf/toom3-r1cs
Browse files Browse the repository at this point in the history
Perf: Toom-3 for Fp6 in R1CS
  • Loading branch information
yelhousni authored May 15, 2024
2 parents 5c74bd7 + 63ba64e commit e204083
Show file tree
Hide file tree
Showing 10 changed files with 572 additions and 16 deletions.
Binary file modified internal/stats/latest.stats
Binary file not shown.
143 changes: 143 additions & 0 deletions std/algebra/emulated/fields_bls12381/e6.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package fields_bls12381

import (
"math/big"

bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/internal/frontendtype"
)

type E6 struct {
Expand Down Expand Up @@ -79,7 +82,116 @@ func (e Ext6) Sub(x, y *E6) *E6 {
}
}

// Mul multiplies two E6 elmts
func (e Ext6) Mul(x, y *E6) *E6 {
if ft, ok := e.api.(frontendtype.FrontendTyper); ok {
switch ft.FrontendType() {
case frontendtype.R1CS:
return e.mulToom3OverKaratsuba(x, y)
case frontendtype.SCS:
return e.mulKaratsubaOverKaratsuba(x, y)
}
}
return e.mulKaratsubaOverKaratsuba(x, y)
}

func (e Ext6) mulToom3OverKaratsuba(x, y *E6) *E6 {
// Toom-Cook-3x over Karatsuba:
// We start by computing five interpolation points – these are evaluations of
// the product x(u)y(u) with u ∈ {0, ±1, 2, ∞}:
//
// v0 = x(0)y(0) = x.A0 * y.A0
// v1 = x(1)y(1) = (x.A0 + x.A1 + x.A2)(y.A0 + y.A1 + y.A2)
// v2 = x(−1)y(−1) = (x.A0 − x.A1 + x.A2)(y.A0 − y.A1 + y.A2)
// v3 = x(2)y(2) = (x.A0 + 2x.A1 + 4x.A2)(y.A0 + 2y.A1 + 4y.A2)
// v4 = x(∞)y(∞) = x.A2 * y.A2

v0 := e.Ext2.Mul(&x.B0, &y.B0)

t1 := e.Ext2.Add(&x.B0, &x.B2)
t2 := e.Ext2.Add(&y.B0, &y.B2)
t3 := e.Ext2.Add(t2, &y.B1)
v1 := e.Ext2.Add(t1, &x.B1)
v1 = e.Ext2.Mul(v1, t3)

t3 = e.Ext2.Sub(t2, &y.B1)
v2 := e.Ext2.Sub(t1, &x.B1)
v2 = e.Ext2.Mul(v2, t3)

t1 = e.Ext2.MulByConstElement(&x.B1, big.NewInt(2))
t2 = e.Ext2.MulByConstElement(&x.B2, big.NewInt(4))
v3 := e.Ext2.Add(t1, t2)
v3 = e.Ext2.Add(v3, &x.B0)
t1 = e.Ext2.MulByConstElement(&y.B1, big.NewInt(2))
t2 = e.Ext2.MulByConstElement(&y.B2, big.NewInt(4))
t3 = e.Ext2.Add(t1, t2)
t3 = e.Ext2.Add(t3, &y.B0)
v3 = e.Ext2.Mul(v3, t3)

v4 := e.Ext2.Mul(&x.B2, &y.B2)

// Then the interpolation is performed as:
//
// a0 = v0 + β((1/2)v0 − (1/2)v1 − (1/6)v2 + (1/6)v3 − 2v4)
// a1 = −(1/2)v0 + v1 − (1/3)v2 − (1/6)v3 + 2v4 + βv4
// a2 = −v0 + (1/2)v1 + (1/2)v2 − v4
//
// where β is the cubic non-residue.
//
// In-circuit, we compute 6*x*y as
// c0 = 6v0 + β(3v0 − 3v1 − v2 + v3 − 12v4)
// a1 = -(3v0 + 2v2 + v3) + 6(v1 + 2v4 + βv4)
// a2 = 3(v1 + v2 - 2(v0 + v4))
//
// and then divide a0, a1 and a2 by 6 using a hint.

a0 := e.Ext2.MulByConstElement(v0, big.NewInt(6))
t1 = e.Ext2.Sub(v0, v1)
t1 = e.Ext2.MulByConstElement(t1, big.NewInt(3))
t1 = e.Ext2.Sub(t1, v2)
t1 = e.Ext2.Add(t1, v3)
t2 = e.Ext2.MulByConstElement(v4, big.NewInt(12))
t1 = e.Ext2.Sub(t1, t2)
t1 = e.Ext2.MulByNonResidue(t1)
a0 = e.Ext2.Add(a0, t1)

a1 := e.Ext2.MulByConstElement(v0, big.NewInt(3))
t1 = e.Ext2.MulByConstElement(v2, big.NewInt(2))
a1 = e.Ext2.Add(a1, t1)
a1 = e.Ext2.Add(a1, v3)
t1 = e.Ext2.MulByConstElement(v4, big.NewInt(2))
t1 = e.Ext2.Add(t1, v1)
t2 = e.Ext2.MulByNonResidue(v4)
t1 = e.Ext2.Add(t1, t2)
t1 = e.Ext2.MulByConstElement(t1, big.NewInt(6))
a1 = e.Ext2.Sub(t1, a1)

a2 := e.Ext2.Add(v1, v2)
a2 = e.Ext2.MulByConstElement(a2, big.NewInt(3))
t1 = e.Ext2.Add(v0, v4)
t1 = e.Ext2.MulByConstElement(t1, big.NewInt(6))
a2 = e.Ext2.Sub(a2, t1)

res := e.divE6By6([6]*baseEl{&a0.A0, &a0.A1, &a1.A0, &a1.A1, &a2.A0, &a2.A1})
return &E6{
B0: E2{
A0: *res[0],
A1: *res[1],
},
B1: E2{
A0: *res[2],
A1: *res[3],
},
B2: E2{
A0: *res[4],
A1: *res[5],
},
}
}

func (e Ext6) mulKaratsubaOverKaratsuba(x, y *E6) *E6 {
// Karatsuba over Karatsuba:
// Algorithm 13 from https://eprint.iacr.org/2010/354.pdf
t0 := e.Ext2.Mul(&x.B0, &y.B0)
t1 := e.Ext2.Mul(&x.B1, &y.B1)
t2 := e.Ext2.Mul(&x.B2, &y.B2)
Expand Down Expand Up @@ -307,6 +419,37 @@ func (e Ext6) DivUnchecked(x, y *E6) *E6 {
return &div
}

func (e Ext6) divE6By6(x [6]*baseEl) [6]*baseEl {
res, err := e.fp.NewHint(divE6By6Hint, 6, x[0], x[1], x[2], x[3], x[4], x[5])
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}

y0 := *res[0]
y1 := *res[1]
y2 := *res[2]
y3 := *res[3]
y4 := *res[4]
y5 := *res[5]

// xi == 6 * yi
x0 := e.fp.MulConst(&y0, big.NewInt(6))
x1 := e.fp.MulConst(&y1, big.NewInt(6))
x2 := e.fp.MulConst(&y2, big.NewInt(6))
x3 := e.fp.MulConst(&y3, big.NewInt(6))
x4 := e.fp.MulConst(&y4, big.NewInt(6))
x5 := e.fp.MulConst(&y5, big.NewInt(6))
e.fp.AssertIsEqual(x[0], x0)
e.fp.AssertIsEqual(x[1], x1)
e.fp.AssertIsEqual(x[2], x2)
e.fp.AssertIsEqual(x[3], x3)
e.fp.AssertIsEqual(x[4], x4)
e.fp.AssertIsEqual(x[5], x5)

return [6]*baseEl{&y0, &y1, &y2, &y3, &y4, &y5}
}

func (e Ext6) Select(selector frontend.Variable, z1, z0 *E6) *E6 {
b0 := e.Ext2.Select(selector, &z1.B0, &z0.B0)
b1 := e.Ext2.Select(selector, &z1.B1, &z0.B1)
Expand Down
33 changes: 33 additions & 0 deletions std/algebra/emulated/fields_bls12381/e6_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,39 @@ func TestMulFp6(t *testing.T) {

}

type e6MulVariant struct {
A, B, C E6
}

func (circuit *e6MulVariant) Define(api frontend.API) error {
e := NewExt6(api)
expected1 := e.mulKaratsubaOverKaratsuba(&circuit.A, &circuit.B)
expected2 := e.mulToom3OverKaratsuba(&circuit.A, &circuit.B)
e.AssertIsEqual(expected1, &circuit.C)
e.AssertIsEqual(expected2, &circuit.C)
return nil
}

func TestMulFp6Variants(t *testing.T) {

assert := test.NewAssert(t)
// witness values
var a, b, c bls12381.E6
_, _ = a.SetRandom()
_, _ = b.SetRandom()
c.Mul(&a, &b)

witness := e6Mul{
A: FromE6(&a),
B: FromE6(&b),
C: FromE6(&c),
}

err := test.IsSolved(&e6Mul{}, &witness, ecc.BN254.ScalarField())
assert.NoError(err)

}

type e6Square struct {
A, C E6
}
Expand Down
32 changes: 32 additions & 0 deletions std/algebra/emulated/fields_bls12381/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"math/big"

bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fp"
"github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/std/math/emulated"
)
Expand All @@ -22,6 +23,7 @@ func GetHints() []solver.Hint {
divE6Hint,
inverseE6Hint,
squareTorusHint,
divE6By6Hint,
// E12
divE12Hint,
inverseE12Hint,
Expand Down Expand Up @@ -149,6 +151,36 @@ func squareTorusHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int)
})
}

func divE6By6Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error {
return emulated.UnwrapHint(nativeInputs, nativeOutputs,
func(mod *big.Int, inputs, outputs []*big.Int) error {
var a, c bls12381.E6

a.B0.A0.SetBigInt(inputs[0])
a.B0.A1.SetBigInt(inputs[1])
a.B1.A0.SetBigInt(inputs[2])
a.B1.A1.SetBigInt(inputs[3])
a.B2.A0.SetBigInt(inputs[4])
a.B2.A1.SetBigInt(inputs[5])

var sixInv fp.Element
sixInv.SetString("6")
sixInv.Inverse(&sixInv)
c.B0.MulByElement(&a.B0, &sixInv)
c.B1.MulByElement(&a.B1, &sixInv)
c.B2.MulByElement(&a.B2, &sixInv)

c.B0.A0.BigInt(outputs[0])
c.B0.A1.BigInt(outputs[1])
c.B1.A0.BigInt(outputs[2])
c.B1.A1.BigInt(outputs[3])
c.B2.A0.BigInt(outputs[4])
c.B2.A1.BigInt(outputs[5])

return nil
})
}

// E12 hints
func inverseE12Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error {
return emulated.UnwrapHint(nativeInputs, nativeOutputs,
Expand Down
Loading

0 comments on commit e204083

Please sign in to comment.