diff --git a/crates/memonitor-sys/cuda/src/memonitor.c b/crates/memonitor-sys/cuda/src/memonitor.c index f8e7cc3..5fe67d3 100644 --- a/crates/memonitor-sys/cuda/src/memonitor.c +++ b/crates/memonitor-sys/cuda/src/memonitor.c @@ -14,14 +14,18 @@ typedef void* mod_type; #endif typedef int CUresult; -struct CUctx_st; -typedef struct CUctx_st *CUcontext; -typedef int CUdevice_v1; -typedef CUdevice_v1 *CUdevice; +struct CUctx_opaque; +typedef struct CUctx_opaque *CUcontext; +typedef int CUdevice_opaque; +typedef CUdevice_opaque *CUdevice; +struct CUexecAffinityParam { + int type; + unsigned int count; +}; typedef CUresult(*cuInit_type)(unsigned int); -typedef CUresult (*cuCtxCreate_type)(CUcontext *, unsigned int, CUdevice); +typedef CUresult (*cuCtxCreate_type)(CUcontext *, struct CUexecAffinityParam *, int, unsigned int, CUdevice); typedef CUresult (*cuCtxDestroy_type)(CUcontext); @@ -43,7 +47,7 @@ typedef CUresult (*cuMemGetInfo_type)(size_t *, size_t *); struct Device { CUdevice handle; - CUdevice_v1 inner; + CUdevice_opaque inner; }; static mod_type module = NULL; @@ -69,16 +73,16 @@ int cu_init() { } cuInit = (cuInit_type) GetProcAddress(module, "cuInit"); - cuCtxCreate = (cuCtxCreate_type) GetProcAddress(module, "cuCtxCreate"); - cuCtxDestroy = (cuCtxDestroy_type) GetProcAddress(module, "cuCtxDestroy"); - cuCtxPopCurrent = (cuCtxPopCurrent_type) GetProcAddress(module, "cuCtxPopCurrent"); - cuCtxPushCurrent = (cuCtxPushCurrent_type) GetProcAddress(module, "cuCtxPushCurrent"); + cuCtxCreate = (cuCtxCreate_type) GetProcAddress(module, "cuCtxCreate_v3"); + cuCtxDestroy = (cuCtxDestroy_type) GetProcAddress(module, "cuCtxDestroy_v2"); + cuCtxPopCurrent = (cuCtxPopCurrent_type) GetProcAddress(module, "cuCtxPopCurrent_v2"); + cuCtxPushCurrent = (cuCtxPushCurrent_type) GetProcAddress(module, "cuCtxPushCurrent_v2"); cuCtxSetCurrent = (cuCtxSetCurrent_type) GetProcAddress(module, "cuCtxSetCurrent"); cuDeviceGetCount = (cuDeviceGetCount_type) GetProcAddress(module, "cuDeviceGetCount"); cuDeviceGet = (cuDeviceGet_type) GetProcAddress(module, "cuDeviceGet"); cuDeviceGetName = (cuDeviceGetName_type) GetProcAddress(module, "cuDeviceGetName"); - cuDeviceTotalMem = (cuDeviceTotalMem_type) GetProcAddress(module, "cuDeviceTotalMem"); - cuMemGetInfo = (cuMemGetInfo_type) GetProcAddress(module, "cuMemGetInfo"); + cuDeviceTotalMem = (cuDeviceTotalMem_type) GetProcAddress(module, "cuDeviceTotalMem_v2"); + cuMemGetInfo = (cuMemGetInfo_type) GetProcAddress(module, "cuMemGetInfo_v2"); #elif defined(__APPLE__) return -1; #else @@ -88,16 +92,16 @@ int cu_init() { } cuInit = (cuInit_type) dlsym(module, "cuInit"); - cuCtxCreate = (cuCtxCreate_type) dlsym(module, "cuCtxCreate"); - cuCtxDestroy = (cuCtxDestroy_type) dlsym(module, "cuCtxDestroy"); - cuCtxPopCurrent = (cuCtxPopCurrent_type) dlsym(module, "cuCtxPopCurrent"); - cuCtxPushCurrent = (cuCtxPushCurrent_type) dlsym(module, "cuCtxPushCurrent"); + cuCtxCreate = (cuCtxCreate_type) dlsym(module, "cuCtxCreate_v3"); + cuCtxDestroy = (cuCtxDestroy_type) dlsym(module, "cuCtxDestroy_v2"); + cuCtxPopCurrent = (cuCtxPopCurrent_type) dlsym(module, "cuCtxPopCurrent_v2"); + cuCtxPushCurrent = (cuCtxPushCurrent_type) dlsym(module, "cuCtxPushCurrent_v2"); cuCtxSetCurrent = (cuCtxSetCurrent_type) dlsym(module, "cuCtxSetCurrent"); cuDeviceGetCount = (cuDeviceGetCount_type) dlsym(module, "cuDeviceGetCount"); cuDeviceGet = (cuDeviceGet_type) dlsym(module, "cuDeviceGet"); cuDeviceGetName = (cuDeviceGetName_type) dlsym(module, "cuDeviceGetName"); - cuDeviceTotalMem = (cuDeviceTotalMem_type) dlsym(module, "cuDeviceTotalMem"); - cuMemGetInfo = (cuMemGetInfo_type) dlsym(module, "cuMemGetInfo"); + cuDeviceTotalMem = (cuDeviceTotalMem_type) dlsym(module, "cuDeviceTotalMem_v2"); + cuMemGetInfo = (cuMemGetInfo_type) dlsym(module, "cuMemGetInfo_v2"); #endif CUresult res = cuInit(0); @@ -159,7 +163,7 @@ struct cu_Devices cu_list_devices() { return invalid_devices; } - res = cuCtxCreate(&ctx_handles[d], 0, device_handles[d].handle); + res = cuCtxCreate(&ctx_handles[d], NULL, 0, 0, device_handles[d].handle); if (res != 0) { free(device_handles); free(ctx_handles);