Skip to content

Commit

Permalink
drop_query_vec for np_retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Nov 22, 2024
1 parent 5013ffd commit 6cde1f1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
39 changes: 21 additions & 18 deletions pyterrier_dr/flex/np_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
from ..indexes import RankedLists
from . import FlexIndex
import ir_datasets
import pyterrier_alpha as pta

logger = ir_datasets.log.easy()

class NumpyRetriever(pt.Transformer):
def __init__(self, flex_index, num_results=1000, batch_size=None):
def __init__(self, flex_index, num_results=1000, batch_size=None, drop_query_vec=False):
self.flex_index = flex_index
self.num_results = num_results
self.batch_size = batch_size or 4096
self.drop_query_vec = drop_query_vec

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
query_vecs = np.stack(inp['query_vec'])
docnos, dvecs, config = self.flex_index.payload()
Expand All @@ -37,19 +40,19 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
scores = query_vecs @ doc_batch
dids = np.arange(idx_start, idx_start+doc_batch.shape[1], dtype='i4').reshape(1, -1).repeat(num_q, axis=0)
ranked_lists.update(scores, dids)
result_scores, result_dids = ranked_lists.results()
result_docnos = docnos.fwd[result_dids]
cols = {
'score': np.concatenate(result_scores),
'docno': np.concatenate(result_docnos),
'docid': np.concatenate(result_dids),
'rank': np.concatenate([np.arange(len(scores)) for scores in result_scores]),
}
idxs = list(itertools.chain(*(itertools.repeat(i, len(scores)) for i, scores in enumerate(result_scores))))
for col in inp.columns:
if col != 'query_vec':
cols[col] = inp[col][idxs].values
return pd.DataFrame(cols)

result = pta.DataFrameBuilder(['score', 'docno', 'docid', 'rank'])
for scores, dids in zip(*ranked_lists.results()):
result.extend({
'score': scores,
'docno': docnos.fwd[dids],
'docid': dids,
'rank': np.arange(len(scores)),
})

if self.drop_query_vec:
inp = inp.drop(columns='query_vec')
return result.to_df(inp)


class NumpyVectorLoader(pt.Transformer):
Expand Down Expand Up @@ -108,18 +111,18 @@ def transform(self, inp):
return res


def _np_retriever(self, num_results=1000, batch_size=None):
return NumpyRetriever(self, num_results=num_results, batch_size=batch_size)
def _np_retriever(self, num_results=1000, batch_size=None, drop_query_vec=False):
return NumpyRetriever(self, num_results=num_results, batch_size=batch_size, drop_query_vec=drop_query_vec)
FlexIndex.np_retriever = _np_retriever


def _np_vec_loader(self):
return NumpyVectorLoader(self)
return NumpyVectorLoader(self)
FlexIndex.np_vec_loader = _np_vec_loader
FlexIndex.vec_loader = _np_vec_loader # default vec_loader


def _np_scorer(self, num_results=None):
return NumpyScorer(self, num_results)
return NumpyScorer(self, num_results)
FlexIndex.np_scorer = _np_scorer
FlexIndex.scorer = _np_scorer # default scorer
4 changes: 4 additions & 0 deletions tests/test_flexindex.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import tempfile
import unittest
import numpy as np
Expand Down Expand Up @@ -129,6 +130,9 @@ def test_scann_retriever(self):
def test_np_retriever(self):
self._test_retr(FlexIndex.np_retriever)

def test_np_retriever_drop_query_vec(self):
self._test_retr(functools.partial(FlexIndex.np_retriever, drop_query_vec=True))

def test_torch_retriever(self):
self._test_retr(FlexIndex.torch_retriever)

Expand Down

0 comments on commit 6cde1f1

Please sign in to comment.