Skip to content

Commit

Permalink
some model documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Dec 3, 2024
1 parent e537a64 commit 9942273
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 35 deletions.
24 changes: 20 additions & 4 deletions pyterrier_dr/hgf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __repr__(self):

class _HgfBiEncoder(HgfBiEncoder, metaclass=Variants):
VARIANTS: dict = None
def __init__(self, model_name, batch_size=32, text_field='text', verbose=False, device=None):
self.model_name = model_name
def __init__(self, model_name=None, batch_size=32, text_field='text', verbose=False, device=None):
self.model_name = model_name or next(iter(self.VARIANTS.values()))
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
Expand All @@ -71,15 +71,31 @@ def __repr__(self):


class TasB(_HgfBiEncoder):
"""Dense encoder for TAS-B (Topic Aware Sampling, Balanced).
See :class:`~pyterrier_dr.BiEncoder` for usage information.
.. cite.dblp:: conf/sigir/HofstatterLYLH21
.. automethod:: dot()
"""
VARIANTS = {
'dot': 'sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco',
}


class RetroMAE(_HgfBiEncoder):
"""Dense encoder for RetroMAE (Masked Auto-Encoder).
See :class:`~pyterrier_dr.BiEncoder` for usage information.
.. cite.dblp:: conf/emnlp/XiaoLSC22
.. automethod:: msmarco_finetune()
.. automethod:: msmarco_distill()
.. automethod:: wiki_bookscorpus_beir()
"""
VARIANTS = {
#'wiki_bookscorpus': 'Shitao/RetroMAE', # only pre-trained
#'msmarco': 'Shitao/RetroMAE_MSMARCO', # only pre-trained
'msmarco_finetune': 'Shitao/RetroMAE_MSMARCO_finetune',
'msmarco_distill': 'Shitao/RetroMAE_MSMARCO_distill',
'wiki_bookscorpus_beir': 'Shitao/RetroMAE_BEIR',
Expand Down
24 changes: 24 additions & 0 deletions pyterrier_dr/pt_docs/encoding.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Encoding
==================================================================

Sentence Transformers
------------------------------------------------------------------

With pyterrier_dr, its easy to support Sentence Transformer (formerly called SentenceBERT)
models, e.g. from HuggingFace, for dense retrieval.

The base class is ``SBertBiEncoder('huggingface/path')``.

Pretrained Encoders
------------------------------------------------------------------

These classes are convenience aliases to popular dense encoding models.

.. autoclass:: pyterrier_dr.Ance()
.. autoclass:: pyterrier_dr.BGEM3()
.. autoclass:: pyterrier_dr.CDE()
.. autoclass:: pyterrier_dr.E5()
.. autoclass:: pyterrier_dr.GTR()
.. autoclass:: pyterrier_dr.Query2Query()
.. autoclass:: pyterrier_dr.RetroMAE()
.. autoclass:: pyterrier_dr.TasB()
2 changes: 1 addition & 1 deletion pyterrier_dr/pt_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ This functionality is covered in more detail in the following pages:
:maxdepth: 1

overview
sbert
encoding
indexing-retrieval
prf
26 changes: 0 additions & 26 deletions pyterrier_dr/pt_docs/sbert.rst

This file was deleted.

50 changes: 46 additions & 4 deletions pyterrier_dr/sbert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,29 @@ def __repr__(self):


class Ance(_SBertBiEncoder):
"""Dense encoder for ANCE (Approximate nearest neighbor Negative Contrastive Learning).
See :class:`~pyterrier_dr.BiEncoder` for usage information.
.. cite.dblp:: conf/iclr/XiongXLTLBAO21
.. automethod:: firstp()
"""
VARIANTS = {
'firstp': 'sentence-transformers/msmarco-roberta-base-ance-firstp',
}

class E5(_SBertBiEncoder):
"""Dense encoder for E5 (EmbEddings from bidirEctional Encoder rEpresentations).
See :class:`~pyterrier_dr.BiEncoder` for usage information.
.. cite.dblp:: journals/corr/abs-2212-03533
.. automethod:: base()
.. automethod:: small()
.. automethod:: large()
"""

encode_queries = partialmethod(_sbert_encode, prompt='query: ', normalize_embeddings=True)
encode_docs = partialmethod(_sbert_encode, prompt='passage: ', normalize_embeddings=True)
Expand All @@ -71,17 +89,37 @@ class E5(_SBertBiEncoder):
}

class GTR(_SBertBiEncoder):
"""Dense encoder for GTR (Generalizable T5-based dense Retrievers)
See :class:`~pyterrier_dr.BiEncoder` for usage information.
.. cite.dblp:: conf/emnlp/Ni0LDAMZLHCY22
.. automethod:: base()
.. automethod:: large()
.. automethod:: xl()
.. automethod:: xxl()
"""
VARIANTS = {
'base': 'sentence-transformers/gtr-t5-base',
'large': 'sentence-transformers/gtr-t5-large',
'xl': 'sentence-transformers/gtr-t5-xl',
'xxl': 'sentence-transformers/gtr-t5-xxl',
}

class Query2Query(pt.Transformer):
DEFAULT_MODEL_NAME = 'neeva/query2query'
def __init__(self, model_name=DEFAULT_MODEL_NAME, batch_size=32, verbose=False, device=None):
self.model_name = model_name
class Query2Query(pt.Transformer, metaclass=Variants):
"""Dense query encoder model for query similarity.
Note that this encoder only provides a :meth:`~pyterrier_dr.BiEncoder.query_encoder` (no document encoder or scorer).
.. cite:: query2query
:citation: Bathwal and Samdani. State-of-the-art Query2Query Similarity. 2022.
:link: https://web.archive.org/web/20220923212754/https://neeva.com/blog/state-of-the-art-query2query-similarity
.. automethod:: base()
"""
def __init__(self, model_name=None, batch_size=32, verbose=False, device=None):
self.model_name = model_name or next(iter(self.VARIANTS.values()))
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)
Expand All @@ -90,6 +128,10 @@ def __init__(self, model_name=DEFAULT_MODEL_NAME, batch_size=32, verbose=False,
self.batch_size = batch_size
self.verbose = verbose

VARIANTS = {
'base': 'neeva/query2query',
}

encode = _sbert_encode
transform = BiQueryEncoder.transform
__repr__ = _SBertBiEncoder.__repr__
4 changes: 4 additions & 0 deletions pyterrier_dr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ class SimFn(Enum):
class Variants(type):
def __getattr__(cls, name):
if name in cls.VARIANTS:
@staticmethod
def wrapped(*args, **kwargs):
return cls(cls.VARIANTS[name], *args, **kwargs)
wrapped.__doc__ = f"``{cls.VARIANTS[name]}``"
if name == next(iter(cls.VARIANTS)):
wrapped.__doc__ = '*(default)* ' + wrapped.__doc__
return wrapped

def __init__(self, *args, **kwargs):
Expand Down

0 comments on commit 9942273

Please sign in to comment.