Skip to content

Commit

Permalink
Add BGE-M3 Encoder (#22)
Browse files Browse the repository at this point in the history
* add encoder for bge-m3 embeddings

* fix return of _BGEM3_encode and return np array directly

* move encode arguments to BGEM3Encoder attributes

* add refactor

* fix some typos + add ability to return token text for sparse encoder

* remove multiple encoder class; merge into single class

* Update README.md

* remove "Factory" suffix

* update README

* refactor BGE-M3 implementation

* fix message

* update readme

* add bgem3 dependency

* refactor sparse transform output to dict; add encode_queries & encode_docs support

* refactor to handle empty inp; fix sparse output bug; rename multi-vector output column

* add tests for bgem3

* add FlagEmbedding to test bge model

* update bgem-m3 multi-vec columns and README

---------

Co-authored-by: Craig Macdonald <[email protected]>
  • Loading branch information
andreaschari and cmacdonald authored Nov 22, 2024
1 parent 5013ffd commit 621ec9d
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 9 deletions.
72 changes: 63 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,43 @@

This provides various Dense Retrieval functionality for [PyTerrier](https://github.com/terrier-org/pyterrier).


## Installation

This repostory can be installed using pip.
This repository can be installed using pip.

```bash
pip install pyterrier-dr
```

If you want the latest version of `pyterrier_dr`, you can install direct from the Github repo:

```bash
pip install --upgrade git+https://github.com/terrierteam/pyterrier_dr.git
```

if you want to use the BGE-M3 encoder with `pyterrier_dr`, you can install the package with the `bgem3` dependency:

```bash
pip install pyterrier-dr[bgem3]
```

---
You'll also need to install FAISS.

On Colab:

!pip install faiss-cpu
On Anaconda:
```bash
!pip install faiss-cpu
```

# CPU-only version
$ conda install -c pytorch faiss-cpu
On Anaconda:

# GPU(+CPU) version
$ conda install -c pytorch faiss-gpu
```bash
# CPU-only version
conda install -c pytorch faiss-cpu
# GPU(+CPU) version
conda install -c pytorch faiss-gpu
```

You can then import the package and PyTerrier in Python:

Expand All @@ -40,6 +55,7 @@ import pyterrier_dr
| [`TasB`](https://arxiv.org/abs/2104.06967) ||||
| [`Ance`](https://arxiv.org/abs/2007.00808) ||||
| [`Query2Query`](https://neeva.com/blog/state-of-the-art-query2query-similarity) || | |
| [`BGE-M3`](https://arxiv.org/abs/2402.03216) ||||

## Inference

Expand Down Expand Up @@ -166,6 +182,44 @@ retr_pipeline = model >> index.faiss_hnsw_retriever()
# ...
```

## BGE-M3 Encoder

`pyterrier_dr` also supports using BGE-M3 for indexing and retrieval with the following encoders:

1. `query_encoder()`: Encodes queries into single-vector representations only.
2. `doc_encoder()`: Encodes documents into single-vector representations only.
3. `query_multi_encoder()`: Allows user to encode queries in dense, sparse or multi-vector representations.
4. `doc_multi_encoder()`: Allows user to encode documents in dense, sparse or multi-vector representations.

What encodings are returned by both `query_multi_encoder()` and `doc_multi_encoder()` can be controlled by the `return_dense`, `return_sparse` and `return_colbert_vecs` parameters. By default, all three are set to `True`.

### Dependencies

The BGE-M3 Encoder requires the [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding) library. You can install it as part of the `bgem3` dependency of `pyterrier_dr` (see Installation section).

### Indexing

```python
factory = BGEM3(batch_size=32, max_length=1024, verbose=True)
encoder = factory.doc_encoder()

index = FlexIndex(f"mmarco/v2/fr_bgem3", verbose=True)
indexing_pipeline = encoder >> index

indexing_pipeline.index(pt.get_dataset(f"irds:mmarco/v2/fr").get_corpus_iter())
```

### Retrieval

```python
factory = BGEM3(batch_size=32, max_length=1024)
encoder = factory.query_encoder()

index = FlexIndex(f"mmarco/v2/fr_bgem3", verbose=True)

pipeline = encoder >> idx.np_retriever()
```

## References

- PyTerrier: PyTerrier: Declarative Experimentation in Python from BM25 to Dense Retrieval (Macdonald et al, CIKM 2021)
Expand Down
1 change: 1 addition & 0 deletions pyterrier_dr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .sbert_models import SBertBiEncoder, Ance, Query2Query, GTR
from .tctcolbert_model import TctColBert
from .electra import ElectraScorer
from .bge_m3 import BGEM3, BGEM3QueryEncoder, BGEM3DocEncoder
from .cde import CDE, CDECache
140 changes: 140 additions & 0 deletions pyterrier_dr/bge_m3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from tqdm import tqdm
import pyterrier as pt
import pandas as pd
import numpy as np
import torch
from .biencoder import BiEncoder

class BGEM3(BiEncoder):
def __init__(self, model_name='BAAI/bge-m3', batch_size=32, max_length=8192, text_field='text', verbose=False, device=None, use_fp16=False):
super().__init__(batch_size, text_field, verbose)
self.model_name = model_name
self.use_fp16 = use_fp16
self.max_length = max_length
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)
try:
from FlagEmbedding import BGEM3FlagModel
except ImportError as e:
raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install pyterrier-dr[bgem3]'")

self.model = BGEM3FlagModel(self.model_name, use_fp16=self.use_fp16, device=self.device)


def __repr__(self):
return f'BGEM3({repr(self.model_name)})'

def encode_queries(self, texts, batch_size=None):
return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length,
return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs']

def encode_docs(self, texts, batch_size=None):
return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length,
return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs']

# Only does dense (single_vec) encoding
def query_encoder(self, verbose=None, batch_size=None):
return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size)
def doc_encoder(self, verbose=None, batch_size=None):
return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size)

# Does all three BGE-M3 encodings: dense, sparse and colbert(multivec)
def query_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True):
return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs)
def doc_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True):
return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs)

class BGEM3QueryEncoder(pt.Transformer):
def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False):
self.bge_factory = bge_factory
self.verbose = verbose if verbose is not None else bge_factory.verbose
self.batch_size = batch_size if batch_size is not None else bge_factory.batch_size
self.max_length = max_length if max_length is not None else bge_factory.max_length

self.dense = return_dense
self.sparse = return_sparse
self.multivecs = return_colbert_vecs

def encode(self, texts):
return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length,
return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in ['query'])

# check if inp is empty
if len(inp) == 0:
if self.dense:
inp = inp.assign(query_vec=[])
if self.sparse:
inp = inp.assign(query_toks=[])
if self.multivecs:
inp = inp.assign(query_embs=[])
return inp

it = inp['query'].values
it, inv = np.unique(it, return_inverse=True)
if self.verbose:
it = pt.tqdm(it, desc='Encoding Queries', unit='query')
bgem3_results = self.encode(it)

if self.dense:
inp = inp.assign(query_vec=[bgem3_results['dense_vecs'][i] for i in inv])
if self.sparse:
# for sparse convert ids to the actual tokens
query_toks = self.bge_factory.model.convert_id_to_token(bgem3_results['lexical_weights'])
inp = inp.assign(query_toks=query_toks)
if self.multivecs:
inp = inp.assign(query_embs=[bgem3_results['colbert_vecs'][i] for i in inv])
return inp

def __repr__(self):
return f'{repr(self.bge_factory)}.query_encoder()'

class BGEM3DocEncoder(pt.Transformer):
def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False):
self.bge_factory = bge_factory
self.verbose = verbose if verbose is not None else bge_factory.verbose
self.batch_size = batch_size if batch_size is not None else bge_factory.batch_size
self.max_length = max_length if max_length is not None else bge_factory.max_length

self.dense = return_dense
self.sparse = return_sparse
self.multivecs = return_colbert_vecs

def encode(self, texts):
return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length,
return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# check if the input dataframe contains the field(s) specified in the text_field
assert all(c in inp.columns for c in [self.bge_factory.text_field])
# check if inp is empty
if len(inp) == 0:
if self.dense:
inp = inp.assign(doc_vec=[])
if self.sparse:
inp = inp.assign(toks=[])
if self.multivecs:
inp = inp.assign(doc_embs=[])
return inp

it = inp[self.bge_factory.text_field]
if self.verbose:
it = pt.tqdm(it, desc='Encoding Documents', unit='doc')
bgem3_results = self.encode(it)

if self.dense:
inp = inp.assign(doc_vec=list(bgem3_results['dense_vecs']))
if self.sparse:
toks = bgem3_results['lexical_weights']
# for sparse convert ids to the actual tokens
toks = self.bge_factory.model.convert_id_to_token(toks)
inp = inp.assign(toks=toks)
if self.multivecs:
inp = inp.assign(doc_embs=list(bgem3_results['colbert_vecs']))
return inp

def __repr__(self):
return f'{repr(self.bge_factory)}.doc_encoder()'
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytest
pytest-subtests
git+https://github.com/terrierteam/pyterrier_adaptive
voyager
FlagEmbedding
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def get_version(rel_path):
long_description_content_type="text/markdown",
packages=setuptools.find_packages(),
install_requires=requirements,
extras_require={
'bgem3': ['FlagEmbedding'],
},
python_requires='>=3.6',
entry_points={
'pyterrier.artifact': [
Expand Down
50 changes: 50 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,46 @@ def _base_test(self, model, test_query_encoder=True, test_doc_encoder=True, test
self.assertTrue('docno' in retr_res.columns)
self.assertTrue('score' in retr_res.columns)
self.assertTrue('rank' in retr_res.columns)

def _test_bgem3_multi(self, model, test_query_multivec_encoder=False, test_doc_multivec_encoder=False):
dataset = pt.get_dataset('irds:vaswani')

docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200))
docs_df = pd.DataFrame(docs)

if test_query_multivec_encoder:
with self.subTest('query_multivec_encoder'):
topics = dataset.get_topics()
enc_topics = model(topics)
self.assertEqual(len(enc_topics), len(topics))
self.assertTrue('query_toks' in enc_topics.columns)
self.assertTrue('query_embs' in enc_topics.columns)
self.assertTrue(all(c in enc_topics.columns for c in topics.columns))
self.assertEqual(enc_topics.query_toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_topics.query_toks))
self.assertEqual(enc_topics.query_embs.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_topics.query_embs))
with self.subTest('query_multivec_encoder empty'):
enc_topics_empty = model(pd.DataFrame(columns=['qid', 'query']))
self.assertEqual(len(enc_topics_empty), 0)
self.assertTrue('query_toks' in enc_topics_empty.columns)
self.assertTrue('query_embs' in enc_topics_empty.columns)
if test_doc_multivec_encoder:
with self.subTest('doc_multi_encoder'):
enc_docs = model(pd.DataFrame(docs_df))
self.assertEqual(len(enc_docs), len(docs_df))
self.assertTrue('toks' in enc_docs.columns)
self.assertTrue('doc_embs' in enc_docs.columns)
self.assertTrue(all(c in enc_docs.columns for c in docs_df.columns))
self.assertEqual(enc_docs.toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_docs.toks))
self.assertEqual(enc_docs.doc_embs.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_docs.doc_embs))
with self.subTest('doc_multi_encoder empty'):
enc_docs_empty = model(pd.DataFrame(columns=['docno', 'text']))
self.assertEqual(len(enc_docs_empty), 0)
self.assertTrue('toks' in enc_docs_empty.columns)
self.assertTrue('doc_embs' in enc_docs_empty.columns)

def test_tct(self):
from pyterrier_dr import TctColBert
Expand All @@ -129,6 +169,16 @@ def test_query2query(self):
from pyterrier_dr import Query2Query
self._base_test(Query2Query(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)

def test_bgem3(self):
from pyterrier_dr import BGEM3
# create BGEM3 instance
bgem3 = BGEM3(max_length=1024)

self._base_test(bgem3.query_multi_encoder(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)
self._base_test(bgem3.doc_multi_encoder(), test_query_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)

self._test_bgem3_multi(bgem3.query_multi_encoder(), test_query_multivec_encoder=True)
self._test_bgem3_multi(bgem3.doc_multi_encoder(), test_doc_multivec_encoder=True)
def setUp(self):
import pyterrier as pt
if not pt.started():
Expand Down

0 comments on commit 621ec9d

Please sign in to comment.