diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py index 72c7736c1..cb7c58805 100644 --- a/dask_cuda/utils.py +++ b/dask_cuda/utils.py @@ -20,6 +20,15 @@ def nvtx_annotate(message=None, color="blue", domain=None): yield +@toolz.memoize +def _is_tegra(): + import os + + return os.path.isdir("/sys/class/tegra-firmware/") or os.path.isfile( + "/etc/nv_tegra_release" + ) + + class CPUAffinity: def __init__(self, cores): self.cores = cores @@ -96,8 +105,13 @@ def get_cpu_count(): @toolz.memoize def get_gpu_count(): - pynvml.nvmlInit() - return pynvml.nvmlDeviceGetCount() + if _is_tegra(): + import numba.cuda + + return len(numba.cuda.gpus) + else: + pynvml.nvmlInit() + return pynvml.nvmlDeviceGetCount() def get_cpu_affinity(device_index): @@ -125,21 +139,24 @@ def get_cpu_affinity(device_index): 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79] """ - pynvml.nvmlInit() - - try: - # Result is a list of 64-bit integers, thus ceil(get_cpu_count() / 64) - affinity = pynvml.nvmlDeviceGetCpuAffinity( - pynvml.nvmlDeviceGetHandleByIndex(device_index), - math.ceil(get_cpu_count() / 64), - ) - return unpack_bitmask(affinity) - except pynvml.NVMLError: - warnings.warn( - "Cannot get CPU affinity for device with index %d, setting default affinity" - % device_index - ) + if _is_tegra(): return list(range(get_cpu_count())) + else: + pynvml.nvmlInit() + + try: + # Result is a list of 64-bit integers, thus ceil(get_cpu_count() / 64) + affinity = pynvml.nvmlDeviceGetCpuAffinity( + pynvml.nvmlDeviceGetHandleByIndex(device_index), + math.ceil(get_cpu_count() / 64), + ) + return unpack_bitmask(affinity) + except pynvml.NVMLError: + warnings.warn( + "Cannot get CPU affinity for device with index %d, setting default affinity" + % device_index + ) + return list(range(get_cpu_count())) def get_n_gpus(): @@ -153,10 +170,22 @@ def get_device_total_memory(index=0): """ Return total memory of CUDA device with index """ - pynvml.nvmlInit() - return pynvml.nvmlDeviceGetMemoryInfo( - pynvml.nvmlDeviceGetHandleByIndex(index) - ).total + if _is_tegra(): + from ctypes import byref, c_size_t + import numba.cuda + + driver = numba.cuda.driver.Driver() + + numba.cuda.current_context() + free = c_size_t() + total = c_size_t() + driver.cuMemGetInfo(byref(free), byref(total)) + return total.value + else: + pynvml.nvmlInit() + return pynvml.nvmlDeviceGetMemoryInfo( + pynvml.nvmlDeviceGetHandleByIndex(index) + ).total def get_ucx_net_devices(