diff --git a/pyterrier_dr/hgf_models.py b/pyterrier_dr/hgf_models.py index 280e80e..7e04df4 100644 --- a/pyterrier_dr/hgf_models.py +++ b/pyterrier_dr/hgf_models.py @@ -56,8 +56,8 @@ def __repr__(self): class _HgfBiEncoder(HgfBiEncoder, metaclass=Variants): VARIANTS: dict = None - def __init__(self, model_name, batch_size=32, text_field='text', verbose=False, device=None): - self.model_name = model_name + def __init__(self, model_name=None, batch_size=32, text_field='text', verbose=False, device=None): + self.model_name = model_name or next(iter(self.VARIANTS.values())) model = AutoModel.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name) @@ -71,15 +71,31 @@ def __repr__(self): class TasB(_HgfBiEncoder): + """Dense encoder for TAS-B (Topic Aware Sampling, Balanced). + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: conf/sigir/HofstatterLYLH21 + + .. automethod:: dot() + """ VARIANTS = { 'dot': 'sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco', } class RetroMAE(_HgfBiEncoder): + """Dense encoder for RetroMAE (Masked Auto-Encoder). + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: conf/emnlp/XiaoLSC22 + + .. automethod:: msmarco_finetune() + .. automethod:: msmarco_distill() + .. automethod:: wiki_bookscorpus_beir() + """ VARIANTS = { - #'wiki_bookscorpus': 'Shitao/RetroMAE', # only pre-trained - #'msmarco': 'Shitao/RetroMAE_MSMARCO', # only pre-trained 'msmarco_finetune': 'Shitao/RetroMAE_MSMARCO_finetune', 'msmarco_distill': 'Shitao/RetroMAE_MSMARCO_distill', 'wiki_bookscorpus_beir': 'Shitao/RetroMAE_BEIR', diff --git a/pyterrier_dr/pt_docs/encoding.rst b/pyterrier_dr/pt_docs/encoding.rst new file mode 100644 index 0000000..8c96894 --- /dev/null +++ b/pyterrier_dr/pt_docs/encoding.rst @@ -0,0 +1,24 @@ +Encoding +================================================================== + +Sentence Transformers +------------------------------------------------------------------ + +With pyterrier_dr, its easy to support Sentence Transformer (formerly called SentenceBERT) +models, e.g. from HuggingFace, for dense retrieval. + +The base class is ``SBertBiEncoder('huggingface/path')``. + +Pretrained Encoders +------------------------------------------------------------------ + +These classes are convenience aliases to popular dense encoding models. + +.. autoclass:: pyterrier_dr.Ance() +.. autoclass:: pyterrier_dr.BGEM3() +.. autoclass:: pyterrier_dr.CDE() +.. autoclass:: pyterrier_dr.E5() +.. autoclass:: pyterrier_dr.GTR() +.. autoclass:: pyterrier_dr.Query2Query() +.. autoclass:: pyterrier_dr.RetroMAE() +.. autoclass:: pyterrier_dr.TasB() diff --git a/pyterrier_dr/pt_docs/index.rst b/pyterrier_dr/pt_docs/index.rst index 40a8d2e..ea6061c 100644 --- a/pyterrier_dr/pt_docs/index.rst +++ b/pyterrier_dr/pt_docs/index.rst @@ -16,6 +16,6 @@ This functionality is covered in more detail in the following pages: :maxdepth: 1 overview - sbert + encoding indexing-retrieval prf diff --git a/pyterrier_dr/pt_docs/sbert.rst b/pyterrier_dr/pt_docs/sbert.rst deleted file mode 100644 index f0e0e77..0000000 --- a/pyterrier_dr/pt_docs/sbert.rst +++ /dev/null @@ -1,26 +0,0 @@ -Encoding: Sentence Transformers -================================================================== - -With pyterrier_dr, its easy to support Sentence Transformer (formerly called SentenceBERT) -models, e.g. from HuggingFace, for dense retrieval. - -The base class is ``SBertBiEncoder('huggingface/path')``; - -There are easy to remember classes for a number of standard models: - - ANCE - an early single-representation dense retrieval model: ``Ance.firstp()`` - - GTR - a dense retrieval model based on the T5 pre-trained encoder: ``GTR.base()`` - - `E5 `__: ``E5.base()`` - - Query2Query - a query similarity model: Query2Query() - -The standard pyterrier_dr pipelines can be used: - -Indexing:: - - model = pyterrier_dr.GTR.base() - index = pyterrier_dr.FlexIndex('gtr.flex') - pipe = (model >> index) - pipe.index(pt.get_dataset('irds:msmarco-passage').get_corpus_iter()) - -Retrieval:: - - pipe.search("chemical reactions") diff --git a/pyterrier_dr/sbert_models.py b/pyterrier_dr/sbert_models.py index 00ad8be..dc6a5ef 100644 --- a/pyterrier_dr/sbert_models.py +++ b/pyterrier_dr/sbert_models.py @@ -55,11 +55,29 @@ def __repr__(self): class Ance(_SBertBiEncoder): + """Dense encoder for ANCE (Approximate nearest neighbor Negative Contrastive Learning). + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: conf/iclr/XiongXLTLBAO21 + + .. automethod:: firstp() + """ VARIANTS = { 'firstp': 'sentence-transformers/msmarco-roberta-base-ance-firstp', } class E5(_SBertBiEncoder): + """Dense encoder for E5 (EmbEddings from bidirEctional Encoder rEpresentations). + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: journals/corr/abs-2212-03533 + + .. automethod:: base() + .. automethod:: small() + .. automethod:: large() + """ encode_queries = partialmethod(_sbert_encode, prompt='query: ', normalize_embeddings=True) encode_docs = partialmethod(_sbert_encode, prompt='passage: ', normalize_embeddings=True) @@ -71,6 +89,17 @@ class E5(_SBertBiEncoder): } class GTR(_SBertBiEncoder): + """Dense encoder for GTR (Generalizable T5-based dense Retrievers) + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: conf/emnlp/Ni0LDAMZLHCY22 + + .. automethod:: base() + .. automethod:: large() + .. automethod:: xl() + .. automethod:: xxl() + """ VARIANTS = { 'base': 'sentence-transformers/gtr-t5-base', 'large': 'sentence-transformers/gtr-t5-large', @@ -78,10 +107,19 @@ class GTR(_SBertBiEncoder): 'xxl': 'sentence-transformers/gtr-t5-xxl', } -class Query2Query(pt.Transformer): - DEFAULT_MODEL_NAME = 'neeva/query2query' - def __init__(self, model_name=DEFAULT_MODEL_NAME, batch_size=32, verbose=False, device=None): - self.model_name = model_name +class Query2Query(pt.Transformer, metaclass=Variants): + """Dense query encoder model for query similarity. + + Note that this encoder only provides a :meth:`~pyterrier_dr.BiEncoder.query_encoder` (no document encoder or scorer). + + .. cite:: query2query + :citation: Bathwal and Samdani. State-of-the-art Query2Query Similarity. 2022. + :link: https://web.archive.org/web/20220923212754/https://neeva.com/blog/state-of-the-art-query2query-similarity + + .. automethod:: base() + """ + def __init__(self, model_name=None, batch_size=32, verbose=False, device=None): + self.model_name = model_name or next(iter(self.VARIANTS.values())) if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(device) @@ -90,6 +128,10 @@ def __init__(self, model_name=DEFAULT_MODEL_NAME, batch_size=32, verbose=False, self.batch_size = batch_size self.verbose = verbose + VARIANTS = { + 'base': 'neeva/query2query', + } + encode = _sbert_encode transform = BiQueryEncoder.transform __repr__ = _SBertBiEncoder.__repr__ diff --git a/pyterrier_dr/util.py b/pyterrier_dr/util.py index fba631b..526274f 100644 --- a/pyterrier_dr/util.py +++ b/pyterrier_dr/util.py @@ -10,8 +10,12 @@ class SimFn(Enum): class Variants(type): def __getattr__(cls, name): if name in cls.VARIANTS: + @staticmethod def wrapped(*args, **kwargs): return cls(cls.VARIANTS[name], *args, **kwargs) + wrapped.__doc__ = f"``{cls.VARIANTS[name]}``" + if name == next(iter(cls.VARIANTS)): + wrapped.__doc__ = '*(default)* ' + wrapped.__doc__ return wrapped def __init__(self, *args, **kwargs):