Skip to content

Commit

Permalink
fix indentation in scann_retr.py
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Oct 14, 2023
1 parent 9017e96 commit c20947e
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions pyterrier_dr/flex/scann_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,37 +46,37 @@ def transform(self, inp):


def _scann_retriever(self, n_leaves=None, leaves_to_search=1, train_sample=None):
import scann
assert not hasattr(scann.scann_ops_pybind, 'builder'), "scann==1.0.0 required; install from wheel here: <https://github.com/google-research/google-research/blob/master/scann/docs/releases.md#scann-wheel-archive>"
dvecs, meta, = self.payload(return_docnos=False)
import scann
assert not hasattr(scann.scann_ops_pybind, 'builder'), "scann==1.0.0 required; install from wheel here: <https://github.com/google-research/google-research/blob/master/scann/docs/releases.md#scann-wheel-archive>"
dvecs, meta, = self.payload(return_docnos=False)

if n_leaves is None:
# rule of thumb: sqrt(doc_count) (from <https://github.com/google-research/google-research/blob/master/scann/docs/algorithms.md>)
n_leaves = math.ceil(math.sqrt(meta['doc_count']))
# we'll shift it to the nearest power of 2
n_leaves = int(1 << math.ceil(math.log2(n_leaves)))
if n_leaves is None:
# rule of thumb: sqrt(doc_count) (from <https://github.com/google-research/google-research/blob/master/scann/docs/algorithms.md>)
n_leaves = math.ceil(math.sqrt(meta['doc_count']))
# we'll shift it to the nearest power of 2
n_leaves = int(1 << math.ceil(math.log2(n_leaves)))

if train_sample is None:
train_sample = n_leaves * 39
train_sample = min(train_sample, meta['doc_count'])
elif 0 <= train_sample <= 1:
train_sample = math.ceil(meta['doc_count'] * train_sample)
if train_sample is None:
train_sample = n_leaves * 39
train_sample = min(train_sample, meta['doc_count'])
elif 0 <= train_sample <= 1:
train_sample = math.ceil(meta['doc_count'] * train_sample)

key = ('scann', n_leaves, train_sample)
index_name = f'scann_leaves-{n_leaves}_train-{train_sample}.scann'
if key not in self._cache:
if not os.path.exists(self.index_path/index_name):
with logger.duration(f'building scann index with {n_leaves} leaves'):
searcher = scann.ScannBuilder(dvecs, 10, "dot_product") # neighbours=10; doesn't seem to affect the model (?)
searcher = searcher.tree(num_leaves=n_leaves, num_leaves_to_search=leaves_to_search, training_sample_size=train_sample, quantize_centroids=True)
searcher = searcher.score_brute_force()
searcher = searcher.create_pybind()
with logger.duration('saving scann index'):
(self.index_path/index_name).mkdir()
searcher.serialize(str(self.index_path/index_name))
self._cache[key] = searcher
else:
with logger.duration('reading index'):
self._cache[key] = scann.scann_ops_pybind.load_searcher(dvecs, str(self.index_path/index_name))
return ScannRetriever(self, self._cache[key], leaves_to_search=leaves_to_search)
key = ('scann', n_leaves, train_sample)
index_name = f'scann_leaves-{n_leaves}_train-{train_sample}.scann'
if key not in self._cache:
if not os.path.exists(self.index_path/index_name):
with logger.duration(f'building scann index with {n_leaves} leaves'):
searcher = scann.ScannBuilder(dvecs, 10, "dot_product") # neighbours=10; doesn't seem to affect the model (?)
searcher = searcher.tree(num_leaves=n_leaves, num_leaves_to_search=leaves_to_search, training_sample_size=train_sample, quantize_centroids=True)
searcher = searcher.score_brute_force()
searcher = searcher.create_pybind()
with logger.duration('saving scann index'):
(self.index_path/index_name).mkdir()
searcher.serialize(str(self.index_path/index_name))
self._cache[key] = searcher
else:
with logger.duration('reading index'):
self._cache[key] = scann.scann_ops_pybind.load_searcher(dvecs, str(self.index_path/index_name))
return ScannRetriever(self, self._cache[key], leaves_to_search=leaves_to_search)
FlexIndex.scann_retriever = _scann_retriever

0 comments on commit c20947e

Please sign in to comment.