diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index b688c2630d68..3b092168707f 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -23,6 +23,7 @@ from ray.data._internal.util import ( _check_pyarrow_version, _is_local_scheme, + _split_list, call_with_retry, iterate_with_retry, ) @@ -354,9 +355,9 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: read_tasks = [] for fragments, paths, metadata in zip( - np.array_split(pq_fragments, parallelism), - np.array_split(pq_paths, parallelism), - np.array_split(pq_metadata, parallelism), + _split_list(pq_fragments, parallelism), + _split_list(pq_paths, parallelism), + _split_list(pq_metadata, parallelism), ): if len(fragments) <= 0: continue diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index b66a9bc5804f..a9bdad3cdc94 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -1,3 +1,4 @@ +import time from typing import Any, Dict, Optional import numpy as np @@ -150,6 +151,23 @@ def test_list_splits(): assert _split_list(["foo", 1, [0], None], 3) == [["foo", 1], [[0]], [None]] +def split_performance(): + int_array = range(1000003) + start = time.perf_counter() + + for split in np.array_split(int_array, 100): + len(split) + + end1 = time.perf_counter() + print(f"============== {end1 - start}") + + for split in _split_list(int_array, 100): + len(split) + end2 = time.perf_counter() + print(f"============== {end2 - end1}") + assert (end2 - end1) < (end1 - start) + + def get_parquet_read_logical_op( ray_remote_args: Optional[Dict[str, Any]] = None, **read_kwargs,