From 3f195b4d02cb49ecff747add757c6e36b01cd066 Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Wed, 13 Nov 2024 12:00:59 -0800 Subject: [PATCH] [Data] support batch_format for Sort and Aggregate (#48287) ## Why are these changes needed? While we calling `xxx.map_groups(..., batch_format="...")`, we may invoke sort function and creating empty blocks which still uses pyarrow by default. And, when we invoke another sort call on top of it, we will hit `AttributeError: 'DataFrame' object has no attribute 'num_rows'` since we uses first block type. (However, we may have different blocks). See more details in #46748 ## Related issue number Close #46748 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Xingyu Long Co-authored-by: Scott Lee --- .../logical/operators/all_to_all_operator.py | 4 ++ .../ray/data/_internal/logical/optimizers.py | 2 + .../logical/rules/inherit_batch_format.py | 42 +++++++++++ .../ray/data/_internal/planner/aggregate.py | 2 + .../planner/exchange/aggregate_task_spec.py | 11 ++- .../planner/exchange/sort_task_spec.py | 12 +++- .../_internal/planner/plan_all_to_all_op.py | 9 ++- python/ray/data/_internal/planner/sort.py | 5 +- .../data/tests/test_execution_optimizer.py | 69 +++++++++++++++++++ 9 files changed, 147 insertions(+), 9 deletions(-) create mode 100644 python/ray/data/_internal/logical/rules/inherit_batch_format.py diff --git a/python/ray/data/_internal/logical/operators/all_to_all_operator.py b/python/ray/data/_internal/logical/operators/all_to_all_operator.py index 3179871c3685..745103f0036f 100644 --- a/python/ray/data/_internal/logical/operators/all_to_all_operator.py +++ b/python/ray/data/_internal/logical/operators/all_to_all_operator.py @@ -120,6 +120,7 @@ def __init__( self, input_op: LogicalOperator, sort_key: SortKey, + batch_format: Optional[str] = "default", ): super().__init__( "Sort", @@ -131,6 +132,7 @@ def __init__( ], ) self._sort_key = sort_key + self._batch_format = batch_format def aggregate_output_metadata(self) -> BlockMetadata: assert len(self._input_dependencies) == 1, len(self._input_dependencies) @@ -145,6 +147,7 @@ def __init__( input_op: LogicalOperator, key: Optional[str], aggs: List[AggregateFn], + batch_format: Optional[str] = "default", ): super().__init__( "Aggregate", @@ -157,3 +160,4 @@ def __init__( ) self._key = key self._aggs = aggs + self._batch_format = batch_format diff --git a/python/ray/data/_internal/logical/optimizers.py b/python/ray/data/_internal/logical/optimizers.py index e50013d4a13c..a7c2b68c06fe 100644 --- a/python/ray/data/_internal/logical/optimizers.py +++ b/python/ray/data/_internal/logical/optimizers.py @@ -6,6 +6,7 @@ PhysicalPlan, Rule, ) +from ray.data._internal.logical.rules.inherit_batch_format import InheritBatchFormatRule from ray.data._internal.logical.rules.inherit_target_max_block_size import ( InheritTargetMaxBlockSizeRule, ) @@ -20,6 +21,7 @@ _LOGICAL_RULES = [ ReorderRandomizeBlocksRule, + InheritBatchFormatRule, ] _PHYSICAL_RULES = [ diff --git a/python/ray/data/_internal/logical/rules/inherit_batch_format.py b/python/ray/data/_internal/logical/rules/inherit_batch_format.py new file mode 100644 index 000000000000..2dd265cd08b1 --- /dev/null +++ b/python/ray/data/_internal/logical/rules/inherit_batch_format.py @@ -0,0 +1,42 @@ +from collections import deque +from typing import Iterable + +from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule +from ray.data._internal.logical.operators.all_to_all_operator import AbstractAllToAll +from ray.data._internal.logical.operators.map_operator import MapBatches + + +class InheritBatchFormatRule(Rule): + """For AbstractAllToAll based operator, apply this rule + to inherit batch_format from upstream operator by traversing + the entire DAG.""" + + def apply(self, plan: LogicalPlan) -> LogicalPlan: + optimized_dag: LogicalOperator = self._apply(plan.dag) + new_plan = LogicalPlan(dag=optimized_dag, context=plan.context) + return new_plan + + def _apply(self, op: LogicalOperator): + # Post-order traversal. + nodes: Iterable[LogicalOperator] = deque() + for node in op.post_order_iter(): + nodes.appendleft(node) + + while len(nodes) > 0: + current_op = nodes.pop() + + if isinstance(current_op, AbstractAllToAll): + # traversal up the DAG until we find MapBatches with batch_format + # or we reach to source op and do nothing + upstream_op = current_op.input_dependencies[0] + while upstream_op.input_dependencies: + if ( + isinstance(upstream_op, MapBatches) + and upstream_op._batch_format + ): + current_op._batch_format = upstream_op._batch_format + break + upstream_op = upstream_op.input_dependencies[0] + + # just return the default op + return op diff --git a/python/ray/data/_internal/planner/aggregate.py b/python/ray/data/_internal/planner/aggregate.py index 6a2a6c1482d1..8f177add41d9 100644 --- a/python/ray/data/_internal/planner/aggregate.py +++ b/python/ray/data/_internal/planner/aggregate.py @@ -24,6 +24,7 @@ def generate_aggregate_fn( key: Optional[str], aggs: List[AggregateFn], + batch_format: str, _debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None, ) -> AllToAllTransformFn: """Generate function to aggregate blocks by the specified key column or key @@ -67,6 +68,7 @@ def fn( boundaries=boundaries, key=key, aggs=aggs, + batch_format=batch_format, ) if DataContext.get_current().use_push_based_shuffle: scheduler = PushBasedShuffleTaskScheduler(agg_spec) diff --git a/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py b/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py index 91d77863e40b..7b0aa0dc7ad8 100644 --- a/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py @@ -29,10 +29,11 @@ def __init__( boundaries: List[KeyType], key: Optional[str], aggs: List[AggregateFn], + batch_format: str, ): super().__init__( map_args=[boundaries, key, aggs], - reduce_args=[key, aggs], + reduce_args=[key, aggs, batch_format], ) @staticmethod @@ -62,11 +63,15 @@ def map( def reduce( key: Optional[str], aggs: List[AggregateFn], + batch_format: str, *mapper_outputs: List[Block], partial_reduce: bool = False, ) -> Tuple[Block, BlockMetadata]: - return BlockAccessor.for_block(mapper_outputs[0]).aggregate_combined_blocks( - list(mapper_outputs), key, aggs, finalize=not partial_reduce + normalized_blocks = TableBlockAccessor.normalize_block_types( + mapper_outputs, normalize_type=batch_format + ) + return BlockAccessor.for_block(normalized_blocks[0]).aggregate_combined_blocks( + list(normalized_blocks), key, aggs, finalize=not partial_reduce ) @staticmethod diff --git a/python/ray/data/_internal/planner/exchange/sort_task_spec.py b/python/ray/data/_internal/planner/exchange/sort_task_spec.py index edeea0639464..299e8793774f 100644 --- a/python/ray/data/_internal/planner/exchange/sort_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/sort_task_spec.py @@ -6,6 +6,7 @@ from ray.data._internal.planner.exchange.interfaces import ExchangeTaskSpec from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn +from ray.data._internal.table_block import TableBlockAccessor from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata from ray.types import ObjectRef @@ -116,10 +117,11 @@ def __init__( self, boundaries: List[T], sort_key: SortKey, + batch_format: str, ): super().__init__( map_args=[boundaries, sort_key], - reduce_args=[sort_key], + reduce_args=[sort_key, batch_format], ) @staticmethod @@ -138,11 +140,15 @@ def map( @staticmethod def reduce( sort_key: SortKey, + batch_format: str, *mapper_outputs: List[Block], partial_reduce: bool = False, ) -> Tuple[Block, BlockMetadata]: - return BlockAccessor.for_block(mapper_outputs[0]).merge_sorted_blocks( - mapper_outputs, sort_key + normalized_blocks = TableBlockAccessor.normalize_block_types( + mapper_outputs, normalize_type=batch_format + ) + return BlockAccessor.for_block(normalized_blocks[0]).merge_sorted_blocks( + normalized_blocks, sort_key ) @staticmethod diff --git a/python/ray/data/_internal/planner/plan_all_to_all_op.py b/python/ray/data/_internal/planner/plan_all_to_all_op.py index fc7f7fdac954..13c13ea6a9a2 100644 --- a/python/ray/data/_internal/planner/plan_all_to_all_op.py +++ b/python/ray/data/_internal/planner/plan_all_to_all_op.py @@ -71,7 +71,9 @@ def plan_all_to_all_op( "debug_limit_shuffle_execution_to_num_blocks", None ) ) - fn = generate_sort_fn(op._sort_key, debug_limit_shuffle_execution_to_num_blocks) + fn = generate_sort_fn( + op._sort_key, op._batch_format, debug_limit_shuffle_execution_to_num_blocks + ) target_max_block_size = DataContext.get_current().target_shuffle_max_block_size elif isinstance(op, Aggregate): debug_limit_shuffle_execution_to_num_blocks = ( @@ -80,7 +82,10 @@ def plan_all_to_all_op( ) ) fn = generate_aggregate_fn( - op._key, op._aggs, debug_limit_shuffle_execution_to_num_blocks + op._key, + op._aggs, + op._batch_format, + debug_limit_shuffle_execution_to_num_blocks, ) target_max_block_size = DataContext.get_current().target_shuffle_max_block_size else: diff --git a/python/ray/data/_internal/planner/sort.py b/python/ray/data/_internal/planner/sort.py index bf46fdad7039..1a14e9f260ae 100644 --- a/python/ray/data/_internal/planner/sort.py +++ b/python/ray/data/_internal/planner/sort.py @@ -20,6 +20,7 @@ def generate_sort_fn( sort_key: SortKey, + batch_format: str, _debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None, ) -> AllToAllTransformFn: """Generate function to sort blocks by the specified key column or key function.""" @@ -56,7 +57,9 @@ def fn( _, ascending = sort_key.to_pandas_sort_args() if not ascending: boundaries.reverse() - sort_spec = SortTaskSpec(boundaries=boundaries, sort_key=sort_key) + sort_spec = SortTaskSpec( + boundaries=boundaries, sort_key=sort_key, batch_format=batch_format + ) if DataContext.get_current().use_push_based_shuffle: scheduler = PushBasedShuffleTaskScheduler(sort_spec) diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 234162871d3c..d657ce1c9d98 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -1172,6 +1172,75 @@ def test_sort_validate_keys(ray_start_regular_shared): ds_named.sort(invalid_col_name).take_all() +def test_inherit_batch_format_rule(): + from ray.data._internal.logical.rules.inherit_batch_format import ( + InheritBatchFormatRule, + ) + + ctx = DataContext.get_current() + + operator1 = get_parquet_read_logical_op() + operator2 = MapBatches(operator1, fn=lambda g: g, batch_format="pandas") + sort_key = SortKey("number", descending=True) + operator3 = Sort(operator2, sort_key) + original_plan = LogicalPlan(dag=operator3, context=ctx) + + rule = InheritBatchFormatRule() + optimized_plan = rule.apply(original_plan) + assert optimized_plan.dag._batch_format == "pandas" + + +def test_batch_format_on_sort(ray_start_regular_shared): + """Checks that the Sort op can inherit batch_format from upstream ops correctly.""" + ds = ray.data.from_items( + [ + {"col1": 1, "col2": 2}, + {"col1": 1, "col2": 4}, + {"col1": 5, "col2": 6}, + {"col1": 7, "col2": 8}, + ] + ) + df_expected = pd.DataFrame( + { + "col1": [7, 5, 1, 1], + "col2": [8, 6, 4, 2], + } + ) + df_actual = ( + ds.groupby("col1") + .map_groups(lambda g: g, batch_format="pandas") + .sort("col2", descending=True) + .to_pandas() + ) + pd.testing.assert_frame_equal(df_actual, df_expected) + + +def test_batch_format_on_aggregate(ray_start_regular_shared): + """Checks that the Aggregate op can inherit batch_format + from upstream ops correctly.""" + from ray.data.aggregate import AggregateFn + + ds = ray.data.from_items( + [ + {"col1": 1, "col2": 2}, + {"col1": 1, "col2": 4}, + {"col1": 5, "col2": 6}, + {"col1": 7, "col2": 8}, + ] + ) + aggregation = AggregateFn( + init=lambda column: 1, + accumulate_row=lambda a, row: a * row["col2"], + merge=lambda a1, a2: a1 * a2, + name="prod", + ) + assert ( + ds.groupby("col1") + .map_groups(lambda g: g, batch_format="pandas") + .aggregate(aggregation) + ) == {"prod": 384} + + def test_aggregate_operator(ray_start_regular_shared): ctx = DataContext.get_current()