Skip to content

Commit

Permalink
update bgem-m3 multi-vec columns and README
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaschari committed Nov 19, 2024
1 parent 701fe54 commit fee706a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 29 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,7 @@ What encodings are returned by both `query_multi_encoder()` and `doc_multi_encod

### Dependencies

The BGE-M3 Encoder requires the [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding) library. You can install it using pip or install it as part of the `bgem3` dependency of `pyterrier_dr` (see Installation section):

```bash
pip install -U FlagEmbedding
```
The BGE-M3 Encoder requires the [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding) library. You can install it as part of the `bgem3` dependency of `pyterrier_dr` (see Installation section).

### Indexing

Expand Down
24 changes: 9 additions & 15 deletions pyterrier_dr/bge_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, model_name='BAAI/bge-m3', batch_size=32, max_length=8192, tex
try:
from FlagEmbedding import BGEM3FlagModel
except ImportError as e:
raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install -U FlagEmbedding'")
raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install pyterrier-dr[bgem3]'")

self.model = BGEM3FlagModel(self.model_name, use_fp16=self.use_fp16, device=self.device)

Expand All @@ -33,13 +33,13 @@ def encode_docs(self, texts, batch_size=None):
return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length,
return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs']

# Only does single_vec encoding
# Only does dense (single_vec) encoding
def query_encoder(self, verbose=None, batch_size=None):
return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size)
def doc_encoder(self, verbose=None, batch_size=None):
return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size)

# Can do dense, sparse and colbert encodings
# Does all three BGE-M3 encodings: dense, sparse and colbert(multivec)
def query_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True):
return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs)
def doc_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True):
Expand Down Expand Up @@ -70,7 +70,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
if self.sparse:
inp = inp.assign(query_toks=[])
if self.multivecs:
inp = inp.assign(query_embs_toks=[])
inp = inp.assign(query_embs=[])
return inp

it = inp['query'].values
Expand All @@ -80,16 +80,13 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
bgem3_results = self.encode(it)

if self.dense:
query_vec = [bgem3_results['dense_vecs'][i] for i in inv]
inp = inp.assign(query_vec=query_vec)
inp = inp.assign(query_vec=[bgem3_results['dense_vecs'][i] for i in inv])
if self.sparse:
# for sparse convert ids to the actual tokens
query_toks = self.bge_factory.model.convert_id_to_token(bgem3_results['lexical_weights'])
inp = inp.assign(query_toks=query_toks)
if self.multivecs:
query_embs_toks = [bgem3_results['colbert_vecs'][i] for i in inv]
inp = inp.assign(query_embs_toks=query_embs_toks)

inp = inp.assign(query_embs=[bgem3_results['colbert_vecs'][i] for i in inv])
return inp

def __repr__(self):
Expand Down Expand Up @@ -120,7 +117,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
if self.sparse:
inp = inp.assign(toks=[])
if self.multivecs:
inp = inp.assign(doc_embs_toks=[])
inp = inp.assign(doc_embs=[])
return inp

it = inp[self.bge_factory.text_field]
Expand All @@ -129,17 +126,14 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
bgem3_results = self.encode(it)

if self.dense:
doc_vec = bgem3_results['dense_vecs']
inp = inp.assign(doc_vec=list(doc_vec))
inp = inp.assign(doc_vec=list(bgem3_results['dense_vecs']))
if self.sparse:
toks = bgem3_results['lexical_weights']
# for sparse convert ids to the actual tokens
toks = self.bge_factory.model.convert_id_to_token(toks)
inp = inp.assign(toks=toks)
if self.multivecs:
doc_embs_toks = bgem3_results['colbert_vecs']
inp = inp.assign(doc_embs_toks=list(doc_embs_toks))

inp = inp.assign(doc_embs=list(bgem3_results['colbert_vecs']))
return inp

def __repr__(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_version(rel_path):
packages=setuptools.find_packages(),
install_requires=requirements,
extras_require={
['bgem3']: ['FlagEmbedding'],
'bgem3': ['FlagEmbedding'],
},
python_requires='>=3.6',
entry_points={
Expand Down
16 changes: 8 additions & 8 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,33 +117,33 @@ def _test_bgem3_multi(self, model, test_query_multivec_encoder=False, test_doc_m
enc_topics = model(topics)
self.assertEqual(len(enc_topics), len(topics))
self.assertTrue('query_toks' in enc_topics.columns)
self.assertTrue('query_embs_toks' in enc_topics.columns)
self.assertTrue('query_embs' in enc_topics.columns)
self.assertTrue(all(c in enc_topics.columns for c in topics.columns))
self.assertEqual(enc_topics.query_toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_topics.query_toks))
self.assertEqual(enc_topics.query_embs_toks.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_topics.query_embs_toks))
self.assertEqual(enc_topics.query_embs.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_topics.query_embs))
with self.subTest('query_multivec_encoder empty'):
enc_topics_empty = model(pd.DataFrame(columns=['qid', 'query']))
self.assertEqual(len(enc_topics_empty), 0)
self.assertTrue('query_toks' in enc_topics_empty.columns)
self.assertTrue('query_embs_toks' in enc_topics_empty.columns)
self.assertTrue('query_embs' in enc_topics_empty.columns)
if test_doc_multivec_encoder:
with self.subTest('doc_multi_encoder'):
enc_docs = model(pd.DataFrame(docs_df))
self.assertEqual(len(enc_docs), len(docs_df))
self.assertTrue('toks' in enc_docs.columns)
self.assertTrue('doc_embs_toks' in enc_docs.columns)
self.assertTrue('doc_embs' in enc_docs.columns)
self.assertTrue(all(c in enc_docs.columns for c in docs_df.columns))
self.assertEqual(enc_docs.toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_docs.toks))
self.assertEqual(enc_docs.doc_embs_toks.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_docs.doc_embs_toks))
self.assertEqual(enc_docs.doc_embs.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_docs.doc_embs))
with self.subTest('doc_multi_encoder empty'):
enc_docs_empty = model(pd.DataFrame(columns=['docno', 'text']))
self.assertEqual(len(enc_docs_empty), 0)
self.assertTrue('toks' in enc_docs_empty.columns)
self.assertTrue('doc_embs_toks' in enc_docs_empty.columns)
self.assertTrue('doc_embs' in enc_docs_empty.columns)

def test_tct(self):
from pyterrier_dr import TctColBert
Expand Down

0 comments on commit fee706a

Please sign in to comment.