From 30418c8359197b2e8117d93fe1b4685bea528d0e Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Sun, 14 Jul 2024 19:50:47 +0200 Subject: [PATCH] Add poc/ntt-cuda/go test. --- poc/ntt-cuda/cuda/ntt_api.cu | 2 +- poc/ntt-cuda/go.mod | 9 ++++ poc/ntt-cuda/go.sum | 2 + poc/ntt-cuda/go/cgo_sppark.h | 37 ++++++++++++++++ poc/ntt-cuda/go/goldilocks.go | 68 ++++++++++++++++++++++++++++++ poc/ntt-cuda/go/goldilocks_test.go | 32 ++++++++++++++ 6 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 poc/ntt-cuda/go.mod create mode 100644 poc/ntt-cuda/go.sum create mode 100644 poc/ntt-cuda/go/cgo_sppark.h create mode 100644 poc/ntt-cuda/go/goldilocks.go create mode 100644 poc/ntt-cuda/go/goldilocks_test.go diff --git a/poc/ntt-cuda/cuda/ntt_api.cu b/poc/ntt-cuda/cuda/ntt_api.cu index 4f1cb6c..f801895 100644 --- a/poc/ntt-cuda/cuda/ntt_api.cu +++ b/poc/ntt-cuda/cuda/ntt_api.cu @@ -24,7 +24,7 @@ #include -extern "C" +SPPARK_FFI RustError::by_value compute_ntt(size_t device_id, fr_t* inout, uint32_t lg_domain_size, NTT::InputOutputOrder ntt_order, diff --git a/poc/ntt-cuda/go.mod b/poc/ntt-cuda/go.mod new file mode 100644 index 0000000..67dada8 --- /dev/null +++ b/poc/ntt-cuda/go.mod @@ -0,0 +1,9 @@ +module poc_cu + +go 1.18 + +replace github.com/supranational/sppark => ../.. + +require github.com/supranational/sppark v0.0.0-00010101000000-000000000000 + +require github.com/supranational/blst v0.3.13 // indirect diff --git a/poc/ntt-cuda/go.sum b/poc/ntt-cuda/go.sum new file mode 100644 index 0000000..8900369 --- /dev/null +++ b/poc/ntt-cuda/go.sum @@ -0,0 +1,2 @@ +github.com/supranational/blst v0.3.13 h1:AYeSxdOMacwu7FBmpfloBz5pbFXDmJL33RuwnKtmTjk= +github.com/supranational/blst v0.3.13/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw= diff --git a/poc/ntt-cuda/go/cgo_sppark.h b/poc/ntt-cuda/go/cgo_sppark.h new file mode 100644 index 0000000..413b2a3 --- /dev/null +++ b/poc/ntt-cuda/go/cgo_sppark.h @@ -0,0 +1,37 @@ +#ifndef __CGO_SPPARK_H__ +#define __CGO_SPPARK_H__ + +typedef struct { + int code; + char *message; +} Error; + +typedef struct { + int code; + _GoString_ message; +} GoError; + +__attribute__((weak)) // required with go1.18 and earlier +void toGoError(GoError *go_err, Error c_err); + +#define WRAP_ERR(ret_t, func, ...) __attribute__((section("_sppark"), used)) \ + static struct { Error (*call)(__VA_ARGS__); const char *name; } \ + func = { NULL, #func }; \ + static void go_##func(GoError *go_err, __VA_ARGS__) +#if 0 +// For example in the import "C" section: +// +// #include "cgo_sppark.h" +// WRAP_ERR(Error, cuda_func, type1 arg1, type2 arg2) +// { toGoError(go_err, (*cuda_func.call)(arg1, arg2); } +// +// and then on the Go side: + ... + var err C.GoError + C.go_cuda_func(&err, arg1, arg2) + if err.code != 0 { + panic(err.message) + } + ... +#endif +#endif diff --git a/poc/ntt-cuda/go/goldilocks.go b/poc/ntt-cuda/go/goldilocks.go new file mode 100644 index 0000000..5e7bf31 --- /dev/null +++ b/poc/ntt-cuda/go/goldilocks.go @@ -0,0 +1,68 @@ +package goldilocks + +// #include "cgo_sppark.h" +// #include +// +// typedef enum { NN, NR, RN, RR } InputOutputOrder; +// typedef enum { forward, inverse } Direction; +// typedef enum { standard, coset } Type; +// +// WRAP_ERR(Error, compute_ntt, size_t device_id, +// uint64_t inout[], uint32_t lg_domain_size, +// InputOutputOrder order, Direction direction, +// Type type) +// { toGoError(go_err, (*compute_ntt.call)(device_id, inout, lg_domain_size, +// order, direction, type)); +// } +import "C" + +import ( + "math/bits" + sppark "github.com/supranational/sppark/go" +) + +func init() { + sppark.Load("../cuda/ntt_api.cu", "-arch=native", "-DFEATURE_GOLDILOCKS") +} + +func NTT(device_id int, inout []uint64, + order InputOutputOrder, direction Direction, optional ...Type) { + kind := Standard + if len(optional) > 0 { + kind = optional[0] + } + + domainSz := len(inout) + if (domainSz & (domainSz-1)) != 0 { + panic("invalid |inout| size") + } + lgDomainSz := bits.Len(uint(domainSz)) - 1 + + var err C.GoError + C.go_compute_ntt(&err, C.size_t(device_id), + (*C.ulong)(&inout[0]), C.uint(lgDomainSz), + order, direction, kind) + if err.code != 0 { + panic(err.message) + } +} + +type InputOutputOrder = C.InputOutputOrder +const ( + NN InputOutputOrder = iota + NR + RN + RR +) + +type Direction = C.Direction +const ( + Forward Direction = iota + Inverse +) + +type Type = C.Type +const ( + Standard Type = iota + Coset +) diff --git a/poc/ntt-cuda/go/goldilocks_test.go b/poc/ntt-cuda/go/goldilocks_test.go new file mode 100644 index 0000000..08d777c --- /dev/null +++ b/poc/ntt-cuda/go/goldilocks_test.go @@ -0,0 +1,32 @@ +package goldilocks + +import ( + sppark "github.com/supranational/sppark/go" + "crypto/rand" + "math/big" + "reflect" + "testing" +) + +func TestSelfConsistency(t *testing.T) { + var mod big.Int + mod.SetUint64(0xffffffff00000001) + + ref := make([]uint64, 1<<20) + for i := range ref { + if rnd, err := rand.Int(rand.Reader, &mod); err == nil { + ref[i] = rnd.Uint64() + } + } + + work := append([]uint64{}, ref...) + + NTT(0, work, NR, Forward) + NTT(0, work, RN, Inverse) + + if !reflect.DeepEqual(ref, work) { + t.Errorf("result mismatch") + } else if err := sppark.Exfiltrate("."); err != nil { + t.Error(err) + } +}