Skip to content

Commit

Permalink
adjust 2
Browse files Browse the repository at this point in the history
  • Loading branch information
ajit283 committed Dec 7, 2024
1 parent f30b57c commit b1fdcb6
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 203 deletions.
206 changes: 33 additions & 173 deletions python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,222 +15,82 @@
# limitations under the License.
#
import argparse
import importlib
import os
import sys
import warnings

from .utils import memmap_bin_file, suffix_from_dtype, write_bin


def import_with_fallback(primary_lib, secondary_lib=None, alias=None):
"""
Attempt to import a primary library, with an optional fallback to a
secondary library.
Optionally assigns the imported module to a global alias.
Parameters
----------
primary_lib : str
Name of the primary library to import.
secondary_lib : str, optional
Name of the secondary library to use as a fallback. If `None`,
no fallback is attempted.
alias : str, optional
Alias to assign the imported module globally.
Returns
-------
module or None
The imported module if successful; otherwise, `None`.
Examples
--------
>>> xp = import_with_fallback('cupy', 'numpy')
>>> mod = import_with_fallback('nonexistent_lib')
>>> if mod is None:
... print("Library not found.")
"""
try:
module = importlib.import_module(primary_lib)
except ImportError:
if secondary_lib is not None:
try:
module = importlib.import_module(secondary_lib)
except ImportError:
module = None
else:
module = None
if alias and module is not None:
globals()[alias] = module
return module


xp = import_with_fallback("cupy", "numpy")
rmm = import_with_fallback("rmm")
gpu_system = False
import cupy as cp
import numpy as np
import rmm
from pylibraft.common import DeviceResources
from rmm.allocators.cupy import rmm_cupy_allocator

from cuvs.neighbors.brute_force import build, search

def force_fallback_to_numpy():
global xp, gpu_system
xp = import_with_fallback("numpy")
gpu_system = False
warnings.warn(
"Consider using a GPU-based system to greatly accelerate "
" generating groundtruths using cuVS."
)


if rmm is not None:
gpu_system = True
try:
from pylibraft.common import DeviceResources
from rmm.allocators.cupy import rmm_cupy_allocator

from cuvs.neighbors.brute_force import build, search
except ImportError:
# RMM is available, cupy is available, but cuVS is not
force_fallback_to_numpy()
else:
# No RMM, no cuVS, but cupy is available
force_fallback_to_numpy()
from .utils import memmap_bin_file, suffix_from_dtype, write_bin


