Skip to content

Commit

Permalink
Wrap BaseEmbeddingModel in a LangChain Embeddings implementation class
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jul 25, 2024
1 parent 058dc24 commit a9c2a6b
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 22 deletions.
36 changes: 19 additions & 17 deletions libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
Expand All @@ -13,6 +13,8 @@
from ragstack_colbert.base_vector_store import BaseVectorStore as ColbertBaseVectorStore
from typing_extensions import override

from ragstack_langchain.colbert.embedding import TokensEmbeddings

CVS = TypeVar("CVS", bound="ColbertVectorStore")


Expand Down Expand Up @@ -209,7 +211,7 @@ async def asimilarity_search_with_score(
def from_documents(
cls,
documents: List[Document],
embedding: Union[Embeddings, ColbertBaseEmbeddingModel],
embedding: Embeddings,
*,
database: Optional[ColbertBaseDatabase] = None,
**kwargs: Any,
Expand All @@ -220,7 +222,7 @@ def from_documents(
return cls.from_texts(
texts=texts,
database=database,
embedding_model=embedding,
embedding=embedding,
metadatas=metadatas,
**kwargs,
)
Expand All @@ -230,7 +232,7 @@ def from_documents(
async def afrom_documents(
cls: Type[CVS],
documents: List[Document],
embedding: Union[Embeddings, ColbertBaseEmbeddingModel],
embedding: Embeddings,
*,
database: Optional[ColbertBaseDatabase] = None,
concurrent_inserts: Optional[int] = 100,
Expand All @@ -242,7 +244,7 @@ async def afrom_documents(
return await cls.afrom_texts(
texts=texts,
database=database,
embedding_model=embedding,
embedding=embedding,
metadatas=metadatas,
concurrent_inserts=concurrent_inserts,
**kwargs,
Expand All @@ -253,22 +255,23 @@ async def afrom_documents(
def from_texts(
cls: Type[CVS],
texts: List[str],
embedding: Union[Embeddings, ColbertBaseEmbeddingModel],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
*,
database: Optional[ColbertBaseDatabase] = None,
**kwargs: Any,
) -> CVS:
if isinstance(embedding, Embeddings):
if not isinstance(embedding, TokensEmbeddings):
raise TypeError(
"ColbertVectorStore needs a ColbertBaseEmbeddingModel embedding, "
"not an Embeddings object."
"ColbertVectorStore requires a ColbertEmbeddings embedding."
)
if database is None:
raise ValueError(
"ColbertVectorStore requires a ColbertBaseDatabase database."
)
instance = cls(database=database, embedding_model=embedding, **kwargs)
instance = cls(
database=database, embedding_model=embedding.get_embedding_model(), **kwargs
)
instance.add_texts(texts=texts, metadatas=metadatas)
return instance

Expand All @@ -277,23 +280,22 @@ def from_texts(
async def afrom_texts(
cls: Type[CVS],
texts: List[str],
embedding: Union[Embeddings, ColbertBaseEmbeddingModel],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
*,
database: Optional[ColbertBaseDatabase] = None,
concurrent_inserts: Optional[int] = 100,
**kwargs: Any,
) -> CVS:
if isinstance(embedding, Embeddings):
raise TypeError(
"ColbertVectorStore needs a ColbertBaseEmbeddingModel embedding, "
"not an Embeddings object."
)
if not isinstance(embedding, TokensEmbeddings):
raise TypeError("ColbertVectorStore requires a TokensEmbeddings embedding.")
if database is None:
raise ValueError(
"ColbertVectorStore requires a ColbertBaseDatabase database."
)
instance = cls(database=database, embedding_model=embedding, **kwargs)
instance = cls(
database=database, embedding_model=embedding.get_embedding_model(), **kwargs
)
await instance.aadd_texts(
texts=texts, metadatas=metadatas, concurrent_inserts=concurrent_inserts
)
Expand Down
58 changes: 58 additions & 0 deletions libs/langchain/ragstack_langchain/colbert/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import List, Optional

from langchain_core.embeddings import Embeddings
from ragstack_colbert import DEFAULT_COLBERT_MODEL, ColbertEmbeddingModel
from ragstack_colbert.base_embedding_model import BaseEmbeddingModel
from typing_extensions import override


class TokensEmbeddings(Embeddings):
"""Adapter for token-based embedding models and the LangChain Embeddings."""

def __init__(self, embedding: BaseEmbeddingModel = None):
self.embedding = embedding or ColbertEmbeddingModel()

@override
def embed_documents(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError

@override
def embed_query(self, text: str) -> List[float]:
raise NotImplementedError

@override
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError

@override
async def aembed_query(self, text: str) -> List[float]:
raise NotImplementedError

def get_embedding_model(self) -> BaseEmbeddingModel:
"""Get the embedding model."""
return self.embedding

@staticmethod
def colbert(
checkpoint: str = DEFAULT_COLBERT_MODEL,
doc_maxlen: int = 256,
nbits: int = 2,
kmeans_niters: int = 4,
nranks: int = -1,
query_maxlen: Optional[int] = None,
verbose: int = 3,
chunk_batch_size: int = 640,
):
"""Create a new ColBERT embedding model."""
return TokensEmbeddings(
ColbertEmbeddingModel(
checkpoint,
doc_maxlen,
nbits,
kmeans_niters,
nranks,
query_maxlen,
verbose,
chunk_batch_size,
)
)
11 changes: 6 additions & 5 deletions libs/langchain/tests/integration_tests/test_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from cassandra.cluster import Session
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from ragstack_colbert import CassandraDatabase, ColbertEmbeddingModel
from ragstack_colbert import CassandraDatabase
from ragstack_langchain.colbert import ColbertVectorStore
from ragstack_langchain.colbert.embedding import TokensEmbeddings
from ragstack_tests_utils import TestData
from transformers import BertTokenizer

Expand Down Expand Up @@ -72,7 +73,7 @@ def test_sync_from_docs(session: Session) -> None:
batch_size = 5 # 640 recommended for production use
chunk_size = 250

embedding_model = ColbertEmbeddingModel(
embedding = TokensEmbeddings.colbert(
doc_maxlen=chunk_size,
chunk_batch_size=batch_size,
)
Expand All @@ -81,7 +82,7 @@ def test_sync_from_docs(session: Session) -> None:

doc_chunks: List[Document] = get_test_chunks()
vector_store: ColbertVectorStore = ColbertVectorStore.from_documents(
documents=doc_chunks, database=database, embedding_model=embedding_model
documents=doc_chunks, database=database, embedding=embedding
)

results: List[Document] = vector_store.similarity_search(
Expand Down Expand Up @@ -124,7 +125,7 @@ async def test_async_from_docs(session: Session) -> None:
batch_size = 5 # 640 recommended for production use
chunk_size = 250

embedding_model = ColbertEmbeddingModel(
embedding = TokensEmbeddings.colbert(
doc_maxlen=chunk_size,
chunk_batch_size=batch_size,
)
Expand All @@ -133,7 +134,7 @@ async def test_async_from_docs(session: Session) -> None:

doc_chunks: List[Document] = get_test_chunks()
vector_store: ColbertVectorStore = await ColbertVectorStore.afrom_documents(
documents=doc_chunks, database=database, embedding_model=embedding_model
documents=doc_chunks, database=database, embedding=embedding
)

results: List[Document] = await vector_store.asimilarity_search(
Expand Down

0 comments on commit a9c2a6b

Please sign in to comment.