Skip to content

Commit

Permalink
BUG: Column pruning failed when groupby by multi series (#708)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
ChengjieLi28 and mergify[bot] authored Sep 20, 2023
1 parent f16df9d commit fe7caba
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,10 @@ def df_groupby_agg_select_function(
ret = {}
# group by a series
groupby_series = False
if isinstance(by, list) and len(by) == 1 and isinstance(by[0], BaseSeriesData):
if isinstance(by, list) and all([isinstance(_by, BaseSeriesData) for _by in by]):
groupby_series = True
ret[by[0]] = {by[0].name}
for _by in by:
ret[_by] = {_by.name}

if isinstance(inp, BaseSeriesData):
ret[inp] = {inp.name}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ def test_df_groupby_agg():
assert labels.data in input_columns
assert input_columns[labels.data] == {"label"}

label1 = Series([1, 1, 1, 1], name="label1")
label2 = Series([2, 2, 3, 3], name="label2")
s = df.groupby(by=[label1, label2]).sum()
input_columns = InputColumnSelector.select(s.data, {"foo"})
assert len(input_columns) == 3
assert df.data in input_columns
assert input_columns[df.data] == {"foo"}
assert label1.data in input_columns
assert input_columns[label1.data] == {"label1"}
assert label2.data in input_columns
assert input_columns[label2.data] == {"label2"}


@pytest.mark.skip(reason="group by index is not supported yet")
def test_df_groupby_index_agg():
Expand Down

0 comments on commit fe7caba

Please sign in to comment.