Skip to content

Commit

Permalink
flatnav integration into FlexIndex (#32)
Browse files Browse the repository at this point in the history
* initial implementation

* test

* fix test

* documentation

* documentation
  • Loading branch information
seanmacavaney authored Dec 4, 2024
1 parent 4dae8e6 commit 7444fec
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pyterrier_dr/flex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from pyterrier_dr.flex import torch_retr
from pyterrier_dr.flex import corpus_graph
from pyterrier_dr.flex import faiss_retr
from pyterrier_dr.flex import flatnav_retr
from pyterrier_dr.flex import scann_retr
from pyterrier_dr.flex import ladr
from pyterrier_dr.flex import gar
from pyterrier_dr.flex import voyager_retr

__all__ = ["FlexIndex", "IndexingMode", "np_retr", "torch_retr", "corpus_graph", "faiss_retr", "scann_retr", "ladr", "gar", "voyager_retr"]
__all__ = ["FlexIndex", "IndexingMode", "np_retr", "torch_retr", "corpus_graph", "faiss_retr", "flatnav_retr", "scann_retr", "ladr", "gar", "voyager_retr"]
125 changes: 125 additions & 0 deletions pyterrier_dr/flex/flatnav_retr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import numpy as np
import pandas as pd
import pyterrier as pt
import pyterrier_alpha as pta
import pyterrier_dr
from pyterrier_dr import SimFn
from . import FlexIndex


class FlatNavRetriever(pt.Transformer):
def __init__(self, flex_index, flatnav_index, *, threads=16, ef_search=100, num_initializations=100, num_results=1000, qbatch=64, drop_query_vec=False, verbose=False):
self.flex_index = flex_index
self.flatnav_index = flatnav_index
self.threads = threads
self.ef_search = ef_search
self.num_results = num_results
self.qbatch = qbatch
self.drop_query_vec = drop_query_vec
self.verbose = verbose

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
docnos, config = self.flex_index.payload(return_dvecs=False)
query_vecs = np.stack(inp['query_vec'])
query_vecs = query_vecs.copy()
num_q = query_vecs.shape[0]
QBATCH = self.qbatch
it = range(0, num_q, QBATCH)
if self.flex_index.verbose:
it = pt.tqdm(it, unit='qbatch')
self.flatnav_index.set_num_threads(self.threads)

result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank'])
for qidx in it:
scores, dids = self.flatnav_index.search(
queries=query_vecs[qidx:qidx+QBATCH],
ef_search=self.ef_search,
K=min(self.num_results, len(self.flex_index)),
)
scores = -scores # distances -> scores
for s, d in zip(scores, dids):
mask = d != -1
d = d[mask]
s = s[mask]
result.extend({
'docno': docnos.fwd[d],
'docid': d,
'score': s,
'rank': np.arange(d.shape[0]),
})

if self.drop_query_vec:
inp = inp.drop(columns='query_vec')
return result.to_df(inp)


def _flatnav_retriever(self,
k: int = 32,
*,
ef_search: int = 100,
num_initializations: int = 100,
ef_construction: int = 100,
threads: int = 16,
num_results: int = 1000,
cache: bool = True,
qbatch: int = 64,
drop_query_vec: bool = False,
verbose: bool = False,
) -> pt.Transformer:
"""Returns a retriever that searchers over a flatnav index.
Args:
k (int): the maximum number of edges per document in the index
ef_search (int): the size of the list during searches. Higher values are slower but more accurate.
num_initializations (int): the number of random initializations to use during search.
ef_construction (int): the size of the list during graph construction. Higher values are slower but more accurate.
threads (int): the number of threads to use
num_results (int): the number of results to return per query
cache (bool): whether to cache the index to disk
qbatch (int): the number of queries to search at once
drop_query_vec (bool): whether to drop the query_vec column after retrieval
verbose (bool): whether to show progress bars
.. note::
This transformer requires the ``flatnav`` package to be installed. Instructions are available
in the `flatnav repository <https://github.com/BlaiseMuhirwa/flatnav>`__.
.. cite:: arxiv:2412.01940
:citation: Munyampirwa et al. Down with the Hierarchy: The 'H' in HNSW Stands for "Hubs". arXiv 2024.
:link: https://arxiv.org/abs/2412.01940
"""
pyterrier_dr.util.assert_flatnav()
import flatnav

key = ('flatnav', k, ef_construction)
index_name = f'{k}_ef-{ef_construction}-{str(self.sim_fn)}.flatnav'
if key not in self._cache:
dvecs, meta = self.payload(return_docnos=False)
if not os.path.exists(self.index_path/index_name):
distance_type = {
SimFn.dot: 'angular',
}[self.sim_fn]
idx = flatnav.index.create(
distance_type=distance_type,
index_data_type=flatnav.data_type.DataType.float32,
dim=dvecs.shape[1],
dataset_size=dvecs.shape[0],
max_edges_per_node=k,
verbose=True,
collect_stats=True,
)
idx.set_num_threads(threads)
idx.add(data=np.array(dvecs), ef_construction=ef_construction)
# for start_idx in pt.tqdm(range(0, dvecs.shape[0], 4096), desc='indexing flatnav', unit='batch'):
# idx.add(data=np.array(dvecs[start_idx:start_idx+4096]), ef_construction=ef_construction)
if cache:
idx.save(str(self.index_path/index_name))
self._cache[key] = idx
else:
self._cache[key] = flatnav.index.IndexIPFloat.load_index(str(self.index_path/index_name))
self._cache[key].set_data_type(flatnav.data_type.DataType.float32)
return FlatNavRetriever(self, self._cache[key], threads=threads, ef_search=ef_search, num_initializations=num_initializations, num_results=num_results, qbatch=qbatch, drop_query_vec=drop_query_vec, verbose=verbose)
FlexIndex.flatnav_retriever = _flatnav_retriever
3 changes: 2 additions & 1 deletion pyterrier_dr/pt_docs/indexing-retrieval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ API Documentation
.. automethod:: faiss_flat_retriever
.. automethod:: faiss_hnsw_retriever
.. automethod:: faiss_ivf_retriever
.. automethod:: flatnav_retriever
.. automethod:: scann_retriever
.. automethod:: voyager_retriever

Re-Ranking
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Results can be re-ranked using indexed vectors using :meth:`scorer`. (:meth:`np_scorer` and :meth:`torch_scorer` are
available as specific implemenations, if needed.)
available as specific implementations, if needed.)

