Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bulk query APIs #342

Merged
merged 12 commits into from
Dec 16, 2024
38 changes: 38 additions & 0 deletions rtree/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,44 @@ def free_error_msg_ptr(result, func, cargs):
rt.Index_NearestNeighbors_id.restype = ctypes.c_int
rt.Index_NearestNeighbors_id.errcheck = check_return # type: ignore

try:
rt.Index_NearestNeighbors_id_v.argtypes = [
ctypes.c_void_p,
ctypes.c_int64,
ctypes.c_int64,
ctypes.c_uint32,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_int64),
]
rt.Index_NearestNeighbors_id_v.restype = ctypes.c_int
rt.Index_NearestNeighbors_id_v.errcheck = check_return # type: ignore

rt.Index_Intersects_id_v.argtypes = [
ctypes.c_void_p,
ctypes.c_int64,
ctypes.c_uint32,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_int64),
]
rt.Index_Intersects_id_v.restype = ctypes.c_int
rt.Index_Intersects_id_v.errcheck = check_return # type: ignore
except AttributeError:
pass


rt.Index_GetLeaves.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_uint32),
Expand Down
109 changes: 109 additions & 0 deletions rtree/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,108 @@ def nearest(

return self._get_ids(it, p_num_results.contents.value)

def intersection_v(self, mins, maxs):
import numpy as np

assert mins.shape == maxs.shape
assert mins.strides == maxs.strides

# Cast
mins = mins.astype(np.float64)
maxs = maxs.astype(np.float64)

# Extract counts
n, d = mins.shape

# Compute strides
d_i_stri = mins.strides[0] // mins.itemsize
d_j_stri = mins.strides[1] // mins.itemsize

ids = np.empty(2 * n, dtype=np.int64)
counts = np.empty(n, dtype=np.uint64)
nr = ctypes.c_int64(0)
offn, offi = 0, 0

while True:
core.rt.Index_Intersects_id_v(
self.handle,
n - offn,
d,
len(ids),
d_i_stri,
d_j_stri,
mins[offn:].ctypes.data,
maxs[offn:].ctypes.data,
ids[offi:].ctypes.data,
counts[offn:].ctypes.data,
ctypes.byref(nr),
)

# If we got the expected nuber of results then return
if nr.value == n - offn:
return ids[: counts.sum()], counts
# Otherwise, if our array is too small then resize
else:
offi += counts[offn : offn + nr.value].sum()
offn += nr.value

ids = ids.resize(2 * len(ids), refcheck=False)

def nearest_v(
self, mins, maxs, num_results=1, strict=False, return_max_dists=False
):
import numpy as np

assert mins.shape == maxs.shape
assert mins.strides == maxs.strides

# Cast
mins = mins.astype(np.float64)
maxs = maxs.astype(np.float64)

# Extract counts
n, d = mins.shape

# Compute strides
d_i_stri = mins.strides[0] // mins.itemsize
d_j_stri = mins.strides[1] // mins.itemsize

ids = np.empty(n * num_results, dtype=np.int64)
counts = np.empty(n, dtype=np.uint64)
dists = np.empty(n) if return_max_dists else None
nr = ctypes.c_int64(0)
offn, offi = 0, 0

while True:
core.rt.Index_NearestNeighbors_id_v(
self.handle,
num_results if not strict else -num_results,
n - offn,
d,
len(ids),
d_i_stri,
d_j_stri,
mins[offn:].ctypes.data,
maxs[offn:].ctypes.data,
ids[offi:].ctypes.data,
counts[offn:].ctypes.data,
dists[offn:].ctypes.data if return_max_dists else None,
ctypes.byref(nr),
)

# If we got the expected nuber of results then return
if nr.value == n - offn:
if return_max_dists:
return ids[: counts.sum()], counts, dists
else:
return ids[: counts.sum()], counts
# Otherwise, if our array is too small then resize
else:
offi += counts[offn : offn + nr.value].sum()
offn += nr.value

ids = ids.resize(2 * len(ids), refcheck=False)

def _nearestTP(self, coordinates, velocities, times, num_results=1, objects=False):
p_mins, p_maxs = self.get_coordinate_pointers(coordinates)
pv_mins, pv_maxs = self.get_coordinate_pointers(velocities)
Expand Down Expand Up @@ -1538,6 +1640,13 @@ def initialize_from_dict(self, state: dict[str, Any]) -> None:
if v is not None:
setattr(self, k, v)

# Consistency checks
if "near_minimum_overlap_factor" not in state:
nmof = self.near_minimum_overlap_factor
ilc = min(self.index_capacity, self.leaf_capacity)
if nmof >= ilc:
self.near_minimum_overlap_factor = ilc // 3 + 1

def __getstate__(self) -> dict[Any, Any]:
return self.as_dict()

Expand Down
Loading