Skip to content

Commit

Permalink
replacing ir_datasets's logger.pbar with pt.tqdm
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Nov 23, 2024
1 parent 6b8e79b commit aaa868c
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 18 deletions.
2 changes: 0 additions & 2 deletions pyterrier_dr/flex/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
from npids import Lookup
from enum import Enum
from .. import SimFn
import ir_datasets
import pyterrier_alpha as pta

logger = ir_datasets.log.easy()

class IndexingMode(Enum):
create = "create"
Expand Down
5 changes: 2 additions & 3 deletions pyterrier_dr/flex/corpus_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import ir_datasets
import torch
import numpy as np
import pyterrier as pt
import pyterrier_dr
from ..indexes import TorchRankedLists
from . import FlexIndex

logger = ir_datasets.log.easy()


def _corpus_graph(self, k=16, batch_size=8192):
from pyterrier_adaptive import CorpusGraph
Expand Down Expand Up @@ -43,7 +42,7 @@ def _build_corpus_graph(flex_index, k, out_dir, batch_size):
weights_path = out_dir/'weights.f16.np'
device = pyterrier_dr.util.infer_device()
dtype = torch.half if device.type == 'cuda' else torch.float
with logger.pbar_raw(total=int((num_chunks+1)*num_chunks/2), unit='chunk', smoothing=1) as pbar, \
with pt.tqdm(total=int((num_chunks+1)*num_chunks/2), unit='chunk', smoothing=1) as pbar, \
ir_datasets.util.finialized_file(str(edges_path), 'wb') as fe, \
ir_datasets.util.finialized_file(str(weights_path), 'wb') as fw:
for i in range(num_chunks):
Expand Down
8 changes: 4 additions & 4 deletions pyterrier_dr/flex/faiss_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def transform(self, inp):
self.faiss_index.hnsw.search_bounded_queue = self.search_bounded_queue
it = range(0, num_q, QBATCH)
if self.flex_index.verbose:
it = logger.pbar(it, unit='qbatch')
it = pt.tqdm(it, unit='qbatch')

result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank'])
for qidx in it:
Expand Down Expand Up @@ -96,7 +96,7 @@ def _faiss_hnsw_retriever(self, neighbours=32, ef_construction=40, ef_search=16,
dvecs, meta = self.payload(return_docnos=False)
if not os.path.exists(self.index_path/index_name):
idx = faiss.IndexHNSWFlat(meta['vec_size'], neighbours, faiss.METRIC_INNER_PRODUCT)
for start_idx in logger.pbar(range(0, dvecs.shape[0], 4096), desc='indexing', unit='batch'):
for start_idx in pt.tqdm(range(0, dvecs.shape[0], 4096), desc='indexing', unit='batch'):
idx.add(np.array(dvecs[start_idx:start_idx+4096]))
idx.storage = faiss.IndexFlatIP(meta['vec_size']) # clear storage ; we can use faiss_flat here instead so we don't keep an extra copy
if cache:
Expand Down Expand Up @@ -133,7 +133,7 @@ def _build_hnsw_graph(hnsw, out_dir):
weights_path = out_dir/'weights.f16.np'
with ir_datasets.util.finialized_file(str(edges_path), 'wb') as fe, \
ir_datasets.util.finialized_file(str(weights_path), 'wb') as fw:
for did in logger.pbar(range(num_docs), unit='doc', smoothing=1):
for did in pt.tqdms(range(num_docs), unit='doc', smoothing=1):
start = hnsw.offsets.at(did)
dids = [hnsw.neighbors.at(i) for i in range(start, start+lvl_0_size)]
dids = [(d if d != -1 else did) for d in dids] # replace with self if missing value
Expand Down Expand Up @@ -186,7 +186,7 @@ def _faiss_ivf_retriever(self, train_sample=None, n_list=None, cache=True, n_pro
train = _sample_train(self, train_sample)
with logger.duration(f'training ivf with {n_list} posting lists'):
idx.train(train)
for start_idx in logger.pbar(range(0, dvecs.shape[0], 4096), desc='indexing', unit='batch'):
for start_idx in pt.tqdm(range(0, dvecs.shape[0], 4096), desc='indexing', unit='batch'):
idx.add(np.array(dvecs[start_idx:start_idx+4096]))
if cache:
with logger.duration('caching index'):
Expand Down
6 changes: 2 additions & 4 deletions pyterrier_dr/flex/ladr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import pyterrier as pt
import pyterrier_alpha as pta
from . import FlexIndex
import ir_datasets

logger = ir_datasets.log.easy()

class LadrPreemptive(pt.Transformer):
def __init__(self, flex_index, graph, dense_scorer, hops=1, drop_query_vec=False):
Expand All @@ -25,7 +23,7 @@ def transform(self, inp):

it = iter(inp.groupby('qid'))
if self.flex_index.verbose:
it = logger.pbar(it)
it = pt.tqdm(it)
for qid, df in it:
qdata = {col: [df[col].iloc[0]] for col in qcols}
docids = docnos.inv[df['docno'].values]
Expand Down Expand Up @@ -79,7 +77,7 @@ def transform(self, inp):

it = iter(inp.groupby('qid'))
if self.flex_index.verbose:
it = logger.pbar(it)
it = pt.tqdm(it)
for qid, df in it:
qdata = {col: [df[col].iloc[0]] for col in qcols}
query_vecs = df['query_vec'].iloc[0].reshape(1, -1)
Expand Down
4 changes: 1 addition & 3 deletions pyterrier_dr/flex/np_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
from .. import SimFn
from ..indexes import RankedLists
from . import FlexIndex
import ir_datasets
import pyterrier_alpha as pta

logger = ir_datasets.log.easy()

class NumpyRetriever(pt.Transformer):
def __init__(self, flex_index, num_results=1000, batch_size=None, drop_query_vec=False):
Expand Down Expand Up @@ -90,7 +88,7 @@ def transform(self, inp):
res_idxs = []
res_scores = []
res_ranks = []
for qid, df in logger.pbar(inp.groupby('qid')):
for qid, df in pt.tqdm(inp.groupby('qid')):
docids = self.flex_index._load_docids(df)
query_vecs = df['query_vec'].iloc[0].reshape(1, -1)
scores = self.score(query_vecs, docids).reshape(-1)
Expand Down
4 changes: 2 additions & 2 deletions pyterrier_dr/flex/voyager_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def transform(self, inp):
QBATCH = self.qbatch
it = range(0, num_q, QBATCH)
if self.flex_index.verbose:
it = logger.pbar(it, unit='qbatch')
it = pt.tqdm(it, unit='qbatch')
for qidx in it:
qvec_batch = query_vecs[qidx:qidx+QBATCH]
neighbor_ids, distances = self.voyager_index.query(qvec_batch, self.flex_index.num_results, self.query_ef)
Expand Down Expand Up @@ -74,7 +74,7 @@ def _voyager_retriever(self, neighbours=12, ef_construction=200, random_seed=1,
print(index.ef)
it = range(0, meta['doc_count'], BATCH_SIZE)
if self.verbose:
it = logger.pbar(it, desc='building index', unit='dbatch')
it = pt.tqdm(it, desc='building index', unit='dbatch')
for idx in it:
index.add_items(dvecs[idx:idx+BATCH_SIZE])
with logger.duration('saving index'):
Expand Down

0 comments on commit aaa868c

Please sign in to comment.