Skip to content

Commit

Permalink
Compilation (#33)
Browse files Browse the repository at this point in the history
* examples of compile operations

* fix fusion cutoff - hat tip to Sean

* add for both prfs

* add a unit test for compilation

* address review feedback

* and this one

* syntax fix
  • Loading branch information
cmacdonald authored Dec 9, 2024
1 parent fa1e04c commit 74a4e92
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 1 deletion.
3 changes: 3 additions & 0 deletions pyterrier_dr/biencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def __init__(self, bi_encoder_model: BiEncoder, verbose=None, batch_size=None, t

def encode(self, texts, batch_size=None) -> np.array:
return self.bi_encoder_model.encode_docs(texts, batch_size=batch_size or self.batch_size)

def fuse_rank_cutoff(self, k):
return pt.RankCutoff(k) >> self

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
pta.validate.columns(inp, includes=[self.text_field])
Expand Down
7 changes: 7 additions & 0 deletions pyterrier_dr/flex/faiss_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ def __init__(self, flex_index, faiss_index, n_probe=None, ef_search=None, search
self.drop_query_vec = drop_query_vec
self.num_results = num_results

def fuse_rank_cutoff(self, k):
return None # disable fusion for ANN
if k < self.num_results:
return FaissRetriever(self.flex_index, self.faiss_index,
n_probe=self.n_probe, ef_search=self.ef_search, search_bounded_queue=self.search_bounded_queue,
num_results=k, qbatch=self.qbatch, drop_query_vec=self.drop_query_vec)

def transform(self, inp):
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
Expand Down
9 changes: 9 additions & 0 deletions pyterrier_dr/flex/flatnav_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ def __init__(self, flex_index, flatnav_index, *, threads=16, ef_search=100, num_
self.num_initializations = num_initializations
self.verbose = verbose

def fuse_rank_cutoff(self, k):
return None # disable fusion for ANN
if k < self.num_results:
return FlatNavRetriever(self.flex_index, self.flatnav_index,
num_results=k, ef_search=self.ef_search,
qbatch = self.qbatch, num_initializations=self.num_initializations,
drop_query_vec=self.drop_query_vec, verbose=self.verbose)


def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
Expand Down
5 changes: 5 additions & 0 deletions pyterrier_dr/flex/gar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def __init__(self, flex_index, graph, score_fn, batch_size=128, num_results=1000
self.num_results = num_results
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):
if k < self.num_results:
return FlexGar(self.flex_index, self.graph, score_fn=self.score_fn,
num_results=k, batch_size=self.batch_size, drop_query_vec=self.drop_query_vec)

def transform(self, inp):
pta.validate.result_frame(inp, extra_columns=['query_vec', 'score'])

Expand Down
5 changes: 5 additions & 0 deletions pyterrier_dr/flex/ladr.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def __init__(self, flex_index, graph, dense_scorer, num_results=1000, depth=100,
self.max_hops = max_hops
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):
if k < self.num_results:
return LadrAdaptive(self.flex_index, self.graph, self.dense_scorer,
num_results=k, depth=self.depth, max_hops=self.max_hops, drop_query_vec=self.drop_query_vec)

def transform(self, inp):
pta.validate.result_frame(inp, extra_columns=['query_vec'])
docnos, config = self.flex_index.payload(return_dvecs=False)
Expand Down
8 changes: 7 additions & 1 deletion pyterrier_dr/flex/np_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def __init__(self,
self.batch_size = batch_size or 4096
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):
if k < self.num_results:
return NumpyRetriever(self.flex_index, num_results=k, batch_size=self.batch_size, drop_query_vec=self.drop_query_vec)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
Expand Down Expand Up @@ -67,7 +71,9 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
docids = self.flex_index._load_docids(inp)
dvecs, config = self.flex_index.payload(return_docnos=False)
return inp.assign(doc_vec=list(dvecs[docids]))


def fuse_rank_cutoff(self, k):
return pt.RankCutoff(k) >> self

class NumpyScorer(pt.Transformer):
def __init__(self, flex_index: FlexIndex, *, num_results: Optional[int] = None):
Expand Down
5 changes: 5 additions & 0 deletions pyterrier_dr/flex/scann_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def __init__(self, flex_index, scann_index, num_results=1000, leaves_to_search=N
self.qbatch = qbatch
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):
return None # disable fusion for ANN
if k < self.num_results:
return ScannRetriever(self.flex_index, self.scann_index, num_results=k, leaves_to_search=self.leaves_to_search, qbatch=self.qbatch, drop_query_vec=self.drop_query_vec)

def transform(self, inp):
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
Expand Down
9 changes: 9 additions & 0 deletions pyterrier_dr/flex/torch_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def __init__(self,
self.qbatch = qbatch
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):
if k < self.num_results:
return TorchRetriever(
self.flex_index,
self.torch_vecs,
num_results=k,
qbatch=self.qbatch,
drop_query_vec=self.drop_query_vec)

def transform(self, inp):
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
Expand Down
11 changes: 11 additions & 0 deletions pyterrier_dr/flex/voyager_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ def __init__(self, flex_index, voyager_index, query_ef=None, num_results=1000, q
self.qbatch = qbatch
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):
return None # disable fusion for ANN
if k < self.num_results:
return VoyagerRetriever(
self.flex_index,
self.voyager_index,
query_ef=self.query_ef,
num_results=k,
qbatch=self.qbatch,
drop_query_vec=self.drop_query_vec)

def transform(self, inp):
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
Expand Down
6 changes: 6 additions & 0 deletions pyterrier_dr/prf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def __init__(self,
self.beta = beta
self.k = k

def compile(self) -> pt.Transformer:
return pt.RankCutoff(self.k) >> self

@pta.transform.by_query(add_ranks=False)
def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
"""Performs Vector PRF on the input dataframe."""
Expand Down Expand Up @@ -76,6 +79,9 @@ def __init__(self,
):
self.k = k

def compile(self) -> pt.Transformer:
return pt.RankCutoff(self.k) >> self

@pta.transform.by_query(add_ranks=False)
def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
"""Performs Average PRF on the input dataframe."""
Expand Down
55 changes: 55 additions & 0 deletions tests/test_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest
import tempfile
import unittest
import numpy as np
import pandas as pd
import pyterrier_dr
from pyterrier_dr import FlexIndex

class TestFlexIndex(unittest.TestCase):

def _generate_data(self, count=2000, dim=100):
def random_unit_vec():
v = np.random.rand(dim).astype(np.float32)
return v / np.linalg.norm(v)
return [
{'docno': str(i), 'doc_vec': random_unit_vec()}
for i in range(count)
]

def test_compilation_with_rank_and_averageprf(self):
self._test_compilation_with_rank_and_prf(pyterrier_dr.AveragePrf)

def test_compilation_with_rank_and_vectorprf(self):
self._test_compilation_with_rank_and_prf(pyterrier_dr.VectorPrf)

def _test_compilation_with_rank_and_prf(self, prf_clz):

with tempfile.TemporaryDirectory() as destdir:
index = FlexIndex(destdir+'/index')
dataset = self._generate_data(count=2000)
index.index(dataset)

retr = index.retriever()
queries = pd.DataFrame([
{'qid': '0', 'query_vec': dataset[0]['doc_vec']},
{'qid': '1', 'query_vec': dataset[1]['doc_vec']},
])

pipe1 = retr >> prf_clz(k=3) >> retr
pipe1_opt = pipe1.compile()
self.assertEqual(3, pipe1_opt[0].num_results)
self.assertEqual(1000, pipe1_opt[-1].num_results)
#NB: pipe1 wouldnt actually work for PRF, as doc_vecs are not present. however compilation is valid

pipe2 = retr >> index.vec_loader() >> pyterrier_dr.AveragePrf(k=3) >> (retr % 2)
pipe2_opt = pipe2.compile()
self.assertEqual(3, pipe2_opt[0].num_results)
self.assertEqual(2, pipe2_opt[-1].num_results)

res2 = pipe2(queries)
res2_opt = pipe2_opt(queries)

pd.testing.assert_frame_equal(res2, res2_opt)


0 comments on commit 74a4e92

Please sign in to comment.