Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

contribution: infinity-integration #834

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ format: ## Running code formatter: black and isort
@find src -name "*.pyi" ! -name "*_pb2*" -exec black --pyi --config pyproject.toml {} \;
@echo "(ruff) Running fix only..."
@ruff check src docs tests --fix-only
format-check:
@echo "(isort) Checking import order..."
@isort --check .
@echo "(black) Checking code formatting..."
@black --config pyproject.toml --check src tests docs
@echo "(ruff) Linting development project..."
@ruff check src docs tests --fix-only
lint: ## Running lint checker: ruff
@echo "(ruff) Linting development project..."
@ruff check src docs tests
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dynamic = ["version", "readme"]
[project.optional-dependencies]
all = [
"sentence-transformers",
"infinity_emb[all]",
]

[tool.setuptools]
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytest
pytest-xdist[psutil]
pytest-asyncio
llama_index
pytest-asyncio
140 changes: 116 additions & 24 deletions src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import typing as t
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import field
from typing import List

Expand All @@ -15,6 +15,9 @@

DEFAULT_MODEL_NAME = "BAAI/bge-small-en-v1.5"

if t.TYPE_CHECKING:
from torch import Tensor


class BaseRagasEmbeddings(Embeddings, ABC):
run_config: RunConfig
Expand All @@ -26,7 +29,7 @@ async def embed_text(self, text: str, is_async=True) -> List[float]:
async def embed_texts(
self, texts: List[str], is_async: bool = True
) -> t.List[t.List[float]]:
if is_async:
if is_async and hasattr(self, "aembed_documents"):
aembed_documents_with_retry = add_async_retry(
self.aembed_documents, self.run_config
)
Expand All @@ -41,6 +44,9 @@ async def embed_texts(
def set_run_config(self, run_config: RunConfig):
self.run_config = run_config

@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]: ...


class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
def __init__(
Expand All @@ -60,7 +66,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_query(self, text: str) -> List[float]:
return await self.embeddings.aembed_query(text)

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: List[str]):
return await self.embeddings.aembed_documents(texts)

def set_run_config(self, run_config: RunConfig):
Expand Down Expand Up @@ -110,47 +116,133 @@ def __post_init__(self):
)

