Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored colbert codebase #407

Merged
merged 15 commits into from
May 10, 2024
2 changes: 1 addition & 1 deletion libs/colbert/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ colbert-ai = "0.2.19"
pyarrow = "14.0.1"
torch = "2.2.1"
cassio = "~0.1.7"
nest-asyncio = "^1.6.0"
pydantic = "^2.7.1"

[tool.poetry.group.test.dependencies]
ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }
Expand Down
23 changes: 13 additions & 10 deletions libs/colbert/ragstack_colbert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,31 @@
and constants related to the ColBERT model configuration are also provided.

Exports:
- CassandraVectorStore: Implementation of a ColBERT vector store using Cassandra for storage.
- CassandraDatabase: Implementation of a BaseDatabase using Cassandra for storage.
- ColbertEmbeddingModel: Class for generating and managing token embeddings using the ColBERT model.
- ColbertVectorStore: Implementation of a BaseVectorStore.
- ColbertRetriever: Retriever class for executing ColBERT searches within a vector store.
- DEFAULT_COLBERT_MODEL: The default identifier for the ColBERT model.
- DEFAULT_COLBERT_DIM: The default dimensionality for ColBERT model embeddings.
- EmbeddedChunk: Data class for representing a chunk of embedded text.
- RetrievedChunk: Data class for representing a chunk of retrieved text.
- Chunk: Data class for representing a chunk of embedded text.
"""

from .cassandra_vector_store import CassandraVectorStore
from .colbert_retriever import ColbertRetriever
from .cassandra_database import CassandraDatabase
from .colbert_embedding_model import ColbertEmbeddingModel
from .colbert_retriever import ColbertRetriever
from .colbert_vector_store import ColbertVectorStore
from .constant import DEFAULT_COLBERT_DIM, DEFAULT_COLBERT_MODEL
from .objects import ChunkData, EmbeddedChunk, RetrievedChunk
from .objects import Chunk, Embedding, Metadata, Vector

__all__ = [
"CassandraVectorStore",
"ChunkData",
"CassandraDatabase",
"ColbertEmbeddingModel",
"ColbertRetriever",
"ColbertVectorStore",
"DEFAULT_COLBERT_DIM",
"DEFAULT_COLBERT_MODEL",
"EmbeddedChunk",
"RetrievedChunk",
"Chunk",
"Embedding",
"Metadata",
"Vector",
]
79 changes: 79 additions & 0 deletions libs/colbert/ragstack_colbert/base_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
This module defines abstract base classes for implementing storage mechanisms for text chunk
embeddings, specifically designed to work with ColBERT or similar embedding models.
"""

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple

from .objects import Chunk, Vector


class BaseDatabase(ABC):
"""
Abstract base class (ABC) for a storage system designed to hold vector representations of text chunks,
typically generated by a ColBERT model or similar embedding model.

This class defines the interface for storing and managing the embedded text chunks, supporting
operations like adding new chunks to the store and deleting existing documents by their identifiers.
"""

@abstractmethod
def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]:
"""
Stores a list of embedded text chunks in the vector store

Parameters:
chunks (List[Chunk]): A list of `Chunk` instances to be stored.

Returns:
a list of tuples: (doc_id, chunk_id)
"""

@abstractmethod
def delete_chunks(self, doc_ids: List[str]) -> bool:
"""
Deletes chunks from the vector store based on their document id.

Parameters:
doc_ids (List[str]): A list of document identifiers specifying the chunks to be deleted.

Returns:
True if the delete was successful.
"""

@abstractmethod
async def search_relevant_chunks(self, vector: Vector, n: int) -> List[Chunk]:
"""
Retrieves 'n' ANN results for an embedded token vector.

Returns:
A list of Chunks with only `doc_id` and `chunk_id` set.
Fewer than 'n' results may be returned.
"""

@abstractmethod
async def get_chunk_embedding(self, doc_id: str, chunk_id: int) -> Chunk:
"""
Retrieve the embedding data for a chunk.

Returns:
A chunk with `doc_id`, `chunk_id`, and `embedding` set.
"""

@abstractmethod
async def get_chunk_data(
self, doc_id: str, chunk_id: int, include_embedding: Optional[bool]
) -> Chunk:
"""
Retrieve the text and metadata for a chunk.

Returns:
A chunk with `doc_id`, `chunk_id`, `text`, `metadata`, and optionally `embedding` set.
"""

@abstractmethod
def close(self) -> None:
"""
Cleans up any open resources.
"""
28 changes: 8 additions & 20 deletions libs/colbert/ragstack_colbert/base_embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from abc import ABC, abstractmethod
from typing import List, Optional

from torch import Tensor

from .objects import ChunkData, EmbeddedChunk
from .objects import Embedding


class BaseEmbeddingModel(ABC):
Expand All @@ -20,25 +18,15 @@ class BaseEmbeddingModel(ABC):
"""

@abstractmethod
def embed_chunks(
self, chunks: List[ChunkData], doc_id: Optional[str] = None
) -> List[EmbeddedChunk]:
def embed_texts(self, texts: List[str]) -> List[Embedding]:
"""
Embeds a list of text chunks into their corresponding vector representations.

This method takes multiple chunks of text and optionally their associated document identifier,
returning a list of `EmbeddedChunk` instances containing the embeddings.
Embeds a list of texts into their corresponding vector embedding representations.

