diff --git a/pecos/utils/mmap_hashmap_util.py b/pecos/utils/mmap_hashmap_util.py index 2138ec7..e8eeb2e 100644 --- a/pecos/utils/mmap_hashmap_util.py +++ b/pecos/utils/mmap_hashmap_util.py @@ -11,7 +11,7 @@ import logging from abc import abstractmethod from pecos.core import clib -from typing import Optional, Tuple, Any +from typing import Optional, Tuple from ctypes import c_char_p, c_uint32, c_uint64, POINTER import numpy as np import os @@ -95,20 +95,24 @@ class MmapHashmapBatchGetter(object): Batch getter for MmapHashmap opened for readonly. """ - def __init__(self, mmap_r: MmapHashmap, max_batch_size: int, threads: int=1): - if not isinstance(mmap_r, MmapHashmap): - raise ValueError(f"Should get from MmapHashmap, got {type(mmap_r)}") - if mmap_r.mode not in ["r", "r_lazy"]: - raise ValueError(f"MmapHashmap should opened for readonly, got {mmap_r.mode}") + def __init__(self, mmap_r, max_batch_size: int, threads: int = 1): + if not isinstance(mmap_r, _MmapHashmapReadOnly): + raise ValueError(f"Should get from readonly MmapHashmap, got {type(mmap_r)}") if max_batch_size <= 0: raise ValueError(f"Max batch size should >0, got {max_batch_size}") if threads <= 0 and threads != -1: raise ValueError(f"Number of threads should >0 or =-1, got {threads}") - self.mmap_r = mmap_r + self.mmap_r: Optional[_MmapHashmapReadOnly] = mmap_r self.max_batch_size = max_batch_size - self.key_prealloc = self.mmap_r.map.get_keyalloc(max_batch_size) - self.threads_c_uint32 = c_uint32(min(os.cpu_count(), os.cpu_count() if threads == -1 else threads)) + self.key_prealloc = mmap_r.get_keyalloc(max_batch_size) + + # `os.cpu_count()` is not equivalent to the number of CPUs the current process can use. + # The number of usable CPUs can be obtained with len(os.sched_getaffinity(0)) + n_usable_cpu = len(os.sched_getaffinity(0)) + self.threads_c_uint32 = c_uint32( + min(n_usable_cpu, n_usable_cpu if threads == -1 else threads) + ) # Pre-allocated space for returns self.vals = np.zeros(max_batch_size, dtype=np.uint64) @@ -123,8 +127,12 @@ def get(self, keys, default_val): ii) int2int: 1D numpy array of int64 2) The return is a reused buffer, use or copy the data once you get it. It is not guaranteed to last. """ - self.mmap_r.map.batch_get( - len(keys), self.key_prealloc.get_key_prealloc(keys), default_val, self.vals, self.threads_c_uint32 + self.mmap_r.batch_get( + len(keys), + self.key_prealloc.get_key_prealloc(keys), + default_val, + self.vals, + self.threads_c_uint32, ) return memoryview(self.vals)[: len(keys)] @@ -159,7 +167,7 @@ def __contains__(self, key): pass @abstractmethod - def batch_get(self, n_keys, keys, default_val, vals): + def batch_get(self, n_keys, keys, default_val, vals, threads_c_uint32): pass @classmethod @@ -200,7 +208,9 @@ def __getitem__(self, key_utf8): def __contains__(self, key_utf8): return self.fn_dict["contains"](self.map_ptr, key_utf8, len(key_utf8)) - def batch_get(self, n_keys: int, keys_utf8: Tuple, default_val: int, vals, threads_c_uint32: c_uint32): + def batch_get( + self, n_keys: int, keys_utf8: Tuple, default_val: int, vals, threads_c_uint32: c_uint32 + ): """ Batch get values for UTF8 encoded bytes string keys. Return values are stored in vals. @@ -227,7 +237,7 @@ def batch_get(self, n_keys: int, keys_utf8: Tuple, default_val: int, vals, threa keys_lens.ctypes.data_as(POINTER(c_uint32)), default_val, vals.ctypes.data_as(POINTER(c_uint64)), - threads_c_uint32 + threads_c_uint32, ) return vals @@ -280,7 +290,7 @@ def batch_get(self, n_keys: int, keys, default_val: int, vals, threads_c_uint32: keys.ctypes.data_as(POINTER(c_uint64)), default_val, vals.ctypes.data_as(POINTER(c_uint64)), - threads_c_uint32 + threads_c_uint32, ) return vals diff --git a/test/pecos/utils/test_mmap_hashmap_util.py b/test/pecos/utils/test_mmap_hashmap_util.py index 4c7a622..59cf020 100644 --- a/test/pecos/utils/test_mmap_hashmap_util.py +++ b/test/pecos/utils/test_mmap_hashmap_util.py @@ -48,7 +48,7 @@ def test_str2int_mmap_hashmap(tmpdir): # Batch get with default max_batch_size = 5 # max_batch_size > num of key - r_map_batch_getter = MmapHashmapBatchGetter(r_map, max_batch_size) + r_map_batch_getter = MmapHashmapBatchGetter(r_map.map, max_batch_size) ks = list(kv_dict.keys()) + ["ccccc".encode("utf-8")] # Non-exist key vs = list(kv_dict.values()) + [10] assert r_map_batch_getter.get(ks, 10).tolist() == vs @@ -99,7 +99,7 @@ def test_int2int_mmap_hashmap(tmpdir): # Batch get with default max_batch_size = 5 # max_batch_size > num of key - r_map_batch_getter = MmapHashmapBatchGetter(r_map, max_batch_size) + r_map_batch_getter = MmapHashmapBatchGetter(r_map.map, max_batch_size) ks = list(kv_dict.keys()) + [1000] # Non-exist key vs = list(kv_dict.values()) + [10] assert r_map_batch_getter.get(np.array(ks, dtype=np.int64), 10).tolist() == vs