diff --git a/pyterrier_dr/flex/faiss_retr.py b/pyterrier_dr/flex/faiss_retr.py index 0ad207d..cd6f033 100644 --- a/pyterrier_dr/flex/faiss_retr.py +++ b/pyterrier_dr/flex/faiss_retr.py @@ -1,36 +1,34 @@ import json -import pandas as pd import math import struct import os import pyterrier as pt -import itertools import numpy as np import tempfile import ir_datasets import pyterrier_dr +import pyterrier_alpha as pta from . import FlexIndex logger = ir_datasets.log.easy() class FaissRetriever(pt.Indexer): - def __init__(self, flex_index, faiss_index, n_probe=None, ef_search=None, search_bounded_queue=None, qbatch=64): + def __init__(self, flex_index, faiss_index, n_probe=None, ef_search=None, search_bounded_queue=None, qbatch=64, drop_query_vec=False): self.flex_index = flex_index self.faiss_index = faiss_index self.n_probe = n_probe self.ef_search = ef_search self.search_bounded_queue = search_bounded_queue self.qbatch = qbatch + self.drop_query_vec = drop_query_vec def transform(self, inp): + pta.validate.query_frame(inp, extra_columns=['query_vec']) inp = inp.reset_index(drop=True) - assert all(f in inp.columns for f in ['qid', 'query_vec']) docnos, config = self.flex_index.payload(return_dvecs=False) query_vecs = np.stack(inp['query_vec']) query_vecs = query_vecs.copy() - idxs = [] - res = {'docid': [], 'score': [], 'rank': []} num_q = query_vecs.shape[0] QBATCH = self.qbatch if self.n_probe is not None: @@ -42,25 +40,27 @@ def transform(self, inp): it = range(0, num_q, QBATCH) if self.flex_index.verbose: it = logger.pbar(it, unit='qbatch') + + result = pta.DataFrameBuilder(['score', 'docno', 'docid', 'rank']) for qidx in it: scores, dids = self.faiss_index.search(query_vecs[qidx:qidx+QBATCH], self.flex_index.num_results) - for i, (s, d) in enumerate(zip(scores, dids)): + for s, d in zip(scores, dids): mask = d != -1 d = d[mask] s = s[mask] - res['docid'].append(d) - res['score'].append(s) - res['rank'].append(np.arange(d.shape[0])) - idxs.extend(itertools.repeat(qidx+i, d.shape[0])) - res = {k: np.concatenate(v) for k, v in res.items()} - res['docno'] = docnos.fwd[res['docid']] - for col in inp.columns: - if col != 'query_vec': - res[col] = inp[col][idxs].values - return pd.DataFrame(res) - - -def _faiss_flat_retriever(self, gpu=False, qbatch=64): + result.extend({ + 'score': s, + 'docno': docnos.fwd[d], + 'docid': d, + 'rank': np.arange(d.shape[0]), + }) + + if self.drop_query_vec: + inp = inp.drop(columns='query_vec') + return result.to_df(inp) + + +def _faiss_flat_retriever(self, gpu=False, qbatch=64, drop_query_vec=False): pyterrier_dr.util.assert_faiss() import faiss if 'faiss_flat' not in self._cache: @@ -80,12 +80,12 @@ def _faiss_flat_retriever(self, gpu=False, qbatch=64): co = faiss.GpuMultipleClonerOptions() co.shard = True self._cache['faiss_flat_gpu'] = faiss.index_cpu_to_all_gpus(self._faiss_flat, co=co) - return FaissRetriever(self, self._cache['faiss_flat_gpu']) - return FaissRetriever(self, self._cache['faiss_flat'], qbatch=qbatch) + return FaissRetriever(self, self._cache['faiss_flat_gpu'], drop_query_vec=drop_query_vec) + return FaissRetriever(self, self._cache['faiss_flat'], qbatch=qbatch, drop_query_vec=drop_query_vec) FlexIndex.faiss_flat_retriever = _faiss_flat_retriever -def _faiss_hnsw_retriever(self, neighbours=32, ef_construction=40, ef_search=16, cache=True, search_bounded_queue=True, qbatch=64): +def _faiss_hnsw_retriever(self, neighbours=32, ef_construction=40, ef_search=16, cache=True, search_bounded_queue=True, qbatch=64, drop_query_vec=False): pyterrier_dr.util.assert_faiss() import faiss meta, = self.payload(return_dvecs=False, return_docnos=False) @@ -107,7 +107,7 @@ def _faiss_hnsw_retriever(self, neighbours=32, ef_construction=40, ef_search=16, with logger.duration('reading hnsw table'): self._cache[key] = faiss.read_index(str(self.index_path/index_name)) self._cache[key].storage = self.faiss_flat_retriever().faiss_index - return FaissRetriever(self, self._cache[key], ef_search=ef_search, search_bounded_queue=search_bounded_queue, qbatch=qbatch) + return FaissRetriever(self, self._cache[key], ef_search=ef_search, search_bounded_queue=search_bounded_queue, qbatch=qbatch, drop_query_vec=drop_query_vec) FlexIndex.faiss_hnsw_retriever = _faiss_hnsw_retriever @@ -154,7 +154,7 @@ def _sample_train(index, count=None): idxs = np.random.RandomState(0).choice(dvecs.shape[0], size=count, replace=False) return dvecs[idxs] -def _faiss_ivf_retriever(self, train_sample=None, n_list=None, cache=True, n_probe=1): +def _faiss_ivf_retriever(self, train_sample=None, n_list=None, cache=True, n_probe=1, drop_query_vec=False): pyterrier_dr.util.assert_faiss() import faiss meta, = self.payload(return_dvecs=False, return_docnos=False) @@ -197,5 +197,5 @@ def _faiss_ivf_retriever(self, train_sample=None, n_list=None, cache=True, n_pro else: with logger.duration('reading index'): self._cache[key] = faiss.read_index(str(self.index_path/index_name)) - return FaissRetriever(self, self._cache[key], n_probe=n_probe) + return FaissRetriever(self, self._cache[key], n_probe=n_probe, drop_query_vec=drop_query_vec) FlexIndex.faiss_ivf_retriever = _faiss_ivf_retriever diff --git a/pyterrier_dr/flex/np_retr.py b/pyterrier_dr/flex/np_retr.py index 27a65d0..77de0ab 100644 --- a/pyterrier_dr/flex/np_retr.py +++ b/pyterrier_dr/flex/np_retr.py @@ -1,4 +1,3 @@ -import itertools import pyterrier as pt import numpy as np import pandas as pd diff --git a/tests/test_flexindex.py b/tests/test_flexindex.py index 952ae01..b1f9a66 100644 --- a/tests/test_flexindex.py +++ b/tests/test_flexindex.py @@ -113,25 +113,34 @@ def _test_retr(self, Retr, exact=True, test_smaller=True): @unittest.skipIf(not pyterrier_dr.util.faiss_available(), "faiss not available") def test_faiss_flat_retriever(self): - self._test_retr(FlexIndex.faiss_flat_retriever) + with self.subTest('drop_query_vec=True'): + self._test_retr(functools.partial(FlexIndex.faiss_flat_retriever, drop_query_vec=True)) + with self.subTest('drop_query_vec=False'): + self._test_retr(functools.partial(FlexIndex.faiss_flat_retriever, drop_query_vec=False)) @unittest.skipIf(not pyterrier_dr.util.faiss_available(), "faiss not available") def test_faiss_hnsw_retriever(self): - self._test_retr(FlexIndex.faiss_hnsw_retriever, exact=False) + with self.subTest('drop_query_vec=True'): + self._test_retr(functools.partial(FlexIndex.faiss_hnsw_retriever, drop_query_vec=True)) + with self.subTest('drop_query_vec=False'): + self._test_retr(functools.partial(FlexIndex.faiss_hnsw_retriever, drop_query_vec=False)) @unittest.skipIf(not pyterrier_dr.util.faiss_available(), "faiss not available") def test_faiss_ivf_retriever(self): - self._test_retr(FlexIndex.faiss_ivf_retriever, exact=False) + with self.subTest('drop_query_vec=True'): + self._test_retr(functools.partial(FlexIndex.faiss_ivf_retriever, drop_query_vec=True)) + with self.subTest('drop_query_vec=False'): + self._test_retr(functools.partial(FlexIndex.faiss_ivf_retriever, drop_query_vec=False)) @unittest.skipIf(not pyterrier_dr.util.scann_available(), "scann not available") def test_scann_retriever(self): self._test_retr(FlexIndex.scann_retriever, exact=False) def test_np_retriever(self): - self._test_retr(FlexIndex.np_retriever) - - def test_np_retriever_drop_query_vec(self): - self._test_retr(functools.partial(FlexIndex.np_retriever, drop_query_vec=True)) + with self.subTest('drop_query_vec=True'): + self._test_retr(functools.partial(FlexIndex.np_retriever, drop_query_vec=True)) + with self.subTest('drop_query_vec=False'): + self._test_retr(functools.partial(FlexIndex.np_retriever, drop_query_vec=False)) def test_torch_retriever(self): self._test_retr(FlexIndex.torch_retriever)