Skip to content

Commit

Permalink
[Data] Fix unequal partitions when grouping by multiple keys (ray-pro…
Browse files Browse the repository at this point in the history
…ject#47924)

Fixes ray-project#45303

---------

Signed-off-by: Balaji Veeramani <[email protected]>
Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
bveeramani authored and ujjawal-khare committed Oct 15, 2024
1 parent 97909f8 commit 988d6f7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
38 changes: 17 additions & 21 deletions python/ray/data/_internal/planner/exchange/sort_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,34 +174,30 @@ def sample_boundaries(
# TODO(zhilong): Update sort sample bar before finished.
samples = sample_bar.fetch_until_complete(sample_results)
del sample_results
samples = [s for s in samples if len(s) > 0]
samples: List[Block] = [s for s in samples if len(s) > 0]
# The dataset is empty
if len(samples) == 0:
return [None] * (num_reducers - 1)

# Convert samples to a sorted list[tuple[...]] where each tuple represents a
# sample.
# TODO: Once we deprecate pandas blocks, we can avoid this conversion and
# directly sort the samples.
builder = DelegatingBlockBuilder()
for sample in samples:
builder.add_block(sample)
samples = builder.build()

sample_dict = BlockAccessor.for_block(samples).to_numpy(columns=columns)
# Compute sorted indices of the samples. In np.lexsort last key is the
# primary key hence have to reverse the order.
indices = np.lexsort(list(reversed(list(sample_dict.values()))))
# Sort each column by indices, and calculate q-ths quantile items.
# Ignore the 1st item as it's not required for the boundary
for k, v in sample_dict.items():
sorted_v = v[indices]
sample_dict[k] = list(
np.quantile(
sorted_v, np.linspace(0, 1, num_reducers), interpolation="nearest"
)[1:]
)
# Return the list of boundaries as tuples
# of a form (col1_value, col2_value, ...)
return [
tuple(sample_dict[k][i] for k in sample_dict)
for i in range(num_reducers - 1)
samples_table = builder.build()
samples_dict = BlockAccessor.for_block(samples_table).to_numpy(columns=columns)
# This zip does the transposition from list of column values to list of tuples.
samples_list = sorted(zip(*samples_dict.values()))

# Each boundary corresponds to a quantile of the data.
quantile_indices = [
int(q * (len(samples_list) - 1))
for q in np.linspace(0, 1, num_reducers + 1)
]
# Exclude the first and last quantiles because they're 0 and 1.
return [samples_list[i] for i in quantile_indices[1:-1]]


def _sample_block(block: Block, n_samples: int, sort_key: SortKey) -> Block:
Expand Down
20 changes: 20 additions & 0 deletions python/ray/data/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ def test_sort_with_specified_boundaries(ray_start_regular, descending, boundarie
assert np.all(block["id"] == expected_block)


def test_sort_multiple_keys_produces_equally_sized_blocks(ray_start_regular):
# Test for https://github.com/ray-project/ray/issues/45303.
ds = ray.data.from_items(
[{"a": i, "b": j} for i in range(2) for j in range(5)], override_num_blocks=5
)

ds_sorted = ds.sort(["a", "b"])

num_rows_per_block = [
bundle.num_rows() for bundle in ds_sorted.iter_internal_ref_bundles()
]
# Number of output blocks should be equal to the number of input blocks.
assert len(num_rows_per_block) == 5, len(num_rows_per_block)
# Ideally we should have 10 rows / 5 blocks = 2 rows per block, but to make this
# test less fragile we allow for a small deviation.
assert all(
1 <= num_rows <= 3 for num_rows in num_rows_per_block
), num_rows_per_block


def test_sort_simple(ray_start_regular, use_push_based_shuffle):
num_items = 100
parallelism = 4
Expand Down

0 comments on commit 988d6f7

Please sign in to comment.