Skip to content

Commit

Permalink
[Data] support batch_format for Sort and Aggregate (ray-project#48287)
Browse files Browse the repository at this point in the history
## Why are these changes needed?
While we calling `xxx.map_groups(..., batch_format="...")`, we may
invoke sort function and creating empty blocks which still uses pyarrow
by default. And, when we invoke another sort call on top of it, we will
hit `AttributeError: 'DataFrame' object has no attribute 'num_rows'`
since we uses first block type. (However, we may have different blocks).
See more details in ray-project#46748

## Related issue number

Close ray-project#46748

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [x] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: Xingyu Long <[email protected]>
Co-authored-by: Scott Lee <[email protected]>
  • Loading branch information
xingyu-long and scottjlee authored Nov 13, 2024
1 parent 5788c4b commit 3f195b4
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
self,
input_op: LogicalOperator,
sort_key: SortKey,
batch_format: Optional[str] = "default",
):
super().__init__(
"Sort",
Expand All @@ -131,6 +132,7 @@ def __init__(
],
)
self._sort_key = sort_key
self._batch_format = batch_format

def aggregate_output_metadata(self) -> BlockMetadata:
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
Expand All @@ -145,6 +147,7 @@ def __init__(
input_op: LogicalOperator,
key: Optional[str],
aggs: List[AggregateFn],
batch_format: Optional[str] = "default",
):
super().__init__(
"Aggregate",
Expand All @@ -157,3 +160,4 @@ def __init__(
)
self._key = key
self._aggs = aggs
self._batch_format = batch_format
2 changes: 2 additions & 0 deletions python/ray/data/_internal/logical/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PhysicalPlan,
Rule,
)
from ray.data._internal.logical.rules.inherit_batch_format import InheritBatchFormatRule
from ray.data._internal.logical.rules.inherit_target_max_block_size import (
InheritTargetMaxBlockSizeRule,
)
Expand All @@ -20,6 +21,7 @@

_LOGICAL_RULES = [
ReorderRandomizeBlocksRule,
InheritBatchFormatRule,
]

