Skip to content

Commit

Permalink
GH1045 Split overload of groupby on as_index for all cases
Browse files Browse the repository at this point in the history
  • Loading branch information
loicdiridollou committed Nov 21, 2024
1 parent e610b76 commit cb09ac2
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 18 deletions.
100 changes: 86 additions & 14 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ class DataFrame(NDFrame, OpsMixin):
dropna: _bool = ...,
) -> DataFrameGroupBy[Timestamp, Literal[True]]: ...
@overload
def groupby(
def groupby( # pyright: ignore reportOverlappingOverload
self,
by: DatetimeIndex,
axis: AxisIndex | NoDefault = ...,
Expand All @@ -1124,77 +1124,149 @@ class DataFrame(NDFrame, OpsMixin):
dropna: _bool = ...,
) -> DataFrameGroupBy[Timestamp, Literal[False]]: ...
@overload
def groupby(
def groupby( # pyright: ignore reportOverlappingOverload
self,
by: TimedeltaIndex,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
as_index: Literal[True] = True,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Timedelta, bool]: ...
) -> DataFrameGroupBy[Timedelta, Literal[True]]: ...
@overload
def groupby(
self,
by: TimedeltaIndex,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: Literal[False] = ...,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Timedelta, Literal[False]]: ...
@overload
def groupby( # pyright: ignore reportOverlappingOverload
self,
by: PeriodIndex,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
as_index: Literal[True] = True,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Period, bool]: ...
) -> DataFrameGroupBy[Period, Literal[True]]: ...
@overload
def groupby(
self,
by: PeriodIndex,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: Literal[False] = ...,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Period, Literal[False]]: ...
@overload
def groupby( # pyright: ignore reportOverlappingOverload
self,
by: IntervalIndex[IntervalT],
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
as_index: Literal[True] = True,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[IntervalT, bool]: ...
) -> DataFrameGroupBy[IntervalT, Literal[True]]: ...
@overload
def groupby(
self,
by: IntervalIndex[IntervalT],
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: Literal[False] = ...,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[IntervalT, Literal[False]]: ...
@overload
def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload
self,
by: MultiIndex | GroupByObjectNonScalar | None = ...,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
as_index: Literal[True] = True,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[tuple, bool]: ...
) -> DataFrameGroupBy[tuple, Literal[True]]: ...
@overload
def groupby( # type: ignore[overload-overlap]
self,
by: MultiIndex | GroupByObjectNonScalar | None = ...,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: Literal[False] = ...,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[tuple, Literal[False]]: ...
@overload
def groupby( # pyright: ignore reportOverlappingOverload
self,
by: Series[SeriesByT],
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: Literal[True] = True,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[SeriesByT, Literal[True]]: ...
@overload
def groupby(
self,
by: Series[SeriesByT],
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
as_index: Literal[False] = ...,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[SeriesByT, Literal[False]]: ...
@overload
def groupby(
self,
by: CategoricalIndex | Index | Series,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: Literal[True] = True,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[SeriesByT, bool]: ...
) -> DataFrameGroupBy[Any, Literal[True]]: ...
@overload
def groupby(
self,
by: CategoricalIndex | Index | Series,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
as_index: Literal[False] = ...,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Any, bool]: ...
) -> DataFrameGroupBy[Any, Literal[False]]: ...
def pivot(
self,
*,
Expand Down
27 changes: 23 additions & 4 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,8 @@ def test_types_mean() -> None:
s2: pd.Series = df.mean(axis=0)
df2: pd.DataFrame = df.groupby(level=0).mean()
if TYPE_CHECKING_INVALID_USAGE:
df3: pd.DataFrame = df.groupby(axis=1, level=0).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
df3: pd.DataFrame = df.groupby(axis=1, level=0).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
s3: pd.Series = df.mean(axis=1, skipna=True, numeric_only=False)


Expand All @@ -515,8 +515,8 @@ def test_types_median() -> None:
s2: pd.Series = df.median(axis=0)
df2: pd.DataFrame = df.groupby(level=0).median()
if TYPE_CHECKING_INVALID_USAGE:
df3: pd.DataFrame = df.groupby(axis=1, level=0).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
df3: pd.DataFrame = df.groupby(axis=1, level=0).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
s3: pd.Series = df.median(axis=1, skipna=True, numeric_only=False)


Expand Down Expand Up @@ -1066,6 +1066,25 @@ def test_types_groupby_as_index() -> None:
)


def test_types_groupby_as_index_list() -> None:
"""Test type of groupby.size method depending on list of grouper GH1045."""
df = pd.DataFrame({"a": [1, 1, 2], "b": [2, 3, 2]})
check(
assert_type(
df.groupby("a", as_index=False).size(),
pd.DataFrame,
),
pd.DataFrame,
)
check(
assert_type(
df.groupby("a", as_index=True).size(),
"pd.Series[int]",
),
pd.Series,
)


def test_types_groupby_as_index_value_counts() -> None:
"""Test type of groupby.value_counts method depending on `as_index`."""
df = pd.DataFrame({"a": [1, 2, 3]})
Expand Down

0 comments on commit cb09ac2

Please sign in to comment.