diff --git a/pyterrier_dr/flex/np_retr.py b/pyterrier_dr/flex/np_retr.py index ff9a21d..27a65d0 100644 --- a/pyterrier_dr/flex/np_retr.py +++ b/pyterrier_dr/flex/np_retr.py @@ -6,16 +6,19 @@ from ..indexes import RankedLists from . import FlexIndex import ir_datasets +import pyterrier_alpha as pta logger = ir_datasets.log.easy() class NumpyRetriever(pt.Transformer): - def __init__(self, flex_index, num_results=1000, batch_size=None): + def __init__(self, flex_index, num_results=1000, batch_size=None, drop_query_vec=False): self.flex_index = flex_index self.num_results = num_results self.batch_size = batch_size or 4096 + self.drop_query_vec = drop_query_vec def transform(self, inp: pd.DataFrame) -> pd.DataFrame: + pta.validate.query_frame(inp, extra_columns=['query_vec']) inp = inp.reset_index(drop=True) query_vecs = np.stack(inp['query_vec']) docnos, dvecs, config = self.flex_index.payload() @@ -37,19 +40,19 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame: scores = query_vecs @ doc_batch dids = np.arange(idx_start, idx_start+doc_batch.shape[1], dtype='i4').reshape(1, -1).repeat(num_q, axis=0) ranked_lists.update(scores, dids) - result_scores, result_dids = ranked_lists.results() - result_docnos = docnos.fwd[result_dids] - cols = { - 'score': np.concatenate(result_scores), - 'docno': np.concatenate(result_docnos), - 'docid': np.concatenate(result_dids), - 'rank': np.concatenate([np.arange(len(scores)) for scores in result_scores]), - } - idxs = list(itertools.chain(*(itertools.repeat(i, len(scores)) for i, scores in enumerate(result_scores)))) - for col in inp.columns: - if col != 'query_vec': - cols[col] = inp[col][idxs].values - return pd.DataFrame(cols) + + result = pta.DataFrameBuilder(['score', 'docno', 'docid', 'rank']) + for scores, dids in zip(*ranked_lists.results()): + result.extend({ + 'score': scores, + 'docno': docnos.fwd[dids], + 'docid': dids, + 'rank': np.arange(len(scores)), + }) + + if self.drop_query_vec: + inp = inp.drop(columns='query_vec') + return result.to_df(inp) class NumpyVectorLoader(pt.Transformer): @@ -108,18 +111,18 @@ def transform(self, inp): return res -def _np_retriever(self, num_results=1000, batch_size=None): - return NumpyRetriever(self, num_results=num_results, batch_size=batch_size) +def _np_retriever(self, num_results=1000, batch_size=None, drop_query_vec=False): + return NumpyRetriever(self, num_results=num_results, batch_size=batch_size, drop_query_vec=drop_query_vec) FlexIndex.np_retriever = _np_retriever def _np_vec_loader(self): - return NumpyVectorLoader(self) + return NumpyVectorLoader(self) FlexIndex.np_vec_loader = _np_vec_loader FlexIndex.vec_loader = _np_vec_loader # default vec_loader def _np_scorer(self, num_results=None): - return NumpyScorer(self, num_results) + return NumpyScorer(self, num_results) FlexIndex.np_scorer = _np_scorer FlexIndex.scorer = _np_scorer # default scorer diff --git a/tests/test_flexindex.py b/tests/test_flexindex.py index 9d5e99b..952ae01 100644 --- a/tests/test_flexindex.py +++ b/tests/test_flexindex.py @@ -1,3 +1,4 @@ +import functools import tempfile import unittest import numpy as np @@ -129,6 +130,9 @@ def test_scann_retriever(self): def test_np_retriever(self): self._test_retr(FlexIndex.np_retriever) + def test_np_retriever_drop_query_vec(self): + self._test_retr(functools.partial(FlexIndex.np_retriever, drop_query_vec=True)) + def test_torch_retriever(self): self._test_retr(FlexIndex.torch_retriever)