if self.is_cross_encoder:
self.model = sentence_transformers.CrossEncoder(
self._ce = sentence_transformers.CrossEncoder(
self.model_name, **self.model_kwargs
)
self.model = self._ce
self.is_cross_encoder = True
else:
self.model = sentence_transformers.SentenceTransformer(
self._st = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
self.model = self._st
self.is_cross_encoder = False

# ensure outputs are tensors
if "convert_to_tensor" not in self.encode_kwargs:
self.encode_kwargs["convert_to_tensor"] = True
self.encode_kwargs["convert_to_tensor"] = True

def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
from sentence_transformers.SentenceTransformer import SentenceTransformer
from torch import Tensor

assert isinstance(
self.model, SentenceTransformer
), "Model is not of the type Bi-encoder"
embeddings = self.model.encode(
assert not self.is_cross_encoder, "Model is not of the type Bi-encoder"
embeddings: Tensor = self._st.encode( # type: ignore
texts, normalize_embeddings=True, **self.encode_kwargs
)

assert isinstance(embeddings, Tensor)
return embeddings.tolist()

def predict(self, texts: List[List[str]]) -> List[List[float]]:
from sentence_transformers.cross_encoder import CrossEncoder
from torch import Tensor
assert self.is_cross_encoder, "Model is not of the type CrossEncoder"
predictions: Tensor = self.model.predict(texts, **self.encode_kwargs) # type: ignore
return predictions.tolist()


@dataclass
class InfinityEmbeddings(BaseRagasEmbeddings):
"""Infinity embeddings using infinity_emb package.

usage:
```python
embedding_engine = InfinityEmbeddings(model_name="BAAI/bge-small-en-v1.5")
async with embedding_engine:
embeddings = await embedding_engine.aembed_documents(
["Paris is in France", "The capital of France is Paris", "Infintiy batches embeddings on the fly"]
)

assert isinstance(
self.model, CrossEncoder
), "Model is not of the type CrossEncoder"
reranking_engine = InfinityEmbeddings(model_name="BAAI/bge-reranker-base")
async with reranking_engine:
rankings = await reranking_engine.arerank("Where is Paris?", ["Paris is in France", "I don't know the capital of Paris.", "Dummy sentence"])
```
"""

predictions = self.model.predict(texts, **self.encode_kwargs)
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""
infinity_engine_kwargs: t.Dict[str, t.Any] = field(default_factory=dict)
"""infinity engine keyword arguments.
{
batch_size: int = 64
revision: str | None = None,
trust_remote_code: bool = True,
engine: str = torch | optimum | ctranslate2
model_warmup: bool = False
vector_disk_cache_path: str = ""
device: Device | str = "auto"
lengths_via_tokenize: bool = False
}
"""

assert isinstance(predictions, Tensor)
return predictions.tolist()
def __post_init__(self):
try:
import infinity_emb
except ImportError as exc:
raise ImportError(
"Could not import infinity_emb python package. "
"Please install it with `pip install infinity-emb[torch,optimum]>=0.0.32`."
) from exc
self.engine = infinity_emb.AsyncEmbeddingEngine(
model_name_or_path=self.model_name, **self.infinity_engine_kwargs
)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError(
"Infinity embeddings does not support sync embeddings"
)

def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]

async def aembed_documents(self, texts: List[str]) -> t.List[t.List[float]]:
"""vectorize documents using an embedding model and return embeddings"""
await self.__aenter__()
if "embed" not in self.engine.capabilities:
raise ValueError(
f"Model={self.model_name} does not have `embed` capability, but only {self.engine.capabilities}. "
"Try a different model, e.g. `model_name=BAAI/bge-small-en-v1.5`"
)
# return embeddings
embeddings, _ = await self.engine.embed(sentences=texts)
return np.array(embeddings).tolist()

async def aembed_query(self, text: str) -> t.List[float]:
"""vectorize a query using an embedding model and return embeddings"""
embeddings = await self.aembed_documents([text])
return embeddings[0]

async def arerank(self, query: str, docs: List[str]) -> List[float]:
"""rerank documents against a single query and return scores for each document"""
await self.__aenter__()
if "rerank" not in self.engine.capabilities:
raise ValueError(
f"Model={self.model_name} does not have `rerank` capability, but only {self.engine.capabilities}. "
"Try a different model, e.g. `model_name=mixedbread-ai/mxbai-rerank-base-v1`"
)
# return predictions
rankings, _ = await self.engine.rerank(query=query, docs=docs)
return rankings

async def __aenter__(self, *args, **kwargs):
if not self.engine.running:
await self.engine.astart()

async def __aexit__(self, *args, **kwargs):
if self.engine.running:
await self.engine.astop()

def __del__(self, *args, **kwargs):
if self.engine.running:
if not hasattr(self.engine, "stop"):
raise AttributeError("Engine does not have a stop method")
self.engine.stop()


def embedding_factory(run_config: t.Optional[RunConfig] = None) -> BaseRagasEmbeddings:
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1 +1,52 @@
from __future__ import annotations

import numpy as np
import pytest

from ragas.embeddings.base import InfinityEmbeddings

try:
import infinity_emb # noqa
import torch # noqa

INFINITY_AVAILABLE = True
except ImportError:
INFINITY_AVAILABLE = False


@pytest.mark.skipif(not INFINITY_AVAILABLE, reason="infinity_emb is not installed.")
@pytest.mark.asyncio
async def test_basic_embedding():
embedding_engine = InfinityEmbeddings(model_name="BAAI/bge-small-en-v1.5")
async with embedding_engine:
embeddings = await embedding_engine.aembed_documents(
[
"Paris is in France",
"The capital of France is Paris",
"Infintiy batches embeddings on the fly",
]
* 20
)
assert isinstance(embeddings, list)
array = np.array(embeddings)
assert array.shape == (60, 384)
assert array[0] @ array[1] > array[0] @ array[2]


@pytest.mark.skipif(not INFINITY_AVAILABLE, reason="infinity_emb is not installed.")
@pytest.mark.asyncio
async def test_rerank():
rerank_engine = InfinityEmbeddings(model_name="BAAI/bge-reranker-base")

async with rerank_engine:
rankings = await rerank_engine.arerank(
"Where is Paris?",
[
"Paris is in France",
"I don't know the capital of Paris.",
"Dummy sentence",
],
)
assert len(rankings) == 3
assert rankings[0] > rankings[1]
assert rankings[0] > rankings[2]
Loading