Skip to content

Commit

Permalink
Change default index_keys in to_dataframe() to varied_config_keys
Browse files Browse the repository at this point in the history
For RunList.to_dataframe(), among the config keys for dataframe indices
"varied" config keys (i.e., having more than two different unique
values) would be most useful index keys to have as a default.
  • Loading branch information
wookayin committed Oct 1, 2022
1 parent 5d63f48 commit 34c05eb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
30 changes: 27 additions & 3 deletions expt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,14 @@ def to_list(self) -> List[Run]:
"""Create a new copy of list containing all the runs."""
return list(self._runs)

INDEX_EXCLUDE_DEFAULT = ('seed', 'random_seed', 'log_dir', 'train_dir')

def to_dataframe(
self,
include_config: bool = True,
config_fn: Optional[Callable[[Run], RunConfig]] = None,
index_keys: Optional[Sequence[str]] = None,
index_excludelist: Sequence[str] = ('seed', 'random_seed'),
index_excludelist: Sequence[str] = INDEX_EXCLUDE_DEFAULT,
as_hypothesis: bool = False,
hypothesis_namer: Optional[ # (run_config, runs) -> str
Callable[[RunConfig, Sequence[Run]], str]] = None,
Expand All @@ -208,7 +210,10 @@ def to_dataframe(
and returns Dict[str, number]. Additional series will be added
to the dataframe from the result of this function.
index_keys: A list of column names to include in the multi index.
If omitted, all the keys from run.config will be used (see config_fn).
If omitted (using the default setting), all the keys from run.config
that have at least two different unique values will be used
(see config_fn). If there is only one run, all the columns will be
used as the default index_keys.
index_excludelist: A list of column names to exclude from multi index.
Explicitly set as the empty list if you want to include the
default excludelist names (seed, random_seed).
Expand Down Expand Up @@ -239,12 +244,15 @@ def _default_config_fn(run: Run) -> RunConfig:

config_fn = _default_config_fn

if index_keys is None: # using default index
index_keys = varied_config_keys(self._runs, config_fn=config_fn)

for i, run in enumerate(self._runs):
config: Mapping[str, Any] = config_fn(run)
if not isinstance(config, Mapping):
raise ValueError("config_fn should return a dict-like object.")

r_keys = index_keys if index_keys is not None else config.keys()
r_keys = index_keys
for k in r_keys:
if k not in config:
raise ValueError(
Expand Down Expand Up @@ -406,6 +414,22 @@ def extract(self, pat: str, flags: int = 0) -> pd.DataFrame:
return df


def varied_config_keys(
runs: Sequence[Run],
config_fn: Callable[[Run], RunConfig],
) -> Sequence[str]:
"""Get a list of config keys (or indices in to_dataframe) that have more than
two different unique values."""

key_values = collections.defaultdict(set)
for r in runs:
for k, v in config_fn(r).items():
key_values[k].add(v)

keys = tuple(k for (k, values) in key_values.items() if len(values) > 1)
return keys


@dataclass
class Hypothesis(Iterable[Run]):
name: str
Expand Down
9 changes: 5 additions & 4 deletions expt/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ def test_to_dataframe_multiindex(self, runs_gridsearch: RunList):
# with custom config_fn
def _config_fn(run: Run):
algorithm, env, seed = run.name.split('-')
return dict(algorithm=algorithm, env=env, seed=seed)
return dict(algorithm=algorithm, env=env, seed=seed, common="common")

df = runs.to_dataframe(config_fn=_config_fn)
print(df)
assert df.index.names == ['algorithm', 'env']
assert df.index.names == ['algorithm', 'env'] # should exclude 'common'
assert list(df.columns) == ['seed', 'name', 'run'] # in order!
assert isinstance(df.run[0], Run)

Expand All @@ -209,8 +209,9 @@ def _config_fn(run: Run):
assert df.reward[0] == df.hypothesis[0].summary()['reward'][0]

# Tests index_keys and index_excludelist
df = runs.to_dataframe(config_fn=_config_fn, index_keys=['algorithm'])
assert df.index.names == ['algorithm']
df = runs.to_dataframe(config_fn=_config_fn, \
index_keys=['algorithm', 'common'])
assert df.index.names == ['algorithm', 'common']
df = runs.to_dataframe(config_fn=_config_fn, index_excludelist=['env'])
assert df.index.names == ['algorithm', 'seed']

Expand Down

0 comments on commit 34c05eb

Please sign in to comment.