diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 83fd332c55..c7b54b514d 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -25,11 +25,6 @@ jobs: git update-index --assume-unchanged go.sum if [[ -n $(git status --porcelain) ]]; then echo "git repo is dirty after running go generate -- please don't modify generated files"; echo $(git diff);echo $(git status --porcelain); exit 1; fi - # hack to ensure golanglint process generated files - - name: remove "generated by" comments from generated files - run: | - find . -type f -name '*.go' -exec sed -i 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; - # on macos: find . -type f -name '*.go' -exec sed -i '' -E 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 1e21a62c65..6dcc8a68a0 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -28,11 +28,6 @@ jobs: git update-index --assume-unchanged go.sum if [[ -n $(git status --porcelain) ]]; then echo "git repo is dirty after running go generate -- please don't modify generated files"; echo $(git diff);echo $(git status --porcelain); exit 1; fi - # hack to ensure golanglint process generated files - - name: remove "generated by" comments from generated files - run: | - find . -type f -name '*.go' -exec sed -i 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; - # on macos: find . -type f -name '*.go' -exec sed -i '' -E 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: diff --git a/.golangci.yml b/.golangci.yml index fb5d970744..0a86beffd4 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -13,5 +13,7 @@ linters-settings: gosec: excludes: - G115 # Conversions from int -> uint etc. +issues: + exclude-generated: disable run: issues-exit-code: 1 \ No newline at end of file diff --git a/backend/groth16/bls12-377/marshal.go b/backend/groth16/bls12-377/marshal.go index d0e0f295c8..3d0d9ea30b 100644 --- a/backend/groth16/bls12-377/marshal.go +++ b/backend/groth16/bls12-377/marshal.go @@ -22,6 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" "github.com/consensys/gnark-crypto/utils/unsafe" "github.com/consensys/gnark/internal/utils" + + "fmt" "io" ) @@ -196,35 +198,39 @@ func (vk *VerifyingKey) readFrom(r io.Reader, raw bool) (int64, error) { &nbCommitments, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err + return dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, nbCommitments) var n int64 - for i := range vk.CommitmentKeys { + for i := 0; i < int(nbCommitments); i++ { var ( m int64 err error ) + commitmentKey := pedersen.VerifyingKey{} if raw { - m, err = vk.CommitmentKeys[i].UnsafeReadFrom(r) + m, err = commitmentKey.UnsafeReadFrom(r) } else { - m, err = vk.CommitmentKeys[i].ReadFrom(r) + m, err = commitmentKey.ReadFrom(r) } n += m if err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + vk.CommitmentKeys = append(vk.CommitmentKeys, commitmentKey) + } + if len(vk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(vk.CommitmentKeys)) } // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 if err := vk.Precompute(); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("precompute: %w", err) } return n + dec.BytesRead(), nil @@ -320,7 +326,7 @@ func (pk *ProvingKey) UnsafeReadFrom(r io.Reader) (int64, error) { func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { n, err := pk.Domain.ReadFrom(r) if err != nil { - return n, err + return n, fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, decOptions...) @@ -344,31 +350,34 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read nbCommitments: %w", err) } - - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - n2, err := pk.CommitmentKeys[i].ReadFrom(r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + n2, err := cpkey.ReadFrom(r) n += n2 if err != nil { - return n, err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return n + dec.BytesRead(), nil @@ -451,11 +460,11 @@ func (pk *ProvingKey) WriteDump(w io.Writer) error { func (pk *ProvingKey) ReadDump(r io.Reader) error { // read the marker to fail early in case of malformed input if err := unsafe.ReadMarker(r); err != nil { - return err + return fmt.Errorf("read marker: %w", err) } if _, err := pk.Domain.ReadFrom(r); err != nil { - return err + return fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, curve.NoSubgroupChecks()) @@ -479,57 +488,61 @@ func (pk *ProvingKey) ReadDump(r io.Reader) error { &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return err + return fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return err + return fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return err + return fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return err + return fmt.Errorf("read nbCommitments: %w", err) } // read slices of points var err error pk.G1.A, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.A: %w", err) } pk.G1.B, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.B: %w", err) } pk.G1.Z, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.Z: %w", err) } pk.G1.K, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.K: %w", err) } pk.G2.B, _, err = unsafe.ReadSlice[[]curve.G2Affine](r) if err != nil { - return err + return fmt.Errorf("read G2.B: %w", err) } - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - pk.CommitmentKeys[i].Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + cpkey.Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basis %d: %w", i, err) } - pk.CommitmentKeys[i].BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + cpkey.BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basisExpSigma %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return nil diff --git a/backend/groth16/bls12-377/marshal_test.go b/backend/groth16/bls12-377/marshal_test.go index 1859b21e08..4ad4a8719e 100644 --- a/backend/groth16/bls12-377/marshal_test.go +++ b/backend/groth16/bls12-377/marshal_test.go @@ -98,7 +98,6 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } - vk.CommitmentKeys = []pedersen.VerifyingKey{} if withCommitment { vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) @@ -108,7 +107,7 @@ func TestVerifyingKeySerialization(t *testing.T) { for j := range bases[i] { bases[i][j] = elem elem.Add(&elem, &p1) - vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigma: p2}) + vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigmaNeg: p2}) } } assert.NoError(t, err) @@ -175,17 +174,18 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true - pedersenBasis := make([]curve.G1Affine, nbCommitment) - pedersenBases := make([][]curve.G1Affine, nbCommitment) - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) - for i := range pedersenBasis { - pedersenBasis[i] = p1 - pedersenBases[i] = pedersenBasis[:i+1] - } - { - var err error - pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) - require.NoError(t, err) + if nbCommitment > 0 { + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) + require.NoError(t, err) + } } if err := io.RoundTripCheck(&pk, func() any { return new(ProvingKey) }); err != nil { diff --git a/backend/groth16/bls12-377/setup.go b/backend/groth16/bls12-377/setup.go index 5b6b3af78d..24efac699e 100644 --- a/backend/groth16/bls12-377/setup.go +++ b/backend/groth16/bls12-377/setup.go @@ -291,8 +291,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + if len(commitmentBases) > 0 { + pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) + vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + } for i := range commitmentBases { comPKey, comVKey, err := pedersen.Setup(commitmentBases[i:i+1], pedersen.WithG2Point(cG2)) if err != nil { diff --git a/backend/groth16/bls12-381/marshal.go b/backend/groth16/bls12-381/marshal.go index 8a34d864fd..57e0af99f3 100644 --- a/backend/groth16/bls12-381/marshal.go +++ b/backend/groth16/bls12-381/marshal.go @@ -22,6 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" "github.com/consensys/gnark-crypto/utils/unsafe" "github.com/consensys/gnark/internal/utils" + + "fmt" "io" ) @@ -196,35 +198,39 @@ func (vk *VerifyingKey) readFrom(r io.Reader, raw bool) (int64, error) { &nbCommitments, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err + return dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, nbCommitments) var n int64 - for i := range vk.CommitmentKeys { + for i := 0; i < int(nbCommitments); i++ { var ( m int64 err error ) + commitmentKey := pedersen.VerifyingKey{} if raw { - m, err = vk.CommitmentKeys[i].UnsafeReadFrom(r) + m, err = commitmentKey.UnsafeReadFrom(r) } else { - m, err = vk.CommitmentKeys[i].ReadFrom(r) + m, err = commitmentKey.ReadFrom(r) } n += m if err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + vk.CommitmentKeys = append(vk.CommitmentKeys, commitmentKey) + } + if len(vk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(vk.CommitmentKeys)) } // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 if err := vk.Precompute(); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("precompute: %w", err) } return n + dec.BytesRead(), nil @@ -320,7 +326,7 @@ func (pk *ProvingKey) UnsafeReadFrom(r io.Reader) (int64, error) { func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { n, err := pk.Domain.ReadFrom(r) if err != nil { - return n, err + return n, fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, decOptions...) @@ -344,31 +350,34 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read nbCommitments: %w", err) } - - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - n2, err := pk.CommitmentKeys[i].ReadFrom(r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + n2, err := cpkey.ReadFrom(r) n += n2 if err != nil { - return n, err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return n + dec.BytesRead(), nil @@ -451,11 +460,11 @@ func (pk *ProvingKey) WriteDump(w io.Writer) error { func (pk *ProvingKey) ReadDump(r io.Reader) error { // read the marker to fail early in case of malformed input if err := unsafe.ReadMarker(r); err != nil { - return err + return fmt.Errorf("read marker: %w", err) } if _, err := pk.Domain.ReadFrom(r); err != nil { - return err + return fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, curve.NoSubgroupChecks()) @@ -479,57 +488,61 @@ func (pk *ProvingKey) ReadDump(r io.Reader) error { &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return err + return fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return err + return fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return err + return fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return err + return fmt.Errorf("read nbCommitments: %w", err) } // read slices of points var err error pk.G1.A, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.A: %w", err) } pk.G1.B, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.B: %w", err) } pk.G1.Z, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.Z: %w", err) } pk.G1.K, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.K: %w", err) } pk.G2.B, _, err = unsafe.ReadSlice[[]curve.G2Affine](r) if err != nil { - return err + return fmt.Errorf("read G2.B: %w", err) } - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - pk.CommitmentKeys[i].Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + cpkey.Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basis %d: %w", i, err) } - pk.CommitmentKeys[i].BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + cpkey.BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basisExpSigma %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return nil diff --git a/backend/groth16/bls12-381/marshal_test.go b/backend/groth16/bls12-381/marshal_test.go index 83b2008995..1f62e84659 100644 --- a/backend/groth16/bls12-381/marshal_test.go +++ b/backend/groth16/bls12-381/marshal_test.go @@ -98,7 +98,6 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } - vk.CommitmentKeys = []pedersen.VerifyingKey{} if withCommitment { vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) @@ -108,7 +107,7 @@ func TestVerifyingKeySerialization(t *testing.T) { for j := range bases[i] { bases[i][j] = elem elem.Add(&elem, &p1) - vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigma: p2}) + vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigmaNeg: p2}) } } assert.NoError(t, err) @@ -175,17 +174,18 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true - pedersenBasis := make([]curve.G1Affine, nbCommitment) - pedersenBases := make([][]curve.G1Affine, nbCommitment) - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) - for i := range pedersenBasis { - pedersenBasis[i] = p1 - pedersenBases[i] = pedersenBasis[:i+1] - } - { - var err error - pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) - require.NoError(t, err) + if nbCommitment > 0 { + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) + require.NoError(t, err) + } } if err := io.RoundTripCheck(&pk, func() any { return new(ProvingKey) }); err != nil { diff --git a/backend/groth16/bls12-381/setup.go b/backend/groth16/bls12-381/setup.go index 5c2f198ff9..bd0511a58e 100644 --- a/backend/groth16/bls12-381/setup.go +++ b/backend/groth16/bls12-381/setup.go @@ -291,8 +291,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + if len(commitmentBases) > 0 { + pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) + vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + } for i := range commitmentBases { comPKey, comVKey, err := pedersen.Setup(commitmentBases[i:i+1], pedersen.WithG2Point(cG2)) if err != nil { diff --git a/backend/groth16/bls24-315/marshal.go b/backend/groth16/bls24-315/marshal.go index 32cbca8368..2684cca417 100644 --- a/backend/groth16/bls24-315/marshal.go +++ b/backend/groth16/bls24-315/marshal.go @@ -22,6 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" "github.com/consensys/gnark-crypto/utils/unsafe" "github.com/consensys/gnark/internal/utils" + + "fmt" "io" ) @@ -196,35 +198,39 @@ func (vk *VerifyingKey) readFrom(r io.Reader, raw bool) (int64, error) { &nbCommitments, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err + return dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, nbCommitments) var n int64 - for i := range vk.CommitmentKeys { + for i := 0; i < int(nbCommitments); i++ { var ( m int64 err error ) + commitmentKey := pedersen.VerifyingKey{} if raw { - m, err = vk.CommitmentKeys[i].UnsafeReadFrom(r) + m, err = commitmentKey.UnsafeReadFrom(r) } else { - m, err = vk.CommitmentKeys[i].ReadFrom(r) + m, err = commitmentKey.ReadFrom(r) } n += m if err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + vk.CommitmentKeys = append(vk.CommitmentKeys, commitmentKey) + } + if len(vk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(vk.CommitmentKeys)) } // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 if err := vk.Precompute(); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("precompute: %w", err) } return n + dec.BytesRead(), nil @@ -320,7 +326,7 @@ func (pk *ProvingKey) UnsafeReadFrom(r io.Reader) (int64, error) { func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { n, err := pk.Domain.ReadFrom(r) if err != nil { - return n, err + return n, fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, decOptions...) @@ -344,31 +350,34 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read nbCommitments: %w", err) } - - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - n2, err := pk.CommitmentKeys[i].ReadFrom(r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + n2, err := cpkey.ReadFrom(r) n += n2 if err != nil { - return n, err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return n + dec.BytesRead(), nil @@ -451,11 +460,11 @@ func (pk *ProvingKey) WriteDump(w io.Writer) error { func (pk *ProvingKey) ReadDump(r io.Reader) error { // read the marker to fail early in case of malformed input if err := unsafe.ReadMarker(r); err != nil { - return err + return fmt.Errorf("read marker: %w", err) } if _, err := pk.Domain.ReadFrom(r); err != nil { - return err + return fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, curve.NoSubgroupChecks()) @@ -479,57 +488,61 @@ func (pk *ProvingKey) ReadDump(r io.Reader) error { &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return err + return fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return err + return fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return err + return fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return err + return fmt.Errorf("read nbCommitments: %w", err) } // read slices of points var err error pk.G1.A, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.A: %w", err) } pk.G1.B, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.B: %w", err) } pk.G1.Z, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.Z: %w", err) } pk.G1.K, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.K: %w", err) } pk.G2.B, _, err = unsafe.ReadSlice[[]curve.G2Affine](r) if err != nil { - return err + return fmt.Errorf("read G2.B: %w", err) } - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - pk.CommitmentKeys[i].Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + cpkey.Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basis %d: %w", i, err) } - pk.CommitmentKeys[i].BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + cpkey.BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basisExpSigma %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return nil diff --git a/backend/groth16/bls24-315/marshal_test.go b/backend/groth16/bls24-315/marshal_test.go index d7e5e2d933..849c5f684e 100644 --- a/backend/groth16/bls24-315/marshal_test.go +++ b/backend/groth16/bls24-315/marshal_test.go @@ -98,7 +98,6 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } - vk.CommitmentKeys = []pedersen.VerifyingKey{} if withCommitment { vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) @@ -108,7 +107,7 @@ func TestVerifyingKeySerialization(t *testing.T) { for j := range bases[i] { bases[i][j] = elem elem.Add(&elem, &p1) - vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigma: p2}) + vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigmaNeg: p2}) } } assert.NoError(t, err) @@ -175,17 +174,18 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true - pedersenBasis := make([]curve.G1Affine, nbCommitment) - pedersenBases := make([][]curve.G1Affine, nbCommitment) - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) - for i := range pedersenBasis { - pedersenBasis[i] = p1 - pedersenBases[i] = pedersenBasis[:i+1] - } - { - var err error - pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) - require.NoError(t, err) + if nbCommitment > 0 { + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) + require.NoError(t, err) + } } if err := io.RoundTripCheck(&pk, func() any { return new(ProvingKey) }); err != nil { diff --git a/backend/groth16/bls24-315/setup.go b/backend/groth16/bls24-315/setup.go index 6b3f193999..7f8147b3c0 100644 --- a/backend/groth16/bls24-315/setup.go +++ b/backend/groth16/bls24-315/setup.go @@ -291,8 +291,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + if len(commitmentBases) > 0 { + pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) + vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + } for i := range commitmentBases { comPKey, comVKey, err := pedersen.Setup(commitmentBases[i:i+1], pedersen.WithG2Point(cG2)) if err != nil { diff --git a/backend/groth16/bls24-317/marshal.go b/backend/groth16/bls24-317/marshal.go index c1aa622e0c..d322edd46e 100644 --- a/backend/groth16/bls24-317/marshal.go +++ b/backend/groth16/bls24-317/marshal.go @@ -22,6 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" "github.com/consensys/gnark-crypto/utils/unsafe" "github.com/consensys/gnark/internal/utils" + + "fmt" "io" ) @@ -196,35 +198,39 @@ func (vk *VerifyingKey) readFrom(r io.Reader, raw bool) (int64, error) { &nbCommitments, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err + return dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, nbCommitments) var n int64 - for i := range vk.CommitmentKeys { + for i := 0; i < int(nbCommitments); i++ { var ( m int64 err error ) + commitmentKey := pedersen.VerifyingKey{} if raw { - m, err = vk.CommitmentKeys[i].UnsafeReadFrom(r) + m, err = commitmentKey.UnsafeReadFrom(r) } else { - m, err = vk.CommitmentKeys[i].ReadFrom(r) + m, err = commitmentKey.ReadFrom(r) } n += m if err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + vk.CommitmentKeys = append(vk.CommitmentKeys, commitmentKey) + } + if len(vk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(vk.CommitmentKeys)) } // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 if err := vk.Precompute(); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("precompute: %w", err) } return n + dec.BytesRead(), nil @@ -320,7 +326,7 @@ func (pk *ProvingKey) UnsafeReadFrom(r io.Reader) (int64, error) { func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { n, err := pk.Domain.ReadFrom(r) if err != nil { - return n, err + return n, fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, decOptions...) @@ -344,31 +350,34 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read nbCommitments: %w", err) } - - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - n2, err := pk.CommitmentKeys[i].ReadFrom(r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + n2, err := cpkey.ReadFrom(r) n += n2 if err != nil { - return n, err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return n + dec.BytesRead(), nil @@ -451,11 +460,11 @@ func (pk *ProvingKey) WriteDump(w io.Writer) error { func (pk *ProvingKey) ReadDump(r io.Reader) error { // read the marker to fail early in case of malformed input if err := unsafe.ReadMarker(r); err != nil { - return err + return fmt.Errorf("read marker: %w", err) } if _, err := pk.Domain.ReadFrom(r); err != nil { - return err + return fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, curve.NoSubgroupChecks()) @@ -479,57 +488,61 @@ func (pk *ProvingKey) ReadDump(r io.Reader) error { &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return err + return fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return err + return fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return err + return fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return err + return fmt.Errorf("read nbCommitments: %w", err) } // read slices of points var err error pk.G1.A, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.A: %w", err) } pk.G1.B, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.B: %w", err) } pk.G1.Z, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.Z: %w", err) } pk.G1.K, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.K: %w", err) } pk.G2.B, _, err = unsafe.ReadSlice[[]curve.G2Affine](r) if err != nil { - return err + return fmt.Errorf("read G2.B: %w", err) } - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - pk.CommitmentKeys[i].Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + cpkey.Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basis %d: %w", i, err) } - pk.CommitmentKeys[i].BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + cpkey.BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basisExpSigma %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return nil diff --git a/backend/groth16/bls24-317/marshal_test.go b/backend/groth16/bls24-317/marshal_test.go index b105bc2181..092b3beb4c 100644 --- a/backend/groth16/bls24-317/marshal_test.go +++ b/backend/groth16/bls24-317/marshal_test.go @@ -98,7 +98,6 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } - vk.CommitmentKeys = []pedersen.VerifyingKey{} if withCommitment { vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) @@ -108,7 +107,7 @@ func TestVerifyingKeySerialization(t *testing.T) { for j := range bases[i] { bases[i][j] = elem elem.Add(&elem, &p1) - vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigma: p2}) + vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigmaNeg: p2}) } } assert.NoError(t, err) @@ -175,17 +174,18 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true - pedersenBasis := make([]curve.G1Affine, nbCommitment) - pedersenBases := make([][]curve.G1Affine, nbCommitment) - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) - for i := range pedersenBasis { - pedersenBasis[i] = p1 - pedersenBases[i] = pedersenBasis[:i+1] - } - { - var err error - pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) - require.NoError(t, err) + if nbCommitment > 0 { + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) + require.NoError(t, err) + } } if err := io.RoundTripCheck(&pk, func() any { return new(ProvingKey) }); err != nil { diff --git a/backend/groth16/bls24-317/setup.go b/backend/groth16/bls24-317/setup.go index 53628f4c5e..c8db3321c2 100644 --- a/backend/groth16/bls24-317/setup.go +++ b/backend/groth16/bls24-317/setup.go @@ -291,8 +291,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + if len(commitmentBases) > 0 { + pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) + vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + } for i := range commitmentBases { comPKey, comVKey, err := pedersen.Setup(commitmentBases[i:i+1], pedersen.WithG2Point(cG2)) if err != nil { diff --git a/backend/groth16/bn254/marshal.go b/backend/groth16/bn254/marshal.go index d1c6ba186f..5005269eba 100644 --- a/backend/groth16/bn254/marshal.go +++ b/backend/groth16/bn254/marshal.go @@ -22,6 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" "github.com/consensys/gnark-crypto/utils/unsafe" "github.com/consensys/gnark/internal/utils" + + "fmt" "io" ) @@ -196,35 +198,39 @@ func (vk *VerifyingKey) readFrom(r io.Reader, raw bool) (int64, error) { &nbCommitments, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err + return dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, nbCommitments) var n int64 - for i := range vk.CommitmentKeys { + for i := 0; i < int(nbCommitments); i++ { var ( m int64 err error ) + commitmentKey := pedersen.VerifyingKey{} if raw { - m, err = vk.CommitmentKeys[i].UnsafeReadFrom(r) + m, err = commitmentKey.UnsafeReadFrom(r) } else { - m, err = vk.CommitmentKeys[i].ReadFrom(r) + m, err = commitmentKey.ReadFrom(r) } n += m if err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + vk.CommitmentKeys = append(vk.CommitmentKeys, commitmentKey) + } + if len(vk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(vk.CommitmentKeys)) } // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 if err := vk.Precompute(); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("precompute: %w", err) } return n + dec.BytesRead(), nil @@ -320,7 +326,7 @@ func (pk *ProvingKey) UnsafeReadFrom(r io.Reader) (int64, error) { func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { n, err := pk.Domain.ReadFrom(r) if err != nil { - return n, err + return n, fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, decOptions...) @@ -344,31 +350,34 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read nbCommitments: %w", err) } - - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - n2, err := pk.CommitmentKeys[i].ReadFrom(r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + n2, err := cpkey.ReadFrom(r) n += n2 if err != nil { - return n, err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return n + dec.BytesRead(), nil @@ -451,11 +460,11 @@ func (pk *ProvingKey) WriteDump(w io.Writer) error { func (pk *ProvingKey) ReadDump(r io.Reader) error { // read the marker to fail early in case of malformed input if err := unsafe.ReadMarker(r); err != nil { - return err + return fmt.Errorf("read marker: %w", err) } if _, err := pk.Domain.ReadFrom(r); err != nil { - return err + return fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, curve.NoSubgroupChecks()) @@ -479,57 +488,61 @@ func (pk *ProvingKey) ReadDump(r io.Reader) error { &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return err + return fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return err + return fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return err + return fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return err + return fmt.Errorf("read nbCommitments: %w", err) } // read slices of points var err error pk.G1.A, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.A: %w", err) } pk.G1.B, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.B: %w", err) } pk.G1.Z, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.Z: %w", err) } pk.G1.K, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.K: %w", err) } pk.G2.B, _, err = unsafe.ReadSlice[[]curve.G2Affine](r) if err != nil { - return err + return fmt.Errorf("read G2.B: %w", err) } - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - pk.CommitmentKeys[i].Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + cpkey.Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basis %d: %w", i, err) } - pk.CommitmentKeys[i].BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + cpkey.BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basisExpSigma %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return nil diff --git a/backend/groth16/bn254/marshal_test.go b/backend/groth16/bn254/marshal_test.go index d59c54e85f..65ded1d4bf 100644 --- a/backend/groth16/bn254/marshal_test.go +++ b/backend/groth16/bn254/marshal_test.go @@ -98,7 +98,6 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } - vk.CommitmentKeys = []pedersen.VerifyingKey{} if withCommitment { vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) @@ -108,7 +107,7 @@ func TestVerifyingKeySerialization(t *testing.T) { for j := range bases[i] { bases[i][j] = elem elem.Add(&elem, &p1) - vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigma: p2}) + vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigmaNeg: p2}) } } assert.NoError(t, err) @@ -175,17 +174,18 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true - pedersenBasis := make([]curve.G1Affine, nbCommitment) - pedersenBases := make([][]curve.G1Affine, nbCommitment) - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) - for i := range pedersenBasis { - pedersenBasis[i] = p1 - pedersenBases[i] = pedersenBasis[:i+1] - } - { - var err error - pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) - require.NoError(t, err) + if nbCommitment > 0 { + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) + require.NoError(t, err) + } } if err := io.RoundTripCheck(&pk, func() any { return new(ProvingKey) }); err != nil { diff --git a/backend/groth16/bn254/setup.go b/backend/groth16/bn254/setup.go index 13ddcd61d3..0af907355a 100644 --- a/backend/groth16/bn254/setup.go +++ b/backend/groth16/bn254/setup.go @@ -291,8 +291,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + if len(commitmentBases) > 0 { + pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) + vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + } for i := range commitmentBases { comPKey, comVKey, err := pedersen.Setup(commitmentBases[i:i+1], pedersen.WithG2Point(cG2)) if err != nil { diff --git a/backend/groth16/bn254/solidity.go b/backend/groth16/bn254/solidity.go index 47793bdf7e..95dfb9a696 100644 --- a/backend/groth16/bn254/solidity.go +++ b/backend/groth16/bn254/solidity.go @@ -104,11 +104,11 @@ contract Verifier { uint256 constant PEDERSEN_G_Y_0 = {{ (fpstr $cmtVk0.G.Y.A0) }}; uint256 constant PEDERSEN_G_Y_1 = {{ (fpstr $cmtVk0.G.Y.A1) }}; - // Pedersen GSigma point in G2 in powers of i - uint256 constant PEDERSEN_GSIGMA_X_0 = {{ (fpstr $cmtVk0.GSigma.X.A0) }}; - uint256 constant PEDERSEN_GSIGMA_X_1 = {{ (fpstr $cmtVk0.GSigma.X.A1) }}; - uint256 constant PEDERSEN_GSIGMA_Y_0 = {{ (fpstr $cmtVk0.GSigma.Y.A0) }}; - uint256 constant PEDERSEN_GSIGMA_Y_1 = {{ (fpstr $cmtVk0.GSigma.Y.A1) }}; + // Pedersen GSigmaNeg point in G2 in powers of i + uint256 constant PEDERSEN_GSIGMANEG_X_0 = {{ (fpstr $cmtVk0.GSigmaNeg.X.A0) }}; + uint256 constant PEDERSEN_GSIGMANEG_X_1 = {{ (fpstr $cmtVk0.GSigmaNeg.X.A1) }}; + uint256 constant PEDERSEN_GSIGMANEG_Y_0 = {{ (fpstr $cmtVk0.GSigmaNeg.Y.A0) }}; + uint256 constant PEDERSEN_GSIGMANEG_Y_1 = {{ (fpstr $cmtVk0.GSigmaNeg.Y.A1) }}; {{- end }} // Constant and public input points @@ -579,10 +579,10 @@ contract Verifier { // Commitments pairings[ 0] = commitments[0]; pairings[ 1] = commitments[1]; - pairings[ 2] = PEDERSEN_GSIGMA_X_1; - pairings[ 3] = PEDERSEN_GSIGMA_X_0; - pairings[ 4] = PEDERSEN_GSIGMA_Y_1; - pairings[ 5] = PEDERSEN_GSIGMA_Y_0; + pairings[ 2] = PEDERSEN_GSIGMANEG_X_1; + pairings[ 3] = PEDERSEN_GSIGMANEG_X_0; + pairings[ 4] = PEDERSEN_GSIGMANEG_Y_1; + pairings[ 5] = PEDERSEN_GSIGMANEG_Y_0; pairings[ 6] = Px; pairings[ 7] = Py; pairings[ 8] = PEDERSEN_G_X_1; @@ -730,10 +730,10 @@ contract Verifier { let f := mload(0x40) calldatacopy(f, commitments, 0x40) // Copy Commitments - mstore(add(f, 0x40), PEDERSEN_GSIGMA_X_1) - mstore(add(f, 0x60), PEDERSEN_GSIGMA_X_0) - mstore(add(f, 0x80), PEDERSEN_GSIGMA_Y_1) - mstore(add(f, 0xa0), PEDERSEN_GSIGMA_Y_0) + mstore(add(f, 0x40), PEDERSEN_GSIGMANEG_X_1) + mstore(add(f, 0x60), PEDERSEN_GSIGMANEG_X_0) + mstore(add(f, 0x80), PEDERSEN_GSIGMANEG_Y_1) + mstore(add(f, 0xa0), PEDERSEN_GSIGMANEG_Y_0) calldatacopy(add(f, 0xc0), commitmentPok, 0x40) mstore(add(f, 0x100), PEDERSEN_G_X_1) mstore(add(f, 0x120), PEDERSEN_G_X_0) diff --git a/backend/groth16/bw6-633/marshal.go b/backend/groth16/bw6-633/marshal.go index d0d6ff3809..48b6ff1490 100644 --- a/backend/groth16/bw6-633/marshal.go +++ b/backend/groth16/bw6-633/marshal.go @@ -22,6 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" "github.com/consensys/gnark-crypto/utils/unsafe" "github.com/consensys/gnark/internal/utils" + + "fmt" "io" ) @@ -196,35 +198,39 @@ func (vk *VerifyingKey) readFrom(r io.Reader, raw bool) (int64, error) { &nbCommitments, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err + return dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, nbCommitments) var n int64 - for i := range vk.CommitmentKeys { + for i := 0; i < int(nbCommitments); i++ { var ( m int64 err error ) + commitmentKey := pedersen.VerifyingKey{} if raw { - m, err = vk.CommitmentKeys[i].UnsafeReadFrom(r) + m, err = commitmentKey.UnsafeReadFrom(r) } else { - m, err = vk.CommitmentKeys[i].ReadFrom(r) + m, err = commitmentKey.ReadFrom(r) } n += m if err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + vk.CommitmentKeys = append(vk.CommitmentKeys, commitmentKey) + } + if len(vk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(vk.CommitmentKeys)) } // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 if err := vk.Precompute(); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("precompute: %w", err) } return n + dec.BytesRead(), nil @@ -320,7 +326,7 @@ func (pk *ProvingKey) UnsafeReadFrom(r io.Reader) (int64, error) { func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { n, err := pk.Domain.ReadFrom(r) if err != nil { - return n, err + return n, fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, decOptions...) @@ -344,31 +350,34 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read nbCommitments: %w", err) } - - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - n2, err := pk.CommitmentKeys[i].ReadFrom(r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + n2, err := cpkey.ReadFrom(r) n += n2 if err != nil { - return n, err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return n + dec.BytesRead(), nil @@ -451,11 +460,11 @@ func (pk *ProvingKey) WriteDump(w io.Writer) error { func (pk *ProvingKey) ReadDump(r io.Reader) error { // read the marker to fail early in case of malformed input if err := unsafe.ReadMarker(r); err != nil { - return err + return fmt.Errorf("read marker: %w", err) } if _, err := pk.Domain.ReadFrom(r); err != nil { - return err + return fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, curve.NoSubgroupChecks()) @@ -479,57 +488,61 @@ func (pk *ProvingKey) ReadDump(r io.Reader) error { &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return err + return fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return err + return fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return err + return fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return err + return fmt.Errorf("read nbCommitments: %w", err) } // read slices of points var err error pk.G1.A, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.A: %w", err) } pk.G1.B, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.B: %w", err) } pk.G1.Z, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.Z: %w", err) } pk.G1.K, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.K: %w", err) } pk.G2.B, _, err = unsafe.ReadSlice[[]curve.G2Affine](r) if err != nil { - return err + return fmt.Errorf("read G2.B: %w", err) } - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - pk.CommitmentKeys[i].Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + cpkey.Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basis %d: %w", i, err) } - pk.CommitmentKeys[i].BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + cpkey.BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basisExpSigma %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return nil diff --git a/backend/groth16/bw6-633/marshal_test.go b/backend/groth16/bw6-633/marshal_test.go index 53fa6e8573..e85f3eefcf 100644 --- a/backend/groth16/bw6-633/marshal_test.go +++ b/backend/groth16/bw6-633/marshal_test.go @@ -98,7 +98,6 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } - vk.CommitmentKeys = []pedersen.VerifyingKey{} if withCommitment { vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) @@ -108,7 +107,7 @@ func TestVerifyingKeySerialization(t *testing.T) { for j := range bases[i] { bases[i][j] = elem elem.Add(&elem, &p1) - vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigma: p2}) + vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigmaNeg: p2}) } } assert.NoError(t, err) @@ -175,17 +174,18 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true - pedersenBasis := make([]curve.G1Affine, nbCommitment) - pedersenBases := make([][]curve.G1Affine, nbCommitment) - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) - for i := range pedersenBasis { - pedersenBasis[i] = p1 - pedersenBases[i] = pedersenBasis[:i+1] - } - { - var err error - pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) - require.NoError(t, err) + if nbCommitment > 0 { + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) + require.NoError(t, err) + } } if err := io.RoundTripCheck(&pk, func() any { return new(ProvingKey) }); err != nil { diff --git a/backend/groth16/bw6-633/setup.go b/backend/groth16/bw6-633/setup.go index d26cd7c324..a569df46fb 100644 --- a/backend/groth16/bw6-633/setup.go +++ b/backend/groth16/bw6-633/setup.go @@ -291,8 +291,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + if len(commitmentBases) > 0 { + pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) + vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + } for i := range commitmentBases { comPKey, comVKey, err := pedersen.Setup(commitmentBases[i:i+1], pedersen.WithG2Point(cG2)) if err != nil { diff --git a/backend/groth16/bw6-761/marshal.go b/backend/groth16/bw6-761/marshal.go index a5766ae5b3..4aebb1deb1 100644 --- a/backend/groth16/bw6-761/marshal.go +++ b/backend/groth16/bw6-761/marshal.go @@ -22,6 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" "github.com/consensys/gnark-crypto/utils/unsafe" "github.com/consensys/gnark/internal/utils" + + "fmt" "io" ) @@ -196,35 +198,39 @@ func (vk *VerifyingKey) readFrom(r io.Reader, raw bool) (int64, error) { &nbCommitments, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err + return dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, nbCommitments) var n int64 - for i := range vk.CommitmentKeys { + for i := 0; i < int(nbCommitments); i++ { var ( m int64 err error ) + commitmentKey := pedersen.VerifyingKey{} if raw { - m, err = vk.CommitmentKeys[i].UnsafeReadFrom(r) + m, err = commitmentKey.UnsafeReadFrom(r) } else { - m, err = vk.CommitmentKeys[i].ReadFrom(r) + m, err = commitmentKey.ReadFrom(r) } n += m if err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + vk.CommitmentKeys = append(vk.CommitmentKeys, commitmentKey) + } + if len(vk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(vk.CommitmentKeys)) } // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 if err := vk.Precompute(); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("precompute: %w", err) } return n + dec.BytesRead(), nil @@ -320,7 +326,7 @@ func (pk *ProvingKey) UnsafeReadFrom(r io.Reader) (int64, error) { func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { n, err := pk.Domain.ReadFrom(r) if err != nil { - return n, err + return n, fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, decOptions...) @@ -344,31 +350,34 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return n + dec.BytesRead(), err + return n + dec.BytesRead(), fmt.Errorf("read nbCommitments: %w", err) } - - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - n2, err := pk.CommitmentKeys[i].ReadFrom(r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + n2, err := cpkey.ReadFrom(r) n += n2 if err != nil { - return n, err + return n + dec.BytesRead(), fmt.Errorf("read commitment key %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return n + dec.BytesRead(), fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return n + dec.BytesRead(), nil @@ -451,11 +460,11 @@ func (pk *ProvingKey) WriteDump(w io.Writer) error { func (pk *ProvingKey) ReadDump(r io.Reader) error { // read the marker to fail early in case of malformed input if err := unsafe.ReadMarker(r); err != nil { - return err + return fmt.Errorf("read marker: %w", err) } if _, err := pk.Domain.ReadFrom(r); err != nil { - return err + return fmt.Errorf("read domain: %w", err) } dec := curve.NewDecoder(r, curve.NoSubgroupChecks()) @@ -479,57 +488,61 @@ func (pk *ProvingKey) ReadDump(r io.Reader) error { &pk.NbInfinityB, } - for _, v := range toDecode { + for i, v := range toDecode { if err := dec.Decode(v); err != nil { - return err + return fmt.Errorf("read field %d: %w", i, err) } } pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) if err := dec.Decode(&pk.InfinityA); err != nil { - return err + return fmt.Errorf("read InfinityA: %w", err) } if err := dec.Decode(&pk.InfinityB); err != nil { - return err + return fmt.Errorf("read InfinityB: %w", err) } if err := dec.Decode(&nbCommitments); err != nil { - return err + return fmt.Errorf("read nbCommitments: %w", err) } // read slices of points var err error pk.G1.A, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.A: %w", err) } pk.G1.B, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.B: %w", err) } pk.G1.Z, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.Z: %w", err) } pk.G1.K, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read G1.K: %w", err) } pk.G2.B, _, err = unsafe.ReadSlice[[]curve.G2Affine](r) if err != nil { - return err + return fmt.Errorf("read G2.B: %w", err) } - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) - for i := range pk.CommitmentKeys { - pk.CommitmentKeys[i].Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + for i := 0; i < int(nbCommitments); i++ { + cpkey := pedersen.ProvingKey{} + cpkey.Basis, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basis %d: %w", i, err) } - pk.CommitmentKeys[i].BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) + cpkey.BasisExpSigma, _, err = unsafe.ReadSlice[[]curve.G1Affine](r) if err != nil { - return err + return fmt.Errorf("read commitment basisExpSigma %d: %w", i, err) } + pk.CommitmentKeys = append(pk.CommitmentKeys, cpkey) + } + if len(pk.CommitmentKeys) != int(nbCommitments) { + return fmt.Errorf("invalid number of commitment keys. Expected %d got %d", nbCommitments, len(pk.CommitmentKeys)) } return nil diff --git a/backend/groth16/bw6-761/marshal_test.go b/backend/groth16/bw6-761/marshal_test.go index ab8fbc8667..36e299d5c2 100644 --- a/backend/groth16/bw6-761/marshal_test.go +++ b/backend/groth16/bw6-761/marshal_test.go @@ -98,7 +98,6 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } - vk.CommitmentKeys = []pedersen.VerifyingKey{} if withCommitment { vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) @@ -108,7 +107,7 @@ func TestVerifyingKeySerialization(t *testing.T) { for j := range bases[i] { bases[i][j] = elem elem.Add(&elem, &p1) - vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigma: p2}) + vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigmaNeg: p2}) } } assert.NoError(t, err) @@ -175,17 +174,18 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true - pedersenBasis := make([]curve.G1Affine, nbCommitment) - pedersenBases := make([][]curve.G1Affine, nbCommitment) - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) - for i := range pedersenBasis { - pedersenBasis[i] = p1 - pedersenBases[i] = pedersenBasis[:i+1] - } - { - var err error - pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) - require.NoError(t, err) + if nbCommitment > 0 { + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) + require.NoError(t, err) + } } if err := io.RoundTripCheck(&pk, func() any { return new(ProvingKey) }); err != nil { diff --git a/backend/groth16/bw6-761/setup.go b/backend/groth16/bw6-761/setup.go index 0988613d36..aeb2483c14 100644 --- a/backend/groth16/bw6-761/setup.go +++ b/backend/groth16/bw6-761/setup.go @@ -291,8 +291,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) - vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + if len(commitmentBases) > 0 { + pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) + vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + } for i := range commitmentBases { comPKey, comVKey, err := pedersen.Setup(commitmentBases[i:i+1], pedersen.WithG2Point(cG2)) if err != nil { diff --git a/examples/inputpacking/doc.go b/examples/inputpacking/doc.go new file mode 100644 index 0000000000..a95ed0400f --- /dev/null +++ b/examples/inputpacking/doc.go @@ -0,0 +1,19 @@ +// Package inputpacking illustrates input packing for reducing public input. + +// Usually in a SNARK circuit there are public and private inputs. The public +// inputs are known to the prover and verifier, while the private inputs are +// known only to the prover. To verify the proof, the verifier needs to provide +// the public inputs as an input to the verification algorithm. +// +// However, there are several drawbacks to this approach: +// 1. The public inputs may not be of a convenient format -- this happens for example when using the non-native arithmetic where we work on limbs. +// 2. The verifier work depends on the number of public inputs -- this is a problem in case of a recursive SNARK verifier, making the recursion more expensive. +// 3. The public input needs to be provided as a calldata to the Solidity verifier, which is expensive. +// +// An alternative approach however is to provide only a hash of the public +// inputs to the verifier. This way, if the verifier computes the hash of the +// inputs on its own, it can be sure that the inputs are correct and we can +// mitigate the issues. +// +// This examples how to use this approach for both native and non-native inputs. We use MiMC hash function. +package inputpacking diff --git a/examples/inputpacking/inputpacking_test.go b/examples/inputpacking/inputpacking_test.go new file mode 100644 index 0000000000..79ac23fbb6 --- /dev/null +++ b/examples/inputpacking/inputpacking_test.go @@ -0,0 +1,199 @@ +package inputpacking + +import ( + "crypto/rand" + "fmt" + "math/big" + + fp_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fp" + fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" + cmimc "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" +) + +func inCircuitComputation(api frontend.API, input1, input2 frontend.Variable, expected frontend.Variable) { + res := api.Mul(input1, input2) + api.AssertIsEqual(res, expected) +} + +func inCircuitComputationEmulated(api frontend.API, input1, input2 emulated.Element[emulated.BN254Fp], expected emulated.Element[emulated.BN254Fp]) error { + f, err := emulated.NewField[emulated.BN254Fp](api) + if err != nil { + return err + } + res := f.Mul(&input1, &input2) + f.AssertIsEqual(res, &expected) + return nil +} + +// UnpackedCircuit represents a circuit where all public inputs are given as is +type UnpackedCircuit struct { + Input1, Input2 frontend.Variable `gnark:",public"` + EmulatedInput1, EmulatedInput2 emulated.Element[emulated.BN254Fp] `gnark:",public"` + Output frontend.Variable `gnark:",private"` + EmulatedOutput emulated.Element[emulated.BN254Fp] `gnark:",private"` +} + +func (circuit *UnpackedCircuit) Define(api frontend.API) error { + inCircuitComputation(api, circuit.Input1, circuit.Input2, circuit.Output) + return inCircuitComputationEmulated(api, circuit.EmulatedInput1, circuit.EmulatedInput2, circuit.EmulatedOutput) +} + +// PackedCircuit represents a circuit where all public inputs are given as private instead and we provide a hash of them as the only public input. +type PackedCircuit struct { + PublicHash frontend.Variable + + Input1, Input2 frontend.Variable `gnark:",private"` + EmulatedInput1, EmulatedInput2 emulated.Element[emulated.BN254Fp] `gnark:",private"` + Output frontend.Variable `gnark:",private"` + EmulatedOutput emulated.Element[emulated.BN254Fp] `gnark:",private"` +} + +func (circuit *PackedCircuit) Define(api frontend.API) error { + h, err := mimc.NewMiMC(api) + if err != nil { + return err + } + h.Write(circuit.Input1) + h.Write(circuit.Input2) + h.Write(circuit.EmulatedInput1.Limbs...) + h.Write(circuit.EmulatedInput2.Limbs...) + dgst := h.Sum() + api.AssertIsEqual(dgst, circuit.PublicHash) + + inCircuitComputation(api, circuit.Input1, circuit.Input2, circuit.Output) + return inCircuitComputationEmulated(api, circuit.EmulatedInput1, circuit.EmulatedInput2, circuit.EmulatedOutput) +} + +func Example() { + modulusNative := ecc.BN254.ScalarField() + modulusEmulated := ecc.BN254.BaseField() + + // declare inputs + input1, err := rand.Int(rand.Reader, modulusNative) + if err != nil { + panic(err) + } + input2, err := rand.Int(rand.Reader, modulusNative) + if err != nil { + panic(err) + } + emulatedInput1, err := rand.Int(rand.Reader, modulusEmulated) + if err != nil { + panic(err) + } + emulatedInput2, err := rand.Int(rand.Reader, modulusEmulated) + if err != nil { + panic(err) + } + output := new(big.Int).Mul(input1, input2) + output.Mod(output, modulusNative) + emulatedOutput := new(big.Int).Mul(emulatedInput1, emulatedInput2) + emulatedOutput.Mod(emulatedOutput, modulusEmulated) + + // first we run the circuit where public inputs are not packed + assignment := &UnpackedCircuit{ + Input1: input1, + Input2: input2, + EmulatedInput1: emulated.ValueOf[emparams.BN254Fp](emulatedInput1), + EmulatedInput2: emulated.ValueOf[emparams.BN254Fp](emulatedInput2), + Output: output, + EmulatedOutput: emulated.ValueOf[emparams.BN254Fp](emulatedOutput), + } + privWit, err := frontend.NewWitness(assignment, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } + publicWit, err := frontend.NewWitness(assignment, ecc.BN254.ScalarField(), frontend.PublicOnly()) + if err != nil { + panic(err) + } + + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &UnpackedCircuit{}) + if err != nil { + panic(err) + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } + proof, err := groth16.Prove(ccs, pk, privWit) + if err != nil { + panic(err) + } + err = groth16.Verify(proof, vk, publicWit) + if err != nil { + panic(err) + } + + // print the number of public inputs when we provide all public inputs. Note that we also count the commitment here. + fmt.Println("unpacked public variables:", ccs.GetNbPublicVariables()) + + // then we run the circuit where public inputs are packed + var buf [fr_bn254.Bytes]byte + var buf2 [fp_bn254.Bytes]byte + h := cmimc.NewMiMC() + input1.FillBytes(buf[:]) + h.Write(buf[:]) + input2.FillBytes(buf[:]) + h.Write(buf[:]) + emulatedInput1.FillBytes(buf2[:]) + h.Write(buf2[24:32]) + h.Write(buf2[16:24]) + h.Write(buf2[8:16]) + h.Write(buf2[0:8]) + emulatedInput2.FillBytes(buf2[:]) + h.Write(buf2[24:32]) + h.Write(buf2[16:24]) + h.Write(buf2[8:16]) + h.Write(buf2[0:8]) + + dgst := h.Sum(nil) + phash := new(big.Int).SetBytes(dgst) + + assignment2 := &PackedCircuit{ + PublicHash: phash, + Input1: input1, + Input2: input2, + EmulatedInput1: emulated.ValueOf[emparams.BN254Fp](emulatedInput1), + EmulatedInput2: emulated.ValueOf[emparams.BN254Fp](emulatedInput2), + Output: output, + EmulatedOutput: emulated.ValueOf[emparams.BN254Fp](emulatedOutput), + } + privWit2, err := frontend.NewWitness(assignment2, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } + publicWit2, err := frontend.NewWitness(assignment2, ecc.BN254.ScalarField(), frontend.PublicOnly()) + if err != nil { + panic(err) + } + + ccs2, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &PackedCircuit{}) + if err != nil { + panic(err) + } + pk2, vk2, err := groth16.Setup(ccs2) + if err != nil { + panic(err) + } + proof2, err := groth16.Prove(ccs2, pk2, privWit2) + if err != nil { + panic(err) + } + err = groth16.Verify(proof2, vk2, publicWit2) + if err != nil { + panic(err) + } + // print the number of public inputs when we provide only the hash. Note that we also count the commitment here. + fmt.Println("packed public variables:", ccs2.GetNbPublicVariables()) + // output: unpacked public variables: 11 + // packed public variables: 1 +} diff --git a/examples/sudoku/doc.go b/examples/sudoku/doc.go new file mode 100644 index 0000000000..d368228cd1 --- /dev/null +++ b/examples/sudoku/doc.go @@ -0,0 +1,14 @@ +// Package sudoku implements a Sudoku circuit using gnark. +// +// [Sudoku] is a popular puzzle to fill a 9x9 grid with digits so that each +// column, each row, and each of the nine 3x3 sub-grids that compose the grid +// contain all of the digits from 1 to 9. This package provides a circuit that +// verifies a solution to a Sudoku puzzle. +// +// See the included full example on how to define the circuit, run the setup, +// generate proof and verify the proof. This example also demonstrates how to +// serialize and deserialize the values produced during setup and proof +// generation. +// +// [Sudoku]: https://en.wikipedia.org/wiki/Sudoku +package sudoku diff --git a/examples/sudoku/sudoku_example_test.go b/examples/sudoku/sudoku_example_test.go new file mode 100644 index 0000000000..1db9498e9f --- /dev/null +++ b/examples/sudoku/sudoku_example_test.go @@ -0,0 +1,254 @@ +package sudoku + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + cs_bn254 "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" +) + +// SudokuCircuit represents a Sudoku circuit. It contains two grids: the +// challenge and solution grids (named Challenge and Solution respectively). The +// challenge grid is public, while the solution grid is private. +type SudokuCircuit struct { + Challenge SudokuGrid `gnark:"Challenge,public"` + Solution SudokuGrid `gnark:"Solution,secret"` +} + +// SudokuGrid represents a 9x9 Sudoku grid in-circuit. +type SudokuGrid [9][9]frontend.Variable + +// Define defines the constraints of the Sudoku circuit. +func (circuit *SudokuCircuit) Define(api frontend.API) error { + // Constraint 1: Each cell value in the CompleteGrid must be between 1 and 9 + for i := 0; i < 9; i++ { + for j := 0; j < 9; j++ { + api.AssertIsLessOrEqual(circuit.Solution[i][j], 9) + api.AssertIsLessOrEqual(1, circuit.Solution[i][j]) + } + } + + // Constraint 2: Each row in the CompleteGrid must contain unique values + for i := 0; i < 9; i++ { + for j := 0; j < 9; j++ { + for k := j + 1; k < 9; k++ { + api.AssertIsDifferent(circuit.Solution[i][j], circuit.Solution[i][k]) + } + } + } + + // Constraint 3: Each column in the CompleteGrid must contain unique values + for j := 0; j < 9; j++ { + for i := 0; i < 9; i++ { + for k := i + 1; k < 9; k++ { + api.AssertIsDifferent(circuit.Solution[i][j], circuit.Solution[k][j]) + } + } + } + + // Constraint 4: Each 3x3 sub-grid in the CompleteGrid must contain unique values + for boxRow := 0; boxRow < 3; boxRow++ { + for boxCol := 0; boxCol < 3; boxCol++ { + for i := 0; i < 9; i++ { + for j := i + 1; j < 9; j++ { + row1 := boxRow*3 + i/3 + col1 := boxCol*3 + i%3 + row2 := boxRow*3 + j/3 + col2 := boxCol*3 + j%3 + api.AssertIsDifferent(circuit.Solution[row1][col1], circuit.Solution[row2][col2]) + } + } + } + } + + // Constraint 5: The values in the IncompleteGrid must match the CompleteGrid where provided + for i := 0; i < 9; i++ { + for j := 0; j < 9; j++ { + isCellGiven := api.IsZero(circuit.Challenge[i][j]) + api.AssertIsEqual(api.Select(isCellGiven, circuit.Solution[i][j], circuit.Challenge[i][j]), circuit.Solution[i][j]) + } + } + + return nil +} + +// SudokuSerialization represents a Sudoku witness out-circuit. Used for serialization. +type SudokuSerialization struct { + Grid [9][9]int `json:"grid"` +} + +// NewSudokuGrid creates a new Sudoku grid from the serialized grid. +func NewSudokuGrid(serialized SudokuSerialization) SudokuGrid { + var grid SudokuGrid + for i := 0; i < 9; i++ { + for j := 0; j < 9; j++ { + grid[i][j] = frontend.Variable(serialized.Grid[i][j]) + } + } + return grid +} + +// setup performs the setup phase of the Sudoku circuit +func setup(ccsWriter, pkWriter, vkWriter io.Writer) error { + // compile the circuit + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &SudokuCircuit{}) + if err != nil { + return fmt.Errorf("failed to compile circuit: %v", err) + } + // perform the setup. NB! In practice use MPC. This is currently UNSAFE + // approach. + pk, vk, err := groth16.Setup(ccs) + if err != nil { + return fmt.Errorf("failed to setup circuit: %v", err) + } + // serialize the circuit, proving key and verifying key + _, err = ccs.WriteTo(ccsWriter) + if err != nil { + return fmt.Errorf("failed to write constraint system: %v", err) + } + _, err = pk.WriteTo(pkWriter) + if err != nil { + return fmt.Errorf("failed to write proving key: %v", err) + } + _, err = vk.WriteTo(vkWriter) + if err != nil { + return fmt.Errorf("failed to write verifying key: %v", err) + } + return nil +} + +// prover performs the prover phase of the Sudoku circuit +func prover(ccsReader, challengeReader, pkReader io.Reader, proofWriter io.Writer) error { + // define the sudoku solution. This is private information known only to the + // prover. We use serialization to represent it. + serializedSolution := `{"grid":[[5,3,4,6,7,8,9,1,2],[6,7,2,1,9,5,3,4,8],[1,9,8,3,4,2,5,6,7],[8,5,9,7,6,1,4,2,3],[4,2,6,8,5,3,7,9,1],[7,1,3,9,2,4,8,5,6],[9,6,1,5,3,7,2,8,4],[2,8,7,4,1,9,6,3,5],[3,4,5,2,8,6,1,7,9]]}` + var nativeSolution SudokuSerialization + err := json.Unmarshal([]byte(serializedSolution), &nativeSolution) + if err != nil { + return fmt.Errorf("failed to unmarshal solution: %v", err) + } + // deserialize the circuit, challenge and proving key + var ccs cs_bn254.R1CS + _, err = ccs.ReadFrom(ccsReader) + if err != nil { + return fmt.Errorf("failed to read constraint system: %v", err) + } + var nativeChallenge SudokuSerialization + err = json.NewDecoder(challengeReader).Decode(&nativeChallenge) + if err != nil { + return fmt.Errorf("failed to read challenge: %v", err) + } + var pk groth16_bn254.ProvingKey + _, err = pk.ReadFrom(pkReader) + if err != nil { + return fmt.Errorf("failed to read proving key: %v", err) + } + + // create the circuit assignments + assignment := &SudokuCircuit{ + Challenge: NewSudokuGrid(nativeChallenge), + Solution: NewSudokuGrid(nativeSolution), + } + // create the witness + witness, err := frontend.NewWitness(assignment, ecc.BN254.ScalarField()) + if err != nil { + return fmt.Errorf("failed to create witness: %v", err) + } + // generate the proof + proof, err := groth16.Prove(&ccs, &pk, witness) + if err != nil { + return fmt.Errorf("failed to generate proof: %v", err) + } + // serialize the proof + _, err = proof.WriteTo(proofWriter) + if err != nil { + return fmt.Errorf("failed to write proof: %v", err) + } + return nil +} + +// verifier performs the verifier phase of the Sudoku circuit +func verifier(challengeReader, vkReader, proofReader io.Reader) error { + // deserialize the challenge, verifying key and proof + var nativeChallenge SudokuSerialization + err := json.NewDecoder(challengeReader).Decode(&nativeChallenge) + if err != nil { + return fmt.Errorf("failed to read challenge: %v", err) + } + var vk groth16_bn254.VerifyingKey + _, err = vk.ReadFrom(vkReader) + if err != nil { + return fmt.Errorf("failed to read verifying key: %v", err) + } + var proof groth16_bn254.Proof + _, err = proof.ReadFrom(proofReader) + if err != nil { + return fmt.Errorf("failed to read proof: %v", err) + } + // create the circuit assignment + assignment := &SudokuCircuit{ + Challenge: NewSudokuGrid(nativeChallenge), + } + // create the public witness + pubWit, err := frontend.NewWitness(assignment, ecc.BN254.ScalarField(), frontend.PublicOnly()) + if err != nil { + return fmt.Errorf("failed to create public witness: %v", err) + } + // verify the proof + err = groth16.Verify(&proof, &vk, pubWit) + if err != nil { + return fmt.Errorf("failed to verify proof: %v", err) + } + return nil +} + +// This example demonstrates how to implement a Sudoku challenge verification +// circuit such that the solution stays private. +// +// This example also demonstrates how to serialize and deserialize the values +// produced during setup and proof generation. +func Example() { + // define the sudoku challenge. This is public information. We use + // serialization to represent it. + serializedChallengeBytes := `{"grid":[[5,3,0,0,7,0,0,0,0],[6,0,0,1,9,5,0,0,0],[0,9,8,0,0,0,0,6,0],[8,0,0,0,6,0,0,0,3],[4,0,0,8,0,3,0,0,1],[7,0,0,0,2,0,0,0,6],[0,6,0,0,0,0,2,8,0],[0,0,0,4,1,9,0,0,5],[0,0,0,0,8,0,0,7,9]]}` + + var ( + serializedCCS bytes.Buffer + + serializedProvingKey bytes.Buffer + serializedVerifyingKey bytes.Buffer + + serializedProof bytes.Buffer + ) + + // full example + + // first we run the setup phase. This happens offline in a trusted or MPC setting + if err := setup(&serializedCCS, &serializedProvingKey, &serializedVerifyingKey); err != nil { + fmt.Println("failed to setup circuit:", err) + return + } + + // then we run the prover phase. This happens online + serializedChallenge := bytes.NewBufferString(serializedChallengeBytes) + if err := prover(&serializedCCS, serializedChallenge, &serializedProvingKey, &serializedProof); err != nil { + fmt.Println("failed to prove circuit:", err) + return + } + + // finally we run the verifier phase. This happens online + serializedChallenge = bytes.NewBufferString(serializedChallengeBytes) + if err := verifier(serializedChallenge, &serializedVerifyingKey, &serializedProof); err != nil { + fmt.Println("failed to verify circuit:", err) + return + } + fmt.Println("proof verified successfully!") + // Output: proof verified successfully! +} diff --git a/frontend/api.go b/frontend/api.go index c2b9d05a53..b3f218c526 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -51,18 +51,20 @@ type API interface { // Mul returns res = i1 * i2 * ... in Mul(i1, i2 Variable, in ...Variable) Variable - // DivUnchecked returns i1 / i2 . if i1 == i2 == 0, returns 0 + // DivUnchecked returns i1 / i2 + // If i1 == i2 == 0, the return value (0) is unconstrained. DivUnchecked(i1, i2 Variable) Variable // Div returns i1 / i2 + // If i2 == 0 the constraint will not be satisfied. Div(i1, i2 Variable) Variable // Inverse returns res = 1 / i1 + // If i1 == 0 the constraint will not be satisfied. Inverse(i1 Variable) Variable // --------------------------------------------------------------------------------------------- // Bit operations - // TODO @gbotrel move bit operations in std/math/bits // ToBinary unpacks a Variable in binary, // n is the number of bits to select (starting from lsb) @@ -72,29 +74,32 @@ type API interface { ToBinary(i1 Variable, n ...int) []Variable // FromBinary packs b, seen as a fr.Element in little endian + // This function constrain the bits b... to be boolean (0 or 1) FromBinary(b ...Variable) Variable // Xor returns a ^ b - // a and b must be 0 or 1 + // This function constrain a and b to be boolean (0 or 1) Xor(a, b Variable) Variable // Or returns a | b - // a and b must be 0 or 1 + // This function constrain a and b to be boolean (0 or 1) Or(a, b Variable) Variable - // Or returns a & b - // a and b must be 0 or 1 + // And returns a & b + // This function constrain a and b to be boolean (0 or 1) And(a, b Variable) Variable // --------------------------------------------------------------------------------------------- // Conditionals // Select if b is true, yields i1 else yields i2 + // This function constrain b to be boolean (0 or 1) Select(b Variable, i1, i2 Variable) Variable // Lookup2 performs a 2-bit lookup between i1, i2, i3, i4 based on bits b0 // and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 // and i3 if b0=b1=1. + // This function constrain b0 and b1 to be boolean (0 or 1) Lookup2(b0, b1 Variable, i0, i1, i2, i3 Variable) Variable // IsZero returns 1 if a is zero, 0 otherwise @@ -119,8 +124,9 @@ type API interface { // AssertIsDifferent fails if i1 == i2 AssertIsDifferent(i1, i2 Variable) - // AssertIsBoolean fails if v != 0 and v != 1 + // AssertIsBoolean fails if v ∉ {0,1} AssertIsBoolean(i1 Variable) + // AssertIsCrumb fails if v ∉ {0,1,2,3} (crumb is a 2-bit variable; see https://en.wikipedia.org/wiki/Units_of_information) AssertIsCrumb(i1 Variable) diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index fb2397cde9..58ed49cfc8 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -35,7 +35,6 @@ import ( "github.com/consensys/gnark/std/math/bits" ) -// Add returns res = i1+i2+...in func (builder *builder) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // separate the constant part from the variables vars, k := builder.filterConstantSum(append([]frontend.Variable{i1, i2}, in...)) @@ -105,7 +104,6 @@ func (builder *builder) mulAccFastTrack(a, b, c frontend.Variable) frontend.Vari return res } -// neg returns -in func (builder *builder) neg(in []frontend.Variable) []frontend.Variable { res := make([]frontend.Variable, len(in)) @@ -115,13 +113,11 @@ func (builder *builder) neg(in []frontend.Variable) []frontend.Variable { return res } -// Sub returns res = i1 - i2 - ...in func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { r := builder.neg(append([]frontend.Variable{i2}, in...)) return builder.Add(i1, r[0], r[1:]...) } -// Neg returns -i func (builder *builder) Neg(i1 frontend.Variable) frontend.Variable { if n, ok := builder.constantValue(i1); ok { n = builder.cs.Neg(n) @@ -132,7 +128,6 @@ func (builder *builder) Neg(i1 frontend.Variable) frontend.Variable { return v } -// Mul returns res = i1 * i2 * ... in func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, k := builder.filterConstantProd(append([]frontend.Variable{i1, i2}, in...)) if len(vars) == 0 { @@ -157,7 +152,6 @@ func (builder *builder) mulConstant(t expr.Term, m constraint.Element) expr.Term return t } -// DivUnchecked returns i1 / i2 . if i1 == i2 == 0, returns 0 func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { c1, i1Constant := builder.constantValue(i1) c2, i2Constant := builder.constantValue(i2) @@ -195,7 +189,6 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable return res } -// Div returns i1 / i2 func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { // note that here we ensure that v2 can't be 0, but it costs us one extra constraint builder.Inverse(i2) @@ -203,7 +196,6 @@ func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { return builder.DivUnchecked(i1, i2) } -// Inverse returns res = 1 / i1 func (builder *builder) Inverse(i1 frontend.Variable) frontend.Variable { if c, ok := builder.constantValue(i1); ok { if c.IsZero() { @@ -236,11 +228,6 @@ func (builder *builder) Inverse(i1 frontend.Variable) frontend.Variable { // --------------------------------------------------------------------------------------------- // Bit operations -// ToBinary unpacks a frontend.Variable in binary, -// n is the number of bits to select (starting from lsb) -// n default value is fr.Bits the number of bits needed to represent a field element -// -// The result is in little endian (first bit= lsb) func (builder *builder) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { // nbBits nbBits := builder.cs.FieldBitLen() @@ -254,13 +241,10 @@ func (builder *builder) ToBinary(i1 frontend.Variable, n ...int) []frontend.Vari return bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits)) } -// FromBinary packs b, seen as a fr.Element in little endian func (builder *builder) FromBinary(b ...frontend.Variable) frontend.Variable { return bits.FromBinary(builder, b) } -// Xor returns a ^ b -// a and b must be 0 or 1 func (builder *builder) Xor(a, b frontend.Variable) frontend.Variable { // pre condition: a, b must be booleans builder.AssertIsBoolean(a) @@ -335,8 +319,6 @@ func (builder *builder) Xor(a, b frontend.Variable) frontend.Variable { return res } -// Or returns a | b -// a and b must be 0 or 1 func (builder *builder) Or(a, b frontend.Variable) frontend.Variable { builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) @@ -388,8 +370,6 @@ func (builder *builder) Or(a, b frontend.Variable) frontend.Variable { return res } -// Or returns a & b -// a and b must be 0 or 1 func (builder *builder) And(a, b frontend.Variable) frontend.Variable { builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) @@ -401,7 +381,6 @@ func (builder *builder) And(a, b frontend.Variable) frontend.Variable { // --------------------------------------------------------------------------------------------- // Conditionals -// Select if b is true, yields i1 else yields i2 func (builder *builder) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { _b, bConstant := builder.constantValue(b) @@ -424,9 +403,6 @@ func (builder *builder) Select(b frontend.Variable, i1, i2 frontend.Variable) fr return builder.Add(l, i2) } -// Lookup2 performs a 2-bit lookup between i1, i2, i3, i4 based on bits b0 -// and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 -// and i3 if b0=b1=1. func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { // ensure that bits are actually bits. Adds no constraints if the variables // are already constrained. @@ -473,7 +449,6 @@ func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten } -// IsZero returns 1 if a is zero, 0 otherwise func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { if a, ok := builder.constantValue(i1); ok { if a.IsZero() { @@ -519,7 +494,6 @@ func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { return m } -// Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1 0 { + pk.CommitmentKeys = make([]pedersen.ProvingKey, len(commitmentBases)) + vk.CommitmentKeys = make([]pedersen.VerifyingKey, len(commitmentBases)) + } for i := range commitmentBases { comPKey, comVKey, err := pedersen.Setup(commitmentBases[i:i+1], pedersen.WithG2Point(cG2)) if err != nil { diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl index 6997043589..2e51218f01 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl @@ -82,7 +82,6 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } - vk.CommitmentKeys = []pedersen.VerifyingKey{} if withCommitment { vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) @@ -92,7 +91,7 @@ func TestVerifyingKeySerialization(t *testing.T) { for j := range bases[i] { bases[i][j] = elem elem.Add(&elem, &p1) - vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigma: p2}) + vk.CommitmentKeys = append(vk.CommitmentKeys, pedersen.VerifyingKey{G: p2, GSigmaNeg: p2}) } } assert.NoError(t, err) @@ -161,17 +160,18 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true - pedersenBasis := make([]curve.G1Affine, nbCommitment) - pedersenBases := make([][]curve.G1Affine, nbCommitment) - pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) - for i := range pedersenBasis { - pedersenBasis[i] = p1 - pedersenBases[i] = pedersenBasis[:i+1] - } - { - var err error - pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) - require.NoError(t, err) + if nbCommitment > 0 { + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases) + require.NoError(t, err) + } } if err := io.RoundTripCheck(&pk, func() any {return new(ProvingKey)}); err != nil { diff --git a/internal/stats/latest_stats.csv b/internal/stats/latest_stats.csv index fbc129594e..eb7b4efb74 100644 --- a/internal/stats/latest_stats.csv +++ b/internal/stats/latest_stats.csv @@ -209,42 +209,42 @@ pairing_bw6761,bls24_315,plonk,0,0 pairing_bw6761,bls24_317,plonk,0,0 pairing_bw6761,bw6_761,plonk,0,0 pairing_bw6761,bw6_633,plonk,0,0 -scalar_mul_G1_bn254,bn254,groth16,99938,159576 +scalar_mul_G1_bn254,bn254,groth16,74345,117078 scalar_mul_G1_bn254,bls12_377,groth16,0,0 scalar_mul_G1_bn254,bls12_381,groth16,0,0 scalar_mul_G1_bn254,bls24_315,groth16,0,0 scalar_mul_G1_bn254,bls24_317,groth16,0,0 scalar_mul_G1_bn254,bw6_761,groth16,0,0 scalar_mul_G1_bn254,bw6_633,groth16,0,0 -scalar_mul_G1_bn254,bn254,plonk,381115,356144 +scalar_mul_G1_bn254,bn254,plonk,278909,261995 scalar_mul_G1_bn254,bls12_377,plonk,0,0 scalar_mul_G1_bn254,bls12_381,plonk,0,0 scalar_mul_G1_bn254,bls24_315,plonk,0,0 scalar_mul_G1_bn254,bls24_317,plonk,0,0 scalar_mul_G1_bn254,bw6_761,plonk,0,0 scalar_mul_G1_bn254,bw6_633,plonk,0,0 -scalar_mul_P256,bn254,groth16,186380,301997 +scalar_mul_P256,bn254,groth16,100828,161106 scalar_mul_P256,bls12_377,groth16,0,0 scalar_mul_P256,bls12_381,groth16,0,0 scalar_mul_P256,bls24_315,groth16,0,0 scalar_mul_P256,bls24_317,groth16,0,0 scalar_mul_P256,bw6_761,groth16,0,0 scalar_mul_P256,bw6_633,groth16,0,0 -scalar_mul_P256,bn254,plonk,737681,687661 +scalar_mul_P256,bn254,plonk,385060,359805 scalar_mul_P256,bls12_377,plonk,0,0 scalar_mul_P256,bls12_381,plonk,0,0 scalar_mul_P256,bls24_315,plonk,0,0 scalar_mul_P256,bls24_317,plonk,0,0 scalar_mul_P256,bw6_761,plonk,0,0 scalar_mul_P256,bw6_633,plonk,0,0 -scalar_mul_secp256k1,bn254,groth16,100948,161209 +scalar_mul_secp256k1,bn254,groth16,75154,118312 scalar_mul_secp256k1,bls12_377,groth16,0,0 scalar_mul_secp256k1,bls12_381,groth16,0,0 scalar_mul_secp256k1,bls24_315,groth16,0,0 scalar_mul_secp256k1,bls24_317,groth16,0,0 scalar_mul_secp256k1,bw6_761,groth16,0,0 scalar_mul_secp256k1,bw6_633,groth16,0,0 -scalar_mul_secp256k1,bn254,plonk,385109,359843 +scalar_mul_secp256k1,bn254,plonk,281870,264753 scalar_mul_secp256k1,bls12_377,plonk,0,0 scalar_mul_secp256k1,bls12_381,plonk,0,0 scalar_mul_secp256k1,bls24_315,plonk,0,0 diff --git a/internal/tinyfield/element.go b/internal/tinyfield/element.go index 85d289e751..5b3bce6659 100644 --- a/internal/tinyfield/element.go +++ b/internal/tinyfield/element.go @@ -404,32 +404,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [2]uint64 var D uint64 diff --git a/internal/tinyfield/element_test.go b/internal/tinyfield/element_test.go index 93664bddb2..85dcdcd893 100644 --- a/internal/tinyfield/element_test.go +++ b/internal/tinyfield/element_test.go @@ -582,7 +582,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -2159,32 +2158,32 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - } + if qElement[0] != ^uint64(0) { + g[0] %= (qElement[0] + 1) + } - if qElement[0] != ^uint64(0) { - g[0] %= (qElement[0] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + g[0] %= (qElement[0] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - } - if qElement[0] != ^uint64(0) { - g[0] %= (qElement[0] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], _ = bits.Add64(a[0], qElement[0], carry) @@ -2193,3 +2192,11 @@ func genFull() gopter.Gen { return genResult } } + +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} diff --git a/internal/tinyfield/vector.go b/internal/tinyfield/vector.go index 9ef47d3cda..703832e720 100644 --- a/internal/tinyfield/vector.go +++ b/internal/tinyfield/vector.go @@ -196,6 +196,96 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/internal/tinyfield/vector_test.go b/internal/tinyfield/vector_test.go index 68a98e5fa9..fc02cc1279 100644 --- a/internal/tinyfield/vector_test.go +++ b/internal/tinyfield/vector_test.go @@ -18,10 +18,15 @@ package tinyfield import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,273 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + mixer[0] %= (qElement[0] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + mixer[0] %= (qElement[0] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/profile/internal/graph/graph.go b/profile/internal/graph/graph.go index 74b904c402..90def3a924 100644 --- a/profile/internal/graph/graph.go +++ b/profile/internal/graph/graph.go @@ -438,7 +438,7 @@ func newTree(prof *profile.Profile, o *Options) (g *Graph) { } } - nodes := make(Nodes, len(prof.Location)) + nodes := make(Nodes, 0, len(prof.Location)) for _, nm := range parentNodeMap { nodes = append(nodes, nm.nodes()...) } diff --git a/std/algebra/emulated/sw_emulated/hints.go b/std/algebra/emulated/sw_emulated/hints.go index 06c15d07a4..f19bfeef0f 100644 --- a/std/algebra/emulated/sw_emulated/hints.go +++ b/std/algebra/emulated/sw_emulated/hints.go @@ -1,12 +1,26 @@ package sw_emulated import ( + "crypto/elliptic" "fmt" "math/big" "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + bls12381_fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark-crypto/ecc/bn254" + bn_fp "github.com/consensys/gnark-crypto/ecc/bn254/fp" + bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761" + bw6_fp "github.com/consensys/gnark-crypto/ecc/bw6-761/fp" + "github.com/consensys/gnark-crypto/ecc/secp256k1" + secp_fp "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" + stark_curve "github.com/consensys/gnark-crypto/ecc/stark-curve" + stark_fp "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" + "github.com/consensys/gnark-crypto/field/eisenstein" "github.com/consensys/gnark/constraint/solver" + limbs "github.com/consensys/gnark/std/internal/limbcomposition" "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" ) func init() { @@ -14,7 +28,15 @@ func init() { } func GetHints() []solver.Hint { - return []solver.Hint{decomposeScalarG1, decomposeScalarG1Signs, decomposeScalarG1Subscalars} + return []solver.Hint{ + decomposeScalarG1Signs, + decomposeScalarG1Subscalars, + scalarMulHint, + halfGCD, + halfGCDSigns, + halfGCDEisenstein, + halfGCDEisensteinSigns, + } } func decomposeScalarG1Subscalars(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { @@ -70,42 +92,354 @@ func decomposeScalarG1Signs(mod *big.Int, inputs []*big.Int, outputs []*big.Int) }) } -func decomposeScalarG1(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { +// TODO @yelhousni: generalize for any supported curve as it currently supports only: +// BN254, BLS12-381, BW6-761 and Secp256k1, P256, P384 and STARK curve. +func scalarMulHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHintWithNativeInput(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error { + if len(outputs) != 2 { + return fmt.Errorf("expecting two outputs") + } + if len(outputs) != 2 { + return fmt.Errorf("expecting two outputs") + } + if field.Cmp(elliptic.P256().Params().P) == 0 { + var fp emparams.P256Fp + var fr emparams.P256Fr + PXLimbs := inputs[:fp.NbLimbs()] + PYLimbs := inputs[fp.NbLimbs() : 2*fp.NbLimbs()] + SLimbs := inputs[2*fp.NbLimbs():] + Px, Py, S := new(big.Int), new(big.Int), new(big.Int) + if err := limbs.Recompose(PXLimbs, fp.BitsPerLimb(), Px); err != nil { + return err + + } + if err := limbs.Recompose(PYLimbs, fp.BitsPerLimb(), Py); err != nil { + return err + + } + if err := limbs.Recompose(SLimbs, fr.BitsPerLimb(), S); err != nil { + return err + + } + curve := elliptic.P256() + // compute the resulting point [s]P + Qx, Qy := curve.ScalarMult(Px, Py, S.Bytes()) + outputs[0].Set(Qx) + outputs[1].Set(Qy) + } else if field.Cmp(elliptic.P384().Params().P) == 0 { + var fp emparams.P384Fp + var fr emparams.P384Fr + PXLimbs := inputs[:fp.NbLimbs()] + PYLimbs := inputs[fp.NbLimbs() : 2*fp.NbLimbs()] + SLimbs := inputs[2*fp.NbLimbs():] + Px, Py, S := new(big.Int), new(big.Int), new(big.Int) + if err := limbs.Recompose(PXLimbs, fp.BitsPerLimb(), Px); err != nil { + return err + + } + if err := limbs.Recompose(PYLimbs, fp.BitsPerLimb(), Py); err != nil { + return err + + } + if err := limbs.Recompose(SLimbs, fr.BitsPerLimb(), S); err != nil { + return err + + } + curve := elliptic.P384() + // compute the resulting point [s]P + Qx, Qy := curve.ScalarMult(Px, Py, S.Bytes()) + outputs[0].Set(Qx) + outputs[1].Set(Qy) + } else if field.Cmp(stark_fp.Modulus()) == 0 { + var fp emparams.STARKCurveFp + var fr emparams.STARKCurveFr + PXLimbs := inputs[:fp.NbLimbs()] + PYLimbs := inputs[fp.NbLimbs() : 2*fp.NbLimbs()] + SLimbs := inputs[2*fp.NbLimbs():] + Px, Py, S := new(big.Int), new(big.Int), new(big.Int) + if err := limbs.Recompose(PXLimbs, fp.BitsPerLimb(), Px); err != nil { + return err + + } + if err := limbs.Recompose(PYLimbs, fp.BitsPerLimb(), Py); err != nil { + return err + + } + if err := limbs.Recompose(SLimbs, fr.BitsPerLimb(), S); err != nil { + return err + + } + // compute the resulting point [s]Q + var P stark_curve.G1Affine + P.X.SetBigInt(Px) + P.Y.SetBigInt(Py) + P.ScalarMultiplication(&P, S) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else if field.Cmp(bn_fp.Modulus()) == 0 { + var fp emparams.BN254Fp + var fr emparams.BN254Fr + PXLimbs := inputs[:fp.NbLimbs()] + PYLimbs := inputs[fp.NbLimbs() : 2*fp.NbLimbs()] + SLimbs := inputs[2*fp.NbLimbs():] + Px, Py, S := new(big.Int), new(big.Int), new(big.Int) + if err := limbs.Recompose(PXLimbs, fp.BitsPerLimb(), Px); err != nil { + return err + + } + if err := limbs.Recompose(PYLimbs, fp.BitsPerLimb(), Py); err != nil { + return err + + } + if err := limbs.Recompose(SLimbs, fr.BitsPerLimb(), S); err != nil { + return err + + } + // compute the resulting point [s]Q + var P bn254.G1Affine + P.X.SetBigInt(Px) + P.Y.SetBigInt(Py) + P.ScalarMultiplication(&P, S) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else if field.Cmp(bls12381_fp.Modulus()) == 0 { + var fp emparams.BLS12381Fp + var fr emparams.BLS12381Fr + PXLimbs := inputs[:fp.NbLimbs()] + PYLimbs := inputs[fp.NbLimbs() : 2*fp.NbLimbs()] + SLimbs := inputs[2*fp.NbLimbs():] + Px, Py, S := new(big.Int), new(big.Int), new(big.Int) + if err := limbs.Recompose(PXLimbs, fp.BitsPerLimb(), Px); err != nil { + return err + + } + if err := limbs.Recompose(PYLimbs, fp.BitsPerLimb(), Py); err != nil { + return err + + } + if err := limbs.Recompose(SLimbs, fr.BitsPerLimb(), S); err != nil { + return err + + } + // compute the resulting point [s]Q + var P bls12381.G1Affine + P.X.SetBigInt(Px) + P.Y.SetBigInt(Py) + P.ScalarMultiplication(&P, S) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else if field.Cmp(secp_fp.Modulus()) == 0 { + var fp emparams.Secp256k1Fp + var fr emparams.Secp256k1Fr + PXLimbs := inputs[:fp.NbLimbs()] + PYLimbs := inputs[fp.NbLimbs() : 2*fp.NbLimbs()] + SLimbs := inputs[2*fp.NbLimbs():] + Px, Py, S := new(big.Int), new(big.Int), new(big.Int) + if err := limbs.Recompose(PXLimbs, fp.BitsPerLimb(), Px); err != nil { + return err + + } + if err := limbs.Recompose(PYLimbs, fp.BitsPerLimb(), Py); err != nil { + return err + + } + if err := limbs.Recompose(SLimbs, fr.BitsPerLimb(), S); err != nil { + return err + + } + // compute the resulting point [s]Q + var P secp256k1.G1Affine + P.X.SetBigInt(Px) + P.Y.SetBigInt(Py) + P.ScalarMultiplication(&P, S) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else if field.Cmp(bw6_fp.Modulus()) == 0 { + var fp emparams.BW6761Fp + var fr emparams.BW6761Fr + PXLimbs := inputs[:fp.NbLimbs()] + PYLimbs := inputs[fp.NbLimbs() : 2*fp.NbLimbs()] + SLimbs := inputs[2*fp.NbLimbs():] + Px, Py, S := new(big.Int), new(big.Int), new(big.Int) + if err := limbs.Recompose(PXLimbs, fp.BitsPerLimb(), Px); err != nil { + return err + + } + if err := limbs.Recompose(PYLimbs, fp.BitsPerLimb(), Py); err != nil { + return err + + } + if err := limbs.Recompose(SLimbs, fr.BitsPerLimb(), S); err != nil { + return err + + } + // compute the resulting point [s]Q + var P bw6761.G1Affine + P.X.SetBigInt(Px) + P.Y.SetBigInt(Py) + P.ScalarMultiplication(&P, S) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + + } else { + return fmt.Errorf("unsupported curve") + } + + return nil + }) +} + +func halfGCDSigns(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHintWithNativeOutput(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 1 { + return fmt.Errorf("expecting one input") + } + if len(outputs) != 1 { + return fmt.Errorf("expecting one output") + } + glvBasis := new(ecc.Lattice) + ecc.PrecomputeLattice(field, inputs[0], glvBasis) + outputs[0].SetUint64(0) + if glvBasis.V1[1].Sign() == -1 { + outputs[0].SetUint64(1) + } + + return nil + }) +} + +func halfGCD(mod *big.Int, inputs, outputs []*big.Int) error { return emulated.UnwrapHint(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error { - if len(inputs) != 3 { - return fmt.Errorf("expecting two inputs") + if len(inputs) != 1 { + return fmt.Errorf("expecting one input") } - if len(outputs) != 6 { + if len(outputs) != 2 { return fmt.Errorf("expecting two outputs") } glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(inputs[2], inputs[1], glvBasis) - sp := ecc.SplitScalar(inputs[0], glvBasis) - outputs[0].Set(&(sp[0])) - outputs[1].Set(&(sp[1])) - // we need the negative values for to check that s0+λ*s1 == s mod r - // output4 = s0 mod r - // output5 = s1 mod r - outputs[4].Set(outputs[0]) - outputs[5].Set(outputs[1]) + ecc.PrecomputeLattice(field, inputs[0], glvBasis) + outputs[0].Set(&glvBasis.V1[0]) + outputs[1].Set(&glvBasis.V1[1]) + // we need the absolute values for the in-circuit computations, // otherwise the negative values will be reduced modulo the SNARK scalar // field and not the emulated field. // output0 = |s0| mod r // output1 = |s1| mod r - // output2 = 1 if s0 is positive, 0 if s0 is negative - // output3 = 1 if s1 is positive, 0 if s0 is negative - outputs[2].SetUint64(1) + if outputs[1].Sign() == -1 { + outputs[1].Neg(outputs[1]) + } + + return nil + }) +} + +func halfGCDEisensteinSigns(mod *big.Int, inputs, outputs []*big.Int) error { + return emulated.UnwrapHintWithNativeOutput(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("expecting two input") + } + if len(outputs) != 5 { + return fmt.Errorf("expecting five outputs") + } + glvBasis := new(ecc.Lattice) + ecc.PrecomputeLattice(field, inputs[1], glvBasis) + r := eisenstein.ComplexNumber{ + A0: &glvBasis.V1[0], + A1: &glvBasis.V1[1], + } + sp := ecc.SplitScalar(inputs[0], glvBasis) + // in-circuit we check that Q - [s]P = 0 or equivalently Q + [-s]P = 0 + // so here we return -s instead of s. + s := eisenstein.ComplexNumber{ + A0: &sp[0], + A1: &sp[1], + } + s.Neg(&s) + + outputs[0].SetUint64(0) + outputs[1].SetUint64(0) + outputs[2].SetUint64(0) + outputs[3].SetUint64(0) + outputs[4].SetUint64(0) + res := eisenstein.HalfGCD(&r, &s) + s.A1.Mul(res[1].A1, inputs[1]). + Add(s.A1, res[1].A0). + Mul(s.A1, inputs[0]). + Add(s.A1, res[0].A0) + s.A0.Mul(res[0].A1, inputs[1]) + s.A1.Add(s.A1, s.A0). + Div(s.A1, field) + + if res[0].A0.Sign() == -1 { + outputs[0].SetUint64(1) + } + if res[0].A1.Sign() == -1 { + outputs[1].SetUint64(1) + } + if res[1].A0.Sign() == -1 { + outputs[2].SetUint64(1) + } + if res[1].A1.Sign() == -1 { + outputs[3].SetUint64(1) + } + if s.A1.Sign() == -1 { + outputs[4].SetUint64(1) + } + return nil + }) +} + +func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHint(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("expecting two input") + } + if len(outputs) != 5 { + return fmt.Errorf("expecting five outputs") + } + glvBasis := new(ecc.Lattice) + ecc.PrecomputeLattice(field, inputs[1], glvBasis) + r := eisenstein.ComplexNumber{ + A0: &glvBasis.V1[0], + A1: &glvBasis.V1[1], + } + sp := ecc.SplitScalar(inputs[0], glvBasis) + // in-circuit we check that Q - [s]P = 0 or equivalently Q + [-s]P = 0 + // so here we return -s instead of s. + s := eisenstein.ComplexNumber{ + A0: &sp[0], + A1: &sp[1], + } + s.Neg(&s) + res := eisenstein.HalfGCD(&r, &s) + outputs[0].Set(res[0].A0) + outputs[1].Set(res[0].A1) + outputs[2].Set(res[1].A0) + outputs[3].Set(res[1].A1) + outputs[4].Mul(res[1].A1, inputs[1]). + Add(outputs[4], res[1].A0). + Mul(outputs[4], inputs[0]). + Add(outputs[4], res[0].A0) + s.A0.Mul(res[0].A1, inputs[1]) + outputs[4].Add(outputs[4], s.A0). + Div(outputs[4], field) + if outputs[0].Sign() == -1 { outputs[0].Neg(outputs[0]) - outputs[2].SetUint64(0) } - outputs[3].SetUint64(1) if outputs[1].Sign() == -1 { outputs[1].Neg(outputs[1]) - outputs[3].SetUint64(0) } - + if outputs[2].Sign() == -1 { + outputs[2].Neg(outputs[2]) + } + if outputs[3].Sign() == -1 { + outputs[3].Neg(outputs[3]) + } + if outputs[4].Sign() == -1 { + outputs[4].Neg(outputs[4]) + } return nil }) } diff --git a/std/algebra/emulated/sw_emulated/params.go b/std/algebra/emulated/sw_emulated/params.go index bf917e04df..97095f86aa 100644 --- a/std/algebra/emulated/sw_emulated/params.go +++ b/std/algebra/emulated/sw_emulated/params.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254" bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/secp256k1" + stark_curve "github.com/consensys/gnark-crypto/ecc/stark-curve" "github.com/consensys/gnark/std/math/emulated" ) @@ -133,6 +134,23 @@ func GetBW6761Params() CurveParams { } } +// GetStarkCurveParams returns the curve parameters for the STARK curve. +// When initialising new curve, use the base field [emulated.STARKCurveFp] and scalar +// field [emulated.STARKCurveFr]. +func GetStarkCurveParams() CurveParams { + _, g1aff := stark_curve.Generators() + b, _ := new(big.Int).SetString("3141592653589793238462643383279502884197169399375105820974944592307816406665", 10) + return CurveParams{ + A: big.NewInt(1), + B: b, + Gx: g1aff.X.BigInt(new(big.Int)), + Gy: g1aff.Y.BigInt(new(big.Int)), + Gm: computeStarkCurveTable(), + Eigenvalue: nil, + ThirdRootOne: nil, + } +} + // GetCurveParams returns suitable curve parameters given the parametric type // Base as base field. It caches the parameters and modifying the values in the // parameters struct leads to undefined behaviour. @@ -151,18 +169,21 @@ func GetCurveParams[Base emulated.FieldParams]() CurveParams { return p384Params case emulated.BW6761Fp{}.Modulus().String(): return bw6761Params + case emulated.STARKCurveFp{}.Modulus().String(): + return starkCurveParams default: panic("no stored parameters") } } var ( - secp256k1Params CurveParams - bn254Params CurveParams - bls12381Params CurveParams - p256Params CurveParams - p384Params CurveParams - bw6761Params CurveParams + secp256k1Params CurveParams + bn254Params CurveParams + bls12381Params CurveParams + p256Params CurveParams + p384Params CurveParams + bw6761Params CurveParams + starkCurveParams CurveParams ) func init() { @@ -172,4 +193,5 @@ func init() { p256Params = GetP256Params() p384Params = GetP384Params() bw6761Params = GetBW6761Params() + starkCurveParams = GetStarkCurveParams() } diff --git a/std/algebra/emulated/sw_emulated/params_compute.go b/std/algebra/emulated/sw_emulated/params_compute.go index 88a514c7bc..af969423f6 100644 --- a/std/algebra/emulated/sw_emulated/params_compute.go +++ b/std/algebra/emulated/sw_emulated/params_compute.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254" bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/secp256k1" + stark_curve "github.com/consensys/gnark-crypto/ecc/stark-curve" ) func computeSecp256k1Table() [][2]*big.Int { @@ -157,3 +158,29 @@ func computeBW6761Table() [][2]*big.Int { } return table } + +func computeStarkCurveTable() [][2]*big.Int { + Gjac, _ := stark_curve.Generators() + table := make([][2]*big.Int, 256) + tmp := new(stark_curve.G1Jac).Set(&Gjac) + aff := new(stark_curve.G1Affine) + jac := new(stark_curve.G1Jac) + for i := 1; i < 256; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table[:] +} diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go index 21063b5f20..433300ccda 100644 --- a/std/algebra/emulated/sw_emulated/point.go +++ b/std/algebra/emulated/sw_emulated/point.go @@ -507,13 +507,13 @@ func (c *Curve[B, S]) Mux(sel frontend.Variable, inputs ...*AffinePoint[B]) *Aff // ScalarMul computes [s]p and returns it. It doesn't modify p nor s. // This function doesn't check that the p is on the curve. See AssertIsOnCurve. // -// ScalarMul calls scalarMulGeneric or scalarMulGLV depending on whether an efficient endomorphism is available. +// ScalarMul calls scalarMulFakeGLV or scalarMulGLVAndFakeGLV depending on whether an efficient endomorphism is available. func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] { if c.eigenvalue != nil && c.thirdRootOne != nil { - return c.scalarMulGLV(p, s, opts...) + return c.scalarMulGLVAndFakeGLV(p, s, opts...) } else { - return c.scalarMulGeneric(p, s, opts...) + return c.scalarMulFakeGLV(p, s, opts...) } } @@ -574,7 +574,7 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op var st S nbits := st.Modulus().BitLen()>>1 + 2 - // precompute -Q, -Φ(Q), Φ(Q) + // precompute -Q, Q, 3Q, -Φ(Q), Φ(Q), 3Φ(Q) var tableQ, tablePhiQ [3]*AffinePoint[B] negQY := c.baseApi.Neg(&Q.Y) tableQ[1] = &AffinePoint[B]{ @@ -663,7 +663,7 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op // note that half the points are negatives of the other half, // hence have the same X coordinates. - // when nbits is odd, we need to handle the first iteration separately + // when nbits is even, we need to handle the first iteration separately if nbits%2 == 0 { // Acc = [2]Acc ± Q ± Φ(Q) T := &AffinePoint[B]{ @@ -675,7 +675,7 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op Acc = c.double(Acc) Acc = c.add(Acc, T) } else { - // when nbits is even we start the main loop at normally nbits - 1 + // when nbits is odd we start the main loop at normally nbits - 1 nbits++ } for i := nbits - 2; i > 0; i -= 2 { @@ -726,7 +726,7 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op return Acc } -// scalarMulGeneric computes [s]p and returns it. It doesn't modify p nor s. +// scalarMulJoye computes [s]p and returns it. It doesn't modify p nor s. // This function doesn't check that the p is on the curve. See AssertIsOnCurve. // // ⚠️ p must not be (0,0) and s must not be 0, unless [algopts.WithCompleteArithmetic] option is set. @@ -745,7 +745,7 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op // [ELM03]: https://arxiv.org/pdf/math/0208038.pdf // [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf // [Joye07]: https://www.iacr.org/archive/ches2007/47270135/47270135.pdf -func (c *Curve[B, S]) scalarMulGeneric(p *AffinePoint[B], s *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] { +func (c *Curve[B, S]) scalarMulJoye(p *AffinePoint[B], s *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] { cfg, err := algopts.NewConfig(opts...) if err != nil { panic(fmt.Sprintf("parse opts: %v", err)) @@ -805,26 +805,18 @@ func (c *Curve[B, S]) jointScalarMul(p1, p2 *AffinePoint[B], s1, s2 *emulated.El return c.jointScalarMulGLV(p1, p2, s1, s2, opts...) } else { - return c.jointScalarMulGeneric(p1, p2, s1, s2, opts...) + return c.jointScalarMulFakeGLV(p1, p2, s1, s2, opts...) } } -// jointScalarMulGeneric computes [s1]p1 + [s2]p2. It doesn't modify p1, p2 nor s1, s2. +// jointScalarMulFakeGLV computes [s1]p1 + [s2]p2. It doesn't modify p1, p2 nor s1, s2. // // ⚠️ The scalars s1, s2 must be nonzero and the point p1, p2 different from (0,0), unless [algopts.WithCompleteArithmetic] option is set. -func (c *Curve[B, S]) jointScalarMulGeneric(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] { - cfg, err := algopts.NewConfig(opts...) - if err != nil { - panic(fmt.Sprintf("parse opts: %v", err)) - } - if cfg.CompleteArithmetic { - res1 := c.scalarMulGeneric(p1, s1, opts...) - res2 := c.scalarMulGeneric(p2, s2, opts...) - return c.AddUnified(res1, res2) - } else { - return c.jointScalarMulGenericUnsafe(p1, p2, s1, s2) - } +func (c *Curve[B, S]) jointScalarMulFakeGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] { + sm1 := c.scalarMulFakeGLV(p1, s1, opts...) + sm2 := c.scalarMulFakeGLV(p2, s2, opts...) + return c.AddUnified(sm1, sm2) } // jointScalarMulGenericUnsafe computes [s1]p1 + [s2]p2 using Shamir's trick and returns it. It doesn't modify p1, p2 nor s1, s2. @@ -974,13 +966,14 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat tableS[3] = c.Neg(tableS[2]) f0 := c.baseApi.Mul(&tableS[0].X, c.thirdRootOne) f2 := c.baseApi.Mul(&tableS[2].X, c.thirdRootOne) + xor := c.api.Xor(selector2, selector4) tablePhiS[0] = &AffinePoint[B]{ - X: *c.baseApi.Select(c.api.Xor(selector2, selector4), f2, f0), + X: *c.baseApi.Select(xor, f2, f0), Y: *c.baseApi.Lookup2(selector2, selector4, &tableS[0].Y, &tableS[2].Y, &tableS[3].Y, &tableS[1].Y), } tablePhiS[1] = c.Neg(tablePhiS[0]) tablePhiS[2] = &AffinePoint[B]{ - X: *c.baseApi.Select(c.api.Xor(selector2, selector4), f0, f2), + X: *c.baseApi.Select(xor, f0, f2), Y: *c.baseApi.Lookup2(selector2, selector4, &tableS[2].Y, &tableS[0].Y, &tableS[1].Y, &tableS[3].Y), } tablePhiS[3] = c.Neg(tablePhiS[2]) @@ -1255,3 +1248,558 @@ func (c *Curve[B, S]) MultiScalarMul(p []*AffinePoint[B], s []*emulated.Element[ return res, nil } } + +// scalarMulFakeGLV computes [s]Q and returns it. It doesn't modify Q nor s. +// It implements the "fake GLV" explained in: https://hackmd.io/@yelhousni/Hy-aWld50. +// +// ⚠️ The scalar s must be nonzero and the point Q different from (0,0) unless [algopts.WithCompleteArithmetic] is set. +// (0,0) is not on the curve but we conventionally take it as the +// neutral/infinity point as per the [EVM]. +// +// TODO @yelhousni: generalize for any supported curve as it currently supports only: +// P256, P384 and STARK curve. +// +// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf +func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(err) + } + + var selector1 frontend.Variable + _s := s + if cfg.CompleteArithmetic { + selector1 = c.scalarApi.IsZero(s) + _s = c.scalarApi.Select(selector1, c.scalarApi.One(), s) + } + + // First we find the sub-salars s1, s2 s.t. s1 + s2*s = 0 mod r and s1, s2 < sqrt(r). + sd, err := c.scalarApi.NewHint(halfGCD, 2, _s) + if err != nil { + panic(fmt.Sprintf("halfGCD hint: %v", err)) + } + s1, s2 := sd[0], sd[1] + // s2 can be negative. If so, we return in the halfGCD hint -s2 + // and here compute _s2 = -s2 mod r + sign, err := c.scalarApi.NewHintWithNativeOutput(halfGCDSigns, 1, _s) + if err != nil { + panic(fmt.Sprintf("halfGCDSigns hint: %v", err)) + } + _s2 := c.scalarApi.Select(sign[0], c.scalarApi.Neg(s2), s2) + // We check that s1 + s*_s2 == 0 mod r + c.scalarApi.AssertIsEqual( + c.scalarApi.Add(s1, c.scalarApi.Mul(_s, _s2)), + c.scalarApi.Zero(), + ) + // A malicious hint can provide s1=s2=0 mod r + // So we check that _s2 is non-zero otherwise [0]([s]Q = ∀R) is always true + c.api.AssertIsEqual(c.scalarApi.IsZero(_s2), 0) + + // Then we compute the hinted scalar mul R = [s]Q + // Q coordinates are in Fp and the scalar s in Fr + // we decompose Q.X, Q.Y, s into limbs and recompose them in the hint. + var inps []frontend.Variable + inps = append(inps, Q.X.Limbs...) + inps = append(inps, Q.Y.Limbs...) + inps = append(inps, s.Limbs...) + R, err := c.baseApi.NewHintWithNativeInput(scalarMulHint, 2, inps...) + if err != nil { + panic(fmt.Sprintf("scalar mul hint: %v", err)) + } + r0, r1 := R[0], R[1] + + var selector2 frontend.Variable + one := c.baseApi.One() + dummy := &AffinePoint[B]{X: *one, Y: *one} + addFn := c.Add + if cfg.CompleteArithmetic { + addFn = c.AddUnified + // if Q=(0,0) we assign a dummy (1,1) to Q and R and continue + selector2 = c.api.And(c.baseApi.IsZero(&Q.X), c.baseApi.IsZero(&Q.Y)) + Q = c.Select(selector2, dummy, Q) + r0 = c.baseApi.Select(selector2, c.baseApi.Zero(), r0) + r1 = c.baseApi.Select(selector2, &dummy.Y, r1) + } + + var st S + nbits := (st.Modulus().BitLen() + 1) / 2 + s1bits := c.scalarApi.ToBits(s1) + s2bits := c.scalarApi.ToBits(s2) + + // Precomputations: + // tableQ[0] = -Q + // tableQ[1] = Q + // tableQ[2] = [3]Q + // tableR[0] = -R or R if s2 is negative + // tableR[1] = R or -R if s2 is negative + // tableR[2] = [3]R or [-3]R if s2 is negative + var tableQ, tableR [3]*AffinePoint[B] + tableQ[1] = Q + tableQ[0] = c.Neg(Q) + tableQ[2] = c.triple(tableQ[1]) + tableR[1] = &AffinePoint[B]{ + X: *r0, + Y: *c.baseApi.Select(sign[0], c.baseApi.Neg(r1), r1), + } + tableR[0] = c.Neg(tableR[1]) + if cfg.CompleteArithmetic { + tableR[2] = c.AddUnified(tableR[1], tableR[1]) + tableR[2] = c.AddUnified(tableR[2], tableR[1]) + } else { + tableR[2] = c.triple(tableR[1]) + } + + // We should start the accumulator by the infinity point, but since affine + // formulae are incomplete we suppose that the first bits of the + // sub-scalars s1 and s2 are 1, and set: + // Acc = Q + R + Acc := c.Add(tableQ[1], tableR[1]) + + // At each iteration we need to compute: + // [2]Acc ± Q ± R. + // We can compute [2]Acc and look up the (precomputed) point P from: + // B1 = Q+R + // B2 = -Q-R + // B3 = Q-R + // B4 = -Q+R + // + // If we extend this by merging two iterations, we need to look up P and P' + // both from {B1, B2, B3, B4} and compute: + // [2]([2]Acc+P)+P' = [4]Acc + T + // where T = [2]P+P'. So at each (merged) iteration, we can compute [4]Acc + // and look up T from the precomputed list of points: + // + // T = [3](Q + R) + // P = B1 and P' = B1 + T1 := c.Add(tableQ[2], tableR[2]) + // T = Q + R + // P = B1 and P' = B2 + T2 := Acc + // T = [3]Q + R + // P = B1 and P' = B3 + T3 := c.Add(tableQ[2], tableR[1]) + // T = Q + [3]R + // P = B1 and P' = B4 + T4 := c.Add(tableQ[1], tableR[2]) + // T = -Q - R + // P = B2 and P' = B1 + T5 := c.Neg(T2) + // T = -[3](Q + R) + // P = B2 and P' = B2 + T6 := c.Neg(T1) + // T = -Q - [3]R + // P = B2 and P' = B3 + T7 := c.Neg(T4) + // T = -[3]Q - R + // P = B2 and P' = B4 + T8 := c.Neg(T3) + // T = [3]Q - R + // P = B3 and P' = B1 + T9 := c.Add(tableQ[2], tableR[0]) + // T = Q - [3]R + // P = B3 and P' = B2 + T11 := c.Neg(tableR[2]) + T10 := c.Add(tableQ[1], T11) + // T = [3](Q - R) + // P = B3 and P' = B3 + T11 = c.Add(tableQ[2], T11) + // T = -R + Q + // P = B3 and P' = B4 + T12 := c.Add(tableR[0], tableQ[1]) + // T = [3]R - Q + // P = B4 and P' = B1 + T13 := c.Neg(T10) + // T = R - [3]Q + // P = B4 and P' = B2 + T14 := c.Neg(T9) + // T = R - Q + // P = B4 and P' = B3 + T15 := c.Neg(T12) + // T = [3](R - Q) + // P = B4 and P' = B4 + T16 := c.Neg(T11) + // note that half of these points are negatives of the other half, + // hence have the same X coordinates. + + // When nbits is even, we need to handle the first iteration separately + if nbits%2 == 0 { + // Acc = [2]Acc ± Q ± R + T := &AffinePoint[B]{ + X: *c.baseApi.Select(c.api.Xor(s1bits[nbits-1], s2bits[nbits-1]), &T12.X, &T5.X), + Y: *c.baseApi.Lookup2(s1bits[nbits-1], s2bits[nbits-1], &T5.Y, &T12.Y, &T15.Y, &T2.Y), + } + // We don't use doubleAndAdd here as it would involve edge cases + // when bits are 00 (T==-Acc) or 11 (T==Acc). + Acc = c.double(Acc) + Acc = c.add(Acc, T) + } else { + // when nbits is odd we start the main loop at normally nbits - 1 + nbits++ + } + for i := nbits - 2; i > 2; i -= 2 { + // selectorY takes values in [0,15] + selectorY := c.api.Add( + s1bits[i], + c.api.Mul(s2bits[i], 2), + c.api.Mul(s1bits[i-1], 4), + c.api.Mul(s2bits[i-1], 8), + ) + // selectorX takes values in [0,7] s.t.: + // - when selectorY < 8: selectorX = selectorY + // - when selectorY >= 8: selectorX = 15 - selectorY + selectorX := c.api.Add( + c.api.Mul(selectorY, c.api.Sub(1, c.api.Mul(s2bits[i-1], 2))), + c.api.Mul(s2bits[i-1], 15), + ) + // Bi.Y are distincts so we need a 16-to-1 multiplexer, + // but only half of the Bi.X are distinct so we need a 8-to-1. + T := &AffinePoint[B]{ + X: *c.baseApi.Mux(selectorX, + &T6.X, &T10.X, &T14.X, &T2.X, &T7.X, &T11.X, &T15.X, &T3.X, + ), + Y: *c.baseApi.Mux(selectorY, + &T6.Y, &T10.Y, &T14.Y, &T2.Y, &T7.Y, &T11.Y, &T15.Y, &T3.Y, + &T8.Y, &T12.Y, &T16.Y, &T4.Y, &T5.Y, &T9.Y, &T13.Y, &T1.Y, + ), + } + // Acc = [4]Acc + T + Acc = c.double(Acc) + Acc = c.doubleAndAdd(Acc, T) + } + + // i = 2 + // we isolate the last iteration to avoid falling into incomplete additions + // + // selectorY takes values in [0,15] + selectorY := c.api.Add( + s1bits[2], + c.api.Mul(s2bits[2], 2), + c.api.Mul(s1bits[1], 4), + c.api.Mul(s2bits[1], 8), + ) + // selectorX takes values in [0,7] s.t.: + // - when selectorY < 8: selectorX = selectorY + // - when selectorY >= 8: selectorX = 15 - selectorY + selectorX := c.api.Add( + c.api.Mul(selectorY, c.api.Sub(1, c.api.Mul(s2bits[1], 2))), + c.api.Mul(s2bits[1], 15), + ) + // Bi.Y are distincts so we need a 16-to-1 multiplexer, + // but only half of the Bi.X are distinct so we need a 8-to-1. + T := &AffinePoint[B]{ + X: *c.baseApi.Mux(selectorX, + &T6.X, &T10.X, &T14.X, &T2.X, &T7.X, &T11.X, &T15.X, &T3.X, + ), + Y: *c.baseApi.Mux(selectorY, + &T6.Y, &T10.Y, &T14.Y, &T2.Y, &T7.Y, &T11.Y, &T15.Y, &T3.Y, + &T8.Y, &T12.Y, &T16.Y, &T4.Y, &T5.Y, &T9.Y, &T13.Y, &T1.Y, + ), + } + // to avoid incomplete additions we add [3]R to the precomputed T before computing [4]Acc+T + // Acc = [4]Acc + T + [3]R + T = c.add(T, tableR[2]) + Acc = c.double(Acc) + Acc = c.doubleAndAdd(Acc, T) + + // i = 0 + // subtract Q and R if the first bits are 0. + // When cfg.CompleteArithmetic is set, we use AddUnified instead of Add. + // This means when s=0 then Acc=(0,0) because AddUnified(Q, -Q) = (0,0). + tableQ[0] = addFn(tableQ[0], Acc) + Acc = c.Select(s1bits[0], Acc, tableQ[0]) + tableR[0] = addFn(tableR[0], Acc) + Acc = c.Select(s2bits[0], Acc, tableR[0]) + + if cfg.CompleteArithmetic { + Acc = c.Select(c.api.Or(selector1, selector2), tableR[2], Acc) + } + // we added [3]R at the last iteration so the result should be + // Acc = [s1]Q + [s2]R + [3]R + // = [s1]Q + [s2*s]Q + [3]R + // = [s1+s2*s]Q + [3]R + // = [0]Q + [3]R + // = [3]R + c.AssertIsEqual(Acc, tableR[2]) + + return &AffinePoint[B]{ + X: *R[0], + Y: *R[1], + } +} + +// scalarMulGLVAndFakeGLV computes [s]P and returns it. It doesn't modify P nor s. +// It implements the "GLV + fake GLV" explained in [ethresear.ch/fake-GLV]. +// +// ⚠️ The scalar s must be nonzero and the point Q different from (0,0) unless [algopts.WithCompleteArithmetic] is set. +// (0,0) is not on the curve but we conventionally take it as the +// neutral/infinity point as per the [EVM]. +// +// TODO @yelhousni: generalize for any supported curve as it currently supports only: +// BN254, BLS12-381, BW6-761 and Secp256k1. +// +// [ethresear.ch/fake-GLV]: https://ethresear.ch/t/fake-glv-you-dont-need-an-efficient-endomorphism-to-implement-glv-like-scalar-multiplication-in-snark-circuits/20394 +// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf +func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(err) + } + + // handle 0-scalar and (-1)-scalar cases + var selector0 frontend.Variable + _s := s + if cfg.CompleteArithmetic { + one := c.scalarApi.One() + selector0 = c.api.Or( + c.scalarApi.IsZero(s), + c.scalarApi.IsZero( + c.scalarApi.Add(s, one)), + ) + _s = c.scalarApi.Select(selector0, one, s) + } + + // Instead of computing [s]P=Q, we check that Q-[s]P == 0. + // Checking Q - [s]P = 0 is equivalent to [v]Q + [-s*v]P = 0 for some nonzero v. + // + // The GLV curves supported in gnark have j-invariant 0, which means the eigenvalue + // of the GLV endomorphism is a primitive cube root of unity. If we write + // v, s and r as Eisenstein integers we can express the check as: + // + // [v1 + λ*v2]Q + [u1 + λ*u2]P = 0 + // [v1]Q + [v2]phi(Q) + [u1]P + [u2]phi(P) = 0 + // + // where (v1 + λ*v2)*(s1 + λ*s2) = u1 + λu2 mod (r1 + λ*r2) + // and u1, u2, v1, v2 < r^{1/4} (up to a constant factor). + // + // This can be done as follows: + // 1. decompose s into s1 + λ*s2 mod r s.t. s1, s2 < sqrt(r) (hinted classical GLV decomposition). + // 2. decompose r into r1 + λ*r2 s.t. r1, r2 < sqrt(r) (hardcoded half-GCD of λ mod r). + // 3. find u1, u2, v1, v2 < c*r^{1/4} s.t. (v1 + λ*v2)*(s1 + λ*s2) = (u1 + λ*u2) mod (r1 + λ*r2). + // This can be done through a hinted half-GCD in the number field + // K=Q[w]/f(w). This corresponds to K being the Eisenstein ring of + // integers i.e. w is a primitive cube root of unity, f(w)=w^2+w+1=0. + // + // The hint returns u1, u2, v1, v2. + // In-circuit we check that (v1 + λ*v2)*s = (u1 + λ*u2) mod r + sd, err := c.scalarApi.NewHint(halfGCDEisenstein, 5, _s, c.eigenvalue) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + u1, u2, v1, v2 := sd[0], sd[1], sd[2], sd[3] + + // Eisenstein integers real and imaginary parts can be negative. So we + // return the absolute value in the hint and negate the corresponding + // points here when needed. + signs, err := c.scalarApi.NewHintWithNativeOutput(halfGCDEisensteinSigns, 5, _s, c.eigenvalue) + if err != nil { + panic(fmt.Sprintf("halfGCDSigns hint: %v", err)) + } + isNegu1, isNegu2, isNegv1, isNegv2 := signs[0], signs[1], signs[2], signs[3] + + // We need to check that: + // s*(v1 + λ*v2) + u1 + λ*u2 = 0 + var st S + sv1 := c.scalarApi.Mul(_s, v1) + sλv2 := c.scalarApi.Mul(_s, c.scalarApi.Mul(c.eigenvalue, v2)) + λu2 := c.scalarApi.Mul(c.eigenvalue, u2) + zero := c.scalarApi.Zero() + + lhs1 := c.scalarApi.Select(isNegv1, zero, sv1) + lhs2 := c.scalarApi.Select(isNegv2, zero, sλv2) + lhs3 := c.scalarApi.Select(isNegu1, zero, u1) + lhs4 := c.scalarApi.Select(isNegu2, zero, λu2) + lhs := c.scalarApi.Add( + c.scalarApi.Add(lhs1, lhs2), + c.scalarApi.Add(lhs3, lhs4), + ) + + rhs1 := c.scalarApi.Select(isNegv1, sv1, zero) + rhs2 := c.scalarApi.Select(isNegv2, sλv2, zero) + rhs3 := c.scalarApi.Select(isNegu1, u1, zero) + rhs4 := c.scalarApi.Select(isNegu2, λu2, zero) + rhs := c.scalarApi.Add( + c.scalarApi.Add(rhs1, rhs2), + c.scalarApi.Add(rhs3, rhs4), + ) + + c.scalarApi.AssertIsEqual(lhs, rhs) + + // Next we compute the hinted scalar mul Q = [s]P + // P coordinates are in Fp and the scalar s in Fr + // we decompose Q.X, Q.Y, s into limbs and recompose them in the hint. + var inps []frontend.Variable + inps = append(inps, P.X.Limbs...) + inps = append(inps, P.Y.Limbs...) + inps = append(inps, s.Limbs...) + point, err := c.baseApi.NewHintWithNativeInput(scalarMulHint, 2, inps...) + if err != nil { + panic(fmt.Sprintf("scalar mul hint: %v", err)) + } + Q := &AffinePoint[B]{X: *point[0], Y: *point[1]} + + // handle (0,0)-point + var _selector0, _selector1 frontend.Variable + _P := P + if cfg.CompleteArithmetic { + // if Q=(0,0) we assign a dummy point to Q and continue + Q = c.Select(selector0, &c.GeneratorMultiples()[3], Q) + // if P=(0,0) we assign a dummy point to P and continue + _selector0 = c.api.And(c.baseApi.IsZero(&P.X), c.baseApi.IsZero(&P.Y)) + _P = c.Select(_selector0, &c.GeneratorMultiples()[4], P) + // if s=±1 we assign a dummy point to Q and continue + _selector1 = c.baseApi.IsZero(c.baseApi.Sub(&P.X, &Q.X)) + Q = c.Select(_selector1, &c.GeneratorMultiples()[3], Q) + } + + // precompute -P, -Φ(P), Φ(P) + var tableP, tablePhiP [2]*AffinePoint[B] + negPY := c.baseApi.Neg(&_P.Y) + tableP[1] = &AffinePoint[B]{ + X: _P.X, + Y: *c.baseApi.Select(isNegu1, negPY, &_P.Y), + } + tableP[0] = c.Neg(tableP[1]) + tablePhiP[1] = &AffinePoint[B]{ + X: *c.baseApi.Mul(&_P.X, c.thirdRootOne), + Y: *c.baseApi.Select(isNegu2, negPY, &_P.Y), + } + tablePhiP[0] = c.Neg(tablePhiP[1]) + + // precompute -Q, -Φ(Q), Φ(Q) + var tableQ, tablePhiQ [2]*AffinePoint[B] + negQY := c.baseApi.Neg(&Q.Y) + tableQ[1] = &AffinePoint[B]{ + X: Q.X, + Y: *c.baseApi.Select(isNegv1, negQY, &Q.Y), + } + tableQ[0] = c.Neg(tableQ[1]) + tablePhiQ[1] = &AffinePoint[B]{ + X: *c.baseApi.Mul(&Q.X, c.thirdRootOne), + Y: *c.baseApi.Select(isNegv2, negQY, &Q.Y), + } + tablePhiQ[0] = c.Neg(tablePhiQ[1]) + + // precompute -P-Q, P+Q, P-Q, -P+Q, -Φ(P)-Φ(Q), Φ(P)+Φ(Q), Φ(P)-Φ(Q), -Φ(P)+Φ(Q) + var tableS, tablePhiS [4]*AffinePoint[B] + tableS[0] = c.Add(tableP[0], tableQ[0]) + tableS[1] = c.Neg(tableS[0]) + tableS[2] = c.Add(tableP[1], tableQ[0]) + tableS[3] = c.Neg(tableS[2]) + tablePhiS[0] = c.Add(tablePhiP[0], tablePhiQ[0]) + tablePhiS[1] = c.Neg(tablePhiS[0]) + tablePhiS[2] = c.Add(tablePhiP[1], tablePhiQ[0]) + tablePhiS[3] = c.Neg(tablePhiS[2]) + + // we suppose that the first bits of the sub-scalars are 1 and set: + // Acc = P + Q + Φ(P) + Φ(Q) + Acc := c.Add(tableS[1], tablePhiS[1]) + B1 := Acc + // then we add G (the base point) to Acc to avoid incomplete additions in + // the loop, because when doing doubleAndAdd(Acc, Bi) as (Acc+Bi)+Acc it + // might happen that Acc==Bi or Acc==-Bi. But now we force Acc to be + // different than the stored Bi. However, at the end, Acc will not be the + // point at infinity but [2^nbits]G. + // + // N.B.: Acc cannot be equal to G, otherwise this means G = -Φ²([s+1]P) + g := c.Generator() + Acc = c.Add(Acc, g) + + // u1, u2, v1, v2 < r^{1/4} (up to a constant factor). + // We prove that the factor is log_(3/sqrt(3)))(r). + // so we need to add 9 bits to r^{1/4}.nbits(). + nbits := st.Modulus().BitLen()>>2 + 9 + u1bits := c.scalarApi.ToBits(u1) + u2bits := c.scalarApi.ToBits(u2) + v1bits := c.scalarApi.ToBits(v1) + v2bits := c.scalarApi.ToBits(v2) + + // At each iteration we look up the point Bi from: + // B1 = +P + Q + Φ(P) + Φ(Q) + // B2 = +P + Q + Φ(P) - Φ(Q) + B2 := c.Add(tableS[1], tablePhiS[2]) + // B3 = +P + Q - Φ(P) + Φ(Q) + B3 := c.Add(tableS[1], tablePhiS[3]) + // B4 = +P + Q - Φ(P) - Φ(Q) + B4 := c.Add(tableS[1], tablePhiS[0]) + // B5 = +P - Q + Φ(P) + Φ(Q) + B5 := c.Add(tableS[2], tablePhiS[1]) + // B6 = +P - Q + Φ(P) - Φ(Q) + B6 := c.Add(tableS[2], tablePhiS[2]) + // B7 = +P - Q - Φ(P) + Φ(Q) + B7 := c.Add(tableS[2], tablePhiS[3]) + // B8 = +P - Q - Φ(P) - Φ(Q) + B8 := c.Add(tableS[2], tablePhiS[0]) + // B9 = -P + Q + Φ(P) + Φ(Q) + B9 := c.Neg(B8) + // B10 = -P + Q + Φ(P) - Φ(Q) + B10 := c.Neg(B7) + // B11 = -P + Q - Φ(P) + Φ(Q) + B11 := c.Neg(B6) + // B12 = -P + Q - Φ(P) - Φ(Q) + B12 := c.Neg(B5) + // B13 = -P - Q + Φ(P) + Φ(Q) + B13 := c.Neg(B4) + // B14 = -P - Q + Φ(P) - Φ(Q) + B14 := c.Neg(B3) + // B15 = -P - Q - Φ(P) + Φ(Q) + B15 := c.Neg(B2) + // B16 = -P - Q - Φ(P) - Φ(Q) + B16 := c.Neg(B1) + // note that half the points are negatives of the other half, + // hence have the same X coordinates. + + var Bi *AffinePoint[B] + for i := nbits - 1; i > 0; i-- { + // selectorY takes values in [0,15] + selectorY := c.api.Add( + u1bits[i], + c.api.Mul(u2bits[i], 2), + c.api.Mul(v1bits[i], 4), + c.api.Mul(v2bits[i], 8), + ) + // selectorX takes values in [0,7] s.t.: + // - when selectorY < 8: selectorX = selectorY + // - when selectorY >= 8: selectorX = 15 - selectorY + selectorX := c.api.Add( + c.api.Mul(selectorY, c.api.Sub(1, c.api.Mul(v2bits[i], 2))), + c.api.Mul(v2bits[i], 15), + ) + // Bi.Y are distincts so we need a 16-to-1 multiplexer, + // but only half of the Bi.X are distinct so we need a 8-to-1. + Bi = &AffinePoint[B]{ + X: *c.baseApi.Mux(selectorX, + &B16.X, &B8.X, &B14.X, &B6.X, &B12.X, &B4.X, &B10.X, &B2.X, + ), + Y: *c.baseApi.Mux(selectorY, + &B16.Y, &B8.Y, &B14.Y, &B6.Y, &B12.Y, &B4.Y, &B10.Y, &B2.Y, + &B15.Y, &B7.Y, &B13.Y, &B5.Y, &B11.Y, &B3.Y, &B9.Y, &B1.Y, + ), + } + // Acc = [2]Acc + Bi + Acc = c.doubleAndAdd(Acc, Bi) + } + + // i = 0 + // subtract the P, Q, Φ(P), Φ(Q) if the first bits are 0 + tableP[0] = c.Add(tableP[0], Acc) + Acc = c.Select(u1bits[0], Acc, tableP[0]) + tablePhiP[0] = c.Add(tablePhiP[0], Acc) + Acc = c.Select(u2bits[0], Acc, tablePhiP[0]) + tableQ[0] = c.Add(tableQ[0], Acc) + Acc = c.Select(v1bits[0], Acc, tableQ[0]) + tablePhiQ[0] = c.Add(tablePhiQ[0], Acc) + Acc = c.Select(v2bits[0], Acc, tablePhiQ[0]) + + // Acc should be now equal to [2^nbits]G + gm := c.GeneratorMultiples()[nbits-1] + if cfg.CompleteArithmetic { + Acc = c.Select(c.api.Or(c.api.Or(selector0, _selector0), _selector1), &gm, Acc) + } + c.AssertIsEqual(Acc, &gm) + + return &AffinePoint[B]{ + X: *point[0], + Y: *point[1], + } +} diff --git a/std/algebra/emulated/sw_emulated/point_test.go b/std/algebra/emulated/sw_emulated/point_test.go index 2605407348..7e96cc3a9e 100644 --- a/std/algebra/emulated/sw_emulated/point_test.go +++ b/std/algebra/emulated/sw_emulated/point_test.go @@ -17,6 +17,8 @@ import ( "github.com/consensys/gnark-crypto/ecc/secp256k1" fp_secp "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" fr_secp "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + stark_curve "github.com/consensys/gnark-crypto/ecc/stark-curve" + fr_stark "github.com/consensys/gnark-crypto/ecc/stark-curve/fr" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/algopts" "github.com/consensys/gnark/std/math/emulated" @@ -784,6 +786,32 @@ func TestScalarMul6(t *testing.T) { assert.NoError(err) } +func TestScalarMul7(t *testing.T) { + assert := test.NewAssert(t) + var r fr_stark.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var res stark_curve.G1Affine + _, gen := stark_curve.Generators() + res.ScalarMultiplication(&gen, s) + + circuit := ScalarMulTest[emulated.STARKCurveFp, emulated.STARKCurveFr]{} + witness := ScalarMulTest[emulated.STARKCurveFp, emulated.STARKCurveFr]{ + S: emulated.ValueOf[emulated.STARKCurveFr](s), + P: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](gen.X), + Y: emulated.ValueOf[emulated.STARKCurveFp](gen.Y), + }, + Q: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](res.X), + Y: emulated.ValueOf[emulated.STARKCurveFp](res.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + type ScalarMulEdgeCasesTest[T, S emulated.FieldParams] struct { P, R AffinePoint[T] S emulated.Element[S] @@ -831,8 +859,8 @@ func TestScalarMulEdgeCasesEdgeCases(t *testing.T) { witness2 := ScalarMulEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ S: emulated.ValueOf[emulated.BN254Fr](new(big.Int)), P: AffinePoint[emulated.BN254Fp]{ - X: emulated.ValueOf[emulated.BN254Fp](S.X), - Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + X: emulated.ValueOf[emulated.BN254Fp](g.X), + Y: emulated.ValueOf[emulated.BN254Fp](g.Y), }, R: AffinePoint[emulated.BN254Fp]{ X: emulated.ValueOf[emulated.BN254Fp](infinity.X), @@ -841,6 +869,21 @@ func TestScalarMulEdgeCasesEdgeCases(t *testing.T) { } err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) assert.NoError(err) + + // 0 * (0,0) == (0,0) + witness3 := ScalarMulEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](new(big.Int)), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](0), + Y: emulated.ValueOf[emulated.BN254Fp](0), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) } type IsOnCurveTest[T, S emulated.FieldParams] struct { @@ -1408,7 +1451,7 @@ func (c *ScalarMulTestBounded[T, S]) Define(api frontend.API) error { if err != nil { return err } - res := cr.scalarMulGeneric(&c.P, &c.S, algopts.WithNbScalarBits(c.bits)) + res := cr.scalarMulJoye(&c.P, &c.S, algopts.WithNbScalarBits(c.bits)) cr.AssertIsEqual(res, &c.Q) return nil } @@ -1912,3 +1955,574 @@ func TestMux(t *testing.T) { err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) assert.NoError(err) } + +type ScalarMulJoyeTest[T, S emulated.FieldParams] struct { + P, Q AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulJoyeTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.scalarMulJoye(&c.P, &c.S) + cr.AssertIsEqual(res, &c.Q) + return nil +} + +func TestScalarMulJoye(t *testing.T) { + assert := test.NewAssert(t) + p256 := elliptic.P256() + s, err := rand.Int(rand.Reader, p256.Params().N) + assert.NoError(err) + px, py := p256.ScalarBaseMult(s.Bytes()) + + circuit := ScalarMulJoyeTest[emulated.P256Fp, emulated.P256Fr]{} + witness := ScalarMulJoyeTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](s), + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p256.Params().Gx), + Y: emulated.ValueOf[emulated.P256Fp](p256.Params().Gy), + }, + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + } + err = test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulJoye2(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S secp256k1.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulJoyeTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := ScalarMulJoyeTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](S.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +// fake GLV +type ScalarMulFakeGLVTest[T, S emulated.FieldParams] struct { + Q, R AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulFakeGLVTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.scalarMulFakeGLV(&c.Q, &c.S) + cr.AssertIsEqual(res, &c.R) + return nil +} + +func TestScalarMulFakeGLV(t *testing.T) { + assert := test.NewAssert(t) + p256 := elliptic.P256() + s, err := rand.Int(rand.Reader, p256.Params().N) + assert.NoError(err) + px, py := p256.ScalarBaseMult(s.Bytes()) + + circuit := ScalarMulFakeGLVTest[emulated.P256Fp, emulated.P256Fr]{} + witness := ScalarMulFakeGLVTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](s), + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p256.Params().Gx), + Y: emulated.ValueOf[emulated.P256Fp](p256.Params().Gy), + }, + R: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + } + err = test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulFakeGLV2(t *testing.T) { + assert := test.NewAssert(t) + p384 := elliptic.P384() + s, err := rand.Int(rand.Reader, p384.Params().N) + assert.NoError(err) + px, py := p384.ScalarBaseMult(s.Bytes()) + + circuit := ScalarMulFakeGLVTest[emulated.P384Fp, emulated.P384Fr]{} + witness := ScalarMulFakeGLVTest[emulated.P384Fp, emulated.P384Fr]{ + S: emulated.ValueOf[emulated.P384Fr](s), + Q: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](p384.Params().Gx), + Y: emulated.ValueOf[emulated.P384Fp](p384.Params().Gy), + }, + R: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](px), + Y: emulated.ValueOf[emulated.P384Fp](py), + }, + } + err = test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulFakeGLV3(t *testing.T) { + assert := test.NewAssert(t) + _, g := stark_curve.Generators() + var r fr_stark.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S stark_curve.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulFakeGLVTest[emulated.STARKCurveFp, emulated.STARKCurveFr]{} + witness := ScalarMulFakeGLVTest[emulated.STARKCurveFp, emulated.STARKCurveFr]{ + S: emulated.ValueOf[emulated.STARKCurveFr](s), + Q: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](g.X), + Y: emulated.ValueOf[emulated.STARKCurveFp](g.Y), + }, + R: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](S.X), + Y: emulated.ValueOf[emulated.STARKCurveFp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type ScalarMulFakeGLVEdgeCasesTest[T, S emulated.FieldParams] struct { + P, R AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulFakeGLVEdgeCasesTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.scalarMulFakeGLV(&c.P, &c.S, algopts.WithCompleteArithmetic()) + cr.AssertIsEqual(res, &c.R) + return nil +} + +func TestScalarMulFakeGLVEdgeCasesEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + p256 := elliptic.P256() + s, err := rand.Int(rand.Reader, p256.Params().N) + assert.NoError(err) + px, py := p256.ScalarBaseMult(s.Bytes()) + _, _ = p256.ScalarMult(px, py, s.Bytes()) + + circuit := ScalarMulFakeGLVEdgeCasesTest[emulated.P256Fp, emulated.P256Fr]{} + + // s * (0,0) == (0,0) + witness1 := ScalarMulFakeGLVEdgeCasesTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](s), + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + R: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * P == (0,0) + witness2 := ScalarMulFakeGLVEdgeCasesTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](new(big.Int)), + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + R: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * (0,0) == (0,0) + witness3 := ScalarMulFakeGLVEdgeCasesTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](new(big.Int)), + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + R: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulFakeGLVEdgeCasesEdgeCases2(t *testing.T) { + assert := test.NewAssert(t) + p384 := elliptic.P384() + s, err := rand.Int(rand.Reader, p384.Params().N) + assert.NoError(err) + px, py := p384.ScalarBaseMult(s.Bytes()) + _, _ = p384.ScalarMult(px, py, s.Bytes()) + + circuit := ScalarMulFakeGLVEdgeCasesTest[emulated.P384Fp, emulated.P384Fr]{} + + // s * (0,0) == (0,0) + witness1 := ScalarMulFakeGLVEdgeCasesTest[emulated.P384Fp, emulated.P384Fr]{ + S: emulated.ValueOf[emulated.P384Fr](s), + P: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](0), + Y: emulated.ValueOf[emulated.P384Fp](0), + }, + R: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](0), + Y: emulated.ValueOf[emulated.P384Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * P == (0,0) + witness2 := ScalarMulFakeGLVEdgeCasesTest[emulated.P384Fp, emulated.P384Fr]{ + S: emulated.ValueOf[emulated.P384Fr](new(big.Int)), + P: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](px), + Y: emulated.ValueOf[emulated.P384Fp](py), + }, + R: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](0), + Y: emulated.ValueOf[emulated.P384Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * (0,0) == (0,0) + witness3 := ScalarMulFakeGLVEdgeCasesTest[emulated.P384Fp, emulated.P384Fr]{ + S: emulated.ValueOf[emulated.P384Fr](new(big.Int)), + P: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](0), + Y: emulated.ValueOf[emulated.P384Fp](0), + }, + R: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](0), + Y: emulated.ValueOf[emulated.P384Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulFakeGLVEdgeCasesEdgeCases3(t *testing.T) { + assert := test.NewAssert(t) + _, g := stark_curve.Generators() + var r fr_stark.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S stark_curve.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulFakeGLVEdgeCasesTest[emulated.STARKCurveFp, emulated.STARKCurveFr]{} + + // s * (0,0) == (0,0) + witness1 := ScalarMulFakeGLVEdgeCasesTest[emulated.STARKCurveFp, emulated.STARKCurveFr]{ + S: emulated.ValueOf[emulated.STARKCurveFr](s), + P: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](0), + Y: emulated.ValueOf[emulated.STARKCurveFp](0), + }, + R: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](0), + Y: emulated.ValueOf[emulated.STARKCurveFp](0), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * P == (0,0) + witness2 := ScalarMulFakeGLVEdgeCasesTest[emulated.STARKCurveFp, emulated.STARKCurveFr]{ + S: emulated.ValueOf[emulated.STARKCurveFr](new(big.Int)), + P: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](S.X), + Y: emulated.ValueOf[emulated.STARKCurveFp](S.X), + }, + R: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](0), + Y: emulated.ValueOf[emulated.STARKCurveFp](0), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * (0,0) == (0,0) + witness3 := ScalarMulFakeGLVEdgeCasesTest[emulated.STARKCurveFp, emulated.STARKCurveFr]{ + S: emulated.ValueOf[emulated.STARKCurveFr](new(big.Int)), + P: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](0), + Y: emulated.ValueOf[emulated.STARKCurveFp](0), + }, + R: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](0), + Y: emulated.ValueOf[emulated.STARKCurveFp](0), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) +} + +type ScalarMulGLVAndFakeGLVTest[T, S emulated.FieldParams] struct { + Q, R AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulGLVAndFakeGLVTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res1 := cr.scalarMulGLVAndFakeGLV(&c.Q, &c.S) + res2 := cr.scalarMulGLVAndFakeGLV(&c.Q, &c.S, algopts.WithCompleteArithmetic()) + cr.AssertIsEqual(res1, &c.R) + cr.AssertIsEqual(res2, &c.R) + return nil +} + +func TestScalarMulGLVAndFakeGLV(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S secp256k1.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulGLVAndFakeGLVTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := ScalarMulGLVAndFakeGLVTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + R: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](S.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type ScalarMulGLVAndFakeGLVEdgeCasesTest[T, S emulated.FieldParams] struct { + P, R AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulGLVAndFakeGLVEdgeCasesTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.scalarMulGLVAndFakeGLV(&c.P, &c.S, algopts.WithCompleteArithmetic()) + cr.AssertIsEqual(res, &c.R) + return nil +} + +func TestScalarMulGLVAndFakeGLVEdgeCasesEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S secp256k1.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + + // s * (0,0) == (0,0) + witness1 := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](0), + Y: emulated.ValueOf[emulated.Secp256k1Fp](0), + }, + R: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](0), + Y: emulated.ValueOf[emulated.Secp256k1Fp](0), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * P == (0,0) + witness2 := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](new(big.Int)), + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + R: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](0), + Y: emulated.ValueOf[emulated.Secp256k1Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * (0,0) == (0,0) + witness3 := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](new(big.Int)), + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](0), + Y: emulated.ValueOf[emulated.Secp256k1Fp](0), + }, + R: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](0), + Y: emulated.ValueOf[emulated.Secp256k1Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) + + // 1 * P == P + witness4 := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](big.NewInt(1)), + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + R: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + } + err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField()) + assert.NoError(err) + + // -1 * P == -P + witness5 := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](big.NewInt(-1)), + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + R: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y.Neg(&g.Y)), + }, + } + err = test.IsSolved(&circuit, &witness5, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulGLVAndFakeGLVEdgeCasesEdgeCases2(t *testing.T) { + assert := test.NewAssert(t) + _, _, g, _ := bn254.Generators() + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S bn254.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{} + + // s * (0,0) == (0,0) + witness1 := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](s), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](0), + Y: emulated.ValueOf[emulated.BN254Fp](0), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](0), + Y: emulated.ValueOf[emulated.BN254Fp](0), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * P == (0,0) + witness2 := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](new(big.Int)), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](g.X), + Y: emulated.ValueOf[emulated.BN254Fp](g.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](0), + Y: emulated.ValueOf[emulated.BN254Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * (0,0) == (0,0) + witness3 := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](new(big.Int)), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](0), + Y: emulated.ValueOf[emulated.BN254Fp](0), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](0), + Y: emulated.ValueOf[emulated.BN254Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) + + // 1 * P == P + witness4 := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](big.NewInt(1)), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](g.X), + Y: emulated.ValueOf[emulated.BN254Fp](g.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](g.X), + Y: emulated.ValueOf[emulated.BN254Fp](g.Y), + }, + } + err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField()) + assert.NoError(err) + + // -1 * P == -P + witness5 := ScalarMulGLVAndFakeGLVEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](big.NewInt(-1)), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](g.X), + Y: emulated.ValueOf[emulated.BN254Fp](g.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](g.X), + Y: emulated.ValueOf[emulated.BN254Fp](g.Y.Neg(&g.Y)), + }, + } + err = test.IsSolved(&circuit, &witness5, testCurve.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/native/sw_bls12377/g1.go b/std/algebra/native/sw_bls12377/g1.go index 8297880fcb..84f038607a 100644 --- a/std/algebra/native/sw_bls12377/g1.go +++ b/std/algebra/native/sw_bls12377/g1.go @@ -17,6 +17,7 @@ limitations under the License. package sw_bls12377 import ( + "fmt" "math/big" "github.com/consensys/gnark-crypto/ecc" @@ -305,11 +306,11 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl } if cfg.CompleteArithmetic { - // subtract [2^N]G = (0,1) since we added H at the beginning + // subtract [2^N]H = (0,1) since we added H at the beginning Acc.AddUnified(api, G1Affine{X: 0, Y: -1}) Acc.Select(api, selector, G1Affine{X: 0, Y: 0}, Acc) } else { - // subtract [2^N]G = (0,1) since we added H at the beginning + // subtract [2^N]H = (0,1) since we added H at the beginning Acc.AddAssign(api, G1Affine{X: 0, Y: -1}) } @@ -481,24 +482,24 @@ func (P *G1Affine) jointScalarMul(api frontend.API, Q, R G1Affine, s, t frontend func (P *G1Affine) jointScalarMulUnsafe(api frontend.API, Q, R G1Affine, s, t frontend.Variable) *G1Affine { cc := getInnerCurveConfig(api.Compiler().Field()) - sd, err := api.Compiler().NewHint(decomposeScalarG1, 3, s) + sd, err := api.Compiler().NewHint(decomposeScalarG1Simple, 2, s) if err != nil { // err is non-nil only for invalid number of inputs panic(err) } s1, s2 := sd[0], sd[1] - td, err := api.Compiler().NewHint(decomposeScalarG1, 3, t) + td, err := api.Compiler().NewHint(decomposeScalarG1Simple, 2, t) if err != nil { // err is non-nil only for invalid number of inputs panic(err) } t1, t2 := td[0], td[1] - api.AssertIsEqual(api.Add(s1, api.Mul(s2, cc.lambda)), api.Add(s, api.Mul(cc.fr, sd[2]))) - api.AssertIsEqual(api.Add(t1, api.Mul(t2, cc.lambda)), api.Add(t, api.Mul(cc.fr, td[2]))) + api.AssertIsEqual(api.Add(s1, api.Mul(s2, cc.lambda)), s) + api.AssertIsEqual(api.Add(t1, api.Mul(t2, cc.lambda)), t) - nbits := cc.lambda.BitLen() + 1 + nbits := cc.lambda.BitLen() s1bits := api.ToBinary(s1, nbits) s2bits := api.ToBinary(s2, nbits) @@ -671,3 +672,212 @@ func (P *G1Affine) scalarBitsMul(api frontend.API, Q G1Affine, s1bits, s2bits [] return P } + +// fake-GLV +// +// N.B.: this method is more expensive than classical GLV, but it is useful for testing purposes. +func (R *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s frontend.Variable, opts ...algopts.AlgebraOption) *G1Affine { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(err) + } + cc := getInnerCurveConfig(api.Compiler().Field()) + + // handle zero-scalar + var selector0 frontend.Variable + _s := s + if cfg.CompleteArithmetic { + selector0 = api.IsZero(s) + _s = api.Select(selector0, 1, s) + } + + // Instead of computing [s]P=Q, we check that Q-[s]P == 0. + // Checking Q - [s]P = 0 is equivalent to [v]Q + [-s*v]P = 0 for some nonzero v. + // + // The GLV curves supported in gnark have j-invariant 0, which means the eigenvalue + // of the GLV endomorphism is a primitive cube root of unity. If we write + // v, s and r as Eisenstein integers we can express the check as: + // + // [v1 + λ*v2]Q + [u1 + λ*u2]P = 0 + // [v1]Q + [v2]phi(Q) + [u1]P + [u2]phi(P) = 0 + // + // where (v1 + λ*v2)*(s1 + λ*s2) = u1 + λu2 mod (r1 + λ*r2) + // and u1, u2, v1, v2 < r^{1/4} (up to a constant factor). + // + // This can be done as follows: + // 1. decompose s into s1 + λ*s2 mod r s.t. s1, s2 < sqrt(r) (hinted classical GLV decomposition). + // 2. decompose r into r1 + λ*r2 s.t. r1, r2 < sqrt(r) (hardcoded half-GCD of λ mod r). + // 3. find u1, u2, v1, v2 < c*r^{1/4} s.t. (v1 + λ*v2)*(s1 + λ*s2) = (u1 + λ*u2) mod (r1 + λ*r2). + // This can be done through a hinted half-GCD in the number field + // K=Q[w]/f(w). This corresponds to K being the Eisenstein ring of + // integers i.e. w is a primitive cube root of unity, f(w)=w^2+w+1=0. + // + // The hint returns u1, u2, v1, v2 and the quotient q. + // In-circuit we check that (v1 + λ*v2)*s = (u1 + λ*u2) + r*q + // + // N.B.: this check may overflow. But we don't use this method anywhere but for testing purposes. + sd, err := api.NewHint(halfGCDEisenstein, 5, _s, cc.lambda) + if err != nil { + panic(fmt.Sprintf("halfGCDEisenstein hint: %v", err)) + } + u1, u2, v1, v2, q := sd[0], sd[1], sd[2], sd[3], sd[4] + + // Eisenstein integers real and imaginary parts can be negative. So we + // return the absolute value in the hint and negate the corresponding + // points here when needed. + signs, err := api.NewHint(halfGCDEisensteinSigns, 5, _s, cc.lambda) + if err != nil { + panic(fmt.Sprintf("halfGCDEisensteinSigns hint: %v", err)) + } + isNegu1, isNegu2, isNegv1, isNegv2, isNegq := signs[0], signs[1], signs[2], signs[3], signs[4] + + // We need to check that: + // s*(v1 + λ*v2) + u1 + λ*u2 - r * q = 0 + sv1 := api.Mul(_s, v1) + sλv2 := api.Mul(_s, api.Mul(cc.lambda, v2)) + λu2 := api.Mul(cc.lambda, u2) + rq := api.Mul(cc.fr, q) + + lhs1 := api.Select(isNegv1, 0, sv1) + lhs2 := api.Select(isNegv2, 0, sλv2) + lhs3 := api.Select(isNegu1, 0, u1) + lhs4 := api.Select(isNegu2, 0, λu2) + lhs5 := api.Select(isNegq, rq, 0) + lhs := api.Add( + api.Add(lhs1, lhs2), + api.Add(lhs3, lhs4), + ) + lhs = api.Add(lhs, lhs5) + + rhs1 := api.Select(isNegv1, sv1, 0) + rhs2 := api.Select(isNegv2, sλv2, 0) + rhs3 := api.Select(isNegu1, u1, 0) + rhs4 := api.Select(isNegu2, λu2, 0) + rhs5 := api.Select(isNegq, 0, rq) + rhs := api.Add( + api.Add(rhs1, rhs2), + api.Add(rhs3, rhs4), + ) + rhs = api.Add(rhs, rhs5) + + api.AssertIsEqual(lhs, rhs) + + // Next we compute the hinted scalar mul Q = [s]P + point, err := api.NewHint(scalarMulGLVG1Hint, 2, P.X, P.Y, s) + if err != nil { + panic(fmt.Sprintf("scalar mul hint: %v", err)) + } + Q := G1Affine{X: point[0], Y: point[1]} + + // handle (0,0)-point + var _selector0 frontend.Variable + _P := P + if cfg.CompleteArithmetic { + // if Q=(0,0) we assign a dummy point to Q and continue + Q.Select(api, selector0, G1Affine{X: 1, Y: 0}, Q) + // if P=(0,0) we assign a dummy point to P and continue + _selector0 = api.And(api.IsZero(P.X), api.IsZero(P.Y)) + _P.Select(api, _selector0, G1Affine{X: 2, Y: 1}, P) + } + + // precompute -P, -Φ(P), Φ(P) + var tableP, tablePhiP [2]G1Affine + negPY := api.Neg(_P.Y) + tableP[1] = G1Affine{ + X: _P.X, + Y: api.Select(isNegu1, negPY, _P.Y), + } + tableP[0].Neg(api, tableP[1]) + tablePhiP[1] = G1Affine{ + X: api.Mul(_P.X, cc.thirdRootOne1), + Y: api.Select(isNegu2, negPY, _P.Y), + } + tablePhiP[0].Neg(api, tablePhiP[1]) + + // precompute -Q, -Φ(Q), Φ(Q) + var tableQ, tablePhiQ [2]G1Affine + negQY := api.Neg(Q.Y) + tableQ[1] = G1Affine{ + X: Q.X, + Y: api.Select(isNegv1, negQY, Q.Y), + } + tableQ[0].Neg(api, tableQ[1]) + tablePhiQ[1] = G1Affine{ + X: api.Mul(Q.X, cc.thirdRootOne1), + Y: api.Select(isNegv2, negQY, Q.Y), + } + tablePhiQ[0].Neg(api, tablePhiQ[1]) + + // precompute -P-Q, P+Q, P-Q, -P+Q, -Φ(P)-Φ(Q), Φ(P)+Φ(Q), Φ(P)-Φ(Q), -Φ(P)+Φ(Q) + var tableS, tablePhiS [4]G1Affine + tableS[0] = tableP[0] + tableS[0].AddAssign(api, tableQ[0]) + tableS[1].Neg(api, tableS[0]) + tableS[2] = tableP[1] + tableS[2].AddAssign(api, tableQ[0]) + tableS[3].Neg(api, tableS[2]) + tablePhiS[0] = tablePhiP[0] + tablePhiS[0].AddAssign(api, tablePhiQ[0]) + tablePhiS[1].Neg(api, tablePhiS[0]) + tablePhiS[2] = tablePhiP[1] + tablePhiS[2].AddAssign(api, tablePhiQ[0]) + tablePhiS[3].Neg(api, tablePhiS[2]) + + // we suppose that the first bits of the sub-scalars are 1 and set: + // Acc = P + Q + Φ(P) + Φ(Q) + Acc := tableS[1] + Acc.AddAssign(api, tablePhiS[1]) + // When doing doubleAndAdd(Acc, B) as (Acc+B)+Acc it might happen that + // Acc==B or -B. So we add the point H=(0,1) on BLS12-377 of order 2 to it + // to avoid incomplete additions in the loop by forcing Acc to be different + // than the stored B. Normally, the point H should be "killed out" by the + // first doubling in the loop and the result will remain unchanged. + // However, we are using affine coordinates that do not encode the infinity + // point. Given the affine formulae, doubling (0,1) results in (0,-1). + // Since the loop size N=nbits-1 is odd the result at the end should be + // [2^N]H = H = (0,1). + H := G1Affine{X: 0, Y: 1} + Acc.AddAssign(api, H) + + // u1, u2, v1, v2 < r^{1/4} (up to a constant factor). + // We prove that the factor is log_(3/sqrt(3)))(r). + // so we need to add 9 bits to r^{1/4}.nbits(). + nbits := cc.lambda.BitLen()>>1 + 9 // 72 + u1bits := api.ToBinary(u1, nbits) + u2bits := api.ToBinary(u2, nbits) + v1bits := api.ToBinary(v1, nbits) + v2bits := api.ToBinary(v2, nbits) + + var B G1Affine + for i := nbits - 1; i > 0; i-- { + B.X = api.Select(api.Xor(u1bits[i], v1bits[i]), tableS[2].X, tableS[0].X) + B.Y = api.Lookup2(u1bits[i], v1bits[i], tableS[0].Y, tableS[2].Y, tableS[3].Y, tableS[1].Y) + Acc.DoubleAndAdd(api, &Acc, &B) + B.X = api.Select(api.Xor(u2bits[i], v2bits[i]), tablePhiS[2].X, tablePhiS[0].X) + B.Y = api.Lookup2(u2bits[i], v2bits[i], tablePhiS[0].Y, tablePhiS[2].Y, tablePhiS[3].Y, tablePhiS[1].Y) + Acc.AddAssign(api, B) + } + + // i = 0 + // subtract the P, Q, Φ(P), Φ(Q) if the first bits are 0 + tableP[0].AddAssign(api, Acc) + Acc.Select(api, u1bits[0], Acc, tableP[0]) + tablePhiP[0].AddAssign(api, Acc) + Acc.Select(api, u2bits[0], Acc, tablePhiP[0]) + tableQ[0].AddAssign(api, Acc) + Acc.Select(api, v1bits[0], Acc, tableQ[0]) + tablePhiQ[0].AddAssign(api, Acc) + Acc.Select(api, v2bits[0], Acc, tablePhiQ[0]) + + // Acc should be now equal to H=(0,-1) + H = G1Affine{X: 0, Y: -1} + if cfg.CompleteArithmetic { + Acc.Select(api, api.Or(selector0, _selector0), H, Acc) + } + Acc.AssertIsEqual(api, H) + + R.X = point[0] + R.Y = point[1] + + return R +} diff --git a/std/algebra/native/sw_bls12377/g1_test.go b/std/algebra/native/sw_bls12377/g1_test.go index f8b4287ed9..af34f7256c 100644 --- a/std/algebra/native/sw_bls12377/g1_test.go +++ b/std/algebra/native/sw_bls12377/g1_test.go @@ -278,18 +278,17 @@ func TestConstantScalarMulG1(t *testing.T) { } type g1constantScalarMulEdgeCases struct { - A G1Affine - R *big.Int + A, Inf G1Affine + R *big.Int } func (circuit *g1constantScalarMulEdgeCases) Define(api frontend.API) error { expected1 := G1Affine{} expected2 := G1Affine{} - infinity := G1Affine{X: 0, Y: 0} expected1.constScalarMul(api, circuit.A, big.NewInt(0)) - expected2.constScalarMul(api, infinity, circuit.R, algopts.WithCompleteArithmetic()) - expected1.AssertIsEqual(api, infinity) - expected2.AssertIsEqual(api, infinity) + expected2.constScalarMul(api, circuit.Inf, circuit.R, algopts.WithCompleteArithmetic()) + expected1.AssertIsEqual(api, circuit.Inf) + expected2.AssertIsEqual(api, circuit.Inf) return nil } @@ -311,6 +310,9 @@ func TestConstantScalarMulG1EdgeCases(t *testing.T) { // br is a circuit parameter circuit.R = br + witness.Inf.X = 0 + witness.Inf.Y = 0 + assert := test.NewAssert(t) assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761)) @@ -353,18 +355,17 @@ func TestVarScalarMulG1(t *testing.T) { } type g1varScalarMulEdgeCases struct { - A G1Affine - R frontend.Variable + A, Inf G1Affine + R, Zero frontend.Variable } func (circuit *g1varScalarMulEdgeCases) Define(api frontend.API) error { expected1 := G1Affine{} expected2 := G1Affine{} - infinity := G1Affine{X: 0, Y: 0} - expected1.varScalarMul(api, circuit.A, 0, algopts.WithCompleteArithmetic()) - expected2.varScalarMul(api, infinity, circuit.R, algopts.WithCompleteArithmetic()) - expected1.AssertIsEqual(api, infinity) - expected2.AssertIsEqual(api, infinity) + expected2.varScalarMul(api, circuit.Inf, circuit.R, algopts.WithCompleteArithmetic()) + expected1.varScalarMul(api, circuit.A, circuit.Zero, algopts.WithCompleteArithmetic()) + expected1.AssertIsEqual(api, circuit.Inf) + expected2.AssertIsEqual(api, circuit.Inf) return nil } @@ -381,6 +382,9 @@ func TestVarScalarMulG1EdgeCases(t *testing.T) { witness.R = r.String() // assign the inputs witness.A.Assign(&a) + witness.Inf.X = 0 + witness.Inf.Y = 0 + witness.Zero = 0 assert := test.NewAssert(t) assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761)) @@ -626,9 +630,9 @@ func TestMultiScalarMul(t *testing.T) { } type g1JointScalarMulEdgeCases struct { - A, B G1Affine - C G1Affine `gnark:",public"` - R, S frontend.Variable + A, B, Inf G1Affine + C G1Affine `gnark:",public"` + R, S, Zero frontend.Variable } func (circuit *g1JointScalarMulEdgeCases) Define(api frontend.API) error { @@ -636,15 +640,14 @@ func (circuit *g1JointScalarMulEdgeCases) Define(api frontend.API) error { expected2 := G1Affine{} expected3 := G1Affine{} expected4 := G1Affine{} - infinity := G1Affine{X: 0, Y: 0} - expected1.jointScalarMul(api, infinity, infinity, circuit.R, circuit.S, algopts.WithCompleteArithmetic()) - expected2.jointScalarMul(api, circuit.A, circuit.B, big.NewInt(0), big.NewInt(0), algopts.WithCompleteArithmetic()) - expected3.jointScalarMul(api, circuit.A, infinity, circuit.R, circuit.S, algopts.WithCompleteArithmetic()) - expected4.jointScalarMul(api, circuit.A, circuit.B, circuit.R, big.NewInt(0), algopts.WithCompleteArithmetic()) + expected1.jointScalarMul(api, circuit.Inf, circuit.Inf, circuit.R, circuit.S, algopts.WithCompleteArithmetic()) + expected2.jointScalarMul(api, circuit.A, circuit.B, circuit.Zero, circuit.Zero, algopts.WithCompleteArithmetic()) + expected3.jointScalarMul(api, circuit.A, circuit.Inf, circuit.R, circuit.S, algopts.WithCompleteArithmetic()) + expected4.jointScalarMul(api, circuit.A, circuit.B, circuit.R, circuit.Zero, algopts.WithCompleteArithmetic()) _expected := G1Affine{} _expected.ScalarMul(api, circuit.A, circuit.R, algopts.WithCompleteArithmetic()) - expected1.AssertIsEqual(api, infinity) - expected2.AssertIsEqual(api, infinity) + expected1.AssertIsEqual(api, circuit.Inf) + expected2.AssertIsEqual(api, circuit.Inf) expected3.AssertIsEqual(api, _expected) expected4.AssertIsEqual(api, _expected) return nil @@ -676,6 +679,10 @@ func TestJointScalarMulG1EdgeCases(t *testing.T) { c.FromJacobian(&_a) witness.C.Assign(&c) + witness.Inf.X = 0 + witness.Inf.Y = 0 + witness.Zero = 0 + assert := test.NewAssert(t) assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761)) } @@ -933,3 +940,81 @@ func TestMultiScalarMulFolded(t *testing.T) { }, &assignment, ecc.BW6_761.ScalarField()) assert.NoError(err) } + +// fake GLV +type scalarMulGLVAndFakeGLV struct { + A G1Affine + C G1Affine `gnark:",public"` + R frontend.Variable +} + +func (circuit *scalarMulGLVAndFakeGLV) Define(api frontend.API) error { + expected := G1Affine{} + expected.scalarMulGLVAndFakeGLV(api, circuit.A, circuit.R) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestScalarMulG1GLVAndFakeGLV(t *testing.T) { + // sample random point + _a := randomPointG1() + var a, c bls12377.G1Affine + a.FromJacobian(&_a) + + // create the cs + var circuit, witness scalarMulGLVAndFakeGLV + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // assign the inputs + witness.A.Assign(&a) + // compute the result + var br big.Int + _a.ScalarMultiplication(&_a, r.BigInt(&br)) + c.FromJacobian(&_a) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761)) +} + +type scalarMulGLVAndFakeGLVEdgeCases struct { + A, Inf G1Affine + R, Zero, One frontend.Variable +} + +func (circuit *scalarMulGLVAndFakeGLVEdgeCases) Define(api frontend.API) error { + expected1, expected2, expected3, expected4 := G1Affine{}, G1Affine{}, G1Affine{}, G1Affine{} + expected1.varScalarMul(api, circuit.A, circuit.Zero, algopts.WithCompleteArithmetic()) + expected2.varScalarMul(api, circuit.Inf, circuit.R, algopts.WithCompleteArithmetic()) + expected3.varScalarMul(api, circuit.Inf, circuit.Zero, algopts.WithCompleteArithmetic()) + expected4.varScalarMul(api, circuit.A, circuit.One, algopts.WithCompleteArithmetic()) + expected1.AssertIsEqual(api, circuit.Inf) + expected2.AssertIsEqual(api, circuit.Inf) + expected3.AssertIsEqual(api, circuit.Inf) + expected4.AssertIsEqual(api, circuit.A) + return nil +} + +func TestScalarMulG1GLVAndFakeGLVEdgeCases(t *testing.T) { + // sample random point + _a := randomPointG1() + var a bls12377.G1Affine + a.FromJacobian(&_a) + + // create the cs + var circuit, witness scalarMulGLVAndFakeGLVEdgeCases + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // assign the inputs + witness.A.Assign(&a) + + witness.Inf.X = 0 + witness.Inf.Y = 0 + witness.Zero = 0 + witness.One = 1 + + assert := test.NewAssert(t) + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761)) +} diff --git a/std/algebra/native/sw_bls12377/hints.go b/std/algebra/native/sw_bls12377/hints.go index d59ef955ef..9c0c86333d 100644 --- a/std/algebra/native/sw_bls12377/hints.go +++ b/std/algebra/native/sw_bls12377/hints.go @@ -5,6 +5,8 @@ import ( "math/big" "github.com/consensys/gnark-crypto/ecc" + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/field/eisenstein" "github.com/consensys/gnark/constraint/solver" ) @@ -13,6 +15,9 @@ func GetHints() []solver.Hint { decomposeScalarG1, decomposeScalarG1Simple, decomposeScalarG2, + scalarMulGLVG1Hint, + halfGCDEisenstein, + halfGCDEisensteinSigns, } } @@ -88,3 +93,130 @@ func decomposeScalarG2(scalarField *big.Int, inputs []*big.Int, outputs []*big.I return nil } + +func scalarMulGLVG1Hint(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 3 { + return fmt.Errorf("expecting three inputs") + } + if len(outputs) != 2 { + return fmt.Errorf("expecting two outputs") + } + + // compute the resulting point [s]Q + var P bls12377.G1Affine + P.X.SetBigInt(inputs[0]) + P.Y.SetBigInt(inputs[1]) + P.ScalarMultiplication(&P, inputs[2]) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + return nil +} + +func halfGCDEisenstein(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("expecting two input") + } + if len(outputs) != 5 { + return fmt.Errorf("expecting five outputs") + } + cc := getInnerCurveConfig(scalarField) + glvBasis := new(ecc.Lattice) + ecc.PrecomputeLattice(cc.fr, inputs[1], glvBasis) + r := eisenstein.ComplexNumber{ + A0: &glvBasis.V1[0], + A1: &glvBasis.V1[1], + } + sp := ecc.SplitScalar(inputs[0], glvBasis) + // in-circuit we check that Q - [s]P = 0 or equivalently Q + [-s]P = 0 + // so here we return -s instead of s. + s := eisenstein.ComplexNumber{ + A0: &sp[0], + A1: &sp[1], + } + s.Neg(&s) + res := eisenstein.HalfGCD(&r, &s) + outputs[0].Set(res[0].A0) + outputs[1].Set(res[0].A1) + outputs[2].Set(res[1].A0) + outputs[3].Set(res[1].A1) + outputs[4].Mul(res[1].A1, inputs[1]). + Add(outputs[4], res[1].A0). + Mul(outputs[4], inputs[0]). + Add(outputs[4], res[0].A0) + s.A0.Mul(res[0].A1, inputs[1]) + outputs[4].Add(outputs[4], s.A0). + Div(outputs[4], cc.fr) + + if outputs[0].Sign() == -1 { + outputs[0].Neg(outputs[0]) + } + if outputs[1].Sign() == -1 { + outputs[1].Neg(outputs[1]) + } + if outputs[2].Sign() == -1 { + outputs[2].Neg(outputs[2]) + } + if outputs[3].Sign() == -1 { + outputs[3].Neg(outputs[3]) + } + if outputs[4].Sign() == -1 { + outputs[4].Neg(outputs[4]) + } + + return nil +} + +func halfGCDEisensteinSigns(scalarField *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("expecting two input") + } + if len(outputs) != 5 { + return fmt.Errorf("expecting five outputs") + } + cc := getInnerCurveConfig(scalarField) + glvBasis := new(ecc.Lattice) + ecc.PrecomputeLattice(cc.fr, inputs[1], glvBasis) + r := eisenstein.ComplexNumber{ + A0: &glvBasis.V1[0], + A1: &glvBasis.V1[1], + } + sp := ecc.SplitScalar(inputs[0], glvBasis) + // in-circuit we check that Q - [s]P = 0 or equivalently Q + [-s]P = 0 + // so here we return -s instead of s. + s := eisenstein.ComplexNumber{ + A0: &sp[0], + A1: &sp[1], + } + s.Neg(&s) + + outputs[0].SetUint64(0) + outputs[1].SetUint64(0) + outputs[2].SetUint64(0) + outputs[3].SetUint64(0) + outputs[4].SetUint64(0) + res := eisenstein.HalfGCD(&r, &s) + s.A1.Mul(res[1].A1, inputs[1]). + Add(s.A1, res[1].A0). + Mul(s.A1, inputs[0]). + Add(s.A1, res[0].A0) + s.A0.Mul(res[0].A1, inputs[1]) + s.A1.Add(s.A1, s.A0). + Div(s.A1, cc.fr) + + if res[0].A0.Sign() == -1 { + outputs[0].SetUint64(1) + } + if res[0].A1.Sign() == -1 { + outputs[1].SetUint64(1) + } + if res[1].A0.Sign() == -1 { + outputs[2].SetUint64(1) + } + if res[1].A1.Sign() == -1 { + outputs[3].SetUint64(1) + } + if s.A1.Sign() == -1 { + outputs[4].SetUint64(1) + } + return nil +} diff --git a/std/algebra/native/sw_bls24315/g1.go b/std/algebra/native/sw_bls24315/g1.go index afb8cec716..062dbf23ab 100644 --- a/std/algebra/native/sw_bls24315/g1.go +++ b/std/algebra/native/sw_bls24315/g1.go @@ -225,7 +225,7 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // hence have the same X coordinates. // However when doing doubleAndAdd(Acc, B) as (Acc+B)+Acc it might happen - // that Acc==B or -B. So we add the point H=(0,1) on BLS12-377 of order 2 + // that Acc==B or -B. So we add the point H=(0,1) on BLS24-315 of order 2 // to it to avoid incomplete additions in the loop by forcing Acc to be // different than the stored B. Normally, the point H should be "killed // out" by the first doubling in the loop and the result will remain diff --git a/std/algebra/native/twistededwards/curve.go b/std/algebra/native/twistededwards/curve.go index bcc5f36119..9349e276e5 100644 --- a/std/algebra/native/twistededwards/curve.go +++ b/std/algebra/native/twistededwards/curve.go @@ -46,14 +46,7 @@ func (c *curve) AssertIsOnCurve(p1 Point) { } func (c *curve) ScalarMul(p1 Point, scalar frontend.Variable) Point { var p Point - if c.endo != nil { - // TODO restore - // this is disabled until this issue is solved https://github.com/ConsenSys/gnark/issues/268 - // p.scalarMulGLV(c.api, &p1, scalar, c.params, c.endo) - p.scalarMul(c.api, &p1, scalar, c.params) - } else { - p.scalarMul(c.api, &p1, scalar, c.params) - } + p.scalarMul(c.api, &p1, scalar, c.params, c.endo) return p } func (c *curve) DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point { diff --git a/std/algebra/native/twistededwards/curve_test.go b/std/algebra/native/twistededwards/curve_test.go index b12a027c01..1aee963405 100644 --- a/std/algebra/native/twistededwards/curve_test.go +++ b/std/algebra/native/twistededwards/curve_test.go @@ -419,3 +419,26 @@ func (p *CurveParams) randomScalar() *big.Int { r, _ := rand.Int(rand.Reader, p.Order) return r } + +type varScalarMul struct { + curveID twistededwards.ID + P Point + R Point + S frontend.Variable +} + +func (circuit *varScalarMul) Define(api frontend.API) error { + + // get edwards curve curve + curve, err := NewEdCurve(api, circuit.curveID) + if err != nil { + return err + } + + // scalar mul + res := curve.ScalarMul(circuit.P, circuit.S) + api.AssertIsEqual(res.X, circuit.R.X) + api.AssertIsEqual(res.Y, circuit.R.Y) + + return nil +} diff --git a/std/algebra/native/twistededwards/hints.go b/std/algebra/native/twistededwards/hints.go new file mode 100644 index 0000000000..90a59619cd --- /dev/null +++ b/std/algebra/native/twistededwards/hints.go @@ -0,0 +1,168 @@ +package twistededwards + +import ( + "errors" + "fmt" + "math/big" + "sync" + + "github.com/consensys/gnark-crypto/ecc" + edbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards" + "github.com/consensys/gnark-crypto/ecc/bls12-381/bandersnatch" + jubjub "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards" + edbls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/twistededwards" + edbls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/twistededwards" + babyjubjub "github.com/consensys/gnark-crypto/ecc/bn254/twistededwards" + edbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards" + edbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" + "github.com/consensys/gnark/constraint/solver" +) + +func GetHints() []solver.Hint { + return []solver.Hint{ + halfGCD, + scalarMulHint, + decomposeScalar, + } +} + +func init() { + solver.RegisterHint(GetHints()...) +} + +type glvParams struct { + lambda, order big.Int + glvBasis ecc.Lattice +} + +func decomposeScalar(scalarField *big.Int, inputs []*big.Int, res []*big.Int) error { + // the efficient endomorphism exists on Bandersnatch only + if scalarField.Cmp(ecc.BLS12_381.ScalarField()) != 0 { + return errors.New("no efficient endomorphism is available on this curve") + } + var glv glvParams + var init sync.Once + init.Do(func() { + glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) + glv.order.SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) + ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis) + }) + + // sp[0] is always negative because, in SplitScalar(), we always round above + // the determinant/2 computed in PrecomputeLattice() which is negative for Bandersnatch. + // Thus taking -sp[0] here and negating the point in ScalarMul(). + // If we keep -sp[0] it will be reduced mod r (the BLS12-381 prime order) + // and not the Bandersnatch prime order (Order) and the result will be incorrect. + // Also, if we reduce it mod Order here, we can't use api.ToBinary(sp[0], 129) + // and hence we can't reduce optimally the number of constraints. + sp := ecc.SplitScalar(inputs[0], &glv.glvBasis) + res[0].Neg(&(sp[0])) + res[1].Set(&(sp[1])) + + // figure out how many times we have overflowed + res[2].Mul(res[1], &glv.lambda).Sub(res[2], res[0]) + res[2].Sub(res[2], inputs[0]) + res[2].Div(res[2], &glv.order) + + return nil +} + +func halfGCD(mod *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("expecting two inputs") + } + if len(outputs) != 4 { + return fmt.Errorf("expecting four outputs") + } + glvBasis := new(ecc.Lattice) + ecc.PrecomputeLattice(inputs[1], inputs[0], glvBasis) + outputs[0].Set(&glvBasis.V1[0]) + outputs[1].Set(&glvBasis.V1[1]) + + // figure out how many times we have overflowed + // s2 * s + s1 = k*r + outputs[3].Mul(outputs[1], inputs[0]). + Add(outputs[3], outputs[0]). + Div(outputs[3], inputs[1]) + + outputs[2].SetUint64(0) + if outputs[1].Sign() == -1 { + outputs[1].Neg(outputs[1]) + outputs[2].SetUint64(1) + } + + return nil +} + +func scalarMulHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 4 { + return fmt.Errorf("expecting four inputs") + } + if len(outputs) != 2 { + return fmt.Errorf("expecting two outputs") + } + // compute the resulting point [s]Q + if field.Cmp(ecc.BLS12_381.ScalarField()) == 0 { + order, _ := new(big.Int).SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) + if inputs[3].Cmp(order) == 0 { + var P bandersnatch.PointAffine + P.X.SetBigInt(inputs[0]) + P.Y.SetBigInt(inputs[1]) + P.ScalarMultiplication(&P, inputs[2]) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else { + var P jubjub.PointAffine + P.X.SetBigInt(inputs[0]) + P.Y.SetBigInt(inputs[1]) + P.ScalarMultiplication(&P, inputs[2]) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } + } else if field.Cmp(ecc.BN254.ScalarField()) == 0 { + var P babyjubjub.PointAffine + P.X.SetBigInt(inputs[0]) + P.Y.SetBigInt(inputs[1]) + P.ScalarMultiplication(&P, inputs[2]) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else if field.Cmp(ecc.BLS12_377.ScalarField()) == 0 { + var P edbls12377.PointAffine + P.X.SetBigInt(inputs[0]) + P.Y.SetBigInt(inputs[1]) + P.ScalarMultiplication(&P, inputs[2]) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else if field.Cmp(ecc.BLS24_315.ScalarField()) == 0 { + var P edbls24315.PointAffine + P.X.SetBigInt(inputs[0]) + P.Y.SetBigInt(inputs[1]) + P.ScalarMultiplication(&P, inputs[2]) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else if field.Cmp(ecc.BLS24_317.ScalarField()) == 0 { + var P edbls24317.PointAffine + P.X.SetBigInt(inputs[0]) + P.Y.SetBigInt(inputs[1]) + P.ScalarMultiplication(&P, inputs[2]) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else if field.Cmp(ecc.BW6_761.ScalarField()) == 0 { + var P edbw6761.PointAffine + P.X.SetBigInt(inputs[0]) + P.Y.SetBigInt(inputs[1]) + P.ScalarMultiplication(&P, inputs[2]) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else if field.Cmp(ecc.BW6_633.ScalarField()) == 0 { + var P edbw6633.PointAffine + P.X.SetBigInt(inputs[0]) + P.Y.SetBigInt(inputs[1]) + P.ScalarMultiplication(&P, inputs[2]) + P.X.BigInt(outputs[0]) + P.Y.BigInt(outputs[1]) + } else { + return fmt.Errorf("scalarMulHint: unknown curve") + } + return nil +} diff --git a/std/algebra/native/twistededwards/point.go b/std/algebra/native/twistededwards/point.go index dbacdb30d5..4b17d61a26 100644 --- a/std/algebra/native/twistededwards/point.go +++ b/std/algebra/native/twistededwards/point.go @@ -16,9 +16,7 @@ limitations under the License. package twistededwards -import ( - "github.com/consensys/gnark/frontend" -) +import "github.com/consensys/gnark/frontend" // neg computes the negative of a point in SNARK coordinates func (p *Point) neg(api frontend.API, p1 *Point) *Point { @@ -95,17 +93,12 @@ func (p *Point) double(api frontend.API, p1 *Point, curve *CurveParams) *Point { return p } -// scalarMul computes the scalar multiplication of a point on a twisted Edwards curve +// scalarMulGeneric computes the scalar multiplication of a point on a twisted Edwards curve // p1: base point (as snark point) // curve: parameters of the Edwards curve // scal: scalar as a SNARK constraint // Standard left to right double and add -func (p *Point) scalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo ...*EndoParams) *Point { - if len(endo) == 1 && endo[0] != nil { - // use glv - return p.scalarMulGLV(api, p1, scalar, curve, endo[0]) - } - +func (p *Point) scalarMulGeneric(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo ...*EndoParams) *Point { // first unpack the scalar b := api.ToBinary(scalar) @@ -142,6 +135,15 @@ func (p *Point) scalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, return p } +// scalarMul computes the scalar multiplication of a point on a twisted Edwards curve +// p1: base point (as snark point) +// curve: parameters of the Edwards curve +// scal: scalar as a SNARK constraint +// Standard left to right double and add +func (p *Point) scalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo ...*EndoParams) *Point { + return p.scalarMulFakeGLV(api, p1, scalar, curve) +} + // doubleBaseScalarMul computes s1*P1+s2*P2 // where P1 and P2 are points on a twisted Edwards curve // and s1, s2 scalars. @@ -172,3 +174,134 @@ func (p *Point) doubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 fron return p } + +// GLV + +// phi endomorphism √-2 ∈ 𝒪₋₈ +// (x,y) → λ × (x,y) s.t. λ² = -2 mod Order +func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoParams) *Point { + + xy := api.Mul(p1.X, p1.Y) + yy := api.Mul(p1.Y, p1.Y) + f := api.Sub(1, yy) + f = api.Mul(f, endo.Endo[1]) + g := api.Add(yy, endo.Endo[0]) + g = api.Mul(g, endo.Endo[0]) + h := api.Sub(yy, endo.Endo[0]) + + p.X = api.DivUnchecked(f, xy) + p.Y = api.DivUnchecked(g, h) + + return p +} + +// scalarMulGLV computes the scalar multiplication of a point on a twisted +// Edwards curve à la GLV. +// p1: base point (as snark point) +// curve: parameters of the Edwards curve +// scal: scalar as a SNARK constraint +// Standard left to right double and add +func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo *EndoParams) *Point { + // the hints allow to decompose the scalar s into s1 and s2 such that + // s1 + λ * s2 == s mod Order, + // with λ s.t. λ² = -2 mod Order. + sd, err := api.NewHint(decomposeScalar, 3, scalar) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + s1, s2 := sd[0], sd[1] + + // -s1 + λ * s2 == s + k*Order + api.AssertIsEqual(api.Sub(api.Mul(s2, endo.Lambda), s1), api.Add(scalar, api.Mul(curve.Order, sd[2]))) + + // Normally s1 and s2 are of the max size sqrt(Order) = 128 + // But in a circuit, we force s1 to be negative by rounding always above. + // This changes the size bounds to 2*sqrt(Order) = 129. + n := 129 + + b1 := api.ToBinary(s1, n) + b2 := api.ToBinary(s2, n) + + var res, _p1, p2, p3, tmp Point + _p1.neg(api, p1) + p2.phi(api, p1, curve, endo) + p3.add(api, &_p1, &p2, curve) + + res.X = api.Lookup2(b1[n-1], b2[n-1], 0, _p1.X, p2.X, p3.X) + res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, _p1.Y, p2.Y, p3.Y) + + for i := n - 2; i >= 0; i-- { + res.double(api, &res, curve) + tmp.X = api.Lookup2(b1[i], b2[i], 0, _p1.X, p2.X, p3.X) + tmp.Y = api.Lookup2(b1[i], b2[i], 1, _p1.Y, p2.Y, p3.Y) + res.add(api, &res, &tmp, curve) + } + + p.X = res.X + p.Y = res.Y + + return p +} + +// scalarMulFakeGLV computes the scalar multiplication of a point on a twisted +// Edwards curve following https://hackmd.io/@yelhousni/Hy-aWld50 +// +// [s]p1 == q is equivalent to [s2]([s]p1 - q) = (0,1) which is [s1]p1 + [s2]q = (0,1) +// with s1, s2 < sqrt(Order) and s1 + s2 * s = 0 mod Order. +// +// p1: base point (as snark point) +// curve: parameters of the Edwards curve +// scal: scalar as a SNARK constraint +// Standard left to right double and add +func (p *Point) scalarMulFakeGLV(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams) *Point { + // the hints allow to decompose the scalar s into s1 and s2 such that + // s1 + s * s2 == 0 mod Order, + s, err := api.NewHint(halfGCD, 4, scalar, curve.Order) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + s1, s2, bit, k := s[0], s[1], s[2], s[3] + + // check that s1 + s2 * s == k*Order + _s2 := api.Mul(s2, scalar) + _k := api.Mul(k, curve.Order) + lhs := api.Select(bit, s1, api.Add(s1, _s2)) + rhs := api.Select(bit, api.Add(_k, _s2), _k) + api.AssertIsEqual(lhs, rhs) + + n := (curve.Order.BitLen() + 1) / 2 + b1 := api.ToBinary(s1, n) + b2 := api.ToBinary(s2, n) + + var res, p2, p3, tmp Point + q, err := api.NewHint(scalarMulHint, 2, p1.X, p1.Y, scalar, curve.Order) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + p2.X = api.Select(bit, api.Neg(q[0]), q[0]) + p2.Y = q[1] + + p3.add(api, p1, &p2, curve) + + res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, p3.X) + res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, p3.Y) + + for i := n - 2; i >= 0; i-- { + res.double(api, &res, curve) + tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, p3.X) + tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, p3.Y) + res.add(api, &res, &tmp, curve) + } + + api.AssertIsEqual(res.X, 0) + api.AssertIsEqual(res.Y, 1) + + p.X = q[0] + p.Y = q[1] + + return p +} diff --git a/std/algebra/native/twistededwards/scalarmul_glv.go b/std/algebra/native/twistededwards/scalarmul_glv.go deleted file mode 100644 index 7b959a2db4..0000000000 --- a/std/algebra/native/twistededwards/scalarmul_glv.go +++ /dev/null @@ -1,135 +0,0 @@ -/* -Copyright © 2022 ConsenSys Software Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package twistededwards - -import ( - "errors" - "math/big" - "sync" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/frontend" -) - -// phi endomorphism √-2 ∈ 𝒪₋₈ -// (x,y) → λ × (x,y) s.t. λ² = -2 mod Order -func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoParams) *Point { - - xy := api.Mul(p1.X, p1.Y) - yy := api.Mul(p1.Y, p1.Y) - f := api.Sub(1, yy) - f = api.Mul(f, endo.Endo[1]) - g := api.Add(yy, endo.Endo[0]) - g = api.Mul(g, endo.Endo[0]) - h := api.Sub(yy, endo.Endo[0]) - - p.X = api.DivUnchecked(f, xy) - p.Y = api.DivUnchecked(g, h) - - return p -} - -type glvParams struct { - lambda, order big.Int - glvBasis ecc.Lattice -} - -var DecomposeScalar = func(scalarField *big.Int, inputs []*big.Int, res []*big.Int) error { - // the efficient endomorphism exists on Bandersnatch only - if scalarField.Cmp(ecc.BLS12_381.ScalarField()) != 0 { - return errors.New("no efficient endomorphism is available on this curve") - } - var glv glvParams - var init sync.Once - init.Do(func() { - glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) - glv.order.SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) - ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis) - }) - - // sp[0] is always negative because, in SplitScalar(), we always round above - // the determinant/2 computed in PrecomputeLattice() which is negative for Bandersnatch. - // Thus taking -sp[0] here and negating the point in ScalarMul(). - // If we keep -sp[0] it will be reduced mod r (the BLS12-381 prime order) - // and not the Bandersnatch prime order (Order) and the result will be incorrect. - // Also, if we reduce it mod Order here, we can't use api.ToBinary(sp[0], 129) - // and hence we can't reduce optimally the number of constraints. - sp := ecc.SplitScalar(inputs[0], &glv.glvBasis) - res[0].Neg(&(sp[0])) - res[1].Set(&(sp[1])) - - // figure out how many times we have overflowed - res[2].Mul(res[1], &glv.lambda).Sub(res[2], res[0]) - res[2].Sub(res[2], inputs[0]) - res[2].Div(res[2], &glv.order) - - return nil -} - -func init() { - solver.RegisterHint(DecomposeScalar) -} - -// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve -// p1: base point (as snark point) -// curve: parameters of the Edwards curve -// scal: scalar as a SNARK constraint -// Standard left to right double and add -func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo *EndoParams) *Point { - // the hints allow to decompose the scalar s into s1 and s2 such that - // s1 + λ * s2 == s mod Order, - // with λ s.t. λ² = -2 mod Order. - sd, err := api.NewHint(DecomposeScalar, 3, scalar) - if err != nil { - // err is non-nil only for invalid number of inputs - panic(err) - } - - s1, s2 := sd[0], sd[1] - - // -s1 + λ * s2 == s + k*Order - api.AssertIsEqual(api.Sub(api.Mul(s2, endo.Lambda), s1), api.Add(scalar, api.Mul(curve.Order, sd[2]))) - - // Normally s1 and s2 are of the max size sqrt(Order) = 128 - // But in a circuit, we force s1 to be negative by rounding always above. - // This changes the size bounds to 2*sqrt(Order) = 129. - n := 129 - - b1 := api.ToBinary(s1, n) - b2 := api.ToBinary(s2, n) - - var res, _p1, p2, p3, tmp Point - _p1.neg(api, p1) - p2.phi(api, p1, curve, endo) - p3.add(api, &_p1, &p2, curve) - - res.X = api.Lookup2(b1[n-1], b2[n-1], 0, _p1.X, p2.X, p3.X) - res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, _p1.Y, p2.Y, p3.Y) - - for i := n - 2; i >= 0; i-- { - res.double(api, &res, curve) - tmp.X = api.Lookup2(b1[i], b2[i], 0, _p1.X, p2.X, p3.X) - tmp.Y = api.Lookup2(b1[i], b2[i], 1, _p1.Y, p2.Y, p3.Y) - res.add(api, &res, &tmp, curve) - } - - p.X = res.X - p.Y = res.Y - - return p -} diff --git a/std/commitments/pedersen/assignment.go b/std/commitments/pedersen/assignment.go index 8822c58187..1c36d595a2 100644 --- a/std/commitments/pedersen/assignment.go +++ b/std/commitments/pedersen/assignment.go @@ -34,35 +34,35 @@ func ValueOfVerifyingKey[G2El algebra.G2ElementT](vk any) (VerifyingKey[G2El], e return ret, fmt.Errorf("expected *ped_bls12377.VerifyingKey, got %T", vk) } s.G = sw_bls12377.NewG2Affine(tVk.G) - s.GSigma = sw_bls12377.NewG2Affine(tVk.GSigma) + s.GSigmaNeg = sw_bls12377.NewG2Affine(tVk.GSigmaNeg) case *VerifyingKey[sw_bls12381.G2Affine]: tVk, ok := vk.(*ped_bls12381.VerifyingKey) if !ok { return ret, fmt.Errorf("expected *ped_bls12381.VerifyingKey, got %T", vk) } s.G = sw_bls12381.NewG2Affine(tVk.G) - s.GSigma = sw_bls12381.NewG2Affine(tVk.GSigma) + s.GSigmaNeg = sw_bls12381.NewG2Affine(tVk.GSigmaNeg) case *VerifyingKey[sw_bls24315.G2Affine]: tVk, ok := vk.(*ped_bls24315.VerifyingKey) if !ok { return ret, fmt.Errorf("expected *ped_bls24315.VerifyingKey, got %T", vk) } s.G = sw_bls24315.NewG2Affine(tVk.G) - s.GSigma = sw_bls24315.NewG2Affine(tVk.GSigma) + s.GSigmaNeg = sw_bls24315.NewG2Affine(tVk.GSigmaNeg) case *VerifyingKey[sw_bw6761.G2Affine]: tVk, ok := vk.(*ped_bw6761.VerifyingKey) if !ok { return ret, fmt.Errorf("expected *ped_bw6761.VerifyingKey, got %T", vk) } s.G = sw_bw6761.NewG2Affine(tVk.G) - s.GSigma = sw_bw6761.NewG2Affine(tVk.GSigma) + s.GSigmaNeg = sw_bw6761.NewG2Affine(tVk.GSigmaNeg) case *VerifyingKey[sw_bn254.G2Affine]: tVk, ok := vk.(*ped_bn254.VerifyingKey) if !ok { return ret, fmt.Errorf("expected *ped_bn254.VerifyingKey, got %T", vk) } s.G = sw_bn254.NewG2Affine(tVk.G) - s.GSigma = sw_bn254.NewG2Affine(tVk.GSigma) + s.GSigmaNeg = sw_bn254.NewG2Affine(tVk.GSigmaNeg) default: panic(fmt.Sprintf("unknown parametric type: %T", s)) } @@ -82,35 +82,35 @@ func ValueOfVerifyingKeyFixed[G2El algebra.G2ElementT](vk any) (VerifyingKey[G2E return ret, fmt.Errorf("expected *ped_bls12377.VerifyingKey, got %T", vk) } s.G = sw_bls12377.NewG2AffineFixed(tVk.G) - s.GSigma = sw_bls12377.NewG2AffineFixed(tVk.GSigma) + s.GSigmaNeg = sw_bls12377.NewG2AffineFixed(tVk.GSigmaNeg) case *VerifyingKey[sw_bls12381.G2Affine]: tVk, ok := vk.(*ped_bls12381.VerifyingKey) if !ok { return ret, fmt.Errorf("expected *ped_bls12381.VerifyingKey, got %T", vk) } s.G = sw_bls12381.NewG2AffineFixed(tVk.G) - s.GSigma = sw_bls12381.NewG2AffineFixed(tVk.GSigma) + s.GSigmaNeg = sw_bls12381.NewG2AffineFixed(tVk.GSigmaNeg) case *VerifyingKey[sw_bls24315.G2Affine]: tVk, ok := vk.(*ped_bls24315.VerifyingKey) if !ok { return ret, fmt.Errorf("expected *ped_bls24315.VerifyingKey, got %T", vk) } s.G = sw_bls24315.NewG2AffineFixed(tVk.G) - s.GSigma = sw_bls24315.NewG2AffineFixed(tVk.GSigma) + s.GSigmaNeg = sw_bls24315.NewG2AffineFixed(tVk.GSigmaNeg) case *VerifyingKey[sw_bw6761.G2Affine]: tVk, ok := vk.(*ped_bw6761.VerifyingKey) if !ok { return ret, fmt.Errorf("expected *ped_bw6761.VerifyingKey, got %T", vk) } s.G = sw_bw6761.NewG2AffineFixed(tVk.G) - s.GSigma = sw_bw6761.NewG2AffineFixed(tVk.GSigma) + s.GSigmaNeg = sw_bw6761.NewG2AffineFixed(tVk.GSigmaNeg) case *VerifyingKey[sw_bn254.G2Affine]: tVk, ok := vk.(*ped_bn254.VerifyingKey) if !ok { return ret, fmt.Errorf("expected *ped_bn254.VerifyingKey, got %T", vk) } s.G = sw_bn254.NewG2AffineFixed(tVk.G) - s.GSigma = sw_bn254.NewG2AffineFixed(tVk.GSigma) + s.GSigmaNeg = sw_bn254.NewG2AffineFixed(tVk.GSigmaNeg) default: return ret, fmt.Errorf("unknown parametric type: %T", s) } diff --git a/std/commitments/pedersen/verifier.go b/std/commitments/pedersen/verifier.go index b1a9afed05..77cd596be6 100644 --- a/std/commitments/pedersen/verifier.go +++ b/std/commitments/pedersen/verifier.go @@ -21,8 +21,8 @@ type KnowledgeProof[G1El algebra.G1ElementT] struct { // VerifyingKey is a verifying key for Pedersen vector commitments. type VerifyingKey[G2El algebra.G2ElementT] struct { - G G2El - GSigma G2El // (-1/σ)[G] for toxic σ + G G2El + GSigmaNeg G2El // (-1/σ)[G] for toxic σ } // Verifier verifies the knowledge proofs for a Pedersen commitments @@ -63,7 +63,7 @@ func (v *Verifier[FR, G1El, G2El, GtEl]) AssertCommitment(commitment Commitment[ v.pairing.AssertIsOnG1(&knowledgeProof.G1El) } - if err = v.pairing.PairingCheck([]*G1El{&commitment.G1El, &knowledgeProof.G1El}, []*G2El{&vk.GSigma, &vk.G}); err != nil { + if err = v.pairing.PairingCheck([]*G1El{&commitment.G1El, &knowledgeProof.G1El}, []*G2El{&vk.GSigmaNeg, &vk.G}); err != nil { return fmt.Errorf("pairing check failed: %w", err) } return nil diff --git a/std/evmprecompiles/01-ecrecover.go b/std/evmprecompiles/01-ecrecover.go index 0f5fdd67bb..8b2882474c 100644 --- a/std/evmprecompiles/01-ecrecover.go +++ b/std/evmprecompiles/01-ecrecover.go @@ -6,7 +6,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" - "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/emulated" ) @@ -51,8 +50,8 @@ func ECRecover(api frontend.API, msg emulated.Element[emulated.Secp256k1Fr], // EVM uses v \in {27, 28}, but everyone else v >= 0. Convert back v = api.Sub(v, 27) - // check that len(v) = 2 - vbits := bits.ToBinary(api, v, bits.WithNbDigits(2)) + // check that len(v) = 1 + api.AssertIsBoolean(v) // with the encoding we may have that r,s < 2*Fr (i.e. not r,s < Fr). Apply more thorough checks. frField.AssertIsLessOrEqual(&r, frField.Modulus()) @@ -90,10 +89,7 @@ func ECRecover(api frontend.API, msg emulated.Element[emulated.Secp256k1Fr], // compute R, the commitment // the signature as elements in Fr, but it actually represents elements in Fp. Convert to Fp element. rbits := frField.ToBits(&r) - rfp := fpField.FromBits(rbits...) - // compute R.X x = r+v[1]*fr - Rx := fpField.Select(vbits[1], fpField.NewElement(emfr.Modulus()), fpField.NewElement(0)) - Rx = fpField.Add(rfp, Rx) // Rx = r + v[1]*fr + Rx := fpField.FromBits(rbits...) Ry := fpField.Mul(Rx, Rx) // Ry = x^2 // compute R.y y = sqrt(x^3+7) Ry = fpField.Mul(Ry, Rx) // Ry = x^3 @@ -104,7 +100,7 @@ func ECRecover(api frontend.API, msg emulated.Element[emulated.Secp256k1Fr], Ry = fpField.Sqrt(Ry) // Ry = sqrt(x^3 + 7) // ensure the oddity of Ry is same as vbits[0], otherwise negate Ry Rybits := fpField.ToBits(Ry) - Ry = fpField.Select(api.Xor(vbits[0], Rybits[0]), fpField.Sub(fpField.Modulus(), Ry), Ry) + Ry = fpField.Select(api.Xor(v, Rybits[0]), fpField.Sub(fpField.Modulus(), Ry), Ry) R := sw_emulated.AffinePoint[emulated.Secp256k1Fp]{ X: *Rx, diff --git a/std/evmprecompiles/01-ecrecover_test.go b/std/evmprecompiles/01-ecrecover_test.go index ac76ff1a33..6a4bf3085e 100644 --- a/std/evmprecompiles/01-ecrecover_test.go +++ b/std/evmprecompiles/01-ecrecover_test.go @@ -2,6 +2,7 @@ package evmprecompiles import ( "crypto/rand" + "errors" "fmt" "math/big" "testing" @@ -260,3 +261,40 @@ func TestInvalidFailureTag(t *testing.T) { err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) assert.Error(err) } + +func TestLargeV(t *testing.T) { + assert := test.NewAssert(t) + var pk ecdsa.PublicKey + msg := []byte("test") + var rE, sE fr.Element + r, s := new(big.Int), new(big.Int) + for _, v := range []uint{2, 3} { + for { + rE.SetRandom() + sE.SetRandom() + rE.BigInt(r) + sE.BigInt(s) + if err := pk.RecoverFrom(msg, v, r, s); errors.Is(err, ecdsa.ErrNoSqrtR) { + continue + } else { + assert.NoError(err) + break + } + } + circuit := ecrecoverCircuit{} + witness := ecrecoverCircuit{ + Message: emulated.ValueOf[emulated.Secp256k1Fr](ecdsa.HashToInt(msg)), + V: v + 27, // EVM constant + R: emulated.ValueOf[emulated.Secp256k1Fr](r), + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + Strict: 0, + IsFailure: 0, + Expected: sw_emulated.AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](pk.A.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](pk.A.Y), + }, + } + err := test.IsSolved(&circuit, &witness, ecc.BLS12_377.ScalarField()) + assert.Error(err) + } +} diff --git a/std/evmprecompiles/05-expmod.go b/std/evmprecompiles/05-expmod.go index f1fa11357f..1c8cdeb888 100644 --- a/std/evmprecompiles/05-expmod.go +++ b/std/evmprecompiles/05-expmod.go @@ -15,16 +15,21 @@ import ( // // [MODEXP]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/expmod/index.html func Expmod[P emulated.FieldParams](api frontend.API, base, exp, modulus *emulated.Element[P]) *emulated.Element[P] { - // x^0 = 1 + // x^0 = 1 (unless mod 0 or mod 1) // x mod 0 = 0 + // x mod 1 = 0 + // 0^0 = 1 (unless mod 0 or mod 1) + f, err := emulated.NewField[P](api) if err != nil { panic(fmt.Sprintf("new field: %v", err)) } - // in case modulus is zero, then need to compute with dummy values and return zero as a result + // in case modulus is zero or one, then need to compute with dummy values and return zero as a result isZeroMod := f.IsZero(modulus) + isOneMod := f.IsZero(f.Sub(modulus, f.One())) + isOneOrZeroMod := api.Or(isZeroMod, isOneMod) modulus = f.Select(isZeroMod, f.One(), modulus) res := f.ModExp(base, exp, modulus) - res = f.Select(isZeroMod, f.Zero(), res) + res = f.Select(isOneOrZeroMod, f.Zero(), res) return res } diff --git a/std/evmprecompiles/05-expmod_test.go b/std/evmprecompiles/05-expmod_test.go index 5de7d95bcb..f2cd24c905 100644 --- a/std/evmprecompiles/05-expmod_test.go +++ b/std/evmprecompiles/05-expmod_test.go @@ -69,7 +69,7 @@ func TestEdgeCases(t *testing.T) { base, exp, modulus, result *big.Int }{ {big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0)}, // 0^0 = 0 mod 0 - {big.NewInt(0), big.NewInt(0), big.NewInt(1), big.NewInt(1)}, // 0^0 = 1 mod 1 + {big.NewInt(0), big.NewInt(0), big.NewInt(1), big.NewInt(0)}, // 0^0 = 0 mod 1 {big.NewInt(0), big.NewInt(0), big.NewInt(123), big.NewInt(1)}, // 0^0 = 1 mod 123 {big.NewInt(123), big.NewInt(123), big.NewInt(0), big.NewInt(0)}, // 123^123 = 0 mod 0 {big.NewInt(123), big.NewInt(123), big.NewInt(0), big.NewInt(0)}, // 123^123 = 0 mod 1 diff --git a/std/math/emulated/emparams/emparams.go b/std/math/emulated/emparams/emparams.go index 55d1008cf1..736530148c 100644 --- a/std/math/emulated/emparams/emparams.go +++ b/std/math/emulated/emparams/emparams.go @@ -282,7 +282,35 @@ type BLS24315Fr struct{ fourLimbPrimeField } func (fr BLS24315Fr) Modulus() *big.Int { return ecc.BLS24_315.ScalarField() } -// Mod1e4096 provides type parametrization for emulated arithmetic: +// STARKCurveFp provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x800000000000011000000000000000000000000000000000000000000000001 (base 16) +// 3618502788666131213697322783095070105623107215331596699973092056135872020481 (base 10) +// +// This is the base field of the STARK curve. +type STARKCurveFp struct{ fourLimbPrimeField } + +func (fp STARKCurveFp) Modulus() *big.Int { return ecc.STARK_CURVE.BaseField() } + +// STARKCurveFr provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x800000000000010ffffffffffffffffb781126dcae7b2321e66a241adc64d2f (base 16) +// 3618502788666131213697322783095070105526743751716087489154079457884512865583 (base 10) +// +// This is the scalar field of the STARK curve. +type STARKCurveFr struct{ fourLimbPrimeField } + +func (fp STARKCurveFr) Modulus() *big.Int { return ecc.STARK_CURVE.ScalarField() } + +// Mod1e4096 provides type parametrization for emulated aritmetic: // - limbs: 64 // - limb width: 64 bits // diff --git a/std/math/emulated/params.go b/std/math/emulated/params.go index 892d141d38..2b0dc9d179 100644 --- a/std/math/emulated/params.go +++ b/std/math/emulated/params.go @@ -17,6 +17,7 @@ import ( // - [BLS12381Fp] and [BLS12381Fr] // - [P256Fp] and [P256Fr] // - [P384Fp] and [P384Fr] +// - [STARKCurveFp] and [STARKCurveFr] type FieldParams interface { NbLimbs() uint // number of limbs to represent field element BitsPerLimb() uint // number of bits per limb. Top limb may contain less than limbSize bits. @@ -25,18 +26,20 @@ type FieldParams interface { } type ( - Goldilocks = emparams.Goldilocks - Secp256k1Fp = emparams.Secp256k1Fp - Secp256k1Fr = emparams.Secp256k1Fr - BN254Fp = emparams.BN254Fp - BN254Fr = emparams.BN254Fr - BLS12377Fp = emparams.BLS12377Fp - BLS12381Fp = emparams.BLS12381Fp - BLS12381Fr = emparams.BLS12381Fr - P256Fp = emparams.P256Fp - P256Fr = emparams.P256Fr - P384Fp = emparams.P384Fp - P384Fr = emparams.P384Fr - BW6761Fp = emparams.BW6761Fp - BW6761Fr = emparams.BW6761Fr + Goldilocks = emparams.Goldilocks + Secp256k1Fp = emparams.Secp256k1Fp + Secp256k1Fr = emparams.Secp256k1Fr + BN254Fp = emparams.BN254Fp + BN254Fr = emparams.BN254Fr + BLS12377Fp = emparams.BLS12377Fp + BLS12381Fp = emparams.BLS12381Fp + BLS12381Fr = emparams.BLS12381Fr + P256Fp = emparams.P256Fp + P256Fr = emparams.P256Fr + P384Fp = emparams.P384Fp + P384Fr = emparams.P384Fr + BW6761Fp = emparams.BW6761Fp + BW6761Fr = emparams.BW6761Fr + STARKCurveFp = emparams.STARKCurveFp + STARKCurveFr = emparams.STARKCurveFr ) diff --git a/std/math/polynomial/polynomial.go b/std/math/polynomial/polynomial.go index e09ef69ef1..59be7084f5 100644 --- a/std/math/polynomial/polynomial.go +++ b/std/math/polynomial/polynomial.go @@ -223,7 +223,7 @@ func (p *Polynomial[FR]) InterpolateLDE(at *emulated.Element[FR], values []*emul return res } -// EvalEquals returns the evaluation +// EvalEqual returns the evaluation // // eq(x, y) = \prod (1-x)*(1-y) + x*y, // diff --git a/std/recursion/groth16/verifier.go b/std/recursion/groth16/verifier.go index ad760db917..067158b441 100644 --- a/std/recursion/groth16/verifier.go +++ b/std/recursion/groth16/verifier.go @@ -367,6 +367,7 @@ func ValueOfVerifyingKeyFixed[G1El algebra.G1ElementT, G2El algebra.G2ElementT, return ret, fmt.Errorf("commitment key[%d]: %w", i, err) } } + s.PublicAndCommitmentCommitted = tVk.PublicAndCommitmentCommitted case *VerifyingKey[sw_bls12377.G1Affine, sw_bls12377.G2Affine, sw_bls12377.GT]: tVk, ok := vk.(*groth16backend_bls12377.VerifyingKey) if !ok { @@ -394,6 +395,7 @@ func ValueOfVerifyingKeyFixed[G1El algebra.G1ElementT, G2El algebra.G2ElementT, return ret, fmt.Errorf("commitment key[%d]: %w", i, err) } } + s.PublicAndCommitmentCommitted = tVk.PublicAndCommitmentCommitted case *VerifyingKey[sw_bls12381.G1Affine, sw_bls12381.G2Affine, sw_bls12381.GTEl]: tVk, ok := vk.(*groth16backend_bls12381.VerifyingKey) if !ok { @@ -421,6 +423,7 @@ func ValueOfVerifyingKeyFixed[G1El algebra.G1ElementT, G2El algebra.G2ElementT, return ret, fmt.Errorf("commitment key[%d]: %w", i, err) } } + s.PublicAndCommitmentCommitted = tVk.PublicAndCommitmentCommitted case *VerifyingKey[sw_bls24315.G1Affine, sw_bls24315.G2Affine, sw_bls24315.GT]: tVk, ok := vk.(*groth16backend_bls24315.VerifyingKey) if !ok { @@ -448,6 +451,7 @@ func ValueOfVerifyingKeyFixed[G1El algebra.G1ElementT, G2El algebra.G2ElementT, return ret, fmt.Errorf("commitment key[%d]: %w", i, err) } } + s.PublicAndCommitmentCommitted = tVk.PublicAndCommitmentCommitted case *VerifyingKey[sw_bw6761.G1Affine, sw_bw6761.G2Affine, sw_bw6761.GTEl]: tVk, ok := vk.(*groth16backend_bw6761.VerifyingKey) if !ok { @@ -475,6 +479,7 @@ func ValueOfVerifyingKeyFixed[G1El algebra.G1ElementT, G2El algebra.G2ElementT, return ret, fmt.Errorf("commitment key[%d]: %w", i, err) } } + s.PublicAndCommitmentCommitted = tVk.PublicAndCommitmentCommitted default: return ret, fmt.Errorf("unknown parametric type combination") } diff --git a/test/assert.go b/test/assert.go index 00f0ec9ab7..a5ed71214a 100644 --- a/test/assert.go +++ b/test/assert.go @@ -82,7 +82,7 @@ func (assert *Assert) ProverSucceeded(circuit frontend.Circuit, validAssignment assert.CheckCircuit(circuit, newOpts...) } -// ProverSucceeded is deprecated use [Assert.CheckCircuit] instead +// ProverFailed is deprecated use [Assert.CheckCircuit] instead func (assert *Assert) ProverFailed(circuit frontend.Circuit, invalidAssignment frontend.Circuit, opts ...TestingOption) { // copy the options newOpts := make([]TestingOption, len(opts), len(opts)+2) diff --git a/test/unsafekzg/kzgsrs.go b/test/unsafekzg/kzgsrs.go index ea3c674f72..c1c363307f 100644 --- a/test/unsafekzg/kzgsrs.go +++ b/test/unsafekzg/kzgsrs.go @@ -59,8 +59,8 @@ var ( memLock, fsLock sync.RWMutex ) -// NewSRS returns a pair of kzg.SRS; one in canonical form, the other in lagrange form. -// Default options use a memory cache, see Option for more details & options. +// NewSRS returns a pair of [kzg.SRS]; one in canonical form, the other in Lagrange form. +// Default options use a memory cache, see [Option] for more details & options. func NewSRS(ccs constraint.ConstraintSystem, opts ...Option) (canonical kzg.SRS, lagrange kzg.SRS, err error) { nbConstraints := ccs.GetNbConstraints() @@ -78,7 +78,7 @@ func NewSRS(ccs constraint.ConstraintSystem, opts ...Option) (canonical kzg.SRS, return nil, nil, err } - key := cacheKey(curveID, sizeCanonical) + key := cacheKey(curveID, sizeCanonical, cfg.toxicValue) log.Debug().Str("key", key).Msg("fetching SRS from mem cache") memLock.RLock() entry, ok := cache[key] @@ -109,17 +109,18 @@ func NewSRS(ccs constraint.ConstraintSystem, opts ...Option) (canonical kzg.SRS, log.Debug().Msg("SRS not found in cache, generating") // not in cache, generate - canonical, lagrange, err = newSRS(curveID, sizeCanonical) + canonical, lagrange, err = newSRS(curveID, sizeCanonical, cfg.toxicValue) if err != nil { return nil, nil, err } - // cache it + // cache it. We cache the SRS in case the toxic value given only in the + // memory cache to avoid storying weak SRS on filesystem. memLock.Lock() cache[key] = cacheEntry{canonical, lagrange} memLock.Unlock() - if cfg.fsCache { + if cfg.fsCache && cfg.toxicValue == nil { log.Debug().Str("key", key).Str("cacheDir", cfg.cacheDir).Msg("writing SRS to fs cache") fsLock.Lock() fsWrite(key, cfg.cacheDir, canonical, lagrange) @@ -134,7 +135,10 @@ type cacheEntry struct { lagrange kzg.SRS } -func cacheKey(curveID ecc.ID, size uint64) string { +func cacheKey(curveID ecc.ID, size uint64, toxicValue *big.Int) string { + if toxicValue != nil { + return fmt.Sprintf("kzgsrs-%s-%d-toxic-%s", curveID.String(), size, toxicValue.String()) + } return fmt.Sprintf("kzgsrs-%s-%d", curveID.String(), size) } @@ -147,15 +151,18 @@ func extractCurveID(key string) (ecc.ID, error) { return ecc.IDFromString(matches[1]) } -func newSRS(curveID ecc.ID, size uint64) (kzg.SRS, kzg.SRS, error) { - - tau, err := rand.Int(rand.Reader, curveID.ScalarField()) - if err != nil { - return nil, nil, err +func newSRS(curveID ecc.ID, size uint64, tau *big.Int) (kzg.SRS, kzg.SRS, error) { + var ( + srs kzg.SRS + err error + ) + if tau == nil { + tau, err = rand.Int(rand.Reader, curveID.ScalarField()) + if err != nil { + return nil, nil, fmt.Errorf("sample random toxic value: %w", err) + } } - var srs kzg.SRS - switch curveID { case ecc.BN254: srs, err = kzg_bn254.NewSRS(size, tau) diff --git a/test/unsafekzg/options.go b/test/unsafekzg/options.go index 3639505fde..120e17079e 100644 --- a/test/unsafekzg/options.go +++ b/test/unsafekzg/options.go @@ -1,12 +1,16 @@ package unsafekzg import ( + "crypto/sha256" + "errors" + "math/big" "os" "path/filepath" "github.com/consensys/gnark/logger" ) +// Option allows changing the behaviour of the unsafe KZG SRS generation. type Option func(*config) error // WithCacheDir enables the filesystem cache and sets the cache directory @@ -18,9 +22,49 @@ func WithFSCache() Option { } } +// WithCacheDir enables the filesystem cache and sets the cache directory +// to the provided path. +func WithCacheDir(dir string) Option { + return func(opt *config) error { + opt.fsCache = true + opt.cacheDir = dir + return nil + } +} + +// WithToxicValue sets the toxic value to the provided value. +// +// NB! This is a debug option and should not be used in production. +func WithToxicValue(toxicValue *big.Int) Option { + return func(opt *config) error { + if opt.toxicValue != nil { + return errors.New("toxic value already set") + } + opt.toxicValue = toxicValue + return nil + } +} + +// WithToxicSeed sets the toxic value to the sha256 hash of the provided seed. +// +// NB! This is a debug option and should not be used in production. +func WithToxicSeed(seed []byte) Option { + return func(opt *config) error { + if opt.toxicValue != nil { + return errors.New("toxic value already set") + } + h := sha256.New() + h.Write(seed) + opt.toxicValue = new(big.Int) + opt.toxicValue.SetBytes(h.Sum(nil)) + return nil + } +} + type config struct { - fsCache bool - cacheDir string + fsCache bool + cacheDir string + toxicValue *big.Int } // default options