Skip to content

Commit

Permalink
more applications of pta.validate
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Nov 23, 2024
1 parent e4c6e81 commit e551ad0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 33 deletions.
34 changes: 15 additions & 19 deletions pyterrier_dr/biencoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from more_itertools import chunked
import numpy as np
import torch
from torch import nn
import pyterrier as pt
import pandas as pd
import pyterrier_alpha as pta
from . import SimFn


Expand All @@ -21,22 +19,20 @@ def encode_docs(self, texts, batch_size=None) -> np.array:
raise NotImplementedError()

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
columns = set(inp.columns)
modes = [
(['qid', 'query', self.text_field], self.scorer),
(['qid', 'query_vec', self.text_field], self.scorer),
(['qid', 'query', 'doc_vec'], self.scorer),
(['qid', 'query_vec', 'doc_vec'], self.scorer),
(['query'], self.query_encoder),
([self.text_field], self.doc_encoder),
]
for fields, fn in modes:
if all(f in columns for f in fields):
return fn()(inp)
message = f'Unexpected input with columns: {inp.columns}. Supports:'
for fields, fn in modes:
message += f'\n - {fn.__doc__.strip()}: {fields}'
raise RuntimeError(message)
with pta.validate.any(inp) as v:
v.columns(includes=['query', self.text_field], mode='scorer')
v.columns(includes=['query_vec', self.text_field], mode='scorer')
v.columns(includes=['query', 'doc_vec'], mode='scorer')
v.columns(includes=['query_vec', 'doc_vec'], mode='scorer')
v.columns(includes=['query'], mode='query_encoder')
v.columns(includes=[self.text_field], mode='doc_encoder')

if v.mode == 'scorer':
return self.scorer()(inp)
elif v.mode == 'query_encoder':
return self.query_encoder()(inp)
elif v.mode == 'doc_encoder':
return self.doc_encoder()(inp)

def query_encoder(self, verbose=None, batch_size=None) -> pt.Transformer:
"""
Expand Down
22 changes: 8 additions & 14 deletions pyterrier_dr/flex/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..indexes import RankedLists
import ir_datasets
import torch
from pyterrier_alpha import Artifact
import pyterrier_alpha as pta

logger = ir_datasets.log.easy()

Expand All @@ -24,7 +24,7 @@ class IndexingMode(Enum):
# append???


class FlexIndex(Artifact, pt.Indexer):
class FlexIndex(pta.Artifact, pt.Indexer):
def __init__(self, index_path, num_results=1000, sim_fn=SimFn.dot, indexing_mode=IndexingMode.create, verbose=True):
super().__init__(index_path)
self.index_path = Path(index_path)
Expand Down Expand Up @@ -88,18 +88,12 @@ def index(self, inp):
json.dump({"type": "dense_index", "format": "flex", "vec_size": vec_size, "doc_count": count}, f_meta)

def transform(self, inp):
columns = set(inp.columns)
modes = [
(['qid', 'query_vec'], self.np_retriever, "performing exhaustive saerch with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster"),
]
for fields, fn, note in modes:
if all(f in columns for f in fields):
warn(f'based on input columns {list(columns)}, {note}')
return fn()(inp)
message = f'Unexpected input with columns: {inp.columns}. Supports:'
for fields, fn in modes:
message += f'\n - {fn.__doc__.strip()}: {fields}'
raise RuntimeError(message)
with pta.validate.any(inp) as v:
v.query_frame(extra_columns=['query_vec'], mode='np_retriever')

if v.mode == 'np_retriever':
warn("performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster")
return self.np_retriever()(inp)

def get_corpus_iter(self, start_idx=None, stop_idx=None, verbose=True):
docnos, dvecs, meta = self.payload()
Expand Down

0 comments on commit e551ad0

Please sign in to comment.