Skip to content

Commit

Permalink
misc cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Nov 23, 2024
1 parent e551ad0 commit 79074e1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 155 deletions.
10 changes: 5 additions & 5 deletions pyterrier_dr/bge_m3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from tqdm import tqdm
import pyterrier as pt
import pandas as pd
import numpy as np
import torch
import pyterrier_alpha as pta
from .biencoder import BiEncoder

class BGEM3(BiEncoder):
Expand Down Expand Up @@ -61,8 +61,8 @@ def encode(self, texts):
return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in ['query'])
pta.validte.columns(includes=['query'])

# check if inp is empty
if len(inp) == 0:
if self.dense:
Expand Down Expand Up @@ -102,14 +102,14 @@ def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length
self.dense = return_dense
self.sparse = return_sparse
self.multivecs = return_colbert_vecs

def encode(self, texts):
return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length,
return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# check if the input dataframe contains the field(s) specified in the text_field
assert all(c in inp.columns for c in [self.bge_factory.text_field])
pta.validte.columns(includes=[self.bge_factory.text_field])
# check if inp is empty
if len(inp) == 0:
if self.dense:
Expand Down
11 changes: 7 additions & 4 deletions pyterrier_dr/biencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def encode(self, texts, batch_size=None) -> np.array:
return self.bi_encoder_model.encode_queries(texts, batch_size=batch_size or self.batch_size)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in ['query'])
pta.validate.columns(includes=['query'])
it = inp['query'].values
it, inv = np.unique(it, return_inverse=True)
if self.verbose:
Expand All @@ -95,7 +95,7 @@ def encode(self, texts, batch_size=None) -> np.array:
return self.bi_encoder_model.encode_docs(texts, batch_size=batch_size or self.batch_size)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in [self.text_field])
pta.validate.columns(includes=[self.text_field])
it = inp[self.text_field]
if self.verbose:
it = pt.tqdm(it, desc='Encoding Docs', unit='doc')
Expand All @@ -114,8 +114,11 @@ def __init__(self, bi_encoder_model: BiEncoder, verbose=None, batch_size=None, t
self.sim_fn = sim_fn if sim_fn is not None else bi_encoder_model.sim_fn

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert 'query_vec' in inp.columns or 'query' in inp.columns
assert 'doc_vec' in inp.columns or self.text_field in inp.columns
with pta.validate.any(inp) as v:
v.columns(includes=['query_vec', 'doc_vec'])
v.columns(includes=['query', 'doc_vec'])
v.columns(includes=['query_vec', self.text_field])
v.columns(includes=['query', self.text_field])
if 'query_vec' in inp.columns:
query_vec = inp['query_vec']
else:
Expand Down
150 changes: 6 additions & 144 deletions pyterrier_dr/flex/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,11 @@ def get_corpus_iter(self, start_idx=None, stop_idx=None, verbose=True):
for docno, i in it:
yield {'docno': docno, 'doc_vec': dvecs[i]}

def np_retriever(self, batch_size=None, num_results=None):
return FlexIndexNumpyRetriever(self, batch_size, num_results=num_results or self.num_results)

def torch_retriever(self, batch_size=None):
return FlexIndexTorchRetriever(self, batch_size)

def vec_loader(self):
return FlexIndexVectorLoader(self)

def scorer(self):
return FlexIndexScorer(self)

def _load_docids(self, inp):
assert 'docid' in inp.columns or 'docno' in inp.columns
if 'docid' in inp.columns:
with pta.validate.any(inp) as v:
v.columns(includes=['docid'], mode='docid')
v.columns(includes=['docno'], mode='docno')
if v.mode == 'docid':
return inp['docid'].values
docnos, config = self.payload(return_dvecs=False)
return docnos.inv[inp['docno'].values] # look up docids from docnos
Expand All @@ -130,134 +120,6 @@ def built(self):
return self.index_path.exists()


class FlexIndexNumpyRetriever(pt.Transformer):
def __init__(self, flex_index, batch_size=None, num_results=None):
self.flex_index = flex_index
self.batch_size = batch_size or 4096
self.num_results = num_results or self.flex_index.num_results

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
inp = inp.reset_index(drop=True)
query_vecs = np.stack(inp['query_vec'])
docnos, dvecs, config = self.flex_index.payload()
if self.flex_index.sim_fn == SimFn.cos:
query_vecs = query_vecs / np.linalg.norm(query_vecs, axis=1, keepdims=True)
elif self.flex_index.sim_fn == SimFn.dot:
pass # nothing to do
else:
raise ValueError(f'{self.flex_index.sim_fn} not supported')
num_q = query_vecs.shape[0]
res = []
ranked_lists = RankedLists(self.num_results, num_q)
batch_it = range(0, dvecs.shape[0], self.batch_size)
if self.flex_index.verbose:
batch_it = pt.tqdm(batch_it)
for idx_start in batch_it:
doc_batch = dvecs[idx_start:idx_start+self.batch_size].T
if self.flex_index.sim_fn == SimFn.cos:
doc_batch = doc_batch / np.linalg.norm(doc_batch, axis=0, keepdims=True)
scores = query_vecs @ doc_batch
dids = np.arange(idx_start, idx_start+doc_batch.shape[1], dtype='i4').reshape(1, -1).repeat(num_q, axis=0)
ranked_lists.update(scores, dids)
result_scores, result_dids = ranked_lists.results()
result_docnos = [docnos.fwd[d] for d in result_dids]
cols = {
'score': np.concatenate(result_scores),
'docno': np.concatenate(result_docnos),
'docid': np.concatenate(result_dids),
'rank': np.concatenate([np.arange(len(scores)) for scores in result_scores]),
}
idxs = list(itertools.chain(*(itertools.repeat(i, len(scores)) for i, scores in enumerate(result_scores))))
for col in inp.columns:
if col != 'query_vec':
cols[col] = inp[col][idxs].values
return pd.DataFrame(cols)


class FlexIndexTorchRetriever(pt.Transformer):
def __init__(self, flex_index, batch_size=None):
self.flex_index = flex_index
self.batch_size = batch_size or 4096
docnos, meta, = flex_index.payload(return_dvecs=False)
SType, TType, CTType, SIZE = torch.FloatStorage, torch.FloatTensor, torch.cuda.FloatTensor, 4
self._cpu_data = TType(SType.from_file(str(self.flex_index.index_path/'vecs.f4'), size=meta['doc_count'] * meta['vec_size'])).reshape(meta['doc_count'], meta['vec_size'])
self._cuda_data = CTType(size=(self.batch_size, meta['vec_size']), device='cuda')
self._docnos = docnos

def transform(self, inp):
columns = set(inp.columns)
assert all(f in columns for f in ['qid', 'query_vec']), "TorchIndex expects columns ['qid', 'query_vec'] when used in a pipeline"
query_vecs = np.stack(inp['query_vec'])
query_vecs = torch.from_numpy(query_vecs).cuda() # TODO: can this go directly to CUDA? device='cuda' doesn't work

step = self._cuda_data.shape[0]
it = range(0, self._cpu_data.shape[0], step)
if self.flex_index.verbose:
it = pt.tqdm(it, desc='TorchIndex scoring', unit='docbatch')

ranked_lists = RankedLists(self.flex_index.num_results, query_vecs.shape[0])
for start_idx in it:
end_idx = start_idx + step
batch = self._cpu_data[start_idx:end_idx]
bsize = batch.shape[0]
self._cuda_data[:bsize] = batch

scores = query_vecs @ self._cuda_data[:bsize].T
if scores.shape[0] > self.flex_index.num_results:
scores, dids = torch.topk(scores, k=self.flex_index.num_results, dim=1)
else:
scores, dids = torch.sort(scores, dim=1, descending=False)
scores = scores.cpu().float().numpy()

dids = (dids + start_idx).cpu().numpy()
ranked_lists.update(scores, dids)

result_scores, result_dids = ranked_lists.results()
result_docnos = self._docnos.fwd[result_dids]
res = []
for query, scores, docnos in zip(inp.itertuples(index=False), result_scores, result_docnos):
for score, docno in zip(scores, docnos):
res.append((*query, docno, score))
res = pd.DataFrame(res, columns=list(query._fields) + ['docno', 'score'])
res = res[~res.score.isna()]
res = add_ranks(res)
return res


def _load_dvecs(flex_index, inp):
assert 'docid' in inp.columns or 'docno' in inp.columns
docnos, dvecs, config = flex_index.payload()
if 'docid' in inp.columns:
docids = inp['docid'].values
else:
docids = docnos.inv[inp['docno'].values]
return dvecs[docids]


class FlexIndexVectorLoader(pt.Transformer):
def __init__(self, flex_index):
self.flex_index = flex_index

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
return inp.assign(doc_vec=list(_load_dvecs(self.flex_index, inp)))


class FlexIndexScorer(pt.Transformer):
def __init__(self, flex_index):
self.flex_index = flex_index

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert 'query_vec' in inp.columns
doc_vecs = _load_dvecs(self.flex_index, inp)
query_vecs = np.stack(inp['query_vec'])
if self.flex_index.sim_fn == SimFn.cos:
query_vecs = query_vecs / np.linalg.norm(query_vecs, axis=1, keepdims=True)
doc_vecs = doc_vecs / np.linalg.norm(doc_vecs, axis=1, keepdims=True)
scores = (query_vecs * doc_vecs).sum(axis=1)
elif self.flex_index.sim_fn == SimFn.dot:
scores = (query_vecs * doc_vecs).sum(axis=1)
else:
raise ValueError(f'{self.flex_index.sim_fn} not supported')
res = inp.assign(score=scores)
res = add_ranks(res)
return res
dvecs, config = flex_index.payload(return_docnos=False)
return dvecs[flex_index._load_docids(inp)]
5 changes: 3 additions & 2 deletions pyterrier_dr/flex/np_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def score(self, query_vecs, docids):
raise ValueError(f'{self.flex_index.sim_fn} not supported')

def transform(self, inp):
assert 'query_vec' in inp.columns
assert 'docno' in inp.columns or 'docid' in inp.columns
with pta.validate.any(inp) as v:
v.columns(includes=['query_vec', 'docno'])
v.columns(includes=['query_vec', 'docid'])
inp = inp.reset_index(drop=True)

res_idxs = []
Expand Down

0 comments on commit 79074e1

Please sign in to comment.