Skip to content

Commit

Permalink
address review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacdonald committed Dec 9, 2024
1 parent 655f060 commit d513f6b
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyterrier_dr/flex/faiss_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, flex_index, faiss_index, n_probe=None, ef_search=None, search
self.num_results = num_results

def fuse_rank_cutoff(self, k):
return None # disable fusion for ANN
if k < self.num_results:
return FaissRetriever(self.flex_index, self.faiss_index,
n_probe=self.n_probe, ef_search=self.ef_search, search_bounded_queue=self.search_bounded_queue,
Expand Down
1 change: 1 addition & 0 deletions pyterrier_dr/flex/flatnav_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, flex_index, flatnav_index, *, threads=16, ef_search=100, num_
self.verbose = verbose

def fuse_rank_cutoff(self, k):
return None # disable fusion for ANN
if k < self.num_results:
return FlatNavRetriever(self.flex_index, self.flatnav_index,
num_results=k, ef_search=self.ef_search,
Expand Down
5 changes: 5 additions & 0 deletions pyterrier_dr/flex/gar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def __init__(self, flex_index, graph, score_fn, batch_size=128, num_results=1000
self.num_results = num_results
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):
if k < self.num_results:
return FlexGar(self.flex_index, self.graph, score_fn=self.score_fn,
num_results=k, batch_size=self.batch_size, drop_query_vec=self.drop_query_vec)

def transform(self, inp):
pta.validate.result_frame(inp, extra_columns=['query_vec', 'score'])

Expand Down
5 changes: 5 additions & 0 deletions pyterrier_dr/flex/ladr.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def __init__(self, flex_index, graph, dense_scorer, num_results=1000, depth=100,
self.max_hops = max_hops
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):

Check failure on line 94 in pyterrier_dr/flex/ladr.py

View workflow job for this annotation

GitHub Actions / ruff (3.10, ubuntu-latest)

Ruff

pyterrier_dr/flex/ladr.py:94:1: SyntaxError: unindent does not match any outer indentation level
if k < self.num_results:
return LadrAdaptive(self.flex_index, self.graph, self.dense_scorer,
num_results=k, depth=self.depth, max_hops=self.max_hops, drop_query_vec=self.drop_query_vec)

def transform(self, inp):
pta.validate.result_frame(inp, extra_columns=['query_vec'])
docnos, config = self.flex_index.payload(return_dvecs=False)
Expand Down
1 change: 1 addition & 0 deletions pyterrier_dr/flex/scann_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, flex_index, scann_index, num_results=1000, leaves_to_search=N
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):
return None # disable fusion for ANN
if k < self.num_results:
return ScannRetriever(self.flex_index, self.scann_index, num_results=k, leaves_to_search=self.leaves_to_search, qbatch=self.qbatch, drop_query_vec=self.drop_query_vec)

Expand Down

0 comments on commit d513f6b

Please sign in to comment.