Skip to content

Commit

Permalink
Add batch get for MmapHashmap
Browse files Browse the repository at this point in the history
  • Loading branch information
weiliw-amz committed Sep 26, 2023
1 parent 7643ebe commit 067eed8
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 3 deletions.
19 changes: 19 additions & 0 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,15 @@ def link_mmap_hashmap_methods(self):
c_uint64, # key int64
],
}
batch_key_args_dict = {
"str2int": [
c_void_p, # List of pointer of key string
POINTER(c_uint32), # List of length of key string
],
"int2int": [
POINTER(c_uint64), # List of key int64
],
}
self.mmap_map_fn_dict = {}

for map_type in map_type_list:
Expand Down Expand Up @@ -1760,6 +1769,16 @@ def link_mmap_hashmap_methods(self):
local_fn_dict[fn_name], c_uint64, [c_void_p] + key_args_dict[map_type] + [c_uint64]
)

fn_name = "batch_get_w_default"
local_fn_dict[fn_name] = getattr(self.clib_float32, f"{fn_prefix}_{fn_name}_{map_type}")
corelib.fillprototype(
local_fn_dict[fn_name],
None,
[c_void_p, c_uint32]
+ batch_key_args_dict[map_type] # noqa: W503
+ [c_uint64, POINTER(c_uint64)], # noqa: W503
)

