-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from lancedb/cleanup
Add support for CSV datasets, synthetic data gen and minor improvements
- Loading branch information
Showing
8 changed files
with
202 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .llama_index import LlamaIndexDataset | ||
from .squad import SquadDataset | ||
from .csv import CSVDataset | ||
|
||
__all__ = ["LlamaIndexDataset", "SquadDataset"] | ||
__all__ = ["LlamaIndexDataset", "SquadDataset", "CSVDataset"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from .base import Dataset, TextNode | ||
from typing import List | ||
from datasets import load_dataset | ||
import pandas as pd | ||
|
||
|
||
class CSVDataset(Dataset): | ||
def __init__(self, path: str, context_column: str = "context", query_column: str = "query"): | ||
self.dataset = pd.read_csv(path) | ||
# get unique contexts from the train dataframe | ||
contexts = self.dataset[context_column].unique() | ||
self.documents = [TextNode(id=str(i), text=context) for i, context in enumerate(contexts)] | ||
|
||
|
||
def to_pandas(self): | ||
return self.dataset | ||
|
||
|
||
def get_contexts(self)->List[TextNode]: | ||
return self.documents | ||
|
||
@property | ||
def context_column_name(self): | ||
return "context" | ||
|
||
@property | ||
def query_column_name(self): | ||
return "query" | ||
|
||
@staticmethod | ||
def available_datasets(): | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
try: | ||
import llama_index | ||
from llama_index.core import SimpleDirectoryReader | ||
from llama_index.core.node_parser import SentenceSplitter | ||
except ImportError: | ||
raise ImportError("Please install the llama_index package by running `pip install llama_index`") | ||
|
||
import pandas as pd | ||
from ragged.dataset.base import TextNode | ||
from .prompts import Q_FROM_CONTEXT_DEFAULT | ||
from .llm_calls import HFInferenceClient, BaseInferenceClient | ||
|
||
def gen_query_context_dataset(directory: str, | ||
inference_client: BaseInferenceClient, | ||
num_questions_per_context: int = 2, | ||
query_column: str = "query", | ||
context_column: str = "context"): | ||
""" | ||
Generate query and contexts from a pandas dataframe | ||
""" | ||
docs = SimpleDirectoryReader(input_dir=directory).load_data() | ||
parser = SentenceSplitter() | ||
nodes = parser.get_nodes_from_documents(docs) | ||
nodes = [TextNode(id=node.id_, text=node.text) for node in nodes] | ||
|
||
pylist = [] | ||
for node in nodes: | ||
context = node.text | ||
queries = inference_client(Q_FROM_CONTEXT_DEFAULT.format(context=context, num_questions=num_questions_per_context)) | ||
for query in queries: | ||
pylist.append({ | ||
query_column: query, | ||
context_column: context | ||
}) | ||
|
||
# create a dataframe | ||
df = pd.DataFrame(pylist) | ||
return df | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import json | ||
import os | ||
from openai import OpenAI | ||
from typing import Optional | ||
from huggingface_hub import InferenceClient | ||
from abc import abstractmethod, ABC | ||
|
||
class BaseInferenceClient(ABC): | ||
def __init__(self, model: str, max_retries: int = 0): | ||
self.model = model | ||
self.max_reties = max_retries | ||
|
||
def init_model(self): | ||
pass | ||
|
||
@abstractmethod | ||
def call(self, prompt: str): | ||
pass | ||
|
||
def __call__(self, prompt: str): | ||
return self.call(prompt) | ||
|
||
|
||
class HFInferenceClient(BaseInferenceClient): | ||
def __init__(self, model: str, max_retries: int = 0, timeout: int = 120, acces_token: Optional[str]=None): | ||
super().__init__(model, max_retries) | ||
self.timeout = timeout | ||
self.acces_token = acces_token or os.environ.get("HF_TOKEN") | ||
self.init_model() | ||
|
||
def init_model(self): | ||
self.model = InferenceClient( | ||
model=self.model, | ||
timeout=self.timeout, | ||
token=self.acces_token | ||
) | ||
|
||
def call(self, prompt: str): | ||
response = self.model.post( | ||
json={ | ||
"inputs": prompt, | ||
"parameters": {"max_new_tokens": 1000}, | ||
"task": "text-generation", | ||
}, | ||
) | ||
return json.loads(response.decode())[0]["generated_text"] | ||
|
||
class OpenAIInferenceClient(BaseInferenceClient): | ||
def __init__(self, | ||
model: str = "gpt-4-turbo", | ||
max_retries: int = 0, | ||
timeout: int = 120, | ||
acces_token: Optional[str]=None): | ||
super().__init__(model, max_retries) | ||
self.timeout = timeout | ||
self.acces_token = acces_token or os.environ.get("OPENAI_API_KEY") | ||
self.init_model() | ||
|
||
def init_model(self): | ||
self.client = OpenAI(api_key=self.acces_token) | ||
|
||
|
||
def call(self, prompt: str): | ||
response = self.client.chat.completions.create( | ||
model=self.model, | ||
response_format={ "type": "json_object" }, | ||
messages=[ | ||
{"role": "system", "content": "You are a helpful assistant designed to output JSON."}, | ||
{"role": "user", "content": prompt} | ||
], | ||
) | ||
json_res = response.choices[0].message.content | ||
# parse the json response | ||
res = json.loads(json_res) | ||
return res["content"] | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
Q_FROM_CONTEXT_DEFAULT = \ | ||
"""" | ||
Your task is to write a list of {num_questions} detailed statements or questions for given a context. It should satisfy the following conditions: | ||
* should relevant to the specifc context. and should not be a one-liner. Rahter it should be a detailed question or statement. | ||
* MUST not be keyword based. Try to make it semantically similar to the context without using the same words. Basically string matching should not work for searching the answer. | ||
* MUST NOT mention something like "according to the passage" or "context". | ||
Provide your answer as a list named 'content' | ||
Now here is the context. | ||
Context: {context}\n | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters