diff --git a/libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py b/libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py index 5564d41f1..7e2030f15 100644 --- a/libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py +++ b/libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar from langchain_core.documents import Document from langchain_core.embeddings import Embeddings @@ -13,6 +13,8 @@ from ragstack_colbert.base_vector_store import BaseVectorStore as ColbertBaseVectorStore from typing_extensions import override +from ragstack_langchain.colbert.embedding import TokensEmbeddings + CVS = TypeVar("CVS", bound="ColbertVectorStore") @@ -209,7 +211,7 @@ async def asimilarity_search_with_score( def from_documents( cls, documents: List[Document], - embedding: Union[Embeddings, ColbertBaseEmbeddingModel], + embedding: Embeddings, *, database: Optional[ColbertBaseDatabase] = None, **kwargs: Any, @@ -220,7 +222,7 @@ def from_documents( return cls.from_texts( texts=texts, database=database, - embedding_model=embedding, + embedding=embedding, metadatas=metadatas, **kwargs, ) @@ -230,7 +232,7 @@ def from_documents( async def afrom_documents( cls: Type[CVS], documents: List[Document], - embedding: Union[Embeddings, ColbertBaseEmbeddingModel], + embedding: Embeddings, *, database: Optional[ColbertBaseDatabase] = None, concurrent_inserts: Optional[int] = 100, @@ -242,7 +244,7 @@ async def afrom_documents( return await cls.afrom_texts( texts=texts, database=database, - embedding_model=embedding, + embedding=embedding, metadatas=metadatas, concurrent_inserts=concurrent_inserts, **kwargs, @@ -253,22 +255,23 @@ async def afrom_documents( def from_texts( cls: Type[CVS], texts: List[str], - embedding: Union[Embeddings, ColbertBaseEmbeddingModel], + embedding: Embeddings, metadatas: Optional[List[dict]] = None, *, database: Optional[ColbertBaseDatabase] = None, **kwargs: Any, ) -> CVS: - if isinstance(embedding, Embeddings): + if not isinstance(embedding, TokensEmbeddings): raise TypeError( - "ColbertVectorStore needs a ColbertBaseEmbeddingModel embedding, " - "not an Embeddings object." + "ColbertVectorStore requires a ColbertEmbeddings embedding." ) if database is None: raise ValueError( "ColbertVectorStore requires a ColbertBaseDatabase database." ) - instance = cls(database=database, embedding_model=embedding, **kwargs) + instance = cls( + database=database, embedding_model=embedding.get_embedding_model(), **kwargs + ) instance.add_texts(texts=texts, metadatas=metadatas) return instance @@ -277,23 +280,22 @@ def from_texts( async def afrom_texts( cls: Type[CVS], texts: List[str], - embedding: Union[Embeddings, ColbertBaseEmbeddingModel], + embedding: Embeddings, metadatas: Optional[List[dict]] = None, *, database: Optional[ColbertBaseDatabase] = None, concurrent_inserts: Optional[int] = 100, **kwargs: Any, ) -> CVS: - if isinstance(embedding, Embeddings): - raise TypeError( - "ColbertVectorStore needs a ColbertBaseEmbeddingModel embedding, " - "not an Embeddings object." - ) + if not isinstance(embedding, TokensEmbeddings): + raise TypeError("ColbertVectorStore requires a TokensEmbeddings embedding.") if database is None: raise ValueError( "ColbertVectorStore requires a ColbertBaseDatabase database." ) - instance = cls(database=database, embedding_model=embedding, **kwargs) + instance = cls( + database=database, embedding_model=embedding.get_embedding_model(), **kwargs + ) await instance.aadd_texts( texts=texts, metadatas=metadatas, concurrent_inserts=concurrent_inserts ) diff --git a/libs/langchain/ragstack_langchain/colbert/embedding.py b/libs/langchain/ragstack_langchain/colbert/embedding.py new file mode 100644 index 000000000..6528dd05e --- /dev/null +++ b/libs/langchain/ragstack_langchain/colbert/embedding.py @@ -0,0 +1,58 @@ +from typing import List, Optional + +from langchain_core.embeddings import Embeddings +from ragstack_colbert import DEFAULT_COLBERT_MODEL, ColbertEmbeddingModel +from ragstack_colbert.base_embedding_model import BaseEmbeddingModel +from typing_extensions import override + + +class TokensEmbeddings(Embeddings): + """Adapter for token-based embedding models and the LangChain Embeddings.""" + + def __init__(self, embedding: BaseEmbeddingModel = None): + self.embedding = embedding or ColbertEmbeddingModel() + + @override + def embed_documents(self, texts: List[str]) -> List[List[float]]: + raise NotImplementedError + + @override + def embed_query(self, text: str) -> List[float]: + raise NotImplementedError + + @override + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + raise NotImplementedError + + @override + async def aembed_query(self, text: str) -> List[float]: + raise NotImplementedError + + def get_embedding_model(self) -> BaseEmbeddingModel: + """Get the embedding model.""" + return self.embedding + + @staticmethod + def colbert( + checkpoint: str = DEFAULT_COLBERT_MODEL, + doc_maxlen: int = 256, + nbits: int = 2, + kmeans_niters: int = 4, + nranks: int = -1, + query_maxlen: Optional[int] = None, + verbose: int = 3, + chunk_batch_size: int = 640, + ): + """Create a new ColBERT embedding model.""" + return TokensEmbeddings( + ColbertEmbeddingModel( + checkpoint, + doc_maxlen, + nbits, + kmeans_niters, + nranks, + query_maxlen, + verbose, + chunk_batch_size, + ) + ) diff --git a/libs/langchain/tests/integration_tests/test_colbert.py b/libs/langchain/tests/integration_tests/test_colbert.py index 380fdaa32..cd7526c46 100644 --- a/libs/langchain/tests/integration_tests/test_colbert.py +++ b/libs/langchain/tests/integration_tests/test_colbert.py @@ -5,8 +5,9 @@ from cassandra.cluster import Session from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document -from ragstack_colbert import CassandraDatabase, ColbertEmbeddingModel +from ragstack_colbert import CassandraDatabase from ragstack_langchain.colbert import ColbertVectorStore +from ragstack_langchain.colbert.embedding import TokensEmbeddings from ragstack_tests_utils import TestData from transformers import BertTokenizer @@ -72,7 +73,7 @@ def test_sync_from_docs(session: Session) -> None: batch_size = 5 # 640 recommended for production use chunk_size = 250 - embedding_model = ColbertEmbeddingModel( + embedding = TokensEmbeddings.colbert( doc_maxlen=chunk_size, chunk_batch_size=batch_size, ) @@ -81,7 +82,7 @@ def test_sync_from_docs(session: Session) -> None: doc_chunks: List[Document] = get_test_chunks() vector_store: ColbertVectorStore = ColbertVectorStore.from_documents( - documents=doc_chunks, database=database, embedding_model=embedding_model + documents=doc_chunks, database=database, embedding=embedding ) results: List[Document] = vector_store.similarity_search( @@ -124,7 +125,7 @@ async def test_async_from_docs(session: Session) -> None: batch_size = 5 # 640 recommended for production use chunk_size = 250 - embedding_model = ColbertEmbeddingModel( + embedding = TokensEmbeddings.colbert( doc_maxlen=chunk_size, chunk_batch_size=batch_size, ) @@ -133,7 +134,7 @@ async def test_async_from_docs(session: Session) -> None: doc_chunks: List[Document] = get_test_chunks() vector_store: ColbertVectorStore = await ColbertVectorStore.afrom_documents( - documents=doc_chunks, database=database, embedding_model=embedding_model + documents=doc_chunks, database=database, embedding=embedding ) results: List[Document] = await vector_store.asimilarity_search(