Skip to content

Commit

Permalink
Add k_factor to local-benchmarks (#517)
Browse files Browse the repository at this point in the history
  • Loading branch information
jparismorgan authored Sep 9, 2024
1 parent 701b9cc commit 02ce860
Showing 1 changed file with 41 additions and 32 deletions.
73 changes: 41 additions & 32 deletions apis/python/test/local-benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class RemoteURIType(Enum):
)


def sift_string():
return "(SIFT 10K)" if USE_SIFT_SMALL else "(SIFT 1M)"


class TimerMode(Enum):
INGESTION = "ingestion"
QUERY = "query"
Expand Down Expand Up @@ -202,7 +206,9 @@ def save_charts(self):
plt.figure(figsize=(20, 12))
plt.xlabel("Average Query Accuracy")
plt.ylabel("Time (seconds)")
plt.title(f"{self.name}: Ingestion Time vs Average Query Accuracy")
plt.title(
f"{self.name}: Ingestion Time vs Average Query Accuracy {sift_string()}"
)
self.add_data_to_ingestion_time_vs_average_query_accuracy()
plt.legend()
plt.savefig(
Expand All @@ -214,7 +220,7 @@ def save_charts(self):
plt.figure(figsize=(20, 12))
plt.xlabel("Accuracy")
plt.ylabel("Time (seconds)")
plt.title(f"{self.name}: Query Time vs Accuracy")
plt.title(f"{self.name}: Query Time vs Accuracy {sift_string()}")
self.add_data_to_query_time_vs_accuracy()
plt.legend()
plt.savefig(
Expand Down Expand Up @@ -245,7 +251,7 @@ def save_charts(self):
plt.figure(figsize=(20, 12))
plt.xlabel("Average Query Accuracy")
plt.ylabel("Time (seconds)")
plt.title("Ingestion Time vs Average Query Accuracy")
plt.title(f"Ingestion Time vs Average Query Accuracy {sift_string()}")
for idx, timer in self.timers:
timer.add_data_to_ingestion_time_vs_average_query_accuracy(
markers[idx % len(markers)]
Expand All @@ -258,7 +264,7 @@ def save_charts(self):
plt.figure(figsize=(20, 12))
plt.xlabel("Accuracy")
plt.ylabel("Time (seconds)")
plt.title("Query Time vs Accuracy")
plt.title(f"Query Time vs Accuracy {sift_string()}")
for idx, timer in self.timers:
timer.add_data_to_query_time_vs_accuracy(markers[idx % len(markers)])
plt.legend()
Expand Down Expand Up @@ -414,35 +420,38 @@ def benchmark_ivf_pq():
dimensions = queries.shape[1]
gt_i, gt_d = get_groundtruth_ivec(SIFT_GROUNDTRUTH_PATH, k=k, nqueries=len(queries))

for partitions in [50]:
for num_subspaces in [dimensions / 2, dimensions / 4, dimensions / 8]:
tag = f"{index_type}_partitions={partitions}_num_subspaces={num_subspaces}"
logger.info(f"Running {tag}")

index_uri = get_uri(tag)

timer.start(tag, TimerMode.INGESTION)
index = ingest(
index_type=index_type,
index_uri=index_uri,
source_uri=SIFT_BASE_PATH,
config=config,
partitions=partitions,
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
num_subspaces=num_subspaces,
)
ingest_time = timer.stop(tag, TimerMode.INGESTION)

for nprobe in [5, 10, 20, 40, 60]:
timer.start(tag, TimerMode.QUERY)
_, result = index.query(queries, k=k, nprobe=nprobe)
query_time = timer.stop(tag, TimerMode.QUERY)
acc = timer.accuracy(tag, accuracy(result, gt_i))
logger.info(
f"Finished {tag} with nprobe={nprobe}. Ingestion: {ingest_time:.4f}s. Query: {query_time:.4f}s. Accuracy: {acc:.4f}."
for partitions in [200]:
for num_subspaces in [dimensions / 4]:
for k_factor in [1, 1.5, 2, 4, 8, 16]:
tag = f"{index_type}_partitions={partitions}_num_subspaces={num_subspaces}_k_factor={k_factor}"
logger.info(f"Running {tag}")

index_uri = get_uri(tag)

timer.start(tag, TimerMode.INGESTION)
index = ingest(
index_type=index_type,
index_uri=index_uri,
source_uri=SIFT_BASE_PATH,
config=config,
partitions=partitions,
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
num_subspaces=num_subspaces,
)

cleanup_uri(index_uri)
ingest_time = timer.stop(tag, TimerMode.INGESTION)

for nprobe in [5, 10, 20, 40, 60]:
timer.start(tag, TimerMode.QUERY)
_, result = index.query(
queries, k=k, nprobe=nprobe, k_factor=k_factor
)
query_time = timer.stop(tag, TimerMode.QUERY)
acc = timer.accuracy(tag, accuracy(result, gt_i))
logger.info(
f"Finished {tag} with nprobe={nprobe}. Ingestion: {ingest_time:.4f}s. Query: {query_time:.4f}s. Accuracy: {acc:.4f}."
)

cleanup_uri(index_uri)

timer.save_and_print_results()

Expand Down

0 comments on commit 02ce860

Please sign in to comment.