Skip to content

Commit

Permalink
Make Go bridge work on Windows.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Jul 29, 2024
1 parent b8d5a37 commit 5da8a72
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 30 deletions.
2 changes: 2 additions & 0 deletions go/cgo_sppark.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ typedef struct {
_GoString_ message;
} GoError;

#ifndef _WIN32
__attribute__((weak)) // required with go1.18 and earlier
#endif
void toGoError(GoError *go_err, Error c_err);

#define WRAP_ERR(ret_t, func, ...) __attribute__((section("_sppark"), used)) \
Expand Down
117 changes: 88 additions & 29 deletions go/sppark.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ package sppark
// #ifndef GO_CGO_EXPORT_PROLOGUE_H
// #ifdef _WIN32
// # include <windows.h>
// # include <stdio.h>
// static char hex_from_nibble(unsigned char nibble)
// {
// int mask = (9 - (nibble &= 0xf)) >> 31;
// return (char)(nibble + ((('a'-10) & mask) | ('0' & ~mask)));
// }
// #else
// # include <dlfcn.h>
// # include <errno.h>
Expand All @@ -18,18 +22,6 @@ package sppark
//
// #include "cgo_sppark.h"
//
// void toGoString(_GoString_ *, char *);
//
// void toGoError(GoError *go_err, Error c_err)
// {
// go_err->code = c_err.code;
// if (c_err.message != NULL) {
// toGoString(&go_err->message, c_err.message);
// free(c_err.message);
// c_err.message = NULL;
// }
// }
//
// typedef struct {
// void *ptr;
// } gpu_ptr_t;
Expand All @@ -48,6 +40,21 @@ package sppark
// WRAP(_Bool, cuda_available, void)
// { return (*cuda_available.call)(); }
//
// WRAP(void, drop_error_message, char *ptr)
// { (*drop_error_message.call)(ptr); }
//
// void toGoString(_GoString_ *, char *);
//
// void toGoError(GoError *go_err, Error c_err)
// {
// go_err->code = c_err.code;
// if (c_err.message != NULL) {
// toGoString(&go_err->message, c_err.message);
// go_drop_error_message(c_err.message);
// c_err.message = NULL;
// }
// }
//
// typedef struct {
// void *value;
// const char *name;
Expand Down Expand Up @@ -84,17 +91,31 @@ package sppark
// break;
// }
// #ifdef __SPPARK_CGO_DEBUG__
// printf("%p %s\n", sym->value, sym->name);
// fprintf(stderr, "%p %s\n", sym->value, sym->name);
// #endif
// }
// }
// }
//
// if (h == NULL) {
// #ifdef _WIN32
// static char buf[24];
// snprintf(buf, sizeof(buf), "WIN32 Error #0x%x", GetLastError());
// toGoString(err, buf);
// char *msg = NULL;
// unsigned int win_err = GetLastError();
// if (FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM|FORMAT_MESSAGE_ALLOCATE_BUFFER,
// NULL, win_err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
// (char *)&msg, 0, NULL)) {
// toGoString(err, msg);
// LocalFree(msg);
// } else {
// static char buf[24] = "WIN32 Error #0x";
// msg = buf + 15;
// do {
// *msg++ = hex_from_nibble(win_err>>4);
// *msg++ = hex_from_nibble(win_err);
// } while (win_err >>= 8);
// *msg = '\0';
// toGoString(err, buf);
// }
// if (hmod) FreeLibrary(hmod);
// #else
// toGoString(err, dlerror());
Expand Down Expand Up @@ -133,6 +154,11 @@ func init() {
}

