From 232fd570de3b286e15979df04a1c9df1aa6fab82 Mon Sep 17 00:00:00 2001 From: Chengrun Yang Date: Tue, 22 Oct 2024 09:29:11 -0700 Subject: [PATCH] Enable `lf.eval.inputs_from` to read csv files. Usage: ``` dataset = lf.eval.inputs_from('')() ``` or providing `pd.read_csv` arguments: ``` dataset = lf.eval.inputs_from('', index_col=0, header=0)() ``` PiperOrigin-RevId: 688580326 --- langfun/core/eval/base.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/langfun/core/eval/base.py b/langfun/core/eval/base.py index c159098..fb74953 100644 --- a/langfun/core/eval/base.py +++ b/langfun/core/eval/base.py @@ -31,6 +31,7 @@ import langfun.core.coding as lf_coding from langfun.core.llms.cache import in_memory import langfun.core.structured as lf_structured +import pandas as pd import pyglove as pg @@ -1684,10 +1685,22 @@ def visualize(cls, evaluations: list['Evaluation']) -> str | None: @pg.functor() -def inputs_from(path: str | list[str]) -> list[Any]: +def inputs_from(path: str | list[str], **kwargs) -> list[Any]: """A functor that returns a list of user-defined objects as eval inputs.""" if isinstance(path, str): - return pg.load(path) + if path.endswith('.json'): + return pg.load(path) + elif path.endswith('.csv'): + dataset_df = pd.read_csv(path, **kwargs) + dataset = [] + for i in range(dataset_df.shape[0]): + row = {} + for col in dataset_df.columns: + row[col] = dataset_df.iloc[i][col] + dataset.append(row) + return dataset + else: + raise ValueError(f'Unsupported file format: {path}') examples = [] for p in path: examples.extend(pg.load(p))