fn_name = "contains"
local_fn_dict[fn_name] = getattr(self.clib_float32, f"{fn_prefix}_{fn_name}_{map_type}")
corelib.fillprototype(
Expand Down
6 changes: 6 additions & 0 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,12 @@ extern "C" {
MMAP_MAP_GET_W_DEFAULT(str2int, KEY_SINGLE_ARG(const char* key, uint32_t key_len), KEY_SINGLE_ARG(key, key_len))
MMAP_MAP_GET_W_DEFAULT(int2int, uint64_t key, key)

#define MMAP_MAP_BATCH_GET_W_DEFAULT(SUFFIX, KEY, FUNC_CALL_KEY) \
uint64_t mmap_hashmap_batch_get_w_default_ ## SUFFIX (void* map_ptr, const uint32_t n_key, KEY, uint64_t def_val, uint64_t* vals) { \
static_cast<mmap_hashmap_ ## SUFFIX *>(map_ptr)->batch_get_w_default(n_key, FUNC_CALL_KEY, def_val, vals); }
MMAP_MAP_BATCH_GET_W_DEFAULT(str2int, KEY_SINGLE_ARG(const char* const* keys, const uint32_t* keys_lens), KEY_SINGLE_ARG(keys, keys_lens))
MMAP_MAP_BATCH_GET_W_DEFAULT(int2int, const uint64_t* key, key)

// Contains
#define MMAP_MAP_CONTAINS(SUFFIX, KEY, FUNC_CALL_KEY) \
bool mmap_hashmap_contains_ ## SUFFIX (void* map_ptr, KEY) { \
Expand Down
16 changes: 16 additions & 0 deletions pecos/core/utils/mmap_hashmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#ifndef __MMAP_ANKERL_HASHMAP_H__
#define __MMAP_ANKERL_HASHMAP_H__

#include <omp.h>
#include "third_party/ankerl/unordered_dense.h"
#include "mmap_util.hpp"


namespace pecos {
namespace mmap_hashmap {

Expand Down Expand Up @@ -374,6 +376,13 @@ class Str2IntMap {
} catch (...) { return def_val;}
}

void batch_get_w_default(const uint32_t n_key, const char* const* keys, const uint32_t* keys_lens, const uint64_t def_val, uint64_t* vals) {
#pragma omp parallel for schedule(static, 1)
for (uint32_t i=0; i<n_key; ++i) {
vals[i] = get_w_default(keys[i], keys_lens[i], def_val);
}
}

bool contains(const char* key, uint32_t key_len) {
return map.contains(std::string_view(key, key_len));
}
Expand Down Expand Up @@ -403,6 +412,13 @@ class Int2IntMap {
} catch (...) { return def_val;}
}

void batch_get_w_default(const uint32_t n_key, const uint64_t* keys, const uint64_t def_val, uint64_t* vals) {
#pragma omp parallel for schedule(static, 1)
for (uint32_t i=0; i<n_key; ++i) {
vals[i] = get_w_default(keys[i], def_val);
}
}

bool contains(uint64_t key) { return map.contains(key); }

size_t size() { return map.size(); }
Expand Down
121 changes: 120 additions & 1 deletion pecos/utils/mmap_hashmap_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import logging
from abc import abstractmethod
from pecos.core import clib
from typing import Optional
from typing import Optional, Tuple, Any
from ctypes import c_char_p, c_uint32, c_uint64, POINTER
import numpy as np


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,6 +90,75 @@ def __del__(self):
self.close()


class MmapHashmapBatchGetter(object):
"""
Batch getter for MmapHashmap opened for readonly.
"""

def __init__(self, mmap_r: MmapHashmap, batch_size: int):
if not isinstance(mmap_r, MmapHashmap):
raise ValueError(f"Should get from MmapHashmap, got {type(mmap_r)}")
if mmap_r.mode not in ["r", "r_lazy"]:
raise ValueError(f"MmapHashmap should opened for readonly, got {mmap_r.mode}")
if batch_size <= 0:
raise ValueError(f"Batch size should >0, got {batch_size}")

self.mmap_r = mmap_r
self.batch_size = batch_size

self.key_prealloc = None # type: Any
if mmap_r.map_type == "str2int":
self.key_prealloc = _Str2IntBatchGetterKeyPreAlloc(batch_size)
elif mmap_r.map_type == "int2int":
self.key_prealloc = _Int2IntBatchGetterKeyPreAlloc(batch_size)
else:
raise NotImplementedError(f"map_type={mmap_r.map_type} is not implemented.")

# Pre-allocated space for returns
self.vals = np.zeros(batch_size, dtype=np.uint64)

def get(self, keys, default_val):
"""
Batch get multiple keys' values. For non-exist keys, `default_val` is returned.
NOTE:
1) Make sure keys given is compatible with the `MmapHashmap` `batch_get` type.
i) str2int: List of UTF8 encoded strings
ii) int2int: 1D numpy array of int64
2) The return is a reused buffer, use or copy the data once you get it. It is not guaranteed to last.
"""
self.mmap_r.map.batch_get(self.key_prealloc.get_key_prealloc(keys), default_val, self.vals)
return self.vals


class _Str2IntBatchGetterKeyPreAlloc(object):
"""
Batch getter for Str2Int MmapHashmap opened for readonly.
"""

def __init__(self, batch_size: int):
self.keys_ptr = (c_char_p * batch_size)()
self.keys_lens = np.zeros(batch_size, dtype=np.uint32)

def get_key_prealloc(self, keys_utf8):
self.keys_ptr[:] = keys_utf8
self.keys_lens.flat[:] = [len(k) for k in keys_utf8]

return (self.keys_ptr, self.keys_lens)


class _Int2IntBatchGetterKeyPreAlloc(object):
"""
Dummy key pre-allocate for Int2Int MmapHashmap.
"""

def __init__(self, batch_size: int):
pass

def get_key_prealloc(self, keys):
return keys


class _MmapHashmapBase(object):
"""Base class for methods shared by all modes"""

Expand Down Expand Up @@ -135,6 +206,7 @@ def get(self, key_utf8, default_val):
"""
Args:
key_utf8: UTF8 encoded bytes string key
default_val: Default value for key not found
"""
return self.fn_dict["get_w_default"](
self.map_ptr,
Expand All @@ -149,6 +221,34 @@ def __getitem__(self, key_utf8):
def __contains__(self, key_utf8):
return self.fn_dict["contains"](self.map_ptr, key_utf8, len(key_utf8))

def batch_get(self, keys_utf8: Tuple, default_val: int, vals):
"""
Batch get values for UTF8 encoded bytes string keys.
Return values are stored in vals.
How to make inputs from UTF8 encoded bytes string keys List `keys_utf8`:
> keys_ptr = (c_char_p * n_keys)()
> keys_ptr[:] = keys_utf8
> keys_lens = np.array([len(k) for k in keys_utf8], dtype=np.uint32)
Args:
keys_utf8: Tuple of (keys_ptr, keys_lens)
keys_ptr: List of UTF8 encoded bytes string keys' pointers
keys_lens: 1D Int32 Numpy array of string keys' lengths
default_val: Default value for key not found
vals: 1D Int64 Numpy array to return results
"""
keys_ptr, keys_lens = keys_utf8
self.fn_dict["batch_get_w_default"](
self.map_ptr,
len(keys_ptr),
keys_ptr,
keys_lens.ctypes.data_as(POINTER(c_uint32)),
default_val,
vals.ctypes.data_as(POINTER(c_uint64)),
)
return vals


class _MmapHashmapInt2IntReadOnly(_MmapHashmapReadOnly):
def get(self, key, default_val):
Expand All @@ -160,6 +260,25 @@ def __getitem__(self, key):
def __contains__(self, key):
return self.fn_dict["contains"](self.map_ptr, key)

def batch_get(self, keys, default_val, vals):
"""
Batch get values for Int64 keys.
Return values are stored in vals.
Args:
keys: 1D Int64 Numpy array
default_val: Default value for key not found
vals: 1D Int64 Numpy array to return results
"""
self.fn_dict["batch_get_w_default"](
self.map_ptr,
len(keys),
keys.ctypes.data_as(POINTER(c_uint64)),
default_val,
vals.ctypes.data_as(POINTER(c_uint64)),
)
return vals


class _MmapHashmapWrite(_MmapHashmapBase):
"""Base class for methods shared by all write modes"""
Expand Down
15 changes: 13 additions & 2 deletions test/pecos/utils/test_mmap_hashmap_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def test_str2int_mmap_hashmap(tmpdir):
from pecos.utils.mmap_hashmap_util import MmapHashmap
from pecos.utils.mmap_hashmap_util import MmapHashmap, MmapHashmapBatchGetter

map_dir = tmpdir.join("str2int_mmap").realpath().strpath
kv_dict = {"aaaa".encode("utf-8"): 2, "bb".encode("utf-8"): 3}
Expand All @@ -38,6 +38,11 @@ def test_str2int_mmap_hashmap(tmpdir):
for k, v in kv_dict.items():
assert r_map.map.get(k, 10) == v
assert r_map.map.get("ccccc".encode("utf-8"), 10) == 10
# Batch get with default
r_map_batch_getter = MmapHashmapBatchGetter(r_map, 3)
ks = list(kv_dict.keys()) + ["ccccc".encode("utf-8")] # Non-exist key
vs = list(kv_dict.values()) + [10]
assert r_map_batch_getter.get(ks, 10).tolist() == vs
# Contains
for k, _ in kv_dict.items():
assert k in r_map.map
Expand All @@ -47,7 +52,8 @@ def test_str2int_mmap_hashmap(tmpdir):


def test_int2int_mmap_hashmap(tmpdir):
from pecos.utils.mmap_hashmap_util import MmapHashmap
from pecos.utils.mmap_hashmap_util import MmapHashmap, MmapHashmapBatchGetter
import numpy as np

map_dir = tmpdir.join("int2int_mmap").realpath().strpath
kv_dict = {10: 2, 20: 3}
Expand All @@ -73,6 +79,11 @@ def test_int2int_mmap_hashmap(tmpdir):
for k, v in kv_dict.items():
assert r_map.map.get(k, 10) == v
assert r_map.map.get(1000, 10) == 10
# Batch get with default
r_map_batch_getter = MmapHashmapBatchGetter(r_map, 3)
ks = list(kv_dict.keys()) + [1000] # Non-exist key
vs = list(kv_dict.values()) + [10]
assert r_map_batch_getter.get(np.array(ks, dtype=np.int64), 10).tolist() == vs
# Contains
for k, _ in kv_dict.items():
assert k in r_map.map
Expand Down

0 comments on commit 067eed8

Please sign in to comment.