func Load(baseName string, options ...string) {
goArch := runtime.GOARCH
if goArch != "amd64" && goArch != "arm64" {
log.Panicf("%s: unsupported GOARCH", goArch)
}

baseName = strings.TrimSuffix(baseName, filepath.Ext(baseName))

var dst, src string
Expand Down Expand Up @@ -181,16 +207,32 @@ func Load(baseName string, options ...string) {
func build(dst string, src string, custom_args ...string) bool {
var args []string

cc, ok := os.LookupEnv("CC")
if !ok {
cc = "gcc" // default for CGO, alternatively examine `go env`...
}

cmd := exec.Command(cc, "-O2", "-fPIC", "-c",
filepath.Join(blst.SrcRoot, "build", "assembly.S"),
filepath.Join(blst.SrcRoot, "src", "cpuid.c"))
if out, err := cmd.CombinedOutput(); err != nil {
log.Fatal(cmd.String(), "\n", string(out))
return false
}

defer os.Remove("assembly.o")
defer os.Remove("cpuid.o")

args = append(args, "-shared", "-o", dst, src)
args = append(args, "-I" + SrcRoot)
args = append(args, filepath.Join(SrcRoot, "util", "all_gpus.cpp"))
args = append(args, "-I" + filepath.Join(blst.SrcRoot, "src"))
args = append(args, filepath.Join(blst.SrcRoot, "build", "assembly.S"))
args = append(args, filepath.Join(blst.SrcRoot, "src", "cpuid.c"))
args = append(args, "-DTAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE")
if runtime.GOOS == "windows" {
args = append(args, "-ccbin=clang-cl")
} else {
args = append(args, filepath.Join(SrcRoot, "util", "all_gpus.cpp"))
args = append(args, "assembly.o", "cpuid.o")
if runtime.GOOS != "windows" {
if cxx, ok := os.LookupEnv("CXX"); ok {
args = append(args, "-ccbin", cxx)
}
args = append(args, "-Xcompiler", "-fPIC,-fvisibility=hidden")
args = append(args, "-Xlinker", "-Bsymbolic")
}
Expand All @@ -217,13 +259,19 @@ func build(dst string, src string, custom_args ...string) bool {
nvcc = sccache
}

cmd := exec.Command(nvcc, args...)
cmd = exec.Command(nvcc, args...)

if out, err := cmd.CombinedOutput(); err != nil {
out, err := cmd.CombinedOutput()
if err != nil {
log.Fatal(cmd.String(), "\n", string(out))
return false
}

if cgo_cflags, ok := os.LookupEnv("CGO_CFLAGS");
ok && strings.Contains(cgo_cflags, "__SPPARK_CGO_DEBUG__") {
log.Print(cmd.String(), "\n", string(out))
}

return true
}

Expand All @@ -250,18 +298,29 @@ func Exfiltrate(optional ...string) error {
if err != nil {
return err
}
fout, err := os.OpenFile(filepath.Join(dir, filepath.Base(file)),
os.O_WRONLY|os.O_CREATE, 0644)

dst := filepath.Join(dir, filepath.Base(file))
fout, err := os.OpenFile(dst,
os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
if err != nil {
finp.Close()
return err
}

finpStat, _ := finp.Stat()
foutStat, _ := fout.Stat()
if !os.SameFile(finpStat, foutStat) {
fout.Close()
if !os.SameFile(finpStat, foutStat) {
fout, err = os.OpenFile(dst,
os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
if err != nil {
finp.Close()
return err
}
log.Print("copying ", file)
io.Copy(fout, finp)
fout.Close()
}
fout.Close()
finp.Close()
}

Expand Down
2 changes: 2 additions & 0 deletions poc/go/cgo_sppark.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ typedef struct {
_GoString_ message;
} GoError;

#ifndef _WIN32
__attribute__((weak)) // required with go1.18 and earlier
#endif
void toGoError(GoError *go_err, Error c_err);

#define WRAP_ERR(ret_t, func, ...) __attribute__((section("_sppark"), used)) \
Expand Down
3 changes: 3 additions & 0 deletions poc/go/poc.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include <stdio.h>
#include <string.h>
#ifdef _MSC_VER
# define strdup _strdup
#endif

__global__ void kernel()
{
Expand Down
2 changes: 2 additions & 0 deletions poc/ntt-cuda/go/cgo_sppark.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ typedef struct {
_GoString_ message;
} GoError;

#ifndef _WIN32
__attribute__((weak)) // required with go1.18 and earlier
#endif
void toGoError(GoError *go_err, Error c_err);

#define WRAP_ERR(ret_t, func, ...) __attribute__((section("_sppark"), used)) \
Expand Down
2 changes: 1 addition & 1 deletion poc/ntt-cuda/go/goldilocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NTT(device_id int, inout []uint64,

var err C.GoError
C.go_compute_ntt(&err, C.size_t(device_id),
(*C.ulong)(&inout[0]), C.uint(lgDomainSz),
(*C.uint64_t)(&inout[0]), C.uint(lgDomainSz),
order, direction, kind)
if err.code != 0 {
panic(err.message)
Expand Down
5 changes: 5 additions & 0 deletions util/all_gpus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,8 @@ SPPARK_FFI gpu_ptr_t<void>::by_value clone_gpu_ptr_t(const gpu_ptr_t<void>& rhs)
#ifdef __clang__
# pragma clang diagnostic pop
#endif

#ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE
SPPARK_FFI void drop_error_message(char *ptr)
{ free(ptr); }
#endif

0 comments on commit 5da8a72

Please sign in to comment.