Skip to content

Commit

Permalink
add E5 as SBertBiEncoder (#30)
Browse files Browse the repository at this point in the history
* add E5 as SBertBiEncoder

* fix test

* use functools.partialmethod()

* fix test case

* initial sketch of sbert docs

* docs

* fix test

* add e5 variants

---------

Co-authored-by: Sean MacAvaney <[email protected]>
Co-authored-by: jinyuan <[email protected]>
  • Loading branch information
3 people authored Dec 2, 2024
1 parent 5071f23 commit 9c86daa
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pyterrier_dr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyterrier_dr.flex import FlexIndex
from pyterrier_dr.biencoder import BiEncoder, BiQueryEncoder, BiDocEncoder, BiScorer
from pyterrier_dr.hgf_models import HgfBiEncoder, TasB, RetroMAE
from pyterrier_dr.sbert_models import SBertBiEncoder, Ance, Query2Query, GTR
from pyterrier_dr.sbert_models import SBertBiEncoder, Ance, Query2Query, GTR, E5
from pyterrier_dr.tctcolbert_model import TctColBert
from pyterrier_dr.electra import ElectraScorer
from pyterrier_dr.bge_m3 import BGEM3, BGEM3QueryEncoder, BGEM3DocEncoder
Expand All @@ -14,5 +14,5 @@

__all__ = ["FlexIndex", "DocnoFile", "NilIndex", "NumpyIndex", "RankedLists", "FaissFlat", "FaissHnsw", "MemIndex", "TorchIndex",
"BiEncoder", "BiQueryEncoder", "BiDocEncoder", "BiScorer", "HgfBiEncoder", "TasB", "RetroMAE", "SBertBiEncoder", "Ance",
"Query2Query", "GTR", "TctColBert", "ElectraScorer", "BGEM3", "BGEM3QueryEncoder", "BGEM3DocEncoder", "CDE", "CDECache",
"Query2Query", "GTR", "E5", "TctColBert", "ElectraScorer", "BGEM3", "BGEM3QueryEncoder", "BGEM3DocEncoder", "CDE", "CDECache",
"SimFn", "infer_device", "AveragePrf", "VectorPrf"]
1 change: 1 addition & 0 deletions pyterrier_dr/pt_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ Import ``pyterrier_dr``, load a pre-built index and model, and retrieve:
:maxdepth: 1

prf
sbert
24 changes: 24 additions & 0 deletions pyterrier_dr/pt_docs/sbert.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Using Sentence Transformer models for Dense retrieval in PyTerrier
==================================================================

With PyTerrier_DR, its easy to support Sentence Transformer (formerly called SentenceBERT)
models, e.g. from HuggingFace, for dense retrieval.

The base class is ``SBertBiEncoder('huggingface/path')``;

There are easy to remember classes for a number of standard models:
- ANCE - an early single-representation dense retrieval model: ``Ance.firstp()``
- GTR - a dense retrieval model based on the T5 pre-trained encoder: ``GTR.base()``
- `E5 <https://huggingface.co/intfloat/e5-base-v2>`: ``E5.base()``
- Query2Query - a query similarity model: Query2Query()

The standard pyterrier_dr pipelines can be used:

Indexing::
model = pyterrier_dr.GTR.base()
index = pyterrier_dr.FlexIndex('gtr.flex')
pipe = (model >> index)
pipe.index(pt.get_dataset('irds:msmarco-passage').get_corpus_iter())

Retrieval::
pipe.search("chemical reactions")
21 changes: 19 additions & 2 deletions pyterrier_dr/sbert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,26 @@
import torch
import pyterrier as pt
from transformers import AutoConfig
from functools import partialmethod
from .biencoder import BiEncoder, BiQueryEncoder
from .util import Variants
from tqdm import tqdm

def _sbert_encode(self, texts, batch_size=None):
def _sbert_encode(self, texts, batch_size=None, prompt=None, normalize_embeddings=False):
show_progress = False
if isinstance(texts, tqdm):
texts.disable = True
show_progress = True
texts = list(texts)
if prompt is not None:
texts = [prompt + t for t in texts]
if len(texts) == 0:
return np.empty(shape=(0, 0))
return self.model.encode(texts, batch_size=batch_size or self.batch_size, show_progress_bar=show_progress)
return self.model.encode(texts,
batch_size=batch_size or self.batch_size,
show_progress_bar=show_progress,
normalize_embeddings=normalize_embeddings
)


class SBertBiEncoder(BiEncoder):
Expand Down Expand Up @@ -52,6 +59,16 @@ class Ance(_SBertBiEncoder):
'firstp': 'sentence-transformers/msmarco-roberta-base-ance-firstp',
}

class E5(_SBertBiEncoder):

encode_queries = partialmethod(_sbert_encode, prompt='query: ', normalize_embeddings=True)
encode_docs = partialmethod(_sbert_encode, prompt='passage: ', normalize_embeddings=True)

VARIANTS = {
'base' : 'intfloat/e5-base-v2',
'small': 'intfloat/e5-small-v2',
'large': 'intfloat/e5-large-v2',
}

class GTR(_SBertBiEncoder):
VARIANTS = {
Expand Down
7 changes: 7 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ def test_ance(self):
from pyterrier_dr import Ance
self._base_test(Ance.firstp())

def test_e5(self):
from pyterrier_dr import E5
testmodel = E5.base()
self._base_test(testmodel)
inp = pd.DataFrame([{'qid': 'q1', 'query' : 'chemical reactions', 'docno' : 'd2', 'text' : 'professor proton mixed the chemical'}])
self.assertTrue(testmodel(inp).iloc[0]['score'] > 0)

def test_tasb(self):
from pyterrier_dr import TasB
self._base_test(TasB.dot())
Expand Down

0 comments on commit 9c86daa

Please sign in to comment.