Parameters:
chunks (List[ChunkData]): A list of chunks including document text and any associated metadata.
doc_id (Optional[str], optional): An optional document identifier that all chunks belong to.
This can be used for tracing back embeddings to their
source document. If not passed, an uuid will be generated.
texts (List[str]): A list of string texts.

Returns:
List[EmbeddedChunk]: A list of `EmbeddedChunks` instances with embeddings populated,
corresponding to the input text chunks, ready for insertion into
a vector store.
List[Embedding]: A list of embeddings, in the order of the input list
"""

@abstractmethod
Expand All @@ -47,18 +35,18 @@ def embed_query(
query: str,
full_length_search: Optional[bool] = False,
query_maxlen: int = -1,
) -> Tensor:
) -> Embedding:
"""
Embeds a single query text into its vector representation.

If the query has fewer than query_maxlen tokens it will be padded with BERT special [mast] tokens.

Parameters:
query (str): The query string to encode.
query (str): The query text to encode.
full_length_search (Optional[bool]): Indicates whether to encode the query for a full-length search.
Defaults to False.
query_maxlen (int): The fixed length for the query token embedding. If -1, uses a dynamically calculated value.

Returns:
Tensor: A tensor representing the embedded query.
Embedding: A vector embedding representation of the query text
"""
100 changes: 89 additions & 11 deletions libs/colbert/ragstack_colbert/base_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
"""

from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple

from .objects import RetrievedChunk
from .objects import Chunk, Embedding


class BaseRetriever(ABC):
Expand All @@ -15,34 +15,112 @@ class BaseRetriever(ABC):
the search and retrieval of text chunks based on query embeddings.
"""

# handles LlamaIndex query
@abstractmethod
def close(self) -> None:
def embedding_search(
self,
query_embedding: Embedding,
k: Optional[int] = None,
include_embedding: Optional[bool] = False,
**kwargs: Any
) -> List[Tuple[Chunk, float]]:
"""
Retrieves a list of text chunks relevant to a given query from the vector store, ranked by
relevance or other metrics.

Parameters:
query_embedding (Embedding): The query embedding to search for relevant text chunks.
k (Optional[int]): The number of top results to retrieve.
include_embedding (Optional[bool]): Optional (default False) flag to include the
embedding vectors in the returned chunks
**kwargs (Any): Additional parameters that implementations might require for customized
retrieval operations.

Returns:
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, each representing a text chunk that is relevant
to the query, along with its similarity score.
"""

# handles LlamaIndex async query
@abstractmethod
async def aembedding_search(
self,
query_embedding: Embedding,
k: Optional[int] = None,
include_embedding: Optional[bool] = False,
**kwargs: Any
) -> List[Tuple[Chunk, float]]:
"""
Closes the retriever, releasing any resources or connections used during operation.
Implementations should ensure that all necessary cleanup is performed to avoid resource leaks.
Retrieves a list of text chunks relevant to a given query from the vector store, ranked by
relevance or other metrics.

Parameters:
query_embedding (Embedding): The query embedding to search for relevant text chunks.
k (Optional[int]): The number of top results to retrieve.
include_embedding (Optional[bool]): Optional (default False) flag to include the
embedding vectors in the returned chunks
**kwargs (Any): Additional parameters that implementations might require for customized
retrieval operations.

Returns:
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, each representing a text chunk that is relevant
to the query, along with its similarity score.
"""

# handles LangChain search
@abstractmethod
def text_search(
self,
query_text: str,
k: Optional[int] = None,
query_maxlen: Optional[int] = None,
include_embedding: Optional[bool] = False,
**kwargs: Any
) -> List[Tuple[Chunk, float]]:
"""
Retrieves a list of text chunks relevant to a given query from the vector store, ranked by
relevance or other metrics.

Parameters:
query_text (str): The query text to search for relevant text chunks.
k (Optional[int]): The number of top results to retrieve.
query_maxlen (Optional[int]): The maximum length of the query to consider. If None, the
maxlen will be dynamically generated.
include_embedding (Optional[bool]): Optional (default False) flag to include the
embedding vectors in the returned chunks
**kwargs (Any): Additional parameters that implementations might require for customized
retrieval operations.

Returns:
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, each representing a text chunk that is relevant
to the query, along with its similarity score.
"""

# handles LangChain async search
@abstractmethod
def retrieve(
async def atext_search(
self,
query: str,
query_text: str,
k: Optional[int] = None,
query_maxlen: Optional[int] = None,
include_embedding: Optional[bool] = False,
**kwargs: Any
) -> List[RetrievedChunk]:
) -> List[Tuple[Chunk, float]]:
"""
Retrieves a list of text chunks relevant to a given query from the vector store, ranked by
relevance or other metrics.

Parameters:
query (str): The query text to search for relevant text chunks.
query_text (str): The query text to search for relevant text chunks.
k (Optional[int]): The number of top results to retrieve.
query_maxlen (Optional[int]): The maximum length of the query to consider. If None, the
maxlen will be dynamically generated.
include_embedding (Optional[bool]): Optional (default False) flag to include the
embedding vectors in the returned chunks
**kwargs (Any): Additional parameters that implementations might require for customized
retrieval operations.

Returns:
List[RetrievedChunk]: A list of `RetrievedChunk` instances representing the retrieved
text chunks, ranked by their relevance to the query.
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, each representing a text chunk that is relevant
to the query, along with its similarity score.
"""
Loading
Loading