def generate_random_queries(n_queries, n_features, dtype=xp.float32):
def generate_random_queries(n_queries, n_features, dtype=np.float32):
print("Generating random queries")
if xp.issubdtype(dtype, xp.integer):
queries = xp.random.randint(
if np.issubdtype(dtype, np.integer):
queries = cp.random.randint(
0, 255, size=(n_queries, n_features), dtype=dtype
)
else:
queries = xp.random.uniform(size=(n_queries, n_features)).astype(dtype)
queries = cp.random.uniform(size=(n_queries, n_features)).astype(dtype)
return queries


def choose_random_queries(dataset, n_queries):
print("Choosing random vector from dataset as query vectors")
query_idx = xp.random.choice(
query_idx = np.random.choice(
dataset.shape[0], size=(n_queries,), replace=False
)
return dataset[query_idx, :]


def cpu_search(dataset, queries, k, metric="squeclidean"):
"""
Find the k nearest neighbors for each query point in the dataset using the
specified metric.
Parameters
----------
dataset : numpy.ndarray
An array of shape (n_samples, n_features) representing the dataset.
queries : numpy.ndarray
An array of shape (n_queries, n_features) representing the query
points.
k : int
The number of nearest neighbors to find.
metric : str, optional
The distance metric to use. Can be 'squeclidean' or 'inner_product'.
Default is 'squeclidean'.
Returns
-------
distances : numpy.ndarray
An array of shape (n_queries, k) containing the distances
(for 'squeclidean') or similarities
(for 'inner_product') to the k nearest neighbors for each query.
indices : numpy.ndarray
An array of shape (n_queries, k) containing the indices of the
k nearest neighbors in the dataset for each query.
"""
if metric == "squeclidean":
diff = queries[:, xp.newaxis, :] - dataset[xp.newaxis, :, :]
dist_sq = xp.sum(diff**2, axis=2) # Shape: (n_queries, n_samples)

indices = xp.argpartition(dist_sq, kth=k - 1, axis=1)[:, :k]
distances = xp.take_along_axis(dist_sq, indices, axis=1)

sorted_idx = xp.argsort(distances, axis=1)
distances = xp.take_along_axis(distances, sorted_idx, axis=1)
indices = xp.take_along_axis(indices, sorted_idx, axis=1)

elif metric == "inner_product":
similarities = xp.dot(
queries, dataset.T
) # Shape: (n_queries, n_samples)

neg_similarities = -similarities
indices = xp.argpartition(neg_similarities, kth=k - 1, axis=1)[:, :k]
distances = xp.take_along_axis(similarities, indices, axis=1)

sorted_idx = xp.argsort(-distances, axis=1)

else:
raise ValueError(
"Unsupported metric in cuvs-bench-cpu. "
"Use 'squeclidean' or 'inner_product' or use the GPU package"
"to use any distance supported by cuVS."
)

distances = xp.take_along_axis(distances, sorted_idx, axis=1)
indices = xp.take_along_axis(indices, sorted_idx, axis=1)

return distances, indices


def calc_truth(dataset, queries, k, metric="sqeuclidean"):
resources = DeviceResources()
n_samples = dataset.shape[0]
n = 500000 # batch size for processing neighbors
i = 0
indices = None
distances = None
queries = xp.asarray(queries, dtype=xp.float32)

if gpu_system:
resources = DeviceResources()
queries = cp.asarray(queries, dtype=cp.float32)

while i < n_samples:
print("Step {0}/{1}:".format(i // n, n_samples // n))
n_batch = n if i + n <= n_samples else n_samples - i

X = xp.asarray(dataset[i : i + n_batch, :], xp.float32)
X = cp.asarray(dataset[i : i + n_batch, :], cp.float32)

if gpu_system:
index = build(X, metric=metric, resources=resources)
D, Ind = search(index, queries, k, resources=resources)
resources.sync()
else:
D, Ind = cpu_search(X, queries, metric=metric)
index = build(X, metric=metric, resources=resources)
D, Ind = search(index, queries, k, resources=resources)
resources.sync()

D, Ind = xp.asarray(D), xp.asarray(Ind)
D, Ind = cp.asarray(D), cp.asarray(Ind)
Ind += i # shift neighbor index by offset i

if distances is None:
distances = D
indices = Ind
else:
distances = xp.concatenate([distances, D], axis=1)
indices = xp.concatenate([indices, Ind], axis=1)
idx = xp.argsort(distances, axis=1)[:, :k]
distances = xp.take_along_axis(distances, idx, axis=1)
indices = xp.take_along_axis(indices, idx, axis=1)
distances = cp.concatenate([distances, D], axis=1)
indices = cp.concatenate([indices, Ind], axis=1)
idx = cp.argsort(distances, axis=1)[:, :k]
distances = cp.take_along_axis(distances, idx, axis=1)
indices = cp.take_along_axis(indices, idx, axis=1)

i += n_batch

return distances, indices


def main():
if gpu_system and xp.__name__ == "cupy":
pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(), initial_pool_size=2**30
)
rmm.mr.set_current_device_resource(pool)
xp.cuda.set_allocator(rmm_cupy_allocator)
else:
# RMM is available, but cupy is not
force_fallback_to_numpy()
pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(), initial_pool_size=2**30
)
rmm.mr.set_current_device_resource(pool)
cp.cuda.set_allocator(rmm_cupy_allocator)

parser = argparse.ArgumentParser(
prog="generate_groundtruth",
Expand Down Expand Up @@ -337,7 +197,7 @@ def main():
"Dataset size {:6.1f} GB, shape {}, dtype {}".format(
dataset.size * dataset.dtype.itemsize / 1e9,
dataset.shape,
xp.dtype(dtype),
np.dtype(dtype),
)
)

Expand Down Expand Up @@ -370,13 +230,13 @@ def main():

write_bin(
os.path.join(args.output, "groundtruth.neighbors.ibin"),
indices.astype(xp.uint32),
indices.astype(np.uint32),
)
write_bin(
os.path.join(args.output, "groundtruth.distances.fbin"),
distances.astype(xp.float32),
distances.astype(np.float32),
)


if __name__ == "__main__":
main()
main()
75 changes: 45 additions & 30 deletions python/cuvs_bench/cuvs_bench/run/data_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import os
import traceback
import warnings

import pandas as pd

Expand Down Expand Up @@ -169,6 +170,44 @@ def convert_json_to_csv_build(dataset, dataset_path):
traceback.print_exc()


def append_build_data(write, build_file):
"""
Append build data to the search DataFrame.
Parameters
----------
write : pandas.DataFrame
The DataFrame containing the search data to which build
data will be appended.
build_file : str
The file path to the build CSV file.
"""
if os.path.exists(build_file):
build_df = pd.read_csv(build_file)
write_ncols = len(write.columns)
# Initialize columns for build data
build_columns = [
"build time",
"build threads",
"build cpu_time",
"build GPU",
]
write = write.assign(**{col: None for col in build_columns})
# Append additional columns if available
for col_name in build_df.columns[6:]:
write[col_name] = None
# Match build rows with search rows by index_name
for s_index, search_row in write.iterrows():
for b_index, build_row in build_df.iterrows():
if search_row["index_name"] == build_row["index_name"]:
write.iloc[s_index, write_ncols:] = build_row[2:].values
break
else:
warnings.warn(
f"Build CSV not found for {build_file}, build params not appended."
)


def convert_json_to_csv_search(dataset, dataset_path):
"""
Convert search JSON files to CSV format.
Expand All @@ -193,7 +232,7 @@ def convert_json_to_csv_search(dataset, dataset_path):
)
algo_name = clean_algo_name(algo_name)
df["name"] = df["name"].str.split("/").str[0]
write = pd.DataFrame(
write_data = pd.DataFrame(
{
"algo_name": [algo_name] * len(df),
"index_name": df["name"],
Expand All @@ -203,35 +242,11 @@ def convert_json_to_csv_search(dataset, dataset_path):
}
)
# Append build data
for name in df:
if name not in skip_search_cols:
write[name] = df[name]
if os.path.exists(build_file):
build_df = pd.read_csv(build_file)
write_ncols = len(write.columns)
write["build time"] = None
write["build threads"] = None
write["build cpu_time"] = None
write["build GPU"] = None

for col_idx in range(6, len(build_df.columns)):
col_name = build_df.columns[col_idx]
write[col_name] = None

for s_index, search_row in write.iterrows():
for b_index, build_row in build_df.iterrows():
if search_row["index_name"] == build_row["index_name"]:
write.iloc[s_index, write_ncols] = build_df.iloc[
b_index, 2
]
write.iloc[
s_index, write_ncols + 1 :
] = build_df.iloc[b_index, 3:]
break
append_build_data(write_data, build_file)
# Write search data and compute frontiers
write.to_csv(file.replace(".json", ",raw.csv"), index=False)
write_frontier(file, write, "throughput")
write_frontier(file, write, "latency")
write_data.to_csv(file.replace(".json", ",raw.csv"), index=False)
write_frontier(file, write_data, "throughput")
write_frontier(file, write_data, "latency")
except Exception as e:
print(f"Error processing search file {file}: {e}. Skipping...")
traceback.print_exc()
Expand Down Expand Up @@ -308,4 +323,4 @@ def write_frontier(file, write_data, metric):
(e.g., "throughput", "latency").
"""
frontier_data = get_frontier(write_data, metric)
frontier_data.to_csv(file.replace(".json", f",{metric}.csv"), index=False)
frontier_data.to_csv(file.replace(".json", f",{metric}.csv"), index=False)

0 comments on commit b1fdcb6

Please sign in to comment.