Skip to content

Commit

Permalink
Merge pull request #5 from lancedb/rag
Browse files Browse the repository at this point in the history
Add RAG eval GUI and API
  • Loading branch information
AyushExel authored Jun 2, 2024
2 parents 153d4ed + d8a0ab0 commit 9150456
Show file tree
Hide file tree
Showing 18 changed files with 901 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Most of popular toy datasets are not semantically challenging enough to evaluate
NOTE: `directory` can contain pdfs, txt files or any other file format that can be handled by Llama-index directory reader.
```python
from ragged.dataset.gen.gen_retrieval_data import gen_query_context_dataset
from ragged.inference_client import OpenAIInferenceClient
fragged.dataset.gen.llm_calls import OpenAIInferenceClient

clinet = OpenAIInferenceClient()
df = gen_query_context_dataset(directory="data/source_files", inference_client=clinet)
Expand Down
17 changes: 16 additions & 1 deletion ragged/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@ def to_pandas(self)->pd.DataFrame:
@abstractmethod
def get_contexts(self)->List[TextNode]:
pass


@abstractmethod
def get_queries(self)->List[str]:
pass

def get_ground_truths(self)->List[str]:
"""
Optional to implement
"""
return None

@staticmethod
def available_datasets():
"""
Expand All @@ -31,4 +41,9 @@ def context_column_name(self):
@property
@abstractmethod
def query_column_name(self):
pass

@property
@abstractmethod
def answer_column_name(self):
pass
23 changes: 17 additions & 6 deletions ragged/dataset/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,38 @@


class CSVDataset(Dataset):
def __init__(self, path: str, context_column: str = "context", query_column: str = "query"):
def __init__(self, path: str, context_column: str = "context", query_column: str = "query", answer_column: str = None):
self.dataset = pd.read_csv(path)
# get unique contexts from the train dataframe
contexts = self.dataset[context_column].unique()
self.documents = [TextNode(id=str(i), text=context) for i, context in enumerate(contexts)]


self.context_column = context_column
self.query_column = query_column
self.answer_column = answer_column

def to_pandas(self):
return self.dataset


def get_contexts(self)->List[TextNode]:
return self.documents

def get_queries(self) -> List[str]:
return self.dataset[self.query_column_name].tolist()

def get_answers(self) -> List[str]:
return self.dataset[self.answer_column_name].tolist()

@property
def context_column_name(self):
return "context"
return self.context_column

@property
def query_column_name(self):
return "query"
return self.query_column

@property
def answer_column_name(self):
return self.answer_column

@staticmethod
def available_datasets():
Expand Down
Empty file.
33 changes: 32 additions & 1 deletion ragged/dataset/gen/gen_retrieval_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
except ImportError:
raise ImportError("Please install the llama_index package by running `pip install llama_index`")

from tqdm import tqdm
import pandas as pd
from ragged.dataset.base import TextNode
from .prompts import Q_FROM_CONTEXT_DEFAULT
from .prompts import Q_FROM_CONTEXT_DEFAULT, QA_FROM_CONTEXT_DEFAULT
from .llm_calls import HFInferenceClient, BaseInferenceClient

def gen_query_context_dataset(directory: str,
Expand Down Expand Up @@ -37,3 +38,33 @@ def gen_query_context_dataset(directory: str,
df = pd.DataFrame(pylist)
return df

def gen_QA_dataset(
directory: str,
inference_client: BaseInferenceClient,
num_questions_per_context: int = 2,
query_column: str = "query",
context_column: str = "context",
answer_column: str = "answer"
):
"""
Generate QA dataset from a pandas dataframe
"""
docs = SimpleDirectoryReader(input_dir=directory).load_data()
parser = SentenceSplitter()
nodes = parser.get_nodes_from_documents(docs)
nodes = [TextNode(id=node.id_, text=node.text) for node in nodes]

pylist = []
for node in tqdm(nodes):
context = node.text
queries = inference_client(QA_FROM_CONTEXT_DEFAULT.format(context=context, num_questions=num_questions_per_context, question=query_column, answer=answer_column))
for query in queries:
pylist.append({
query_column: query['query'],
context_column: context,
answer_column: query['answer']
})

df = pd.DataFrame(pylist)
return df

12 changes: 12 additions & 0 deletions ragged/dataset/gen/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,16 @@
Provide your answer as a list named 'content'
Now here is the context.
Context: {context}\n
"""

QA_FROM_CONTEXT_DEFAULT = \
""""
Your task is to write a list of {num_questions} detailed statements or questions & their answers for given a context. It should satisfy the following conditions:
* should relevant to the specifc context. and should not be a one-liner. Rahter it should be a detailed question or statement.
* MUST not be keyword based. Try to make it semantically similar to the context without using the same words. Basically string matching should not work for searching the answer.
* MUST NOT mention something like "according to the passage" or "context".
Provide your questions and answers as a list named 'content' with each element being a dictionary with keys '{question}' and '{answer}'.
Now here is the context.
Context: {context}\n
"""
7 changes: 7 additions & 0 deletions ragged/dataset/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def to_pandas(self):

def get_contexts(self) -> List[TextNode]:
return self.documents

def get_queries(self) -> List[str]:
return self.dataset[self.query_column_name].tolist()

@property
def context_column_name(self):
Expand All @@ -56,6 +59,10 @@ def context_column_name(self):
def query_column_name(self):
return "query"

@property
def answer_column_name(self):
return None

@staticmethod
def available_datasets():
return [
Expand Down
7 changes: 7 additions & 0 deletions ragged/dataset/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ def to_pandas(self):
def get_contexts(self)->List[TextNode]:
return self.documents

def get_queries(self) -> List[str]:
return self.dataset[self.query_column_name].tolist()

@property
def context_column_name(self):
return "context"

@property
def query_column_name(self):
return "question"

@property
def answer_column_name(self):
return None

@staticmethod
def available_datasets():
Expand Down
33 changes: 33 additions & 0 deletions ragged/gui/choices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from ragged.dataset import LlamaIndexDataset, SquadDataset, CSVDataset
from lancedb.rerankers import CohereReranker, ColbertReranker, CrossEncoderReranker


def dataset_provider_options():
return {
"Llama-Index": LlamaIndexDataset,
"Squad": SquadDataset,
"CSV": CSVDataset
}

def datasets_options():
return {
"Llama-Index": LlamaIndexDataset.available_datasets(),
"Squad": SquadDataset.available_datasets(),
"CSV": CSVDataset.available_datasets()
}


def reranker_options():
return {
"None": None,
"CohereReranker": CohereReranker,
"ColbertReranker": ColbertReranker,
"CrossEncoderReranker": CrossEncoderReranker
}

def embedding_provider_options():
return {
"openai": ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"],
"huggingface": ["BAAI/bge-small-en-v1.5", "BAAI/bge-large-en-v1.5"],
"sentence-transformers": ["all-MiniLM-L12-v2", "all-MiniLM-L6-v2", "all-MiniLM-L12-v1", "BAAI/bge-small-en-v1.5", "BAAI/bge-large-en-v1.5"],
}
127 changes: 127 additions & 0 deletions ragged/gui/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import json
import streamlit as st
import streamlit.components.v1 as components
from ragged.metrics.retriever import HitRate, QueryType
from ragged.results import RetriverResult
from choices import dataset_provider_options, datasets_options, reranker_options, embedding_provider_options
from ragas.metrics import (
faithfulness,
answer_relevancy,
context_precision,
context_recall,
)
from ragas.metrics.critique import harmfulness

def metric_options():
return {
"faithfulness": faithfulness,
"answer_relevancy": answer_relevancy,
"context_precision": context_precision,
"context_recall": context_recall,
"harmfulness": harmfulness
}

def safe_import_wandb():
try:
import wandb
from wandb import __version__
return wandb
except ImportError:
return None

def init_wandb(dataset: str, embed_model: str):
wandb = safe_import_wandb()
if wandb is None:
st.error("Please install wandb to log metrics using `pip install wandb`")
return
run = wandb.init(project=f"ragged-vectordb", name=f"{dataset}-{embed_model}") if wandb.run is None else None

def eval_retrieval():
st.title("RAG Evaluator Quickstart")
st.write("For custom dataset and retriever evaluation, use the API")
col1, col2 = st.columns(2)
with col1:
provider = st.selectbox("Select a provider", datasets_options().keys(), placeholder="Choose a provider")
with col2:
if provider == "CSV":
# choose a csv file
dataset = st.file_uploader("Upload a CSV file", type=["csv"])
else:
dataset = st.selectbox("Select a dataset", datasets_options()[provider], placeholder="Choose a dataset", disabled=provider is None)

col1, col2 = st.columns(2)
with col1:
metrics = st.multiselect("Select metrics", metric_options().keys(), default=["faithfulness", "answer_relevancy", "context_precision", "context_recall"])
with col2:
top_k = st.number_input("Top K (Not used currently)", value=5, disabled=True)

col1, col2 = st.columns(2)
with col1:
embed_provider = st.selectbox("Select an embedding provider", embedding_provider_options().keys(), placeholder="Choose an embedding provider")
with col2:
embed_model = st.selectbox("Select an embedding model", embedding_provider_options()[embed_provider], placeholder="Choose an embedding model", disabled=embed_provider is None)

col1, col2 = st.columns(2)
with col1:
reranker = st.selectbox("Select a reranker", reranker_options(), placeholder="Choose a reranker")
with col2:
kwargs = st.text_input("Reranker kwargs", value="{}")

col1, col2 = st.columns(2)
with col1:
query_type = st.selectbox("Select a query type", [qt for qt in QueryType.__dict__.keys() if not qt.startswith("__")], placeholder="Choose a query type")
with col2:
log_wandb = st.checkbox("Log to WandB and plot in real-time", value=False)
use_existing_table = st.checkbox("Use existing table", value=False)
create_index = st.checkbox("Create index", value=False)


eval_button = st.button("Evaluate")
results = RetriverResult()
if eval_button:
dataset = dataset_provider_options()[provider](dataset)
reranker_kwargs = json.loads(kwargs)
reranker = reranker_options()[reranker](**reranker_kwargs) if reranker != "None" else None
query_type = QueryType.__dict__[query_type]
metric = metric_options()[metric](
dataset,
embedding_registry_id=embed_provider,
embed_model_kwarg={"name": embed_model},
reranker=reranker
)

results = metric.evaluate(top_k=top_k,
query_type=query_type,
create_index=create_index,
use_existing_table=use_existing_table)
total_metrics = len(results.model_dump())
cols = st.columns(total_metrics)
for idx, (k,v) in enumerate(results.model_dump().items()):
with cols[idx]:
st.metric(label=k, value=v)

if log_wandb:
wandb = safe_import_wandb()
if wandb is None:
st.error("Please install wandb to log metrics using `pip install wandb`")
return
init_wandb(dataset, embed_model)
wandb.log(results.model_dump())


if log_wandb:
st.title("Wandb Project Page")
wandb = safe_import_wandb()
if wandb is None:
st.error("Please install wandb to log metrics using `pip install wandb`")
return
init_wandb(dataset, embed_model)
project_url = wandb.run.get_project_url()
st.markdown("""
Visit the WandB project page to view the metrics in real-time.
[WandB Project Page]({project_url})
""")


if __name__ == "__main__":
eval_retrieval()
Loading

0 comments on commit 9150456

Please sign in to comment.