Skip to content

Commit

Permalink
Update llama_index.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AyushExel authored Oct 14, 2024
1 parent d528594 commit f1c7c72
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions ragged/dataset/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
raise ImportError("Please install the llama_index package by running `pip install llama_index`")

class LlamaIndexDataset(Dataset):
def __init__(self, dataset_name: Optional[str] = None, path: Optional[str] = None):
def __init__(self, dataset_name: Optional[str] = None, path: Optional[str] = None, context_column_name="reference_contexts", query_column_name="query" ):
if path is None and dataset_name is None:
raise ValueError("Either path or dataset_name must be provided")
if path is not None and dataset_name is not None:
Expand All @@ -41,6 +41,8 @@ def __init__(self, dataset_name: Optional[str] = None, path: Optional[str] = Non
parser = SentenceSplitter()
nodes = parser.get_nodes_from_documents(documents)
self.documents = [TextNode(id=node.id_, text=node.text) for node in nodes]
self.context_column = context_column_name
self.query_column = query_column_name

def to_pandas(self):
return self.dataset.to_pandas()
Expand All @@ -53,11 +55,11 @@ def get_queries(self) -> List[str]:

@property
def context_column_name(self):
return "reference_contexts"
return self.context_column

@property
def query_column_name(self):
return "query"
return self.query_column

@property
def answer_column_name(self):
Expand Down

0 comments on commit f1c7c72

Please sign in to comment.