From 1f36c91ec28e6e60ce4d6e9b7495f79aa9f4d3c4 Mon Sep 17 00:00:00 2001 From: ming luo Date: Wed, 13 Mar 2024 16:02:22 -0400 Subject: [PATCH] base colbert implementation --- .github/workflows/modules.yml | 57 ++++ ragstack/colbertbase/README.md | 14 + ragstack/colbertbase/colbertbase/__init__.py | 17 ++ .../colbertbase/astra_colbert_embedding.py | 246 ++++++++++++++++++ ragstack/colbertbase/colbertbase/astra_db.py | 180 +++++++++++++ .../colbertbase/astra_retriever.py | 139 ++++++++++ ragstack/colbertbase/colbertbase/constant.py | 4 + .../colbertbase/token_embedding.py | 115 ++++++++ ragstack/colbertbase/pyproject.toml | 21 ++ ragstack/colbertbase/tests/__init__.py | 1 + .../tests/test_astra_colbert_embeddings.py | 60 +++++ .../colbertbase/tests/test_astra_retriever.py | 22 ++ 12 files changed, 876 insertions(+) create mode 100644 .github/workflows/modules.yml create mode 100644 ragstack/colbertbase/README.md create mode 100644 ragstack/colbertbase/colbertbase/__init__.py create mode 100644 ragstack/colbertbase/colbertbase/astra_colbert_embedding.py create mode 100644 ragstack/colbertbase/colbertbase/astra_db.py create mode 100644 ragstack/colbertbase/colbertbase/astra_retriever.py create mode 100644 ragstack/colbertbase/colbertbase/constant.py create mode 100644 ragstack/colbertbase/colbertbase/token_embedding.py create mode 100644 ragstack/colbertbase/pyproject.toml create mode 100644 ragstack/colbertbase/tests/__init__.py create mode 100644 ragstack/colbertbase/tests/test_astra_colbert_embeddings.py create mode 100644 ragstack/colbertbase/tests/test_astra_retriever.py diff --git a/.github/workflows/modules.yml b/.github/workflows/modules.yml new file mode 100644 index 000000000..d5b09bd93 --- /dev/null +++ b/.github/workflows/modules.yml @@ -0,0 +1,57 @@ +# +# this is RAGStack specific modules workflow +# + +name: RAGStack Modules and Adapter CI + +on: + push: + branches: + - main + paths: + - "ragstack/**" + pull_request: + paths: + - "ragstack/**" + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python 3.10 + uses: actions/setup-python@v2 + with: + python-version: '3.10' + + - name: Cache Poetry virtualenv and dependencies + uses: actions/cache@v2 + with: + path: | + ~/.cache/pypoetry/virtualenvs + **/poetry.lock + key: ${{ runner.os }}-poetry-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-poetry- + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: | + cd ragstack/colbertbase + poetry install + +# - name: Lint with flake8 +# run: | +# cd ragstack +# poetry run flake8 . + + - name: Test with pytest + run: | + cd ragstack/colbertbase + poetry run pytest diff --git a/ragstack/colbertbase/README.md b/ragstack/colbertbase/README.md new file mode 100644 index 000000000..43dd81ab4 --- /dev/null +++ b/ragstack/colbertbase/README.md @@ -0,0 +1,14 @@ +# ColBERT retrieval + +This is a ColBERT retrieval based on Astra DB or Cassandra implementation. + +This module only depends on ColBERT and Cassandra driver with no LangChain and LlamaIndex dependencies. + +## Examples + + +## Env +```bash +poetry install +``` + diff --git a/ragstack/colbertbase/colbertbase/__init__.py b/ragstack/colbertbase/colbertbase/__init__.py new file mode 100644 index 000000000..9290e723d --- /dev/null +++ b/ragstack/colbertbase/colbertbase/__init__.py @@ -0,0 +1,17 @@ + +from .astra_colbert_embedding import ColbertTokenEmbeddings +from .astra_db import AstraDB +from .astra_retriever import ColbertAstraRetriever, max_similarity_torch +from .token_embedding import PerTokenEmbeddings, PassageEmbeddings, TokenEmbeddings +from .constant import DEFAULT_COLBERT_MODEL, DEFAULT_COLBERT_DIM + +__all__ = ( + ColbertTokenEmbeddings, AstraDB, + ColbertAstraRetriever, + max_similarity_torch, + PerTokenEmbeddings, + PassageEmbeddings, TokenEmbeddings, + DEFAULT_COLBERT_MODEL, DEFAULT_COLBERT_DIM, +) + +__version__ = "0.0.1" \ No newline at end of file diff --git a/ragstack/colbertbase/colbertbase/astra_colbert_embedding.py b/ragstack/colbertbase/colbertbase/astra_colbert_embedding.py new file mode 100644 index 000000000..e27166bbd --- /dev/null +++ b/ragstack/colbertbase/colbertbase/astra_colbert_embedding.py @@ -0,0 +1,246 @@ +from typing import Any, Dict, List, Union +import datetime +import itertools +import torch # it should part of colbert dependencies +import uuid +from .token_embedding import TokenEmbeddings, PerTokenEmbeddings, PassageEmbeddings + + +from colbert.indexing.collection_indexer import CollectionIndexer +from colbert.infra import Run, RunConfig, ColBERTConfig +from colbert.data import Queries +from colbert.indexing.collection_encoder import CollectionEncoder +from colbert.modeling.checkpoint import Checkpoint + + + +class ColbertTokenEmbeddings(TokenEmbeddings): + """ + Colbert embeddings model. + + The embedding runs locally and requires the colbert library to be installed. + + Example: + Currently the pyarrow module requires a specific version to be installed. + + pip uninstall pyarrow && pip install pyarrow==14.0.0 + pip install colbert-ai==0.2.19 + pip torch + + To take advantage of GPU, please install faiss-gpu + """ + + colbert_config: ColBERTConfig + checkpoint: Checkpoint + encoder: CollectionEncoder + + # these are default values aligned with the colbert library + __doc_maxlen: int = 220, + __nbits: int = 1, + __kmeans_niters: int = 4, + __nranks: int = 1, + __index_bsize: int = 64, + + # TODO: expose these values + # these are default values aligned with the colbert library + __resume: bool = False, + __similarity: str = 'cosine', + __bsize: int = 32, + __accumsteps: int = 1, + __lr: float = 0.000003, + __maxsteps: int = 500000, + __nway: int = 2, + __use_ib_negatives: bool = False, + __reranker: bool = False, + __is_cuda: bool = False + + @classmethod + def validate_environment(self, values: Dict) -> Dict: + """Validate colbert and its dependency is installed.""" + try: + from colbert import Indexer + except ImportError as exc: + raise ImportError( + "Could not import colbert library. " + "Please install it with `pip install colbert`" + ) from exc + + try: + import torch + if torch.cuda.is_available(): + self.__is_cuda = True + try: + import faiss + except ImportError as e: + raise ImportError( + "Could not import faiss library. " + "Please install it with `pip install faiss-gpu`" + ) from e + + except ImportError as exc: + raise ImportError( + "Could not import torch library. " + "Please install it with `pip install torch`" + ) from exc + + return values + + + def __init__( + self, + checkpoint: str = "colbert-ir/colbertv2.0", + doc_maxlen: int = 220, + nbits: int = 1, + kmeans_niters: int = 4, + nranks: int = -1, + query_maxlen: int = 32, + **data: Any, + ): + self.__cuda = torch.cuda.is_available() + total_visible_gpus=0 + if self.__cuda: + self.__cuda_device_count = torch.cuda.device_count() + self.__cuda_device_name = torch.cuda.get_device_name() + print(f"nrank {nranks}") + if nranks < 1: + nranks=self.__cuda_device_count + if nranks > 1: + total_visible_gpus=self.__cuda_device_count + print(f"run on {self.__cuda_device_count} gpus and visible {total_visible_gpus} gpus embeddings on {nranks} gpus") + else: + if nranks < 1: + nranks = 1 + + with Run().context(RunConfig(nranks=nranks)): + if self.__cuda: + torch.cuda.empty_cache() + self.colbert_config = ColBERTConfig( + doc_maxlen=doc_maxlen, + nbits=nbits, + kmeans_niters=kmeans_niters, + nranks=nranks, + checkpoint=checkpoint, + query_maxlen=query_maxlen, + gpus=total_visible_gpus, + ) + self.__doc_maxlen = doc_maxlen + self.__nbits = nbits + self.__kmeans_niters = kmeans_niters + self.__nranks = nranks + print("creating checkpoint") + self.checkpoint = Checkpoint(self.colbert_config.checkpoint, colbert_config=self.colbert_config) + self.encoder = CollectionEncoder(config=self.colbert_config, checkpoint=self.checkpoint) + self.__cuda = torch.cuda.is_available() + if self.__cuda: + self.checkpoint = self.checkpoint.cuda() + + self.print_memory_stats("ColbertTokenEmbeddings") + + + def embed_documents(self, texts: List[str], title: str="") -> List[PassageEmbeddings]: + if title == "": + title = str(uuid.uuid4()) + """Embed search docs.""" + return self.encode(texts, title) + + + def embed_query(self, text: str, title: str) -> PassageEmbeddings: + return self.embed_documents([text], title)[0] + + def encode_queries( + self, + query: Union[str, List[str]], + full_length_search: bool = False, + query_maxlen: int = 32, + ): + queries = query if type(query) is list else [query] + bsize = 128 if len(queries) > 128 else None + + self.checkpoint.query_tokenizer.query_maxlen = max(query_maxlen, self.colbert_config.query_maxlen) + Q = self.checkpoint.queryFromText(queries, bsize=bsize, to_cpu=True, full_length_search=full_length_search) + + return Q + + def encode_query( + self, + query: str, + full_length_search: bool = False, + query_maxlen: int = 32, + ): + Q = self.encode_queries(query, full_length_search, query_maxlen=query_maxlen) + return Q[0] + + + def encode(self, texts: List[str], title: str="") -> List[PassageEmbeddings]: + # collection = Collection(texts) + # batches = collection.enumerate_batches(rank=Run().rank) + ''' + config = ColBERTConfig( + doc_maxlen=self.__doc_maxlen, + nbits=self.__nbits, + kmeans_niters=self.__kmeans_niters, + checkpoint=self.checkpoint, + index_bsize=1) + ckp = Checkpoint(config.checkpoint, colbert_config=config) + encoder = CollectionEncoder(config=self.config, checkpoint=self.checkpoint) + ''' + embeddings, count = self.encoder.encode_passages(texts) + + collectionEmbds = [] + # split up embeddings by counts, a list of the number of tokens in each passage + start_indices = [0] + list(itertools.accumulate(count[:-1])) + embeddings_by_part = [embeddings[start:start+count] for start, count in zip(start_indices, count)] + size = len(embeddings_by_part) + for part, embedding in enumerate(embeddings_by_part): + collectionEmbd = PassageEmbeddings(text=texts[part], title=title, part=part) + pid = collectionEmbd.id() + for __part_i, perTokenEmbedding in enumerate(embedding): + perToken = PerTokenEmbeddings(parent_id=pid, id=__part_i, title=title, part=part) + perToken.add_embeddings(perTokenEmbedding.tolist()) + # print(f" token embedding part {part} id {__part_i} parent id {pid}") + collectionEmbd.add_token_embeddings(perToken) + collectionEmbds.append(collectionEmbd) + # print(f"embedding part {part} collection id {pid}, collection size {len(collectionEmbd.get_all_token_embeddings())}") + + return collectionEmbds + + def print_message(self, *s, condition=True, pad=False): + s = ' '.join([str(x) for x in s]) + msg = "[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s) + + if condition: + msg = msg if not pad else f'\n{msg}\n' + print(msg, flush=True) + + return msg + + def print_memory_stats(self, message=''): + try: + import psutil # Remove before releases? Or at least make optional with try/except. + except ImportError: + self.print_message("psutil not installed. Memory stats not available.") + return + + global_info = psutil.virtual_memory() + total, available, used, free = global_info.total, global_info.available, global_info.used, global_info.free + + info = psutil.Process().memory_info() + rss, vms, shared = info.rss, info.vms, info.shared + uss = psutil.Process().memory_full_info().uss + + gib = 1024 ** 3 + + summary = f""" + "[PID: {os.getpid()}] + [{message}] + Available: {available / gib:,.1f} / {total / gib:,.1f} + Free: {free / gib:,.1f} / {total / gib:,.1f} + Usage: {used / gib:,.1f} / {total / gib:,.1f} + + RSS: {rss / gib:,.1f} + VMS: {vms / gib:,.1f} + USS: {uss / gib:,.1f} + SHARED: {shared / gib:,.1f} + """.strip().replace('\n', '\t') + + self.print_message(summary, pad=True) diff --git a/ragstack/colbertbase/colbertbase/astra_db.py b/ragstack/colbertbase/colbertbase/astra_db.py new file mode 100644 index 000000000..c51eda3e9 --- /dev/null +++ b/ragstack/colbertbase/colbertbase/astra_db.py @@ -0,0 +1,180 @@ +from typing import Any, Dict, List +from cassandra.cluster import Cluster +from cassandra.query import SimpleStatement +from cassandra.auth import PlainTextAuthProvider +from cassandra import InvalidRequest +from cassandra.concurrent import execute_concurrent_with_args + +from .token_embedding import PassageEmbeddings, PerTokenEmbeddings + +def required_cred(cred: str): + if cred is None or cred == "": + raise ValueError("Please provide credentials") + +class AstraDB: + def __init__( + self, + secure_connect_bundle: str="", + astra_token: str=None, + keyspace: str="colbert128", + verbose: bool=False, + timeout: int=60, + **kwargs, + ): + + required_cred(secure_connect_bundle) + required_cred(astra_token) + + # self.cluster = Cluster(**kwargs) + self.cluster = Cluster( + cloud={ + 'secure_connect_bundle': secure_connect_bundle + }, + auth_provider=PlainTextAuthProvider( + 'token', + astra_token + ) + ) + self.keyspace = keyspace + self.session = self.cluster.connect() + self.session.default_timeout = timeout + self.verbose = verbose + + print(f"set up keyspace {keyspace}, tables and indexes...") + + if keyspace not in self.cluster.metadata.keyspaces.keys(): + raise ValueError(f"Keyspace '{keyspace}' does not exist. please create it first.") + + self.create_tables() + + # prepare statements + + chunk_counts_cql = f""" + SELECT COUNT(*) FROM {keyspace}.chunks + """ + self.chunk_counts_stmt = self.session.prepare(chunk_counts_cql) + + insert_chunk_cql = f""" + INSERT INTO {keyspace}.chunks (title, part, body) + VALUES (?, ?, ?) + """ + self.insert_chunk_stmt = self.session.prepare(insert_chunk_cql) + + insert_colbert_cql = f""" + INSERT INTO {keyspace}.colbert_embeddings (title, part, embedding_id, bert_embedding) + VALUES (?, ?, ?, ?) + """ + self.insert_colbert_stmt = self.session.prepare(insert_colbert_cql) + + query_colbert_ann_cql = f""" + SELECT title, part + FROM {keyspace}.colbert_embeddings + ORDER BY bert_embedding ANN OF ? + LIMIT ? + """ + self.query_colbert_ann_stmt = self.session.prepare(query_colbert_ann_cql) + + query_colbert_parts_cql = f""" + SELECT title, part, bert_embedding + FROM {keyspace}.colbert_embeddings + WHERE title = ? AND part = ? + """ + self.query_colbert_parts_stmt = self.session.prepare(query_colbert_parts_cql) + + query_part_by_pk = f""" + SELECT body + FROM {keyspace}.chunks + WHERE title = ? AND part = ? + """ + self.query_part_by_pk_stmt = self.session.prepare(query_part_by_pk) + + print("statements are prepared") + + + def create_tables(self): + self.session.execute(f""" + use {self.keyspace}; + """) + print(f"Using keyspace {self.keyspace}") + + self.session.execute(""" + CREATE TABLE IF NOT EXISTS chunks( + title text, + part int, + body text, + PRIMARY KEY (title, part) + ); + """) + print("Created chunks table") + + self.session.execute(""" + CREATE TABLE IF NOT EXISTS colbert_embeddings ( + title text, + part int, + embedding_id int, + bert_embedding vector, + PRIMARY KEY (title, part, embedding_id) + ); + """) + print("Created colbert_embeddings table") + + self.create_index(""" + CREATE CUSTOM INDEX colbert_ann ON colbert_embeddings(bert_embedding) USING 'StorageAttachedIndex' + WITH OPTIONS = { 'similarity_function': 'DOT_PRODUCT' }; + """) + print("Created index on colbert_embeddings table") + + def create_index(self, command: str): + try: + self.session.execute(command) + except InvalidRequest as e: + if "already exists" in str(e): + print("Index already exists and continue...") + else: + raise e + # throw other exceptions + + # ensure db connection is alive + def ping(self): + self.session.execute("select release_version from system.local").one() + + + def insert_chunk(self, title: str, part: int, body: str): + self.session.execute(self.insert_chunk_stmt, (title, part, body)) + + def insert_colbert_embeddings_chunks( + self, + embeddings: List[PassageEmbeddings], + delete_existed_passage: bool = False + ) -> None: + if delete_existed_passage: + for p in embeddings: + try: + self.delete_title(p.title()) + except Exception as e: + # no need to throw error if the title does not exist + # let the error propagate + print(f"delete title {p.title()} error {e}") + # insert chunks + p_parameters = [(p.title(), p.part(), p.get_text()) for p in embeddings] + execute_concurrent_with_args(self.session, self.insert_chunk_stmt, p_parameters) + if (self.verbose): + print(f"inserting chunks {p_parameters}") + + # insert colbert embeddings + for passageEmbd in embeddings: + title = passageEmbd.title() + parameters = [(title, e[1].part, e[1].id, e[1].get_embeddings()) for e in enumerate(passageEmbd.get_all_token_embeddings())] + execute_concurrent_with_args(self.session, self.insert_colbert_stmt, parameters) + + def delete_title(self, title: str): + # Assuming `title` is a variable holding the title you want to delete + query = "DELETE FROM {}.chunks WHERE title = %s".format(self.keyspace) + self.session.execute(query, (title,)) + + query = "DELETE FROM {}.colbert_embeddings WHERE title = %s".format(self.keyspace) + self.session.execute(query, (title,)) + + def close(self): + self.session.shutdown() + self.cluster.shutdown() \ No newline at end of file diff --git a/ragstack/colbertbase/colbertbase/astra_retriever.py b/ragstack/colbertbase/colbertbase/astra_retriever.py new file mode 100644 index 000000000..d359ccf66 --- /dev/null +++ b/ragstack/colbertbase/colbertbase/astra_retriever.py @@ -0,0 +1,139 @@ +from .astra_colbert_embedding import ColbertTokenEmbeddings + +from .astra_db import AstraDB +from torch import tensor +from typing import List +import torch +import math + +# max similarity between a query vector and a list of embeddings +# The function returns the highest similarity score (i.e., the maximum dot product value) between the query vector and any of the embedding vectors in the list. + +''' +# The function iterates over each embedding vector (e) in the embeddings. +# For each e, it performs a dot product operation (@) with the query vector (qv). +# The dot product of two vectors is a measure of their similarity. In the context of embeddings, +# a higher dot product value usually indicates greater similarity. +# The max function then takes the highest value from these dot product operations. +# Essentially, it's picking the embedding vector that has the highest similarity to the query vector qv. +def max_similary_operator_based(qv, embeddings, is_cuda: bool=False): + if is_cuda: + # Assuming qv and embeddings are PyTorch tensors + qv = qv.to('cuda') # Move qv to GPU + embeddings = [e.to('cuda') for e in embeddings] # Move all embeddings to GPU + return max(qv @ e for e in embeddings) +def max_similarity_numpy_based(query_vector, embedding_list): + # Convert the list of embeddings into a numpy matrix for vectorized operation + embedding_matrix = np.vstack(embedding_list) + + # Calculate the dot products in a vectorized manner + sims = np.dot(embedding_matrix, query_vector) + + # Find the maximum similarity (dot product) value + max_sim = np.max(sims) + + return max_sim +''' + +# this torch based max similary has the best performance. +# it is at least 20 times faster than dot product operator and numpy based implementation CuDA and CPU +def max_similarity_torch(query_vector, embedding_list, is_cuda: bool=False): + """ + Calculate the maximum similarity (dot product) between a query vector and a list of embedding vectors, + optimized for performance using PyTorch for GPU acceleration. + + Parameters: + - query_vector: A PyTorch tensor representing the query vector. + - embedding_list: A list of PyTorch tensors, each representing an embedding vector. + + Returns: + - max_sim: A float representing the highest similarity (dot product) score between the query vector and the embedding vectors in the list, computed on the GPU. + """ + # stacks the list of embedding tensors into a single tensor + if is_cuda: + query_vector = query_vector.to('cuda') + embedding_list = torch.stack(embedding_list).to('cuda') + else: + embedding_list = torch.stack(embedding_list) + + # Calculate the dot products in a vectorized manner on the GPU + sims = torch.matmul(embedding_list, query_vector) + + # Find the maximum similarity (dot product) value + max_sim = torch.max(sims) + + # returns a tensor; the item() is the score + return max_sim + + +class ColbertAstraRetriever(): + astra: AstraDB + colbertEmbeddings: ColbertTokenEmbeddings + verbose: bool + is_cuda: bool=False + + class Config: + arbitrary_types_allowed = True + + def __init__( + self, + astraDB: AstraDB, + colbertEmbeddings: ColbertTokenEmbeddings, + verbose: bool=False, + **kwargs + ): + # initialize pydantic base model + self.astra = astraDB + self.colbertEmbeddings = colbertEmbeddings + self.verbose = verbose + self.is_cuda = torch.cuda.is_available() + + def retrieve( + self, + query: str, + k: int=10, + query_maxlen: int=64, + **kwargs + ): + # + # if the query has fewer than a predefined number of of tokens Nq, + # colbertEmbeddings will pad it with BERT special [mast] token up to length Nq. + # + query_encodings = self.colbertEmbeddings.encode_query(query, query_maxlen=query_maxlen) + + count = self.astra.session.execute(self.astra.chunk_counts_stmt).one().count + k = min(k, count) + + # the min of query_maxlen is 32 + top_k = max(math.floor(len(query_encodings) / 2), 16) + if self.verbose: + print(f"Total number of chunks: {count}, query length {len(query)} embeddings top_k: {top_k}") + + # find the most relevant documents + docparts = set() + for qv in query_encodings: + # per token based retrieval + rows = self.astra.session.execute(self.astra.query_colbert_ann_stmt, [list(qv), top_k]) + docparts.update((row.title, row.part) for row in rows) + # score each document + scores = {} + import time + for title, part in docparts: + # find all the found parts so that we can do max similarity search + rows = self.astra.session.execute(self.astra.query_colbert_parts_stmt, [title, part]) + embeddings_for_part = [tensor(row.bert_embedding) for row in rows] + # score based on The function returns the highest similarity score + #(i.e., the maximum dot product value) between the query vector and any of the embedding vectors in the list. + scores[(title, part)] = sum(max_similarity_torch(qv, embeddings_for_part, self.is_cuda) for qv in query_encodings) + # load the source chunk for the top k documents + docs_by_score = sorted(scores, key=scores.get, reverse=True)[:k] + answers = [] + rank = 1 + for title, part in docs_by_score: + rs = self.astra.session.execute(self.astra.query_part_by_pk_stmt, [title, part]) + score = scores[(title, part)] + answers.append({'title': title, 'score': score.item(), 'rank': rank, 'body': rs.one().body}) + rank=rank+1 + # clean up on tensor memory on GPU + del scores + return answers diff --git a/ragstack/colbertbase/colbertbase/constant.py b/ragstack/colbertbase/colbertbase/constant.py new file mode 100644 index 000000000..e80305d4d --- /dev/null +++ b/ragstack/colbertbase/colbertbase/constant.py @@ -0,0 +1,4 @@ +# the default colbert model is colbert-ir/colbertv2.0 +DEFAULT_COLBERT_MODEL = "colbert-ir/colbertv2.0" + +DEFAULT_COLBERT_DIM = 128 \ No newline at end of file diff --git a/ragstack/colbertbase/colbertbase/token_embedding.py b/ragstack/colbertbase/colbertbase/token_embedding.py new file mode 100644 index 000000000..52dccb0a5 --- /dev/null +++ b/ragstack/colbertbase/colbertbase/token_embedding.py @@ -0,0 +1,115 @@ +# +# this is a base class for ColBERT per token based embedding + +from abc import ABC, abstractmethod +from typing import Any, List +import uuid + +class PerTokenEmbeddings(): + + __embeddings: List[float] + + def __init__( + self, + id: int, + part: int, + parent_id: uuid.UUID = None, + title: str = "", + ): + self.id = id + self.parent_id = parent_id + self.__embeddings = [] + self.title = title + self.part =part + + def add_embeddings(self, embeddings: List[float]): + self.__embeddings = embeddings + + def get_embeddings(self) -> List[float]: + return self.__embeddings + + def id(self): + return self.id + + def parent_id(self): + return self.parent_id + + def part(self): + return self.part + +class PassageEmbeddings(): + __token_embeddings: List[PerTokenEmbeddings] + __text: str + __title: str + __id: uuid.UUID + + def __init__( + self, + text: str, + title: str = "", + part: int = 0, + id: uuid.UUID = None, + model: str = "colbert-ir/colbertv2.0", + dim: int = 128, + ): + #self.token_ids = token_ids + self.__text = text + self.__token_embeddings = [] + if id is None: + self.__id = uuid.uuid4() + else: + self.__id = id + self.__model = model + self.__dim = dim + self.__title = title + self.__part = part + + def model(self): + return self.__model + + def dim(self): + return self.__dim + + def token_size(self): + return len(self.token_ids) + + def title(self): + return self.__title + + def __len__(self): + return len(self.embeddings) + + def id(self): + return self.__id + + def part(self): + return self.__part + + def add_token_embeddings(self, token_embeddings: PerTokenEmbeddings): + self.__token_embeddings.append(token_embeddings) + + def get_token_embeddings(self, token_id: int) -> PerTokenEmbeddings: + for token in self.__token_embeddings: + if token.token_id == token_id: + return token + return None + + def get_all_token_embeddings(self) -> List[PerTokenEmbeddings]: + return self.__token_embeddings + + def get_text(self): + return self.__text + +# +# This is the base class for token based embedding +# ColBERT token embeddings is an example of a class that inherits from this class +class TokenEmbeddings(ABC): + """Interface for token embedding models.""" + + @abstractmethod + def embed_documents(self, texts: List[str]) -> List[PassageEmbeddings]: + """Embed search docs.""" + + @abstractmethod + def embed_query(self, text: str) -> PassageEmbeddings: + """Embed query text.""" diff --git a/ragstack/colbertbase/pyproject.toml b/ragstack/colbertbase/pyproject.toml new file mode 100644 index 000000000..d02118067 --- /dev/null +++ b/ragstack/colbertbase/pyproject.toml @@ -0,0 +1,21 @@ +[tool.poetry] +name = "colbertbase" +version = "0.1.0" +description = "" +authors = ["ming luo "] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.10" +colbert-ai = "^0.2.19" +cassandra-driver = "^3.29.0" +torch = "^2.2.1" +pyarrow = "14.0.0" + + +[tool.poetry.group.dev.dependencies] +pytest = "^8.1.1" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/ragstack/colbertbase/tests/__init__.py b/ragstack/colbertbase/tests/__init__.py new file mode 100644 index 000000000..8ced11412 --- /dev/null +++ b/ragstack/colbertbase/tests/__init__.py @@ -0,0 +1 @@ +# place holder \ No newline at end of file diff --git a/ragstack/colbertbase/tests/test_astra_colbert_embeddings.py b/ragstack/colbertbase/tests/test_astra_colbert_embeddings.py new file mode 100644 index 000000000..3ae1de3f6 --- /dev/null +++ b/ragstack/colbertbase/tests/test_astra_colbert_embeddings.py @@ -0,0 +1,60 @@ +# test_embeddings.py + +from colbertbase import ColbertTokenEmbeddings +from colbertbase import DEFAULT_COLBERT_MODEL, DEFAULT_COLBERT_DIM + +def test_colbert_token_embeddings(): + colbert = ColbertTokenEmbeddings() + assert colbert.colbert_config is not None + + passagesEmbeddings = colbert.embed_documents(["test1", "test2"]) + + assert len(passagesEmbeddings) == 2 + + assert passagesEmbeddings[0].get_text() == "test1" + assert passagesEmbeddings[1].get_text() == "test2" + + # generate uuid based title + assert passagesEmbeddings[0].title() != "" + assert passagesEmbeddings[1].title() != "" + + passageEmbeddings = colbert.embed_documents(texts=["test1", "test2"], title="test-title") + + assert passageEmbeddings[0].get_text() == "test1" + assert passageEmbeddings[0].title() == "test-title" + assert passageEmbeddings[1].title() == "test-title" + + # test query embedding + # queryEmbeddings = colbert.embed_query(text="test-query", title="test-title") + + tokenEmbeddings = passagesEmbeddings[0].get_all_token_embeddings() + assert len(tokenEmbeddings[0].get_embeddings()) == DEFAULT_COLBERT_DIM + + # test query encoding + queryEncoding = colbert.encode_query("test-query", query_maxlen=512) + assert len(queryEncoding) == 512 + + +def test_colbert_token_embeddings_with_params(): + colbert = ColbertTokenEmbeddings( + doc_maxlen=220, + nbits=1, + kmeans_niters=4, + checkpoint=DEFAULT_COLBERT_MODEL, + query_maxlen=32, + ) + assert colbert.colbert_config is not None + + passagesEmbeddings = colbert.embed_documents(["test1", "test2", "test3"]) + + assert len(passagesEmbeddings) == 3 + + assert passagesEmbeddings[0].get_text() == "test1" + assert passagesEmbeddings[1].get_text() == "test2" + + tokenEmbeddings = passagesEmbeddings[0].get_all_token_embeddings() + assert len(tokenEmbeddings) > 1 + assert len(tokenEmbeddings[0].get_embeddings()) == DEFAULT_COLBERT_DIM + + + diff --git a/ragstack/colbertbase/tests/test_astra_retriever.py b/ragstack/colbertbase/tests/test_astra_retriever.py new file mode 100644 index 000000000..afe4b4277 --- /dev/null +++ b/ragstack/colbertbase/tests/test_astra_retriever.py @@ -0,0 +1,22 @@ +import torch +import pytest + +from colbertbase import max_similarity_torch + +def test_max_similarity_torch(): + # Example query vector and embedding list + query_vector = torch.tensor([1, 2, 3], dtype=torch.float32) + embedding_list = [ + torch.tensor([2, 3, 4], dtype=torch.float32), + torch.tensor([1, 0, 1], dtype=torch.float32), + torch.tensor([4, 5, 6], dtype=torch.float32) # This should produce the highest dot product + ] + + # Expected result calculated manually or logically determined + expected_max_similarity = torch.dot(query_vector, embedding_list[2]) # Should be the highest + + # Call the function under test + max_sim = max_similarity_torch(query_vector, embedding_list, is_cuda=False) + + # Check if the returned max similarity matches the expected value + assert max_sim.item() == expected_max_similarity.item(), "The max similarity does not match the expected value."