Skip to content

Commit

Permalink
GH203 Split groupby with as_index (temptative) (#1014)
Browse files Browse the repository at this point in the history
* GH203 Split groupby with as_index

* Update to the fix

* Update to the fix

* Experiment for size

* Experiment for size

* GH203 Create new overload for DatetimeIndex

* GH203 Fix lint
  • Loading branch information
loicdiridollou authored Oct 31, 2024
1 parent 1a314a0 commit 53c299f
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 17 deletions.
46 changes: 35 additions & 11 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1054,29 +1054,53 @@ class DataFrame(NDFrame, OpsMixin):
errors: IgnoreRaise = ...,
) -> None: ...
@overload
def groupby(
def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload
self,
by: Scalar,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
as_index: Literal[True] = True,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Scalar]: ...
) -> DataFrameGroupBy[Scalar, Literal[True]]: ...
@overload
def groupby(
self,
by: Scalar,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: Literal[False] = ...,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Scalar, Literal[False]]: ...
@overload
def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload
self,
by: DatetimeIndex,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
as_index: Literal[True] = True,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Timestamp, Literal[True]]: ...
@overload
def groupby( # type: ignore[overload-overlap]
self,
by: DatetimeIndex,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: Literal[False] = ...,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Timestamp]: ...
) -> DataFrameGroupBy[Timestamp, Literal[False]]: ...
@overload
def groupby(
self,
Expand All @@ -1088,7 +1112,7 @@ class DataFrame(NDFrame, OpsMixin):
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Timedelta]: ...
) -> DataFrameGroupBy[Timedelta, bool]: ...
@overload
def groupby(
self,
Expand All @@ -1100,7 +1124,7 @@ class DataFrame(NDFrame, OpsMixin):
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Period]: ...
) -> DataFrameGroupBy[Period, bool]: ...
@overload
def groupby(
self,
Expand All @@ -1112,7 +1136,7 @@ class DataFrame(NDFrame, OpsMixin):
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[IntervalT]: ...
) -> DataFrameGroupBy[IntervalT, bool]: ...
@overload
def groupby(
self,
Expand All @@ -1124,7 +1148,7 @@ class DataFrame(NDFrame, OpsMixin):
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[tuple]: ...
) -> DataFrameGroupBy[tuple, bool]: ...
@overload
def groupby(
self,
Expand All @@ -1136,7 +1160,7 @@ class DataFrame(NDFrame, OpsMixin):
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[SeriesByT]: ...
) -> DataFrameGroupBy[SeriesByT, bool]: ...
@overload
def groupby(
self,
Expand All @@ -1148,7 +1172,7 @@ class DataFrame(NDFrame, OpsMixin):
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Any]: ...
) -> DataFrameGroupBy[Any, bool]: ...
def pivot(
self,
*,
Expand Down
16 changes: 14 additions & 2 deletions pandas-stubs/core/groupby/generic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from typing import (
Generic,
Literal,
NamedTuple,
TypeVar,
final,
overload,
)
Expand All @@ -29,6 +30,7 @@ from typing_extensions import (
)

from pandas._libs.lib import NoDefault
from pandas._libs.tslibs.timestamps import Timestamp
from pandas._typing import (
S1,
AggFuncTypeBase,
Expand Down Expand Up @@ -182,7 +184,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
self,
) -> Iterator[tuple[ByT, Series[S1]]]: ...

class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
_TT = TypeVar("_TT", bound=Literal[True, False])

class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
# error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
@overload # type: ignore[override]
def apply( # type: ignore[overload-overlap]
Expand Down Expand Up @@ -236,7 +240,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
@overload
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride, reportOverlappingOverload]
self, key: Iterable[Hashable] | slice
) -> DataFrameGroupBy[ByT]: ...
) -> DataFrameGroupBy[ByT, bool]: ...
def nunique(self, dropna: bool = ...) -> DataFrame: ...
def idxmax(
self,
Expand Down Expand Up @@ -388,3 +392,11 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
def __iter__( # pyright: ignore[reportIncompatibleMethodOverride]
self,
) -> Iterator[tuple[ByT, DataFrame]]: ...
@overload
def size(self: DataFrameGroupBy[ByT, Literal[True]]) -> Series[int]: ...
@overload
def size(self: DataFrameGroupBy[ByT, Literal[False]]) -> DataFrame: ...
@overload
def size(self: DataFrameGroupBy[Timestamp, Literal[True]]) -> Series[int]: ...
@overload
def size(self: DataFrameGroupBy[Timestamp, Literal[False]]) -> DataFrame: ...
4 changes: 0 additions & 4 deletions pandas-stubs/core/groupby/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
def sem(
self: GroupBy[DataFrame], ddof: int = ..., numeric_only: bool = ...
) -> DataFrame: ...
@final
@overload
def size(self: GroupBy[Series]) -> Series[int]: ...
@overload # return type depends on `as_index` for dataframe groupby
def size(self: GroupBy[DataFrame]) -> DataFrame | Series[int]: ...
@final
def sum(
self,
Expand Down
34 changes: 34 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,40 @@ def test_types_pivot_table() -> None:
)


def test_types_groupby_as_index() -> None:
df = pd.DataFrame({"a": [1, 2, 3]})
check(
assert_type(
df.groupby("a", as_index=False).size(),
pd.DataFrame,
),
pd.DataFrame,
)
check(
assert_type(
df.groupby("a", as_index=True).size(),
"pd.Series[int]",
),
pd.Series,
)


def test_types_groupby_size() -> None:
"""Test for GH886."""
data = [
{"date": "2023-12-01", "val": 12},
{"date": "2023-12-02", "val": 2},
{"date": "2023-12-03", "val": 1},
{"date": "2023-12-03", "val": 10},
]

df = pd.DataFrame(data)
groupby = df.groupby("date")
size = groupby.size()
frame = size.to_frame()
check(assert_type(frame.reset_index(), pd.DataFrame), pd.DataFrame)


def test_types_groupby() -> None:
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]})
df.index.name = "ind"
Expand Down

0 comments on commit 53c299f

Please sign in to comment.