From f1c7c72454ddcfca418bec98f62f64c58223def9 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Mon, 14 Oct 2024 15:37:40 +0530 Subject: [PATCH] Update llama_index.py --- ragged/dataset/llama_index.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ragged/dataset/llama_index.py b/ragged/dataset/llama_index.py index bf1fbd8..b809ab6 100644 --- a/ragged/dataset/llama_index.py +++ b/ragged/dataset/llama_index.py @@ -15,7 +15,7 @@ raise ImportError("Please install the llama_index package by running `pip install llama_index`") class LlamaIndexDataset(Dataset): - def __init__(self, dataset_name: Optional[str] = None, path: Optional[str] = None): + def __init__(self, dataset_name: Optional[str] = None, path: Optional[str] = None, context_column_name="reference_contexts", query_column_name="query" ): if path is None and dataset_name is None: raise ValueError("Either path or dataset_name must be provided") if path is not None and dataset_name is not None: @@ -41,6 +41,8 @@ def __init__(self, dataset_name: Optional[str] = None, path: Optional[str] = Non parser = SentenceSplitter() nodes = parser.get_nodes_from_documents(documents) self.documents = [TextNode(id=node.id_, text=node.text) for node in nodes] + self.context_column = context_column_name + self.query_column = query_column_name def to_pandas(self): return self.dataset.to_pandas() @@ -53,11 +55,11 @@ def get_queries(self) -> List[str]: @property def context_column_name(self): - return "reference_contexts" + return self.context_column @property def query_column_name(self): - return "query" + return self.query_column @property def answer_column_name(self):