Skip to content

Commit

Permalink
Merge pull request #16 from terrierteam/refactor
Browse files Browse the repository at this point in the history
refactoring
  • Loading branch information
seanmacavaney authored Jul 20, 2023
2 parents 3ee5b7a + 862b973 commit 53e965a
Show file tree
Hide file tree
Showing 17 changed files with 1,405 additions and 117 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ pt.init()
import pyterrier_dr
```

## Built-in Models

| Model | `.query_encoder()` | `.doc_encoder()` | `.scorer()` |
|-------|:---------------:|:-------------:|:--------:|
| [`TctColBert`](https://arxiv.org/abs/2010.11386) ||||
| [`TasB`](https://arxiv.org/abs/2104.06967) ||||
| [`Ance`](https://arxiv.org/abs/2007.00808) ||||
| [`Query2Query`](https://neeva.com/blog/state-of-the-art-query2query-similarity) || | |

## Inference

Bi-encoder models are represented as PyTerrier transformers. For instance,
Expand Down
5 changes: 5 additions & 0 deletions pyterrier_dr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from .util import SimFn, infer_device
from .indexes import DocnoFile, NilIndex, NumpyIndex, RankedLists, FaissFlat, FaissHnsw, MemIndex, TorchIndex
from .flex import FlexIndex
from .biencoder import BiEncoder, BiQueryEncoder, BiDocEncoder, BiScorer
from .hgf_models import HgfBiEncoder, TasB, RetroMAE
from .sbert_models import SBertBiEncoder, Ance, Query2Query
from .tctcolbert_model import TctColBert
from .electra import ElectraScorer
136 changes: 136 additions & 0 deletions pyterrier_dr/biencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from more_itertools import chunked
import numpy as np
import torch
from torch import nn
import pyterrier as pt
import pandas as pd
from . import SimFn


class BiEncoder(pt.Transformer):
def __init__(self, batch_size=32, text_field='text', verbose=False):
super().__init__()
self.batch_size = batch_size
self.text_field = text_field
self.verbose = verbose

def encode_queries(self, texts, batch_size=None) -> np.array:
raise NotImplementedError()

def encode_docs(self, texts, batch_size=None) -> np.array:
raise NotImplementedError()

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
columns = set(inp.columns)
modes = [
(['qid', 'query', self.text_field], self.scorer),
(['query'], self.query_encoder),
([self.text_field], self.doc_encoder),
]
for fields, fn in modes:
if all(f in columns for f in fields):
return fn()(inp)
message = f'Unexpected input with columns: {inp.columns}. Supports:'
for fields, fn in modes:
message += f'\n - {fn.__doc__.strip()}: {fields}'
raise RuntimeError(message)

def query_encoder(self, verbose=None, batch_size=None) -> pt.Transformer:
"""
Query encoding
"""
return BiQueryEncoder(self, verbose=verbose, batch_size=batch_size)

def doc_encoder(self, verbose=None, batch_size=None) -> pt.Transformer:
"""
Doc encoding
"""
return BiDocEncoder(self, verbose=verbose, batch_size=batch_size)

def scorer(self, verbose=None, batch_size=None, sim_fn=None) -> pt.Transformer:
"""
Scoring (re-ranking)
"""
return BiScorer(self, verbose=verbose, batch_size=batch_size, sim_fn=sim_fn)

@property
def sim_fn(self) -> SimFn:
"""
The similarity function to use between embeddings for this model
"""
if hasattr(self, 'config') and hasattr(self.config, 'sim_fn'):
return SimFn(self.config.sim_fn)
return SimFn.dot # default


class BiQueryEncoder(pt.Transformer):
def __init__(self, bi_encoder_model: BiEncoder, verbose=None, batch_size=None):
self.bi_encoder_model = bi_encoder_model
self.verbose = verbose if verbose is not None else bi_encoder_model.verbose
self.batch_size = batch_size if batch_size is not None else bi_encoder_model.batch_size

def encode(self, texts, batch_size=None) -> np.array:
return self.bi_encoder_model.encode_queries(texts, batch_size=batch_size or self.batch_size)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in ['query'])
it = inp['query'].values
it, inv = np.unique(it, return_inverse=True)
if self.verbose:
it = pt.tqdm(it, desc='Encoding Queries', unit='query')
enc = self.encode(it)
return inp.assign(query_vec=[enc[i] for i in inv])

def __repr__(self):
return f'{repr(self.bi_encoder_model)}.query_encoder()'


class BiDocEncoder(pt.Transformer):
def __init__(self, bi_encoder_model: BiEncoder, verbose=None, batch_size=None, text_field=None):
self.bi_encoder_model = bi_encoder_model
self.verbose = verbose if verbose is not None else bi_encoder_model.verbose
self.batch_size = batch_size if batch_size is not None else bi_encoder_model.batch_size
self.text_field = text_field if text_field is not None else bi_encoder_model.text_field

def encode(self, texts, batch_size=None) -> np.array:
return self.bi_encoder_model.encode_docs(texts, batch_size=batch_size or self.batch_size)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in [self.text_field])
it = inp[self.text_field]
if self.verbose:
it = pt.tqdm(it, desc='Encoding Docs', unit='doc')
return inp.assign(doc_vec=list(self.encode(it)))

def __repr__(self):
return f'{repr(self.bi_encoder_model)}.doc_encoder()'


class BiScorer(pt.Transformer):
def __init__(self, bi_encoder_model: BiEncoder, verbose=None, batch_size=None, text_field=None, sim_fn=None):
self.bi_encoder_model = bi_encoder_model
self.verbose = verbose if verbose is not None else bi_encoder_model.verbose
self.batch_size = batch_size if batch_size is not None else bi_encoder_model.batch_size
self.text_field = text_field if text_field is not None else bi_encoder_model.text_field
self.sim_fn = sim_fn if sim_fn is not None else bi_encoder_model.sim_fn

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert 'query_vec' in inp.columns or 'query' in inp.columns
assert 'doc_vec' in inp.columns or self.text_field in inp.columns
if 'query_vec' in inp.columns:
query_vec = inp['query_vec']
else:
query_vec = self.bi_encoder_model.query_encoder(batch_size=self.batch_size, verbose=self.verbose)(inp)['query_vec']
if 'doc_vec' in inp.columns:
doc_vec = inp['doc_vec']
else:
doc_vec = self.bi_encoder_model.doc_encoder(batch_size=self.batch_size, verbose=self.verbose)(inp)['doc_vec']
if self.sim_fn == SimFn.dot:
scores = (query_vec * doc_vec).apply(np.sum)
else:
raise ValueError(f'{self.sim_fn} not yet supported by BiScorer')
outp = inp.assign(score=scores)
return pt.model.add_ranks(outp)

def __repr__(self):
return f'{repr(self.bi_encoder_model)}.scorer()'
8 changes: 8 additions & 0 deletions pyterrier_dr/flex/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .core import FlexIndex, IndexingMode
from .np_retr import *
from .torch_retr import *
from .corpus_graph import *
from .faiss_retr import *
from .scann_retr import *
from .ladr import *
from .gar import *
Loading

0 comments on commit 53e965a

Please sign in to comment.