diff --git a/pyterrier_dr/__init__.py b/pyterrier_dr/__init__.py index d860373..90bede9 100644 --- a/pyterrier_dr/__init__.py +++ b/pyterrier_dr/__init__.py @@ -5,7 +5,7 @@ from pyterrier_dr.flex import FlexIndex from pyterrier_dr.biencoder import BiEncoder, BiQueryEncoder, BiDocEncoder, BiScorer from pyterrier_dr.hgf_models import HgfBiEncoder, TasB, RetroMAE -from pyterrier_dr.sbert_models import SBertBiEncoder, Ance, Query2Query, GTR +from pyterrier_dr.sbert_models import SBertBiEncoder, Ance, Query2Query, GTR, E5 from pyterrier_dr.tctcolbert_model import TctColBert from pyterrier_dr.electra import ElectraScorer from pyterrier_dr.bge_m3 import BGEM3, BGEM3QueryEncoder, BGEM3DocEncoder @@ -14,5 +14,5 @@ __all__ = ["FlexIndex", "DocnoFile", "NilIndex", "NumpyIndex", "RankedLists", "FaissFlat", "FaissHnsw", "MemIndex", "TorchIndex", "BiEncoder", "BiQueryEncoder", "BiDocEncoder", "BiScorer", "HgfBiEncoder", "TasB", "RetroMAE", "SBertBiEncoder", "Ance", - "Query2Query", "GTR", "TctColBert", "ElectraScorer", "BGEM3", "BGEM3QueryEncoder", "BGEM3DocEncoder", "CDE", "CDECache", + "Query2Query", "GTR", "E5", "TctColBert", "ElectraScorer", "BGEM3", "BGEM3QueryEncoder", "BGEM3DocEncoder", "CDE", "CDECache", "SimFn", "infer_device", "AveragePrf", "VectorPrf"] diff --git a/pyterrier_dr/pt_docs/index.rst b/pyterrier_dr/pt_docs/index.rst index e6dfc55..92328c2 100644 --- a/pyterrier_dr/pt_docs/index.rst +++ b/pyterrier_dr/pt_docs/index.rst @@ -42,3 +42,4 @@ Import ``pyterrier_dr``, load a pre-built index and model, and retrieve: :maxdepth: 1 prf + sbert diff --git a/pyterrier_dr/pt_docs/sbert.rst b/pyterrier_dr/pt_docs/sbert.rst new file mode 100644 index 0000000..fb32ea7 --- /dev/null +++ b/pyterrier_dr/pt_docs/sbert.rst @@ -0,0 +1,24 @@ +Using Sentence Transformer models for Dense retrieval in PyTerrier +================================================================== + +With PyTerrier_DR, its easy to support Sentence Transformer (formerly called SentenceBERT) +models, e.g. from HuggingFace, for dense retrieval. + +The base class is ``SBertBiEncoder('huggingface/path')``; + +There are easy to remember classes for a number of standard models: + - ANCE - an early single-representation dense retrieval model: ``Ance.firstp()`` + - GTR - a dense retrieval model based on the T5 pre-trained encoder: ``GTR.base()`` + - `E5 `: ``E5.base()`` + - Query2Query - a query similarity model: Query2Query() + +The standard pyterrier_dr pipelines can be used: + +Indexing:: + model = pyterrier_dr.GTR.base() + index = pyterrier_dr.FlexIndex('gtr.flex') + pipe = (model >> index) + pipe.index(pt.get_dataset('irds:msmarco-passage').get_corpus_iter()) + +Retrieval:: + pipe.search("chemical reactions") diff --git a/pyterrier_dr/sbert_models.py b/pyterrier_dr/sbert_models.py index 24166cd..00ad8be 100644 --- a/pyterrier_dr/sbert_models.py +++ b/pyterrier_dr/sbert_models.py @@ -2,19 +2,26 @@ import torch import pyterrier as pt from transformers import AutoConfig +from functools import partialmethod from .biencoder import BiEncoder, BiQueryEncoder from .util import Variants from tqdm import tqdm -def _sbert_encode(self, texts, batch_size=None): +def _sbert_encode(self, texts, batch_size=None, prompt=None, normalize_embeddings=False): show_progress = False if isinstance(texts, tqdm): texts.disable = True show_progress = True texts = list(texts) + if prompt is not None: + texts = [prompt + t for t in texts] if len(texts) == 0: return np.empty(shape=(0, 0)) - return self.model.encode(texts, batch_size=batch_size or self.batch_size, show_progress_bar=show_progress) + return self.model.encode(texts, + batch_size=batch_size or self.batch_size, + show_progress_bar=show_progress, + normalize_embeddings=normalize_embeddings + ) class SBertBiEncoder(BiEncoder): @@ -52,6 +59,16 @@ class Ance(_SBertBiEncoder): 'firstp': 'sentence-transformers/msmarco-roberta-base-ance-firstp', } +class E5(_SBertBiEncoder): + + encode_queries = partialmethod(_sbert_encode, prompt='query: ', normalize_embeddings=True) + encode_docs = partialmethod(_sbert_encode, prompt='passage: ', normalize_embeddings=True) + + VARIANTS = { + 'base' : 'intfloat/e5-base-v2', + 'small': 'intfloat/e5-small-v2', + 'large': 'intfloat/e5-large-v2', + } class GTR(_SBertBiEncoder): VARIANTS = { diff --git a/tests/test_models.py b/tests/test_models.py index 29b9c27..ff88b94 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -152,6 +152,13 @@ def test_ance(self): from pyterrier_dr import Ance self._base_test(Ance.firstp()) + def test_e5(self): + from pyterrier_dr import E5 + testmodel = E5.base() + self._base_test(testmodel) + inp = pd.DataFrame([{'qid': 'q1', 'query' : 'chemical reactions', 'docno' : 'd2', 'text' : 'professor proton mixed the chemical'}]) + self.assertTrue(testmodel(inp).iloc[0]['score'] > 0) + def test_tasb(self): from pyterrier_dr import TasB self._base_test(TasB.dot())