From b18f425e7e74234a6a218a3d9f86d54cab5a8291 Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Thu, 14 Sep 2023 12:34:48 +0300 Subject: [PATCH] Bump icicle, add simple point copy to device functions --- curves/bls12377/ntt_test.go | 16 ++++++++-------- curves/bn254/gnark.go | 10 ++-------- curves/bn254/ntt_test.go | 16 ++++++++-------- curves/bn254/utils.go | 27 +++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 2 ++ 6 files changed, 48 insertions(+), 25 deletions(-) diff --git a/curves/bls12377/ntt_test.go b/curves/bls12377/ntt_test.go index 31dc79e..f7aaf2b 100644 --- a/curves/bls12377/ntt_test.go +++ b/curves/bls12377/ntt_test.go @@ -58,7 +58,7 @@ func TestNttBN254CompareToGnarkDIF(t *testing.T) { copy(nttResult, scalars) assert.Equal(t, nttResult, scalars) - icicle.Ntt(&nttResult, false, icicle.DIF, 0) + icicle.Ntt(&nttResult, false, 0) assert.NotEqual(t, nttResult, scalars) domain := fft.NewDomain(uint64(len(scalars))) @@ -83,7 +83,7 @@ func TestNttBN254CompareToGnarkDIT(t *testing.T) { copy(nttResult, scalars) assert.Equal(t, nttResult, scalars) - icicle.Ntt(&nttResult, false, icicle.DIT, 0) + icicle.Ntt(&nttResult, false, 0) assert.NotEqual(t, nttResult, scalars) domain := fft.NewDomain(uint64(len(scalars))) @@ -108,7 +108,7 @@ func TestINttBN254CompareToGnarkDIT(t *testing.T) { copy(nttResult, scalars) assert.Equal(t, nttResult, scalars) - icicle.Ntt(&nttResult, true, icicle.DIT, 0) + icicle.Ntt(&nttResult, true, 0) assert.NotEqual(t, nttResult, scalars) frResScalars := make([]fr.Element, len(frScalars)) // Make a new slice with the same length @@ -136,7 +136,7 @@ func TestINttBN254CompareToGnarkDIF(t *testing.T) { copy(nttResult, scalars) assert.Equal(t, nttResult, scalars) - icicle.Ntt(&nttResult, true, icicle.DIF, 0) + icicle.Ntt(&nttResult, true, 0) assert.NotEqual(t, nttResult, scalars) domain := fft.NewDomain(uint64(len(scalars))) @@ -160,14 +160,14 @@ func TestNttBN254(t *testing.T) { copy(nttResult, scalars) assert.Equal(t, nttResult, scalars) - icicle.Ntt(&nttResult, false, icicle.NONE, 0) + icicle.Ntt(&nttResult, false, 0) assert.NotEqual(t, nttResult, scalars) inttResult := make([]icicle.G1ScalarField, len(nttResult)) copy(inttResult, nttResult) assert.Equal(t, inttResult, nttResult) - icicle.Ntt(&inttResult, true, icicle.NONE, 0) + icicle.Ntt(&inttResult, true, 0) assert.Equal(t, inttResult, scalars) } @@ -203,7 +203,7 @@ func TestNttBatchBN254(t *testing.T) { nttResultVecOfVec = append(nttResultVecOfVec, clone) // Call the ntt_bls12377 function - icicle.Ntt(&nttResultVecOfVec[i], false, icicle.NONE, 0) + icicle.Ntt(&nttResultVecOfVec[i], false, 0) } assert.NotEqual(t, nttBatchResult, scalars) @@ -227,7 +227,7 @@ func BenchmarkNTT(b *testing.B) { nttResult := make([]icicle.G1ScalarField, len(scalars)) // Make a new slice with the same length copy(nttResult, scalars) for n := 0; n < b.N; n++ { - icicle.Ntt(&nttResult, false, icicle.NONE, 0) + icicle.Ntt(&nttResult, false, 0) } }) } diff --git a/curves/bn254/gnark.go b/curves/bn254/gnark.go index 79143b1..408d178 100644 --- a/curves/bn254/gnark.go +++ b/curves/bn254/gnark.go @@ -15,8 +15,8 @@ import ( ) type OnDeviceData struct { - p unsafe.Pointer - size int + P unsafe.Pointer + Size int } func INttOnDevice(scalars_d, twiddles_d, cosetPowers_d unsafe.Pointer, size, sizeBytes int, isCoset bool) unsafe.Pointer { @@ -35,8 +35,6 @@ func NttOnDevice(scalars_out, scalars_d, twiddles_d, coset_powers_d unsafe.Point } icicle.ReverseScalars(scalars_out, size) - - return } func MsmOnDevice(scalars_d, points_d unsafe.Pointer, count int, convert bool) (curve.G1Jac, unsafe.Pointer, error) { @@ -84,8 +82,6 @@ func PolyOps(a_d, b_d, c_d, den_d unsafe.Pointer, size int) { if ret != 0 { fmt.Print("Vector mult a*den issue") } - - return } func MontConvOnDevice(scalars_d unsafe.Pointer, size int, is_into bool) { @@ -94,8 +90,6 @@ func MontConvOnDevice(scalars_d unsafe.Pointer, size int, is_into bool) { } else { icicle.FromMontgomery(scalars_d, size) } - - return } func CopyToDevice(scalars []fr.Element, bytes int, copyDone chan unsafe.Pointer) { diff --git a/curves/bn254/ntt_test.go b/curves/bn254/ntt_test.go index 164ea0f..537db78 100644 --- a/curves/bn254/ntt_test.go +++ b/curves/bn254/ntt_test.go @@ -59,7 +59,7 @@ func TestNttBN254CompareToGnarkDIF(t *testing.T) { copy(nttResult, scalars) assert.Equal(t, nttResult, scalars) - icicle.Ntt(&nttResult, false, icicle.DIF, 0) + icicle.Ntt(&nttResult, false, 0) assert.NotEqual(t, nttResult, scalars) domain := fft.NewDomain(uint64(len(scalars))) @@ -84,7 +84,7 @@ func TestNttBN254CompareToGnarkDIT(t *testing.T) { copy(nttResult, scalars) assert.Equal(t, nttResult, scalars) - icicle.Ntt(&nttResult, false, icicle.DIT, 0) + icicle.Ntt(&nttResult, false, 0) assert.NotEqual(t, nttResult, scalars) domain := fft.NewDomain(uint64(len(scalars))) @@ -109,7 +109,7 @@ func TestINttBN254CompareToGnarkDIT(t *testing.T) { copy(nttResult, scalars) assert.Equal(t, nttResult, scalars) - icicle.Ntt(&nttResult, true, icicle.DIT, 0) + icicle.Ntt(&nttResult, true, 0) assert.NotEqual(t, nttResult, scalars) frResScalars := make([]fr.Element, len(frScalars)) // Make a new slice with the same length @@ -137,7 +137,7 @@ func TestINttBN254CompareToGnarkDIF(t *testing.T) { copy(nttResult, scalars) assert.Equal(t, nttResult, scalars) - icicle.Ntt(&nttResult, true, icicle.DIF, 0) + icicle.Ntt(&nttResult, true, 0) assert.NotEqual(t, nttResult, scalars) domain := fft.NewDomain(uint64(len(scalars))) @@ -161,14 +161,14 @@ func TestNttBN254(t *testing.T) { copy(nttResult, scalars) assert.Equal(t, nttResult, scalars) - icicle.Ntt(&nttResult, false, icicle.NONE, 0) + icicle.Ntt(&nttResult, false, 0) assert.NotEqual(t, nttResult, scalars) inttResult := make([]icicle.G1ScalarField, len(nttResult)) copy(inttResult, nttResult) assert.Equal(t, inttResult, nttResult) - icicle.Ntt(&inttResult, true, icicle.NONE, 0) + icicle.Ntt(&inttResult, true, 0) assert.Equal(t, inttResult, scalars) } @@ -204,7 +204,7 @@ func TestNttBatchBN254(t *testing.T) { nttResultVecOfVec = append(nttResultVecOfVec, clone) // Call the ntt_bn254 function - icicle.Ntt(&nttResultVecOfVec[i], false, icicle.NONE, 0) + icicle.Ntt(&nttResultVecOfVec[i], false, 0) } assert.NotEqual(t, nttBatchResult, scalars) @@ -228,7 +228,7 @@ func BenchmarkNTT(b *testing.B) { nttResult := make([]icicle.G1ScalarField, len(scalars)) // Make a new slice with the same length copy(nttResult, scalars) for n := 0; n < b.N; n++ { - icicle.Ntt(&nttResult, false, icicle.NONE, 0) + icicle.Ntt(&nttResult, false, 0) } }) } diff --git a/curves/bn254/utils.go b/curves/bn254/utils.go index 99200c0..ac0de0e 100644 --- a/curves/bn254/utils.go +++ b/curves/bn254/utils.go @@ -2,13 +2,40 @@ package bn254 import ( "fmt" + "unsafe" "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fp" "github.com/consensys/gnark-crypto/ecc/bn254/fr" icicle "github.com/ingonyama-zk/icicle/goicicle/curves/bn254" + goicicle "github.com/ingonyama-zk/icicle/goicicle" ) +func CopyPointsToDevice(points []bn254.G1Affine, pointsBytes int, copyDone chan unsafe.Pointer) { + if pointsBytes == 0 { + copyDone <- nil + } else { + devicePtr, _ := goicicle.CudaMalloc(pointsBytes) + iciclePoints := BatchConvertFromG1Affine(points) + goicicle.CudaMemCpyHtoD[icicle.G1PointAffine](devicePtr, iciclePoints, pointsBytes) + + copyDone <- devicePtr + } +} + +func CopyG2PointsToDevice(points []bn254.G2Affine, pointsBytes int, copyDone chan unsafe.Pointer) { + if pointsBytes == 0 { + copyDone <- nil + } else { + devicePtr, _ := goicicle.CudaMalloc(pointsBytes) + iciclePoints := BatchConvertFromG2Affine(points) + goicicle.CudaMemCpyHtoD[icicle.G2PointAffine](devicePtr, iciclePoints, pointsBytes) + + copyDone <- devicePtr + } +} + + func ScalarToGnarkFr(f *icicle.G1ScalarField) *fr.Element { fb := f.ToBytesLe() var b32 [32]byte diff --git a/go.mod b/go.mod index 6dfd09e..2c53510 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.20 require ( github.com/consensys/gnark-crypto v0.11.0 - github.com/ingonyama-zk/icicle v0.0.0-20230817090824-e04bd928e658 + github.com/ingonyama-zk/icicle v0.0.0-20230907052343-04e5ff5d1af4 github.com/stretchr/testify v1.8.3 ) diff --git a/go.sum b/go.sum index bbe4edf..40d8be2 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/ingonyama-zk/icicle v0.0.0-20230816133820-ad1e4822526b h1:O2KFw7nh8JM github.com/ingonyama-zk/icicle v0.0.0-20230816133820-ad1e4822526b/go.mod h1:9JiPcJeIAUoVYrgVgBO6tVy6ZXIaPXZtz/y2Mvg6jbk= github.com/ingonyama-zk/icicle v0.0.0-20230817090824-e04bd928e658 h1:B9HOONNwcKYpUGKQN1/ASBGZUlvr1V3vlW2uFG/JPHg= github.com/ingonyama-zk/icicle v0.0.0-20230817090824-e04bd928e658/go.mod h1:9JiPcJeIAUoVYrgVgBO6tVy6ZXIaPXZtz/y2Mvg6jbk= +github.com/ingonyama-zk/icicle v0.0.0-20230907052343-04e5ff5d1af4 h1:3Va/VmQ+KZ0pI8eLRnS1j25eFzTQvjgfh6o85xbChcM= +github.com/ingonyama-zk/icicle v0.0.0-20230907052343-04e5ff5d1af4/go.mod h1:kAK8/EoN7fUEmakzgZIYdWy1a2rBnpCaZLqSHwZWxEk= github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU=