Skip to content

Commit

Permalink
use up-to-date Cuda driver functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pedro-devv committed Apr 8, 2024
1 parent f10c977 commit 959a682
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions crates/memonitor-sys/cuda/src/memonitor.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ typedef void* mod_type;
#endif

typedef int CUresult;
struct CUctx_st;
typedef struct CUctx_st *CUcontext;
typedef int CUdevice_v1;
typedef CUdevice_v1 *CUdevice;
struct CUctx_opaque;
typedef struct CUctx_opaque *CUcontext;
typedef int CUdevice_opaque;
typedef CUdevice_opaque *CUdevice;
struct CUexecAffinityParam {
int type;
unsigned int count;
};

typedef CUresult(*cuInit_type)(unsigned int);

typedef CUresult (*cuCtxCreate_type)(CUcontext *, unsigned int, CUdevice);
typedef CUresult (*cuCtxCreate_type)(CUcontext *, struct CUexecAffinityParam *, int, unsigned int, CUdevice);

typedef CUresult (*cuCtxDestroy_type)(CUcontext);

Expand All @@ -43,7 +47,7 @@ typedef CUresult (*cuMemGetInfo_type)(size_t *, size_t *);

struct Device {
CUdevice handle;
CUdevice_v1 inner;
CUdevice_opaque inner;
};

static mod_type module = NULL;
Expand All @@ -69,16 +73,16 @@ int cu_init() {
}

cuInit = (cuInit_type) GetProcAddress(module, "cuInit");
cuCtxCreate = (cuCtxCreate_type) GetProcAddress(module, "cuCtxCreate");
cuCtxDestroy = (cuCtxDestroy_type) GetProcAddress(module, "cuCtxDestroy");
cuCtxPopCurrent = (cuCtxPopCurrent_type) GetProcAddress(module, "cuCtxPopCurrent");
cuCtxPushCurrent = (cuCtxPushCurrent_type) GetProcAddress(module, "cuCtxPushCurrent");
cuCtxCreate = (cuCtxCreate_type) GetProcAddress(module, "cuCtxCreate_v3");
cuCtxDestroy = (cuCtxDestroy_type) GetProcAddress(module, "cuCtxDestroy_v2");
cuCtxPopCurrent = (cuCtxPopCurrent_type) GetProcAddress(module, "cuCtxPopCurrent_v2");
cuCtxPushCurrent = (cuCtxPushCurrent_type) GetProcAddress(module, "cuCtxPushCurrent_v2");
cuCtxSetCurrent = (cuCtxSetCurrent_type) GetProcAddress(module, "cuCtxSetCurrent");
cuDeviceGetCount = (cuDeviceGetCount_type) GetProcAddress(module, "cuDeviceGetCount");
cuDeviceGet = (cuDeviceGet_type) GetProcAddress(module, "cuDeviceGet");
cuDeviceGetName = (cuDeviceGetName_type) GetProcAddress(module, "cuDeviceGetName");
cuDeviceTotalMem = (cuDeviceTotalMem_type) GetProcAddress(module, "cuDeviceTotalMem");
cuMemGetInfo = (cuMemGetInfo_type) GetProcAddress(module, "cuMemGetInfo");
cuDeviceTotalMem = (cuDeviceTotalMem_type) GetProcAddress(module, "cuDeviceTotalMem_v2");
cuMemGetInfo = (cuMemGetInfo_type) GetProcAddress(module, "cuMemGetInfo_v2");
#elif defined(__APPLE__)
return -1;
#else
Expand All @@ -88,16 +92,16 @@ int cu_init() {
}

cuInit = (cuInit_type) dlsym(module, "cuInit");
cuCtxCreate = (cuCtxCreate_type) dlsym(module, "cuCtxCreate");
cuCtxDestroy = (cuCtxDestroy_type) dlsym(module, "cuCtxDestroy");
cuCtxPopCurrent = (cuCtxPopCurrent_type) dlsym(module, "cuCtxPopCurrent");
cuCtxPushCurrent = (cuCtxPushCurrent_type) dlsym(module, "cuCtxPushCurrent");
cuCtxCreate = (cuCtxCreate_type) dlsym(module, "cuCtxCreate_v3");
cuCtxDestroy = (cuCtxDestroy_type) dlsym(module, "cuCtxDestroy_v2");
cuCtxPopCurrent = (cuCtxPopCurrent_type) dlsym(module, "cuCtxPopCurrent_v2");
cuCtxPushCurrent = (cuCtxPushCurrent_type) dlsym(module, "cuCtxPushCurrent_v2");
cuCtxSetCurrent = (cuCtxSetCurrent_type) dlsym(module, "cuCtxSetCurrent");
cuDeviceGetCount = (cuDeviceGetCount_type) dlsym(module, "cuDeviceGetCount");
cuDeviceGet = (cuDeviceGet_type) dlsym(module, "cuDeviceGet");
cuDeviceGetName = (cuDeviceGetName_type) dlsym(module, "cuDeviceGetName");
cuDeviceTotalMem = (cuDeviceTotalMem_type) dlsym(module, "cuDeviceTotalMem");
cuMemGetInfo = (cuMemGetInfo_type) dlsym(module, "cuMemGetInfo");
cuDeviceTotalMem = (cuDeviceTotalMem_type) dlsym(module, "cuDeviceTotalMem_v2");
cuMemGetInfo = (cuMemGetInfo_type) dlsym(module, "cuMemGetInfo_v2");
#endif

CUresult res = cuInit(0);
Expand Down Expand Up @@ -159,7 +163,7 @@ struct cu_Devices cu_list_devices() {
return invalid_devices;
}

res = cuCtxCreate(&ctx_handles[d], 0, device_handles[d].handle);
res = cuCtxCreate(&ctx_handles[d], NULL, 0, 0, device_handles[d].handle);
if (res != 0) {
free(device_handles);
free(ctx_handles);
Expand Down

0 comments on commit 959a682

Please sign in to comment.