:meth:`gar`, :meth:`ladr_proactive`, and :meth:`ladr_adaptive` are *adaptive* re-ranking approaches that pull in other
documents from the corpus that may be relevant.
Expand Down
8 changes: 8 additions & 0 deletions pyterrier_dr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,11 @@ def voyager_available():

def assert_voyager():
assert voyager_available(), "voyager required; install with `pip install voyager`"


def flatnav_available():
return package_available('flatnav')


def assert_flatnav():
assert flatnav_available(), "flatnav required; install with instructions here: https://github.com/BlaiseMuhirwa/flatnav"
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ pytest
pytest-subtests
pytest-cov
pytest-json-report
git+https://github.com/terrierteam/pyterrier_adaptive
voyager
FlagEmbedding
faiss-cpu
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ torch
numpy>=1.21.0, <2.0.0
npids
sentence_transformers
pyterrier-adaptive
4 changes: 4 additions & 0 deletions tests/test_flexindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def test_faiss_ivf_retriever(self):
def test_scann_retriever(self):
self._test_retr(FlexIndex.scann_retriever, exact=False)

@unittest.skipIf(not pyterrier_dr.util.flatnav_available(), "flatnav not available")
def test_flatnav_retriever(self):
self._test_retr(FlexIndex.flatnav_retriever, exact=False)

def test_np_retriever(self):
self._test_retr(FlexIndex.np_retriever)

Expand Down

0 comments on commit 7444fec

Please sign in to comment.