_PHYSICAL_RULES = [
Expand Down
42 changes: 42 additions & 0 deletions python/ray/data/_internal/logical/rules/inherit_batch_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from collections import deque
from typing import Iterable

from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule
from ray.data._internal.logical.operators.all_to_all_operator import AbstractAllToAll
from ray.data._internal.logical.operators.map_operator import MapBatches


class InheritBatchFormatRule(Rule):
"""For AbstractAllToAll based operator, apply this rule
to inherit batch_format from upstream operator by traversing
the entire DAG."""

def apply(self, plan: LogicalPlan) -> LogicalPlan:
optimized_dag: LogicalOperator = self._apply(plan.dag)
new_plan = LogicalPlan(dag=optimized_dag, context=plan.context)
return new_plan

def _apply(self, op: LogicalOperator):
# Post-order traversal.
nodes: Iterable[LogicalOperator] = deque()
for node in op.post_order_iter():
nodes.appendleft(node)

while len(nodes) > 0:
current_op = nodes.pop()

if isinstance(current_op, AbstractAllToAll):
# traversal up the DAG until we find MapBatches with batch_format
# or we reach to source op and do nothing
upstream_op = current_op.input_dependencies[0]
while upstream_op.input_dependencies:
if (
isinstance(upstream_op, MapBatches)
and upstream_op._batch_format
):
current_op._batch_format = upstream_op._batch_format
break
upstream_op = upstream_op.input_dependencies[0]

# just return the default op
return op
2 changes: 2 additions & 0 deletions python/ray/data/_internal/planner/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
def generate_aggregate_fn(
key: Optional[str],
aggs: List[AggregateFn],
batch_format: str,
_debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None,
) -> AllToAllTransformFn:
"""Generate function to aggregate blocks by the specified key column or key
Expand Down Expand Up @@ -67,6 +68,7 @@ def fn(
boundaries=boundaries,
key=key,
aggs=aggs,
batch_format=batch_format,
)
if DataContext.get_current().use_push_based_shuffle:
scheduler = PushBasedShuffleTaskScheduler(agg_spec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def __init__(
boundaries: List[KeyType],
key: Optional[str],
aggs: List[AggregateFn],
batch_format: str,
):
super().__init__(
map_args=[boundaries, key, aggs],
reduce_args=[key, aggs],
reduce_args=[key, aggs, batch_format],
)

@staticmethod
Expand Down Expand Up @@ -62,11 +63,15 @@ def map(
def reduce(
key: Optional[str],
aggs: List[AggregateFn],
batch_format: str,
*mapper_outputs: List[Block],
partial_reduce: bool = False,
) -> Tuple[Block, BlockMetadata]:
return BlockAccessor.for_block(mapper_outputs[0]).aggregate_combined_blocks(
list(mapper_outputs), key, aggs, finalize=not partial_reduce
normalized_blocks = TableBlockAccessor.normalize_block_types(
mapper_outputs, normalize_type=batch_format
)
return BlockAccessor.for_block(normalized_blocks[0]).aggregate_combined_blocks(
list(normalized_blocks), key, aggs, finalize=not partial_reduce
)

@staticmethod
Expand Down
12 changes: 9 additions & 3 deletions python/ray/data/_internal/planner/exchange/sort_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ray.data._internal.planner.exchange.interfaces import ExchangeTaskSpec
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.table_block import TableBlockAccessor
from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata
from ray.types import ObjectRef

Expand Down Expand Up @@ -116,10 +117,11 @@ def __init__(
self,
boundaries: List[T],
sort_key: SortKey,
batch_format: str,
):
super().__init__(
map_args=[boundaries, sort_key],
reduce_args=[sort_key],
reduce_args=[sort_key, batch_format],
)

@staticmethod
Expand All @@ -138,11 +140,15 @@ def map(
@staticmethod
def reduce(
sort_key: SortKey,
batch_format: str,
*mapper_outputs: List[Block],
partial_reduce: bool = False,
) -> Tuple[Block, BlockMetadata]:
return BlockAccessor.for_block(mapper_outputs[0]).merge_sorted_blocks(
mapper_outputs, sort_key
normalized_blocks = TableBlockAccessor.normalize_block_types(
mapper_outputs, normalize_type=batch_format
)
return BlockAccessor.for_block(normalized_blocks[0]).merge_sorted_blocks(
normalized_blocks, sort_key
)

@staticmethod
Expand Down
9 changes: 7 additions & 2 deletions python/ray/data/_internal/planner/plan_all_to_all_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def plan_all_to_all_op(
"debug_limit_shuffle_execution_to_num_blocks", None
)
)
fn = generate_sort_fn(op._sort_key, debug_limit_shuffle_execution_to_num_blocks)
fn = generate_sort_fn(
op._sort_key, op._batch_format, debug_limit_shuffle_execution_to_num_blocks
)
target_max_block_size = DataContext.get_current().target_shuffle_max_block_size
elif isinstance(op, Aggregate):
debug_limit_shuffle_execution_to_num_blocks = (
Expand All @@ -80,7 +82,10 @@ def plan_all_to_all_op(
)
)
fn = generate_aggregate_fn(
op._key, op._aggs, debug_limit_shuffle_execution_to_num_blocks
op._key,
op._aggs,
op._batch_format,
debug_limit_shuffle_execution_to_num_blocks,
)
target_max_block_size = DataContext.get_current().target_shuffle_max_block_size
else:
Expand Down
5 changes: 4 additions & 1 deletion python/ray/data/_internal/planner/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

def generate_sort_fn(
sort_key: SortKey,
batch_format: str,
_debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None,
) -> AllToAllTransformFn:
"""Generate function to sort blocks by the specified key column or key function."""
Expand Down Expand Up @@ -56,7 +57,9 @@ def fn(
_, ascending = sort_key.to_pandas_sort_args()
if not ascending:
boundaries.reverse()
sort_spec = SortTaskSpec(boundaries=boundaries, sort_key=sort_key)
sort_spec = SortTaskSpec(
boundaries=boundaries, sort_key=sort_key, batch_format=batch_format
)

if DataContext.get_current().use_push_based_shuffle:
scheduler = PushBasedShuffleTaskScheduler(sort_spec)
Expand Down
69 changes: 69 additions & 0 deletions python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,75 @@ def test_sort_validate_keys(ray_start_regular_shared):
ds_named.sort(invalid_col_name).take_all()


def test_inherit_batch_format_rule():
from ray.data._internal.logical.rules.inherit_batch_format import (
InheritBatchFormatRule,
)

ctx = DataContext.get_current()

operator1 = get_parquet_read_logical_op()
operator2 = MapBatches(operator1, fn=lambda g: g, batch_format="pandas")
sort_key = SortKey("number", descending=True)
operator3 = Sort(operator2, sort_key)
original_plan = LogicalPlan(dag=operator3, context=ctx)

rule = InheritBatchFormatRule()
optimized_plan = rule.apply(original_plan)
assert optimized_plan.dag._batch_format == "pandas"


def test_batch_format_on_sort(ray_start_regular_shared):
"""Checks that the Sort op can inherit batch_format from upstream ops correctly."""
ds = ray.data.from_items(
[
{"col1": 1, "col2": 2},
{"col1": 1, "col2": 4},
{"col1": 5, "col2": 6},
{"col1": 7, "col2": 8},
]
)
df_expected = pd.DataFrame(
{
"col1": [7, 5, 1, 1],
"col2": [8, 6, 4, 2],
}
)
df_actual = (
ds.groupby("col1")
.map_groups(lambda g: g, batch_format="pandas")
.sort("col2", descending=True)
.to_pandas()
)
pd.testing.assert_frame_equal(df_actual, df_expected)


def test_batch_format_on_aggregate(ray_start_regular_shared):
"""Checks that the Aggregate op can inherit batch_format
from upstream ops correctly."""
from ray.data.aggregate import AggregateFn

ds = ray.data.from_items(
[
{"col1": 1, "col2": 2},
{"col1": 1, "col2": 4},
{"col1": 5, "col2": 6},
{"col1": 7, "col2": 8},
]
)
aggregation = AggregateFn(
init=lambda column: 1,
accumulate_row=lambda a, row: a * row["col2"],
merge=lambda a1, a2: a1 * a2,
name="prod",
)
assert (
ds.groupby("col1")
.map_groups(lambda g: g, batch_format="pandas")
.aggregate(aggregation)
) == {"prod": 384}


def test_aggregate_operator(ray_start_regular_shared):
ctx = DataContext.get_current()

Expand Down

0 comments on commit 3f195b4

Please sign in to comment.