From bdfcf56d8b8626ea043e084eddf4c295db7d4e2a Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Tue, 2 Apr 2024 00:12:09 -0700 Subject: [PATCH] contribution: infinity-integration --- Makefile | 7 ++ pyproject.toml | 1 + requirements/test.txt | 1 + src/ragas/embeddings/base.py | 140 ++++++++++++++++++++++++++++------ tests/unit/test_embeddings.py | 51 +++++++++++++ 5 files changed, 176 insertions(+), 24 deletions(-) diff --git a/Makefile b/Makefile index bc562f08b..45959602b 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,13 @@ format: ## Running code formatter: black and isort @find src -name "*.pyi" ! -name "*_pb2*" -exec black --pyi --config pyproject.toml {} \; @echo "(ruff) Running fix only..." @ruff check src docs tests --fix-only +format-check: + @echo "(isort) Checking import order..." + @isort --check . + @echo "(black) Checking code formatting..." + @black --config pyproject.toml --check src tests docs + @echo "(ruff) Linting development project..." + @ruff check src docs tests --fix-only lint: ## Running lint checker: ruff @echo "(ruff) Linting development project..." @ruff check src docs tests diff --git a/pyproject.toml b/pyproject.toml index 0851232f3..5812f2044 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dynamic = ["version", "readme"] [project.optional-dependencies] all = [ "sentence-transformers", + "infinity_emb[all]", ] [tool.setuptools] diff --git a/requirements/test.txt b/requirements/test.txt index e2814f467..13b0db26c 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -2,3 +2,4 @@ pytest pytest-xdist[psutil] pytest-asyncio llama_index +pytest-asyncio \ No newline at end of file diff --git a/src/ragas/embeddings/base.py b/src/ragas/embeddings/base.py index 38d5efd72..a4380c150 100644 --- a/src/ragas/embeddings/base.py +++ b/src/ragas/embeddings/base.py @@ -2,7 +2,7 @@ import asyncio import typing as t -from abc import ABC +from abc import ABC, abstractmethod from dataclasses import field from typing import List @@ -15,6 +15,9 @@ DEFAULT_MODEL_NAME = "BAAI/bge-small-en-v1.5" +if t.TYPE_CHECKING: + from torch import Tensor + class BaseRagasEmbeddings(Embeddings, ABC): run_config: RunConfig @@ -26,7 +29,7 @@ async def embed_text(self, text: str, is_async=True) -> List[float]: async def embed_texts( self, texts: List[str], is_async: bool = True ) -> t.List[t.List[float]]: - if is_async: + if is_async and hasattr(self, "aembed_documents"): aembed_documents_with_retry = add_async_retry( self.aembed_documents, self.run_config ) @@ -41,6 +44,9 @@ async def embed_texts( def set_run_config(self, run_config: RunConfig): self.run_config = run_config + @abstractmethod + def embed_documents(self, texts: List[str]) -> List[List[float]]: ... + class LangchainEmbeddingsWrapper(BaseRagasEmbeddings): def __init__( @@ -60,7 +66,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_query(self, text: str) -> List[float]: return await self.embeddings.aembed_query(text) - async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + async def aembed_documents(self, texts: List[str]): return await self.embeddings.aembed_documents(texts) def set_run_config(self, run_config: RunConfig): @@ -110,47 +116,133 @@ def __post_init__(self): ) if self.is_cross_encoder: - self.model = sentence_transformers.CrossEncoder( + self._ce = sentence_transformers.CrossEncoder( self.model_name, **self.model_kwargs ) + self.model = self._ce + self.is_cross_encoder = True else: - self.model = sentence_transformers.SentenceTransformer( + self._st = sentence_transformers.SentenceTransformer( self.model_name, cache_folder=self.cache_folder, **self.model_kwargs ) + self.model = self._st + self.is_cross_encoder = False # ensure outputs are tensors - if "convert_to_tensor" not in self.encode_kwargs: - self.encode_kwargs["convert_to_tensor"] = True + self.encode_kwargs["convert_to_tensor"] = True def embed_query(self, text: str) -> List[float]: return self.embed_documents([text])[0] def embed_documents(self, texts: List[str]) -> List[List[float]]: - from sentence_transformers.SentenceTransformer import SentenceTransformer - from torch import Tensor - - assert isinstance( - self.model, SentenceTransformer - ), "Model is not of the type Bi-encoder" - embeddings = self.model.encode( + assert not self.is_cross_encoder, "Model is not of the type Bi-encoder" + embeddings: Tensor = self._st.encode( # type: ignore texts, normalize_embeddings=True, **self.encode_kwargs ) - - assert isinstance(embeddings, Tensor) return embeddings.tolist() def predict(self, texts: List[List[str]]) -> List[List[float]]: - from sentence_transformers.cross_encoder import CrossEncoder - from torch import Tensor + assert self.is_cross_encoder, "Model is not of the type CrossEncoder" + predictions: Tensor = self.model.predict(texts, **self.encode_kwargs) # type: ignore + return predictions.tolist() + + +@dataclass +class InfinityEmbeddings(BaseRagasEmbeddings): + """Infinity embeddings using infinity_emb package. + + usage: + ```python + embedding_engine = InfinityEmbeddings(model_name="BAAI/bge-small-en-v1.5") + async with embedding_engine: + embeddings = await embedding_engine.aembed_documents( + ["Paris is in France", "The capital of France is Paris", "Infintiy batches embeddings on the fly"] + ) - assert isinstance( - self.model, CrossEncoder - ), "Model is not of the type CrossEncoder" + reranking_engine = InfinityEmbeddings(model_name="BAAI/bge-reranker-base") + async with reranking_engine: + rankings = await reranking_engine.arerank("Where is Paris?", ["Paris is in France", "I don't know the capital of Paris.", "Dummy sentence"]) + ``` + """ - predictions = self.model.predict(texts, **self.encode_kwargs) + model_name: str = DEFAULT_MODEL_NAME + """Model name to use.""" + infinity_engine_kwargs: t.Dict[str, t.Any] = field(default_factory=dict) + """infinity engine keyword arguments. + { + batch_size: int = 64 + revision: str | None = None, + trust_remote_code: bool = True, + engine: str = torch | optimum | ctranslate2 + model_warmup: bool = False + vector_disk_cache_path: str = "" + device: Device | str = "auto" + lengths_via_tokenize: bool = False + } + """ - assert isinstance(predictions, Tensor) - return predictions.tolist() + def __post_init__(self): + try: + import infinity_emb + except ImportError as exc: + raise ImportError( + "Could not import infinity_emb python package. " + "Please install it with `pip install infinity-emb[torch,optimum]>=0.0.32`." + ) from exc + self.engine = infinity_emb.AsyncEmbeddingEngine( + model_name_or_path=self.model_name, **self.infinity_engine_kwargs + ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + raise NotImplementedError( + "Infinity embeddings does not support sync embeddings" + ) + + def embed_query(self, text: str) -> List[float]: + return self.embed_documents([text])[0] + + async def aembed_documents(self, texts: List[str]) -> t.List[t.List[float]]: + """vectorize documents using an embedding model and return embeddings""" + await self.__aenter__() + if "embed" not in self.engine.capabilities: + raise ValueError( + f"Model={self.model_name} does not have `embed` capability, but only {self.engine.capabilities}. " + "Try a different model, e.g. `model_name=BAAI/bge-small-en-v1.5`" + ) + # return embeddings + embeddings, _ = await self.engine.embed(sentences=texts) + return np.array(embeddings).tolist() + + async def aembed_query(self, text: str) -> t.List[float]: + """vectorize a query using an embedding model and return embeddings""" + embeddings = await self.aembed_documents([text]) + return embeddings[0] + + async def arerank(self, query: str, docs: List[str]) -> List[float]: + """rerank documents against a single query and return scores for each document""" + await self.__aenter__() + if "rerank" not in self.engine.capabilities: + raise ValueError( + f"Model={self.model_name} does not have `rerank` capability, but only {self.engine.capabilities}. " + "Try a different model, e.g. `model_name=mixedbread-ai/mxbai-rerank-base-v1`" + ) + # return predictions + rankings, _ = await self.engine.rerank(query=query, docs=docs) + return rankings + + async def __aenter__(self, *args, **kwargs): + if not self.engine.running: + await self.engine.astart() + + async def __aexit__(self, *args, **kwargs): + if self.engine.running: + await self.engine.astop() + + def __del__(self, *args, **kwargs): + if self.engine.running: + if not hasattr(self.engine, "stop"): + raise AttributeError("Engine does not have a stop method") + self.engine.stop() def embedding_factory(run_config: t.Optional[RunConfig] = None) -> BaseRagasEmbeddings: diff --git a/tests/unit/test_embeddings.py b/tests/unit/test_embeddings.py index 9d48db4f9..2d0810a4b 100644 --- a/tests/unit/test_embeddings.py +++ b/tests/unit/test_embeddings.py @@ -1 +1,52 @@ from __future__ import annotations + +import numpy as np +import pytest + +from ragas.embeddings.base import InfinityEmbeddings + +try: + import infinity_emb # noqa + import torch # noqa + + INFINITY_AVAILABLE = True +except ImportError: + INFINITY_AVAILABLE = False + + +@pytest.mark.skipif(not INFINITY_AVAILABLE, reason="infinity_emb is not installed.") +@pytest.mark.asyncio +async def test_basic_embedding(): + embedding_engine = InfinityEmbeddings(model_name="BAAI/bge-small-en-v1.5") + async with embedding_engine: + embeddings = await embedding_engine.aembed_documents( + [ + "Paris is in France", + "The capital of France is Paris", + "Infintiy batches embeddings on the fly", + ] + * 20 + ) + assert isinstance(embeddings, list) + array = np.array(embeddings) + assert array.shape == (60, 384) + assert array[0] @ array[1] > array[0] @ array[2] + + +@pytest.mark.skipif(not INFINITY_AVAILABLE, reason="infinity_emb is not installed.") +@pytest.mark.asyncio +async def test_rerank(): + rerank_engine = InfinityEmbeddings(model_name="BAAI/bge-reranker-base") + + async with rerank_engine: + rankings = await rerank_engine.arerank( + "Where is Paris?", + [ + "Paris is in France", + "I don't know the capital of Paris.", + "Dummy sentence", + ], + ) + assert len(rankings) == 3 + assert rankings[0] > rankings[1] + assert rankings[0] > rankings[2]