Skip to content

Commit

Permalink
Add poc/ntt-cuda/go test.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Jul 24, 2024
1 parent 4823fc1 commit 30418c8
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 1 deletion.
2 changes: 1 addition & 1 deletion poc/ntt-cuda/cuda/ntt_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

#include <ntt/ntt.cuh>

extern "C"
SPPARK_FFI
RustError::by_value compute_ntt(size_t device_id,
fr_t* inout, uint32_t lg_domain_size,
NTT::InputOutputOrder ntt_order,
Expand Down
9 changes: 9 additions & 0 deletions poc/ntt-cuda/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module poc_cu

go 1.18

replace github.com/supranational/sppark => ../..

require github.com/supranational/sppark v0.0.0-00010101000000-000000000000

require github.com/supranational/blst v0.3.13 // indirect
2 changes: 2 additions & 0 deletions poc/ntt-cuda/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/supranational/blst v0.3.13 h1:AYeSxdOMacwu7FBmpfloBz5pbFXDmJL33RuwnKtmTjk=
github.com/supranational/blst v0.3.13/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw=
37 changes: 37 additions & 0 deletions poc/ntt-cuda/go/cgo_sppark.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef __CGO_SPPARK_H__
#define __CGO_SPPARK_H__

typedef struct {
int code;
char *message;
} Error;

typedef struct {
int code;
_GoString_ message;
} GoError;

__attribute__((weak)) // required with go1.18 and earlier
void toGoError(GoError *go_err, Error c_err);

#define WRAP_ERR(ret_t, func, ...) __attribute__((section("_sppark"), used)) \
static struct { Error (*call)(__VA_ARGS__); const char *name; } \
func = { NULL, #func }; \
static void go_##func(GoError *go_err, __VA_ARGS__)
#if 0
// For example in the import "C" section:
//
// #include "cgo_sppark.h"
// WRAP_ERR(Error, cuda_func, type1 arg1, type2 arg2)
// { toGoError(go_err, (*cuda_func.call)(arg1, arg2); }
//
// and then on the Go side:
...
var err C.GoError
C.go_cuda_func(&err, arg1, arg2)
if err.code != 0 {
panic(err.message)
}
...
#endif
#endif
68 changes: 68 additions & 0 deletions poc/ntt-cuda/go/goldilocks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package goldilocks

// #include "cgo_sppark.h"
// #include <stdint.h>
//
// typedef enum { NN, NR, RN, RR } InputOutputOrder;
// typedef enum { forward, inverse } Direction;
// typedef enum { standard, coset } Type;
//
// WRAP_ERR(Error, compute_ntt, size_t device_id,
// uint64_t inout[], uint32_t lg_domain_size,
// InputOutputOrder order, Direction direction,
// Type type)
// { toGoError(go_err, (*compute_ntt.call)(device_id, inout, lg_domain_size,
// order, direction, type));
// }
import "C"

import (
"math/bits"
sppark "github.com/supranational/sppark/go"
)

func init() {
sppark.Load("../cuda/ntt_api.cu", "-arch=native", "-DFEATURE_GOLDILOCKS")
}

func NTT(device_id int, inout []uint64,
order InputOutputOrder, direction Direction, optional ...Type) {
kind := Standard
if len(optional) > 0 {
kind = optional[0]
}

domainSz := len(inout)
if (domainSz & (domainSz-1)) != 0 {
panic("invalid |inout| size")
}
lgDomainSz := bits.Len(uint(domainSz)) - 1

var err C.GoError
C.go_compute_ntt(&err, C.size_t(device_id),
(*C.ulong)(&inout[0]), C.uint(lgDomainSz),
order, direction, kind)
if err.code != 0 {
panic(err.message)
}
}

type InputOutputOrder = C.InputOutputOrder
const (
NN InputOutputOrder = iota
NR
RN
RR
)

type Direction = C.Direction
const (
Forward Direction = iota
Inverse
)

type Type = C.Type
const (
Standard Type = iota
Coset
)
32 changes: 32 additions & 0 deletions poc/ntt-cuda/go/goldilocks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package goldilocks

import (
sppark "github.com/supranational/sppark/go"
"crypto/rand"
"math/big"
"reflect"
"testing"
)

func TestSelfConsistency(t *testing.T) {
var mod big.Int
mod.SetUint64(0xffffffff00000001)

ref := make([]uint64, 1<<20)
for i := range ref {
if rnd, err := rand.Int(rand.Reader, &mod); err == nil {
ref[i] = rnd.Uint64()
}
}

work := append([]uint64{}, ref...)

NTT(0, work, NR, Forward)
NTT(0, work, RN, Inverse)

if !reflect.DeepEqual(ref, work) {
t.Errorf("result mismatch")
} else if err := sppark.Exfiltrate("."); err != nil {
t.Error(err)
}
}

0 comments on commit 30418c8

Please sign in to comment.