Skip to content

Commit

Permalink
Update IVF_PQ to set memory_budget in constructor, support preload fe…
Browse files Browse the repository at this point in the history
…ature_vectors and metadata only modes (#518)
  • Loading branch information
jparismorgan authored Sep 16, 2024
1 parent b88d4ba commit 35acf35
Show file tree
Hide file tree
Showing 12 changed files with 620 additions and 589 deletions.
5 changes: 5 additions & 0 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,20 @@ def _query_with_driver(
def query_udf(index_type, index_open_kwargs, query_kwargs):
from tiledb.vector_search.flat_index import FlatIndex
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
from tiledb.vector_search.ivf_pq_index import IVFPQIndex
from tiledb.vector_search.vamana_index import VamanaIndex

# Open index
if index_type == "FLAT":
index = FlatIndex(**index_open_kwargs)
elif index_type == "IVF_FLAT":
index = IVFFlatIndex(**index_open_kwargs)
elif index_type == "IVF_PQ":
index = IVFPQIndex(**index_open_kwargs)
elif index_type == "VAMANA":
index = VamanaIndex(**index_open_kwargs)
else:
raise ValueError(f"Unsupported index_type: {index_type}")

# Query index
return index.query(**query_kwargs)
Expand Down
49 changes: 36 additions & 13 deletions apis/python/src/tiledb/vector_search/ivf_pq_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ class IVFPQIndex(index.Index):
If not provided, all index data are loaded in main memory.
Otherwise, no index data are loaded in main memory and this memory budget is
applied during queries.
preload_k_factor_vectors: bool
When using `k_factor` in a query, we first query for `k_factor * k` pq-encoded vectors,
and then do a re-ranking step using the original input vectors for the top `k` vectors.
If `True`, we will load all the input vectors in main memory. This can only be used with
`memory_budget` set to `-1`, and is useful when the input vectors are small enough to fit in
memory and you want to speed up re-ranking.
open_for_remote_query_execution: bool
If `True`, do not load any index data in main memory locally, and instead load index data in the TileDB Cloud taskgraph created when a non-`None` `driver_mode` is passed to `query()`.
If `False`, load index data in main memory locally. Note that you can still use a taskgraph for query execution, you'll just end up loading the data both on your local machine and in the cloud taskgraph.
Expand All @@ -48,15 +54,26 @@ def __init__(
config: Optional[Mapping[str, Any]] = None,
timestamp=None,
memory_budget: int = -1,
preload_k_factor_vectors: bool = False,
open_for_remote_query_execution: bool = False,
group: tiledb.Group = None,
**kwargs,
):
if preload_k_factor_vectors and memory_budget != -1:
raise ValueError(
"preload_k_factor_vectors can only be used with memory_budget set to -1."
)
if preload_k_factor_vectors and open_for_remote_query_execution:
raise ValueError(
"preload_k_factor_vectors can only be used with open_for_remote_query_execution set to False."
)

self.index_open_kwargs = {
"uri": uri,
"config": config,
"timestamp": timestamp,
"memory_budget": memory_budget,
"preload_k_factor_vectors": preload_k_factor_vectors,
}
self.index_open_kwargs.update(kwargs)
self.index_type = INDEX_TYPE
Expand All @@ -67,8 +84,21 @@ def __init__(
open_for_remote_query_execution=open_for_remote_query_execution,
group=group,
)
# TODO(SC-48710): Add support for `open_for_remote_query_execution`. We don't leave `self.index`` as `None` because we need to be able to call index.dimensions().
self.index = vspy.IndexIVFPQ(self.ctx, uri, to_temporal_policy(timestamp))
strategy = (
vspy.IndexLoadStrategy.PQ_INDEX_AND_RERANKING_VECTORS
if preload_k_factor_vectors
else vspy.IndexLoadStrategy.PQ_OOC
if open_for_remote_query_execution
or (memory_budget != -1 and memory_budget != 0)
else vspy.IndexLoadStrategy.PQ_INDEX
)
self.index = vspy.IndexIVFPQ(
self.ctx,
uri,
strategy,
0 if memory_budget == -1 else memory_budget,
to_temporal_policy(timestamp),
)
self.db_uri = self.group[
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
].uri
Expand Down Expand Up @@ -127,16 +157,9 @@ def query_internal(
if not queries.flags.f_contiguous:
queries = queries.copy(order="F")
queries_feature_vector_array = vspy.FeatureVectorArray(queries)

if self.memory_budget == -1:
distances, ids = self.index.query_infinite_ram(
queries_feature_vector_array, k, nprobe, k_factor
)
else:
distances, ids = self.index.query_finite_ram(
queries_feature_vector_array, k, nprobe, self.memory_budget, k_factor
)

distances, ids = self.index.query(
queries_feature_vector_array, k=k, nprobe=nprobe, k_factor=k_factor
)
return np.array(distances, copy=False), np.array(ids, copy=False)


Expand Down Expand Up @@ -203,7 +226,7 @@ def create(
id_type=np.dtype(np.uint64).name,
partitioning_index_type=np.dtype(np.uint64).name,
dimensions=dimensions,
n_list=partitions if (partitions is not None and partitions is not -1) else 0,
n_list=partitions if (partitions is not None and partitions != -1) else 0,
num_subspaces=num_subspaces,
distance_metric=int(distance_metric),
)
Expand Down
9 changes: 9 additions & 0 deletions apis/python/src/tiledb/vector_search/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "detail/linalg/tdb_matrix.h"
#include "detail/linalg/tdb_partitioned_matrix.h"
#include "detail/time/temporal_policy.h"
#include "index/index_defs.h"
#include "utils/seeder.h"

namespace py = pybind11;
Expand Down Expand Up @@ -1096,6 +1097,14 @@ PYBIND11_MODULE(_tiledbvspy, m) {
.value("L2", DistanceMetric::L2)
.export_values();

py::enum_<IndexLoadStrategy>(m, "IndexLoadStrategy")
.value("PQ_OOC", IndexLoadStrategy::PQ_OOC)
.value("PQ_INDEX", IndexLoadStrategy::PQ_INDEX)
.value(
"PQ_INDEX_AND_RERANKING_VECTORS",
IndexLoadStrategy::PQ_INDEX_AND_RERANKING_VECTORS)
.export_values();

/* === Module inits === */

init_kmeans(m);
Expand Down
66 changes: 25 additions & 41 deletions apis/python/src/tiledb/vector_search/type_erased_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,8 @@ void init_type_erased_module(py::module_& m) {
.def("dimensions", &IndexFlatL2::dimensions)
.def(
"query",
[](IndexFlatL2& index,
const FeatureVectorArray& vectors,
size_t top_k) {
auto r = index.query(vectors, top_k);
[](IndexFlatL2& index, const FeatureVectorArray& vectors, size_t k) {
auto r = index.query(vectors, k);
return make_python_pair(std::move(r));
});

Expand Down Expand Up @@ -422,13 +420,13 @@ void init_type_erased_module(py::module_& m) {
"query",
[](IndexVamana& index,
const FeatureVectorArray& vectors,
size_t top_k,
size_t k,
uint32_t l_search) {
auto r = index.query(vectors, top_k, l_search);
auto r = index.query(vectors, k, l_search);
return make_python_pair(std::move(r));
},
py::arg("vectors"),
py::arg("top_k"),
py::arg("k"),
py::arg("l_search"))
.def(
"write_index",
Expand Down Expand Up @@ -467,12 +465,21 @@ void init_type_erased_module(py::module_& m) {
[](IndexIVFPQ& instance,
const tiledb::Context& ctx,
const std::string& group_uri,
IndexLoadStrategy index_load_strategy,
size_t memory_budget,
std::optional<TemporalPolicy> temporal_policy) {
new (&instance) IndexIVFPQ(ctx, group_uri, temporal_policy);
new (&instance) IndexIVFPQ(
ctx,
group_uri,
index_load_strategy,
memory_budget,
temporal_policy);
},
py::keep_alive<1, 2>(), // IndexIVFPQ should keep ctx alive.
py::arg("ctx"),
py::arg("group_uri"),
py::arg("index_load_strategy") = IndexLoadStrategy::PQ_INDEX,
py::arg("memory_budget") = 0,
py::arg("temporal_policy") = std::nullopt)
.def(
"__init__",
Expand All @@ -494,41 +501,18 @@ void init_type_erased_module(py::module_& m) {
},
py::arg("vectors"))
.def(
"query_infinite_ram",
[](IndexIVFPQ& index,
const FeatureVectorArray& vectors,
size_t top_k,
size_t nprobe,
float k_factor) {
auto r = index.query(
QueryType::InfiniteRAM, vectors, top_k, nprobe, 0, k_factor);
return make_python_pair(std::move(r));
},
py::arg("vectors"),
py::arg("top_k"),
py::arg("nprobe"),
py::arg("k_factor") = 1.f)
.def(
"query_finite_ram",
"query",
[](IndexIVFPQ& index,
const FeatureVectorArray& vectors,
size_t top_k,
size_t k,
size_t nprobe,
size_t memory_budget,
float k_factor) {
auto r = index.query(
QueryType::FiniteRAM,
vectors,
top_k,
nprobe,
memory_budget,
k_factor);
auto r = index.query(vectors, k, nprobe, k_factor);
return make_python_pair(std::move(r));
},
py::arg("vectors"),
py::arg("top_k"),
py::arg("k"),
py::arg("nprobe"),
py::arg("memory_budget"),
py::arg("k_factor") = 1.f)
.def(
"write_index",
Expand Down Expand Up @@ -603,24 +587,24 @@ void init_type_erased_module(py::module_& m) {
"query_infinite_ram",
[](IndexIVFFlat& index,
const FeatureVectorArray& query,
size_t top_k,
size_t k,
size_t nprobe) {
auto r = index.query_infinite_ram(query, top_k, nprobe);
auto r = index.query_infinite_ram(query, k, nprobe);
return make_python_pair(std::move(r));
}) // , py::arg("vectors"), py::arg("top_k") = 1, py::arg("nprobe")
}) // , py::arg("vectors"), py::arg("k") = 1, py::arg("nprobe")
// = 10)
.def(
"query_finite_ram",
[](IndexIVFFlat& index,
const FeatureVectorArray& query,
size_t top_k,
size_t k,
size_t nprobe,
size_t upper_bound) {
auto r = index.query_finite_ram(query, top_k, nprobe, upper_bound);
auto r = index.query_finite_ram(query, k, nprobe, upper_bound);
return make_python_pair(std::move(r));
},
py::arg("vectors"),
py::arg("top_k") = 1,
py::arg("k") = 1,
py::arg("nprobe") = 10,
py::arg("upper_bound") = 0)
.def("feature_type_string", &IndexIVFFlat::feature_type_string)
Expand Down
24 changes: 19 additions & 5 deletions apis/python/test/local-benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from tiledb.vector_search.index import Index
from tiledb.vector_search.ingestion import TrainingSamplingPolicy
from tiledb.vector_search.ingestion import ingest
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
from tiledb.vector_search.ivf_pq_index import IVFPQIndex
from tiledb.vector_search.utils import load_fvecs
from tiledb.vector_search.vamana_index import VamanaIndex


class RemoteURIType(Enum):
Expand Down Expand Up @@ -252,7 +255,7 @@ def save_charts(self):
plt.xlabel("Average Query Accuracy")
plt.ylabel("Time (seconds)")
plt.title(f"Ingestion Time vs Average Query Accuracy {sift_string()}")
for idx, timer in self.timers:
for idx, timer in enumerate(self.timers):
timer.add_data_to_ingestion_time_vs_average_query_accuracy(
markers[idx % len(markers)]
)
Expand All @@ -265,7 +268,7 @@ def save_charts(self):
plt.xlabel("Accuracy")
plt.ylabel("Time (seconds)")
plt.title(f"Query Time vs Accuracy {sift_string()}")
for idx, timer in self.timers:
for idx, timer in enumerate(self.timers):
timer.add_data_to_query_time_vs_accuracy(markers[idx % len(markers)])
plt.legend()
plt.savefig(os.path.join(RESULTS_DIR, "query_time_vs_accuracy.png"))
Expand Down Expand Up @@ -295,6 +298,7 @@ def download_and_extract(url, download_path, extract_path):


def get_uri(tag):
global config
index_name = f"index_{tag.replace('=', '_')}"
index_uri = ""
if REMOTE_URI_TYPE == RemoteURIType.LOCAL:
Expand Down Expand Up @@ -346,7 +350,7 @@ def benchmark_ivf_flat():
index_uri = get_uri(tag)

timer.start(tag, TimerMode.INGESTION)
index = ingest(
ingest(
index_type=index_type,
index_uri=index_uri,
source_uri=SIFT_BASE_PATH,
Expand All @@ -356,6 +360,10 @@ def benchmark_ivf_flat():
)
ingest_time = timer.stop(tag, TimerMode.INGESTION)

# The index returned by ingest() automatically has memory_budget=1000000 set. Open
# a fresh index so it's clear what config is being used.
index = IVFFlatIndex(index_uri, config)

for nprobe in [1, 2, 3, 4, 5, 10, 20]:
timer.start(tag, TimerMode.QUERY)
_, result = index.query(queries, k=k, nprobe=nprobe)
Expand Down Expand Up @@ -386,7 +394,7 @@ def benchmark_vamana():
index_uri = get_uri(tag)

timer.start(tag, TimerMode.INGESTION)
index = ingest(
ingest(
index_type=index_type,
index_uri=index_uri,
source_uri=SIFT_BASE_PATH,
Expand All @@ -397,6 +405,8 @@ def benchmark_vamana():
)
ingest_time = timer.stop(tag, TimerMode.INGESTION)

index = VamanaIndex(index_uri, config)

for l_search in [k, k + 50, k + 100, k + 200, k + 400]:
timer.start(tag, TimerMode.QUERY)
_, result = index.query(queries, k=k, l_search=l_search)
Expand Down Expand Up @@ -429,7 +439,7 @@ def benchmark_ivf_pq():
index_uri = get_uri(tag)

timer.start(tag, TimerMode.INGESTION)
index = ingest(
ingest(
index_type=index_type,
index_uri=index_uri,
source_uri=SIFT_BASE_PATH,
Expand All @@ -440,6 +450,10 @@ def benchmark_ivf_pq():
)
ingest_time = timer.stop(tag, TimerMode.INGESTION)

# The index returned by ingest() automatically has memory_budget=1000000 set. Open
# a fresh index so it's clear what config is being used.
index = IVFPQIndex(index_uri, config)

for nprobe in [5, 10, 20, 40, 60]:
timer.start(tag, TimerMode.QUERY)
_, result = index.query(
Expand Down
1 change: 1 addition & 0 deletions apis/python/test/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def run_cloud_test(self, index_uri, index_type, index_class):
input_vectors_per_work_item=5000,
config=tiledb.cloud.Config().dict(),
mode=Mode.BATCH,
verbose=True,
)
tiledb_index_uri = groups.info(index_uri).tiledb_uri

Expand Down
Loading

0 comments on commit 35acf35

Please sign in to comment.