Skip to content

Commit

Permalink
Fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
weiliw-amz committed Sep 26, 2023
1 parent e434c3f commit cf94ea2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
40 changes: 25 additions & 15 deletions pecos/utils/mmap_hashmap_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
from abc import abstractmethod
from pecos.core import clib
from typing import Optional, Tuple, Any
from typing import Optional, Tuple
from ctypes import c_char_p, c_uint32, c_uint64, POINTER
import numpy as np
import os
Expand Down Expand Up @@ -95,20 +95,24 @@ class MmapHashmapBatchGetter(object):
Batch getter for MmapHashmap opened for readonly.
"""

def __init__(self, mmap_r: MmapHashmap, max_batch_size: int, threads: int=1):
if not isinstance(mmap_r, MmapHashmap):
raise ValueError(f"Should get from MmapHashmap, got {type(mmap_r)}")
if mmap_r.mode not in ["r", "r_lazy"]:
raise ValueError(f"MmapHashmap should opened for readonly, got {mmap_r.mode}")
def __init__(self, mmap_r, max_batch_size: int, threads: int = 1):
if not isinstance(mmap_r, _MmapHashmapReadOnly):
raise ValueError(f"Should get from readonly MmapHashmap, got {type(mmap_r)}")
if max_batch_size <= 0:
raise ValueError(f"Max batch size should >0, got {max_batch_size}")
if threads <= 0 and threads != -1:
raise ValueError(f"Number of threads should >0 or =-1, got {threads}")

self.mmap_r = mmap_r
self.mmap_r: Optional[_MmapHashmapReadOnly] = mmap_r
self.max_batch_size = max_batch_size
self.key_prealloc = self.mmap_r.map.get_keyalloc(max_batch_size)
self.threads_c_uint32 = c_uint32(min(os.cpu_count(), os.cpu_count() if threads == -1 else threads))
self.key_prealloc = mmap_r.get_keyalloc(max_batch_size)

# `os.cpu_count()` is not equivalent to the number of CPUs the current process can use.
# The number of usable CPUs can be obtained with len(os.sched_getaffinity(0))
n_usable_cpu = len(os.sched_getaffinity(0))
self.threads_c_uint32 = c_uint32(
min(n_usable_cpu, n_usable_cpu if threads == -1 else threads)
)

# Pre-allocated space for returns
self.vals = np.zeros(max_batch_size, dtype=np.uint64)
Expand All @@ -123,8 +127,12 @@ def get(self, keys, default_val):
ii) int2int: 1D numpy array of int64
2) The return is a reused buffer, use or copy the data once you get it. It is not guaranteed to last.
"""
self.mmap_r.map.batch_get(
len(keys), self.key_prealloc.get_key_prealloc(keys), default_val, self.vals, self.threads_c_uint32
self.mmap_r.batch_get(
len(keys),
self.key_prealloc.get_key_prealloc(keys),
default_val,
self.vals,
self.threads_c_uint32,
)
return memoryview(self.vals)[: len(keys)]

Expand Down Expand Up @@ -159,7 +167,7 @@ def __contains__(self, key):
pass

@abstractmethod
def batch_get(self, n_keys, keys, default_val, vals):
def batch_get(self, n_keys, keys, default_val, vals, threads_c_uint32):
pass

@classmethod
Expand Down Expand Up @@ -200,7 +208,9 @@ def __getitem__(self, key_utf8):
def __contains__(self, key_utf8):
return self.fn_dict["contains"](self.map_ptr, key_utf8, len(key_utf8))

def batch_get(self, n_keys: int, keys_utf8: Tuple, default_val: int, vals, threads_c_uint32: c_uint32):
def batch_get(
self, n_keys: int, keys_utf8: Tuple, default_val: int, vals, threads_c_uint32: c_uint32
):
"""
Batch get values for UTF8 encoded bytes string keys.
Return values are stored in vals.
Expand All @@ -227,7 +237,7 @@ def batch_get(self, n_keys: int, keys_utf8: Tuple, default_val: int, vals, threa
keys_lens.ctypes.data_as(POINTER(c_uint32)),
default_val,
vals.ctypes.data_as(POINTER(c_uint64)),
threads_c_uint32
threads_c_uint32,
)
return vals

Expand Down Expand Up @@ -280,7 +290,7 @@ def batch_get(self, n_keys: int, keys, default_val: int, vals, threads_c_uint32:
keys.ctypes.data_as(POINTER(c_uint64)),
default_val,
vals.ctypes.data_as(POINTER(c_uint64)),
threads_c_uint32
threads_c_uint32,
)
return vals

Expand Down
4 changes: 2 additions & 2 deletions test/pecos/utils/test_mmap_hashmap_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_str2int_mmap_hashmap(tmpdir):
# Batch get with default
max_batch_size = 5
# max_batch_size > num of key
r_map_batch_getter = MmapHashmapBatchGetter(r_map, max_batch_size)
r_map_batch_getter = MmapHashmapBatchGetter(r_map.map, max_batch_size)
ks = list(kv_dict.keys()) + ["ccccc".encode("utf-8")] # Non-exist key
vs = list(kv_dict.values()) + [10]
assert r_map_batch_getter.get(ks, 10).tolist() == vs
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_int2int_mmap_hashmap(tmpdir):
# Batch get with default
max_batch_size = 5
# max_batch_size > num of key
r_map_batch_getter = MmapHashmapBatchGetter(r_map, max_batch_size)
r_map_batch_getter = MmapHashmapBatchGetter(r_map.map, max_batch_size)
ks = list(kv_dict.keys()) + [1000] # Non-exist key
vs = list(kv_dict.values()) + [10]
assert r_map_batch_getter.get(np.array(ks, dtype=np.int64), 10).tolist() == vs
Expand Down

0 comments on commit cf94ea2

Please sign in to comment.