Skip to content

Commit

Permalink
Merge pull request #4 from lancedb/cleanup
Browse files Browse the repository at this point in the history
Add support for CSV datasets, synthetic data gen and minor improvements
  • Loading branch information
AyushExel authored May 20, 2024
2 parents 657d730 + c367338 commit dad4f87
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 8 deletions.
30 changes: 28 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,31 @@ hit_rate = HitRate(
print(hit_rate.evaluate(top_k=5, query_type="all")) # Evaliate all possible query types
```

## Create custom Dataset, Metrics, Reranking connectors
# TODO
### Generate a custom semantic search dataset
Most of popular toy datasets are not semantically challenging enough to evaluate the performance of LLM based retrieval systems. Most of them work well with simple BM25 based retrieval systems. To generate a custom dataset, that is semantically challenging, you can use the following code snippet.

```python
from ragged.dataset.gen.gen_retrieval_data import gen_query_context_dataset
from ragged.inference_client import OpenAIInferenceClient

clinet = OpenAIInferenceClient()
df = gen_query_context_dataset(directory="data/source_files", inference_client=clinet)

print(df.head())
# save the dataframe
df.to_csv("data.csv")
```

Now, you can evaluate this dataset using the `ragged --quickstart vectordb` GUI or via the API:
```python
from ragged.dataset.csv import CSVDataset
from ragged.metrics.retriever import HitRate
from lancedb.rerankers import CohereReranker

data = CSVDataset(path="data.csv")
reranker = CohereReranker()

hit_rate = HitRate(data, reranker=reranker, embedding_registry_id="openai", embed_model_kwarg={"model":"text-embedding-3-small"})
res = hit_rate.evaluate(top_k=5, query_type="all")
print(res)
```
3 changes: 2 additions & 1 deletion ragged/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .llama_index import LlamaIndexDataset
from .squad import SquadDataset
from .csv import CSVDataset

__all__ = ["LlamaIndexDataset", "SquadDataset"]
__all__ = ["LlamaIndexDataset", "SquadDataset", "CSVDataset"]
32 changes: 32 additions & 0 deletions ragged/dataset/csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .base import Dataset, TextNode
from typing import List
from datasets import load_dataset
import pandas as pd


class CSVDataset(Dataset):
def __init__(self, path: str, context_column: str = "context", query_column: str = "query"):
self.dataset = pd.read_csv(path)
# get unique contexts from the train dataframe
contexts = self.dataset[context_column].unique()
self.documents = [TextNode(id=str(i), text=context) for i, context in enumerate(contexts)]


def to_pandas(self):
return self.dataset


def get_contexts(self)->List[TextNode]:
return self.documents

@property
def context_column_name(self):
return "context"

@property
def query_column_name(self):
return "query"

@staticmethod
def available_datasets():
return []
39 changes: 39 additions & 0 deletions ragged/dataset/gen/gen_retrieval_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
try:
import llama_index
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
except ImportError:
raise ImportError("Please install the llama_index package by running `pip install llama_index`")

import pandas as pd
from ragged.dataset.base import TextNode
from .prompts import Q_FROM_CONTEXT_DEFAULT
from .llm_calls import HFInferenceClient, BaseInferenceClient

def gen_query_context_dataset(directory: str,
inference_client: BaseInferenceClient,
num_questions_per_context: int = 2,
query_column: str = "query",
context_column: str = "context"):
"""
Generate query and contexts from a pandas dataframe
"""
docs = SimpleDirectoryReader(input_dir=directory).load_data()
parser = SentenceSplitter()
nodes = parser.get_nodes_from_documents(docs)
nodes = [TextNode(id=node.id_, text=node.text) for node in nodes]

pylist = []
for node in nodes:
context = node.text
queries = inference_client(Q_FROM_CONTEXT_DEFAULT.format(context=context, num_questions=num_questions_per_context))
for query in queries:
pylist.append({
query_column: query,
context_column: context
})

# create a dataframe
df = pd.DataFrame(pylist)
return df

79 changes: 79 additions & 0 deletions ragged/dataset/gen/llm_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
import os
from openai import OpenAI
from typing import Optional
from huggingface_hub import InferenceClient
from abc import abstractmethod, ABC

class BaseInferenceClient(ABC):
def __init__(self, model: str, max_retries: int = 0):
self.model = model
self.max_reties = max_retries

def init_model(self):
pass

@abstractmethod
def call(self, prompt: str):
pass

def __call__(self, prompt: str):
return self.call(prompt)


class HFInferenceClient(BaseInferenceClient):
def __init__(self, model: str, max_retries: int = 0, timeout: int = 120, acces_token: Optional[str]=None):
super().__init__(model, max_retries)
self.timeout = timeout
self.acces_token = acces_token or os.environ.get("HF_TOKEN")
self.init_model()

def init_model(self):
self.model = InferenceClient(
model=self.model,
timeout=self.timeout,
token=self.acces_token
)

def call(self, prompt: str):
response = self.model.post(
json={
"inputs": prompt,
"parameters": {"max_new_tokens": 1000},
"task": "text-generation",
},
)
return json.loads(response.decode())[0]["generated_text"]

class OpenAIInferenceClient(BaseInferenceClient):
def __init__(self,
model: str = "gpt-4-turbo",
max_retries: int = 0,
timeout: int = 120,
acces_token: Optional[str]=None):
super().__init__(model, max_retries)
self.timeout = timeout
self.acces_token = acces_token or os.environ.get("OPENAI_API_KEY")
self.init_model()

def init_model(self):
self.client = OpenAI(api_key=self.acces_token)


def call(self, prompt: str):
response = self.client.chat.completions.create(
model=self.model,
response_format={ "type": "json_object" },
messages=[
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
{"role": "user", "content": prompt}
],
)
json_res = response.choices[0].message.content
# parse the json response
res = json.loads(json_res)
return res["content"]




11 changes: 11 additions & 0 deletions ragged/dataset/gen/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Q_FROM_CONTEXT_DEFAULT = \
""""
Your task is to write a list of {num_questions} detailed statements or questions for given a context. It should satisfy the following conditions:
* should relevant to the specifc context. and should not be a one-liner. Rahter it should be a detailed question or statement.
* MUST not be keyword based. Try to make it semantically similar to the context without using the same words. Basically string matching should not work for searching the answer.
* MUST NOT mention something like "according to the passage" or "context".
Provide your answer as a list named 'content'
Now here is the context.
Context: {context}\n
"""
14 changes: 10 additions & 4 deletions ragged/gui/vectordb.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import json
import streamlit as st
import streamlit.components.v1 as components
from ragged.dataset import LlamaIndexDataset, SquadDataset
from ragged.dataset import LlamaIndexDataset, SquadDataset, CSVDataset
from ragged.metrics.retriever import HitRate, QueryType
from ragged.results import RetriverResult
from lancedb.rerankers import CohereReranker, ColbertReranker, CrossEncoderReranker

def dataset_provider_options():
return {
"Llama-Index": LlamaIndexDataset,
"Squad": SquadDataset
"Squad": SquadDataset,
"CSV": CSVDataset
}

def datasets_options():
return {
"Llama-Index": LlamaIndexDataset.available_datasets(),
"Squad": SquadDataset.available_datasets()
"Squad": SquadDataset.available_datasets(),
"CSV": CSVDataset.available_datasets()
}

def metric_options():
Expand Down Expand Up @@ -60,7 +62,11 @@ def eval_retrieval():
with col1:
provider = st.selectbox("Select a provider", datasets_options().keys(), placeholder="Choose a provider")
with col2:
dataset = st.selectbox("Select a dataset", datasets_options()[provider], placeholder="Choose a dataset", disabled=provider is None)
if provider == "CSV":
# choose a csv file
dataset = st.file_uploader("Upload a CSV file", type=["csv"])
else:
dataset = st.selectbox("Select a dataset", datasets_options()[provider], placeholder="Choose a dataset", disabled=provider is None)

col1, col2 = st.columns(2)
with col1:
Expand Down
2 changes: 1 addition & 1 deletion ragged/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class QueryConfigError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)


def deduce_query_type(query_type: str, reranker: Optional[Reranker]):
if query_type == QueryType.AUTO:
Expand Down

0 comments on commit dad4f87

Please sign in to comment.