diff --git a/.github/ci.yaml b/.github/workflows/ci.yaml similarity index 84% rename from .github/ci.yaml rename to .github/workflows/ci.yaml index f11fe18..243944d 100644 --- a/.github/ci.yaml +++ b/.github/workflows/ci.yaml @@ -22,7 +22,7 @@ jobs: python-version: 3.x - name: Install dependencies - run: pip install -r requirements.txt + run: pip install -e .[dev] - name: Run tests - run: pytest tests + run: python -m pytest tests diff --git a/pyproject.toml b/pyproject.toml index 8adb866..790608e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,11 +4,11 @@ version = "0.1.dev0" dependencies = [ "lancedb", "pandas", + "streamlit", + "datasets", + "tantivy" ] -dev-dependencies = [ - "pytest", - "transformers" - ] + description = "ragged" license = { file = "LICENSE" } readme = "README.md" @@ -17,6 +17,7 @@ keywords = [ "data-science", "machine-learning", "data-analytics", + ] classifiers = [ "Development Status :: 3 - Alpha", @@ -39,8 +40,12 @@ classifiers = [ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] -dataset_providers = [ - "llama-index" +dev = [ + "llama-index", + "pytest", + "transformers", + "torch", + "sentence-transformers", ] [build-system] @@ -54,4 +59,6 @@ markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "asyncio", "s3_test" -] \ No newline at end of file +] +[project.scripts] +ragged = "ragged.cli.entry_point:cli" \ No newline at end of file diff --git a/ragged/cli/entry_point.py b/ragged/cli/entry_point.py new file mode 100644 index 0000000..c0d43f8 --- /dev/null +++ b/ragged/cli/entry_point.py @@ -0,0 +1,23 @@ +import os +import argparse +from pathlib import Path + + +def cli(): + parser = argparse.ArgumentParser(description="CLI for running VectorDB quickstart") + parser.add_argument("--quickstart", type=str, help="Name of the app") + args = parser.parse_args() + + if args.quickstart == "vectordb": + run_vectordb_quickstart_gui() + else: + raise ValueError(f"App {args.name} not found. Available apps: vectordb") + + +def run_vectordb_quickstart_gui(): + # get path of the parent directory + parent_dir = Path(__file__).parent.parent + # get path of the executable + executable = os.path.join(parent_dir, "gui/vectordb.py") + # run the executable + os.system(f"streamlit run {executable}") \ No newline at end of file diff --git a/ragged/dataset/__init__.py b/ragged/dataset/__init__.py index ba934d2..b0e9ee2 100644 --- a/ragged/dataset/__init__.py +++ b/ragged/dataset/__init__.py @@ -1,3 +1,4 @@ from .llama_index import LlamaIndexDataset +from .squad import SquadDataset -__all__ = ["LlamaIndexDataset"] \ No newline at end of file +__all__ = ["LlamaIndexDataset", "SquadDataset"] \ No newline at end of file diff --git a/ragged/dataset/base.py b/ragged/dataset/base.py index b1a77b5..5b94d7f 100644 --- a/ragged/dataset/base.py +++ b/ragged/dataset/base.py @@ -1,11 +1,20 @@ from abc import ABC, abstractmethod - +from pydantic import BaseModel +from typing import List import pandas as pd +class TextNode(BaseModel): + id: str + text: str + class Dataset(ABC): @abstractmethod def to_pandas(self)->pd.DataFrame: pass + + @abstractmethod + def get_contexts(self)->List[TextNode]: + pass @staticmethod def available_datasets(): @@ -13,3 +22,13 @@ def available_datasets(): List of available datasets that can be loaded """ return [] + + @property + @abstractmethod + def context_column_name(self): + pass + + @property + @abstractmethod + def query_column_name(self): + pass \ No newline at end of file diff --git a/ragged/dataset/llama_index.py b/ragged/dataset/llama_index.py index 0f6a14e..a5ee294 100644 --- a/ragged/dataset/llama_index.py +++ b/ragged/dataset/llama_index.py @@ -1,4 +1,6 @@ -from typing import Optional +from typing import List, Optional + +from ragged.dataset.base import TextNode from .base import Dataset import logging import os @@ -38,24 +40,30 @@ def __init__(self, dataset_name: Optional[str] = None, path: Optional[str] = Non parser = SentenceSplitter() nodes = parser.get_nodes_from_documents(documents) - self.documents = nodes + self.documents = [TextNode(id=node.id_, text=node.text) for node in nodes] def to_pandas(self): return self.dataset.to_pandas() + def get_contexts(self) -> List[TextNode]: + return self.documents + + @property + def context_column_name(self): + return "reference_contexts" + + @property + def query_column_name(self): + return "query" + @staticmethod def available_datasets(): return [ - "PaulGrahamEssayDataset", "Uber10KDataset2021", "MiniEsgBenchDataset", "OriginOfCovid19Dataset", - "BraintrustCodaHelpDeskDataset", - "MiniCovidQaDataset", - "PatronusAIFinanceBenchDataset", - "BlockchainSolanaDataset", "MiniTruthfulQADataset", "Llama2PaperDataset", - "CovidQaDataset", + "OriginOfCovid19Dataset", ] diff --git a/ragged/dataset/squad.py b/ragged/dataset/squad.py new file mode 100644 index 0000000..6d7e6d0 --- /dev/null +++ b/ragged/dataset/squad.py @@ -0,0 +1,31 @@ +from .base import Dataset, TextNode +from typing import List +from datasets import load_dataset + + +class SquadDataset(Dataset): + def __init__(self, dataset_name: str = "rajpurkar/squad"): + self.dataset = load_dataset(dataset_name) + # get unique contexts from the train dataframe + contexts = self.dataset["train"].to_pandas()["context"].unique() + self.documents = [TextNode(id=str(i), text=context) for i, context in enumerate(contexts)] + + + def to_pandas(self): + return self.dataset["train"].to_pandas() + + + def get_contexts(self)->List[TextNode]: + return self.documents + + @property + def context_column_name(self): + return "context" + + @property + def query_column_name(self): + return "question" + + @staticmethod + def available_datasets(): + return ["rajpurkar/squad"] \ No newline at end of file diff --git a/ragged/gui/vectordb.py b/ragged/gui/vectordb.py index 8e70bd6..903b1ec 100644 --- a/ragged/gui/vectordb.py +++ b/ragged/gui/vectordb.py @@ -1,19 +1,21 @@ import json import streamlit as st import streamlit.components.v1 as components -from ragged.dataset import LlamaIndexDataset +from ragged.dataset import LlamaIndexDataset, SquadDataset from ragged.metrics.retriever import HitRate, QueryType from ragged.results import RetriverResult from lancedb.rerankers import CohereReranker, ColbertReranker, CrossEncoderReranker def dataset_provider_options(): return { - "Llama-Index": LlamaIndexDataset + "Llama-Index": LlamaIndexDataset, + "Squad": SquadDataset } def datasets_options(): return { - "Llama-Index": LlamaIndexDataset.available_datasets() + "Llama-Index": LlamaIndexDataset.available_datasets(), + "Squad": SquadDataset.available_datasets() } def metric_options(): @@ -23,6 +25,7 @@ def metric_options(): def reranker_options(): return { + "None": None, "CohereReranker": CohereReranker, "ColbertReranker": ColbertReranker, "CrossEncoderReranker": CrossEncoderReranker @@ -30,25 +33,25 @@ def reranker_options(): def embedding_provider_options(): return { - "openai": ["text-embedding-ada-002", "ext-embedding-3-small", "text-embedding-3-large"], + "openai": ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"], "huggingface": ["BAAI/bge-small-en-v1.5", "BAAI/bge-large-en-v1.5"], "sentence-transformers": ["all-MiniLM-L12-v2", "all-MiniLM-L6-v2", "all-MiniLM-L12-v1", "BAAI/bge-small-en-v1.5", "BAAI/bge-large-en-v1.5"], } -def is_wandb_installed(): +def safe_import_wandb(): try: import wandb from wandb import __version__ - return True + return wandb except ImportError: - return False + return None def init_wandb(dataset: str, embed_model: str): - if not is_wandb_installed(): + wandb = safe_import_wandb() + if wandb is None: st.error("Please install wandb to log metrics using `pip install wandb`") return - import wandb - wandb.init(project=f"ragged-vectordb", name=f"{dataset}-{embed_model}") if wandb.run is None else None + run = wandb.init(project=f"ragged-vectordb", name=f"{dataset}-{embed_model}") if wandb.run is None else None def eval_retrieval(): st.title("Retrieval Evaluator Quickstart") @@ -81,7 +84,9 @@ def eval_retrieval(): with col1: query_type = st.selectbox("Select a query type", [qt for qt in QueryType.__dict__.keys() if not qt.startswith("__")], placeholder="Choose a query type") with col2: - log_wandb = st.checkbox("Log to Wandb and plot in real-time", value=False) + log_wandb = st.checkbox("Log to WandB and plot in real-time", value=False) + use_existing_table = st.checkbox("Use existing table", value=False) + create_index = st.checkbox("Create index", value=False) eval_button = st.button("Evaluate") @@ -89,7 +94,7 @@ def eval_retrieval(): if eval_button: dataset = dataset_provider_options()[provider](dataset) reranker_kwargs = json.loads(kwargs) - reranker = reranker_options()[reranker](**reranker_kwargs) + reranker = reranker_options()[reranker](**reranker_kwargs) if reranker != "None" else None query_type = QueryType.__dict__[query_type] metric = metric_options()[metric]( dataset, @@ -98,7 +103,10 @@ def eval_retrieval(): reranker=reranker ) - results = metric.evaluate(top_k=top_k, query_type=query_type) + results = metric.evaluate(top_k=top_k, + query_type=query_type, + create_index=create_index, + use_existing_table=use_existing_table) total_metrics = len(results.model_dump()) cols = st.columns(total_metrics) for idx, (k,v) in enumerate(results.model_dump().items()): @@ -106,23 +114,26 @@ def eval_retrieval(): st.metric(label=k, value=v) if log_wandb: - init_wandb(dataset, embed_model) - if not is_wandb_installed(): + wandb = safe_import_wandb() + if wandb is None: st.error("Please install wandb to log metrics using `pip install wandb`") return - import wandb + init_wandb(dataset, embed_model) wandb.log(results.model_dump()) if log_wandb: st.title("Wandb Project Page") - if not is_wandb_installed(): + wandb = safe_import_wandb() + if wandb is None: st.error("Please install wandb to log metrics using `pip install wandb`") return - import wandb init_wandb(dataset, embed_model) - print(wandb.run.get_project_url()) - components.iframe(wandb.run.get_project_url()) + project_url = wandb.run.get_project_url() + st.markdown(""" + Visit the WandB project page to view the metrics in real-time. + [WandB Project Page]({project_url}) + """) if __name__ == "__main__": diff --git a/ragged/metrics/retriever/base.py b/ragged/metrics/retriever/base.py index 65e7b34..7689956 100644 --- a/ragged/metrics/retriever/base.py +++ b/ragged/metrics/retriever/base.py @@ -8,7 +8,8 @@ import logging import sys -#logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger("lancedb") +logger.setLevel(logging.INFO) class Metric(ABC): def __init__(self, @@ -32,7 +33,7 @@ def __init__(self, self.table = None @abstractmethod - def ingest_docs(self): + def ingest_docs(self, batched: bool = False, use_existing_table: bool = False): """ Ingest documents into the database and initialize the table """ @@ -51,7 +52,12 @@ def evaluate_query_type(self,query_type:str, top_k:5) -> float: """ pass - def evaluate(self, top_k: int, create_index: bool = False, query_type=QueryType.VECTOR) -> RetriverResult: + def evaluate(self, + top_k: int, + create_index: bool = False, + query_type=QueryType.VECTOR, + batched: bool = False, + use_existing_table: bool = False) -> RetriverResult: """ Run evaluaion @@ -64,10 +70,9 @@ def evaluate(self, top_k: int, create_index: bool = False, query_type=QueryType. Type of query to run. Default is QueryType.VECTOR. If "all" is passed, all query types will be evaluated """ - self.ingest_docs() + self.ingest_docs(batched, use_existing_table) if create_index: - # TODO: Create index - pass + self.table.create_index(metric="L2", num_partitions=256, num_sub_vectors=96) self.table.create_fts_index("text", replace=True) @@ -75,14 +80,18 @@ def evaluate(self, top_k: int, create_index: bool = False, query_type=QueryType. if query_type == "all": # Evaluate all query types with progress for qt in [QueryType.VECTOR, QueryType.FTS, QueryType.RERANK_VECTOR, QueryType.RERANK_FTS, QueryType.HYBRID]: - logging.info(f"Evaluating query type: {qt}") + logger.info(f"Evaluating query type: {qt}") + if self.reranker is None and qt in [QueryType.RERANK_VECTOR, QueryType.RERANK_FTS, QueryType.HYBRID]: + logger.warning(f"Reranker is not provided. Skipping query type: {qt}") + continue results[qt] = self.evaluate_query_type(top_k=top_k, query_type=qt) + logger.info(f"Hit rate for {qt}: {results[qt]}") return RetriverResult(**results) if query_type == "auto": query_type = deduce_query_type(query_type, self.reranker) - logging.info(f"Evaluating query type: {query_type}") + logger.info(f"Evaluating query type: {query_type}") results[query_type] = self.evaluate_query_type(top_k=top_k, query_type=query_type) return RetriverResult(**results) diff --git a/ragged/metrics/retriever/hit_rate.py b/ragged/metrics/retriever/hit_rate.py index 0523436..e848026 100644 --- a/ragged/metrics/retriever/hit_rate.py +++ b/ragged/metrics/retriever/hit_rate.py @@ -11,17 +11,18 @@ from ...results import RetriverResult # Set logging level to INFO -#logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("lancedb") +logger.setLevel(logging.INFO) class HitRate(Metric): def evaluate_query_type(self,query_type:str, top_k:5) -> float: eval_results = [] ds = self.dataset.to_pandas() for idx in tqdm.tqdm(range(len(ds))): - query = ds['query'][idx] - reference_context = ds['reference_contexts'][idx] + query = ds[self.dataset.query_column_name][idx] + reference_context = ds[self.dataset.context_column_name][idx] if not reference_context: - logging.warning("reference_context is None for query: {idx}. \ + logger.warning("reference_context is None for query: {idx}. \ Skipping this query. Please check your dataset.") continue try: @@ -29,9 +30,16 @@ def evaluate_query_type(self,query_type:str, top_k:5) -> float: except Exception as e: if isinstance(e, QueryConfigError): raise e + logger.warn(f'Error with query: {idx} {e}') + eval_results.append({ + 'is_hit': False, + 'retrieved': [], + 'expected': reference_context, + 'query': query, + }) continue retrieved_texts = rs['text'].tolist()[:top_k] - expected_text = reference_context[0] + expected_text = reference_context[0] if isinstance(reference_context, list) else reference_context is_hit = False # HACK: to handle new line characters added my llamaindex doc reader if expected_text in retrieved_texts or expected_text+'\n' in retrieved_texts: @@ -49,7 +57,7 @@ def evaluate_query_type(self,query_type:str, top_k:5) -> float: return hit_rate - def ingest_docs(self): + def ingest_docs(self, batched: bool = False, use_existing_table: bool = False): db = lancedb.connect(self.uri) embed_model = get_registry().get(self.embedding_registry_id).create(**self.embed_model_kwarg) @@ -57,18 +65,27 @@ class Schema(LanceModel): id: str text: str = embed_model.SourceField() vector: Vector(embed_model.ndims()) = embed_model.VectorField(default=None) + if use_existing_table and "documents" in db.table_names(): + logger.info("Using existing table") + self.table = db["documents"] + return tbl = db.create_table("documents", schema=Schema, mode="overwrite") - batch_size = 1000 - num_batches = (len(self.dataset.documents) // batch_size) + 1 if len(self.dataset.documents) % batch_size != 0 else 0 - # tqdm - logging.info(f"Adding {len(self.dataset.documents)} documents to LanceDB, in {num_batches} batches of size {batch_size}") - for i in tqdm.tqdm(range(num_batches), desc="Adding documents to LanceDB"): - batch = self.dataset.documents[i:i+batch_size] + contexts = self.dataset.get_contexts() + batch_size = len(contexts) if not batched else 1000 + num_batches = 1 + if batched: + num_batches = (len(contexts) // batch_size) + 1 if len(contexts) % batch_size != 0 else 0 + + logger.info(f"Adding {len(contexts)} documents to LanceDB, in {num_batches} batches of size {batch_size}") + for i in range(num_batches): + batch = contexts[i:i+batch_size] pydantic_batch = [] for doc in tqdm.tqdm(batch, desc="Adding batch to LanceDB"): - pydantic_batch.append(Schema(id=str(doc.id_), text=doc.text)) + pydantic_batch.append(Schema(id=str(doc.id), text=doc.text)) + logger.info(f"Adding batch {i} to LanceDB") tbl.add(pydantic_batch) + logger.info(f"created table with length {len(tbl)}") self.table = tbl