diff --git a/pyterrier_dr/bge_m3.py b/pyterrier_dr/bge_m3.py index 8e79ad5..7dbad16 100644 --- a/pyterrier_dr/bge_m3.py +++ b/pyterrier_dr/bge_m3.py @@ -1,8 +1,8 @@ -from tqdm import tqdm import pyterrier as pt import pandas as pd import numpy as np import torch +import pyterrier_alpha as pta from .biencoder import BiEncoder class BGEM3(BiEncoder): @@ -61,8 +61,8 @@ def encode(self, texts): return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: - assert all(c in inp.columns for c in ['query']) - + pta.validte.columns(includes=['query']) + # check if inp is empty if len(inp) == 0: if self.dense: @@ -102,14 +102,14 @@ def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length self.dense = return_dense self.sparse = return_sparse self.multivecs = return_colbert_vecs - + def encode(self, texts): return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length, return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: # check if the input dataframe contains the field(s) specified in the text_field - assert all(c in inp.columns for c in [self.bge_factory.text_field]) + pta.validte.columns(includes=[self.bge_factory.text_field]) # check if inp is empty if len(inp) == 0: if self.dense: diff --git a/pyterrier_dr/biencoder.py b/pyterrier_dr/biencoder.py index 09dd3c9..1430748 100644 --- a/pyterrier_dr/biencoder.py +++ b/pyterrier_dr/biencoder.py @@ -72,7 +72,7 @@ def encode(self, texts, batch_size=None) -> np.array: return self.bi_encoder_model.encode_queries(texts, batch_size=batch_size or self.batch_size) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: - assert all(c in inp.columns for c in ['query']) + pta.validate.columns(includes=['query']) it = inp['query'].values it, inv = np.unique(it, return_inverse=True) if self.verbose: @@ -95,7 +95,7 @@ def encode(self, texts, batch_size=None) -> np.array: return self.bi_encoder_model.encode_docs(texts, batch_size=batch_size or self.batch_size) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: - assert all(c in inp.columns for c in [self.text_field]) + pta.validate.columns(includes=[self.text_field]) it = inp[self.text_field] if self.verbose: it = pt.tqdm(it, desc='Encoding Docs', unit='doc') @@ -114,8 +114,11 @@ def __init__(self, bi_encoder_model: BiEncoder, verbose=None, batch_size=None, t self.sim_fn = sim_fn if sim_fn is not None else bi_encoder_model.sim_fn def transform(self, inp: pd.DataFrame) -> pd.DataFrame: - assert 'query_vec' in inp.columns or 'query' in inp.columns - assert 'doc_vec' in inp.columns or self.text_field in inp.columns + with pta.validate.any(inp) as v: + v.columns(includes=['query_vec', 'doc_vec']) + v.columns(includes=['query', 'doc_vec']) + v.columns(includes=['query_vec', self.text_field]) + v.columns(includes=['query', self.text_field]) if 'query_vec' in inp.columns: query_vec = inp['query_vec'] else: diff --git a/pyterrier_dr/flex/core.py b/pyterrier_dr/flex/core.py index 66d500a..57fdea1 100644 --- a/pyterrier_dr/flex/core.py +++ b/pyterrier_dr/flex/core.py @@ -107,21 +107,11 @@ def get_corpus_iter(self, start_idx=None, stop_idx=None, verbose=True): for docno, i in it: yield {'docno': docno, 'doc_vec': dvecs[i]} - def np_retriever(self, batch_size=None, num_results=None): - return FlexIndexNumpyRetriever(self, batch_size, num_results=num_results or self.num_results) - - def torch_retriever(self, batch_size=None): - return FlexIndexTorchRetriever(self, batch_size) - - def vec_loader(self): - return FlexIndexVectorLoader(self) - - def scorer(self): - return FlexIndexScorer(self) - def _load_docids(self, inp): - assert 'docid' in inp.columns or 'docno' in inp.columns - if 'docid' in inp.columns: + with pta.validate.any(inp) as v: + v.columns(includes=['docid'], mode='docid') + v.columns(includes=['docno'], mode='docno') + if v.mode == 'docid': return inp['docid'].values docnos, config = self.payload(return_dvecs=False) return docnos.inv[inp['docno'].values] # look up docids from docnos @@ -130,134 +120,6 @@ def built(self): return self.index_path.exists() -class FlexIndexNumpyRetriever(pt.Transformer): - def __init__(self, flex_index, batch_size=None, num_results=None): - self.flex_index = flex_index - self.batch_size = batch_size or 4096 - self.num_results = num_results or self.flex_index.num_results - - def transform(self, inp: pd.DataFrame) -> pd.DataFrame: - inp = inp.reset_index(drop=True) - query_vecs = np.stack(inp['query_vec']) - docnos, dvecs, config = self.flex_index.payload() - if self.flex_index.sim_fn == SimFn.cos: - query_vecs = query_vecs / np.linalg.norm(query_vecs, axis=1, keepdims=True) - elif self.flex_index.sim_fn == SimFn.dot: - pass # nothing to do - else: - raise ValueError(f'{self.flex_index.sim_fn} not supported') - num_q = query_vecs.shape[0] - res = [] - ranked_lists = RankedLists(self.num_results, num_q) - batch_it = range(0, dvecs.shape[0], self.batch_size) - if self.flex_index.verbose: - batch_it = pt.tqdm(batch_it) - for idx_start in batch_it: - doc_batch = dvecs[idx_start:idx_start+self.batch_size].T - if self.flex_index.sim_fn == SimFn.cos: - doc_batch = doc_batch / np.linalg.norm(doc_batch, axis=0, keepdims=True) - scores = query_vecs @ doc_batch - dids = np.arange(idx_start, idx_start+doc_batch.shape[1], dtype='i4').reshape(1, -1).repeat(num_q, axis=0) - ranked_lists.update(scores, dids) - result_scores, result_dids = ranked_lists.results() - result_docnos = [docnos.fwd[d] for d in result_dids] - cols = { - 'score': np.concatenate(result_scores), - 'docno': np.concatenate(result_docnos), - 'docid': np.concatenate(result_dids), - 'rank': np.concatenate([np.arange(len(scores)) for scores in result_scores]), - } - idxs = list(itertools.chain(*(itertools.repeat(i, len(scores)) for i, scores in enumerate(result_scores)))) - for col in inp.columns: - if col != 'query_vec': - cols[col] = inp[col][idxs].values - return pd.DataFrame(cols) - - -class FlexIndexTorchRetriever(pt.Transformer): - def __init__(self, flex_index, batch_size=None): - self.flex_index = flex_index - self.batch_size = batch_size or 4096 - docnos, meta, = flex_index.payload(return_dvecs=False) - SType, TType, CTType, SIZE = torch.FloatStorage, torch.FloatTensor, torch.cuda.FloatTensor, 4 - self._cpu_data = TType(SType.from_file(str(self.flex_index.index_path/'vecs.f4'), size=meta['doc_count'] * meta['vec_size'])).reshape(meta['doc_count'], meta['vec_size']) - self._cuda_data = CTType(size=(self.batch_size, meta['vec_size']), device='cuda') - self._docnos = docnos - - def transform(self, inp): - columns = set(inp.columns) - assert all(f in columns for f in ['qid', 'query_vec']), "TorchIndex expects columns ['qid', 'query_vec'] when used in a pipeline" - query_vecs = np.stack(inp['query_vec']) - query_vecs = torch.from_numpy(query_vecs).cuda() # TODO: can this go directly to CUDA? device='cuda' doesn't work - - step = self._cuda_data.shape[0] - it = range(0, self._cpu_data.shape[0], step) - if self.flex_index.verbose: - it = pt.tqdm(it, desc='TorchIndex scoring', unit='docbatch') - - ranked_lists = RankedLists(self.flex_index.num_results, query_vecs.shape[0]) - for start_idx in it: - end_idx = start_idx + step - batch = self._cpu_data[start_idx:end_idx] - bsize = batch.shape[0] - self._cuda_data[:bsize] = batch - - scores = query_vecs @ self._cuda_data[:bsize].T - if scores.shape[0] > self.flex_index.num_results: - scores, dids = torch.topk(scores, k=self.flex_index.num_results, dim=1) - else: - scores, dids = torch.sort(scores, dim=1, descending=False) - scores = scores.cpu().float().numpy() - - dids = (dids + start_idx).cpu().numpy() - ranked_lists.update(scores, dids) - - result_scores, result_dids = ranked_lists.results() - result_docnos = self._docnos.fwd[result_dids] - res = [] - for query, scores, docnos in zip(inp.itertuples(index=False), result_scores, result_docnos): - for score, docno in zip(scores, docnos): - res.append((*query, docno, score)) - res = pd.DataFrame(res, columns=list(query._fields) + ['docno', 'score']) - res = res[~res.score.isna()] - res = add_ranks(res) - return res - - def _load_dvecs(flex_index, inp): - assert 'docid' in inp.columns or 'docno' in inp.columns - docnos, dvecs, config = flex_index.payload() - if 'docid' in inp.columns: - docids = inp['docid'].values - else: - docids = docnos.inv[inp['docno'].values] - return dvecs[docids] - - -class FlexIndexVectorLoader(pt.Transformer): - def __init__(self, flex_index): - self.flex_index = flex_index - - def transform(self, inp: pd.DataFrame) -> pd.DataFrame: - return inp.assign(doc_vec=list(_load_dvecs(self.flex_index, inp))) - - -class FlexIndexScorer(pt.Transformer): - def __init__(self, flex_index): - self.flex_index = flex_index - - def transform(self, inp: pd.DataFrame) -> pd.DataFrame: - assert 'query_vec' in inp.columns - doc_vecs = _load_dvecs(self.flex_index, inp) - query_vecs = np.stack(inp['query_vec']) - if self.flex_index.sim_fn == SimFn.cos: - query_vecs = query_vecs / np.linalg.norm(query_vecs, axis=1, keepdims=True) - doc_vecs = doc_vecs / np.linalg.norm(doc_vecs, axis=1, keepdims=True) - scores = (query_vecs * doc_vecs).sum(axis=1) - elif self.flex_index.sim_fn == SimFn.dot: - scores = (query_vecs * doc_vecs).sum(axis=1) - else: - raise ValueError(f'{self.flex_index.sim_fn} not supported') - res = inp.assign(score=scores) - res = add_ranks(res) - return res + dvecs, config = flex_index.payload(return_docnos=False) + return dvecs[flex_index._load_docids(inp)] diff --git a/pyterrier_dr/flex/np_retr.py b/pyterrier_dr/flex/np_retr.py index 6d9fb9a..b949725 100644 --- a/pyterrier_dr/flex/np_retr.py +++ b/pyterrier_dr/flex/np_retr.py @@ -82,8 +82,9 @@ def score(self, query_vecs, docids): raise ValueError(f'{self.flex_index.sim_fn} not supported') def transform(self, inp): - assert 'query_vec' in inp.columns - assert 'docno' in inp.columns or 'docid' in inp.columns + with pta.validate.any(inp) as v: + v.columns(includes=['query_vec', 'docno']) + v.columns(includes=['query_vec', 'docid']) inp = inp.reset_index(drop=True) res_idxs = []