Skip to content

Commit

Permalink
faiss retrievers
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Nov 22, 2024
1 parent 6cde1f1 commit 0baba13
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 34 deletions.
52 changes: 26 additions & 26 deletions pyterrier_dr/flex/faiss_retr.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,34 @@
import json
import pandas as pd
import math
import struct
import os
import pyterrier as pt
import itertools
import numpy as np
import tempfile
import ir_datasets
import pyterrier_dr
import pyterrier_alpha as pta
from . import FlexIndex

logger = ir_datasets.log.easy()


class FaissRetriever(pt.Indexer):
def __init__(self, flex_index, faiss_index, n_probe=None, ef_search=None, search_bounded_queue=None, qbatch=64):
def __init__(self, flex_index, faiss_index, n_probe=None, ef_search=None, search_bounded_queue=None, qbatch=64, drop_query_vec=False):
self.flex_index = flex_index
self.faiss_index = faiss_index
self.n_probe = n_probe
self.ef_search = ef_search
self.search_bounded_queue = search_bounded_queue
self.qbatch = qbatch
self.drop_query_vec = drop_query_vec

def transform(self, inp):
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
assert all(f in inp.columns for f in ['qid', 'query_vec'])
docnos, config = self.flex_index.payload(return_dvecs=False)
query_vecs = np.stack(inp['query_vec'])
query_vecs = query_vecs.copy()
idxs = []
res = {'docid': [], 'score': [], 'rank': []}
num_q = query_vecs.shape[0]
QBATCH = self.qbatch
if self.n_probe is not None:
Expand All @@ -42,25 +40,27 @@ def transform(self, inp):
it = range(0, num_q, QBATCH)
if self.flex_index.verbose:
it = logger.pbar(it, unit='qbatch')

result = pta.DataFrameBuilder(['score', 'docno', 'docid', 'rank'])
for qidx in it:
scores, dids = self.faiss_index.search(query_vecs[qidx:qidx+QBATCH], self.flex_index.num_results)
for i, (s, d) in enumerate(zip(scores, dids)):
for s, d in zip(scores, dids):
mask = d != -1
d = d[mask]
s = s[mask]
res['docid'].append(d)
res['score'].append(s)
res['rank'].append(np.arange(d.shape[0]))
idxs.extend(itertools.repeat(qidx+i, d.shape[0]))
res = {k: np.concatenate(v) for k, v in res.items()}
res['docno'] = docnos.fwd[res['docid']]
for col in inp.columns:
if col != 'query_vec':
res[col] = inp[col][idxs].values
return pd.DataFrame(res)


def _faiss_flat_retriever(self, gpu=False, qbatch=64):
result.extend({
'score': s,
'docno': docnos.fwd[d],
'docid': d,
'rank': np.arange(d.shape[0]),
})

if self.drop_query_vec:
inp = inp.drop(columns='query_vec')
return result.to_df(inp)


def _faiss_flat_retriever(self, gpu=False, qbatch=64, drop_query_vec=False):
pyterrier_dr.util.assert_faiss()
import faiss
if 'faiss_flat' not in self._cache:
Expand All @@ -80,12 +80,12 @@ def _faiss_flat_retriever(self, gpu=False, qbatch=64):
co = faiss.GpuMultipleClonerOptions()
co.shard = True
self._cache['faiss_flat_gpu'] = faiss.index_cpu_to_all_gpus(self._faiss_flat, co=co)
return FaissRetriever(self, self._cache['faiss_flat_gpu'])
return FaissRetriever(self, self._cache['faiss_flat'], qbatch=qbatch)
return FaissRetriever(self, self._cache['faiss_flat_gpu'], drop_query_vec=drop_query_vec)
return FaissRetriever(self, self._cache['faiss_flat'], qbatch=qbatch, drop_query_vec=drop_query_vec)
FlexIndex.faiss_flat_retriever = _faiss_flat_retriever


