From 7444fec4a1bbaa39ebf15dd0cc9e90de96d8dea3 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 4 Dec 2024 15:42:00 +0000 Subject: [PATCH] flatnav integration into FlexIndex (#32) * initial implementation * test * fix test * documentation * documentation --- pyterrier_dr/flex/__init__.py | 3 +- pyterrier_dr/flex/flatnav_retr.py | 125 ++++++++++++++++++++ pyterrier_dr/pt_docs/indexing-retrieval.rst | 3 +- pyterrier_dr/util.py | 8 ++ requirements-dev.txt | 1 - requirements.txt | 1 + tests/test_flexindex.py | 4 + 7 files changed, 142 insertions(+), 3 deletions(-) create mode 100644 pyterrier_dr/flex/flatnav_retr.py diff --git a/pyterrier_dr/flex/__init__.py b/pyterrier_dr/flex/__init__.py index 88e79e3..b736036 100644 --- a/pyterrier_dr/flex/__init__.py +++ b/pyterrier_dr/flex/__init__.py @@ -3,9 +3,10 @@ from pyterrier_dr.flex import torch_retr from pyterrier_dr.flex import corpus_graph from pyterrier_dr.flex import faiss_retr +from pyterrier_dr.flex import flatnav_retr from pyterrier_dr.flex import scann_retr from pyterrier_dr.flex import ladr from pyterrier_dr.flex import gar from pyterrier_dr.flex import voyager_retr -__all__ = ["FlexIndex", "IndexingMode", "np_retr", "torch_retr", "corpus_graph", "faiss_retr", "scann_retr", "ladr", "gar", "voyager_retr"] +__all__ = ["FlexIndex", "IndexingMode", "np_retr", "torch_retr", "corpus_graph", "faiss_retr", "flatnav_retr", "scann_retr", "ladr", "gar", "voyager_retr"] diff --git a/pyterrier_dr/flex/flatnav_retr.py b/pyterrier_dr/flex/flatnav_retr.py new file mode 100644 index 0000000..a16441f --- /dev/null +++ b/pyterrier_dr/flex/flatnav_retr.py @@ -0,0 +1,125 @@ +import os +import numpy as np +import pandas as pd +import pyterrier as pt +import pyterrier_alpha as pta +import pyterrier_dr +from pyterrier_dr import SimFn +from . import FlexIndex + + +class FlatNavRetriever(pt.Transformer): + def __init__(self, flex_index, flatnav_index, *, threads=16, ef_search=100, num_initializations=100, num_results=1000, qbatch=64, drop_query_vec=False, verbose=False): + self.flex_index = flex_index + self.flatnav_index = flatnav_index + self.threads = threads + self.ef_search = ef_search + self.num_results = num_results + self.qbatch = qbatch + self.drop_query_vec = drop_query_vec + self.verbose = verbose + + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: + pta.validate.query_frame(inp, extra_columns=['query_vec']) + inp = inp.reset_index(drop=True) + docnos, config = self.flex_index.payload(return_dvecs=False) + query_vecs = np.stack(inp['query_vec']) + query_vecs = query_vecs.copy() + num_q = query_vecs.shape[0] + QBATCH = self.qbatch + it = range(0, num_q, QBATCH) + if self.flex_index.verbose: + it = pt.tqdm(it, unit='qbatch') + self.flatnav_index.set_num_threads(self.threads) + + result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank']) + for qidx in it: + scores, dids = self.flatnav_index.search( + queries=query_vecs[qidx:qidx+QBATCH], + ef_search=self.ef_search, + K=min(self.num_results, len(self.flex_index)), + ) + scores = -scores # distances -> scores + for s, d in zip(scores, dids): + mask = d != -1 + d = d[mask] + s = s[mask] + result.extend({ + 'docno': docnos.fwd[d], + 'docid': d, + 'score': s, + 'rank': np.arange(d.shape[0]), + }) + + if self.drop_query_vec: + inp = inp.drop(columns='query_vec') + return result.to_df(inp) + + +def _flatnav_retriever(self, + k: int = 32, + *, + ef_search: int = 100, + num_initializations: int = 100, + ef_construction: int = 100, + threads: int = 16, + num_results: int = 1000, + cache: bool = True, + qbatch: int = 64, + drop_query_vec: bool = False, + verbose: bool = False, +) -> pt.Transformer: + """Returns a retriever that searchers over a flatnav index. + + Args: + k (int): the maximum number of edges per document in the index + ef_search (int): the size of the list during searches. Higher values are slower but more accurate. + num_initializations (int): the number of random initializations to use during search. + ef_construction (int): the size of the list during graph construction. Higher values are slower but more accurate. + threads (int): the number of threads to use + num_results (int): the number of results to return per query + cache (bool): whether to cache the index to disk + qbatch (int): the number of queries to search at once + drop_query_vec (bool): whether to drop the query_vec column after retrieval + verbose (bool): whether to show progress bars + + .. note:: + This transformer requires the ``flatnav`` package to be installed. Instructions are available + in the `flatnav repository `__. + + .. cite:: arxiv:2412.01940 + :citation: Munyampirwa et al. Down with the Hierarchy: The 'H' in HNSW Stands for "Hubs". arXiv 2024. + :link: https://arxiv.org/abs/2412.01940 + """ + pyterrier_dr.util.assert_flatnav() + import flatnav + + key = ('flatnav', k, ef_construction) + index_name = f'{k}_ef-{ef_construction}-{str(self.sim_fn)}.flatnav' + if key not in self._cache: + dvecs, meta = self.payload(return_docnos=False) + if not os.path.exists(self.index_path/index_name): + distance_type = { + SimFn.dot: 'angular', + }[self.sim_fn] + idx = flatnav.index.create( + distance_type=distance_type, + index_data_type=flatnav.data_type.DataType.float32, + dim=dvecs.shape[1], + dataset_size=dvecs.shape[0], + max_edges_per_node=k, + verbose=True, + collect_stats=True, + ) + idx.set_num_threads(threads) + idx.add(data=np.array(dvecs), ef_construction=ef_construction) + # for start_idx in pt.tqdm(range(0, dvecs.shape[0], 4096), desc='indexing flatnav', unit='batch'): + # idx.add(data=np.array(dvecs[start_idx:start_idx+4096]), ef_construction=ef_construction) + if cache: + idx.save(str(self.index_path/index_name)) + self._cache[key] = idx + else: + self._cache[key] = flatnav.index.IndexIPFloat.load_index(str(self.index_path/index_name)) + self._cache[key].set_data_type(flatnav.data_type.DataType.float32) + return FlatNavRetriever(self, self._cache[key], threads=threads, ef_search=ef_search, num_initializations=num_initializations, num_results=num_results, qbatch=qbatch, drop_query_vec=drop_query_vec, verbose=verbose) +FlexIndex.flatnav_retriever = _flatnav_retriever diff --git a/pyterrier_dr/pt_docs/indexing-retrieval.rst b/pyterrier_dr/pt_docs/indexing-retrieval.rst index a9e29fe..13f2286 100644 --- a/pyterrier_dr/pt_docs/indexing-retrieval.rst +++ b/pyterrier_dr/pt_docs/indexing-retrieval.rst @@ -37,6 +37,7 @@ API Documentation .. automethod:: faiss_flat_retriever .. automethod:: faiss_hnsw_retriever .. automethod:: faiss_ivf_retriever + .. automethod:: flatnav_retriever .. automethod:: scann_retriever .. automethod:: voyager_retriever @@ -44,7 +45,7 @@ API Documentation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Results can be re-ranked using indexed vectors using :meth:`scorer`. (:meth:`np_scorer` and :meth:`torch_scorer` are - available as specific implemenations, if needed.) + available as specific implementations, if needed.) :meth:`gar`, :meth:`ladr_proactive`, and :meth:`ladr_adaptive` are *adaptive* re-ranking approaches that pull in other documents from the corpus that may be relevant. diff --git a/pyterrier_dr/util.py b/pyterrier_dr/util.py index e4e1635..8b9ccda 100644 --- a/pyterrier_dr/util.py +++ b/pyterrier_dr/util.py @@ -63,3 +63,11 @@ def voyager_available(): def assert_voyager(): assert voyager_available(), "voyager required; install with `pip install voyager`" + + +def flatnav_available(): + return package_available('flatnav') + + +def assert_flatnav(): + assert flatnav_available(), "flatnav required; install with instructions here: https://github.com/BlaiseMuhirwa/flatnav" diff --git a/requirements-dev.txt b/requirements-dev.txt index 7350e9c..f7af1f0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,6 @@ pytest pytest-subtests pytest-cov pytest-json-report -git+https://github.com/terrierteam/pyterrier_adaptive voyager FlagEmbedding faiss-cpu diff --git a/requirements.txt b/requirements.txt index ba2e698..21818a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ torch numpy>=1.21.0, <2.0.0 npids sentence_transformers +pyterrier-adaptive diff --git a/tests/test_flexindex.py b/tests/test_flexindex.py index 517c18d..516f137 100644 --- a/tests/test_flexindex.py +++ b/tests/test_flexindex.py @@ -134,6 +134,10 @@ def test_faiss_ivf_retriever(self): def test_scann_retriever(self): self._test_retr(FlexIndex.scann_retriever, exact=False) + @unittest.skipIf(not pyterrier_dr.util.flatnav_available(), "flatnav not available") + def test_flatnav_retriever(self): + self._test_retr(FlexIndex.flatnav_retriever, exact=False) + def test_np_retriever(self): self._test_retr(FlexIndex.np_retriever)