def _faiss_hnsw_retriever(self, neighbours=32, ef_construction=40, ef_search=16, cache=True, search_bounded_queue=True, qbatch=64):
def _faiss_hnsw_retriever(self, neighbours=32, ef_construction=40, ef_search=16, cache=True, search_bounded_queue=True, qbatch=64, drop_query_vec=False):
pyterrier_dr.util.assert_faiss()
import faiss
meta, = self.payload(return_dvecs=False, return_docnos=False)
Expand All @@ -107,7 +107,7 @@ def _faiss_hnsw_retriever(self, neighbours=32, ef_construction=40, ef_search=16,
with logger.duration('reading hnsw table'):
self._cache[key] = faiss.read_index(str(self.index_path/index_name))
self._cache[key].storage = self.faiss_flat_retriever().faiss_index
return FaissRetriever(self, self._cache[key], ef_search=ef_search, search_bounded_queue=search_bounded_queue, qbatch=qbatch)
return FaissRetriever(self, self._cache[key], ef_search=ef_search, search_bounded_queue=search_bounded_queue, qbatch=qbatch, drop_query_vec=drop_query_vec)
FlexIndex.faiss_hnsw_retriever = _faiss_hnsw_retriever


Expand Down Expand Up @@ -154,7 +154,7 @@ def _sample_train(index, count=None):
idxs = np.random.RandomState(0).choice(dvecs.shape[0], size=count, replace=False)
return dvecs[idxs]

def _faiss_ivf_retriever(self, train_sample=None, n_list=None, cache=True, n_probe=1):
def _faiss_ivf_retriever(self, train_sample=None, n_list=None, cache=True, n_probe=1, drop_query_vec=False):
pyterrier_dr.util.assert_faiss()
import faiss
meta, = self.payload(return_dvecs=False, return_docnos=False)
Expand Down Expand Up @@ -197,5 +197,5 @@ def _faiss_ivf_retriever(self, train_sample=None, n_list=None, cache=True, n_pro
else:
with logger.duration('reading index'):
self._cache[key] = faiss.read_index(str(self.index_path/index_name))
return FaissRetriever(self, self._cache[key], n_probe=n_probe)
return FaissRetriever(self, self._cache[key], n_probe=n_probe, drop_query_vec=drop_query_vec)
FlexIndex.faiss_ivf_retriever = _faiss_ivf_retriever
1 change: 0 additions & 1 deletion pyterrier_dr/flex/np_retr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import pyterrier as pt
import numpy as np
import pandas as pd
Expand Down
23 changes: 16 additions & 7 deletions tests/test_flexindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,34 @@ def _test_retr(self, Retr, exact=True, test_smaller=True):

@unittest.skipIf(not pyterrier_dr.util.faiss_available(), "faiss not available")
def test_faiss_flat_retriever(self):
self._test_retr(FlexIndex.faiss_flat_retriever)
with self.subTest('drop_query_vec=True'):
self._test_retr(functools.partial(FlexIndex.faiss_flat_retriever, drop_query_vec=True))
with self.subTest('drop_query_vec=False'):
self._test_retr(functools.partial(FlexIndex.faiss_flat_retriever, drop_query_vec=False))

@unittest.skipIf(not pyterrier_dr.util.faiss_available(), "faiss not available")
def test_faiss_hnsw_retriever(self):
self._test_retr(FlexIndex.faiss_hnsw_retriever, exact=False)
with self.subTest('drop_query_vec=True'):
self._test_retr(functools.partial(FlexIndex.faiss_hnsw_retriever, drop_query_vec=True))
with self.subTest('drop_query_vec=False'):
self._test_retr(functools.partial(FlexIndex.faiss_hnsw_retriever, drop_query_vec=False))

@unittest.skipIf(not pyterrier_dr.util.faiss_available(), "faiss not available")
def test_faiss_ivf_retriever(self):
self._test_retr(FlexIndex.faiss_ivf_retriever, exact=False)
with self.subTest('drop_query_vec=True'):
self._test_retr(functools.partial(FlexIndex.faiss_ivf_retriever, drop_query_vec=True))
with self.subTest('drop_query_vec=False'):
self._test_retr(functools.partial(FlexIndex.faiss_ivf_retriever, drop_query_vec=False))

@unittest.skipIf(not pyterrier_dr.util.scann_available(), "scann not available")
def test_scann_retriever(self):
self._test_retr(FlexIndex.scann_retriever, exact=False)

def test_np_retriever(self):
self._test_retr(FlexIndex.np_retriever)

def test_np_retriever_drop_query_vec(self):
self._test_retr(functools.partial(FlexIndex.np_retriever, drop_query_vec=True))
with self.subTest('drop_query_vec=True'):
self._test_retr(functools.partial(FlexIndex.np_retriever, drop_query_vec=True))
with self.subTest('drop_query_vec=False'):
self._test_retr(functools.partial(FlexIndex.np_retriever, drop_query_vec=False))

def test_torch_retriever(self):
self._test_retr(FlexIndex.torch_retriever)
Expand Down

0 comments on commit 0baba13

Please sign in to comment.